MLSuggest.sys.mjs (14959B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ 4 5 /** 6 * MLSuggest helps with ML based suggestions around intents and location. 7 */ 8 9 import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs"; 10 11 const lazy = XPCOMUtils.declareLazy({ 12 createEngine: "chrome://global/content/ml/EngineProcess.sys.mjs", 13 UrlbarPrefs: "moz-src:///browser/components/urlbar/UrlbarPrefs.sys.mjs", 14 }); 15 16 /** 17 * @import {EngineRunRequest, EngineRunResponse, MLEntry} from "moz-src:///toolkit/components/ml/actors/MLEngineParent.sys.mjs" 18 * @typedef {Parameters<typeof lazy.createEngine>[0]} MLEngineOptions 19 * @typedef {Awaited<ReturnType<typeof lazy.createEngine>>} MLEngine 20 */ 21 22 // List of prepositions used in subject cleaning. 23 const PREPOSITIONS = ["in", "at", "on", "for", "to", "near"]; 24 25 const MAX_QUERY_LENGTH = 200; 26 const NAME_PUNCTUATION = [".", "-", "'"]; 27 const NAME_PUNCTUATION_EXCEPT_DOT = NAME_PUNCTUATION.filter(p => p !== "."); 28 29 /** 30 * Class for handling ML-based suggestions using intent and NER models. 31 * 32 * @class 33 */ 34 class _MLSuggest { 35 /** 36 * @type {Map<string, MLEngine>} 37 */ 38 #modelEngines = new Map(); 39 40 INTENT_OPTIONS = { 41 taskName: "text-classification", 42 featureId: "suggest-intent-classification", 43 timeoutMS: -1, 44 numThreads: 2, 45 backend: "onnx-native", 46 }; 47 48 INTENT_OPTIONS_FALLBACK = { 49 taskName: "text-classification", 50 featureId: "suggest-intent-classification", 51 timeoutMS: -1, 52 numThreads: 2, 53 backend: "onnx", 54 }; 55 56 NER_OPTIONS = { 57 taskName: "token-classification", 58 featureId: "suggest-NER", 59 timeoutMS: -1, 60 numThreads: 2, 61 backend: "onnx-native", 62 }; 63 64 NER_OPTIONS_FALLBACK = { 65 taskName: "token-classification", 66 featureId: "suggest-NER", 67 timeoutMS: -1, 68 numThreads: 2, 69 backend: "onnx", 70 }; 71 72 /** 73 * Helper to wrap createEngine for testing purposes. 74 * 75 * @param {MLEngineOptions} options 76 * Configuration options for the ML engine. 77 */ 78 createEngine(options) { 79 return lazy.createEngine(options); 80 } 81 82 /** 83 * Initializes the intent and NER models. 84 */ 85 async initialize() { 86 await Promise.all([ 87 this.#initializeModelEngine( 88 this.INTENT_OPTIONS, 89 this.INTENT_OPTIONS_FALLBACK 90 ), 91 this.#initializeModelEngine(this.NER_OPTIONS, this.NER_OPTIONS_FALLBACK), 92 ]); 93 } 94 95 /** 96 * @typedef {object} MLSuggestResult 97 * @property {string} intent 98 * The predicted intent label of the query. Possible values include: 99 * - 'information_intent': For queries seeking general information. 100 * - 'yelp_intent': For queries related to local businesses or services. 101 * - 'navigation_intent': For queries with navigation-related actions. 102 * - 'travel_intent': For queries showing travel-related interests. 103 * - 'purchase_intent': For queries with purchase or shopping intent. 104 * - 'weather_intent': For queries asking about weather or forecasts. 105 * - 'translation_intent': For queries seeking translations. 106 * - 'unknown': When the intent cannot be classified with confidence. 107 * - '' (empty string): Returned when model probabilities for all intents 108 * are below the intent threshold. 109 * @property {?{city: ?string, state: ?string}} location 110 * The detected location from the query. 111 * @property {string} subject 112 * The subject of the query after location is removed. 113 * @property {{intent: object, ner: object}} metrics 114 * The combined metrics from NER model results, representing additional 115 * information about the model's performance. 116 */ 117 118 /** 119 * Generates ML-based suggestions by finding intent, detecting entities, and 120 * combining locations. 121 * 122 * @param {string} query 123 * The user's input query. 124 * @returns {Promise<?MLSuggestResult>} 125 * The suggestion result including intent, location, and subject, or null if 126 * an error occurs or query length > MAX_QUERY_LENGTH 127 */ 128 async makeSuggestions(query) { 129 // avoid bunch of work for very long strings 130 if (query.length > MAX_QUERY_LENGTH) { 131 return null; 132 } 133 134 let intentRes, nerResult; 135 try { 136 [intentRes, nerResult] = await Promise.all([ 137 this._findIntent(query), 138 this._findNER(query), 139 ]); 140 } catch (error) { 141 return null; 142 } 143 144 if (!intentRes || !nerResult) { 145 return null; 146 } 147 148 const locationResVal = this.#combineLocations( 149 nerResult, 150 lazy.UrlbarPrefs.get("nerThreshold") 151 ); 152 153 const intentLabel = this.#applyIntentThreshold( 154 intentRes, 155 lazy.UrlbarPrefs.get("intentThreshold") 156 ); 157 158 return { 159 intent: intentLabel, 160 location: locationResVal, 161 subject: this.#findSubjectFromQuery(query, locationResVal), 162 metrics: { intent: intentRes.metrics, ner: nerResult.metrics }, 163 }; 164 } 165 166 /** 167 * Shuts down all initialized engines. 168 */ 169 async shutdown() { 170 for (const [key, engine] of this.#modelEngines.entries()) { 171 try { 172 await engine.terminate?.(); 173 } finally { 174 // Remove each engine after termination 175 this.#modelEngines.delete(key); 176 } 177 } 178 } 179 180 /** 181 * Initializes a engine model. 182 * 183 * @param {MLEngineOptions} options 184 * Configuration options for the ML engine. 185 * @param {MLEngineOptions} [fallbackOptions] 186 * Fallback options if creating with the main options fails. 187 */ 188 async #initializeModelEngine(options, fallbackOptions = null) { 189 const featureId = options.featureId; 190 191 // uses cache if engine was used 192 let engine = this.#modelEngines.get(featureId); 193 if (engine) { 194 return engine; 195 } 196 try { 197 engine = await this.createEngine(options); 198 } catch (e) { 199 if (fallbackOptions) { 200 try { 201 engine = await this.createEngine(fallbackOptions); 202 } catch (_) { 203 // do nothing 204 } 205 } 206 } 207 208 // Cache the engine 209 this.#modelEngines.set(featureId, engine); 210 return engine; 211 } 212 213 /** 214 * Finds the intent of the query using the intent classification model. 215 * (This has been made public to enable testing) 216 * 217 * @param {string} query 218 * The user's input query. 219 * @param {Omit<EngineRunRequest, "args">} [options] 220 * The options for the engine pipeline 221 */ 222 async _findIntent(query, options = {}) { 223 const engineIntentClassifier = this.#modelEngines.get( 224 this.INTENT_OPTIONS.featureId 225 ); 226 if (!engineIntentClassifier) { 227 return null; 228 } 229 230 /** @type {EngineRunResponse} */ 231 let res; 232 try { 233 res = await engineIntentClassifier.run({ 234 args: [query], 235 options, 236 }); 237 } catch (error) { 238 // engine could timeout or fail, so remove that from cache 239 // and reinitialize 240 this.#modelEngines.delete(this.INTENT_OPTIONS.featureId); 241 this.#initializeModelEngine( 242 this.INTENT_OPTIONS, 243 this.INTENT_OPTIONS_FALLBACK 244 ); 245 return null; 246 } 247 return res; 248 } 249 250 /** 251 * Finds named entities in the query using the NER model. 252 * (This has been made public to enable testing) 253 * 254 * @param {string} query 255 * The user's input query. 256 * @param {Omit<EngineRunRequest, "args">} options 257 * The options for the engine pipeline 258 */ 259 async _findNER(query, options = {}) { 260 const engineNER = this.#modelEngines.get(this.NER_OPTIONS.featureId); 261 try { 262 return engineNER?.run({ args: [query], options }); 263 } catch (error) { 264 // engine could timeout or fail, so remove that from cache 265 // and reinitialize 266 this.#modelEngines.delete(this.NER_OPTIONS.featureId); 267 this.#initializeModelEngine(this.NER_OPTIONS, this.NER_OPTIONS_FALLBACK); 268 return null; 269 } 270 } 271 272 /** 273 * Applies a confidence threshold to determine the intent label. 274 * 275 * If the highest-scoring intent in the result exceeds the threshold, its label 276 * is returned; otherwise, the label defaults to 'unknown'. 277 * 278 * @param {EngineRunResponse} intentResult 279 * The result of the intent classification model, where each item includes 280 * a `label` and `score`. 281 * @param {number} intentThreshold 282 * The confidence threshold for accepting the intent label. 283 * @returns {string} 284 * The determined intent label or 'unknown' if the threshold is not met. 285 */ 286 #applyIntentThreshold(intentResult, intentThreshold) { 287 return intentResult[0]?.score > intentThreshold 288 ? intentResult[0].label 289 : ""; 290 } 291 292 /** 293 * Combines location tokens detected by NER into separate city and state 294 * components. This method processes city, state, and combined city-state 295 * entities, returning an object with `city` and `state` fields. 296 * 297 * Handles the following entity types: 298 * - B-CITY, I-CITY: Identifies city tokens. 299 * - B-STATE, I-STATE: Identifies state tokens. 300 * - B-CITYSTATE, I-CITYSTATE: Identifies tokens that represent a combined 301 * city and state. 302 * 303 * @param {EngineRunResponse} nerResult 304 * The NER results containing tokens and their corresponding entity labels. 305 * @param {number} nerThreshold 306 * The confidence threshold for including entities. Tokens with a confidence 307 * score below this threshold will be ignored. 308 */ 309 #combineLocations(nerResult, nerThreshold) { 310 let cityResult = []; 311 let stateResult = []; 312 let cityStateResult = []; 313 314 for (let i = 0; i < nerResult.length; i++) { 315 const res = nerResult[i]; 316 if (res.entity === "B-CITY" || res.entity === "I-CITY") { 317 this.#processNERToken(res, cityResult, nerThreshold); 318 } else if (res.entity === "B-STATE" || res.entity === "I-STATE") { 319 this.#processNERToken(res, stateResult, nerThreshold); 320 } else if (res.entity === "B-CITYSTATE" || res.entity === "I-CITYSTATE") { 321 this.#processNERToken(res, cityStateResult, nerThreshold); 322 } 323 } 324 325 // Handle city_state as combined and split into city and state 326 if (cityStateResult.length && !cityResult.length && !stateResult.length) { 327 let cityStateSplit = cityStateResult.join(" ").split(","); 328 cityResult = 329 cityStateSplit[0] 330 ?.trim?.() 331 .split(",") 332 .filter(item => item.trim() !== "") || []; 333 stateResult = 334 cityStateSplit[1] 335 ?.trim?.() 336 .split(",") 337 .filter(item => item.trim() !== "") || []; 338 } 339 340 // Remove trailing punctuation from the last cityResult element if present 341 this.#removePunctFromEndIfPresent(cityResult); 342 this.#removePunctFromEndIfPresent(stateResult); 343 344 // Return city and state as separate components if detected 345 return { 346 city: cityResult.join(" ").trim() || null, 347 state: stateResult.join(" ").trim() || null, 348 }; 349 } 350 351 /** 352 * Processes a token from the NER results, appending it to the provided result 353 * array while handling wordpieces (e.g., "##"), punctuation, and 354 * multi-token entities. 355 * 356 * - Appends wordpieces (starting with "##") to the last token in the array. 357 * - Handles punctuation tokens like ".", "-", or "'". 358 * - Ensures continuity for entities split across multiple tokens. 359 * 360 * @param {MLEntry} res 361 * The NER result token to process. Should include: 362 * - {string} word: The word or token from the NER output. 363 * - {number} score: The confidence score for the token. 364 * - {string} entity: The entity type label (e.g., "B-CITY", "I-STATE"). 365 * @param {string[]} resultArray 366 * The array to append the processed token. Typically `cityResult`, 367 * `stateResult`, or `cityStateResult`. 368 * @param {number} nerThreshold 369 * The confidence threshold for including tokens. Tokens with a score below 370 * this threshold will be ignored. 371 */ 372 #processNERToken(res, resultArray, nerThreshold) { 373 // Skip low-confidence tokens 374 if (res.score <= nerThreshold) { 375 return; 376 } 377 378 const lastTokenIndex = resultArray.length - 1; 379 // "##" prefix indicates that a token is continuation of a word 380 // rather than a start of a new word. 381 // reference -> https://github.com/google-research/bert/blob/master/tokenization.py#L314-L316 382 if (res.word.startsWith("##") && resultArray.length) { 383 resultArray[lastTokenIndex] += res.word.slice(2); 384 } else if ( 385 resultArray.length && 386 (NAME_PUNCTUATION.includes(res.word) || 387 NAME_PUNCTUATION_EXCEPT_DOT.includes( 388 resultArray[lastTokenIndex].slice(-1) 389 )) 390 ) { 391 // Special handling for punctuation like ".", "-", or "'" 392 resultArray[lastTokenIndex] += res.word; 393 } else { 394 resultArray.push(res.word); 395 } 396 } 397 398 /** 399 * Removes trailing punctuation from the last element in the result array 400 * if the last character matches any punctuation in `NAME_PUNCTUATION`. 401 * 402 * This method is useful for cleaning up city or state tokens that may 403 * contain unwanted punctuation after processing NER results. 404 * 405 * @param {string[]} resultArray 406 * An array of strings representing detected entities (e.g., cities or states). 407 * The array is modified in place if the last element ends with punctuation. 408 */ 409 #removePunctFromEndIfPresent(resultArray) { 410 const lastTokenIndex = resultArray.length - 1; 411 if ( 412 resultArray.length && 413 NAME_PUNCTUATION.includes(resultArray[lastTokenIndex].slice(-1)) 414 ) { 415 resultArray[lastTokenIndex] = resultArray[lastTokenIndex].slice(0, -1); 416 } 417 } 418 419 /** 420 * Finds the subject from the query, removing the city and location words. 421 * 422 * @param {string} query 423 * @param {{city: ?string, state: ?string}} location 424 */ 425 #findSubjectFromQuery(query, location) { 426 // If location is null or no city/state, return the entire query 427 if (!location || (!location.city && !location.state)) { 428 return query; 429 } 430 431 // Remove the city and state values from the query 432 let locValues = Object.values(location) 433 .map(loc => loc?.replace(/\W+/g, " ")) 434 .filter(loc => loc?.trim()); 435 436 // Regular expression to remove locations 437 // This handles single & multi-worded cities/states 438 let locPattern = locValues.map(loc => `\\b${loc}\\b`).join("|"); 439 let locRegex = new RegExp(locPattern, "g"); 440 441 // Remove locations, trim whitespace, and split words 442 let words = query 443 .replace(/\W+/g, " ") 444 .replace(locRegex, "") 445 .split(/\W+/) 446 .filter(word => !!word.length); 447 448 let subjectWords = this.#cleanSubject(words); 449 return subjectWords.join(" "); 450 } 451 452 /** 453 * Remove trailing prepositions from the list of words 454 * 455 * @param {string[]} words 456 */ 457 #cleanSubject(words) { 458 while (words.length && PREPOSITIONS.includes(words[words.length - 1])) { 459 words.pop(); 460 } 461 return words; 462 } 463 } 464 465 // Export the singleton instance 466 export var MLSuggest = new _MLSuggest();