tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

commit a3e0f470619270b114bb33647e6680522713d80f
parent 60e7487cbe7036fbbce533f20a41a5cdfe7725e2
Author: Nick Grato <ngrato@gmail.com>
Date:   Wed,  1 Oct 2025 19:12:20 +0000

Bug 1990387 - Add streaming to openai compatible endpoint r=firefox-ai-ml-reviewers,atossou

The new openai compatible endpoint does not have the option to add streaming. We need to be able to add stream: true to the create engine as a config option.

Differential Revision: https://phabricator.services.mozilla.com/D266030

Diffstat:
Mbrowser/components/genai/SmartAssistEngine.sys.mjs | 15++++++++++-----
Mbrowser/components/genai/content/smart-assist.mjs | 41++++++++++++++++++++++++++++++++++++-----
Mbrowser/components/genai/tests/xpcshell/test_smart_assist_engine.js | 148++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------
Mtoolkit/components/ml/content/backends/OpenAIPipeline.mjs | 174++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------
4 files changed, 307 insertions(+), 71 deletions(-)

diff --git a/browser/components/genai/SmartAssistEngine.sys.mjs b/browser/components/genai/SmartAssistEngine.sys.mjs @@ -33,7 +33,6 @@ export const SmartAssistEngine = { modelId: Services.prefs.getStringPref("browser.ml.smartAssist.model"), modelRevision: "main", taskName: "text-generation", - // TODO need to add stream support - https://bugzilla.mozilla.org/show_bug.cgi?id=1990387 }); return engineInstance; @@ -51,10 +50,16 @@ export const SmartAssistEngine = { * @returns {string} AI response */ - async fetchWithHistory(messages) { + async *fetchWithHistory(messages) { const engineInstance = await this.createOpenAIEngine(); - const resp = await engineInstance.run({ args: messages }); - - return resp.finalOutput; + // Use runWithGenerator to get streaming chunks directly + for await (const chunk of engineInstance.runWithGenerator({ + streamOptions: { enabled: true }, + args: messages, + })) { + if (chunk.text) { + yield chunk.text; + } + } }, }; diff --git a/browser/components/genai/content/smart-assist.mjs b/browser/components/genai/content/smart-assist.mjs @@ -64,12 +64,40 @@ export class SmartAssist extends MozLitElement { }; _handleSubmit = async () => { - this._updateConversationState({ role: "user", content: this.userPrompt }); - const resp = await lazy.SmartAssistEngine.fetchWithHistory( - this.conversationState - ); + const formattedPrompt = (this.userPrompt || "").trim(); + if (!formattedPrompt) { + return; + } + + // Push user prompt + this._updateConversationState({ role: "user", content: formattedPrompt }); this.userPrompt = ""; - this._updateConversationState({ role: "assistant", content: resp }); + + // Create an empty assistant placeholder. + this._updateConversationState({ role: "assistant", content: "" }); + const latestAssistantMessageIndex = this.conversationState.length - 1; + + let acc = ""; + try { + const stream = lazy.SmartAssistEngine.fetchWithHistory( + this.conversationState + ); + + for await (const chunk of stream) { + acc += chunk; + this.conversationState[latestAssistantMessageIndex] = { + ...this.conversationState[latestAssistantMessageIndex], + content: acc, + }; + this.requestUpdate?.(); + } + } catch (e) { + this.conversationState[latestAssistantMessageIndex] = { + role: "assistant", + content: `There was an error`, + }; + this.requestUpdate?.(); + } }; /** @@ -118,6 +146,9 @@ export class SmartAssist extends MozLitElement { msg => html`<div class="message ${msg.role}"> <strong>${msg.role}:</strong> ${msg.content} + ${msg.role === "assistant" && msg.content.length === 0 + ? html`<span>Thinking</span>` + : ""} </div>` )} </div> diff --git a/browser/components/genai/tests/xpcshell/test_smart_assist_engine.js b/browser/components/genai/tests/xpcshell/test_smart_assist_engine.js @@ -24,15 +24,17 @@ registerCleanupFunction(() => { }); add_task(async function test_createOpenAIEngine_uses_prefs_and_static_fields() { - // Arrange known prefs Services.prefs.setStringPref(PREF_API_KEY, "test-key-123"); Services.prefs.setStringPref(PREF_ENDPOINT, "https://example.test/v1"); Services.prefs.setStringPref(PREF_MODEL, "gpt-fake"); const sb = sinon.createSandbox(); try { - // Stub _createEngine to capture options - const fakeEngine = { run: sb.stub().resolves({ finalOutput: "" }) }; + const fakeEngine = { + runWithGenerator() { + throw new Error("not used"); + }, + }; const stub = sb .stub(SmartAssistEngine, "_createEngine") .resolves(fakeEngine); @@ -42,55 +44,76 @@ add_task(async function test_createOpenAIEngine_uses_prefs_and_static_fields() { Assert.strictEqual( engine, fakeEngine, - "Should return engine resolved by _createEngine" + "Should return engine from _createEngine" ); Assert.ok(stub.calledOnce, "_createEngine should be called once"); - const passed = stub.firstCall.args[0]; - Assert.equal(passed.apiKey, "test-key-123", "apiKey should come from pref"); + const opts = stub.firstCall.args[0]; + Assert.equal(opts.apiKey, "test-key-123", "apiKey should come from pref"); Assert.equal( - passed.baseURL, + opts.baseURL, "https://example.test/v1", "baseURL should come from pref" ); - Assert.equal(passed.modelId, "gpt-fake", "modelId should come from pref"); + Assert.equal(opts.modelId, "gpt-fake", "modelId should come from pref"); } finally { sb.restore(); } }); -add_task( - async function test_fetchWithHistory_returns_finalOutput_and_forwards_args() { - const sb = sinon.createSandbox(); - try { - let capturedArgs = null; - const fakeEngine = { - async run({ args }) { - capturedArgs = args; - return { finalOutput: "Hello from fake engine!" }; - }, - }; - - sb.stub(SmartAssistEngine, "_createEngine").resolves(fakeEngine); - - const messages = [ - { role: "system", content: "You are helpful" }, - { role: "user", content: "Hi there" }, - ]; - - const out = await SmartAssistEngine.fetchWithHistory(messages); - - Assert.equal(out, "Hello from fake engine!", "Should return finalOutput"); - Assert.deepEqual( - capturedArgs, - messages, - "Should forward messages unmodified as 'args' to engine.run()" - ); - } finally { - sb.restore(); +add_task(async function test_fetchWithHistory_streams_and_forwards_args() { + const sb = sinon.createSandbox(); + try { + let capturedArgs = null; + let capturedStreamOption = null; + + // Fake async generator that yields three text chunks and one empty (ignored) + const fakeEngine = { + runWithGenerator({ streamOptions, args }) { + capturedArgs = args; + capturedStreamOption = streamOptions; + async function* gen() { + yield { text: "Hello" }; + yield { text: " from" }; + yield { text: " fake engine!" }; + yield {}; // ignored by SmartAssistEngine + } + return gen(); + }, + }; + + sb.stub(SmartAssistEngine, "_createEngine").resolves(fakeEngine); + + const messages = [ + { role: "system", content: "You are helpful" }, + { role: "user", content: "Hi there" }, + ]; + + // Collect streamed output + let acc = ""; + for await (const t of SmartAssistEngine.fetchWithHistory(messages)) { + acc += t; } + + Assert.equal( + acc, + "Hello from fake engine!", + "Should concatenate streamed chunks" + ); + Assert.deepEqual( + capturedArgs, + messages, + "Should forward messages as args to runWithGenerator()" + ); + Assert.deepEqual( + capturedStreamOption.enabled, + true, + "Should enable streaming in runWithGenerator()" + ); + } finally { + sb.restore(); } -); +}); add_task( async function test_fetchWithHistory_propagates_engine_creation_rejection() { @@ -98,18 +121,59 @@ add_task( try { const err = new Error("creation failed (generic)"); const stub = sb.stub(SmartAssistEngine, "_createEngine").rejects(err); - const messages = [{ role: "user", content: "Hi" }]; + // Must CONSUME the async generator to trigger the rejection + const consume = async () => { + for await (const _message of SmartAssistEngine.fetchWithHistory( + messages + )) { + void _message; + } + }; + await Assert.rejects( - SmartAssistEngine.fetchWithHistory(messages), - e => e === err, // exact error propagated + consume(), + e => e === err, "Should propagate the same error thrown by _createEngine" ); - Assert.ok(stub.calledOnce, "_createEngine should be called once"); } finally { sb.restore(); } } ); + +add_task(async function test_fetchWithHistory_propagates_stream_error() { + const sb = sinon.createSandbox(); + try { + const fakeEngine = { + runWithGenerator() { + async function* gen() { + yield { text: "partial" }; + throw new Error("engine stream boom"); + } + return gen(); + }, + }; + sb.stub(SmartAssistEngine, "_createEngine").resolves(fakeEngine); + + const consume = async () => { + let acc = ""; + for await (const t of SmartAssistEngine.fetchWithHistory([ + { role: "user", content: "x" }, + ])) { + acc += t; + } + return acc; + }; + + await Assert.rejects( + consume(), + e => /engine stream boom/.test(e.message), + "Should propagate errors thrown during streaming" + ); + } finally { + sb.restore(); + } +}); diff --git a/toolkit/components/ml/content/backends/OpenAIPipeline.mjs b/toolkit/components/ml/content/backends/OpenAIPipeline.mjs @@ -77,6 +77,142 @@ export class OpenAIPipeline { return new OpenAIPipeline(config, errorFactory); } + /** + * Sends progress updates to both the port and inference progress callback. + * + * @private + * @param {object} args - The arguments object + * @param {string} args.content - The text content to send in the progress update + * @param {string|null} args.requestId - Unique identifier for the request + * @param {Function|null} args.inferenceProgressCallback - Callback function to report inference progress + * @param {MessagePort|null} args.port - Port for posting messages to the caller + * @param {boolean} args.isDone - Whether this is the final progress update + */ + #sendProgress(args) { + const { content, requestId, inferenceProgressCallback, port, isDone } = + args; + port?.postMessage({ + text: content, + ...(isDone && { done: true, finalOutput: content }), + ok: true, + }); + + inferenceProgressCallback?.({ + ok: true, + metadata: { + text: content, + requestId, + tokens: [], + }, + type: Progress.ProgressType.INFERENCE, + statusText: isDone + ? Progress.ProgressStatusText.DONE + : Progress.ProgressStatusText.IN_PROGRESS, + }); + } + + /** + * Handles streaming response from the OpenAI API. + * Processes each chunk as it arrives and sends progress updates. + * + * @private + * @param {object} args - The arguments object + * @param {OpenAI} args.client - OpenAI client instance + * @param {object} args.completionParams - Parameters for the completion request + * @param {string|null} args.requestId - Unique identifier for the request + * @param {Function|null} args.inferenceProgressCallback - Callback function to report inference progress + * @param {MessagePort|null} args.port - Port for posting messages to the caller + * @returns {Promise<object>} Result object with done, finalOutput, ok, and metrics properties + */ + async #handleStreamingResponse(args) { + const { + client, + completionParams, + requestId, + inferenceProgressCallback, + port, + } = args; + const stream = await client.chat.completions.create(completionParams); + let streamOutput = ""; + + for await (const chunk of stream) { + const content = chunk.choices[0]?.delta?.content || ""; + if (content) { + streamOutput += content; + this.#sendProgress({ + content, + requestId, + inferenceProgressCallback, + port, + isDone: false, + }); + } + } + + this.#sendProgress({ + content: "", + requestId, + inferenceProgressCallback, + port, + isDone: true, + }); + + return { + finalOutput: streamOutput, + metrics: [], + }; + } + + /** + * Handles non-streaming response from the OpenAI API. + * Waits for the complete response and sends it as a single update. + * + * @private + * @param {object} args - The arguments object + * @param {OpenAI} args.client - OpenAI client instance + * @param {object} args.completionParams - Parameters for the completion request + * @param {string|null} args.requestId - Unique identifier for the request + * @param {Function|null} args.inferenceProgressCallback - Callback function to report inference progress + * @param {MessagePort|null} args.port - Port for posting messages to the caller + * @returns {Promise<object>} Result object with done, finalOutput, ok, and metrics properties + */ + async #handleNonStreamingResponse(args) { + const { + client, + completionParams, + requestId, + inferenceProgressCallback, + port, + } = args; + const completion = await client.chat.completions.create(completionParams); + const output = completion.choices[0].message.content; + + this.#sendProgress({ + content: output, + requestId, + inferenceProgressCallback, + port, + isDone: true, + }); + + return { + finalOutput: output, + metrics: [], + }; + } + + /** + * Executes the OpenAI pipeline with the given request. + * Supports both streaming and non-streaming modes based on options configuration. + * + * @param {object} request - The request object containing the messages + * @param {Array} request.args - Array of message objects for the chat completion + * @param {string|null} [requestId=null] - Unique identifier for this request + * @param {Function|null} [inferenceProgressCallback=null] - Callback function to report progress during inference + * @param {MessagePort|null} [port=null] - Port for posting messages back to the caller + * @returns {Promise<object>} Result object containing completion status, output, and metrics + * @throws {Error} Throws backend error if the API request fails + */ async run( request, requestId = null, @@ -85,32 +221,32 @@ export class OpenAIPipeline { ) { lazy.console.debug("Running OpenAI pipeline"); try { - const baseURL = this.#options.baseURL || "http://localhost:11434/v1"; + const { baseURL, apiKey, modelId } = this.#options; const client = new OpenAIPipeline.OpenAILib.OpenAI({ - baseURL, - apiKey: this.#options.apiKey || "ollama", + baseURL: baseURL ? baseURL : "http://localhost:11434/v1", + apiKey: apiKey || "ollama", }); + const stream = request.streamOptions?.enabled || false; - const completion = await client.chat.completions.create({ - model: this.#options.modelId, + const completionParams = { + model: modelId, messages: request.args, - }); + stream, + }; - const output = completion.choices[0].message.content; - port?.postMessage({ done: true, finalOutput: output, ok: true }); + const args = { + client, + completionParams, + requestId, + inferenceProgressCallback, + port, + }; - inferenceProgressCallback?.({ - ok: true, - metadata: { - text: output, - requestId, - tokens: [], - }, - type: Progress.ProgressType.INFERENCE, - statusText: Progress.ProgressStatusText.DONE, - }); + if (stream) { + return await this.#handleStreamingResponse(args); + } - return { done: true, finalOutput: output, ok: true, metrics: [] }; + return await this.#handleNonStreamingResponse(args); } catch (error) { const backendError = this.#errorFactory(error); port?.postMessage({ done: true, ok: false, error: backendError });