commit a3e0f470619270b114bb33647e6680522713d80f
parent 60e7487cbe7036fbbce533f20a41a5cdfe7725e2
Author: Nick Grato <ngrato@gmail.com>
Date: Wed, 1 Oct 2025 19:12:20 +0000
Bug 1990387 - Add streaming to openai compatible endpoint r=firefox-ai-ml-reviewers,atossou
The new openai compatible endpoint does not have the option to add streaming. We need to be able to add stream: true to the create engine as a config option.
Differential Revision: https://phabricator.services.mozilla.com/D266030
Diffstat:
4 files changed, 307 insertions(+), 71 deletions(-)
diff --git a/browser/components/genai/SmartAssistEngine.sys.mjs b/browser/components/genai/SmartAssistEngine.sys.mjs
@@ -33,7 +33,6 @@ export const SmartAssistEngine = {
modelId: Services.prefs.getStringPref("browser.ml.smartAssist.model"),
modelRevision: "main",
taskName: "text-generation",
- // TODO need to add stream support - https://bugzilla.mozilla.org/show_bug.cgi?id=1990387
});
return engineInstance;
@@ -51,10 +50,16 @@ export const SmartAssistEngine = {
* @returns {string} AI response
*/
- async fetchWithHistory(messages) {
+ async *fetchWithHistory(messages) {
const engineInstance = await this.createOpenAIEngine();
- const resp = await engineInstance.run({ args: messages });
-
- return resp.finalOutput;
+ // Use runWithGenerator to get streaming chunks directly
+ for await (const chunk of engineInstance.runWithGenerator({
+ streamOptions: { enabled: true },
+ args: messages,
+ })) {
+ if (chunk.text) {
+ yield chunk.text;
+ }
+ }
},
};
diff --git a/browser/components/genai/content/smart-assist.mjs b/browser/components/genai/content/smart-assist.mjs
@@ -64,12 +64,40 @@ export class SmartAssist extends MozLitElement {
};
_handleSubmit = async () => {
- this._updateConversationState({ role: "user", content: this.userPrompt });
- const resp = await lazy.SmartAssistEngine.fetchWithHistory(
- this.conversationState
- );
+ const formattedPrompt = (this.userPrompt || "").trim();
+ if (!formattedPrompt) {
+ return;
+ }
+
+ // Push user prompt
+ this._updateConversationState({ role: "user", content: formattedPrompt });
this.userPrompt = "";
- this._updateConversationState({ role: "assistant", content: resp });
+
+ // Create an empty assistant placeholder.
+ this._updateConversationState({ role: "assistant", content: "" });
+ const latestAssistantMessageIndex = this.conversationState.length - 1;
+
+ let acc = "";
+ try {
+ const stream = lazy.SmartAssistEngine.fetchWithHistory(
+ this.conversationState
+ );
+
+ for await (const chunk of stream) {
+ acc += chunk;
+ this.conversationState[latestAssistantMessageIndex] = {
+ ...this.conversationState[latestAssistantMessageIndex],
+ content: acc,
+ };
+ this.requestUpdate?.();
+ }
+ } catch (e) {
+ this.conversationState[latestAssistantMessageIndex] = {
+ role: "assistant",
+ content: `There was an error`,
+ };
+ this.requestUpdate?.();
+ }
};
/**
@@ -118,6 +146,9 @@ export class SmartAssist extends MozLitElement {
msg =>
html`<div class="message ${msg.role}">
<strong>${msg.role}:</strong> ${msg.content}
+ ${msg.role === "assistant" && msg.content.length === 0
+ ? html`<span>Thinking</span>`
+ : ""}
</div>`
)}
</div>
diff --git a/browser/components/genai/tests/xpcshell/test_smart_assist_engine.js b/browser/components/genai/tests/xpcshell/test_smart_assist_engine.js
@@ -24,15 +24,17 @@ registerCleanupFunction(() => {
});
add_task(async function test_createOpenAIEngine_uses_prefs_and_static_fields() {
- // Arrange known prefs
Services.prefs.setStringPref(PREF_API_KEY, "test-key-123");
Services.prefs.setStringPref(PREF_ENDPOINT, "https://example.test/v1");
Services.prefs.setStringPref(PREF_MODEL, "gpt-fake");
const sb = sinon.createSandbox();
try {
- // Stub _createEngine to capture options
- const fakeEngine = { run: sb.stub().resolves({ finalOutput: "" }) };
+ const fakeEngine = {
+ runWithGenerator() {
+ throw new Error("not used");
+ },
+ };
const stub = sb
.stub(SmartAssistEngine, "_createEngine")
.resolves(fakeEngine);
@@ -42,55 +44,76 @@ add_task(async function test_createOpenAIEngine_uses_prefs_and_static_fields() {
Assert.strictEqual(
engine,
fakeEngine,
- "Should return engine resolved by _createEngine"
+ "Should return engine from _createEngine"
);
Assert.ok(stub.calledOnce, "_createEngine should be called once");
- const passed = stub.firstCall.args[0];
- Assert.equal(passed.apiKey, "test-key-123", "apiKey should come from pref");
+ const opts = stub.firstCall.args[0];
+ Assert.equal(opts.apiKey, "test-key-123", "apiKey should come from pref");
Assert.equal(
- passed.baseURL,
+ opts.baseURL,
"https://example.test/v1",
"baseURL should come from pref"
);
- Assert.equal(passed.modelId, "gpt-fake", "modelId should come from pref");
+ Assert.equal(opts.modelId, "gpt-fake", "modelId should come from pref");
} finally {
sb.restore();
}
});
-add_task(
- async function test_fetchWithHistory_returns_finalOutput_and_forwards_args() {
- const sb = sinon.createSandbox();
- try {
- let capturedArgs = null;
- const fakeEngine = {
- async run({ args }) {
- capturedArgs = args;
- return { finalOutput: "Hello from fake engine!" };
- },
- };
-
- sb.stub(SmartAssistEngine, "_createEngine").resolves(fakeEngine);
-
- const messages = [
- { role: "system", content: "You are helpful" },
- { role: "user", content: "Hi there" },
- ];
-
- const out = await SmartAssistEngine.fetchWithHistory(messages);
-
- Assert.equal(out, "Hello from fake engine!", "Should return finalOutput");
- Assert.deepEqual(
- capturedArgs,
- messages,
- "Should forward messages unmodified as 'args' to engine.run()"
- );
- } finally {
- sb.restore();
+add_task(async function test_fetchWithHistory_streams_and_forwards_args() {
+ const sb = sinon.createSandbox();
+ try {
+ let capturedArgs = null;
+ let capturedStreamOption = null;
+
+ // Fake async generator that yields three text chunks and one empty (ignored)
+ const fakeEngine = {
+ runWithGenerator({ streamOptions, args }) {
+ capturedArgs = args;
+ capturedStreamOption = streamOptions;
+ async function* gen() {
+ yield { text: "Hello" };
+ yield { text: " from" };
+ yield { text: " fake engine!" };
+ yield {}; // ignored by SmartAssistEngine
+ }
+ return gen();
+ },
+ };
+
+ sb.stub(SmartAssistEngine, "_createEngine").resolves(fakeEngine);
+
+ const messages = [
+ { role: "system", content: "You are helpful" },
+ { role: "user", content: "Hi there" },
+ ];
+
+ // Collect streamed output
+ let acc = "";
+ for await (const t of SmartAssistEngine.fetchWithHistory(messages)) {
+ acc += t;
}
+
+ Assert.equal(
+ acc,
+ "Hello from fake engine!",
+ "Should concatenate streamed chunks"
+ );
+ Assert.deepEqual(
+ capturedArgs,
+ messages,
+ "Should forward messages as args to runWithGenerator()"
+ );
+ Assert.deepEqual(
+ capturedStreamOption.enabled,
+ true,
+ "Should enable streaming in runWithGenerator()"
+ );
+ } finally {
+ sb.restore();
}
-);
+});
add_task(
async function test_fetchWithHistory_propagates_engine_creation_rejection() {
@@ -98,18 +121,59 @@ add_task(
try {
const err = new Error("creation failed (generic)");
const stub = sb.stub(SmartAssistEngine, "_createEngine").rejects(err);
-
const messages = [{ role: "user", content: "Hi" }];
+ // Must CONSUME the async generator to trigger the rejection
+ const consume = async () => {
+ for await (const _message of SmartAssistEngine.fetchWithHistory(
+ messages
+ )) {
+ void _message;
+ }
+ };
+
await Assert.rejects(
- SmartAssistEngine.fetchWithHistory(messages),
- e => e === err, // exact error propagated
+ consume(),
+ e => e === err,
"Should propagate the same error thrown by _createEngine"
);
-
Assert.ok(stub.calledOnce, "_createEngine should be called once");
} finally {
sb.restore();
}
}
);
+
+add_task(async function test_fetchWithHistory_propagates_stream_error() {
+ const sb = sinon.createSandbox();
+ try {
+ const fakeEngine = {
+ runWithGenerator() {
+ async function* gen() {
+ yield { text: "partial" };
+ throw new Error("engine stream boom");
+ }
+ return gen();
+ },
+ };
+ sb.stub(SmartAssistEngine, "_createEngine").resolves(fakeEngine);
+
+ const consume = async () => {
+ let acc = "";
+ for await (const t of SmartAssistEngine.fetchWithHistory([
+ { role: "user", content: "x" },
+ ])) {
+ acc += t;
+ }
+ return acc;
+ };
+
+ await Assert.rejects(
+ consume(),
+ e => /engine stream boom/.test(e.message),
+ "Should propagate errors thrown during streaming"
+ );
+ } finally {
+ sb.restore();
+ }
+});
diff --git a/toolkit/components/ml/content/backends/OpenAIPipeline.mjs b/toolkit/components/ml/content/backends/OpenAIPipeline.mjs
@@ -77,6 +77,142 @@ export class OpenAIPipeline {
return new OpenAIPipeline(config, errorFactory);
}
+ /**
+ * Sends progress updates to both the port and inference progress callback.
+ *
+ * @private
+ * @param {object} args - The arguments object
+ * @param {string} args.content - The text content to send in the progress update
+ * @param {string|null} args.requestId - Unique identifier for the request
+ * @param {Function|null} args.inferenceProgressCallback - Callback function to report inference progress
+ * @param {MessagePort|null} args.port - Port for posting messages to the caller
+ * @param {boolean} args.isDone - Whether this is the final progress update
+ */
+ #sendProgress(args) {
+ const { content, requestId, inferenceProgressCallback, port, isDone } =
+ args;
+ port?.postMessage({
+ text: content,
+ ...(isDone && { done: true, finalOutput: content }),
+ ok: true,
+ });
+
+ inferenceProgressCallback?.({
+ ok: true,
+ metadata: {
+ text: content,
+ requestId,
+ tokens: [],
+ },
+ type: Progress.ProgressType.INFERENCE,
+ statusText: isDone
+ ? Progress.ProgressStatusText.DONE
+ : Progress.ProgressStatusText.IN_PROGRESS,
+ });
+ }
+
+ /**
+ * Handles streaming response from the OpenAI API.
+ * Processes each chunk as it arrives and sends progress updates.
+ *
+ * @private
+ * @param {object} args - The arguments object
+ * @param {OpenAI} args.client - OpenAI client instance
+ * @param {object} args.completionParams - Parameters for the completion request
+ * @param {string|null} args.requestId - Unique identifier for the request
+ * @param {Function|null} args.inferenceProgressCallback - Callback function to report inference progress
+ * @param {MessagePort|null} args.port - Port for posting messages to the caller
+ * @returns {Promise<object>} Result object with done, finalOutput, ok, and metrics properties
+ */
+ async #handleStreamingResponse(args) {
+ const {
+ client,
+ completionParams,
+ requestId,
+ inferenceProgressCallback,
+ port,
+ } = args;
+ const stream = await client.chat.completions.create(completionParams);
+ let streamOutput = "";
+
+ for await (const chunk of stream) {
+ const content = chunk.choices[0]?.delta?.content || "";
+ if (content) {
+ streamOutput += content;
+ this.#sendProgress({
+ content,
+ requestId,
+ inferenceProgressCallback,
+ port,
+ isDone: false,
+ });
+ }
+ }
+
+ this.#sendProgress({
+ content: "",
+ requestId,
+ inferenceProgressCallback,
+ port,
+ isDone: true,
+ });
+
+ return {
+ finalOutput: streamOutput,
+ metrics: [],
+ };
+ }
+
+ /**
+ * Handles non-streaming response from the OpenAI API.
+ * Waits for the complete response and sends it as a single update.
+ *
+ * @private
+ * @param {object} args - The arguments object
+ * @param {OpenAI} args.client - OpenAI client instance
+ * @param {object} args.completionParams - Parameters for the completion request
+ * @param {string|null} args.requestId - Unique identifier for the request
+ * @param {Function|null} args.inferenceProgressCallback - Callback function to report inference progress
+ * @param {MessagePort|null} args.port - Port for posting messages to the caller
+ * @returns {Promise<object>} Result object with done, finalOutput, ok, and metrics properties
+ */
+ async #handleNonStreamingResponse(args) {
+ const {
+ client,
+ completionParams,
+ requestId,
+ inferenceProgressCallback,
+ port,
+ } = args;
+ const completion = await client.chat.completions.create(completionParams);
+ const output = completion.choices[0].message.content;
+
+ this.#sendProgress({
+ content: output,
+ requestId,
+ inferenceProgressCallback,
+ port,
+ isDone: true,
+ });
+
+ return {
+ finalOutput: output,
+ metrics: [],
+ };
+ }
+
+ /**
+ * Executes the OpenAI pipeline with the given request.
+ * Supports both streaming and non-streaming modes based on options configuration.
+ *
+ * @param {object} request - The request object containing the messages
+ * @param {Array} request.args - Array of message objects for the chat completion
+ * @param {string|null} [requestId=null] - Unique identifier for this request
+ * @param {Function|null} [inferenceProgressCallback=null] - Callback function to report progress during inference
+ * @param {MessagePort|null} [port=null] - Port for posting messages back to the caller
+ * @returns {Promise<object>} Result object containing completion status, output, and metrics
+ * @throws {Error} Throws backend error if the API request fails
+ */
async run(
request,
requestId = null,
@@ -85,32 +221,32 @@ export class OpenAIPipeline {
) {
lazy.console.debug("Running OpenAI pipeline");
try {
- const baseURL = this.#options.baseURL || "http://localhost:11434/v1";
+ const { baseURL, apiKey, modelId } = this.#options;
const client = new OpenAIPipeline.OpenAILib.OpenAI({
- baseURL,
- apiKey: this.#options.apiKey || "ollama",
+ baseURL: baseURL ? baseURL : "http://localhost:11434/v1",
+ apiKey: apiKey || "ollama",
});
+ const stream = request.streamOptions?.enabled || false;
- const completion = await client.chat.completions.create({
- model: this.#options.modelId,
+ const completionParams = {
+ model: modelId,
messages: request.args,
- });
+ stream,
+ };
- const output = completion.choices[0].message.content;
- port?.postMessage({ done: true, finalOutput: output, ok: true });
+ const args = {
+ client,
+ completionParams,
+ requestId,
+ inferenceProgressCallback,
+ port,
+ };
- inferenceProgressCallback?.({
- ok: true,
- metadata: {
- text: output,
- requestId,
- tokens: [],
- },
- type: Progress.ProgressType.INFERENCE,
- statusText: Progress.ProgressStatusText.DONE,
- });
+ if (stream) {
+ return await this.#handleStreamingResponse(args);
+ }
- return { done: true, finalOutput: output, ok: true, metrics: [] };
+ return await this.#handleNonStreamingResponse(args);
} catch (error) {
const backendError = this.#errorFactory(error);
port?.postMessage({ done: true, ok: false, error: backendError });