tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

commit f690cf8ee9eafbe9d79eb3e828d9f2b9be83f0a2
parent 312ca7f12b0f5fa6f76430fa3d38d4fe1218fd30
Author: Serban Stanca <sstanca@mozilla.com>
Date:   Thu, 30 Oct 2025 23:35:42 +0200

Revert "Bug 1996291 - Rework getStatus to getStatusByEngineId and fix status reporting in about:inference r=firefox-ai-ml-reviewers,tarek" for causing mochitests failures in browser_aboutinference.js.

This reverts commit 4903e6275b2f9d65ad17d1dbb9b367ac2415955e.

This reverts commit 0974467706f2067587546c466e799d484eed8dd9.

This reverts commit 702c391f4f72ae7fa9c409e5d157a8a3e1c4b80c.

This reverts commit ed987c5eb7651f1bd0ecec3958d01042c9169326.

This reverts commit c98dad23a0364c009588418ad89f63e8cdc6e81d.

Diffstat:
Mtoolkit/components/aboutinference/content/aboutInference.js | 40+++++++++++++++-------------------------
Mtoolkit/components/ml/actors/MLEngineChild.sys.mjs | 171+++++++++++++++++++++++++++++++++++++++++++++----------------------------------
Mtoolkit/components/ml/actors/MLEngineParent.sys.mjs | 23++++++++---------------
Dtoolkit/components/ml/ml.d.ts | 60------------------------------------------------------------
Mtoolkit/components/ml/tests/browser/browser.toml | 11++---------
Atoolkit/components/ml/tests/browser/browser_ml_engine.js | 2162+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Dtoolkit/components/ml/tests/browser/browser_ml_engine_e2e.js | 544-------------------------------------------------------------------------------
Dtoolkit/components/ml/tests/browser/browser_ml_engine_lifetime.js | 520-------------------------------------------------------------------------------
Dtoolkit/components/ml/tests/browser/browser_ml_engine_pipeline_options.js | 825-------------------------------------------------------------------------------
Dtoolkit/components/ml/tests/browser/browser_ml_engine_rs_hub.js | 97-------------------------------------------------------------------------------
Dtoolkit/components/ml/tests/browser/browser_ml_openai.js | 253-------------------------------------------------------------------------------
Mtoolkit/components/ml/tests/browser/head.js | 21++-------------------
12 files changed, 2286 insertions(+), 2441 deletions(-)

