tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

MLSuggest.sys.mjs (14959B)


      1 /* This Source Code Form is subject to the terms of the Mozilla Public
      2 * License, v. 2.0. If a copy of the MPL was not distributed with this
      3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
      4 
      5 /**
      6 * MLSuggest helps with ML based suggestions around intents and location.
      7 */
      8 
      9 import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
     10 
     11 const lazy = XPCOMUtils.declareLazy({
     12  createEngine: "chrome://global/content/ml/EngineProcess.sys.mjs",
     13  UrlbarPrefs: "moz-src:///browser/components/urlbar/UrlbarPrefs.sys.mjs",
     14 });
     15 
     16 /**
     17 * @import {EngineRunRequest, EngineRunResponse, MLEntry} from "moz-src:///toolkit/components/ml/actors/MLEngineParent.sys.mjs"
     18 * @typedef {Parameters<typeof lazy.createEngine>[0]} MLEngineOptions
     19 * @typedef {Awaited<ReturnType<typeof lazy.createEngine>>} MLEngine
     20 */
     21 
     22 // List of prepositions used in subject cleaning.
     23 const PREPOSITIONS = ["in", "at", "on", "for", "to", "near"];
     24 
     25 const MAX_QUERY_LENGTH = 200;
     26 const NAME_PUNCTUATION = [".", "-", "'"];
     27 const NAME_PUNCTUATION_EXCEPT_DOT = NAME_PUNCTUATION.filter(p => p !== ".");
     28 
     29 /**
     30 * Class for handling ML-based suggestions using intent and NER models.
     31 *
     32 * @class
     33 */
     34 class _MLSuggest {
     35  /**
     36   * @type {Map<string, MLEngine>}
     37   */
     38  #modelEngines = new Map();
     39 
     40  INTENT_OPTIONS = {
     41    taskName: "text-classification",
     42    featureId: "suggest-intent-classification",
     43    timeoutMS: -1,
     44    numThreads: 2,
     45    backend: "onnx-native",
     46  };
     47 
     48  INTENT_OPTIONS_FALLBACK = {
     49    taskName: "text-classification",
     50    featureId: "suggest-intent-classification",
     51    timeoutMS: -1,
     52    numThreads: 2,
     53    backend: "onnx",
     54  };
     55 
     56  NER_OPTIONS = {
     57    taskName: "token-classification",
     58    featureId: "suggest-NER",
     59    timeoutMS: -1,
     60    numThreads: 2,
     61    backend: "onnx-native",
     62  };
     63 
     64  NER_OPTIONS_FALLBACK = {
     65    taskName: "token-classification",
     66    featureId: "suggest-NER",
     67    timeoutMS: -1,
     68    numThreads: 2,
     69    backend: "onnx",
     70  };
     71 
     72  /**
     73   * Helper to wrap createEngine for testing purposes.
     74   *
     75   * @param {MLEngineOptions} options
     76   *   Configuration options for the ML engine.
     77   */
     78  createEngine(options) {
     79    return lazy.createEngine(options);
     80  }
     81 
     82  /**
     83   * Initializes the intent and NER models.
     84   */
     85  async initialize() {
     86    await Promise.all([
     87      this.#initializeModelEngine(
     88        this.INTENT_OPTIONS,
     89        this.INTENT_OPTIONS_FALLBACK
     90      ),
     91      this.#initializeModelEngine(this.NER_OPTIONS, this.NER_OPTIONS_FALLBACK),
     92    ]);
     93  }
     94 
     95  /**
     96   * @typedef {object} MLSuggestResult
     97   * @property {string} intent
     98   *   The predicted intent label of the query. Possible values include:
     99   *   - 'information_intent': For queries seeking general information.
    100   *   - 'yelp_intent': For queries related to local businesses or services.
    101   *   - 'navigation_intent': For queries with navigation-related actions.
    102   *   - 'travel_intent': For queries showing travel-related interests.
    103   *   - 'purchase_intent': For queries with purchase or shopping intent.
    104   *   - 'weather_intent': For queries asking about weather or forecasts.
    105   *   - 'translation_intent': For queries seeking translations.
    106   *   - 'unknown': When the intent cannot be classified with confidence.
    107   *   - '' (empty string): Returned when model probabilities for all intents
    108   *     are below the intent threshold.
    109   * @property {?{city: ?string, state: ?string}} location
    110   *   The detected location from the query.
    111   * @property {string} subject
    112   *   The subject of the query after location is removed.
    113   * @property {{intent: object, ner: object}} metrics
    114   *   The combined metrics from NER model results, representing additional
    115   *   information about the model's performance.
    116   */
    117 
    118  /**
    119   * Generates ML-based suggestions by finding intent, detecting entities, and
    120   * combining locations.
    121   *
    122   * @param {string} query
    123   *   The user's input query.
    124   * @returns {Promise<?MLSuggestResult>}
    125   *   The suggestion result including intent, location, and subject, or null if
    126   *   an error occurs or query length > MAX_QUERY_LENGTH
    127   */
    128  async makeSuggestions(query) {
    129    // avoid bunch of work for very long strings
    130    if (query.length > MAX_QUERY_LENGTH) {
    131      return null;
    132    }
    133 
    134    let intentRes, nerResult;
    135    try {
    136      [intentRes, nerResult] = await Promise.all([
    137        this._findIntent(query),
    138        this._findNER(query),
    139      ]);
    140    } catch (error) {
    141      return null;
    142    }
    143 
    144    if (!intentRes || !nerResult) {
    145      return null;
    146    }
    147 
    148    const locationResVal = this.#combineLocations(
    149      nerResult,
    150      lazy.UrlbarPrefs.get("nerThreshold")
    151    );
    152 
    153    const intentLabel = this.#applyIntentThreshold(
    154      intentRes,
    155      lazy.UrlbarPrefs.get("intentThreshold")
    156    );
    157 
    158    return {
    159      intent: intentLabel,
    160      location: locationResVal,
    161      subject: this.#findSubjectFromQuery(query, locationResVal),
    162      metrics: { intent: intentRes.metrics, ner: nerResult.metrics },
    163    };
    164  }
    165 
    166  /**
    167   * Shuts down all initialized engines.
    168   */
    169  async shutdown() {
    170    for (const [key, engine] of this.#modelEngines.entries()) {
    171      try {
    172        await engine.terminate?.();
    173      } finally {
    174        // Remove each engine after termination
    175        this.#modelEngines.delete(key);
    176      }
    177    }
    178  }
    179 
    180  /**
    181   * Initializes a engine model.
    182   *
    183   * @param {MLEngineOptions} options
    184   *   Configuration options for the ML engine.
    185   * @param {MLEngineOptions} [fallbackOptions]
    186   *   Fallback options if creating with the main options fails.
    187   */
    188  async #initializeModelEngine(options, fallbackOptions = null) {
    189    const featureId = options.featureId;
    190 
    191    // uses cache if engine was used
    192    let engine = this.#modelEngines.get(featureId);
    193    if (engine) {
    194      return engine;
    195    }
    196    try {
    197      engine = await this.createEngine(options);
    198    } catch (e) {
    199      if (fallbackOptions) {
    200        try {
    201          engine = await this.createEngine(fallbackOptions);
    202        } catch (_) {
    203          // do nothing
    204        }
    205      }
    206    }
    207 
    208    // Cache the engine
    209    this.#modelEngines.set(featureId, engine);
    210    return engine;
    211  }
    212 
    213  /**
    214   * Finds the intent of the query using the intent classification model.
    215   * (This has been made public to enable testing)
    216   *
    217   * @param {string} query
    218   *   The user's input query.
    219   * @param {Omit<EngineRunRequest, "args">} [options]
    220   *   The options for the engine pipeline
    221   */
    222  async _findIntent(query, options = {}) {
    223    const engineIntentClassifier = this.#modelEngines.get(
    224      this.INTENT_OPTIONS.featureId
    225    );
    226    if (!engineIntentClassifier) {
    227      return null;
    228    }
    229 
    230    /** @type {EngineRunResponse} */
    231    let res;
    232    try {
    233      res = await engineIntentClassifier.run({
    234        args: [query],
    235        options,
    236      });
    237    } catch (error) {
    238      // engine could timeout or fail, so remove that from cache
    239      // and reinitialize
    240      this.#modelEngines.delete(this.INTENT_OPTIONS.featureId);
    241      this.#initializeModelEngine(
    242        this.INTENT_OPTIONS,
    243        this.INTENT_OPTIONS_FALLBACK
    244      );
    245      return null;
    246    }
    247    return res;
    248  }
    249 
    250  /**
    251   * Finds named entities in the query using the NER model.
    252   * (This has been made public to enable testing)
    253   *
    254   * @param {string} query
    255   *   The user's input query.
    256   * @param {Omit<EngineRunRequest, "args">} options
    257   *   The options for the engine pipeline
    258   */
    259  async _findNER(query, options = {}) {
    260    const engineNER = this.#modelEngines.get(this.NER_OPTIONS.featureId);
    261    try {
    262      return engineNER?.run({ args: [query], options });
    263    } catch (error) {
    264      // engine could timeout or fail, so remove that from cache
    265      // and reinitialize
    266      this.#modelEngines.delete(this.NER_OPTIONS.featureId);
    267      this.#initializeModelEngine(this.NER_OPTIONS, this.NER_OPTIONS_FALLBACK);
    268      return null;
    269    }
    270  }
    271 
    272  /**
    273   * Applies a confidence threshold to determine the intent label.
    274   *
    275   * If the highest-scoring intent in the result exceeds the threshold, its label
    276   * is returned; otherwise, the label defaults to 'unknown'.
    277   *
    278   * @param {EngineRunResponse} intentResult
    279   *   The result of the intent classification model, where each item includes
    280   *   a `label` and `score`.
    281   * @param {number} intentThreshold
    282   *   The confidence threshold for accepting the intent label.
    283   * @returns {string}
    284   *   The determined intent label or 'unknown' if the threshold is not met.
    285   */
    286  #applyIntentThreshold(intentResult, intentThreshold) {
    287    return intentResult[0]?.score > intentThreshold
    288      ? intentResult[0].label
    289      : "";
    290  }
    291 
    292  /**
    293   * Combines location tokens detected by NER into separate city and state
    294   * components. This method processes city, state, and combined city-state
    295   * entities, returning an object with `city` and `state` fields.
    296   *
    297   * Handles the following entity types:
    298   * - B-CITY, I-CITY: Identifies city tokens.
    299   * - B-STATE, I-STATE: Identifies state tokens.
    300   * - B-CITYSTATE, I-CITYSTATE: Identifies tokens that represent a combined
    301   *   city and state.
    302   *
    303   * @param {EngineRunResponse} nerResult
    304   *   The NER results containing tokens and their corresponding entity labels.
    305   * @param {number} nerThreshold
    306   *   The confidence threshold for including entities. Tokens with a confidence
    307   *   score below this threshold will be ignored.
    308   */
    309  #combineLocations(nerResult, nerThreshold) {
    310    let cityResult = [];
    311    let stateResult = [];
    312    let cityStateResult = [];
    313 
    314    for (let i = 0; i < nerResult.length; i++) {
    315      const res = nerResult[i];
    316      if (res.entity === "B-CITY" || res.entity === "I-CITY") {
    317        this.#processNERToken(res, cityResult, nerThreshold);
    318      } else if (res.entity === "B-STATE" || res.entity === "I-STATE") {
    319        this.#processNERToken(res, stateResult, nerThreshold);
    320      } else if (res.entity === "B-CITYSTATE" || res.entity === "I-CITYSTATE") {
    321        this.#processNERToken(res, cityStateResult, nerThreshold);
    322      }
    323    }
    324 
    325    // Handle city_state as combined and split into city and state
    326    if (cityStateResult.length && !cityResult.length && !stateResult.length) {
    327      let cityStateSplit = cityStateResult.join(" ").split(",");
    328      cityResult =
    329        cityStateSplit[0]
    330          ?.trim?.()
    331          .split(",")
    332          .filter(item => item.trim() !== "") || [];
    333      stateResult =
    334        cityStateSplit[1]
    335          ?.trim?.()
    336          .split(",")
    337          .filter(item => item.trim() !== "") || [];
    338    }
    339 
    340    // Remove trailing punctuation from the last cityResult element if present
    341    this.#removePunctFromEndIfPresent(cityResult);
    342    this.#removePunctFromEndIfPresent(stateResult);
    343 
    344    // Return city and state as separate components if detected
    345    return {
    346      city: cityResult.join(" ").trim() || null,
    347      state: stateResult.join(" ").trim() || null,
    348    };
    349  }
    350 
    351  /**
    352   * Processes a token from the NER results, appending it to the provided result
    353   * array while handling wordpieces (e.g., "##"), punctuation, and
    354   * multi-token entities.
    355   *
    356   * - Appends wordpieces (starting with "##") to the last token in the array.
    357   * - Handles punctuation tokens like ".", "-", or "'".
    358   * - Ensures continuity for entities split across multiple tokens.
    359   *
    360   * @param {MLEntry} res
    361   *   The NER result token to process. Should include:
    362   *   - {string} word: The word or token from the NER output.
    363   *   - {number} score: The confidence score for the token.
    364   *   - {string} entity: The entity type label (e.g., "B-CITY", "I-STATE").
    365   * @param {string[]} resultArray
    366   *   The array to append the processed token. Typically `cityResult`,
    367   *   `stateResult`, or `cityStateResult`.
    368   * @param {number} nerThreshold
    369   *   The confidence threshold for including tokens. Tokens with a score below
    370   *   this threshold will be ignored.
    371   */
    372  #processNERToken(res, resultArray, nerThreshold) {
    373    // Skip low-confidence tokens
    374    if (res.score <= nerThreshold) {
    375      return;
    376    }
    377 
    378    const lastTokenIndex = resultArray.length - 1;
    379    // "##" prefix indicates that a token is continuation of a word
    380    // rather than a start of a new word.
    381    // reference -> https://github.com/google-research/bert/blob/master/tokenization.py#L314-L316
    382    if (res.word.startsWith("##") && resultArray.length) {
    383      resultArray[lastTokenIndex] += res.word.slice(2);
    384    } else if (
    385      resultArray.length &&
    386      (NAME_PUNCTUATION.includes(res.word) ||
    387        NAME_PUNCTUATION_EXCEPT_DOT.includes(
    388          resultArray[lastTokenIndex].slice(-1)
    389        ))
    390    ) {
    391      // Special handling for punctuation like ".", "-", or "'"
    392      resultArray[lastTokenIndex] += res.word;
    393    } else {
    394      resultArray.push(res.word);
    395    }
    396  }
    397 
    398  /**
    399   * Removes trailing punctuation from the last element in the result array
    400   * if the last character matches any punctuation in `NAME_PUNCTUATION`.
    401   *
    402   * This method is useful for cleaning up city or state tokens that may
    403   * contain unwanted punctuation after processing NER results.
    404   *
    405   * @param {string[]} resultArray
    406   *   An array of strings representing detected entities (e.g., cities or states).
    407   *   The array is modified in place if the last element ends with punctuation.
    408   */
    409  #removePunctFromEndIfPresent(resultArray) {
    410    const lastTokenIndex = resultArray.length - 1;
    411    if (
    412      resultArray.length &&
    413      NAME_PUNCTUATION.includes(resultArray[lastTokenIndex].slice(-1))
    414    ) {
    415      resultArray[lastTokenIndex] = resultArray[lastTokenIndex].slice(0, -1);
    416    }
    417  }
    418 
    419  /**
    420   * Finds the subject from the query, removing the city and location words.
    421   *
    422   * @param {string} query
    423   * @param {{city: ?string, state: ?string}} location
    424   */
    425  #findSubjectFromQuery(query, location) {
    426    // If location is null or no city/state, return the entire query
    427    if (!location || (!location.city && !location.state)) {
    428      return query;
    429    }
    430 
    431    // Remove the city and state values from the query
    432    let locValues = Object.values(location)
    433      .map(loc => loc?.replace(/\W+/g, " "))
    434      .filter(loc => loc?.trim());
    435 
    436    // Regular expression to remove locations
    437    // This handles single & multi-worded cities/states
    438    let locPattern = locValues.map(loc => `\\b${loc}\\b`).join("|");
    439    let locRegex = new RegExp(locPattern, "g");
    440 
    441    // Remove locations, trim whitespace, and split words
    442    let words = query
    443      .replace(/\W+/g, " ")
    444      .replace(locRegex, "")
    445      .split(/\W+/)
    446      .filter(word => !!word.length);
    447 
    448    let subjectWords = this.#cleanSubject(words);
    449    return subjectWords.join(" ");
    450  }
    451 
    452  /**
    453   * Remove trailing prepositions from the list of words
    454   *
    455   * @param {string[]} words
    456   */
    457  #cleanSubject(words) {
    458    while (words.length && PREPOSITIONS.includes(words[words.length - 1])) {
    459      words.pop();
    460    }
    461    return words;
    462  }
    463 }
    464 
    465 // Export the singleton instance
    466 export var MLSuggest = new _MLSuggest();