tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

SmartTabGrouping.sys.mjs (58799B)


      1 /* This Source Code Form is subject to the terms of the Mozilla Public
      2 * License, v. 2.0. If a copy of the MPL was not distributed with this
      3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      4 
      5 import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
      6 import { createEngine } from "chrome://global/content/ml/EngineProcess.sys.mjs";
      7 import {
      8  cosSim,
      9  KeywordExtractor,
     10 } from "chrome://global/content/ml/NLPUtils.sys.mjs";
     11 
     12 import {
     13  computeCentroidFrom2DArray,
     14  computeRandScore,
     15  euclideanDistance,
     16  getAccuracyStats,
     17  kmeansPlusPlus,
     18  silhouetteCoefficients,
     19 } from "chrome://global/content/ml/ClusterAlgos.sys.mjs";
     20 
     21 const lazy = {};
     22 
     23 ChromeUtils.defineESModuleGetters(lazy, {
     24  NLP: "resource://gre/modules/NLP.sys.mjs",
     25  MLEngineParent: "resource://gre/actors/MLEngineParent.sys.mjs",
     26  MultiProgressAggregator: "chrome://global/content/ml/Utils.sys.mjs",
     27  Progress: "chrome://global/content/ml/Utils.sys.mjs",
     28 });
     29 
     30 const LATEST_MODEL_REVISION = "latest";
     31 
     32 // Methods for suggesting tabs that are similar to current tab
     33 export const SUGGEST_OTHER_TABS_METHODS = {
     34  KMEANS_WITH_ANCHOR: "KMEANS_WITH_ANCHOR",
     35  NEAREST_NEIGHBOR: "NEAREST_NEIGHBOR",
     36  LOGISTIC_REGRESSION: "LOGISTIC_REGRESSION",
     37 };
     38 
     39 XPCOMUtils.defineLazyPreferenceGetter(
     40  lazy,
     41  "suggestOtherTabsMethod",
     42  "browser.tabs.groups.smart.suggestOtherTabsMethod"
     43 );
     44 
     45 XPCOMUtils.defineLazyPreferenceGetter(
     46  lazy,
     47  "topicModelRevision",
     48  "browser.tabs.groups.smart.topicModelRevision"
     49 );
     50 
     51 XPCOMUtils.defineLazyPreferenceGetter(
     52  lazy,
     53  "embeddingModelRevision",
     54  "browser.tabs.groups.smart.embeddingModelRevision"
     55 );
     56 
     57 XPCOMUtils.defineLazyPreferenceGetter(
     58  lazy,
     59  "nearestNeighborThresholdInt",
     60  "browser.tabs.groups.smart.nearestNeighborThresholdInt"
     61 );
     62 
     63 const EMBED_TEXT_KEY = "combined_text";
     64 export const CLUSTER_METHODS = {
     65  KMEANS: "KMEANS",
     66 };
     67 
     68 // Methods for finding similar items for an existing cluster
     69 export const ANCHOR_METHODS = {
     70  DRIFT: "DRIFT", // We let k-means clustering run, and find the cluster with the most anchor items
     71  FIXED: "FIXED", // We always group with the anchor items in the 0 cluster, and never let them be reassinged
     72 };
     73 
     74 // Methods for finding ignoring other groups that were already grouped
     75 export const PREGROUPED_HANDLING_METHODS = {
     76  EXCLUDE: "EXCLUDE", // We let k-means clustering run, and find the cluster with the most anchor items
     77  IGNORE: "IGNORE", // We always group with the anchor items in the 0 cluster, and never let them be reassinged
     78 };
     79 
     80 const EXPECTED_TOPIC_MODEL_OBJECTS = 6;
     81 const EXPECTED_EMBEDDING_MODEL_OBJECTS = 4;
     82 
     83 const MAX_NON_SUMMARIZED_SEARCH_LENGTH = 26;
     84 
     85 export const DIM_REDUCTION_METHODS = {};
     86 const MISSING_ANCHOR_IN_CLUSTER_PENALTY = 0.2;
     87 const MAX_GROUPED_TABS = 3;
     88 const MAX_SUGGESTED_TABS = 10;
     89 // limit number of tabs to be processed so inference process doesn't crash
     90 const MAX_TABS_TO_PROCESS = 300;
     91 
     92 const DISSIMILAR_TAB_LABEL = "none";
     93 const ADULT_TAB_LABEL = "adult content";
     94 const LABELS_TO_EXCLUDE = [DISSIMILAR_TAB_LABEL, ADULT_TAB_LABEL];
     95 
     96 const ML_TASK_FEATURE_EXTRACTION = "feature-extraction";
     97 const ML_TASK_TEXT2TEXT = "text2text-generation";
     98 
     99 const LABEL_REASONS = {
    100  DEFAULT: "DEFAULT",
    101  LOW_CONFIDENCE: "LOW_CONFIDENCE",
    102  EXCLUDE: "EXCLUDE",
    103  ERROR: "ERROR",
    104 };
    105 
    106 export const SMART_TAB_GROUPING_CONFIG = {
    107  embedding: {
    108    dtype: "q8",
    109    timeoutMS: 2 * 60 * 1000, // 2 minutes
    110    taskName: ML_TASK_FEATURE_EXTRACTION,
    111    featureId: "smart-tab-embedding",
    112    backend: "onnx-native",
    113    fallbackBackend: "onnx",
    114  },
    115  topicGeneration: {
    116    dtype: "q8",
    117    timeoutMS: 2 * 60 * 1000, // 2 minutes
    118    taskName: ML_TASK_TEXT2TEXT,
    119    featureId: "smart-tab-topic",
    120    backend: "onnx-native",
    121    fallbackBackend: "onnx",
    122  },
    123  dataConfig: {
    124    titleKey: "label",
    125    descriptionKey: "description",
    126  },
    127  clustering: {
    128    dimReductionMethod: null, // Not completed.
    129    clusterImplementation: CLUSTER_METHODS.KMEANS,
    130    clusteringTriesPerK: 3,
    131    anchorMethod: ANCHOR_METHODS.FIXED,
    132    pregroupedHandlingMethod: PREGROUPED_HANDLING_METHODS.EXCLUDE,
    133    pregroupedSilhouetteBoost: 2, // Relative weight of the cluster's score and all other cluster's combined
    134    suggestOtherTabsMethod: SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR,
    135  },
    136 };
    137 
    138 // these parameters were generated by training a logistic regression
    139 // model on synthetic data. see https://github.com/mozilla/smart-tab-grouping
    140 // and https://github.com/mozilla/smart-tab-grouping/pull/12 for more info
    141 const LOGISTIC_REGRESSION_PARAMS = {
    142  // Logistic WITH group name
    143  // Features: s_gc, s_tt_max, s_dd in [0, 1]
    144  TITLE_WITH_GROUP_NAME: {
    145    GROUP_SIMILARITY_WEIGHT: 0.10249,
    146    TITLE_SIMILARITY_WEIGHT: 0.54897,
    147    DOMAIN_SIMILARITY_WEIGHT: 0.34854,
    148    INTERCEPT: -0.07397,
    149    THRESHOLD: 0.59,
    150  },
    151  // Logistic WITHOUT group name
    152  // Features: s_tt_max, s_dd in [0, 1]
    153  TITLE_ONLY: {
    154    GROUP_SIMILARITY_WEIGHT: 0, // unused in this variant
    155    TITLE_SIMILARITY_WEIGHT: 0.92513,
    156    DOMAIN_SIMILARITY_WEIGHT: 0.07487,
    157    INTERCEPT: -2.58574,
    158    THRESHOLD: 0.123,
    159  },
    160 };
    161 
    162 const TAB_URLS_TO_EXCLUDE = [
    163  "about:newtab",
    164  "about:home",
    165  "about:privatebrowsing",
    166  "chrome://browser/content/blanktab.html",
    167  "about:firefoxview",
    168  "about:opentabs",
    169 ];
    170 
    171 const TITLE_DELIMETER_SET = new Set(["-", "|", "—"]);
    172 
    173 /**
    174 * For a given set of clusters represented by indices, returns the index of the cluster
    175 * that has the most anchor items inside it.
    176 *
    177 * An anhor item is an index that represents the index to a tab that is already grouped and in
    178 * the cluster we're interested in finding more items for.
    179 *
    180 * @param {number[][]} groupIndices - Array of clusters represented as arrays of indices.
    181 * @param {number[]} anchorItems - Array of anchor item indices.
    182 * @returns {{anchorClusterIndex: number, numAnchorItemsInCluster: number}} Index of best cluster and the number of anchor items.
    183 */
    184 export function getBestAnchorClusterInfo(groupIndices, anchorItems) {
    185  const anchorItemSet = new Set(anchorItems);
    186  const numItemsList = groupIndices.map(g =>
    187    g.reduce(
    188      (cur, itemIndex) => (anchorItemSet.has(itemIndex) ? cur + 1 : cur),
    189      0
    190    )
    191  );
    192  const anchorClusterIndex = numItemsList.indexOf(Math.max(...numItemsList));
    193  const numAnchorItemsInCluster = numItemsList[anchorClusterIndex];
    194  return { anchorClusterIndex, numAnchorItemsInCluster };
    195 }
    196 
    197 /**
    198 * Check tab to see if it's a search page
    199 *
    200 * @param {object} tab
    201 * @returns {boolean} Returns true if the tab is a web search from the Firefox search UI and the user is still on the original page.
    202 * Changes in search query after search is made is supported.
    203 * Returns false if user started from a hompepage of a site rather than the New Tab / browser UI
    204 */
    205 export function isSearchTab(tab) {
    206  const linkedBrowser = tab?.linkedBrowser;
    207  if (!linkedBrowser) {
    208    return false;
    209  }
    210  const searchURL = linkedBrowser.getAttribute("triggeringSearchEngineURL");
    211  const curURL = linkedBrowser.currentURI?.spec;
    212  if (!searchURL) {
    213    return false;
    214  }
    215  const queryFieldsMarker = searchURL.indexOf("?");
    216 
    217  if (
    218    queryFieldsMarker > 0 &&
    219    searchURL.substring(0, queryFieldsMarker) ===
    220      curURL.substring(0, queryFieldsMarker)
    221  ) {
    222    return true;
    223  }
    224  return false;
    225 }
    226 
    227 export class SmartTabGroupingManager {
    228  /**
    229   * Creates the SmartTabGroupingManager object.
    230   *
    231   * @param {object} config configuration options
    232   */
    233  constructor(config) {
    234    this.config = config || structuredClone(SMART_TAB_GROUPING_CONFIG);
    235  }
    236 
    237  /**
    238   *
    239   * @param {MLEngine} engine the engine to check
    240   * @return {boolean} true if the engine has not been initialized or closed
    241   */
    242  static isEngineClosed(engine) {
    243    return !engine || engine?.engineStatus === "closed";
    244  }
    245 
    246  /**
    247   * Initializes the embedding engine by running a test request
    248   * This helps remove the init latency
    249   */
    250  async initEmbeddingEngine() {
    251    if (!SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) {
    252      return;
    253    }
    254    try {
    255      this.embeddingEngine = await this._createMLEngine(this.config.embedding);
    256      const request = {
    257        args: ["Test"],
    258        options: { pooling: "mean", normalize: true },
    259      };
    260      this.embeddingEngine.run(request);
    261    } catch (e) {}
    262  }
    263 
    264  /**
    265   * Generates tabs to process with a limit. First MAX_GROUPED_TABS are tabs that are
    266   * present in the group of the anchor tab. The remaining "ungrouped" tabs fill the
    267   * slots up to MAX_TABS_TO_PROCESS
    268   *
    269   * @param {Array} tabsInGroup active tabs in anchor group we are adding tabs to
    270   * @param {Array} allTabs list of tabs from gbrowser, some of which may be grouped in other groups
    271   * @param {number} max_limit_to_process max number of tabs we want to process as part of the flow
    272   * @returns a list of suggested new tabs. If no new tabs are suggested an empty list is returned.
    273   */
    274  getTabsToProcess(
    275    tabsInGroup,
    276    allTabs,
    277    max_limit_to_process = MAX_TABS_TO_PROCESS
    278  ) {
    279    const seen = new Set();
    280    let tabsToProcess = [];
    281 
    282    const shouldInclude = tab => {
    283      if (tab.pinned) {
    284        return false;
    285      }
    286      if (!tab?.linkedBrowser?.currentURI?.spec) {
    287        return false;
    288      }
    289      return true;
    290    };
    291 
    292    // include tabs in the anchor group first
    293    for (const tab of tabsInGroup) {
    294      if (!shouldInclude(tab)) {
    295        continue;
    296      }
    297      if (!seen.has(tab)) {
    298        // make sure we have "seen" all the
    299        // tabs already in the current group
    300        seen.add(tab);
    301        tabsToProcess.push(tab);
    302      }
    303    }
    304 
    305    // when generating embeddings, we only look at the first MAX_GROUPED_TABS
    306    // so use that limit here
    307    tabsToProcess = tabsToProcess.slice(0, MAX_GROUPED_TABS);
    308    // fill remaining slots with ungrouped tabs from the window
    309    for (const tab of allTabs) {
    310      if (tabsToProcess.length >= max_limit_to_process) {
    311        break;
    312      }
    313      if (!shouldInclude(tab)) {
    314        continue;
    315      }
    316      if (!seen.has(tab)) {
    317        seen.add(tab);
    318        tabsToProcess.push(tab);
    319      }
    320    }
    321    return tabsToProcess;
    322  }
    323 
    324  /**
    325   * Generates suggested tabs for an existing or provisional group
    326   *
    327   * @param {object} group active group we are adding tabs to
    328   * @param {Array} tabs list of tabs from gbrowser, some of which may be grouped in other groups
    329   * @returns a list of suggested new tabs. If no new tabs are suggested an empty list is returned.
    330   */
    331  async smartTabGroupingForGroup(group, tabs) {
    332    // Add tabs to suggested group
    333    const groupTabs = group.tabs;
    334    const allTabs = this.getTabsToProcess(groupTabs, tabs, MAX_TABS_TO_PROCESS);
    335    // first (1 up to MAX_GROUPED_TABS) are tabs in the group
    336    const groupIndices = [];
    337    for (let i = 0; i < MAX_GROUPED_TABS; i++) {
    338      if (groupTabs.includes(allTabs[i])) {
    339        groupIndices.push(i);
    340      }
    341    }
    342 
    343    // find tabs that are part of other groups
    344    const alreadyGroupedIndices = allTabs
    345      .map((t, i) => (t.group ? i : -1))
    346      .filter(a => a >= 0);
    347 
    348    let suggestedTabs;
    349    switch (lazy.suggestOtherTabsMethod) {
    350      case SUGGEST_OTHER_TABS_METHODS.KMEANS_WITH_ANCHOR:
    351        suggestedTabs = await this.generateClusters(
    352          allTabs,
    353          null,
    354          null,
    355          null,
    356          groupIndices,
    357          alreadyGroupedIndices
    358        ).then(clusters => {
    359          if (!clusters) {
    360            return [];
    361          }
    362          const targetCluster = clusters.clusterRepresentations.find(c =>
    363            groupTabs.some(g => c.tabs.includes(g))
    364          );
    365          if (targetCluster) {
    366            // Return only tabs not already grouped
    367            return targetCluster.tabs.filter(t => !t.group);
    368          }
    369          return [];
    370        });
    371        break;
    372      case SUGGEST_OTHER_TABS_METHODS.LOGISTIC_REGRESSION:
    373        suggestedTabs = await this.findSimilarTabsLogisticRegression({
    374          allTabs,
    375          groupedIndices: groupIndices,
    376          alreadyGroupedIndices,
    377          groupLabel: group?.label,
    378        });
    379        break;
    380      case SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR:
    381      default:
    382        // find nearest neighbors to current group
    383        suggestedTabs = await this.findNearestNeighbors({
    384          allTabs,
    385          groupedIndices: groupIndices,
    386          alreadyGroupedIndices,
    387          groupLabel: group?.label,
    388        });
    389    }
    390    return suggestedTabs.slice(0, MAX_SUGGESTED_TABS);
    391  }
    392 
    393  /**
    394   * Get tabs that need to be included in suggestions
    395   *
    396   * @param {Array} allTabs all tabs that are part of the window
    397   * @param {Array} groupedIndices indices of tabs that are already part of the group
    398   * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups
    399   * @returns {Array} tabs indices to be considered for suggestions
    400   */
    401  getTabsToSuggest(allTabs, groupedIndices, alreadyGroupedIndices) {
    402    // tabs to be excluded
    403    // indices of all tabs that should be excluded (with duplicates)
    404    const tabURLIndicesToExclude = allTabs
    405      .map((at, index) => (TAB_URLS_TO_EXCLUDE.includes(at.url) ? index : -1))
    406      .filter(index => index !== -1);
    407    const excludedTabIndices = [
    408      ...groupedIndices,
    409      ...alreadyGroupedIndices,
    410      ...tabURLIndicesToExclude,
    411    ];
    412 
    413    // tabs to be included
    414    return allTabs
    415      .map((_, index) => index)
    416      .filter(i => !excludedTabIndices.includes(i));
    417  }
    418 
    419  /**
    420   * Generates similar tabs a grouped list of tabs
    421   *
    422   * @param {Array} allTabs all tabs that are part of the window
    423   * @param {Array} groupedIndices indices of tabs that are already part of the group
    424   * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups
    425   * @param {string} groupLabel name of group if present
    426   * @param {number} threshold for nearest neighbor similarity
    427   * @returns a list of suggested tabs that are similar to the groupedIndices tabs
    428   */
    429  async findNearestNeighbors({
    430    allTabs,
    431    groupedIndices,
    432    alreadyGroupedIndices,
    433    groupLabel = "",
    434    thresholdMills = lazy.nearestNeighborThresholdInt,
    435    precomputedEmbeddings = [],
    436    depth = 0,
    437  }) {
    438    // get embeddings for all the tabs
    439    const tabData = await this._prepareTabData(allTabs);
    440    let embeddings = precomputedEmbeddings;
    441    if (precomputedEmbeddings.length === 0) {
    442      embeddings = await this._generateEmbeddings(
    443        tabData.map((td, index) => {
    444          let text = SmartTabGroupingManager.preprocessText(td[EMBED_TEXT_KEY]);
    445          // augment with group name if it's present
    446          if (groupLabel && groupedIndices.includes(index)) {
    447            text = `${groupLabel.slice(0, 100)}. ${text}`;
    448          }
    449          return text;
    450        })
    451      );
    452    }
    453 
    454    // tabs that need to be assigned after filtering
    455    const tabsToAssignIndices = this.getTabsToSuggest(
    456      tabData,
    457      groupedIndices,
    458      alreadyGroupedIndices
    459    );
    460 
    461    let closestTabs = [];
    462    const similarTabsIndices = [];
    463    for (let i = 0; i < tabsToAssignIndices.length; i++) {
    464      let closestScore = null;
    465      for (
    466        let j = 0;
    467        j < Math.min(groupedIndices.length, MAX_GROUPED_TABS);
    468        j++
    469      ) {
    470        const cosineSim = cosSim(
    471          embeddings[tabsToAssignIndices[i]],
    472          embeddings[groupedIndices[j]]
    473        );
    474        if (!closestScore || cosineSim > closestScore) {
    475          closestScore = cosineSim;
    476        }
    477      }
    478      // threshold could also be set via a nimbus experiment, in which case
    479      // it will be an int <= 1000
    480      if (closestScore > thresholdMills / 1000) {
    481        closestTabs.push([allTabs[tabsToAssignIndices[i]], closestScore]);
    482        similarTabsIndices.push(tabsToAssignIndices[i]);
    483      }
    484    }
    485    closestTabs.sort((a, b) => b[1] - a[1]);
    486    closestTabs = closestTabs.map(t => t[0]);
    487    // recurse once if the initial call only had a single tab
    488    // and we found at least 1 similar tab - this improves recall
    489    if (groupedIndices.length === 1 && !!closestTabs.length && depth === 1) {
    490      const recurseSimilarTabs = await this.findNearestNeighbors({
    491        allTabs,
    492        groupedIndices: similarTabsIndices,
    493        alreadyGroupedIndices: alreadyGroupedIndices.concat(groupedIndices),
    494        groupLabel,
    495        thresholdMills,
    496        precomputedEmbeddings: embeddings,
    497        depth: depth - 1,
    498      });
    499      closestTabs = closestTabs.concat(recurseSimilarTabs);
    500    }
    501    return closestTabs;
    502  }
    503 
    504  /**
    505   * Calculates the average similarity between the anchor embeddings and the candidate embeddings
    506   *
    507   * @param {number[]} anchorEmbeddings title embeddings for the anchor tabs
    508   * @param {number[]} candidateEmbeddings title embeddings for the candidate tabs
    509   */
    510  getAverageSimilarity(anchorEmbeddings, candidateEmbeddings) {
    511    let averageSimilarities = [];
    512    for (let candidate_embedding of candidateEmbeddings) {
    513      let averageSimilarity = 0;
    514      for (let anchor_embedding of anchorEmbeddings) {
    515        averageSimilarity += cosSim(candidate_embedding, anchor_embedding);
    516      }
    517      averageSimilarities.push(averageSimilarity / anchorEmbeddings.length);
    518    }
    519    return averageSimilarities;
    520  }
    521 
    522  /**
    523   * Calculates the max similarity between the anchor embeddings and the candidate embeddings
    524   * (used for s_tt_max).
    525   *
    526   * @param {number[]} anchorEmbeddings title embeddings for the anchor tabs
    527   * @param {number[]} candidateEmbeddings title embeddings for the candidate tabs
    528   */
    529  getMaxSimilarity(anchorEmbeddings, candidateEmbeddings) {
    530    let maxSimilarities = [];
    531    for (let candidate_embedding of candidateEmbeddings) {
    532      let maxSimilarity = -1;
    533      for (let anchor_embedding of anchorEmbeddings) {
    534        const sim = cosSim(candidate_embedding, anchor_embedding);
    535        if (sim > maxSimilarity) {
    536          maxSimilarity = sim;
    537        }
    538      }
    539      maxSimilarities.push(maxSimilarity);
    540    }
    541    return maxSimilarities;
    542  }
    543 
    544  /**
    545   * Extract base domain from a URL with error handling
    546   *
    547   * @param {string} url
    548   * @return {string}
    549   */
    550  static getBaseDomain(url) {
    551    if (!url) {
    552      return "";
    553    }
    554 
    555    let hostname;
    556    try {
    557      ({ hostname } = new URL(url));
    558    } catch (_e) {
    559      // invalid URL
    560      return "";
    561    }
    562 
    563    if (!hostname) {
    564      return "";
    565    }
    566 
    567    try {
    568      // additionalParts = 1 → one label above the registrable domain
    569      // then remove 'www'
    570      // https://www.example.com -> www.example.com -> example.com
    571      // https://www.docs.google.com -> docs.google.com
    572      // https://localhost -> error
    573      return Services.eTLD
    574        .getBaseDomain(Services.io.newURI(url.toLowerCase()), 1)
    575        .replace(/^www\./, "");
    576    } catch (_e) {
    577      // localhost, IPs, internal hosts, etc.
    578      // bucket by the hostname.
    579      return hostname.toLowerCase();
    580    }
    581  }
    582 
    583  /**
    584   * For each candidate tab, compute s_dd = fraction of anchors whose base domain
    585   * matches the candidate's base domain.
    586   *
    587   * @param {Array} anchorTabsPrep  output of _prepareTabData for anchor tabs
    588   * @param {Array} candidateTabsPrep output of _prepareTabData for candidate tabs
    589   * @return {number[]} array of s_dd values in [0, 1]
    590   */
    591  getDomainMatchFractions(anchorTabsPrep, candidateTabsPrep) {
    592    const anchorDomains = anchorTabsPrep.map(t =>
    593      SmartTabGroupingManager.getBaseDomain(t.url)
    594    );
    595    const numAnchors = anchorDomains.length || 1;
    596 
    597    return candidateTabsPrep.map(tab => {
    598      const candDomain = SmartTabGroupingManager.getBaseDomain(tab.url);
    599      if (!candDomain) {
    600        return 0;
    601      }
    602      let same = 0;
    603      for (const ad of anchorDomains) {
    604        if (ad && ad === candDomain) {
    605          same++;
    606        }
    607      }
    608      return same / numAnchors;
    609    });
    610  }
    611 
    612  /**
    613   * Calculates the sigmoid value of the input
    614   *
    615   * @param {number} z
    616   * @return {number}
    617   */
    618  sigmoid(z) {
    619    return 1 / (1 + Math.exp(-z));
    620  }
    621 
    622  /**
    623   * Calculates the probability using the linear combination of the parameters
    624   *
    625   * @param {number} groupSimilarity s_gc in [0,1]
    626   * @param {number} titleSimilarity s_tt_max in [0,1]
    627   * @param {number} domainSimilarity s_dd in [0,1]
    628   * @param {object} params the logistic regression weights assigned to each parameter
    629   * @return {number}
    630   */
    631  calculateProbability(
    632    groupSimilarity,
    633    titleSimilarity,
    634    domainSimilarity,
    635    params
    636  ) {
    637    const wGroup = params.GROUP_SIMILARITY_WEIGHT || 0;
    638    const wTitle = params.TITLE_SIMILARITY_WEIGHT || 0;
    639    const wDomain = params.DOMAIN_SIMILARITY_WEIGHT || 0;
    640    const z =
    641      groupSimilarity * wGroup +
    642      titleSimilarity * wTitle +
    643      domainSimilarity * wDomain +
    644      params.INTERCEPT;
    645    return this.sigmoid(z);
    646  }
    647 
    648  /**
    649   * Calculates the probabilities given similarity lists (cosine) and domain fractions.
    650   *
    651   * @param {number[]|null} groupSimilaritiesCos cosine(group, candidate) in [-1,1] or null
    652   * @param {number[]} titleSimilaritiesCos max cosine(anchor, candidate) in [-1,1]
    653   * @param {number[]} domainSimilarities s_dd in [0,1]
    654   * @return {number[]} probabilities for each candidate tab
    655   */
    656  calculateAllProbabilities(
    657    groupSimilaritiesCos,
    658    titleSimilaritiesCos,
    659    domainSimilarities
    660  ) {
    661    const hasGroupSimilarity =
    662      Array.isArray(groupSimilaritiesCos) && groupSimilaritiesCos.length;
    663    const useDomain =
    664      Array.isArray(domainSimilarities) && domainSimilarities.length;
    665 
    666    const probabilities = [];
    667    for (let i = 0; i < titleSimilaritiesCos.length; i++) {
    668      // groupTitleSim and titleSim are (cos + 1)/2 -> [0,1]
    669      const groupTitleSim = hasGroupSimilarity
    670        ? 0.5 * (groupSimilaritiesCos[i] + 1)
    671        : 0;
    672      const titleSim = 0.5 * (titleSimilaritiesCos[i] + 1);
    673      const domainSim = useDomain ? domainSimilarities[i] : 0;
    674 
    675      const params = hasGroupSimilarity
    676        ? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME
    677        : LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY;
    678 
    679      probabilities.push(
    680        this.calculateProbability(groupTitleSim, titleSim, domainSim, params)
    681      );
    682    }
    683    return probabilities;
    684  }
    685 
    686  /**
    687   * Generates similar tabs to a grouped list of tabs using a logistic regression "model"
    688   *
    689   * @param {Array} allTabs all tabs that are part of the window
    690   * @param {Array} groupedIndices indices of tabs that are already part of the group
    691   * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups
    692   * @param {string} groupLabel name of group if present
    693   */
    694  async findSimilarTabsLogisticRegression({
    695    allTabs,
    696    groupedIndices,
    697    alreadyGroupedIndices,
    698    groupLabel = "",
    699  }) {
    700    const tabData = await this._prepareTabData(allTabs);
    701    const candidateIndices = this.getTabsToSuggest(
    702      tabData,
    703      groupedIndices,
    704      alreadyGroupedIndices
    705    );
    706 
    707    const candidateTabsData = candidateIndices.map(ci => allTabs[ci]);
    708    const candidateTabsPrep = await this._prepareTabData(candidateTabsData);
    709 
    710    const anchorTabsPrep = groupedIndices
    711      .map(gi => tabData[gi])
    712      .slice(0, MAX_GROUPED_TABS);
    713 
    714    // generate embeddings for both anchor and candidate titles
    715    const titleEmbeddings = await this._generateEmbeddings(
    716      anchorTabsPrep
    717        .concat(candidateTabsPrep)
    718        .map(tab => SmartTabGroupingManager.preprocessText(tab[EMBED_TEXT_KEY]))
    719    );
    720 
    721    let groupEmbedding;
    722    let groupSimilaritiesCos = null;
    723    if (groupLabel) {
    724      groupEmbedding = await this._generateEmbeddings([groupLabel]);
    725      // cosine(group, candidate_title) in [-1,1]
    726      groupSimilaritiesCos = this.getAverageSimilarity(
    727        groupEmbedding,
    728        titleEmbeddings.slice(anchorTabsPrep.length)
    729      );
    730    }
    731 
    732    // s_tt_max: max cosine(anchor_title, candidate_title) in [-1,1]
    733    const titleSimilaritiesCos = this.getMaxSimilarity(
    734      titleEmbeddings.slice(0, anchorTabsPrep.length),
    735      titleEmbeddings.slice(anchorTabsPrep.length)
    736    );
    737 
    738    // s_dd: fraction of anchors sharing the candidate's base domain
    739    const domainSimilarities = this.getDomainMatchFractions(
    740      anchorTabsPrep,
    741      candidateTabsPrep
    742    );
    743 
    744    const candidateProbabilities = this.calculateAllProbabilities(
    745      groupSimilaritiesCos,
    746      titleSimilaritiesCos,
    747      domainSimilarities
    748    );
    749 
    750    // get matching params depending on the group name availability
    751    const probabilityThreshold = groupEmbedding
    752      ? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME.THRESHOLD
    753      : LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY.THRESHOLD;
    754 
    755    return (
    756      candidateTabsData
    757        // combine candidate tabs with corresponding probabilities
    758        .map((ct, index) => ({
    759          ct,
    760          prob: candidateProbabilities[index],
    761        }))
    762        // only keep those that are within the probability threshold
    763        .filter(item => item.prob >= probabilityThreshold)
    764        // ensure the highest probability candidates come first in the list
    765        .sort((a, b) => b.prob - a.prob)
    766        // keep the tabs only
    767        .map(item => item.ct)
    768    );
    769  }
    770 
    771  /**
    772   * This function will terminate a grouping or label generation in progress
    773   * It is currently not implemented.
    774   */
    775  terminateProcess() {
    776    // TODO - teminate AI processes, This method will be
    777    // called when tab grouping panel is closed.
    778  }
    779 
    780  /**
    781   * Changes the clustering method. Must be one of supported methods.
    782   *
    783   * @param {string} method Name of method
    784   */
    785  setClusteringMethod(method) {
    786    if (!(method in CLUSTER_METHODS)) {
    787      throw new Error(`Clustering method ${method} not supported`);
    788    }
    789    this.config.clustering.clusterImplementation = method;
    790  }
    791 
    792  /**
    793   * Set the technique for clustering when certain tabs are already assigned to groups
    794   *
    795   * @param {string} method which is one of ANCHOR_METHODS
    796   */
    797  setAnchorMethod(method) {
    798    if (!(method in ANCHOR_METHODS)) {
    799      throw new Error(`Clustering anchor method ${method} not supported`);
    800    }
    801    this.config.clustering.anchorMethod = method;
    802  }
    803 
    804  setSilBoost(boost) {
    805    this.config.clustering.pregroupedSilhouetteBoost = boost;
    806  }
    807 
    808  /**
    809   * Sets method to reduce dimensionality of embeddings prior to clustering
    810   *
    811   * @param {string} method Name of method
    812   */
    813  setDimensionReductionMethod(method) {
    814    if (method && !(method in DIM_REDUCTION_METHODS)) {
    815      throw new Error(`Dimension reduction method ${method} not supported`);
    816    }
    817    this.config.clustering.dimReductionMethod = method;
    818  }
    819 
    820  /**
    821   * Sets the field name of the title of a page to be used when clustering or generating embeddings
    822   * This is useful when clustering test data that is not a tab object
    823   *
    824   * @param {string} titleKey KEY FOR THE TITLE
    825   */
    826  setDataTitleKey(titleKey) {
    827    this.config.dataConfig.titleKey = titleKey;
    828  }
    829 
    830  /**
    831   * Logs to the appropriate place for debugging. Console for now
    832   *
    833   * @param {string} msg Message to log
    834   * @param {boolean} useDescription Whether to add description to the final text
    835   */
    836  log(_msg) {}
    837 
    838  /**
    839   * Prepares data to be used by the ml models
    840   *
    841   * @param {object[]} tabList list of tabs in the current window
    842   * @param {boolean} useDescription whether we should combined the title and description
    843   * @return {Promise<*[Object]>}
    844   * @private
    845   */
    846  async _prepareTabData(tabList, useDescription = false) {
    847    const titleKey = this.config.dataConfig.titleKey;
    848    const descriptionKey = this.config.dataConfig.descriptionKey;
    849    const structuredData = [];
    850    for (let tab of tabList) {
    851      const description =
    852        useDescription && descriptionKey && tab[descriptionKey];
    853 
    854      let textToEmbed;
    855      if (description) {
    856        textToEmbed = tab[titleKey] + " " + description;
    857      } else {
    858        textToEmbed = tab[titleKey] || "Unknown";
    859      }
    860 
    861      structuredData.push({
    862        [EMBED_TEXT_KEY]: textToEmbed,
    863        title: tab[titleKey],
    864        description,
    865        url: tab?.linkedBrowser?.currentURI?.spec,
    866      });
    867    }
    868    return structuredData;
    869  }
    870 
    871  /**
    872   * Get updated config for the ml engine
    873   *
    874   * @param {object} initData
    875   * @param {string} featureId
    876   * @return {*}
    877   */
    878  static getUpdatedInitData(initData, featureId) {
    879    // we're setting a specific modelRevision through about:config or Nimbus
    880    if (
    881      featureId === SMART_TAB_GROUPING_CONFIG.topicGeneration.featureId &&
    882      lazy.topicModelRevision !== LATEST_MODEL_REVISION
    883    ) {
    884      initData.modelRevision = lazy.topicModelRevision;
    885    } else if (
    886      featureId === SMART_TAB_GROUPING_CONFIG.embedding.featureId &&
    887      lazy.embeddingModelRevision !== LATEST_MODEL_REVISION
    888    ) {
    889      initData.modelRevision = lazy.embeddingModelRevision;
    890    }
    891    return initData;
    892  }
    893 
    894  /**
    895   * Creates an ML engine for a given config.
    896   *
    897   * @param {*} engineConfig
    898   * @param {function} progressCallback
    899   * @returns MLEngine
    900   */
    901  async _createMLEngine(engineConfig, progressCallback) {
    902    const {
    903      featureId,
    904      engineId,
    905      dtype,
    906      taskName,
    907      timeoutMS,
    908      modelId,
    909      modelRevision,
    910      backend,
    911      fallbackBackend,
    912    } = engineConfig;
    913    let initData = {
    914      featureId,
    915      engineId,
    916      dtype,
    917      taskName,
    918      timeoutMS,
    919      modelId,
    920      modelRevision,
    921      backend,
    922    };
    923    initData = SmartTabGroupingManager.getUpdatedInitData(initData, featureId);
    924    let engine;
    925    try {
    926      engine = await createEngine(initData, progressCallback);
    927      this.backend = backend;
    928    } catch (e) {
    929      engine = await createEngine(
    930        {
    931          ...initData,
    932          backend: fallbackBackend,
    933        },
    934        progressCallback
    935      );
    936      this.backend = fallbackBackend;
    937    }
    938    return engine;
    939  }
    940 
    941  /**
    942   * Generates embeddings from a list of tab data structures
    943   *
    944   * @param  tabList List of tabs with label (title) and description keys
    945   * @returns {Promise<*[]>} List of embeddings (2d array)
    946   * @private
    947   */
    948  async _generateEmbeddings(textToEmbedList) {
    949    const inputData = {
    950      inputArgs: textToEmbedList,
    951      runOptions: {
    952        pooling: "mean",
    953        normalize: true,
    954      },
    955    };
    956 
    957    if (SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) {
    958      this.embeddingEngine = await this._createMLEngine(this.config.embedding);
    959    }
    960    const request = {
    961      args: [inputData.inputArgs],
    962      options: inputData.runOptions,
    963    };
    964    return await this.embeddingEngine.run(request);
    965  }
    966 
    967  /**
    968   * Clusters in desired methods
    969   * based on the config of the class
    970   *
    971   * @param tabList List of tabs as array
    972   * @param docEmbeddings Precomputed embeddings for the Tab as two dimensional array
    973   * @param k Desired number of clusters. Tries a range of sizes if 0.
    974   * @param {function} randomFunc Optional seeded random number generator for testing
    975   * @returns {SmartTabGroupingResult}
    976   * @private
    977   */
    978  _clusterEmbeddings({
    979    tabs,
    980    embeddings,
    981    k,
    982    randomFunc,
    983    anchorIndices,
    984    alreadyGroupedIndices = [],
    985  }) {
    986    let allItems;
    987 
    988    const freezeAnchorsInZeroCluster =
    989      anchorIndices &&
    990      this.config.clustering.anchorMethod == ANCHOR_METHODS.FIXED;
    991 
    992    const dimReductionMethod = this.config.clustering.dimReductionMethod;
    993    switch (dimReductionMethod) {
    994      default:
    995        //  Dimensionality reduction support is landing very soon.
    996        break;
    997    }
    998    k = k || 0;
    999    let startK = k;
   1000    let endK = k + 1;
   1001    if (!k) {
   1002      startK = 2;
   1003      // Find a reasonable max # of clusters
   1004      endK =
   1005        Math.min(
   1006          Math.floor(Math.log(embeddings.length) * 2.0),
   1007          embeddings.length
   1008        ) + 1;
   1009    }
   1010    let bestResult;
   1011    let bestResultSilScore = -100.0;
   1012    let bestResultCenterCluster = 0;
   1013 
   1014    const clusteringMethod = this.config.clustering.clusterImplementation;
   1015    const clusteringTriesPerK = this.config.clustering.clusteringTriesPerK;
   1016    for (let curK = startK; curK < endK; curK++) {
   1017      let bestItemsForK;
   1018      let bestInertiaForK = 500000000000;
   1019      for (let j = 0; j < clusteringTriesPerK; j++) {
   1020        switch (clusteringMethod) {
   1021          case CLUSTER_METHODS.KMEANS:
   1022            allItems = kmeansPlusPlus({
   1023              data: embeddings,
   1024              k: curK,
   1025              maxIterations: 0,
   1026              randomFunc,
   1027              anchorIndices,
   1028              preassignedIndices:
   1029                this.config.clustering.pregroupedHandlingMethod ===
   1030                PREGROUPED_HANDLING_METHODS.EXCLUDE
   1031                  ? alreadyGroupedIndices
   1032                  : [],
   1033              freezeAnchorsInZeroCluster,
   1034            });
   1035            break;
   1036          default:
   1037            throw Error("Clustering implementation not supported");
   1038        }
   1039        const tempResult = new SmartTabGroupingResult({
   1040          indices: allItems,
   1041          embeddings,
   1042          config: this.config,
   1043        });
   1044        const inertia = tempResult.getCentroidInertia();
   1045        if (inertia < bestInertiaForK) {
   1046          bestInertiaForK = inertia;
   1047          bestItemsForK = tempResult;
   1048        }
   1049      }
   1050      const silScores = silhouetteCoefficients(
   1051        embeddings,
   1052        bestItemsForK.indices
   1053      );
   1054 
   1055      if (
   1056        freezeAnchorsInZeroCluster &&
   1057        this.config.clustering.pregroupedSilhouetteBoost > 0
   1058      ) {
   1059        // Boost silhouette score of target cluster when we are grouping around an existing cluster
   1060        // pregroupedSilhouetteBoost indicates the relative weight of the cluster's score and all other cluster's combined
   1061        silScores[0] *= this.config.clustering.pregroupedSilhouetteBoost;
   1062      }
   1063 
   1064      let avgSil = silScores.reduce((p, c) => p + c, 0) / silScores.length;
   1065      let curAnchorCluster = 0;
   1066      if (anchorIndices && !freezeAnchorsInZeroCluster) {
   1067        const { anchorClusterIndex, numAnchorItemsInCluster } =
   1068          getBestAnchorClusterInfo(bestItemsForK.indices, anchorIndices);
   1069        curAnchorCluster = anchorClusterIndex;
   1070        const penalty =
   1071          (MISSING_ANCHOR_IN_CLUSTER_PENALTY *
   1072            (anchorIndices.length - numAnchorItemsInCluster)) /
   1073          anchorIndices.length;
   1074        avgSil -= penalty;
   1075      }
   1076      if (avgSil > bestResultSilScore) {
   1077        bestResultSilScore = avgSil;
   1078        bestResult = bestItemsForK.indices;
   1079        bestResultCenterCluster = curAnchorCluster;
   1080      }
   1081    }
   1082    const result = new SmartTabGroupingResult({
   1083      indices: bestResult,
   1084      tabs,
   1085      embeddings,
   1086      config: this.config,
   1087    });
   1088    if (anchorIndices) {
   1089      result.setAnchorClusterIndex(
   1090        freezeAnchorsInZeroCluster ? 0 : bestResultCenterCluster
   1091      ); // In our k-means clustering implementation anchor cluster is always first
   1092      if (!freezeAnchorsInZeroCluster) {
   1093        result.adjustClusterForAnchors(anchorIndices);
   1094      }
   1095    }
   1096    return result;
   1097  }
   1098 
   1099  /**
   1100   * Generate a label for tabs in a group created by the user
   1101   *
   1102   * @param tabs tabs that are currently in the group
   1103   * @param otherTabs tabs in the window not part of the group
   1104   * @return {Promise<null|string|string|*>}
   1105   */
   1106  async getPredictedLabelForGroup(tabs, otherTabs) {
   1107    const clusters = this.createStaticCluster(tabs);
   1108    const otherClusters = this.createStaticCluster(otherTabs);
   1109    let predictedLabel;
   1110    try {
   1111      // function below modifies "clusters" object
   1112      await this.generateGroupLabels(clusters, otherClusters);
   1113      predictedLabel = clusters.clusterRepresentations[0].predictedTopicLabel;
   1114    } catch (e) {
   1115      this.labelReason = LABEL_REASONS.ERROR;
   1116      predictedLabel = "";
   1117    }
   1118    return predictedLabel;
   1119  }
   1120 
   1121  /**
   1122   * Generates clusters for a given list of tabs using precomputed embeddings or newly generated ones.
   1123   *
   1124   * @param {object[]} tabList - List of tab objects to be clustered.
   1125   * @param {number[][]} [precomputedEmbeddings] - Precomputed embeddings for tab titles and descriptions.
   1126   * @param {number} numClusters - Number of clusters to form.
   1127   * @param {Function} randFunc - Random function used for clustering initialization.
   1128   * @param {number[]} [anchorIndices=[]] - Indices of anchor tabs that should be prioritized in clustering.
   1129   * @param {number[]} [alreadyGroupedIndices=[]] - Indices of tabs that are already assigned to groups.
   1130   * @returns {SmartTabGroupingResult} - The best clustering result based on centroid inertia.
   1131   */
   1132  async generateClusters(
   1133    tabList,
   1134    precomputedEmbeddings,
   1135    numClusters,
   1136    randFunc,
   1137    anchorIndices = [],
   1138    alreadyGroupedIndices = []
   1139  ) {
   1140    numClusters = numClusters ?? 0;
   1141    const structuredData = await this._prepareTabData(tabList);
   1142 
   1143    // embeddings for title and description
   1144    if (precomputedEmbeddings) {
   1145      this.docEmbeddings = precomputedEmbeddings;
   1146    } else {
   1147      this.docEmbeddings = await this._generateEmbeddings(
   1148        structuredData.map(a => a[EMBED_TEXT_KEY])
   1149      );
   1150    }
   1151    let bestResultCluster;
   1152    let bestResultDistance = 50000000.0;
   1153 
   1154    const NUM_RUNS = 1;
   1155    for (let i = 0; i < NUM_RUNS; i++) {
   1156      const curResult = this._clusterEmbeddings({
   1157        tabs: tabList,
   1158        embeddings: this.docEmbeddings,
   1159        k: numClusters,
   1160        randomFunc: randFunc,
   1161        anchorIndices,
   1162        alreadyGroupedIndices,
   1163      });
   1164      const distance = curResult.getCentroidInertia();
   1165      if (distance < bestResultDistance) {
   1166        bestResultDistance = distance;
   1167        bestResultCluster = curResult;
   1168      }
   1169    }
   1170    return bestResultCluster;
   1171  }
   1172 
   1173  /**
   1174   * Create static cluster from a list of tabs. A single tab is Ok. Returns null for 0 tabs
   1175   *
   1176   * @param tabs
   1177   * @returns {SmartTabGroupingResult} groupingResult
   1178   */
   1179  createStaticCluster(tabs) {
   1180    if (!tabs) {
   1181      return null;
   1182    }
   1183 
   1184    return new SmartTabGroupingResult({
   1185      indices: [Array.from({ length: tabs.length }, (_, i) => i)],
   1186      tabs,
   1187      config: this.config,
   1188    });
   1189  }
   1190 
   1191  /**
   1192   * Utility function that loads all required engines for Smart Tab Grouping and any dependent models
   1193   *
   1194   * @param {(progress: { percentage: number }) => void} progressCallback callback function to call.
   1195   * Callback passes a dict with percentage indicating best effort 0.0-100.0 progress in model download.
   1196   */
   1197  async preloadAllModels(progressCallback) {
   1198    let previousProgress = -1;
   1199    const expectedObjects =
   1200      EXPECTED_TOPIC_MODEL_OBJECTS + EXPECTED_EMBEDDING_MODEL_OBJECTS;
   1201    // TODO - Find a way to get these fields. Add as a transformers js callback or within remotesettings
   1202 
   1203    const UPDATE_THRESHOLD_PERCENTAGE = 0.5;
   1204    const ONE_MB = 1024 * 1024;
   1205    const START_THRESHOLD_BYTES = ONE_MB * 0.2;
   1206 
   1207    const mutliProgressAggregator = new lazy.MultiProgressAggregator({
   1208      progressCallback: ({ progress, totalLoaded, metadata }) => {
   1209        if (totalLoaded < START_THRESHOLD_BYTES) {
   1210          progress = 0.0;
   1211        } else {
   1212          const numObjSeen = metadata.totalObjectsSeen || 0;
   1213          if (numObjSeen > 0 && numObjSeen < expectedObjects) {
   1214            // When starting to download we may still be getting configs and not have all the data
   1215            progress *= numObjSeen / expectedObjects;
   1216          }
   1217          if (progress > 100) {
   1218            progress = 100;
   1219          }
   1220        }
   1221        if (
   1222          Math.abs(previousProgress - progress) > UPDATE_THRESHOLD_PERCENTAGE
   1223        ) {
   1224          // Update only once changes are above a threshold to avoid throttling the UI with events.
   1225          progressCallback({
   1226            percentage: progress,
   1227          });
   1228          previousProgress = progress;
   1229        }
   1230      },
   1231      watchedTypes: [
   1232        lazy.Progress.ProgressType.DOWNLOAD,
   1233        lazy.Progress.ProgressType.LOAD_FROM_CACHE,
   1234      ],
   1235    });
   1236 
   1237    const [topicEngine, embeddingEngine] = await Promise.all([
   1238      this._createMLEngine(
   1239        this.config.topicGeneration,
   1240        mutliProgressAggregator?.aggregateCallback.bind(
   1241          mutliProgressAggregator
   1242        ) || null
   1243      ),
   1244      this._createMLEngine(
   1245        this.config.embedding,
   1246        mutliProgressAggregator?.aggregateCallback.bind(
   1247          mutliProgressAggregator
   1248        ) || null
   1249      ),
   1250    ]);
   1251    this.topicEngine = topicEngine;
   1252    this.embeddingEngine = embeddingEngine;
   1253  }
   1254 
   1255  /**
   1256   * Generate model input from keywords and documents
   1257   *
   1258   * @param {string []} keywords
   1259   * @param {string []} documents
   1260   */
   1261  createModelInput(keywords, documents) {
   1262    if (!keywords || keywords.length === 0) {
   1263      return `Topic from keywords: titles: \n${documents.join(" \n")}`;
   1264    }
   1265    return `Topic from keywords: ${keywords.join(", ")}. titles: \n${documents.join(" \n")}`;
   1266  }
   1267 
   1268  /**
   1269   * One artifact of the LLM output is that sometimes words are duplicated
   1270   * This function cuts the phrase when it sees the first duplicate word.
   1271   * Handles simple singluar / plural duplicates (-s only).
   1272   *
   1273   * @param {string} phrase Input phrase
   1274   * @returns {string} phrase cut before any duplicate word
   1275   */
   1276  static cutAtDuplicateWords(phrase) {
   1277    if (!phrase.length) {
   1278      return phrase;
   1279    }
   1280    const wordsSet = new Set();
   1281    const wordList = phrase.split(" ");
   1282    for (let i = 0; i < wordList.length; i++) {
   1283      let baseWord = wordList[i].toLowerCase();
   1284      if (baseWord.length > 3) {
   1285        if (baseWord.slice(-1) === "s") {
   1286          baseWord = baseWord.slice(0, -1);
   1287        }
   1288      }
   1289      if (wordsSet.has(baseWord)) {
   1290        // We are seeing a baseWord word. Exit with just the words so far and don't
   1291        // add any new words
   1292        return wordList.slice(0, i).join(" ");
   1293      }
   1294      wordsSet.add(baseWord);
   1295    }
   1296    return phrase; // return original phrase
   1297  }
   1298 
   1299  /**
   1300   * Removes trailing domain-related text such as '... - Mail' or '... | News'
   1301   * If there's not enough information remaining after, we keep the text as is
   1302   *
   1303   * @param {string} text tab title with potential domain information
   1304   * @return {string}
   1305   */
   1306  static preprocessText(text) {
   1307    // Matches 'xyz - Domain' or 'xyz | Domain'
   1308    // with a space before and after delimiter
   1309    // or if there are multiple delimiters next to each other
   1310    const delimiters = /(?<=\s)[|–-]+(?=\s)/;
   1311    const splitText = text.split(delimiters);
   1312 
   1313    // ensure there's enough info without the last element
   1314    const hasEnoughInfo =
   1315      !!splitText.length && splitText.slice(0, -1).join(" ").length > 5;
   1316 
   1317    // domain related texts are usually shorter, this takes care of the most common cases
   1318    const isPotentialDomainInfo =
   1319      splitText.length > 1 && splitText[splitText.length - 1].length < 20;
   1320 
   1321    // If both conditions are met, remove the last chunk, filter out empty strings,
   1322    // join on space, trim, and lowercase
   1323    if (hasEnoughInfo && isPotentialDomainInfo) {
   1324      return splitText
   1325        .slice(0, -1) // everything except the last element
   1326        .map(t => t.trim())
   1327        .filter(Boolean) // remove empty strings
   1328        .join(" ") // join with spaces
   1329        .trim(); // remove leading/trailing spaces
   1330    }
   1331 
   1332    // Otherwise, just return the text
   1333    return text;
   1334  }
   1335 
   1336  /**
   1337   * Postprocessing of raw output from Topic Model ML Engine
   1338   *
   1339   * @param {string | undefined} topic Raw topic phrase from topic model or undefined in case of an error
   1340   */
   1341  processTopicModelResult(topic) {
   1342    let basicResult = (topic || "").trim();
   1343    if (!basicResult) {
   1344      this.labelReason = LABEL_REASONS.LOW_CONFIDENCE;
   1345    }
   1346    if (LABELS_TO_EXCLUDE.includes(basicResult.toLowerCase())) {
   1347      this.labelReason = LABEL_REASONS.EXCLUDE;
   1348      return "";
   1349    }
   1350    return SmartTabGroupingManager.cutAtDuplicateWords(basicResult);
   1351  }
   1352 
   1353  /**
   1354   * Add titles to a cluster in a SmartTabGroupingResult using generative tehniques
   1355   * Currently this function only works with a single target group, and a separate
   1356   * item that represents all other ungrouped tabs.
   1357   *
   1358   * In the future this may be updated to more generally find labels for a set of clusters.
   1359   *
   1360   * @param {SmartTabGroupingResult} groupingResult The cluster we are generating the label for
   1361   * @param {SmartTabGroupingResult} otherGroupingResult A 'made up' cluster representing all other tabs in the window
   1362   */
   1363  async generateGroupLabels(groupingResult, otherGroupingResult = null) {
   1364    // Special case for a search page
   1365    const searchTopicSpecialCase = Services.prefs.getBoolPref(
   1366      "browser.tabs.groups.smart.searchTopicEnabled",
   1367      true
   1368    );
   1369    if (
   1370      searchTopicSpecialCase &&
   1371      groupingResult.clusterRepresentations.length == 1 &&
   1372      groupingResult.clusterRepresentations[0].isSingleTabSearch
   1373    ) {
   1374      if (groupingResult.clusterRepresentations[0].setSingleTabSearchLabel()) {
   1375        return;
   1376      }
   1377    }
   1378 
   1379    const { keywords, documents } =
   1380      groupingResult.getRepresentativeDocsAndKeywords(
   1381        otherGroupingResult
   1382          ? otherGroupingResult.getRepresentativeDocuments()
   1383          : []
   1384      );
   1385    const inputArgs = this.createModelInput(
   1386      keywords ? keywords[0] : [],
   1387      documents
   1388    );
   1389    const requestInfo = {
   1390      inputArgs,
   1391      runOptions: {
   1392        max_length: 6,
   1393      },
   1394    };
   1395 
   1396    if (SmartTabGroupingManager.isEngineClosed(this.topicEngine)) {
   1397      this.topicEngine = await this._createMLEngine(
   1398        this.config.topicGeneration
   1399      );
   1400    }
   1401    const request = {
   1402      args: [requestInfo.inputArgs],
   1403      options: requestInfo.runOptions,
   1404    };
   1405    const genLabelResults = await this.topicEngine.run(request);
   1406    genLabelResults.forEach((genResult, genResultIndex) => {
   1407      groupingResult.clusterRepresentations[
   1408        genResultIndex
   1409      ].predictedTopicLabel = this.processTopicModelResult(
   1410        genResult.generated_text
   1411      );
   1412    });
   1413  }
   1414 
   1415  getLabelReason() {
   1416    return this.labelReason || LABEL_REASONS.DEFAULT;
   1417  }
   1418 
   1419  /**
   1420   * Generates glean metrics for ml smart tab label / topic.
   1421   * This is currently called when the user saves or cancels the "suggest label" flow.
   1422   *
   1423   * @param {string} action "save" or "cancel"
   1424   * @param {number} numTabsInGroup Number of tabs used to generate the label
   1425   * @param {string} mlLabel ML generated label for the tab group
   1426   * @param {string} userLabel User saved label for the tab group
   1427   * @param {string} id The id of the group
   1428   */
   1429  async handleLabelTelemetry({
   1430    action,
   1431    numTabsInGroup,
   1432    mlLabel,
   1433    userLabel,
   1434    id = "",
   1435  }) {
   1436    const { [ML_TASK_TEXT2TEXT]: topicEngineConfig } =
   1437      await this.getEngineConfigs();
   1438    const labelReason = this.getLabelReason();
   1439    Glean.tabgroup.smartTabTopic.record({
   1440      action,
   1441      tabs_in_group: numTabsInGroup,
   1442      ml_label_length: (mlLabel || "").length,
   1443      user_label_length: (userLabel || "").length,
   1444      levenshtein_distance: lazy.NLP.levenshtein(
   1445        userLabel || "",
   1446        mlLabel || ""
   1447      ),
   1448      model_revision: topicEngineConfig.modelRevision || "",
   1449      id,
   1450      label_reason: labelReason,
   1451      backend: this.backend || "onnx-native",
   1452    });
   1453    this.labelReason = LABEL_REASONS.DEFAULT;
   1454  }
   1455 
   1456  /**
   1457   * Generates glean metrics for ml smart tab label / topic.
   1458   * This is currently called when the user saves or cancels the "suggest other tabs" flow
   1459   *
   1460   * @param {string} action "save" or "cancel"
   1461   * @param {number} numTabsInWindow Number of tabs in the current window
   1462   * @param {number} numTabsInGroup Number of tabs in the current group
   1463   * @param {number} numTabsSuggested Number of tabs suggested by the model
   1464   * @param {number} numTabsApproved Number of tabs approved by the user
   1465   * @param {number} numTabsRemoved Number of tabs removed by the user
   1466   * @param {string} id The id of the group
   1467   */
   1468  async handleSuggestTelemetry({
   1469    action,
   1470    numTabsInWindow,
   1471    numTabsInGroup,
   1472    numTabsSuggested,
   1473    numTabsApproved,
   1474    numTabsRemoved,
   1475    id = "",
   1476  }) {
   1477    const { [ML_TASK_FEATURE_EXTRACTION]: embeddingEngineConfig } =
   1478      await this.getEngineConfigs();
   1479    Glean.tabgroup.smartTabSuggest.record({
   1480      action,
   1481      tabs_in_window: numTabsInWindow,
   1482      tabs_in_group: numTabsInGroup,
   1483      tabs_suggested: numTabsSuggested,
   1484      tabs_approved: numTabsApproved,
   1485      tabs_removed: numTabsRemoved,
   1486      model_revision: embeddingEngineConfig.modelRevision || "",
   1487      id,
   1488      backend: this.backend || "onnx-native",
   1489    });
   1490  }
   1491 
   1492  /**
   1493   * Gets config that engine was initialized with
   1494   *
   1495   * @return {Promise<{"[ML_TASK_TEXT2TEXT]", "[ML_TASK_FEATURE_EXTRACTION]"}>}
   1496   */
   1497  async getEngineConfigs() {
   1498    if (!this.topicEngineConfig) {
   1499      this.topicEngineConfig = await lazy.MLEngineParent.getInferenceOptions(
   1500        this.config.topicGeneration.featureId,
   1501        this.config.topicGeneration.taskName
   1502      );
   1503    }
   1504    if (!this.embeddingEngineConfig) {
   1505      this.embeddingEngineConfig =
   1506        await lazy.MLEngineParent.getInferenceOptions(
   1507          this.config.embedding.featureId,
   1508          this.config.embedding.taskName
   1509        );
   1510    }
   1511    return {
   1512      [ML_TASK_TEXT2TEXT]: this.topicEngineConfig,
   1513      [ML_TASK_FEATURE_EXTRACTION]: this.embeddingEngineConfig,
   1514    };
   1515  }
   1516 }
   1517 
   1518 export class SmartTabGroupingResult {
   1519  #anchorClusterIndex = -1; // Index of cluster that has original items we're building clustering around, when building around an existing item.
   1520 
   1521  /**
   1522   * Creates a result from indices and complete tab and embedding lists.
   1523   * This may create some extra data for management later
   1524   *
   1525   * @param indices indices of clusters (eg [[2,4], [1], [3]]_
   1526   * @param tabItems 1D array of tabs
   1527   * @param embeddingItems Two dimensional array of embeddings
   1528   * @param config Cluster config
   1529   */
   1530  constructor({ indices = [], tabs, embeddings, config }) {
   1531    this.embeddingItems = embeddings;
   1532    this.config = config;
   1533    this.indices = indices.filter(subArray => !!subArray.length); // Cleanup any empty clusters
   1534    this.tabItems = tabs;
   1535    this._buildClusterRepresentations();
   1536  }
   1537 
   1538  /**
   1539   * Builds list of ClusterRepresentations
   1540   */
   1541  _buildClusterRepresentations() {
   1542    this.clusterRepresentations = this.indices.map(subClusterIndices => {
   1543      const tabItemsMapped =
   1544        this.tabItems && subClusterIndices.map(idx => this.tabItems[idx]);
   1545      const embeddingItemsMapped =
   1546        this.embeddingItems &&
   1547        subClusterIndices.map(idx => this.embeddingItems[idx]);
   1548      return new ClusterRepresentation({
   1549        tabs: tabItemsMapped,
   1550        embeddings: embeddingItemsMapped,
   1551        config: this.config,
   1552      });
   1553    });
   1554  }
   1555 
   1556  /**
   1557   * Returns a list of documents for each cluster. Currently it is a list of documents picked
   1558   * in no particular order.
   1559   *
   1560   * @return {[strings]} Title and description that represent the cluster. (If no docs are in the class, then titles are returned)
   1561   */
   1562  getRepresentativeDocuments() {
   1563    if (!this.documents) {
   1564      this.documents = this.tabItems.map(
   1565        t => t[this.config.dataConfig.titleKey]
   1566      );
   1567    }
   1568    // set a limit of 10 for now
   1569    return this.documents.slice(0, 10);
   1570  }
   1571 
   1572  /**
   1573   * Returns the keywords and documents for the cluster, computing if needed
   1574   * Does not return keywods if only one document is passed to the function.
   1575   *
   1576   * @param {string[]} otherDocuments other clusters that we'll compare against
   1577   * @return keywords and documents that represent the cluster
   1578   */
   1579  getRepresentativeDocsAndKeywords(otherDocuments = []) {
   1580    this.documents = this.getRepresentativeDocuments();
   1581    if (!this.keywords) {
   1582      const joinedDocs = this.documents.slice(0, 3).join(" ");
   1583      const otherDocs = otherDocuments.join(" ");
   1584      if (this.documents.length > 1) {
   1585        const keywordExtractor = new KeywordExtractor();
   1586        this.keywords = keywordExtractor.fitTransform([joinedDocs, otherDocs]);
   1587      } else {
   1588        this.keywords = [];
   1589      }
   1590    }
   1591    return { keywords: this.keywords, documents: this.documents };
   1592  }
   1593 
   1594  setAnchorClusterIndex(index) {
   1595    this.#anchorClusterIndex = index;
   1596  }
   1597 
   1598  /**
   1599   * Get the cluster we originally are grouping around (finding additinoal item)
   1600   *
   1601   * @returns ClusterRepresentation
   1602   */
   1603  getAnchorCluster() {
   1604    if (this.#anchorClusterIndex === -1) {
   1605      return null;
   1606    }
   1607    return this.clusterRepresentations[this.#anchorClusterIndex];
   1608  }
   1609 
   1610  /**
   1611   *   Given the indices that we were clustering around, make sure they are are all in the target grouping
   1612   *   Our generic k-means clustering might have them in separate groups
   1613   */
   1614  adjustClusterForAnchors(anchorIndices) {
   1615    if (!anchorIndices.length) {
   1616      return;
   1617    }
   1618    const anchorSet = new Set(anchorIndices);
   1619    for (let i = 0; i < this.indices.length; i++) {
   1620      if (i === this.#anchorClusterIndex) {
   1621        continue;
   1622      }
   1623      this.indices[i] = this.indices[i].filter(item => {
   1624        if (anchorSet.has(item)) {
   1625          this.indices[this.#anchorClusterIndex].push(item);
   1626          return false;
   1627        }
   1628        return true;
   1629      });
   1630    }
   1631    this._buildClusterRepresentations();
   1632  }
   1633 
   1634  /**
   1635   * Prints information about the cluster
   1636   */
   1637  printClusters() {
   1638    for (let cluster of this.clusterRepresentations) {
   1639      cluster.print();
   1640    }
   1641  }
   1642 
   1643  /**
   1644   * Computes the inertia of the cluster which is the sum of square total distance.
   1645   *
   1646   * @returns {number}
   1647   */
   1648  getCentroidInertia() {
   1649    let runningTotalDistance = 0;
   1650    this.clusterRepresentations.forEach(rep => {
   1651      runningTotalDistance += rep.computeTotalSquaredCentroidDistance();
   1652    });
   1653    return runningTotalDistance;
   1654  }
   1655 
   1656  /**
   1657   * Converts a cluster representation to a flat list of tabs, with clusterID key in each
   1658   * tab representing the id of the cluster it was part of.
   1659   *
   1660   * @returns {[object]}
   1661   */
   1662  _flatMapItemsInClusters() {
   1663    return this.clusterRepresentations.reduce((result, clusterRep) => {
   1664      const annotatedTabs = clusterRep.tabs.map(a => {
   1665        let c = {};
   1666        Object.assign(c, a);
   1667        c.clusterID = clusterRep.clusterID;
   1668        return c;
   1669      });
   1670      return result.concat(annotatedTabs);
   1671    }, []);
   1672  }
   1673 
   1674  /**
   1675   * Get rand score which describes the accuracy versus a user labeled
   1676   * annotation on the dataset. Requires the dataset to be labeled.
   1677   *
   1678   * @param labelKey Key in the tabs that represent a unique label ID for the cluster.
   1679   * @returns {number} The rand score.
   1680   */
   1681  getRandScore(labelKey = "annotatedLabel") {
   1682    const combinedItems = this._flatMapItemsInClusters();
   1683    return computeRandScore(combinedItems, "clusterID", labelKey);
   1684  }
   1685 
   1686  /**
   1687   * Get accuracy for a specific cluster
   1688   *
   1689   * @param labelKey Key in the tabs that represent a unique label ID for the cluster.
   1690   * @param clusterValue is the cluster we are comparing
   1691   * @returns {number} The rand score.
   1692   */
   1693  getAccuracyStatsForCluster(labelKey = "annotatedLabel", clusterValue) {
   1694    const combinedItems = this._flatMapItemsInClusters();
   1695 
   1696    let keyClusterId = combinedItems.find(
   1697      a => a[labelKey] === clusterValue
   1698    ).clusterID;
   1699 
   1700    let truePositives = 0,
   1701      trueNegatives = 0,
   1702      falseNegatives = 0,
   1703      falsePositives = 0;
   1704 
   1705    combinedItems.forEach(item => {
   1706      const sameLabel = item[labelKey] === clusterValue;
   1707      const sameCluster = item.clusterID === keyClusterId;
   1708      if (sameLabel && sameCluster) {
   1709        truePositives++;
   1710      }
   1711      if (!sameLabel && !sameCluster) {
   1712        trueNegatives++;
   1713      }
   1714      if (sameLabel && !sameCluster) {
   1715        falseNegatives++;
   1716      }
   1717      if (!sameLabel && sameCluster) {
   1718        falsePositives++;
   1719      }
   1720    });
   1721    return getAccuracyStats({
   1722      truePositives,
   1723      trueNegatives,
   1724      falsePositives,
   1725      falseNegatives,
   1726    });
   1727  }
   1728 }
   1729 
   1730 /**
   1731 * Utility function to generate a random ID string
   1732 *
   1733 * @param len Length of the string
   1734 * @returns {string}
   1735 */
   1736 function genHexString(len) {
   1737  const hex = "0123456789ABCDEF";
   1738  let output = "";
   1739  for (let i = 0; i < len; ++i) {
   1740    output += hex.charAt(Math.floor(Math.random() * hex.length));
   1741  }
   1742  return output;
   1743 }
   1744 
   1745 class EmbeddingCluster {
   1746  constructor({ tabs, embeddings, centroid }) {
   1747    this.embeddings = embeddings;
   1748    this.centroid =
   1749      centroid || (embeddings && computeCentroidFrom2DArray(this.embeddings));
   1750    this.tabs = tabs;
   1751  }
   1752 
   1753  /**
   1754   * @returns total sum euclidan squared distance of each item from cluster's centroid
   1755   */
   1756  computeTotalSquaredCentroidDistance() {
   1757    let totalDistance = 0;
   1758    if (this.embeddings.length === 0) {
   1759      return 0;
   1760    }
   1761    this.embeddings.forEach(embedding => {
   1762      totalDistance += euclideanDistance(this.centroid, embedding, true);
   1763    });
   1764    return totalDistance;
   1765  }
   1766 
   1767  /**
   1768   * Returns number of items in the cluster
   1769   *
   1770   * @returns {int}
   1771   */
   1772  numItems() {
   1773    return this.tabs.length;
   1774  }
   1775 }
   1776 
   1777 /**
   1778 * Represents a single cluster with additional saved metadata
   1779 */
   1780 export class ClusterRepresentation extends EmbeddingCluster {
   1781  constructor({ tabs, embeddings, centroid, config }) {
   1782    super({ tabs, embeddings, centroid });
   1783    this.config = config;
   1784    this.predictedTopicLabel = null;
   1785    this.annotatedTopicLabel = null;
   1786    this.userEditedTopicLabel = null;
   1787    this.representativeText = null;
   1788    this.keywords = null;
   1789    this.documents = null;
   1790    this.clusterID = genHexString(10);
   1791    this.isSingleTabSearch = tabs?.length == 1 && isSearchTab(tabs[0]);
   1792  }
   1793 
   1794  /**
   1795   * For a single tab cluster with a search field, set the predicted topic
   1796   * to be the title of the page
   1797   *
   1798   * @returns {boolean} True if we updated the cluster label successfully
   1799   */
   1800  setSingleTabSearchLabel() {
   1801    if (this.tabs.length !== 1) {
   1802      return false;
   1803    }
   1804    const pageTitle = this.tabs[0][this.config.dataConfig.titleKey] || "";
   1805    for (let i = pageTitle.length - 1; i > 0; i--) {
   1806      if (TITLE_DELIMETER_SET.has(pageTitle[i])) {
   1807        const topicString = pageTitle.substring(0, i).trim();
   1808        if (topicString.length > MAX_NON_SUMMARIZED_SEARCH_LENGTH) {
   1809          return false;
   1810        }
   1811        // Capitalize first character of each word. Regex returns first char of each word
   1812        this.predictedTopicLabel = topicString.replace(/(^|\s)\S/g, t =>
   1813          t.toUpperCase()
   1814        );
   1815        return true;
   1816      }
   1817    }
   1818    return false;
   1819  }
   1820 
   1821  /**
   1822   * Returns the representative text for a cluster, computing it if needed
   1823   */
   1824  getRepresentativeText() {
   1825    if (!this.representativeText) {
   1826      this.representativeText = this._generateRepresentativeText();
   1827    }
   1828    return this.representativeText;
   1829  }
   1830 
   1831  /**
   1832   * Returns representative text for a cluster.
   1833   * For this in initial implementation it simply returns title from a few tabs
   1834   *
   1835   * @returns {string}
   1836   * @private
   1837   */
   1838  _generateRepresentativeText() {
   1839    let text = "";
   1840    const titleKey = this.config.dataConfig.titleKey;
   1841    for (const tab of this.tabs.slice(0, 3)) {
   1842      text += `\n${tab[titleKey]}`;
   1843    }
   1844    return text;
   1845  }
   1846 
   1847  print() {
   1848    // Add console log for debugging
   1849  }
   1850 }