commit df7784956ce051bfbdd02503e9a45b679f866be2 parent a49b09a1fbd3e7495c953880b792167f89c3a5a7 Author: Greg Tatum <tatum.creative@gmail.com> Date: Thu, 2 Oct 2025 19:35:07 +0000 Bug 1990084 - Add static embeddings as an ml backend; r=tarek Differential Revision: https://phabricator.services.mozilla.com/D265709 Diffstat:
17 files changed, 782 insertions(+), 49 deletions(-)
diff --git a/toolkit/components/aboutinference/content/aboutInference.js b/toolkit/components/aboutinference/content/aboutInference.js @@ -42,6 +42,7 @@ const TASKS = [ "question-answering", "fill-mask", "summarization", + "static-embeddings", "translation", "text2text-generation", "text-generation", @@ -310,6 +311,54 @@ const INFERENCE_PAD_PRESETS = { device: "cpu", backend: "onnx", }, + "static-embeddings": { + inputArgs: [ + "This is an example of encoding", + "The quick brown fox jumps over the lazy dog.", + "Curaçao, naïve fiancé, jalapeño, déjà vu.", + "Привет, как дела?", + "Бързата кафява лисица прескача мързеливото куче.", + "Γρήγορη καφέ αλεπού πηδάει πάνω από τον τεμπέλη σκύλο.", + "اللغة العربية جميلة وغنية بالتاريخ.", + "مرحبا بالعالم!", + "Simplified: 快速的棕色狐狸跳过懒狗。", + "Traditional: 快速的棕色狐狸跳過懶狗。", + "素早い茶色の狐が怠け者の犬を飛び越える。", + "コンピュータープログラミング", + "빠른 갈색 여우가 게으른 개를 뛰어넘습니다.", + "तेज़ भूरी लोमड़ी आलसी कुत्ते के ऊपर कूदती है।", + "দ্রুত বাদামী শিয়াল অলস কুকুরের উপর দিয়ে লাফ দেয়।", + "வேகமான பழுப்பு நரி சோம்பேறி நாயின் மேல் குதிக்கிறது.", + "สุนัขจิ้งจอกสีน้ำตาลกระโดดข้ามสุนัขขี้เกียจ.", + "ብሩክ ቡናማ ቀበሮ ሰነፍ ውሻን ተዘልሏል።", + // Mixed scripts: + "Hello 世界 مرحبا 🌍", + "123, αβγ, абв, العربية, 中文, हिन्दी.", + ], + runOptions: { + // Use mean pooling, where each static embedding is averaged together into + // a new vector. + pooling: "mean", + // Normalize the resulting vector. + normalize: true, + }, + task: "static-embeddings", + modelHub: "mozilla", + modelId: "mozilla/static-embeddings", + modelRevision: "v1.0.0", + backend: "static-embeddings", + staticEmbeddingsOptions: { + // View the available models here: + // https://huggingface.co/gregtatum/static-embeddings/tree/main/models + subfolder: "models/minishlab/potion-retrieval-32M", + // The precision of the embeddings: fp32, fp16, fp8_e5m2, fp8_e4m3 + dtype: "fp8_e4m3", + // The dimensions available: 32, 64, 128, 256. + dimensions: 128, + // Whether or not to use ZST compression. + compression: true, + }, + }, "link-preview": { inputArgs: `Summarize this: ${TINY_ARTICLE}`, runOptions: { @@ -664,6 +713,10 @@ async function displayInfo() { await refreshPage(); } +/** + * @param {string} selectId + * @param {string} optionValue + */ function setSelectOption(selectId, optionValue) { const selectElement = document.getElementById(selectId); if (!selectElement) { @@ -684,7 +737,9 @@ function setSelectOption(selectId, optionValue) { } } - console.warn(`No option found with value: ${optionValue}`); + console.warn( + `No option found for "${selectId}" with value: "${optionValue}"` + ); } function loadExample(name) { @@ -796,6 +851,13 @@ async function runInference() { }; } + if (taskName == "static-embeddings") { + const config = INFERENCE_PAD_PRESETS["static-embeddings"]; + additionalEngineOptions = { + staticEmbeddingsOptions: config.staticEmbeddingsOptions, + }; + } + const initData = { featureId: "about-inference", modelId, diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs @@ -463,7 +463,7 @@ class EngineDispatcher { // Ignore errors from tests intentionally causing errors. !error?.message?.startsWith("Intentionally") ) { - lazy.console.error("Could not initalize the engine", error); + lazy.console.error("Could not initialize the engine", error); } }); diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs @@ -5,12 +5,16 @@ import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs"; /** * @typedef {object} Lazy - * @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams - * @property {typeof console} console - * @property {typeof import("../content/Utils.sys.mjs").getRuntimeWasmFilename} getRuntimeWasmFilename - * @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings - * @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent - * @typedef {import("../../translations").WasmRecord} WasmRecord + * @property {typeof import("resource://services-settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings + * @property {typeof import("resource://services-settings/Utils.sys.mjs").Utils} Utils + * @property {typeof import("resource://gre/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent + * @property {typeof setTimeout} setTimeout + * @property {typeof clearTimeout} clearTimeout + * @property {typeof import("chrome://global/content/ml/ModelHub.sys.mjs").ModelHub} ModelHub + * @property {typeof import("chrome://global/content/ml/Utils.sys.mjs").Progress} Progress + * @property {typeof import("chrome://global/content/ml/Utils.sys.mjs").isAddonEngineId} isAddonEngineId + * @property {typeof import("chrome://global/content/ml/OPFS.sys.mjs").OPFS} OPFS + * @property {typeof import("chrome://global/content/ml/EngineProcess.sys.mjs").BACKENDS} BACKENDS */ /** @type {Lazy} */ @@ -949,7 +953,7 @@ class ResponseOrChunkResolvers { * * @template Response */ -class MLEngine { +export class MLEngine { /** * The cached engines. * diff --git a/toolkit/components/ml/content/EngineProcess.sys.mjs b/toolkit/components/ml/content/EngineProcess.sys.mjs @@ -16,9 +16,7 @@ export const DEFAULT_ENGINE_ID = "default-engine"; /** - * @constant - * @type {Array<string>} - * @description Supported backends. + * Supported backends. */ export const BACKENDS = Object.freeze({ onnx: "onnx", @@ -27,6 +25,7 @@ export const BACKENDS = Object.freeze({ llamaCpp: "llama.cpp", bestLlama: "best-llama", openai: "openai", + staticEmbeddings: "static-embeddings", }); /** @@ -231,72 +230,74 @@ export const ModelHub = { }; /** - * Enum for execution priority. - * - * Defines the priority of the task: - * - * - "High" is absolutely needed for Firefox. - * - "Normal" is the default priority. - * - "Low" is for 3rd party calls. + * Enum for execution priority of a task. * * @readonly * @enum {string} */ export const ExecutionPriority = { + // The task is absolutely needed for Firefox. HIGH: "HIGH", + // The default priority. NORMAL: "NORMAL", + // 3rd party calls. LOW: "LOW", }; /** - * Enum for model quantization levels. - * - * Defines the quantization level of the task: - * - * - 'fp32': Full precision 32-bit floating point (`''`) - * - 'fp16': Half precision 16-bit floating point (`'_fp16'`) - * - 'q8': Quantized 8-bit (`'_quantized'`) - * - 'int8': Integer 8-bit quantization (`'_int8'`) - * - 'uint8': Unsigned integer 8-bit quantization (`'_uint8'`) - * - 'q4': Quantized 4-bit (`'_q4'`) - * - 'bnb4': Binary/Boolean 4-bit quantization (`'_bnb4'`) - * - 'q4f16': 16-bit floating point model with 4-bit block weight quantization (`'_q4f16'`) + * Enum for model quantization levels. Not all models support all values. * * @readonly * @enum {string} */ export const QuantizationLevel = { + // Full precision 32-bit floating point (`''`) FP32: "fp32", + // Half precision 16-bit floating point (`'_fp16'`) FP16: "fp16", + // Floating point 8 with the exponential taking 5 bits, and the mantissa taking 2. + // This format can express a wide dynamic range of float values because of the + // extra bits in the exponential, but with the trade off of lower precision of stored + // values because of the small mantissa. The max finite values are ±57,344. + FP8_E5M2: "fp8_e5m2", + // Floating point 8 with the exponential taking 4 bits, and the mantissa taking 3. + // This format is best for values without a wide dynamic range. The higher bits + // in the mantissa retains more precision The max finite values are ±448. + FP8_E4M3: "fp8_e4m3", + // Quantized 8-bit (`'_quantized'`) Q8: "q8", + // Integer 8-bit quantization (`'_int8'`) INT8: "int8", + // Unsigned integer 8-bit quantization (`'_uint8'`) UINT8: "uint8", + // Quantized 4-bit (`'_q4'`) Q4: "q4", + // Binary/Boolean 4-bit quantization (`'_bnb4'`) BNB4: "bnb4", + // 16-bit floating point model with 4-bit block weight quantization (`'_q4f16'`) Q4F16: "q4f16", }; /** * Enum for KV cache quantization levels. * - * - 'q8_0': Quantized 8-bit with optimized storage (`'_q8_0'`) (block-based) - * - 'q4_0': Quantized 4-bit version 0 (`'_q4_0'`) (block-based) - * - 'q4_1': Quantized 4-bit version 1 (`'_q4_1'`) (block-based) - * - 'q5_1': Quantized 5-bit version 1 (`'_q5_1'`) (block-based) - * - 'q5_0': Quantized 5-bit version 0 (`'_q5_0'`) (block-based) - * - 'f16': Half-precision (16-bit floating point) (`'_f16'`) - * - 'f32': Full precision (32-bit floating point) (`'_f32'`) - * * @readonly * @enum {string} */ export const KVCacheQuantizationLevel = { + // Quantized 8-bit with optimized storage (`'_q8_0'`) (block-based) Q8_0: "q8_0", + // Quantized 4-bit version 0 (`'_q4_0'`) (block-based) Q4_0: "q4_0", + // Quantized 4-bit version 1 (`'_q4_1'`) (block-based) Q4_1: "q4_1", + // Quantized 5-bit version 1 (`'_q5_1'`) (block-based) Q5_1: "q5_1", + // Quantized 5-bit version 0 (`'_q5_0'`) (block-based) Q5_0: "q5_0", + // Half-precision (16-bit floating point) (`'_f16'`) F16: "f16", + // Full precision (32-bit floating point) (`'_f32'`) F32: "f32", }; @@ -337,7 +338,8 @@ export const LogLevel = { export const AllowedBoolean = [false, true]; /** - * @typedef {import("../../translations/actors/TranslationsEngineParent.sys.mjs").TranslationsEngineParent} TranslationsEngineParent + * @import { TranslationsEngineParent } from "../../translations/actors/TranslationsEngineParent.sys.mjs" + * @import { StaticEmbeddingsOptions } from "./backends/StaticEmbeddingsPipeline.d.ts" */ const PIPELINE_TEST_NAMES = ["moz-echo", "test-echo"]; @@ -583,6 +585,13 @@ export class PipelineOptions { apiKey = null; /** + * The options for the engine when using static embeddings. + * + * @type {?StaticEmbeddingsOptions} + */ + staticEmbeddingsOptions = null; + + /** * Create a PipelineOptions instance. * * @param {object} options - The options for the pipeline. Must include mandatory fields. @@ -787,6 +796,7 @@ export class PipelineOptions { "backend", "baseURL", "apiKey", + "staticEmbeddingsOptions", ]; if (options instanceof PipelineOptions) { @@ -925,6 +935,7 @@ export class PipelineOptions { backend: this.backend, baseURL: this.baseURL, apiKey: this.apiKey, + staticEmbeddingsOptions: this.staticEmbeddingsOptions, }; } diff --git a/toolkit/components/ml/content/MLEngine.worker.mjs b/toolkit/components/ml/content/MLEngine.worker.mjs @@ -18,7 +18,7 @@ ChromeUtils.defineESModuleGetters( /** * The actual MLEngine lives here in a worker. */ -class MLEngineWorker { +export class MLEngineWorker { #pipeline; #sessionId; @@ -134,7 +134,14 @@ class MLEngineWorker { self.callMainThread = worker.callMainThread.bind(worker); self.addEventListener("message", msg => worker.handleMessage(msg)); self.addEventListener("unhandledrejection", function (error) { - throw error.reason?.fail ?? error.reason; + const reason = error?.reason?.fail ?? error?.reason; + if (reason) { + // The PromiseWorker message passing doesn't properly expose the call stack of + // errors which makes it really hard to debug code. Log the error here to + // ensure that nice call stacks are preserved. + console.error(reason); + } + throw new Error("MLEngine.worker.mjs had an unhandled error."); }); } } diff --git a/toolkit/components/ml/content/ModelHub.sys.mjs b/toolkit/components/ml/content/ModelHub.sys.mjs @@ -82,7 +82,9 @@ const NO_ETAG = "NO_ETAG"; */ class ForbiddenURLError extends Error { constructor(url, rejectionType) { - super(`Forbidden URL: ${url} (${rejectionType})`); + super( + `Forbidden URL: ${url} (${rejectionType}). Set MOZ_ALLOW_EXTERNAL_ML_HUB=1 to allow external URLs.` + ); this.name = "ForbiddenURLError"; this.url = url; } @@ -1389,6 +1391,9 @@ export class ModelHub { this.reset = reset; } + /** + * @param {string} url + */ allowedURL(url) { if (this.allowDenyList === null) { return { allowed: true, rejectionType: lazy.RejectionType.NONE }; diff --git a/toolkit/components/ml/content/Utils.sys.mjs b/toolkit/components/ml/content/Utils.sys.mjs @@ -186,7 +186,7 @@ export class ProgressAndStatusCallbackParams { * @param {string} config.file - filename * @param {string} config.rootUrl - root url of the model hub * @param {string} config.urlTemplate - url template of the model hub - * @param {boolean} config.addDownloadParams - Whether to add a download query parameter. + * @param {boolean} [config.addDownloadParams] - Whether to add a download query parameter. * @returns {string} The full URL */ export function createFileUrl({ diff --git a/toolkit/components/ml/content/backends/ONNXPipeline.mjs b/toolkit/components/ml/content/backends/ONNXPipeline.mjs @@ -62,11 +62,12 @@ let transformers = null; * @async * @function importTransformers * @param {string} backend - The backend to use (e.g. "onnx-native" or "onnx"). - * @returns {Promise<void>} A promise that resolves once the Transformers library is imported. + * @returns {Promise<import("chrome://global/content/ml/transformers-dev.js")>} + * A promise that resolves once the Transformers library is imported. */ -async function importTransformers(backend) { +export async function importTransformers(backend) { if (transformers) { - return; + return transformers; } lazy.console.debug(`Using backend ${backend}`); @@ -95,6 +96,8 @@ async function importTransformers(backend) { lazy.console.debug("Beta or Release detected, using transformers.js"); transformers = await import("chrome://global/content/ml/transformers.js"); } + + return transformers; } /** diff --git a/toolkit/components/ml/content/backends/Pipeline.mjs b/toolkit/components/ml/content/backends/Pipeline.mjs @@ -13,6 +13,8 @@ ChromeUtils.defineESModuleGetters( "chrome://global/content/ml/backends/LlamaCppPipeline.mjs", PipelineOptions: "chrome://global/content/ml/EngineProcess.sys.mjs", OpenAIPipeline: "chrome://global/content/ml/backends/OpenAIPipeline.mjs", + StaticEmbeddingsPipeline: + "chrome://global/content/ml/backends/StaticEmbeddingsPipeline.mjs", }, { global: "current" } ); @@ -44,6 +46,9 @@ export async function getBackend(consumer, wasm, options) { case "openai": factory = lazy.OpenAIPipeline.initialize; break; + case "static-embeddings": + factory = lazy.StaticEmbeddingsPipeline.initialize; + break; default: factory = lazy.ONNXPipeline.initialize; } diff --git a/toolkit/components/ml/content/backends/StaticEmbeddingsPipeline.d.ts b/toolkit/components/ml/content/backends/StaticEmbeddingsPipeline.d.ts @@ -0,0 +1,94 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +export type EmbeddingPooling = "mean" | "sum" | "max"; + +export type EmbeddingDType = "fp32" | "fp16" | "fp8_e5m2" | "fp8_e4m3"; + +/** + * The options that can be passed in for embedding. + */ +export interface EmbeddingOptions { + pooling: EmbeddingPooling; + dimensions: EmbeddingDType; + normalize: boolean; +} + +/** + * The full request for embedding text. Takes a list of embeddings. + */ +export interface EmbeddingRequest { + args: string[]; + options: EmbeddingOptions; +} + +/** + * This is the tokenenizer.json that is used to configur the PreTrainedTokenizer class + * in transformers.js. + */ +export interface TokenizerJSON { + added_tokens: any[]; // Array(5) [ {…}, {…}, {…}, … ] + decoder: any; // { type: "WordPiece", prefix: "##", cleanup: true } + model: any; // { type: "WordPiece", unk_token: "[UNK]", continuing_subword_prefix: "##", … } + normalizer: any; // { type: "BertNormalizer", clean_text: true, handle_chinese_chars: true, … } + padding: null; // + post_processor: any; // { type: "TemplateProcessing", single: (3) […], pair: (5) […], … } + pre_tokenizer: any; // { type: "BertPreTokenizer" } + version: string; // "1.0" +} + +export interface EmbeddingResponse { + metrics: Array<{ name: string; when: number }>; + output: Array<Float32Array>; +} + +/** + * The options for configuring the static embeddings engine. + * + * @see https://huggingface.co/Mozilla/static-embeddings. + */ +export interface StaticEmbeddingsOptions { + /** + * The path to the models in the repo. equivalent to the `subfolder` property + * in transformers. + * + * e.g. "models/minishlab/potion-retrieval-32M" + * + * View the available models here: + * @see https://huggingface.co/Mozilla/static-embeddings/tree/main/models + */ + subfolder: string; + + /** + * The precision of the embeddings. Generally fp8_e4m3 is smallest that still + * performs well. There is almost no quality improvement from fp16 to fp32. + * See the model cards for more information. + */ + dtype: "fp32" | "fp16" | "fp8_e5m2" | "fp8_e4m3"; + /** + * See each model card for what dimensions are available. Generally models are trained + * with Matroyshka loss so it's best to pick one of the pre-defined dimensions. + */ + dimensions: number; + /** + * Whether or not to use ZST compression. There is a small trade off between the speed + * of loading the model, and the size on disk and download time. + */ + compression: boolean; + + /** + * Mock the engine for tests. + */ + mockedValues?: {}; +} + +/** + * Stub type defintion for transfomers.js + * + * @see https://huggingface.co/docs/transformers.js/api/tokenizers#module_tokenizers.PreTrainedTokenizer + */ +interface PreTrainedTokenizer { + model: { vocab: list[any] }; + encode(text: string): number[]; +} diff --git a/toolkit/components/ml/content/backends/StaticEmbeddingsPipeline.mjs b/toolkit/components/ml/content/backends/StaticEmbeddingsPipeline.mjs @@ -0,0 +1,394 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// @ts-check + +/** + * @import { PipelineOptions } from "chrome://global/content/ml/EngineProcess.sys.mjs" + * @import { BackendError } from "./Pipeline.mjs" + * @import { MLEngineWorker } from "../MLEngine.worker.mjs" + * @import { EmbeddingDType, EmbeddingRequest, EmbeddingResponse, PreTrainedTokenizer } from "./StaticEmbeddingsPipeline.d.ts" + */ + +/** + * @typedef {object} Lazy + * @property {typeof import("chrome://global/content/ml/Utils.sys.mjs").createFileUrl} createFileUrl + * @property {typeof import("chrome://global/content/ml/Utils.sys.mjs").parseNpy} parseNpy + * @property {typeof import("chrome://global/content/ml/OPFS.sys.mjs").OPFS} OPFS + * @property {typeof import("chrome://global/content/ml/backends/ONNXPipeline.mjs").importTransformers} importTransformers + * @property {typeof import("chrome://global/content/ml/EngineProcess.sys.mjs").QuantizationLevel} QuantizationLevel + */ + +/** @type {Lazy} */ +const lazy = /** @type {any} */ ({}); + +ChromeUtils.defineESModuleGetters( + lazy, + { + createFileUrl: "chrome://global/content/ml/Utils.sys.mjs", + parseNpy: "chrome://global/content/ml/Utils.sys.mjs", + OPFS: "chrome://global/content/ml/OPFS.sys.mjs", + importTransformers: "chrome://global/content/ml/backends/ONNXPipeline.mjs", + QuantizationLevel: "chrome://global/content/ml/EngineProcess.sys.mjs", + }, + { global: "current" } +); + +/** + * Mock out a response object for tests. This should have the same type as @see {Response} + */ +class MockedResponse { + /** @type {any} */ + #value; + /** + * @param {any} value + */ + constructor(value) { + this.#value = value; + } + + /** + * @returns {ReturnType<Response["json"]>} + */ + json() { + return this.#value; + } + + /** + * @returns {ReturnType<Response["arrayBuffer"]>} + */ + arrayBuffer() { + return this.#value; + } +} + +/** + * Embeddings are typically generated through running text through a BERT-like model + * that is an encoder-only transformer. However, this is expensive and slow. Static + * embeddings allow for a cheaper way to generate an embedding by just averaging the + * values of each token's embedding vector. This involves a simple lookup per token, and + * then some vector math. These embeddings are often good enough for looking up + * semantically similar values. + */ +export class StaticEmbeddingsPipeline { + /** @type {PreTrainedTokenizer} */ + #tokenizer; + + /** + * The embedding dimensions size, e.g. 128, 256, 512 + * + * @type {number} + */ + #dimensions; + + /** @type {number | null} */ + #initializeStart; + + /** + * Get a native JS double out of the backing data array. + * + * @type {(index: number) => number} + */ + #getFloat; + + /** + * @param {PreTrainedTokenizer} tokenizer + * @param {ArrayBuffer} npyData + * @param {EmbeddingDType} dtype + * @param {number} dimensions + * @param {number} initializeStart + */ + constructor(tokenizer, npyData, dtype, dimensions, initializeStart) { + this.#tokenizer = tokenizer; + this.#dimensions = dimensions; + this.#initializeStart = initializeStart; + + const { + data: embeddings, + shape: [vocabSize, dimActual], + } = lazy.parseNpy(npyData); + + if (dimActual != this.#dimensions) { + throw new Error( + `The dimensions requested (${this.#dimensions}) and the dimensions received (${dimActual}) did not match` + ); + } + if (tokenizer.model.vocab.length != vocabSize) { + throw new Error( + `The tokenizer vocab size (${this.#dimensions}) did not match the data vocab size (${vocabSize})` + ); + } + + switch (dtype) { + case lazy.QuantizationLevel.FP32: + case lazy.QuantizationLevel.FP16: + // No processing is needed. + this.#getFloat = index => this.embeddings[index]; + break; + case lazy.QuantizationLevel.FP8_E5M2: + this.#getFloat = this.#getFp8_E5M2; + break; + case lazy.QuantizationLevel.FP8_E4M3: + this.#getFloat = this.#getFp8_E4M3; + break; + default: + throw new Error("Unsupported dtype: " + dtype); + } + + /** @type {ArrayBufferLike} */ + this.embeddings = embeddings; + } + + /** + * @param {MLEngineWorker} worker + * @param {null} _wasm + * @param {PipelineOptions} pipelineOptions + * @param {(error: any) => BackendError} _createError + */ + static async initialize(worker, _wasm, pipelineOptions, _createError) { + let initializeStart = ChromeUtils.now(); + + const { + backend, + modelHubRootUrl, + modelHubUrlTemplate, + modelId, + modelRevision, + staticEmbeddingsOptions, + } = pipelineOptions; + + // These are the options that are specific to this engine. + const { subfolder, dtype, dimensions, compression, mockedValues } = + staticEmbeddingsOptions; + + const extension = compression ? ".zst" : ""; + + const files = [ + `${subfolder}/tokenizer.json${extension}`, + `${subfolder}/${dtype}.d${dimensions}.npy${extension}`, + ]; + + /** + * @param {string} fileName + * @returns {Promise<Response | MockedResponse>} + */ + async function getResponse(fileName) { + const url = lazy.createFileUrl({ + file: fileName, + model: modelId, + revision: modelRevision, + urlTemplate: modelHubUrlTemplate, + rootUrl: modelHubRootUrl, + }); + if (mockedValues) { + const mockedValue = mockedValues[url]; + if (!mockedValue) { + throw new Error( + "Could not find mocked value for requested url: " + url + ); + } + if (url.endsWith(`.json${extension}`)) { + return new MockedResponse(mockedValue); + } + return new MockedResponse(new Uint8Array(mockedValue).buffer); + } + const modelFile = await worker.getModelFile({ url }); + const filePath = modelFile.ok[2]; + const fileHandle = await lazy.OPFS.getFileHandle(filePath); + const file = await fileHandle.getFile(); + let stream = file.stream(); + if (compression) { + const decompressionStream = new DecompressionStream("zstd"); + stream = stream.pipeThrough(decompressionStream); + } + return new Response(stream); + } + + const [tokenizerJsonResponse, npyDataResponse] = await Promise.all( + files.map(getResponse) + ); + + const npyData = await npyDataResponse.arrayBuffer(); + const tokenizerJson = await tokenizerJsonResponse.json(); + + let assetsLoad = ChromeUtils.now(); + ChromeUtils.addProfilerMarker( + "StaticEmbeddingsPipeline", + initializeStart, + "Assets load" + ); + const { PreTrainedTokenizer } = await lazy.importTransformers(backend); + const tokenizer = new PreTrainedTokenizer(tokenizerJson, {}); + ChromeUtils.addProfilerMarker( + "StaticEmbeddingsPipeline", + assetsLoad, + "Tokenizer load" + ); + + return new StaticEmbeddingsPipeline( + tokenizer, + npyData, + dtype, + dimensions, + initializeStart + ); + } + + /** + * @param {number} index + */ + #getFp8_E5M2 = index => { + const byte = this.embeddings[index]; + + // Do some bit manipulation to extract the sign (S), the exponent (E), and the + // mantissa (M) + // This is format: | S E E E | E E M M | + // To do the manipulation, shift the bits to the right (>>) and mask off the relevant + // bits with an & operation. + const sign = (byte >> 7) & 0b0000_0001; + const exponent = (byte >> 2) & 0b0001_1111; + const mantissa = byte & 0b0000_0011; + const bias = 15; + + if (exponent === 0) { + if (mantissa === 0) { + // Zero + return sign ? -0 : 0; + } + // Subnormal: exponent = 1 - bias, no implicit leading 1 + const frac = mantissa / 4; // 2 mantissa bits → divide by 2^2 + const value = frac * Math.pow(2, 1 - bias); + return sign ? -value : value; + } else if (exponent === 0x1f) { + if (mantissa === 0) { + return sign ? -Infinity : Infinity; + } + return NaN; + } + // Normalized + const frac = 1 + mantissa / 4; + const value = frac * Math.pow(2, exponent - bias); + return sign ? -value : value; + }; + + /** + * @param {number} index + */ + #getFp8_E4M3 = index => { + const byte = this.embeddings[index]; + + // Do some bit manipulation to extract the sign (S), the exponent (E), and the + // mantissa (M) + // This is format: | S E E E | E M M M | + // To do the manipulation, shift the bits to the right (>>) and mask off the relevant + // bits with an & operation. + const sign = (byte >> 7) & 0b0000_0001; + const exponent = (byte >> 3) & 0b0000_1111; + const mantissa = byte & 0b0000_0111; + const bias = 7; + + if (exponent === 0) { + if (mantissa === 0) { + return sign ? -0 : 0; + } + // Subnormal + const frac = mantissa / 8; // 3 mantissa bits → divide by 2^3 + const value = frac * Math.pow(2, 1 - bias); + return sign ? -value : value; + } else if (exponent === 0xf) { + if (mantissa === 0) { + return sign ? -Infinity : Infinity; + } + return NaN; + } + + // Normalized + const frac = 1 + mantissa / 8; + const value = frac * Math.pow(2, exponent - bias); + + return sign ? -value : value; + }; + + /** + * @param {EmbeddingRequest} request + * @param {number} _requestId + * @param {null} _engineRunOptions + * @returns {EmbeddingResponse} + */ + run(request, _requestId, _engineRunOptions) { + if (request.options.pooling != "mean") { + throw new Error( + `Only "mean" pooling is currently supported, please add support "${request.options.pooling}" here.` + ); + } + + let tokenCount = 0; + const sequenceCount = request.args.length; + + let beforeResponse = ChromeUtils.now(); + const response = { + metrics: [], + output: request.args.map(text => { + // Always do the vector math in f32 space, even if the underlying precision + // is lower. + const embedding = new Float32Array(this.#dimensions); + + /** @type {number[]} */ + const tokenIds = this.#tokenizer.encode(text); + tokenCount += tokenIds.length; + + // Sum up the embeddings. + for (const tokenId of tokenIds) { + for (let i = 0; i < this.#dimensions; i++) { + // Inflate the double into a JavaScript double, then add it. + embedding[i] += this.#getFloat(tokenId * this.#dimensions + i); + } + } + + if (request.options.normalize) { + // Compute the average by dividing by the tokens provided. + // Also compute the sum of the squares while we're here. + let sumSquares = 0; + for (let i = 0; i < this.#dimensions; i++) { + const n = embedding[i] / tokenIds.length; + embedding[i] = n; + sumSquares += n * n; + } + + // Apply the normalization. + const magnitude = Math.sqrt(sumSquares); + if (magnitude != 0) { + for (let i = 0; i < this.#dimensions; i++) { + embedding[i] = embedding[i] / magnitude; + } + } + } else { + // Only compute the average by dividing by the tokens provided. + for (let i = 0; i < this.#dimensions; i++) { + embedding[i] = embedding[i] / tokenIds.length; + } + } + + return embedding; + }), + }; + + ChromeUtils.addProfilerMarker( + "StaticEmbeddingsPipeline", + beforeResponse, + `Processed ${sequenceCount} sequences with ${tokenCount} tokens.` + ); + + if (this.#initializeStart) { + ChromeUtils.addProfilerMarker( + "StaticEmbeddingsPipeline", + this.#initializeStart, + "Time to first response" + ); + this.#initializeStart = null; + } + + return response; + } +} diff --git a/toolkit/components/ml/jar.mn b/toolkit/components/ml/jar.mn @@ -25,6 +25,7 @@ toolkit.jar: content/global/ml/ThomSample.mjs (content/ThomSample.mjs) content/global/ml/backends/LlamaCppPipeline.mjs (content/backends/LlamaCppPipeline.mjs) content/global/ml/backends/OpenAIPipeline.mjs (content/backends/OpenAIPipeline.mjs) + content/global/ml/backends/StaticEmbeddingsPipeline.mjs (content/backends/StaticEmbeddingsPipeline.mjs) content/global/ml/openai.mjs (vendor/openai/dist/openai.mjs) #ifdef NIGHTLY_BUILD content/global/ml/ort.webgpu-dev.mjs (vendor/ort.webgpu-dev.mjs) diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml @@ -35,6 +35,8 @@ skip-if = [ ["browser_ml_privatebrowsing.js"] +["browser_ml_static_embeddings.js"] + ["browser_ml_telemetry.js"] skip-if = [ "verify" ] diff --git a/toolkit/components/ml/tests/browser/browser_ml_static_embeddings.js b/toolkit/components/ml/tests/browser/browser_ml_static_embeddings.js @@ -0,0 +1,109 @@ +/* Any copyright is dedicated to the Public Domain. + https://creativecommons.org/publicdomain/zero/1.0/ */ + +"use strict"; + +/** + * @import { Request as EngineRequest, MLEngine as MLEngineClass } from "../../actors/MLEngineParent.sys.mjs" + * @import { StaticEmbeddingsOptions } from "../../content/backends/StaticEmbeddingsPipeline.d.ts" + */ + +const { parseNpy } = ChromeUtils.importESModule( + "chrome://global/content/ml/Utils.sys.mjs" +); + +const vocabSize = 9; +const dimensions = 8; + +/** + * Mock out the URL requests with a small bad embeddings model. + */ +function getMockedValues() { + const { encoding } = generateFloat16Numpy(vocabSize, dimensions); + const tokenizer = + // prettier-ignore + { + version: "1.0", + truncation: null, + padding: null, + added_tokens: [{ id: 0, content: "[UNK]", single_word: false, lstrip: false, rstrip: false, normalized: false, special: true }], + normalizer: { type: "BertNormalizer", clean_text: true, handle_chinese_chars: true, strip_accents: null, lowercase: true }, + pre_tokenizer: { type: "BertPreTokenizer" }, + post_processor: { + type: "TemplateProcessing", + single: [ + { SpecialToken: { id: "[CLS]", type_id: 0 } }, + { Sequence: { id: "A", type_id: 0 } }, + { SpecialToken: { id: "[SEP]", type_id: 0 } }, + ], + pair: [], + special_tokens: {}, + }, + decoder: { type: "WordPiece", prefix: "##", cleanup: true }, + model: { + type: "WordPiece", unk_token: "[UNK]", continuing_subword_prefix: "##", max_input_chars_per_word: 100, + vocab: { "[UNK]": 0, the: 1, quick: 2, brown: 3, dog: 4, jumped: 5, over: 6, lazy: 7, fox: 8 }, + }, + }; + + return { + "https://model-hub.mozilla.org/mozilla/static-embeddings/v1.0.0/models/minishlab/potion-retrieval-32M/tokenizer.json": + tokenizer, + [`https://model-hub.mozilla.org/mozilla/static-embeddings/v1.0.0/models/minishlab/potion-retrieval-32M/fp16.d${dimensions}.npy`]: + encoding, + }; +} + +add_task(async function test_static_embeddings() { + /** @type {StaticEmbeddingsOptions} */ + const staticEmbeddingsOptions = { + dtype: "fp16", + subfolder: "models/minishlab/potion-retrieval-32M", + dimensions, + mockedValues: getMockedValues(), + compression: false, + }; + + /** @type {MLEngineClass} */ + const engine = await createEngine( + new PipelineOptions({ + featureId: "simple-text-embedder", + engineId: "test-static-embeddings", + + modelId: "mozilla/static-embeddings", + modelRevision: "v1.0.0", + taskName: "static-embeddings", + modelHub: "mozilla", + backend: "static-embeddings", + + staticEmbeddingsOptions, + }) + ); + + const { output } = await engine.run({ + args: ["The quick brown fox jumped over the lazy fox"], + options: { + pooling: "mean", + normalize: true, + }, + }); + + is(output.length, 1, "One embedding was returned"); + const [embedding] = output; + is(embedding.length, dimensions, "The dimensions match"); + is( + embedding.constructor.name, + "Float32Array", + "The embedding was returned as a Float32Array" + ); + + assertFloatArraysMatch( + embedding, + [ + 0.3156551122, 0.3262447714, 0.3368626534, 0.3474076688, 0.3580137789, + 0.3685869872, 0.3791790008, 0.3898085951, + ], + "The embeddings were computed as expected.", + 0.00001 // epsilon + ); +}); diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js @@ -729,9 +729,41 @@ async function perfTest({ /** * Measures floating point value within epsilon tolerance + * + * @param {number[]} a + * @param {number[]} b + * @param {number} [epsilon] + * @returns {boolean} + */ +function isEqualWithTolerance(a, b, epsilon = 0.000001) { + return Math.abs(Math.abs(a) - Math.abs(b)) < epsilon; +} + +/** + * Asserts whether two float arrays are equal within epsilon tolerance. + * + * @param {number[] | ArrayBufferLike} a + * @param {number[] | ArrayBufferLike} b + * @param {string} message + * @param {number} [epsilon] */ -function isEqualWithTolerance(A, B, epsilon = 0.000001) { - return Math.abs(Math.abs(A) - Math.abs(B)) < epsilon; +function assertFloatArraysMatch(a, b, message, epsilon) { + const raise = () => { + // When logging errors, spread into a new array so that the logging is nice for + // ArrayBufferLike values. This makes it easy to see how arrays differ + console.log("a:", [...a]); + console.log("b:", [...b]); + throw new Error(message); + }; + if (a.length !== b.length) { + raise(); + } + for (let i = 0; i < a.length; i++) { + if (!isEqualWithTolerance(a[i], b[i], epsilon)) { + raise(); + } + } + ok(true, message); } // Mock OpenAI Chat Completions server for mochitests diff --git a/tools/@types/generated/lib.gecko.modules.d.ts b/tools/@types/generated/lib.gecko.modules.d.ts @@ -38,6 +38,7 @@ export interface Modules { "chrome://global/content/ml/backends/ONNXPipeline.mjs": typeof import("chrome://global/content/ml/backends/ONNXPipeline.mjs"), "chrome://global/content/ml/backends/OpenAIPipeline.mjs": typeof import("chrome://global/content/ml/backends/OpenAIPipeline.mjs"), "chrome://global/content/ml/backends/Pipeline.mjs": typeof import("chrome://global/content/ml/backends/Pipeline.mjs"), + "chrome://global/content/ml/backends/StaticEmbeddingsPipeline.mjs": typeof import("chrome://global/content/ml/backends/StaticEmbeddingsPipeline.mjs"), "chrome://global/content/preferences/Preferences.mjs": typeof import("chrome://global/content/preferences/Preferences.mjs"), "chrome://global/content/translations/TranslationsTelemetry.sys.mjs": typeof import("chrome://global/content/translations/TranslationsTelemetry.sys.mjs"), "chrome://global/content/translations/TranslationsUtils.mjs": typeof import("chrome://global/content/translations/TranslationsUtils.mjs"), diff --git a/tools/@types/generated/tspaths.json b/tools/@types/generated/tspaths.json @@ -308,6 +308,9 @@ "chrome://global/content/ml/backends/Pipeline.mjs": [ "toolkit/components/ml/content/backends/Pipeline.mjs" ], + "chrome://global/content/ml/backends/StaticEmbeddingsPipeline.mjs": [ + "toolkit/components/ml/content/backends/StaticEmbeddingsPipeline.mjs" + ], "chrome://global/content/ml/openai-dev.mjs": [ "toolkit/components/ml/vendor/openai/dist/openai-dev.mjs" ],