diff --git a/toolkit/components/aboutinference/content/aboutInference.js b/toolkit/components/aboutinference/content/aboutInference.js @@ -6,11 +6,6 @@ "use strict"; /** - * @import { MLEngineParent } from "resource://gre/actors/MLEngineParent.sys.mjs" - * @import { StatusByEngineId } from "../../ml/ml.d.ts" - */ - -/** * Imports necessary modules from ChromeUtils. */ const lazy = {}; @@ -80,7 +75,6 @@ function getNumThreadsArray() { ); } -/** @type {MLEngineParent | null} */ let engineParent = null; const TINY_ARTICLE = @@ -466,26 +460,25 @@ function ts2str(ts) { */ async function updateStatus() { - const engineParent = await getEngineParent(); + if (!engineParent) { + return; + } - /** - * @type {StatusByEngineId} - */ - let statusByEngineId; + let info; // Fetch the engine status info try { - statusByEngineId = await engineParent.getStatusByEngineId(); - } catch (error) { - console.error("Failed to get the engine status", error); - statusByEngineId = new Map(); + info = await engineParent.getStatus(); + } catch (e) { + engineParent = null; // let's re-create it on errors. + info = new Map(); } // Get the container where the table will be displayed let tableContainer = document.getElementById("statusTableContainer"); // Clear the container if the map is empty - if (statusByEngineId.size === 0) { + if (info.size === 0) { tableContainer.innerHTML = ""; // Clear any existing table if (updateStatusInterval) { clearInterval(updateStatusInterval); // Clear the interval if it exists @@ -527,7 +520,7 @@ async function updateStatus() { let tbody = document.createElement("tbody"); // Iterate over the info map - for (let [engineId, { status, options }] of statusByEngineId.entries()) { + for (let [engineId, engineInfo] of info.entries()) { let row = document.createElement("tr"); // Create a cell for each piece of data @@ -536,23 +529,23 @@ async function updateStatus() { row.appendChild(engineIdCell); let statusCell = document.createElement("td"); - statusCell.textContent = status; + statusCell.textContent = engineInfo.status; row.appendChild(statusCell); let modelIdCell = document.createElement("td"); - modelIdCell.textContent = options?.modelId || "N/A"; + modelIdCell.textContent = engineInfo.options?.modelId || "N/A"; row.appendChild(modelIdCell); let dtypeCell = document.createElement("td"); - dtypeCell.textContent = options?.dtype || "N/A"; + dtypeCell.textContent = engineInfo.options?.dtype || "N/A"; row.appendChild(dtypeCell); let deviceCell = document.createElement("td"); - deviceCell.textContent = options?.device || "N/A"; + deviceCell.textContent = engineInfo.options?.device || "N/A"; row.appendChild(deviceCell); let timeoutCell = document.createElement("td"); - timeoutCell.textContent = options?.timeoutMS || "N/A"; + timeoutCell.textContent = engineInfo.options?.timeoutMS || "N/A"; row.appendChild(timeoutCell); // Append the row to the table body @@ -1140,9 +1133,6 @@ function showTab(button) { button.setAttribute("selected", "true"); } -/** - * @returns {Promise<MLEngineParent>} - */ async function getEngineParent() { if (!engineParent) { engineParent = await EngineProcess.getMLEngineParent(); diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs @@ -2,18 +2,23 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -// @ts-check - import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs"; /** - * @import { BasePromiseWorker } from "resource://gre/modules/PromiseWorker.sys.mjs" - * @import { PipelineOptions } from "chrome://global/content/ml/EngineProcess.sys.mjs" - * @import { EngineStatus, EngineId, StatusByEngineId } from "../ml.d.ts" - * @import { ProgressAndStatusCallbackParams } from "chrome://global/content/ml/Utils.sys.mjs" + * @typedef {import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker + */ + +/** + * @typedef {object} Lazy + * @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams + * @property {typeof import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker + * @property {typeof setTimeout} setTimeout + * @property {typeof clearTimeout} clearTimeout */ -const lazy = XPCOMUtils.declareLazy({ +/** @type {Lazy} */ +const lazy = {}; +ChromeUtils.defineESModuleGetters(lazy, { BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs", setTimeout: "resource://gre/modules/Timer.sys.mjs", clearTimeout: "resource://gre/modules/Timer.sys.mjs", @@ -22,24 +27,45 @@ const lazy = XPCOMUtils.declareLazy({ DEFAULT_MODELS: "chrome://global/content/ml/EngineProcess.sys.mjs", WASM_BACKENDS: "chrome://global/content/ml/EngineProcess.sys.mjs", BACKENDS: "chrome://global/content/ml/EngineProcess.sys.mjs", - console: () => - console.createInstance({ - maxLogLevelPref: "browser.ml.logLevel", - prefix: "GeckoMLEngineChild", - }), - // Prefs: - CACHE_TIMEOUT_MS: { pref: "browser.ml.modelCacheTimeout" }, - MODEL_HUB_ROOT_URL: { pref: "browser.ml.modelHubRootUrl" }, - MODEL_HUB_URL_TEMPLATE: { pref: "browser.ml.modelHubUrlTemplate" }, - LOG_LEVEL: { pref: "browser.ml.logLevel" }, - PIPELINE_OVERRIDE_OPTIONS: { - pref: "browser.ml.overridePipelineOptions", - default: "{}", - }, - // Services - mlUtils: { service: "@mozilla.org/ml-utils;1", iid: Ci.nsIMLUtils }, }); +ChromeUtils.defineLazyGetter(lazy, "console", () => { + return console.createInstance({ + maxLogLevelPref: "browser.ml.logLevel", + prefix: "GeckoMLEngineChild", + }); +}); + +XPCOMUtils.defineLazyPreferenceGetter( + lazy, + "CACHE_TIMEOUT_MS", + "browser.ml.modelCacheTimeout" +); +XPCOMUtils.defineLazyPreferenceGetter( + lazy, + "MODEL_HUB_ROOT_URL", + "browser.ml.modelHubRootUrl" +); +XPCOMUtils.defineLazyPreferenceGetter( + lazy, + "MODEL_HUB_URL_TEMPLATE", + "browser.ml.modelHubUrlTemplate" +); +XPCOMUtils.defineLazyPreferenceGetter(lazy, "LOG_LEVEL", "browser.ml.logLevel"); +XPCOMUtils.defineLazyServiceGetter( + lazy, + "mlUtils", + "@mozilla.org/ml-utils;1", + Ci.nsIMLUtils +); + +XPCOMUtils.defineLazyPreferenceGetter( + lazy, + "PIPELINE_OVERRIDE_OPTIONS", + "browser.ml.overridePipelineOptions", + "{}" +); + const SAFE_OVERRIDE_OPTIONS = [ "dtype", "logLevel", @@ -63,11 +89,11 @@ export class MLEngineChild extends JSProcessActorChild { #engineDispatchers = new Map(); /** - * Tracks that an engine is present, even if the dispatcher is not present yet. + * Engine statuses * - * @type {Map<EngineId, PipelineOptions>} + * @type {Map<string, string>} */ - #enginesPresent = new Map(); + #engineStatuses = new Map(); // eslint-disable-next-line consistent-return async receiveMessage({ name, data }) { @@ -76,8 +102,8 @@ export class MLEngineChild extends JSProcessActorChild { await this.#onNewPortCreated(data); break; } - case "MLEngine:GetStatusByEngineId": { - return this.getStatusByEngineId(); + case "MLEngine:GetStatus": { + return this.getStatus(); } case "MLEngine:ForceShutdown": { for (const engineDispatcher of this.#engineDispatchers.values()) { @@ -115,7 +141,7 @@ export class MLEngineChild extends JSProcessActorChild { this.getUpdatedPipelineOptions(pipelineOptions); options.updateOptions(updatedPipelineOptions); const engineId = options.engineId; - this.#enginesPresent.set(engineId, options); + this.#engineStatuses.set(engineId, "INITIALIZING"); // Check if we already have an engine under this id. if (this.#engineDispatchers.has(engineId)) { @@ -127,6 +153,8 @@ export class MLEngineChild extends JSProcessActorChild { type: "EnginePort:EngineReady", error: null, }); + this.#engineStatuses.set(engineId, "READY"); + return; } @@ -138,6 +166,8 @@ export class MLEngineChild extends JSProcessActorChild { this.#engineDispatchers.delete(engineId); } + this.#engineStatuses.set(engineId, "CREATING"); + const dispatcher = new EngineDispatcher(this, port, options); this.#engineDispatchers.set(engineId, dispatcher); @@ -148,9 +178,10 @@ export class MLEngineChild extends JSProcessActorChild { // NOTE: This is done after adding to #engineDispatchers to ensure other // async calls see the new dispatcher. if (!lazy.PipelineOptions.isMocked(pipelineOptions)) { - await dispatcher.isReady(); + await dispatcher.ensureInferenceEngineIsReady(); } + this.#engineStatuses.set(engineId, "READY"); port.postMessage({ type: "EnginePort:EngineReady", error: null, @@ -229,12 +260,12 @@ export class MLEngineChild extends JSProcessActorChild { * Removes an engine by its ID. Optionally shuts down if no engines remain. * * @param {string} engineId - The ID of the engine to remove. - * @param {boolean} shutDownIfEmpty - If true, shuts down the engine process if no engines remain. + * @param {boolean} [shutDownIfEmpty] - If true, shuts down the engine process if no engines remain. * @param {boolean} replacement - Flag indicating whether the engine is being replaced. */ removeEngine(engineId, shutDownIfEmpty, replacement) { this.#engineDispatchers.delete(engineId); - this.#enginesPresent.delete(engineId); + this.#engineStatuses.delete(engineId); try { this.sendAsyncMessage("MLEngine:Removed", { @@ -258,26 +289,19 @@ export class MLEngineChild extends JSProcessActorChild { } } - /** + /* * Collects information about the current status. - * - * @returns {StatusByEngineId} */ - getStatusByEngineId() { - /** @type {StatusByEngineId} */ + async getStatus() { const statusMap = new Map(); - for (let [engineId, options] of this.#enginesPresent) { - const dispatcher = this.#engineDispatchers.get(engineId); - let status = dispatcher.getStatus(); - if (!status) { - // This engine doesn't have a dispatcher yet. - status = { - status: "SHUTTING_DOWN_PREVIOUS_ENGINE", - options, - }; + for (const [key, value] of this.#engineStatuses) { + if (this.#engineDispatchers.has(key)) { + statusMap.set(key, this.#engineDispatchers.get(key).getStatus()); + } else { + // The engine is probably being created + statusMap.set(key, { status: value }); } - statusMap.set(engineId, status); } return statusMap; } @@ -311,14 +335,14 @@ class EngineDispatcher { /** @type {MessagePort | null} */ #port = null; - /** @type {number | null} */ + /** @type {TimeoutID | null} */ #keepAliveTimeout = null; /** @type {PromiseWithResolvers} */ #modelRequest; - /** @type {Promise<InferenceEngine>} */ - #engine; + /** @type {Promise<Engine> | null} */ + #engine = null; /** @type {string} */ #taskName; @@ -332,7 +356,7 @@ class EngineDispatcher { /** @type {PipelineOptions | null} */ pipelineOptions = null; - /** @type {EngineStatus} */ + /** @type {string} */ #status; /** @@ -347,7 +371,7 @@ class EngineDispatcher { * * @param {PipelineOptions} pipelineOptions * @param {?function(ProgressAndStatusCallbackParams):void} notificationsCallback The callback to call for updating about notifications such as dowload progress status. - * @returns {Promise<InferenceEngine>} + * @returns {Promise<Engine>} */ async initializeInferenceEngine(pipelineOptions, notificationsCallback) { let remoteSettingsOptions = await this.mlEngineChild.getInferenceOptions( @@ -417,8 +441,7 @@ class EngineDispatcher { * @param {PipelineOptions} pipelineOptions */ constructor(mlEngineChild, port, pipelineOptions) { - this.#status = "INITIALIZING"; - /** @type {MLEngineChild} */ + this.#status = "CREATED"; this.mlEngineChild = mlEngineChild; this.#featureId = pipelineOptions.featureId; this.#taskName = pipelineOptions.taskName; @@ -432,12 +455,9 @@ class EngineDispatcher { } ); + // Trigger the keep alive timer. this.#engine - .then(() => { - this.#status = "IDLE"; - // Trigger the keep alive timer. - void this.keepAlive(); - }) + .then(() => void this.keepAlive()) .catch(error => { if ( // Ignore errors from tests intentionally causing errors. @@ -457,9 +477,18 @@ class EngineDispatcher { return { status: this.#status, options: this.pipelineOptions, + engineId: this.#engineId, }; } + /** + * Resolves the engine to fully initialize it. + */ + async ensureInferenceEngineIsReady() { + this.#engine = await this.#engine; + this.#status = "READY"; + } + handleInitProgressStatus(port, notificationsData) { port.postMessage({ type: "EnginePort:InitProgress", @@ -505,20 +534,11 @@ class EngineDispatcher { } /** - * Wait for the engine to be ready. - */ - async isReady() { - await this.#engine; - } - - /** * @param {MessagePort} port */ #setupMessageHandler(port) { this.#port = port; - port.onmessage = async event => { - const { data } = /** @type {any} */ (event); - + port.onmessage = async ({ data }) => { switch (data.type) { case "EnginePort:Discard": { port.close(); @@ -548,7 +568,7 @@ class EngineDispatcher { case "EnginePort:Run": { const { requestId, request, engineRunOptions } = data; try { - await this.isReady(); + await this.ensureInferenceEngineIsReady(); } catch (error) { port.postMessage({ type: "EnginePort:RunResponse", @@ -569,12 +589,15 @@ class EngineDispatcher { this.keepAlive(); this.#status = "RUNNING"; - const engine = await this.#engine; try { port.postMessage({ type: "EnginePort:RunResponse", requestId, - response: await engine.run(request, requestId, engineRunOptions), + response: await this.#engine.run( + request, + requestId, + engineRunOptions + ), error: null, }); } catch (error) { @@ -585,7 +608,7 @@ class EngineDispatcher { error, }); } - this.#status = "IDLE"; + this.#status = "IDLING"; break; } default: @@ -687,7 +710,7 @@ class InferenceEngine { * @param {?function(ProgressAndStatusCallbackParams):void} config.notificationsCallback The callback to call for updating about notifications such as dowload progress status. * @param {?function(object):Promise<[string, object]>} config.getModelFileFn - A function that actually retrieves the model and headers. * @param {?function(object):Promise<void>} config.notifyModelDownloadCompleteFn - A function to notify that all files needing downloads are completed. - * @returns {Promise<InferenceEngine>} + * @returns {InferenceEngine} */ static async create({ workerUrl, diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs @@ -3,10 +3,6 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs"; -/** - * @import { MLEngineChild } from "./MLEngineChild.sys.mjs" - */ - const lazy = XPCOMUtils.declareLazy({ RemoteSettings: "resource://services-settings/remote-settings.sys.mjs", Utils: "resource://services-settings/Utils.sys.mjs", @@ -231,8 +227,10 @@ export class MLEngineParent extends JSProcessActorParent { // Wait for the existing lock to resolve await MLEngineParent.engineLocks.get(engineId); } - const { promise: lockPromise, resolve: resolveLock } = - Promise.withResolvers(); + let resolveLock; + const lockPromise = new Promise(resolve => { + resolveLock = resolve; + }); MLEngineParent.engineLocks.set(engineId, lockPromise); MLEngineParent.engineCreationAbortSignal.set(engineId, abortSignal); try { @@ -785,15 +783,10 @@ export class MLEngineParent extends JSProcessActorParent { } /** - * Goes through the engines and determines their status. This is used by about:inference - * to display debug information about the engines. - * - * @see MLEngineChild#getStatusByEngineId - * - * @returns {Promise<StatusByEngineId>} + * Gets a status */ - getStatusByEngineId() { - return this.sendQuery("MLEngine:GetStatusByEngineId"); + getStatus() { + return this.sendQuery("MLEngine:GetStatus"); } /** @@ -948,7 +941,7 @@ export class MLEngine { #requests = new Map(); /** - * @type {"uninitialized" | "ready" | "error" | "closed" | "crashed"} + * @type {"uninitialized" | "ready" | "error" | "closed"} */ engineStatus = "uninitialized"; diff --git a/toolkit/components/ml/ml.d.ts b/toolkit/components/ml/ml.d.ts @@ -1,60 +0,0 @@ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this - * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ - -/** - * This file contains the shared types for the machine learning component. The intended - * use is for defining types to be used in JSDoc. They are used in a form that the - * TypeScript language server can read them, and provide code hints. - * - * @see https://firefox-source-docs.mozilla.org/code-quality/typescript/ - */ - -import { type PipelineOptions } from "chrome://global/content/ml/EngineProcess.sys.mjs"; - -export type EngineStatus = - // The engine is waiting for a previous one to shut down. - | "SHUTTING_DOWN_PREVIOUS_ENGINE" - // The engine dispatcher has been created, and the engine is still initializing. - | "INITIALIZING" - // The engine is fully ready and idle. - | "IDLE" - // The engine is currently processing a run request. - | "RUNNING" - // The engine is in the process of terminating, but hasn't fully shut down. - | "TERMINATING" - // The engine has been fully terminated and removed. - | "TERMINATED"; - -/** - * The EngineId is used to identify a unique engine that can be shared across multiple - * consumers. This way a single model can be loaded into memory and used in different - * locations, assuming the other parameters match as well. - */ -export type EngineId = string; - -/** - * Utility type to extract the data fields from a class. It removes all of the - * functions. - */ -type DataFields<T> = { - [K in keyof T as T[K] extends Function ? never : K]: T[K]; -}; - -/** - * The PipelineOptions are a nominal class that validates the options. The - * PipelineOptionsRaw are the raw subset of those. - */ -type PipelineOptionsRaw = Partial<DataFields<PipelineOptions>>; - -/** - * Tracks the current status of the engines for about:inference. It's not used - * for deciding any business logic of the engines, only for debug info. - */ -export type StatusByEngineId = Map< - EngineId, - { - status: EngineStatus; - options: PipelineOptions | PipelineOptionsRaw; - } ->; diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml @@ -12,16 +12,11 @@ support-files = [ ["browser_ml_embeddings_generator.js"] -["browser_ml_engine_e2e.js"] - -["browser_ml_engine_lifetime.js"] - -["browser_ml_engine_pipeline_options.js"] +["browser_ml_engine.js"] +skip-if = [ "verify" ] ["browser_ml_engine_process.js"] -["browser_ml_engine_rs_hub.js"] - ["browser_ml_native.js"] skip-if = [ "os == 'android'", @@ -34,8 +29,6 @@ skip-if = [ ["browser_ml_nlp_utils.js"] -["browser_ml_openai.js"] - ["browser_ml_opfs.js"] ["browser_ml_pipeline.js"] diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine.js b/toolkit/components/ml/tests/browser/browser_ml_engine.js @@ -0,0 +1,2162 @@ +/* Any copyright is dedicated to the Public Domain. + http://creativecommons.org/publicdomain/zero/1.0/ */ + +"use strict"; + +/// <reference path="head.js" /> + +requestLongerTimeout(2); + +const { BACKENDS } = ChromeUtils.importESModule( + "chrome://global/content/ml/EngineProcess.sys.mjs" +); + +const { sinon } = ChromeUtils.importESModule( + "resource://testing-common/Sinon.sys.mjs" +); + +const { MLUtils } = ChromeUtils.importESModule( + "chrome://global/content/ml/Utils.sys.mjs" +); + +const RAW_PIPELINE_OPTIONS = { taskName: "moz-echo", timeoutMS: -1 }; +const PIPELINE_OPTIONS = new PipelineOptions({ + taskName: "moz-echo", + timeoutMS: -1, +}); + +async function checkForRemoteType(remoteType) { + let procinfo3 = await ChromeUtils.requestProcInfo(); + for (const child of procinfo3.children) { + if (child.type === remoteType) { + return true; + } + } + return false; +} + +const SHARED_TOOLS = [ + { + type: "function", + function: { + name: "search_open_tabs", + description: "Search open tabs by type.", + parameters: { + type: "object", + properties: { type: { type: "string" } }, + required: ["type"], + }, + }, + }, +]; + +const BASE_ENGINE_OPTIONS = { + featureId: "about-inference", + taskName: "text-generation", + modelId: "qwen3:0.6b", + modelRevision: "main", +}; + +/** + * End to End test that the engine is indeed initialized with wllama when it is the + * best-llama. + */ +add_task(async function test_e2e_choose_backend_best_wllama() { + // Allow any url + Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); + + const backendData = new Uint8Array([10, 20, 30]); + + const expectedBackendData = JSON.stringify(backendData); + + // Mocking function used in the workers or child doesn't work. + // So we are stubbing the code run by the worker. + const workerCode = ` + // Inject the MLEngine.worker.mjs code + + ${await getMLEngineWorkerCode()} + + // Stub + ChromeUtils.defineESModuleGetters( + lazy, + { + createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", + + }, + { global: "current" } +); + + // Change the getBackend to a mocked version that doesn't actually do inference + // but does initiate model downloads and engine initialization + + lazy.getBackend = async function ( + mlEngineWorker, + backendData, + { + modelHubUrlTemplate, + modelHubRootUrl, + modelId, + modelRevision, + modelFile, + engineId, + } = {} + ) { + + + const receivedBackendData = JSON.stringify(backendData); + if (receivedBackendData !== '${expectedBackendData}'){ + throw new Error("BackendData not equal Received: ".concat(receivedBackendData, " Expected: ", '${expectedBackendData}')); + } + + return { + run: () => {}, + }; + }; +`; + + const blob = new Blob([workerCode], { type: "application/javascript" }); + const blobURL = URL.createObjectURL(blob); + + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + + let promiseStub = sinon + .stub(MLEngineParent, "getWorkerConfig") + .callsFake(function () { + return { url: blobURL, options: { type: "module" } }; + }); + + let wasmBufferStub = sinon + .stub(MLEngineParent, "getWasmArrayBuffer") + .returns(backendData); + + let chooseBestBackendStub = sinon + .stub(MLEngineParent, "chooseBestBackend") + .returns(BACKENDS.wllama); + + try { + await createEngine({ + engineId: "main", + taskName: "real-wllama-text-generation", + featureId: "link-preview", + backend: BACKENDS.bestLlama, + modelId: "acme/bert", + modelHubUrlTemplate: "{model}/resolve/{revision}", + modelRevision: "v0.4", + modelHubRootUrl: + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", + modelFile: "onnx/config.json", + }); + } finally { + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + wasmBufferStub.restore(); + promiseStub.restore(); + chooseBestBackendStub.restore(); + } +}); + +/** + * End to End test that the engine can indeed fail if it doesn't use best-llama. + */ +add_task(async function test_e2e_choose_backend_can_detect_failure() { + // Allow any url + Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); + + const backendData = new Uint8Array([10, 20, 30]); + + const expectedBackendData = JSON.stringify("data so no matches"); + + // Mocking function used in the workers or child doesn't work. + // So we are stubbing the code run by the worker. + const workerCode = ` + // Inject the MLEngine.worker.mjs code + + ${await getMLEngineWorkerCode()} + + // Stub + ChromeUtils.defineESModuleGetters( + lazy, + { + createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", + + }, + { global: "current" } +); + + // Change the getBackend to a mocked version that doesn't actually do inference + // but does initiate model downloads and engine initialization + + lazy.getBackend = async function ( + mlEngineWorker, + backendData, + { + modelHubUrlTemplate, + modelHubRootUrl, + modelId, + modelRevision, + modelFile, + engineId, + } = {} + ) { + + + const receivedBackendData = JSON.stringify(backendData); + if (receivedBackendData !== '${expectedBackendData}'){ + throw new Error("BackendData not equal Received: ".concat(receivedBackendData, " Expected: ", '${expectedBackendData}')); + } + + return { + run: () => {}, + }; + }; +`; + + const blob = new Blob([workerCode], { type: "application/javascript" }); + const blobURL = URL.createObjectURL(blob); + + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + + let promiseStub = sinon + .stub(MLEngineParent, "getWorkerConfig") + .callsFake(function () { + return { url: blobURL, options: { type: "module" } }; + }); + + let wasmBufferStub = sinon + .stub(MLEngineParent, "getWasmArrayBuffer") + .returns(backendData); + + let chooseBestBackendStub = sinon + .stub(MLEngineParent, "chooseBestBackend") + .returns(BACKENDS.wllama); + + try { + await Assert.rejects( + createEngine({ + engineId: "main", + taskName: "real-wllama-text-generation", + featureId: "link-preview", + backend: BACKENDS.bestLlama, + modelId: "acme/bert", + modelHubUrlTemplate: "{model}/resolve/{revision}", + modelRevision: "v0.4", + modelHubRootUrl: + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", + modelFile: "onnx/config.json", + }), + /BackendData not equal Received:/, + "The call should be rejected because it used the wrong backend" + ); + } finally { + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + wasmBufferStub.restore(); + promiseStub.restore(); + chooseBestBackendStub.restore(); + } +}); + +/** + * End to End test that the engine is indeed initialized with llama.cpp when it is the + * best-llama. + */ +add_task(async function test_e2e_choose_backend_best_llamma_cpp() { + // Allow any url + Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); + + const backendData = new Uint8Array([10, 20, 30]); + + const expectedBackendData = JSON.stringify(null); + + // Mocking function used in the workers or child doesn't work. + // So we are stubbing the code run by the worker. + const workerCode = ` + // Inject the MLEngine.worker.mjs code + + ${await getMLEngineWorkerCode()} + + // Stub + ChromeUtils.defineESModuleGetters( + lazy, + { + createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", + + }, + { global: "current" } +); + + // Change the getBackend to a mocked version that doesn't actually do inference + // but does initiate model downloads and engine initialization + + lazy.getBackend = async function ( + mlEngineWorker, + backendData, + { + modelHubUrlTemplate, + modelHubRootUrl, + modelId, + modelRevision, + modelFile, + engineId, + } = {} + ) { + + + const receivedBackendData = JSON.stringify(backendData); + if (receivedBackendData !== '${expectedBackendData}'){ + throw new Error("BackendData not equal Received: ".concat(receivedBackendData, " Expected: ", '${expectedBackendData}')); + } + + return { + run: () => {}, + }; + }; +`; + + const blob = new Blob([workerCode], { type: "application/javascript" }); + const blobURL = URL.createObjectURL(blob); + + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + + let promiseStub = sinon + .stub(MLEngineParent, "getWorkerConfig") + .callsFake(function () { + return { url: blobURL, options: { type: "module" } }; + }); + + let wasmBufferStub = sinon + .stub(MLEngineParent, "getWasmArrayBuffer") + .returns(backendData); + + let chooseBestBackendStub = sinon + .stub(MLEngineParent, "chooseBestBackend") + .returns(BACKENDS.llamaCpp); + + try { + await createEngine({ + engineId: "main", + taskName: "real-wllama-text-generation", + featureId: "link-preview", + backend: BACKENDS.bestLlama, + modelId: "acme/bert", + modelHubUrlTemplate: "{model}/resolve/{revision}", + modelRevision: "v0.4", + modelHubRootUrl: + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", + modelFile: "onnx/config.json", + }); + } finally { + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + wasmBufferStub.restore(); + promiseStub.restore(); + chooseBestBackendStub.restore(); + } +}); + +/** + * End to End test that the engine can be cancelled. + */ +add_task(async function test_e2e_engine_can_be_cancelled() { + // Allow any url + Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); + + const backendData = new Uint8Array([10, 20, 30]); + + // Mocking function used in the workers or child doesn't work. + // So we are stubbing the code run by the worker. + const workerCode = ` + // Inject the MLEngine.worker.mjs code + + ${await getMLEngineWorkerCode()} + + // Stub + ChromeUtils.defineESModuleGetters( + lazy, + { + createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", + + }, + { global: "current" } +); + + // Change the getBackend to a mocked version that doesn't actually do inference + // but does initiate model downloads and engine initialization + + lazy.getBackend = async function ( + mlEngineWorker, + backendData, + { + modelHubUrlTemplate, + modelHubRootUrl, + modelId, + modelRevision, + modelFile, + engineId, + } = {} + ) { + + const url = lazy.createFileUrl({ + model: modelId, + revision: modelRevision, + file: modelFile, + urlTemplate: modelHubUrlTemplate, + rootUrl: modelHubRootUrl, + }); + + await mlEngineWorker.getModelFile({url}); + + return { + run: () => {}, + }; + }; +`; + + const blob = new Blob([workerCode], { type: "application/javascript" }); + const blobURL = URL.createObjectURL(blob); + + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + + let promiseStub = sinon + .stub(MLEngineParent, "getWorkerConfig") + .callsFake(function () { + return { url: blobURL, options: { type: "module" } }; + }); + + let wasmBufferStub = sinon + .stub(MLEngineParent, "getWasmArrayBuffer") + .returns(backendData); + + const controller = new AbortController(); + const { signal } = controller; + controller.abort(); + + try { + await Assert.rejects( + createEngine( + { + engineId: "main5", + taskName: "real-wllama-text-generation", + featureId: "link-preview", + backend: BACKENDS.llamaCpp, + modelId: "acme/bert", + modelHubUrlTemplate: "{model}/resolve/{revision}", + modelRevision: "v0.1", + modelHubRootUrl: + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", + modelFile: "onnx/config.json", + }, + null, + signal + ), + /AbortError:/, + "The call should be cancelled" + ); + } catch (err) { + Assert.ok(false, `Expected AbortError. Got ${err}`); + } finally { + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + wasmBufferStub.restore(); + promiseStub.restore(); + } +}); + +/** + * End to End test that the engine can be cancelled after fetch success. + */ +add_task(async function test_e2e_engine_can_be_cancelled_after_fetch() { + // Allow any url + Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); + + const backendData = new Uint8Array([10, 20, 30]); + + // Mocking function used in the workers or child doesn't work. + // So we are stubbing the code run by the worker. + const workerCode = ` + // Inject the MLEngine.worker.mjs code + + ${await getMLEngineWorkerCode()} + + // Stub + ChromeUtils.defineESModuleGetters( + lazy, + { + createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", + + }, + { global: "current" } +); + + // Change the getBackend to a mocked version that doesn't actually do inference + // but does initiate model downloads and engine initialization + + lazy.getBackend = async function ( + mlEngineWorker, + backendData, + { + modelHubUrlTemplate, + modelHubRootUrl, + modelId, + modelRevision, + modelFile, + engineId, + } = {} + ) { + + const url = lazy.createFileUrl({ + model: modelId, + revision: modelRevision, + file: modelFile, + urlTemplate: modelHubUrlTemplate, + rootUrl: modelHubRootUrl, + }); + + await mlEngineWorker.getModelFile({url}); + + return { + run: () => {}, + }; + }; +`; + + const blob = new Blob([workerCode], { type: "application/javascript" }); + const blobURL = URL.createObjectURL(blob); + + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + + let promiseStub = sinon + .stub(MLEngineParent, "getWorkerConfig") + .callsFake(function () { + return { url: blobURL, options: { type: "module" } }; + }); + + let wasmBufferStub = sinon + .stub(MLEngineParent, "getWasmArrayBuffer") + .returns(backendData); + + const controller = new AbortController(); + const { signal } = controller; + + const fetchUrlStub = sinon + .stub(MLUtils, "fetchUrl") + .callsFake((url, { signal: _, ...rest } = {}) => { + const p = fetch(url, rest); + + controller.abort(); + + return p; + }); + + try { + await Assert.rejects( + createEngine( + { + engineId: "main5", + taskName: "real-wllama-text-generation", + featureId: "link-preview", + backend: BACKENDS.llamaCpp, + modelId: "acme/bert", + modelHubUrlTemplate: "{model}/resolve/{revision}", + modelRevision: "v0.1", + modelHubRootUrl: + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", + modelFile: "onnx/config.json", + }, + null, + signal + ), + /AbortError:/, + "The call should be cancelled" + ); + } catch (err) { + Assert.ok(false, `Expected AbortError. Got ${err}`); + } finally { + await EngineProcess.destroyMLEngine(); + await IndexedDBCache.init({ reset: true }); + wasmBufferStub.restore(); + promiseStub.restore(); + fetchUrlStub.restore(); + } +}); + +add_task(async function test_ml_engine_basics() { + const { cleanup, remoteClients } = await setup(); + + info("Get the engine"); + const engineInstance = await createEngine(RAW_PIPELINE_OPTIONS); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + const res = await inferencePromise; + Assert.equal( + res.output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal(res.output.dtype, "q8", "The config was enriched by RS"); + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + + await cleanup(); +}); + +add_task(async function test_ml_engine_pick_feature_id() { + // one record sent back from RS contains featureId + const records = [ + { + taskName: "moz-echo", + modelId: "mozilla/distilvit", + processorId: "mozilla/distilvit", + tokenizerId: "mozilla/distilvit", + modelRevision: "main", + processorRevision: "main", + tokenizerRevision: "main", + dtype: "q8", + id: "74a71cfd-1734-44e6-85c0-69cf3e874138", + }, + { + featureId: "pdfjs-alt-text", + taskName: "moz-echo", + modelId: "mozilla/distilvit", + processorId: "mozilla/distilvit", + tokenizerId: "mozilla/distilvit", + modelRevision: "v1.0", + processorRevision: "v1.0", + tokenizerRevision: "v1.0", + dtype: "fp16", + id: "74a71cfd-1734-44e6-85c0-69cf3e874138", + }, + ]; + + const { cleanup, remoteClients } = await setup({ records }); + + info("Get the engine"); + const engineInstance = await createEngine({ + featureId: "pdfjs-alt-text", + taskName: "moz-echo", + }); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + const res = await inferencePromise; + Assert.equal( + res.output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal( + res.output.dtype, + "fp16", + "The config was enriched by RS - using a feature Id" + ); + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + + await cleanup(); +}); + +add_task(async function test_ml_engine_wasm_rejection() { + const { cleanup, remoteClients } = await setup(); + + info("Get the engine"); + const engineInstance = await createEngine(RAW_PIPELINE_OPTIONS); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].rejectPendingDownloads(1); + //await remoteClients.models.resolvePendingDownloads(1); + + let error; + try { + await inferencePromise; + } catch (e) { + error = e; + } + + is( + error?.message, + "Intentionally rejecting downloads.", + "The error is correctly surfaced." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * Tests that the engineInstanceModel's internal errors are correctly surfaced. + */ +add_task(async function test_ml_engine_model_error() { + const { cleanup, remoteClients } = await setup(); + + info("Get the engine"); + const engineInstance = await createEngine(RAW_PIPELINE_OPTIONS); + + info("Run the inference with a throwing example."); + const inferencePromise = engineInstance.run("throw"); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + //await remoteClients.models.resolvePendingDownloads(1); + + let error; + try { + await inferencePromise; + } catch (e) { + error = e; + } + is( + error?.message, + 'Error: Received the message "throw", so intentionally throwing an error.', + "The error is correctly surfaced." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * This test is really similar to the "basic" test, but tests manually destroying + * the engineInstance. + */ +add_task(async function test_ml_engine_destruction() { + const { cleanup, remoteClients } = await setup(); + + info("Get engineInstance"); + const engineInstance = await createEngine(PIPELINE_OPTIONS); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await engineInstance.terminate( + /* shutDownIfEmpty */ true, + /* replacement */ false + ); + + info( + "The engineInstance is manually destroyed. The cleanup function should wait for the engine process to be destroyed." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * Tests that we display a nice error message when the pref is off + */ +add_task(async function test_pref_is_off() { + await SpecialPowers.pushPrefEnv({ + set: [["browser.ml.enable", false]], + }); + + info("Get the engine process"); + let error; + + try { + await EngineProcess.getMLEngineParent(); + } catch (e) { + error = e; + } + is( + error?.message, + "MLEngine is disabled. Check the browser.ml prefs.", + "The error is correctly surfaced." + ); + + await SpecialPowers.pushPrefEnv({ + set: [["browser.ml.enable", true]], + }); +}); + +/** + * Tests the generic pipeline API + */ +add_task(async function test_ml_generic_pipeline() { + const { cleanup, remoteClients } = await setup(); + + info("Get engineInstance"); + + const options = new PipelineOptions({ + taskName: "summarization", + modelId: "test-echo", + modelRevision: "main", + }); + + const engineInstance = await createEngine(options); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ + args: ["This gets echoed."], + }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * Tests that the engine is reused. + */ +add_task(async function test_ml_engine_reuse_same() { + const { cleanup, remoteClients } = await setup(); + + const options = { taskName: "moz-echo", engineId: "echo" }; + const engineInstance = await createEngine(options); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + let engineInstance2 = await createEngine(options); + is(engineInstance2.engineId, "echo", "The engine ID matches"); + is(engineInstance, engineInstance2, "The engine is reused."); + const inferencePromise2 = engineInstance2.run({ data: "This gets echoed." }); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise2).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * Tests that we can have two competing engines + */ +add_task(async function test_ml_two_engines() { + const { cleanup, remoteClients } = await setup(); + + const engineInstance = await createEngine({ + taskName: "moz-echo", + engineId: "engine1", + }); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + let engineInstance2 = await createEngine({ + taskName: "moz-echo", + engineId: "engine2", + }); + + const inferencePromise2 = engineInstance2.run({ data: "This gets echoed." }); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise2).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.notEqual( + engineInstance.engineId, + engineInstance2.engineId, + "Should be different engines" + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * Tests that we can have the same engine reinitialized + */ +add_task(async function test_ml_dupe_engines() { + const { cleanup, remoteClients } = await setup(); + + const engineInstance = await createEngine({ + taskName: "moz-echo", + engineId: "engine1", + }); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + let engineInstance2 = await createEngine({ + taskName: "moz-echo", + engineId: "engine1", + timeoutMS: 2000, // that makes the options different + }); + const inferencePromise2 = engineInstance2.run({ data: "This gets echoed." }); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise2).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.notEqual( + engineInstance, + engineInstance2, + "Should be different engines" + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +add_task(async function test_ml_engine_override_options() { + const { cleanup, remoteClients } = await setup(); + + info("Get the engine"); + const engineInstance = await createEngine({ + taskName: "moz-echo", + modelRevision: "v1", + }); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal( + (await inferencePromise).output.modelRevision, + "v1", + "The config options goes through and overrides." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * Tests a custom model hub + */ +add_task(async function test_ml_custom_hub() { + const { cleanup, remoteClients } = await setup(); + + info("Get engineInstance"); + + const options = new PipelineOptions({ + taskName: "summarization", + modelId: "test-echo", + modelRevision: "main", + modelHubRootUrl: "https://example.com", + modelHubUrlTemplate: "models/{model}/{revision}", + }); + + const engineInstance = await createEngine(options); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ + args: ["This gets echoed."], + }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + let res = await inferencePromise; + + Assert.equal( + res.output, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal( + res.config.modelHubRootUrl, + "https://example.com", + "The pipeline used the custom hub" + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * Make sure we don't get race conditions when running several inference runs in parallel + * + */ + +add_task(async function test_ml_engine_parallel() { + const { cleanup, remoteClients } = await setup(); + + // We're doing 10 calls and each echo call will take from 0 to 1000ms + // So we're sure we're mixing runs. + let sleepTimes = [300, 1000, 700, 0, 500, 900, 400, 800, 600, 100]; + let numCalls = 10; + + async function run(x) { + const engineInstance = await createEngine(RAW_PIPELINE_OPTIONS); + + let msg = `${x} - This gets echoed.`; + let res = engineInstance.run({ + data: msg, + sleepTime: sleepTimes[x], + }); + + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + res = await res; + + return res; + } + + info(`Run ${numCalls} inferences in parallel`); + let runs = []; + for (let x = 0; x < numCalls; x++) { + runs.push(run(x)); + } + + // await all runs + const results = await Promise.all(runs); + Assert.equal(results.length, numCalls, `All ${numCalls} were successful`); + + // check that each one got their own stuff + for (let y = 0; y < numCalls; y++) { + Assert.equal( + results[y].output.echo, + `${y} - This gets echoed.`, + `Result ${y} is correct` + ); + } + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + + await cleanup(); +}); + +/** + * Test threading support + */ +add_task(async function test_ml_threading_support() { + const { cleanup, remoteClients } = await setup(); + + info("Get engineInstance"); + + const options = new PipelineOptions({ + taskName: "summarization", + modelId: "test-echo", + modelRevision: "main", + }); + + const engineInstance = await createEngine(options); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ + args: ["This gets echoed."], + }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + let res = await inferencePromise; + + ok(res.multiThreadSupported, "Multi-thread should be supported"); + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +add_task(async function test_ml_engine_get_status() { + const { cleanup, remoteClients } = await setup(); + + info("Get the engine"); + const engineInstance = await createEngine({ taskName: "moz-echo" }); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + const res = await inferencePromise; + Assert.equal( + res.output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + const expected = { + "default-engine": { + status: "IDLING", + options: { + useExternalDataFormat: false, + engineId: "default-engine", + featureId: null, + taskName: "moz-echo", + timeoutMS: 1000, + modelHubRootUrl: + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", + modelHubUrlTemplate: "{model}/{revision}", + modelId: "mozilla/distilvit", + modelRevision: "main", + tokenizerId: "mozilla/distilvit", + tokenizerRevision: "main", + processorId: "mozilla/distilvit", + processorRevision: "main", + logLevel: "All", + runtimeFilename: "ort-wasm-simd-threaded.jsep.wasm", + staticEmbeddingsOptions: null, + device: null, + dtype: "q8", + numThreads: "NOT_COMPARED", + executionPriority: null, + kvCacheDtype: null, + numContext: 1024, + numBatch: 1024, + numUbatch: 1024, + flashAttn: false, + useMmap: false, + useMlock: true, + numThreadsDecoding: null, + modelFile: null, + backend: null, + modelHub: null, + baseURL: null, + apiKey: null, + }, + engineId: "default-engine", + }, + }; + + let status = await engineInstance.mlEngineParent.getStatus(); + status = JSON.parse(JSON.stringify(Object.fromEntries(status))); + + status["default-engine"].options.numThreads = "NOT_COMPARED"; + Assert.deepEqual(status, expected); + + await ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + + await cleanup(); +}); + +add_task(async function test_ml_engine_not_enough_memory() { + const { cleanup } = await setup({ + prefs: [ + ["browser.ml.checkForMemory", true], + ["browser.ml.minimumPhysicalMemory", 99999], + ], + }); + + info("Get the greedy engine"); + + await Assert.rejects( + createEngine({ + modelId: "testing/greedy", + taskName: "moz-echo", + dtype: "q8", + numThreads: 1, + device: "wasm", + }), + /Not enough physical memory/, + "The call should be rejected because of a lack of memory" + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +/** + * Helper function to create a basic set of valid options + */ +function getValidOptions(overrides = {}) { + return Object.assign( + { + engineId: "validEngine1", + featureId: "pdfjs-alt-text", + taskName: "valid_task", + modelHubRootUrl: "https://example.com", + modelHubUrlTemplate: "https://example.com/{modelId}", + timeoutMS: 5000, + modelId: "validModel", + modelRevision: "v1", + tokenizerId: "validTokenizer", + tokenizerRevision: "v1", + processorId: "validProcessor", + processorRevision: "v1", + logLevel: null, + runtimeFilename: "runtime.wasm", + device: InferenceDevice.GPU, + numThreads: 4, + executionPriority: ExecutionPriority.NORMAL, + }, + overrides + ); +} + +/** + * A collection of test cases for invalid and valid values. + */ +const commonInvalidCases = [ + { description: "Invalid value (special characters)", value: "org1/my!value" }, + { + description: "Invalid value (special characters in organization)", + value: "org@1/my-value", + }, + { description: "Invalid value (missing name part)", value: "org1/" }, + { + description: "Invalid value (invalid characters in name)", + value: "my$value", + }, +]; + +const commonValidCases = [ + { description: "Valid organization/name", value: "org1/my-value" }, + { description: "Valid name only", value: "my-value" }, + { + description: "Valid name with underscores and dashes", + value: "my_value-123", + }, + { + description: "Valid organization with underscores and dashes", + value: "org_123/my-value", + }, +]; + +const pipelineOptionsCases = [ + // Invalid cases for various fields + ...commonInvalidCases.map(test => ({ + description: `Invalid processorId (${test.description})`, + options: { processorId: test.value }, + expectedError: /Invalid value/, + })), + ...commonInvalidCases.map(test => ({ + description: `Invalid tokenizerId (${test.description})`, + options: { tokenizerId: test.value }, + expectedError: /Invalid value/, + })), + ...commonInvalidCases.map(test => ({ + description: `Invalid modelId (${test.description})`, + options: { modelId: test.value }, + expectedError: /Invalid value/, + })), + + // Valid cases for various fields + ...commonValidCases.map(test => ({ + description: `Valid processorId (${test.description})`, + options: { processorId: test.value }, + expected: { processorId: test.value }, + })), + ...commonValidCases.map(test => ({ + description: `Valid tokenizerId (${test.description})`, + options: { tokenizerId: test.value }, + expected: { tokenizerId: test.value }, + })), + ...commonValidCases.map(test => ({ + description: `Valid modelId (${test.description})`, + options: { modelId: test.value }, + expected: { modelId: test.value }, + })), + + // Invalid values + { + description: "Invalid hub", + options: { modelHub: "rogue" }, + expectedError: /Invalid value/, + }, + { + description: "Invalid timeoutMS", + options: { timeoutMS: -3 }, + expectedError: /Invalid value/, + }, + { + description: "Invalid timeoutMS", + options: { timeoutMS: 40000000 }, + expectedError: /Invalid value/, + }, + { + description: "Invalid featureId", + options: { featureId: "unknown" }, + expectedError: /Invalid value/, + }, + { + description: "Invalid dtype", + options: { dtype: "invalid_dtype" }, + expectedError: /Invalid value/, + }, + { + description: "Invalid device", + options: { device: "invalid_device" }, + expectedError: /Invalid value/, + }, + { + description: "Invalid executionPriority", + options: { executionPriority: "invalid_priority" }, + expectedError: /Invalid value/, + }, + { + description: "Invalid logLevel", + options: { logLevel: "invalid_log_level" }, + expectedError: /Invalid value/, + }, + + // Valid values + { + description: "valid hub", + options: { modelHub: "huggingface" }, + expected: { modelHub: "huggingface" }, + }, + { + description: "valid hub", + options: { modelHub: "mozilla" }, + expected: { modelHub: "mozilla" }, + }, + { + description: "valid timeoutMS", + options: { timeoutMS: 12345 }, + expected: { timeoutMS: 12345 }, + }, + { + description: "valid timeoutMS", + options: { timeoutMS: -1 }, + expected: { timeoutMS: -1 }, + }, + + { + description: "Valid dtype", + options: { dtype: QuantizationLevel.FP16 }, + expected: { dtype: QuantizationLevel.FP16 }, + }, + { + description: "Valid device", + options: { device: InferenceDevice.WASM }, + expected: { device: InferenceDevice.WASM }, + }, + { + description: "Valid executionPriority", + options: { executionPriority: ExecutionPriority.HIGH }, + expected: { executionPriority: ExecutionPriority.HIGH }, + }, + { + description: "Valid logLevel (Info)", + options: { logLevel: LogLevel.INFO }, + expected: { logLevel: LogLevel.INFO }, + }, + { + description: "Valid logLevel (Critical)", + options: { logLevel: LogLevel.CRITICAL }, + expected: { logLevel: LogLevel.CRITICAL }, + }, + { + description: "Valid logLevel (All)", + options: { logLevel: LogLevel.ALL }, + expected: { logLevel: LogLevel.ALL }, + }, + { + description: "Valid modelId", + options: { modelId: "Qwen2.5-0.5B-Instruct" }, + expected: { modelId: "Qwen2.5-0.5B-Instruct" }, + }, + + // Invalid revision cases + { + description: "Invalid revision (random string)", + options: { modelRevision: "invalid_revision" }, + expectedError: /Invalid value/, + }, + { + description: "Invalid revision (too many version numbers)", + options: { tokenizerRevision: "v1.0.3.4.5" }, + expectedError: /Invalid value/, + }, + { + description: "Invalid revision (unknown suffix)", + options: { processorRevision: "v1.0.0-unknown" }, + expectedError: /Invalid value/, + }, + + // Valid revision cases with new format + { + description: "Valid revision (main)", + options: { modelRevision: "main" }, + expected: { modelRevision: "main" }, + }, + { + description: "Valid revision (v-prefixed version with alpha)", + options: { tokenizerRevision: "v1.2.3-alpha1" }, + expected: { tokenizerRevision: "v1.2.3-alpha1" }, + }, + { + description: + "Valid revision (v-prefixed version with beta and dot separator)", + options: { tokenizerRevision: "v1.2.3.beta2" }, + expected: { tokenizerRevision: "v1.2.3.beta2" }, + }, + { + description: + "Valid revision (non-prefixed version with rc and dash separator)", + options: { processorRevision: "1.0.0-rc3" }, + expected: { processorRevision: "1.0.0-rc3" }, + }, + { + description: + "Valid revision (non-prefixed version with pre and dot separator)", + options: { processorRevision: "1.0.0.pre4" }, + expected: { processorRevision: "1.0.0.pre4" }, + }, + { + description: "Valid revision (version without suffix)", + options: { modelRevision: "1.0.0" }, + expected: { modelRevision: "1.0.0" }, + }, + + // Valid engineID cases + { + description: "Valid engineID (qwen)", + options: { engineId: "SUM-ONNX-COMMUNITY_QWEN2_5-0_5B-INSTRUCT_BIG" }, + expected: { engineId: "SUM-ONNX-COMMUNITY_QWEN2_5-0_5B-INSTRUCT_BIG" }, + }, +]; + +/** + * Testing PipelineOption validation + */ +add_task(async function test_pipeline_options_validation() { + pipelineOptionsCases.forEach(testCase => { + if (testCase.expectedError) { + Assert.throws( + () => new PipelineOptions(getValidOptions(testCase.options)), + testCase.expectedError, + `${testCase.description} throws the expected error` + ); + } else { + const pipelineOptions = new PipelineOptions( + getValidOptions(testCase.options) + ); + Object.keys(testCase.expected).forEach(key => { + is( + pipelineOptions[key], + testCase.expected[key], + `${testCase.description} sets ${key} correctly` + ); + }); + } + }); +}); + +add_task(async function test_ml_engine_infinite_worker() { + const { cleanup, remoteClients } = await setup(); + + const options = { taskName: "moz-echo", timeoutMS: -1 }; + const engineInstance = await createEngine(options); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + const res = await inferencePromise; + Assert.equal( + res.output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal(res.output.timeoutMS, -1, "This should be an infinite worker."); + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + + await cleanup(); +}); + +add_task(async function test_ml_engine_model_hub_applied() { + const options = { + taskName: "moz-echo", + timeoutMS: -1, + modelHub: "huggingface", + }; + const parsedOptions = new PipelineOptions(options); + + Assert.equal( + parsedOptions.modelHubRootUrl, + "https://huggingface.co/", + "modelHubRootUrl is set" + ); + + Assert.equal( + parsedOptions.modelHubUrlTemplate, + "{model}/resolve/{revision}", + "modelHubUrlTemplate is set" + ); +}); + +add_task(async function test_ml_engine_blessed_model() { + const { cleanup, remoteClients } = await setup(); + + const options = { taskName: "test-echo" }; + const engineInstance = await createEngine(options); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + const res = await inferencePromise; + + Assert.equal( + res.config.modelId, + "test-echo", + "The blessed model was picked." + ); + + Assert.equal(res.config.dtype, "q8", "With the right quantization level"); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + + await cleanup(); +}); + +add_task(async function test_ml_engine_two_tasknames_in_rs() { + // RS has two records with the same taskName + // we should use the modelId match in that case + const records = [ + { + taskName: "moz-echo", + modelId: "mozilla/anothermodel", + processorId: "mozilla/distilvit", + tokenizerId: "mozilla/distilvit", + modelRevision: "main", + processorRevision: "main", + tokenizerRevision: "main", + dtype: "q8", + id: "74a71cfd-1734-44e6-85c0-69cf3e874138", + }, + { + taskName: "moz-echo", + modelId: "mozilla/distilvit", + processorId: "mozilla/distilvit", + tokenizerId: "mozilla/distilvit", + modelRevision: "v1.0", + processorRevision: "v1.0", + tokenizerRevision: "v1.0", + dtype: "fp16", + id: "74a71cfd-1734-44e6-85c0-69cf3e874138", + }, + ]; + + const { cleanup, remoteClients } = await setup({ records }); + + info("Get the engine"); + const engineInstance = await createEngine({ + featureId: "pdfjs-alt-text", + taskName: "moz-echo", + }); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + const res = await inferencePromise; + Assert.equal( + res.output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal( + res.output.dtype, + "fp16", + "The config was enriched by RS - using a feature Id" + ); + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + + await cleanup(); +}); + +add_task( + async function test_override_ml_engine_pipeline_options_in_allow_list() { + const { cleanup, remoteClients } = await setup(); + await SpecialPowers.pushPrefEnv({ + set: [ + [ + "browser.ml.overridePipelineOptions", + '{"about-inference": {"modelRevision": "v0.2.0"}}', + ], + ], + }); + + info("Get the engine"); + const engineInstance = await createEngine({ + taskName: "moz-echo", + featureId: "about-inference", + }); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal( + (await inferencePromise).output.modelRevision, + "v0.2.0", + "The config options goes through and overrides." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); + } +); + +add_task(async function test_override_ml_pipeline_options_not_in_allow_list() { + const { cleanup, remoteClients } = await setup(); + await SpecialPowers.pushPrefEnv({ + set: [ + [ + "browser.ml.overridePipelineOptions", + '{"about-inferences": {"modelRevision": "v0.2.0"}}', + ], + ], + }); + + info("Get the engine"); + const engineInstance = await createEngine({ + taskName: "moz-echo", + featureId: "about-inference", + }); + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal( + (await inferencePromise).output.modelRevision, + "main", + "The config options goes through and overrides." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +add_task(async function test_override_ml_pipeline_options_unsafe_options() { + const { cleanup, remoteClients } = await setup(); + await SpecialPowers.pushPrefEnv({ + set: [ + [ + "browser.ml.overridePipelineOptions", + '{"about-inference": {"modelRevision": "v0.2.0", "modelId": "unsafe-model-id"}}', + ], + ], + }); + + info("Get the engine"); + const engineInstance = await createEngine({ + taskName: "moz-echo", + featureId: "about-inference", + }); + + info("Check the inference process is running"); + Assert.equal(await checkForRemoteType("inference"), true); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text get echoed exercising the whole flow." + ); + + Assert.equal( + (await inferencePromise).output.modelRevision, + "v0.2.0", + "The config options goes through and overrides." + ); + + Assert.equal( + (await inferencePromise).output.modelId, + "mozilla/distilvit", + "The config should not override." + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +add_task(async function test_q8_by_default() { + const { cleanup, remoteClients } = await setup(); + + info("Get the engine"); + const engineInstance = await createEngine({ + taskName: "moz-echo", + modelId: "Xenova/distilbart-cnn-6-6", + modelHub: "huggingface", + }); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text gets echoed exercising the whole flow." + ); + + Assert.equal( + (await inferencePromise).output.dtype, + "q8", + "dtype should be set to q8" + ); + + // the model hub sets the revision + Assert.equal( + (await inferencePromise).output.modelRevision, + "main", + "modelRevision should be main" + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +add_task(async function test_hub_by_default() { + const { cleanup, remoteClients } = await setup(); + + info("Get the engine"); + const engineInstance = await createEngine({ + taskName: "moz-echo", + }); + + info("Run the inference"); + const inferencePromise = engineInstance.run({ data: "This gets echoed." }); + + info("Wait for the pending downloads."); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + + Assert.equal( + (await inferencePromise).output.echo, + "This gets echoed.", + "The text gets echoed exercising the whole flow." + ); + + Assert.equal( + (await inferencePromise).output.modelHubUrlTemplate, + "{model}/{revision}", + "Default template should be model/revision" + ); + + Assert.equal( + (await inferencePromise).output.modelRevision, + "main", + "modelRevision should be main" + ); + + ok( + !EngineProcess.areAllEnginesTerminated(), + "The engine process is still active." + ); + + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +add_task(async function test_openai_client() { + const records = [ + { + featureId: "about-inference", + taskName: "text-generation", + modelId: "qwen3:0.6b", + modelRevision: "main", + id: "74a71cfd-1734-44e6-85c0-69cf3e874138", + }, + ]; + const { cleanup } = await setup({ records }); + const { server: mockServer, port } = startMockOpenAI({ + echo: "This gets echoed.", + }); + + const engineInstance = await createEngine({ + featureId: "about-inference", + task: "text-generation", + modelId: "qwen3:0.6b", + modelRevision: "main", + apiKey: "ollama", + baseURL: `http://localhost:${port}/v1`, + backend: "openai", + }); + + const request = { + args: [ + { + role: "system", + content: + "You are a helpful assistant that summarizes text clearly and concisely.", + }, + { + role: "user", + content: `Please summarize the following text:\n\n blah bla`, + }, + ], + }; + + try { + info("Run the inference"); + const inferencePromise = engineInstance.run(request); + + const result = await inferencePromise; + + Assert.equal( + result.finalOutput, + "This is a mock summary for testing end-to-end flow." + ); + } finally { + await EngineProcess.destroyMLEngine(); + await cleanup(); + await stopMockOpenAI(mockServer); + } +}); + +add_task(async function test_openai_client_tools_non_streaming() { + const records = [ + { + ...BASE_ENGINE_OPTIONS, + id: "74a71cfd-1734-44e6-85c0-69cf3e874138", + }, + ]; + const { cleanup } = await setup({ records }); + const { server: mockServer, port } = startMockOpenAI(); + + const engineInstance = await createEngine({ + ...BASE_ENGINE_OPTIONS, + apiKey: "ollama", + baseURL: `http://localhost:${port}/v1`, + backend: "openai", + }); + + // First request: ask with tools; server responds with tool_calls + const requestWithTools = { + args: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "Find my open news tabs." }, + ], + tools: SHARED_TOOLS, + }; + + try { + info("Run request that triggers tool calls"); + const result1 = await engineInstance.run(requestWithTools); + + // The pipeline should surface toolCalls from the OpenAI message + Assert.ok(result1.toolCalls, "toolCalls should exist on the result"); + Assert.equal(result1.toolCalls.length, 1, "Exactly one tool call"); + Assert.equal( + result1.toolCalls[0].function.name, + "search_open_tabs", + "Tool name should match" + ); + + // Second request: append assistant tool_calls + our tool result + const assistantToolCallsMsg = { + role: "assistant", + tool_calls: result1.toolCalls.map(tc => ({ + id: tc.id, + type: "function", + function: { + name: tc.function.name, + arguments: tc.function.arguments, + }, + })), + }; + + const toolResultMsg = { + role: "tool", + tool_call_id: result1.toolCalls[0].id, + content: JSON.stringify({ query: "news", allTabs: [] }), + }; + + const followup = await engineInstance.run({ + args: [...requestWithTools.args, assistantToolCallsMsg, toolResultMsg], + tools: requestWithTools.tools, // still valid to include + }); + + Assert.equal( + followup.finalOutput, + "Here are the tabs I found for you.", + "Should get assistant follow-up after tool result" + ); + } finally { + await EngineProcess.destroyMLEngine(); + await cleanup(); + await stopMockOpenAI(mockServer); + } +}); + +add_task(async function test_openai_client_tools_streaming() { + const records = [ + { + ...BASE_ENGINE_OPTIONS, + id: "b3b2b661-daa6-4b7f-8d3c-7db0df0dbeef", + }, + ]; + const { cleanup } = await setup({ records }); + const { server: mockServer, port } = startMockOpenAI(); + + const engineInstance = await createEngine({ + ...BASE_ENGINE_OPTIONS, + apiKey: "ollama", + baseURL: `http://localhost:${port}/v1`, + backend: "openai", + }); + + const starter = { + args: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "Find my open news tabs." }, + ], + tools: SHARED_TOOLS, + streamOptions: { enabled: true }, + }; + + try { + // --- First turn: expect tool_calls via streaming --- + const gen = engineInstance.runWithGenerator(starter); + + let toolCalls = null; + for await (const chunk of gen) { + // Your MLEngineParent + OpenAIPipeline put toolCalls onto the yielded chunk + if (chunk.toolCalls && chunk.toolCalls.length) { + toolCalls = chunk.toolCalls; + break; // we end the turn when model asks for tools + } + // (Optional) you could accumulate chunk.text here; expected empty in this turn + } + + Assert.ok(toolCalls, "Should receive toolCalls via streaming"); + Assert.equal(toolCalls.length, 1, "One tool call"); + Assert.equal( + toolCalls[0].function.name, + "search_open_tabs", + "Tool name should match" + ); + + // --- Second turn: send tool result, stream final answer --- + const assistantToolCallsMsg = { + role: "assistant", + tool_calls: toolCalls.map(tc => ({ + id: tc.id, + type: "function", + function: { + name: tc.function.name, + arguments: tc.function.arguments, + }, + })), + }; + + const toolResultMsg = { + role: "tool", + tool_call_id: toolCalls[0].id, + content: JSON.stringify({ query: "news", allTabs: [] }), + }; + + const gen2 = engineInstance.runWithGenerator({ + args: [...starter.args, assistantToolCallsMsg, toolResultMsg], + tools: SHARED_TOOLS, + streamOptions: { enabled: true }, + }); + + let final = ""; + for await (const chunk of gen2) { + if (chunk.text) { + final += chunk.text; + } + } + + Assert.ok(final.length, "Should stream some final content"); + Assert.equal( + final, + "Here are the tabs I found for you.", + "Should stream the expected assistant follow-up" + ); + } finally { + await EngineProcess.destroyMLEngine(); + await cleanup(); + await stopMockOpenAI(mockServer); + } +}); diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine_e2e.js b/toolkit/components/ml/tests/browser/browser_ml_engine_e2e.js @@ -1,544 +0,0 @@ -/* Any copyright is dedicated to the Public Domain. - https://creativecommons.org/publicdomain/zero/1.0/ */ - -"use strict"; - -const { sinon } = ChromeUtils.importESModule( - "resource://testing-common/Sinon.sys.mjs" -); - -const { BACKENDS } = ChromeUtils.importESModule( - "chrome://global/content/ml/EngineProcess.sys.mjs" -); - -const { MLUtils } = ChromeUtils.importESModule( - "chrome://global/content/ml/Utils.sys.mjs" -); - -/** - * End to End test that the engine is indeed initialized with wllama when it is the - * best-llama. - */ -add_task(async function test_e2e_choose_backend_best_wllama() { - // Allow any url - Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); - - const backendData = new Uint8Array([10, 20, 30]); - - const expectedBackendData = JSON.stringify(backendData); - - // Mocking function used in the workers or child doesn't work. - // So we are stubbing the code run by the worker. - const workerCode = ` - // Inject the MLEngine.worker.mjs code - - ${await getMLEngineWorkerCode()} - - // Stub - ChromeUtils.defineESModuleGetters( - lazy, - { - createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", - - }, - { global: "current" } -); - - // Change the getBackend to a mocked version that doesn't actually do inference - // but does initiate model downloads and engine initialization - - lazy.getBackend = async function ( - mlEngineWorker, - backendData, - { - modelHubUrlTemplate, - modelHubRootUrl, - modelId, - modelRevision, - modelFile, - engineId, - } = {} - ) { - - - const receivedBackendData = JSON.stringify(backendData); - if (receivedBackendData !== '${expectedBackendData}'){ - throw new Error("BackendData not equal Received: ".concat(receivedBackendData, " Expected: ", '${expectedBackendData}')); - } - - return { - run: () => {}, - }; - }; -`; - - const blob = new Blob([workerCode], { type: "application/javascript" }); - const blobURL = URL.createObjectURL(blob); - - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - - let promiseStub = sinon - .stub(MLEngineParent, "getWorkerConfig") - .callsFake(function () { - return { url: blobURL, options: { type: "module" } }; - }); - - let wasmBufferStub = sinon - .stub(MLEngineParent, "getWasmArrayBuffer") - .returns(backendData); - - let chooseBestBackendStub = sinon - .stub(MLEngineParent, "chooseBestBackend") - .returns(BACKENDS.wllama); - - try { - await createEngine({ - engineId: "main", - taskName: "real-wllama-text-generation", - featureId: "link-preview", - backend: BACKENDS.bestLlama, - modelId: "acme/bert", - modelHubUrlTemplate: "{model}/resolve/{revision}", - modelRevision: "v0.4", - modelHubRootUrl: - "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", - modelFile: "onnx/config.json", - }); - } finally { - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - wasmBufferStub.restore(); - promiseStub.restore(); - chooseBestBackendStub.restore(); - } -}); - -/** - * End to End test that the engine can indeed fail if it doesn't use best-llama. - */ -add_task(async function test_e2e_choose_backend_can_detect_failure() { - // Allow any url - Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); - - const backendData = new Uint8Array([10, 20, 30]); - - const expectedBackendData = JSON.stringify("data so no matches"); - - // Mocking function used in the workers or child doesn't work. - // So we are stubbing the code run by the worker. - const workerCode = ` - // Inject the MLEngine.worker.mjs code - - ${await getMLEngineWorkerCode()} - - // Stub - ChromeUtils.defineESModuleGetters( - lazy, - { - createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", - - }, - { global: "current" } -); - - // Change the getBackend to a mocked version that doesn't actually do inference - // but does initiate model downloads and engine initialization - - lazy.getBackend = async function ( - mlEngineWorker, - backendData, - { - modelHubUrlTemplate, - modelHubRootUrl, - modelId, - modelRevision, - modelFile, - engineId, - } = {} - ) { - - - const receivedBackendData = JSON.stringify(backendData); - if (receivedBackendData !== '${expectedBackendData}'){ - throw new Error("BackendData not equal Received: ".concat(receivedBackendData, " Expected: ", '${expectedBackendData}')); - } - - return { - run: () => {}, - }; - }; -`; - - const blob = new Blob([workerCode], { type: "application/javascript" }); - const blobURL = URL.createObjectURL(blob); - - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - - let promiseStub = sinon - .stub(MLEngineParent, "getWorkerConfig") - .callsFake(function () { - return { url: blobURL, options: { type: "module" } }; - }); - - let wasmBufferStub = sinon - .stub(MLEngineParent, "getWasmArrayBuffer") - .returns(backendData); - - let chooseBestBackendStub = sinon - .stub(MLEngineParent, "chooseBestBackend") - .returns(BACKENDS.wllama); - - try { - await Assert.rejects( - createEngine({ - engineId: "main", - taskName: "real-wllama-text-generation", - featureId: "link-preview", - backend: BACKENDS.bestLlama, - modelId: "acme/bert", - modelHubUrlTemplate: "{model}/resolve/{revision}", - modelRevision: "v0.4", - modelHubRootUrl: - "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", - modelFile: "onnx/config.json", - }), - /BackendData not equal Received:/, - "The call should be rejected because it used the wrong backend" - ); - } finally { - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - wasmBufferStub.restore(); - promiseStub.restore(); - chooseBestBackendStub.restore(); - } -}); - -/** - * End to End test that the engine is indeed initialized with llama.cpp when it is the - * best-llama. - */ -add_task(async function test_e2e_choose_backend_best_llamma_cpp() { - // Allow any url - Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); - - const backendData = new Uint8Array([10, 20, 30]); - - const expectedBackendData = JSON.stringify(null); - - // Mocking function used in the workers or child doesn't work. - // So we are stubbing the code run by the worker. - const workerCode = ` - // Inject the MLEngine.worker.mjs code - - ${await getMLEngineWorkerCode()} - - // Stub - ChromeUtils.defineESModuleGetters( - lazy, - { - createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", - - }, - { global: "current" } -); - - // Change the getBackend to a mocked version that doesn't actually do inference - // but does initiate model downloads and engine initialization - - lazy.getBackend = async function ( - mlEngineWorker, - backendData, - { - modelHubUrlTemplate, - modelHubRootUrl, - modelId, - modelRevision, - modelFile, - engineId, - } = {} - ) { - - - const receivedBackendData = JSON.stringify(backendData); - if (receivedBackendData !== '${expectedBackendData}'){ - throw new Error("BackendData not equal Received: ".concat(receivedBackendData, " Expected: ", '${expectedBackendData}')); - } - - return { - run: () => {}, - }; - }; -`; - - const blob = new Blob([workerCode], { type: "application/javascript" }); - const blobURL = URL.createObjectURL(blob); - - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - - let promiseStub = sinon - .stub(MLEngineParent, "getWorkerConfig") - .callsFake(function () { - return { url: blobURL, options: { type: "module" } }; - }); - - let wasmBufferStub = sinon - .stub(MLEngineParent, "getWasmArrayBuffer") - .returns(backendData); - - let chooseBestBackendStub = sinon - .stub(MLEngineParent, "chooseBestBackend") - .returns(BACKENDS.llamaCpp); - - try { - await createEngine({ - engineId: "main", - taskName: "real-wllama-text-generation", - featureId: "link-preview", - backend: BACKENDS.bestLlama, - modelId: "acme/bert", - modelHubUrlTemplate: "{model}/resolve/{revision}", - modelRevision: "v0.4", - modelHubRootUrl: - "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", - modelFile: "onnx/config.json", - }); - } finally { - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - wasmBufferStub.restore(); - promiseStub.restore(); - chooseBestBackendStub.restore(); - } -}); - -/** - * End to End test that the engine can be cancelled. - */ -add_task(async function test_e2e_engine_can_be_cancelled() { - // Allow any url - Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); - - const backendData = new Uint8Array([10, 20, 30]); - - // Mocking function used in the workers or child doesn't work. - // So we are stubbing the code run by the worker. - const workerCode = ` - // Inject the MLEngine.worker.mjs code - - ${await getMLEngineWorkerCode()} - - // Stub - ChromeUtils.defineESModuleGetters( - lazy, - { - createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", - - }, - { global: "current" } -); - - // Change the getBackend to a mocked version that doesn't actually do inference - // but does initiate model downloads and engine initialization - - lazy.getBackend = async function ( - mlEngineWorker, - backendData, - { - modelHubUrlTemplate, - modelHubRootUrl, - modelId, - modelRevision, - modelFile, - engineId, - } = {} - ) { - - const url = lazy.createFileUrl({ - model: modelId, - revision: modelRevision, - file: modelFile, - urlTemplate: modelHubUrlTemplate, - rootUrl: modelHubRootUrl, - }); - - await mlEngineWorker.getModelFile({url}); - - return { - run: () => {}, - }; - }; -`; - - const blob = new Blob([workerCode], { type: "application/javascript" }); - const blobURL = URL.createObjectURL(blob); - - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - - let promiseStub = sinon - .stub(MLEngineParent, "getWorkerConfig") - .callsFake(function () { - return { url: blobURL, options: { type: "module" } }; - }); - - let wasmBufferStub = sinon - .stub(MLEngineParent, "getWasmArrayBuffer") - .returns(backendData); - - const controller = new AbortController(); - const { signal } = controller; - controller.abort(); - - try { - await Assert.rejects( - createEngine( - { - engineId: "main5", - taskName: "real-wllama-text-generation", - featureId: "link-preview", - backend: BACKENDS.llamaCpp, - modelId: "acme/bert", - modelHubUrlTemplate: "{model}/resolve/{revision}", - modelRevision: "v0.1", - modelHubRootUrl: - "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", - modelFile: "onnx/config.json", - }, - null, - signal - ), - /AbortError:/, - "The call should be cancelled" - ); - } catch (err) { - Assert.ok(false, `Expected AbortError. Got ${err}`); - } finally { - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - wasmBufferStub.restore(); - promiseStub.restore(); - } -}); - -/** - * End to End test that the engine can be cancelled after fetch success. - */ -add_task(async function test_e2e_engine_can_be_cancelled_after_fetch() { - // Allow any url - Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "true"); - - const backendData = new Uint8Array([10, 20, 30]); - - // Mocking function used in the workers or child doesn't work. - // So we are stubbing the code run by the worker. - const workerCode = ` - // Inject the MLEngine.worker.mjs code - - ${await getMLEngineWorkerCode()} - - // Stub - ChromeUtils.defineESModuleGetters( - lazy, - { - createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", - - }, - { global: "current" } -); - - // Change the getBackend to a mocked version that doesn't actually do inference - // but does initiate model downloads and engine initialization - - lazy.getBackend = async function ( - mlEngineWorker, - backendData, - { - modelHubUrlTemplate, - modelHubRootUrl, - modelId, - modelRevision, - modelFile, - engineId, - } = {} - ) { - - const url = lazy.createFileUrl({ - model: modelId, - revision: modelRevision, - file: modelFile, - urlTemplate: modelHubUrlTemplate, - rootUrl: modelHubRootUrl, - }); - - await mlEngineWorker.getModelFile({url}); - - return { - run: () => {}, - }; - }; -`; - - const blob = new Blob([workerCode], { type: "application/javascript" }); - const blobURL = URL.createObjectURL(blob); - - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - - let promiseStub = sinon - .stub(MLEngineParent, "getWorkerConfig") - .callsFake(function () { - return { url: blobURL, options: { type: "module" } }; - }); - - let wasmBufferStub = sinon - .stub(MLEngineParent, "getWasmArrayBuffer") - .returns(backendData); - - const controller = new AbortController(); - const { signal } = controller; - - const fetchUrlStub = sinon - .stub(MLUtils, "fetchUrl") - .callsFake((url, { signal: _, ...rest } = {}) => { - const p = fetch(url, rest); - - controller.abort(); - - return p; - }); - - try { - await Assert.rejects( - createEngine( - { - engineId: "main5", - taskName: "real-wllama-text-generation", - featureId: "link-preview", - backend: BACKENDS.llamaCpp, - modelId: "acme/bert", - modelHubUrlTemplate: "{model}/resolve/{revision}", - modelRevision: "v0.1", - modelHubRootUrl: - "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", - modelFile: "onnx/config.json", - }, - null, - signal - ), - /AbortError:/, - "The call should be cancelled" - ); - } catch (err) { - Assert.ok(false, `Expected AbortError. Got ${err}`); - } finally { - await EngineProcess.destroyMLEngine(); - await IndexedDBCache.init({ reset: true }); - wasmBufferStub.restore(); - promiseStub.restore(); - fetchUrlStub.restore(); - } -}); diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine_lifetime.js b/toolkit/components/ml/tests/browser/browser_ml_engine_lifetime.js @@ -1,520 +0,0 @@ -/* Any copyright is dedicated to the Public Domain. - https://creativecommons.org/publicdomain/zero/1.0/ */ - -"use strict"; - -const MOZ_ECHO_OPTIONS_RAW = { taskName: "moz-echo", timeoutMS: -1 }; -const MOZ_ECHO_OPTIONS = new PipelineOptions({ - taskName: "moz-echo", - timeoutMS: -1, -}); - -/** - * Performing a basic engine initialization and run. - */ -add_task(async function test_ml_engine_basics() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine"); - const engineInstance = await createEngine(MOZ_ECHO_OPTIONS_RAW); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - const res = await inferencePromise; - Assert.equal( - res.output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal(res.output.dtype, "q8", "The config was enriched by RS"); - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - - await cleanup(); -}); - -/** - * Test the Wasm failing to download triggering a rejection. - */ -add_task(async function test_ml_engine_wasm_rejection() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine"); - const engineInstance = await createEngine(MOZ_ECHO_OPTIONS_RAW); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].rejectPendingDownloads(1); - - let error; - try { - await inferencePromise; - } catch (e) { - error = e; - } - - is( - error?.message, - "Intentionally rejecting downloads.", - "The error is correctly surfaced." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Make sure we don't get race conditions when running several inference runs in parallel - */ -add_task(async function test_ml_engine_parallel() { - const { cleanup, remoteClients } = await setup(); - - // We're doing 10 calls and each echo call will take from 0 to 1000ms - // So we're sure we're mixing runs. - let sleepTimes = [300, 1000, 700, 0, 500, 900, 400, 800, 600, 100]; - let numCalls = 10; - - const enginesSeen = new Set(); - async function run(x) { - const engineInstance = await createEngine(MOZ_ECHO_OPTIONS_RAW); - enginesSeen.add(engineInstance); - - let msg = `${x} - This gets echoed.`; - let res = engineInstance.run({ - data: msg, - sleepTime: sleepTimes[x], - }); - - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - res = await res; - - return res; - } - - info(`Run ${numCalls} inferences in parallel`); - let runs = []; - for (let x = 0; x < numCalls; x++) { - runs.push(run(x)); - } - - // await all runs - const results = await Promise.all(runs); - Assert.equal(results.length, numCalls, `All ${numCalls} were successful`); - - // check that each one got their own stuff - for (let y = 0; y < numCalls; y++) { - Assert.equal( - results[y].output.echo, - `${y} - This gets echoed.`, - `Result ${y} is correct` - ); - } - - Assert.equal(enginesSeen.size, 1, "Only one engine was created."); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - - await cleanup(); -}); - -/** - * Tests that the engineInstanceModel's internal errors are correctly surfaced. - */ -add_task(async function test_ml_engine_model_error() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine"); - const engineInstance = await createEngine(MOZ_ECHO_OPTIONS_RAW); - - info("Run the inference with a throwing example."); - const inferencePromise = engineInstance.run("throw"); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - let error; - try { - await inferencePromise; - } catch (e) { - error = e; - } - is( - error?.message, - 'Error: Received the message "throw", so intentionally throwing an error.', - "The error is correctly surfaced." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * This test is really similar to the "basic" test, but tests manually destroying - * the engineInstance. - */ -add_task(async function test_ml_engine_destruction() { - const { cleanup, remoteClients } = await setup(); - - info("Get engineInstance"); - const engineInstance = await createEngine(MOZ_ECHO_OPTIONS); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await engineInstance.terminate( - /* shutDownIfEmpty */ true, - /* replacement */ false - ); - - info( - "The engineInstance is manually destroyed. The cleanup function should wait for the engine process to be destroyed." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Tests creating an engine after an error. - */ -add_task(async function test_ml_engine_model_error() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine"); - const engineInstance = await createEngine(MOZ_ECHO_OPTIONS_RAW); - - info("Run the inference with a throwing example."); - const inferencePromise = engineInstance.run("throw"); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - let error; - try { - await inferencePromise; - } catch (e) { - error = e; - } - is( - error?.message, - 'Error: Received the message "throw", so intentionally throwing an error.', - "The error is correctly surfaced." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Tests that we display a nice error message when the "browser.ml.enable" pref is off. - */ -add_task(async function test_pref_is_off() { - await SpecialPowers.pushPrefEnv({ - set: [["browser.ml.enable", false]], - }); - - info("Get the engine process"); - let error; - - try { - await EngineProcess.getMLEngineParent(); - } catch (e) { - error = e; - } - is( - error?.message, - "MLEngine is disabled. Check the browser.ml prefs.", - "The error is correctly surfaced." - ); - - await SpecialPowers.pushPrefEnv({ - set: [["browser.ml.enable", true]], - }); -}); - -/** - * Tests that the engine is reused. - */ -add_task(async function test_ml_engine_reuse_same() { - const { cleanup, remoteClients } = await setup(); - - const options = { taskName: "moz-echo", engineId: "echo" }; - const engineInstance = await createEngine(options); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - let engineInstance2 = await createEngine(options); - is(engineInstance2.engineId, "echo", "The engine ID matches"); - is(engineInstance, engineInstance2, "The engine is reused."); - const inferencePromise2 = engineInstance2.run({ data: "This gets echoed." }); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise2).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Tests that we can have two competing engines - */ -add_task(async function test_ml_two_engines() { - const { cleanup, remoteClients } = await setup(); - - const engineInstance = await createEngine({ - taskName: "moz-echo", - engineId: "engine1", - }); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - let engineInstance2 = await createEngine({ - taskName: "moz-echo", - engineId: "engine2", - }); - - const inferencePromise2 = engineInstance2.run({ data: "This gets echoed." }); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise2).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.notEqual( - engineInstance.engineId, - engineInstance2.engineId, - "Should be different engines" - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Tests that we can have the same engine reinitialized - */ -add_task(async function test_ml_dupe_engines() { - const { cleanup, remoteClients } = await setup(); - - const engineInstance = await createEngine({ - taskName: "moz-echo", - engineId: "engine1", - }); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - let engineInstance2 = await createEngine({ - taskName: "moz-echo", - engineId: "engine1", - timeoutMS: 2000, // that makes the options different - }); - const inferencePromise2 = engineInstance2.run({ data: "This gets echoed." }); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise2).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.notEqual( - engineInstance, - engineInstance2, - "Should be different engines" - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Tests that a worker can have an infinite timeout. - */ -add_task(async function test_ml_engine_infinite_worker() { - const { cleanup, remoteClients } = await setup(); - - const options = { taskName: "moz-echo", timeoutMS: -1 }; - const engineInstance = await createEngine(options); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - const res = await inferencePromise; - Assert.equal( - res.output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal(res.output.timeoutMS, -1, "This should be an infinite worker."); - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - - await cleanup(); -}); - -/** - * These status are visualized in about:inference, but aren't used for business - * logic. - */ -add_task(async function test_ml_engine_get_status_by_engine_id() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine"); - const engineInstance = await createEngine({ taskName: "moz-echo" }); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - const res = await inferencePromise; - Assert.equal( - res.output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - const expected = { - "default-engine": { - status: "IDLE", - options: { - useExternalDataFormat: false, - engineId: "default-engine", - featureId: null, - taskName: "moz-echo", - timeoutMS: 1000, - modelHubRootUrl: - "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data", - modelHubUrlTemplate: "{model}/{revision}", - modelId: "mozilla/distilvit", - modelRevision: "main", - tokenizerId: "mozilla/distilvit", - tokenizerRevision: "main", - processorId: "mozilla/distilvit", - processorRevision: "main", - logLevel: "All", - runtimeFilename: "ort-wasm-simd-threaded.jsep.wasm", - staticEmbeddingsOptions: null, - device: null, - dtype: "q8", - numThreads: "NOT_COMPARED", - executionPriority: null, - kvCacheDtype: null, - numContext: 1024, - numBatch: 1024, - numUbatch: 1024, - flashAttn: false, - useMmap: false, - useMlock: true, - numThreadsDecoding: null, - modelFile: null, - backend: null, - modelHub: null, - baseURL: null, - apiKey: null, - }, - }, - }; - - const statusByEngineId = Object.fromEntries( - await engineInstance.mlEngineParent.getStatusByEngineId() - ); - statusByEngineId["default-engine"].options.numThreads = "NOT_COMPARED"; - Assert.deepEqual(statusByEngineId, expected); - - await ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - - await cleanup(); -}); diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine_pipeline_options.js b/toolkit/components/ml/tests/browser/browser_ml_engine_pipeline_options.js @@ -1,825 +0,0 @@ -/* Any copyright is dedicated to the Public Domain. - https://creativecommons.org/publicdomain/zero/1.0/ */ - -"use strict"; - -/** - * Test that model PipelineOptions can override the defaults. - */ -add_task(async function test_ml_engine_override_options() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine"); - const engineInstance = await createEngine({ - taskName: "moz-echo", - modelRevision: "v1", - }); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal( - (await inferencePromise).output.modelRevision, - "v1", - "The config options goes through and overrides." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Verify that features such as the dtype can be picked up via Remote Settings. - */ -add_task(async function test_ml_engine_pick_feature_id() { - // one record sent back from RS contains featureId - const records = [ - { - taskName: "moz-echo", - modelId: "mozilla/distilvit", - processorId: "mozilla/distilvit", - tokenizerId: "mozilla/distilvit", - modelRevision: "main", - processorRevision: "main", - tokenizerRevision: "main", - dtype: "q8", - id: "74a71cfd-1734-44e6-85c0-69cf3e874138", - }, - { - featureId: "pdfjs-alt-text", - taskName: "moz-echo", - modelId: "mozilla/distilvit", - processorId: "mozilla/distilvit", - tokenizerId: "mozilla/distilvit", - modelRevision: "v1.0", - processorRevision: "v1.0", - tokenizerRevision: "v1.0", - dtype: "fp16", - id: "74a71cfd-1734-44e6-85c0-69cf3e874138", - }, - ]; - - const { cleanup, remoteClients } = await setup({ records }); - - info("Get the engine"); - const engineInstance = await createEngine({ - featureId: "pdfjs-alt-text", - taskName: "moz-echo", - }); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - const res = await inferencePromise; - Assert.equal( - res.output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal( - res.output.dtype, - "fp16", - "The config was enriched by RS - using a feature Id" - ); - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - - await cleanup(); -}); - -/** - * Tests the generic pipeline API - */ -add_task(async function test_ml_generic_pipeline() { - const { cleanup, remoteClients } = await setup(); - - info("Get engineInstance"); - - const options = new PipelineOptions({ - taskName: "summarization", - modelId: "test-echo", - modelRevision: "main", - }); - - const engineInstance = await createEngine(options); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ - args: ["This gets echoed."], - }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Test out the default precision values. - */ -add_task(async function test_q8_by_default() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine"); - const engineInstance = await createEngine({ - taskName: "moz-echo", - modelId: "Xenova/distilbart-cnn-6-6", - modelHub: "huggingface", - }); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text gets echoed exercising the whole flow." - ); - - Assert.equal( - (await inferencePromise).output.dtype, - "q8", - "dtype should be set to q8" - ); - - // the model hub sets the revision - Assert.equal( - (await inferencePromise).output.modelRevision, - "main", - "modelRevision should be main" - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Test that the preference override options only work for the SAFE_OVERRIDE_OPTIONS - * defined in MLEngineChild.sys.mjs - */ -add_task( - async function test_override_ml_engine_pipeline_options_in_allow_list() { - const { cleanup, remoteClients } = await setup(); - await SpecialPowers.pushPrefEnv({ - set: [ - [ - "browser.ml.overridePipelineOptions", - '{"about-inference": {"modelRevision": "v0.2.0"}}', - ], - ], - }); - - info("Get the engine"); - const engineInstance = await createEngine({ - taskName: "moz-echo", - featureId: "about-inference", - }); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal( - (await inferencePromise).output.modelRevision, - "v0.2.0", - "The config options goes through and overrides." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); - } -); - -add_task(async function test_override_ml_pipeline_options_not_in_allow_list() { - const { cleanup, remoteClients } = await setup(); - await SpecialPowers.pushPrefEnv({ - set: [ - [ - "browser.ml.overridePipelineOptions", - '{"about-inferences": {"modelRevision": "v0.2.0"}}', - ], - ], - }); - - info("Get the engine"); - const engineInstance = await createEngine({ - taskName: "moz-echo", - featureId: "about-inference", - }); - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal( - (await inferencePromise).output.modelRevision, - "main", - "The config options goes through and overrides." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Test that an unsanctioned modelId does not get used. - */ -add_task(async function test_override_ml_pipeline_options_unsafe_options() { - const { cleanup, remoteClients } = await setup(); - await SpecialPowers.pushPrefEnv({ - set: [ - [ - "browser.ml.overridePipelineOptions", - '{"about-inference": {"modelRevision": "v0.2.0", "modelId": "unsafe-model-id"}}', - ], - ], - }); - - info("Get the engine"); - const engineInstance = await createEngine({ - taskName: "moz-echo", - featureId: "about-inference", - }); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal( - (await inferencePromise).output.modelRevision, - "v0.2.0", - "The config options goes through and overrides." - ); - - Assert.equal( - (await inferencePromise).output.modelId, - "mozilla/distilvit", - "The config should not override." - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Check that DEFAULT_MODELS are used to pick a preferred model for a given task. - */ -add_task(async function test_ml_engine_blessed_model() { - const { cleanup, remoteClients } = await setup(); - - const options = { taskName: "test-echo" }; - const engineInstance = await createEngine(options); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - const res = await inferencePromise; - - Assert.equal( - res.config.modelId, - "test-echo", - "The blessed model was picked." - ); - - Assert.equal(res.config.dtype, "q8", "With the right quantization level"); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - - await cleanup(); -}); - -add_task(async function test_ml_engine_two_tasknames_in_rs() { - // RS has two records with the same taskName - // we should use the modelId match in that case - const records = [ - { - taskName: "moz-echo", - modelId: "mozilla/anothermodel", - processorId: "mozilla/distilvit", - tokenizerId: "mozilla/distilvit", - modelRevision: "main", - processorRevision: "main", - tokenizerRevision: "main", - dtype: "q8", - id: "74a71cfd-1734-44e6-85c0-69cf3e874138", - }, - { - taskName: "moz-echo", - modelId: "mozilla/distilvit", - processorId: "mozilla/distilvit", - tokenizerId: "mozilla/distilvit", - modelRevision: "v1.0", - processorRevision: "v1.0", - tokenizerRevision: "v1.0", - dtype: "fp16", - id: "74a71cfd-1734-44e6-85c0-69cf3e874138", - }, - ]; - - const { cleanup, remoteClients } = await setup({ records }); - - info("Get the engine"); - const engineInstance = await createEngine({ - featureId: "pdfjs-alt-text", - taskName: "moz-echo", - }); - - info("Check the inference process is running"); - Assert.equal(await checkForRemoteType("inference"), true); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - const res = await inferencePromise; - Assert.equal( - res.output.echo, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal( - res.output.dtype, - "fp16", - "The config was enriched by RS - using a feature Id" - ); - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - - await cleanup(); -}); - -/** - * The modelHub should be applied to the PipelineOptions - */ -add_task(async function test_ml_engine_model_hub_applied() { - const options = { - taskName: "moz-echo", - timeoutMS: -1, - modelHub: "huggingface", - }; - const parsedOptions = new PipelineOptions(options); - - Assert.equal( - parsedOptions.modelHubRootUrl, - "https://huggingface.co/", - "modelHubRootUrl is set" - ); - - Assert.equal( - parsedOptions.modelHubUrlTemplate, - "{model}/resolve/{revision}", - "modelHubUrlTemplate is set" - ); -}); - -/** - * Helper function to create a basic set of valid options - */ -function getValidOptions(overrides = {}) { - return Object.assign( - { - engineId: "validEngine1", - featureId: "pdfjs-alt-text", - taskName: "valid_task", - modelHubRootUrl: "https://example.com", - modelHubUrlTemplate: "https://example.com/{modelId}", - timeoutMS: 5000, - modelId: "validModel", - modelRevision: "v1", - tokenizerId: "validTokenizer", - tokenizerRevision: "v1", - processorId: "validProcessor", - processorRevision: "v1", - logLevel: null, - runtimeFilename: "runtime.wasm", - device: InferenceDevice.GPU, - numThreads: 4, - executionPriority: ExecutionPriority.NORMAL, - }, - overrides - ); -} - -/** - * A collection of test cases for invalid and valid values. - */ -const commonInvalidCases = [ - { description: "Invalid value (special characters)", value: "org1/my!value" }, - { - description: "Invalid value (special characters in organization)", - value: "org@1/my-value", - }, - { description: "Invalid value (missing name part)", value: "org1/" }, - { - description: "Invalid value (invalid characters in name)", - value: "my$value", - }, -]; - -const commonValidCases = [ - { description: "Valid organization/name", value: "org1/my-value" }, - { description: "Valid name only", value: "my-value" }, - { - description: "Valid name with underscores and dashes", - value: "my_value-123", - }, - { - description: "Valid organization with underscores and dashes", - value: "org_123/my-value", - }, -]; - -const pipelineOptionsCases = [ - // Invalid cases for various fields - ...commonInvalidCases.map(test => ({ - description: `Invalid processorId (${test.description})`, - options: { processorId: test.value }, - expectedError: /Invalid value/, - })), - ...commonInvalidCases.map(test => ({ - description: `Invalid tokenizerId (${test.description})`, - options: { tokenizerId: test.value }, - expectedError: /Invalid value/, - })), - ...commonInvalidCases.map(test => ({ - description: `Invalid modelId (${test.description})`, - options: { modelId: test.value }, - expectedError: /Invalid value/, - })), - - // Valid cases for various fields - ...commonValidCases.map(test => ({ - description: `Valid processorId (${test.description})`, - options: { processorId: test.value }, - expected: { processorId: test.value }, - })), - ...commonValidCases.map(test => ({ - description: `Valid tokenizerId (${test.description})`, - options: { tokenizerId: test.value }, - expected: { tokenizerId: test.value }, - })), - ...commonValidCases.map(test => ({ - description: `Valid modelId (${test.description})`, - options: { modelId: test.value }, - expected: { modelId: test.value }, - })), - - // Invalid values - { - description: "Invalid hub", - options: { modelHub: "rogue" }, - expectedError: /Invalid value/, - }, - { - description: "Invalid timeoutMS", - options: { timeoutMS: -3 }, - expectedError: /Invalid value/, - }, - { - description: "Invalid timeoutMS", - options: { timeoutMS: 40000000 }, - expectedError: /Invalid value/, - }, - { - description: "Invalid featureId", - options: { featureId: "unknown" }, - expectedError: /Invalid value/, - }, - { - description: "Invalid dtype", - options: { dtype: "invalid_dtype" }, - expectedError: /Invalid value/, - }, - { - description: "Invalid device", - options: { device: "invalid_device" }, - expectedError: /Invalid value/, - }, - { - description: "Invalid executionPriority", - options: { executionPriority: "invalid_priority" }, - expectedError: /Invalid value/, - }, - { - description: "Invalid logLevel", - options: { logLevel: "invalid_log_level" }, - expectedError: /Invalid value/, - }, - - // Valid values - { - description: "valid hub", - options: { modelHub: "huggingface" }, - expected: { modelHub: "huggingface" }, - }, - { - description: "valid hub", - options: { modelHub: "mozilla" }, - expected: { modelHub: "mozilla" }, - }, - { - description: "valid timeoutMS", - options: { timeoutMS: 12345 }, - expected: { timeoutMS: 12345 }, - }, - { - description: "valid timeoutMS", - options: { timeoutMS: -1 }, - expected: { timeoutMS: -1 }, - }, - - { - description: "Valid dtype", - options: { dtype: QuantizationLevel.FP16 }, - expected: { dtype: QuantizationLevel.FP16 }, - }, - { - description: "Valid device", - options: { device: InferenceDevice.WASM }, - expected: { device: InferenceDevice.WASM }, - }, - { - description: "Valid executionPriority", - options: { executionPriority: ExecutionPriority.HIGH }, - expected: { executionPriority: ExecutionPriority.HIGH }, - }, - { - description: "Valid logLevel (Info)", - options: { logLevel: LogLevel.INFO }, - expected: { logLevel: LogLevel.INFO }, - }, - { - description: "Valid logLevel (Critical)", - options: { logLevel: LogLevel.CRITICAL }, - expected: { logLevel: LogLevel.CRITICAL }, - }, - { - description: "Valid logLevel (All)", - options: { logLevel: LogLevel.ALL }, - expected: { logLevel: LogLevel.ALL }, - }, - { - description: "Valid modelId", - options: { modelId: "Qwen2.5-0.5B-Instruct" }, - expected: { modelId: "Qwen2.5-0.5B-Instruct" }, - }, - - // Invalid revision cases - { - description: "Invalid revision (random string)", - options: { modelRevision: "invalid_revision" }, - expectedError: /Invalid value/, - }, - { - description: "Invalid revision (too many version numbers)", - options: { tokenizerRevision: "v1.0.3.4.5" }, - expectedError: /Invalid value/, - }, - { - description: "Invalid revision (unknown suffix)", - options: { processorRevision: "v1.0.0-unknown" }, - expectedError: /Invalid value/, - }, - - // Valid revision cases with new format - { - description: "Valid revision (main)", - options: { modelRevision: "main" }, - expected: { modelRevision: "main" }, - }, - { - description: "Valid revision (v-prefixed version with alpha)", - options: { tokenizerRevision: "v1.2.3-alpha1" }, - expected: { tokenizerRevision: "v1.2.3-alpha1" }, - }, - { - description: - "Valid revision (v-prefixed version with beta and dot separator)", - options: { tokenizerRevision: "v1.2.3.beta2" }, - expected: { tokenizerRevision: "v1.2.3.beta2" }, - }, - { - description: - "Valid revision (non-prefixed version with rc and dash separator)", - options: { processorRevision: "1.0.0-rc3" }, - expected: { processorRevision: "1.0.0-rc3" }, - }, - { - description: - "Valid revision (non-prefixed version with pre and dot separator)", - options: { processorRevision: "1.0.0.pre4" }, - expected: { processorRevision: "1.0.0.pre4" }, - }, - { - description: "Valid revision (version without suffix)", - options: { modelRevision: "1.0.0" }, - expected: { modelRevision: "1.0.0" }, - }, - - // Valid engineID cases - { - description: "Valid engineID (qwen)", - options: { engineId: "SUM-ONNX-COMMUNITY_QWEN2_5-0_5B-INSTRUCT_BIG" }, - expected: { engineId: "SUM-ONNX-COMMUNITY_QWEN2_5-0_5B-INSTRUCT_BIG" }, - }, -]; - -/** - * Go through all of the pipeline validation test cases. - */ -add_task(async function test_pipeline_options_validation() { - pipelineOptionsCases.forEach(testCase => { - if (testCase.expectedError) { - Assert.throws( - () => new PipelineOptions(getValidOptions(testCase.options)), - testCase.expectedError, - `${testCase.description} throws the expected error` - ); - } else { - const pipelineOptions = new PipelineOptions( - getValidOptions(testCase.options) - ); - Object.keys(testCase.expected).forEach(key => { - is( - pipelineOptions[key], - testCase.expected[key], - `${testCase.description} sets ${key} correctly` - ); - }); - } - }); -}); - -/** - * The pipeline should only be able to be initialized when there is enough memory. - */ -add_task(async function test_ml_engine_not_enough_memory() { - const { cleanup } = await setup({ - prefs: [ - ["browser.ml.checkForMemory", true], - ["browser.ml.minimumPhysicalMemory", 99999], - ], - }); - - info("Get the greedy engine"); - - await Assert.rejects( - createEngine({ - modelId: "testing/greedy", - taskName: "moz-echo", - dtype: "q8", - numThreads: 1, - device: "wasm", - }), - /Not enough physical memory/, - "The call should be rejected because of a lack of memory" - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * This tests that threading is supported. On certain machines this could be false, - * but should be true for our testing infrastructure. - */ -add_task(async function test_ml_threading_support() { - const { cleanup, remoteClients } = await setup(); - - info("Get engineInstance"); - - const options = new PipelineOptions({ - taskName: "summarization", - modelId: "test-echo", - modelRevision: "main", - }); - - const engineInstance = await createEngine(options); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ - args: ["This gets echoed."], - }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - let res = await inferencePromise; - - ok(res.multiThreadSupported, "Multi-thread should be supported"); - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine_rs_hub.js b/toolkit/components/ml/tests/browser/browser_ml_engine_rs_hub.js @@ -1,97 +0,0 @@ -/* Any copyright is dedicated to the Public Domain. - https://creativecommons.org/publicdomain/zero/1.0/ */ - -"use strict"; - -/** - * Test the hub return values by default. - */ -add_task(async function test_hub_by_default() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine"); - const engineInstance = await createEngine({ - taskName: "moz-echo", - }); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ data: "This gets echoed." }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - Assert.equal( - (await inferencePromise).output.echo, - "This gets echoed.", - "The text gets echoed exercising the whole flow." - ); - - Assert.equal( - (await inferencePromise).output.modelHubUrlTemplate, - "{model}/{revision}", - "Default template should be model/revision" - ); - - Assert.equal( - (await inferencePromise).output.modelRevision, - "main", - "modelRevision should be main" - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); - -/** - * Tests that the pipeline can use a custom model hub. - */ -add_task(async function test_ml_custom_hub() { - const { cleanup, remoteClients } = await setup(); - - info("Get engineInstance"); - - const options = new PipelineOptions({ - taskName: "summarization", - modelId: "test-echo", - modelRevision: "main", - modelHubRootUrl: "https://example.com", - modelHubUrlTemplate: "models/{model}/{revision}", - }); - - const engineInstance = await createEngine(options); - - info("Run the inference"); - const inferencePromise = engineInstance.run({ - args: ["This gets echoed."], - }); - - info("Wait for the pending downloads."); - await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - - let res = await inferencePromise; - - Assert.equal( - res.output, - "This gets echoed.", - "The text get echoed exercising the whole flow." - ); - - Assert.equal( - res.config.modelHubRootUrl, - "https://example.com", - "The pipeline used the custom hub" - ); - - ok( - !EngineProcess.areAllEnginesTerminated(), - "The engine process is still active." - ); - - await EngineProcess.destroyMLEngine(); - await cleanup(); -}); diff --git a/toolkit/components/ml/tests/browser/browser_ml_openai.js b/toolkit/components/ml/tests/browser/browser_ml_openai.js @@ -1,253 +0,0 @@ -/* Any copyright is dedicated to the Public Domain. - https://creativecommons.org/publicdomain/zero/1.0/ */ - -"use strict"; - -const BASE_ENGINE_OPTIONS = { - featureId: "about-inference", - taskName: "text-generation", - modelId: "qwen3:0.6b", - modelRevision: "main", -}; - -const SHARED_TOOLS = [ - { - type: "function", - function: { - name: "search_open_tabs", - description: "Search open tabs by type.", - parameters: { - type: "object", - properties: { type: { type: "string" } }, - required: ["type"], - }, - }, - }, -]; - -/** - * Test that createEngine successfully talks to the OpenAI client. - */ -add_task(async function test_openai_client() { - const records = [ - { - featureId: "about-inference", - taskName: "text-generation", - modelId: "qwen3:0.6b", - modelRevision: "main", - id: "74a71cfd-1734-44e6-85c0-69cf3e874138", - }, - ]; - const { cleanup } = await setup({ records }); - const { server: mockServer, port } = startMockOpenAI({ - echo: "This gets echoed.", - }); - - const engineInstance = await createEngine({ - featureId: "about-inference", - task: "text-generation", - modelId: "qwen3:0.6b", - modelRevision: "main", - apiKey: "ollama", - baseURL: `http://localhost:${port}/v1`, - backend: "openai", - }); - - const request = { - args: [ - { - role: "system", - content: - "You are a helpful assistant that summarizes text clearly and concisely.", - }, - { - role: "user", - content: `Please summarize the following text:\n\n blah bla`, - }, - ], - }; - - try { - info("Run the inference"); - const inferencePromise = engineInstance.run(request); - - const result = await inferencePromise; - - Assert.equal( - result.finalOutput, - "This is a mock summary for testing end-to-end flow." - ); - } finally { - await EngineProcess.destroyMLEngine(); - await cleanup(); - await stopMockOpenAI(mockServer); - } -}); - -add_task(async function test_openai_client_tools_non_streaming() { - const records = [ - { - ...BASE_ENGINE_OPTIONS, - id: "74a71cfd-1734-44e6-85c0-69cf3e874138", - }, - ]; - const { cleanup } = await setup({ records }); - const { server: mockServer, port } = startMockOpenAI(); - - const engineInstance = await createEngine({ - ...BASE_ENGINE_OPTIONS, - apiKey: "ollama", - baseURL: `http://localhost:${port}/v1`, - backend: "openai", - }); - - // First request: ask with tools; server responds with tool_calls - const requestWithTools = { - args: [ - { role: "system", content: "You are a helpful assistant." }, - { role: "user", content: "Find my open news tabs." }, - ], - tools: SHARED_TOOLS, - }; - - try { - info("Run request that triggers tool calls"); - const result1 = await engineInstance.run(requestWithTools); - - // The pipeline should surface toolCalls from the OpenAI message - Assert.ok(result1.toolCalls, "toolCalls should exist on the result"); - Assert.equal(result1.toolCalls.length, 1, "Exactly one tool call"); - Assert.equal( - result1.toolCalls[0].function.name, - "search_open_tabs", - "Tool name should match" - ); - - // Second request: append assistant tool_calls + our tool result - const assistantToolCallsMsg = { - role: "assistant", - tool_calls: result1.toolCalls.map(tc => ({ - id: tc.id, - type: "function", - function: { - name: tc.function.name, - arguments: tc.function.arguments, - }, - })), - }; - - const toolResultMsg = { - role: "tool", - tool_call_id: result1.toolCalls[0].id, - content: JSON.stringify({ query: "news", allTabs: [] }), - }; - - const followup = await engineInstance.run({ - args: [...requestWithTools.args, assistantToolCallsMsg, toolResultMsg], - tools: requestWithTools.tools, // still valid to include - }); - - Assert.equal( - followup.finalOutput, - "Here are the tabs I found for you.", - "Should get assistant follow-up after tool result" - ); - } finally { - await EngineProcess.destroyMLEngine(); - await cleanup(); - await stopMockOpenAI(mockServer); - } -}); - -add_task(async function test_openai_client_tools_streaming() { - const records = [ - { - ...BASE_ENGINE_OPTIONS, - id: "b3b2b661-daa6-4b7f-8d3c-7db0df0dbeef", - }, - ]; - const { cleanup } = await setup({ records }); - const { server: mockServer, port } = startMockOpenAI(); - - const engineInstance = await createEngine({ - ...BASE_ENGINE_OPTIONS, - apiKey: "ollama", - baseURL: `http://localhost:${port}/v1`, - backend: "openai", - }); - - const starter = { - args: [ - { role: "system", content: "You are a helpful assistant." }, - { role: "user", content: "Find my open news tabs." }, - ], - tools: SHARED_TOOLS, - streamOptions: { enabled: true }, - }; - - try { - // --- First turn: expect tool_calls via streaming --- - const gen = engineInstance.runWithGenerator(starter); - - let toolCalls = null; - for await (const chunk of gen) { - // Your MLEngineParent + OpenAIPipeline put toolCalls onto the yielded chunk - if (chunk.toolCalls && chunk.toolCalls.length) { - toolCalls = chunk.toolCalls; - break; // we end the turn when model asks for tools - } - // (Optional) you could accumulate chunk.text here; expected empty in this turn - } - - Assert.ok(toolCalls, "Should receive toolCalls via streaming"); - Assert.equal(toolCalls.length, 1, "One tool call"); - Assert.equal( - toolCalls[0].function.name, - "search_open_tabs", - "Tool name should match" - ); - - // --- Second turn: send tool result, stream final answer --- - const assistantToolCallsMsg = { - role: "assistant", - tool_calls: toolCalls.map(tc => ({ - id: tc.id, - type: "function", - function: { - name: tc.function.name, - arguments: tc.function.arguments, - }, - })), - }; - - const toolResultMsg = { - role: "tool", - tool_call_id: toolCalls[0].id, - content: JSON.stringify({ query: "news", allTabs: [] }), - }; - - const gen2 = engineInstance.runWithGenerator({ - args: [...starter.args, assistantToolCallsMsg, toolResultMsg], - tools: SHARED_TOOLS, - streamOptions: { enabled: true }, - }); - - let final = ""; - for await (const chunk of gen2) { - if (chunk.text) { - final += chunk.text; - } - } - - Assert.ok(final.length, "Should stream some final content"); - Assert.equal( - final, - "Here are the tabs I found for you.", - "Should stream the expected assistant follow-up" - ); - } finally { - await EngineProcess.destroyMLEngine(); - await cleanup(); - await stopMockOpenAI(mockServer); - } -}); diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js @@ -31,9 +31,6 @@ const { HttpServer } = ChromeUtils.importESModule( const MS_PER_SEC = 1000; const IndexedDBCache = TestIndexedDBCache; -/** - * @type {import("../../../ml/content/EngineProcess.sys.mjs")} - */ const { createEngine, PipelineOptions, @@ -53,7 +50,8 @@ Services.scriptloader.loadSubScript( ); /** - * Mock out remote settings and set some default preferences for the testing environment. + * Sets up the stage for a test + * */ async function setup({ disabled = false, @@ -1118,18 +1116,3 @@ async function getMLEngineWorkerCode() { ); return response.text(); } - -/** - * Checks that a process exists. - * - * @param {string} remoteType - */ -async function checkForRemoteType(remoteType) { - let procinfo3 = await ChromeUtils.requestProcInfo(); - for (const child of procinfo3.children) { - if (child.type === remoteType) { - return true; - } - } - return false; -}