commit d2c64cf7da4818155aabf5fca53c686bda14ff69
parent cd57b348e07bd562cfa6aa5f12b7c16baf31e7fd
Author: Marco Bonardo <mbonardo@mozilla.com>
Date: Wed, 22 Oct 2025 08:37:09 +0000
Bug 1992265 - Use static embeddings in semantic search. r=cgopal,places-reviewers,Standard8
Differential Revision: https://phabricator.services.mozilla.com/D268671
Diffstat:
7 files changed, 242 insertions(+), 72 deletions(-)
diff --git a/browser/components/urlbar/UrlbarProviderSemanticHistorySearch.sys.mjs b/browser/components/urlbar/UrlbarProviderSemanticHistorySearch.sys.mjs
@@ -42,7 +42,6 @@ ChromeUtils.defineLazyGetter(lazy, "semanticManager", function () {
0.6
);
return getPlacesSemanticHistoryManager({
- embeddingSize: 384,
rowLimit: 10000,
samplingAttrib: "frecency",
changeThresholdCount: 3,
diff --git a/toolkit/components/ml/content/nlp/EmbeddingsGenerator.sys.mjs b/toolkit/components/ml/content/nlp/EmbeddingsGenerator.sys.mjs
@@ -44,31 +44,80 @@ ChromeUtils.defineLazyGetter(lazy, "console", () => {
const REQUIRED_MEMORY_BYTES = 7 * 1024 * 1024 * 1024;
const REQUIRED_CPU_CORES = 2;
+const staticEmbeddingsOptions = {
+ // See https://huggingface.co/Mozilla/static-embeddings/blob/main/models/minishlab/potion-retrieval-32M/README.md
+ subfolder: "models/minishlab/potion-retrieval-32M",
+ // Available: fp32, fp16, fp8_e5m2, fp8_e4m3
+ dtype: "fp16",
+ // Avalable dimsensions: 32, 64, 128, 256, 512
+ dimensions: 256,
+ // Use zstd compression, probably set it to true.
+ compression: true,
+};
+
/**
*
*/
export class EmbeddingsGenerator {
#engine = undefined;
#promiseEngine;
- #embeddingSize = 384;
- options = {
- taskName: "feature-extraction",
- featureId: "simple-text-embedder",
- timeoutMS: -1,
- numThreads: 2,
- backend: "onnx-native",
- };
- // wasm as fallback
- optionsFallback = {
- taskName: "feature-extraction",
- featureId: "simple-text-embedder",
- timeoutMS: -1,
- numThreads: 2,
- backend: "onnx",
- };
+ #embeddingSize;
+ options;
+ #optionsByEngine = new Map([
+ [
+ "onnx-native",
+ {
+ taskName: "feature-extraction",
+ featureId: "simple-text-embedder",
+ timeoutMS: -1,
+ numThreads: 2,
+ backend: "onnx-native",
+
+ supportedDimensions: [384],
+ fallbackEngine: "onnx-wasm",
+ },
+ ],
+ [
+ "onnx-wasm",
+ {
+ taskName: "feature-extraction",
+ featureId: "simple-text-embedder",
+ timeoutMS: -1,
+ numThreads: 2,
+ backend: "onnx",
+
+ supportedDimensions: [384],
+ },
+ ],
+ [
+ "static-embeddings",
+ {
+ featureId: "simple-text-embedder",
+ modelId: "mozilla/static-embeddings",
+ modelRevision: "v1.0.0",
+ taskName: "static-embeddings",
+ modelHub: "mozilla",
+ backend: "static-embeddings",
+ staticEmbeddingsOptions,
- constructor(embeddingSize = 384) {
+ supportedDimensions: [32, 64, 128, 256, 512],
+ setDimensions(embeddingSize) {
+ this.staticEmbeddingsOptions.dimensions = embeddingSize;
+ },
+ },
+ ],
+ ]);
+
+ constructor({ backend = "static-embeddings", embeddingSize = 256 } = {}) {
this.#embeddingSize = embeddingSize;
+ this.options = this.#optionsByEngine.get(backend);
+ if (!this.options) {
+ throw new TypeError("Unsupported embedding engine");
+ }
+ if (!this.options.supportedDimensions.includes(embeddingSize)) {
+ throw new TypeError("Unsupported embedding size");
+ }
+ this.options.setDimensions?.(embeddingSize);
}
/**
@@ -130,19 +179,23 @@ export class EmbeddingsGenerator {
this.#engine = await lazy.createEngine(this.options);
} catch (ex) {
lazy.console.warn(
- "Native engine init failed. Falling back to wasm. Error:" + ex
+ `Engine ${this.options.backend} init failed. Falling back to wasm. Error:` +
+ ex
);
- // Fallback to wasm
- if (this.optionsFallback) {
+ // Use a fallback engine if available.
+ if (this.options.fallbackEngine) {
+ let options = this.#optionsByEngine.get(this.options.fallbackEngine);
+ options.setDimensions?.(this.#embeddingSize);
try {
- this.#engine = await lazy.createEngine(this.optionsFallback);
+ this.#engine = await lazy.createEngine(options);
} catch (fallbackEx) {
lazy.console.error(
- "Fallback engine also failed. Error:" + fallbackEx
+ `Fallback engine ${options.backend} also failed. Error:` +
+ fallbackEx
);
throw new Error(
- "Unable to initialize the ML engine (including fallback wasm).",
+ "Unable to initialize the ML engine (including fallback).",
{ cause: fallbackEx }
);
}
@@ -221,7 +274,7 @@ export class EmbeddingsGenerator {
// call the engine once with the batch of texts.
let batchTensors = await this.engineRun({
- args: [texts],
+ args: this.options.backend == "static-embeddings" ? texts : [texts],
options: { pooling: "mean", normalize: true, max_length: 100 },
});
diff --git a/toolkit/components/ml/tests/browser/browser_ml_embeddings_generator.js b/toolkit/components/ml/tests/browser/browser_ml_embeddings_generator.js
@@ -3,7 +3,7 @@
* You can obtain one at http://mozilla.org/MPL/2.0/. */
/**
- * Test for EmbeddingGenerator.sys.mjs
+ * Test for EmbeddingsGenerator.sys.mjs
*/
"use strict";
@@ -13,6 +13,8 @@ ChromeUtils.defineESModuleGetters(this, {
sinon: "resource://testing-common/Sinon.sys.mjs",
});
+const EMBEDDING_SIZE = 256;
+
async function setup() {
const { removeMocks, remoteClients } = await createAndMockMLRemoteSettings({
autoDownloadFromRemoteSettings: false,
@@ -59,13 +61,13 @@ add_task(async function test_EmbeddingsGenerator_for_minimum_cpu_cores() {
class MockMLEngineForEmbedMany {
async run(request) {
- const texts = request.args[0];
+ const texts = request.args;
return texts.map(text => {
if (typeof text !== "string" || text.trim() === "") {
throw new Error("Invalid input: text must be a non-empty string");
}
// Return a mock embedding vector (e.g., an array of zeros)
- return Array(384).fill(0);
+ return Array(EMBEDDING_SIZE).fill(0);
});
}
}
@@ -83,7 +85,7 @@ add_task(async function test_embedMany_valid_inputs() {
Assert.ok(Array.isArray(result), "Result should be an array");
Assert.equal(result.length, 2, "Should return 2 embeddings");
for (const vector of result) {
- Assert.equal(vector.length, 384, "Each embedding should be of size 384");
+ Assert.equal(vector.length, EMBEDDING_SIZE, "Check embeddings dimension");
}
sinon.restore();
@@ -160,7 +162,7 @@ class MockMLEngineForEmbed {
throw new Error("Invalid input: text must be a non-empty string");
}
// Return a mock embedding vector (e.g., an array of zeros)
- return Array(384).fill(0);
+ return Array(EMBEDDING_SIZE).fill(0);
});
}
}
@@ -175,7 +177,7 @@ add_task(async function test_embed_valid_input() {
const result = await embeddingsGenerator.embed("test string");
Assert.ok(Array.isArray(result), "Embedding result should be an array");
- Assert.equal(result[0].length, 384, "Embedding should be of size 384");
+ Assert.equal(result[0].length, EMBEDDING_SIZE, "Check embedding dimension");
sinon.restore();
});
@@ -202,12 +204,25 @@ add_task(async function test_embed_invalid_input_empty_string() {
sinon.restore();
});
-add_task(async function test_default_backend_is_onnx_native() {
+add_task(async function test_default_backend_is_static_emebddings() {
const embeddingsGenerator = new EmbeddingsGenerator();
Assert.equal(
embeddingsGenerator.options.backend,
+ "static-embeddings",
+ "Check default backend"
+ );
+});
+
+add_task(async function test_onnx() {
+ const embeddingsGenerator = new EmbeddingsGenerator({
+ backend: "onnx-native",
+ embeddingSize: 384,
+ });
+
+ Assert.equal(
+ embeddingsGenerator.options.backend,
"onnx-native",
- "Default backend should be onnx-native"
+ "Check other backend"
);
});
diff --git a/toolkit/components/places/PlacesSemanticHistoryManager.sys.mjs b/toolkit/components/places/PlacesSemanticHistoryManager.sys.mjs
@@ -73,10 +73,22 @@ class PlacesSemanticHistoryManager {
#lastMaxChunksCount = 0;
/**
+ * Checks if a value is an array or a typed array.
+ *
+ * @param {Array|ArrayBufferView} val
+ * @returns {boolean} Whether the input is like an array.
+ */
+ #isArrayLike(val) {
+ return Array.isArray(val) || ArrayBuffer.isView(val);
+ }
+
+ /**
* Constructor for PlacesSemanticHistoryManager.
*
* @param {object} options - Configuration options.
- * @param {number} [options.embeddingSize=384] - Size of embeddings used for vector operations.
+ * @param {string} [options.backend] - The backend to use for embeddings.
+ * See EmbeddingsGenerator.sys.mjs for a list of available backends.
+ * @param {number} [options.embeddingSize=512] - Size of embeddings used for vector operations.
* @param {number} [options.rowLimit=10000] - Maximum number of rows to process from the database.
* @param {string} [options.samplingAttrib="frecency"] - Attribute used for sampling rows.
* @param {number} [options.changeThresholdCount=3] - Threshold of changed rows to trigger updates.
@@ -84,7 +96,8 @@ class PlacesSemanticHistoryManager {
* @param {boolean} [options.testFlag=false] - Flag for test behavior.
*/
constructor({
- embeddingSize = 384,
+ backend = "static-embeddings",
+ embeddingSize = 512,
rowLimit = 10000,
samplingAttrib = "frecency",
changeThresholdCount = 3,
@@ -105,7 +118,10 @@ class PlacesSemanticHistoryManager {
this.#finalized = true;
return;
}
- this.embedder = new lazy.EmbeddingsGenerator(embeddingSize);
+ this.embedder = new lazy.EmbeddingsGenerator({
+ backend,
+ embeddingSize,
+ });
this.semanticDB = new lazy.PlacesSemanticHistoryDatabase({
embeddingSize,
fileName: "places_semantic.sqlite",
@@ -576,13 +592,16 @@ class PlacesSemanticHistoryManager {
if (rowsToAdd.length) {
// Instead of calling engineRun in a loop for each row,
// you prepare an array of requests.
- batchTensors = await this.embedder.embedMany(
- rowsToAdd.map(r => r.content)
- );
- if (batchTensors.length != rowsToAdd.length) {
- throw new Error(
- `Expected ${rowsToAdd.length} tensors, got ${batchTensors.length}`
+ try {
+ batchTensors = await this.embedder.embedMany(
+ rowsToAdd.map(r => r.content)
);
+ batchTensors = this.#convertTensor(batchTensors, rowsToAdd.length);
+ } catch (ex) {
+ lazy.logger.error(`Error processing tensors: ${ex}`);
+ // If we failed generating tensors skip the addition, but proceed
+ // with removals below.
+ rowsToAdd.clear();
}
}
@@ -590,15 +609,8 @@ class PlacesSemanticHistoryManager {
// Process each new row and the corresponding tensor.
for (let i = 0; i < rowsToAdd.length; i++) {
const { url_hash } = rowsToAdd[i];
- const tensor = batchTensors[i];
+ const tensor = batchTensors.values[i];
try {
- if (!Array.isArray(tensor) || tensor.length !== this.#embeddingSize) {
- lazy.logger.error(
- `Got tensor with invalid length: ${Array.isArray(tensor) ? tensor.length : "non-array value"}`
- );
- continue;
- }
-
// We first insert the url into vec_history_mapping, get the rowid
// and then insert the embedding into vec_history using that.
// Doing the opposite doesn't work, as RETURNING is not properly
@@ -752,27 +764,18 @@ class PlacesSemanticHistoryManager {
const inferStartTime = ChromeUtils.now();
let results = [];
await this.embedder.ensureEngine();
- let tensor = await this.embedder.embed(queryContext.searchString);
- if (!tensor) {
+ let tensor;
+ try {
+ tensor = await this.embedder.embed(queryContext.searchString);
+ tensor = this.#convertTensor(tensor, 1);
+ } catch (ex) {
+ lazy.logger.error(`Error processing tensor: ${ex}`);
return results;
}
let metrics = tensor.metrics;
- // If tensor is a nested array with a single element, extract the inner array.
- if (
- Array.isArray(tensor) &&
- tensor.length === 1 &&
- Array.isArray(tensor[0])
- ) {
- tensor = tensor[0];
- }
-
- if (!Array.isArray(tensor) || tensor.length !== this.#embeddingSize) {
- lazy.logger.info(`Got tensor with length ${tensor.length}`);
- return results;
- }
let conn = await this.getConnection();
let rows = await conn.executeCached(
@@ -805,7 +808,7 @@ class PlacesSemanticHistoryManager {
ORDER BY distance
`,
{
- vector: lazy.PlacesUtils.tensorToSQLBindable(tensor),
+ vector: lazy.PlacesUtils.tensorToSQLBindable(tensor.values[0]),
distanceThreshold: this.#distanceThreshold,
}
);
@@ -837,6 +840,63 @@ class PlacesSemanticHistoryManager {
}
/**
+ * Converts result of an engine run into a consistent structure.
+ *
+ * @param {Array|object} tensor
+ * @param {number} expectedLength
+ * @returns {{ metrics: object, values: Array <Array|Float32Array>[]}}
+ */
+ #convertTensor(tensor, expectedLength) {
+ if (!tensor) {
+ throw new Error("Unexpected empty tensor");
+ }
+ let result = { metrics: tensor?.metrics ?? null, values: [] };
+ if (expectedLength == 0) {
+ return result;
+ }
+
+ // It may be a { metrics, output } object.
+ if (tensor.output) {
+ if (Array.isArray(tensor.output) && this.#isArrayLike(tensor.output[0])) {
+ result.values = tensor.output;
+ } else {
+ result.values.push(tensor.output);
+ }
+ } else {
+ // It may be a nested array, then we must extract it first.
+ if (
+ Array.isArray(tensor) &&
+ tensor.length === 1 &&
+ Array.isArray(tensor[0])
+ ) {
+ tensor = tensor[0];
+ }
+
+ // Then we check if it's an array of arrays or just a single value.
+ if (Array.isArray(tensor) && this.#isArrayLike(tensor[0])) {
+ result.values = tensor;
+ } else {
+ result.values.push(tensor);
+ }
+ }
+
+ if (result.values.length != expectedLength) {
+ throw new Error(
+ `Got ${result.values.length} embeddings instead of ${expectedLength}`
+ );
+ }
+ if (
+ !this.#isArrayLike(result.values[0]) ||
+ result.values[0].length != this.#embeddingSize
+ ) {
+ throw new Error(
+ `Got tensors with dimension ${tensor.values[0].length} instead of ${this.#embeddingSize}`
+ );
+ }
+ return result;
+ }
+
+ /**
* Performs a WAL checkpoint to flush all pending writes from WAL to the main database file.
* Then measures the final disk size of semantic.sqlite.
* **This method is for test purposes only.**
diff --git a/toolkit/components/places/PlacesUtils.sys.mjs b/toolkit/components/places/PlacesUtils.sys.mjs
@@ -2147,14 +2147,19 @@ export var PlacesUtils = {
/**
* Converts an array of Float32 into a SQL bindable blob format.
*
- * @param {Array<number>} tensor
+ * @param {Array<number>|Float32Array} tensor
* @returns {Uint8ClampedArray} SQL bindable blob.
*/
tensorToSQLBindable(tensor) {
- if (!tensor || !Array.isArray(tensor)) {
+ if (!tensor) {
+ throw new Error("tensorToSQLBindable received an invalid tensor");
+ } else if (Array.isArray(tensor)) {
+ return new Uint8ClampedArray(new Float32Array(tensor).buffer);
+ } else if (tensor instanceof Float32Array) {
+ return new Uint8ClampedArray(tensor.buffer);
+ } else {
throw new Error("tensorToSQLBindable received an invalid tensor");
}
- return new Uint8ClampedArray(new Float32Array(tensor).buffer);
},
/**
diff --git a/toolkit/components/places/tests/unit/test_PlacesSemanticHistoryDatabase.js b/toolkit/components/places/tests/unit/test_PlacesSemanticHistoryDatabase.js
@@ -98,6 +98,43 @@ add_task(async function test_corruptdb() {
await db.removeDatabaseFiles();
});
+add_task(async function test_differentDimensionsReplacesDatabase() {
+ let db = new PlacesSemanticHistoryDatabase({
+ embeddingSize: EMBEDDING_SIZE,
+ fileName: "places_semantic.sqlite",
+ });
+ let conn = await db.getConnection();
+ Assert.ok(
+ !!(
+ await conn.execute(
+ `SELECT INSTR(sql, :needle) > 0
+ FROM sqlite_master WHERE name = 'vec_history'`,
+ { needle: `FLOAT[${EMBEDDING_SIZE}]` }
+ )
+ )[0].getResultByIndex(0),
+ "Check embeddings size for the table"
+ );
+ await db.closeConnection();
+
+ db = new PlacesSemanticHistoryDatabase({
+ embeddingSize: EMBEDDING_SIZE + 16,
+ fileName: "places_semantic.sqlite",
+ });
+ conn = await db.getConnection();
+ Assert.ok(
+ !!(
+ await conn.execute(
+ `SELECT INSTR(sql, :needle) > 0
+ FROM sqlite_master WHERE name = 'vec_history'`,
+ { needle: `FLOAT[${EMBEDDING_SIZE + 16}]` }
+ )
+ )[0].getResultByIndex(0),
+ "Check that the database was replaced"
+ );
+ await db.closeConnection();
+ await db.removeDatabaseFiles();
+});
+
add_task(async function test_healthydb() {
let db = new PlacesSemanticHistoryDatabase({
embeddingSize: EMBEDDING_SIZE,
diff --git a/toolkit/components/places/tests/unit/test_PlacesSemanticHistoryManager.js b/toolkit/components/places/tests/unit/test_PlacesSemanticHistoryManager.js
@@ -10,8 +10,9 @@ ChromeUtils.defineESModuleGetters(this, {
"resource://gre/modules/PlacesSemanticHistoryManager.sys.mjs",
});
-// Must be divisible by 8.
-const EMBEDDING_SIZE = 16;
+// Must be supported, and a multiple of 8. see EmbeddingsGenerator.sys.mjs for
+// a list of supported values.
+const EMBEDDING_SIZE = 32;
function approxEqual(a, b, tolerance = 1e-6) {
return Math.abs(a - b) < tolerance;
@@ -43,7 +44,7 @@ class MockMLEngine {
}
async run(request) {
- const texts = request.args[0];
+ const texts = request.args;
return texts.map(text => {
if (typeof text !== "string" || text.trim() === "") {
throw new Error("Invalid input: text must be a non-empty string");