SmartTabGrouping.sys.mjs (58799B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 4 5 import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs"; 6 import { createEngine } from "chrome://global/content/ml/EngineProcess.sys.mjs"; 7 import { 8 cosSim, 9 KeywordExtractor, 10 } from "chrome://global/content/ml/NLPUtils.sys.mjs"; 11 12 import { 13 computeCentroidFrom2DArray, 14 computeRandScore, 15 euclideanDistance, 16 getAccuracyStats, 17 kmeansPlusPlus, 18 silhouetteCoefficients, 19 } from "chrome://global/content/ml/ClusterAlgos.sys.mjs"; 20 21 const lazy = {}; 22 23 ChromeUtils.defineESModuleGetters(lazy, { 24 NLP: "resource://gre/modules/NLP.sys.mjs", 25 MLEngineParent: "resource://gre/actors/MLEngineParent.sys.mjs", 26 MultiProgressAggregator: "chrome://global/content/ml/Utils.sys.mjs", 27 Progress: "chrome://global/content/ml/Utils.sys.mjs", 28 }); 29 30 const LATEST_MODEL_REVISION = "latest"; 31 32 // Methods for suggesting tabs that are similar to current tab 33 export const SUGGEST_OTHER_TABS_METHODS = { 34 KMEANS_WITH_ANCHOR: "KMEANS_WITH_ANCHOR", 35 NEAREST_NEIGHBOR: "NEAREST_NEIGHBOR", 36 LOGISTIC_REGRESSION: "LOGISTIC_REGRESSION", 37 }; 38 39 XPCOMUtils.defineLazyPreferenceGetter( 40 lazy, 41 "suggestOtherTabsMethod", 42 "browser.tabs.groups.smart.suggestOtherTabsMethod" 43 ); 44 45 XPCOMUtils.defineLazyPreferenceGetter( 46 lazy, 47 "topicModelRevision", 48 "browser.tabs.groups.smart.topicModelRevision" 49 ); 50 51 XPCOMUtils.defineLazyPreferenceGetter( 52 lazy, 53 "embeddingModelRevision", 54 "browser.tabs.groups.smart.embeddingModelRevision" 55 ); 56 57 XPCOMUtils.defineLazyPreferenceGetter( 58 lazy, 59 "nearestNeighborThresholdInt", 60 "browser.tabs.groups.smart.nearestNeighborThresholdInt" 61 ); 62 63 const EMBED_TEXT_KEY = "combined_text"; 64 export const CLUSTER_METHODS = { 65 KMEANS: "KMEANS", 66 }; 67 68 // Methods for finding similar items for an existing cluster 69 export const ANCHOR_METHODS = { 70 DRIFT: "DRIFT", // We let k-means clustering run, and find the cluster with the most anchor items 71 FIXED: "FIXED", // We always group with the anchor items in the 0 cluster, and never let them be reassinged 72 }; 73 74 // Methods for finding ignoring other groups that were already grouped 75 export const PREGROUPED_HANDLING_METHODS = { 76 EXCLUDE: "EXCLUDE", // We let k-means clustering run, and find the cluster with the most anchor items 77 IGNORE: "IGNORE", // We always group with the anchor items in the 0 cluster, and never let them be reassinged 78 }; 79 80 const EXPECTED_TOPIC_MODEL_OBJECTS = 6; 81 const EXPECTED_EMBEDDING_MODEL_OBJECTS = 4; 82 83 const MAX_NON_SUMMARIZED_SEARCH_LENGTH = 26; 84 85 export const DIM_REDUCTION_METHODS = {}; 86 const MISSING_ANCHOR_IN_CLUSTER_PENALTY = 0.2; 87 const MAX_GROUPED_TABS = 3; 88 const MAX_SUGGESTED_TABS = 10; 89 // limit number of tabs to be processed so inference process doesn't crash 90 const MAX_TABS_TO_PROCESS = 300; 91 92 const DISSIMILAR_TAB_LABEL = "none"; 93 const ADULT_TAB_LABEL = "adult content"; 94 const LABELS_TO_EXCLUDE = [DISSIMILAR_TAB_LABEL, ADULT_TAB_LABEL]; 95 96 const ML_TASK_FEATURE_EXTRACTION = "feature-extraction"; 97 const ML_TASK_TEXT2TEXT = "text2text-generation"; 98 99 const LABEL_REASONS = { 100 DEFAULT: "DEFAULT", 101 LOW_CONFIDENCE: "LOW_CONFIDENCE", 102 EXCLUDE: "EXCLUDE", 103 ERROR: "ERROR", 104 }; 105 106 export const SMART_TAB_GROUPING_CONFIG = { 107 embedding: { 108 dtype: "q8", 109 timeoutMS: 2 * 60 * 1000, // 2 minutes 110 taskName: ML_TASK_FEATURE_EXTRACTION, 111 featureId: "smart-tab-embedding", 112 backend: "onnx-native", 113 fallbackBackend: "onnx", 114 }, 115 topicGeneration: { 116 dtype: "q8", 117 timeoutMS: 2 * 60 * 1000, // 2 minutes 118 taskName: ML_TASK_TEXT2TEXT, 119 featureId: "smart-tab-topic", 120 backend: "onnx-native", 121 fallbackBackend: "onnx", 122 }, 123 dataConfig: { 124 titleKey: "label", 125 descriptionKey: "description", 126 }, 127 clustering: { 128 dimReductionMethod: null, // Not completed. 129 clusterImplementation: CLUSTER_METHODS.KMEANS, 130 clusteringTriesPerK: 3, 131 anchorMethod: ANCHOR_METHODS.FIXED, 132 pregroupedHandlingMethod: PREGROUPED_HANDLING_METHODS.EXCLUDE, 133 pregroupedSilhouetteBoost: 2, // Relative weight of the cluster's score and all other cluster's combined 134 suggestOtherTabsMethod: SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR, 135 }, 136 }; 137 138 // these parameters were generated by training a logistic regression 139 // model on synthetic data. see https://github.com/mozilla/smart-tab-grouping 140 // and https://github.com/mozilla/smart-tab-grouping/pull/12 for more info 141 const LOGISTIC_REGRESSION_PARAMS = { 142 // Logistic WITH group name 143 // Features: s_gc, s_tt_max, s_dd in [0, 1] 144 TITLE_WITH_GROUP_NAME: { 145 GROUP_SIMILARITY_WEIGHT: 0.10249, 146 TITLE_SIMILARITY_WEIGHT: 0.54897, 147 DOMAIN_SIMILARITY_WEIGHT: 0.34854, 148 INTERCEPT: -0.07397, 149 THRESHOLD: 0.59, 150 }, 151 // Logistic WITHOUT group name 152 // Features: s_tt_max, s_dd in [0, 1] 153 TITLE_ONLY: { 154 GROUP_SIMILARITY_WEIGHT: 0, // unused in this variant 155 TITLE_SIMILARITY_WEIGHT: 0.92513, 156 DOMAIN_SIMILARITY_WEIGHT: 0.07487, 157 INTERCEPT: -2.58574, 158 THRESHOLD: 0.123, 159 }, 160 }; 161 162 const TAB_URLS_TO_EXCLUDE = [ 163 "about:newtab", 164 "about:home", 165 "about:privatebrowsing", 166 "chrome://browser/content/blanktab.html", 167 "about:firefoxview", 168 "about:opentabs", 169 ]; 170 171 const TITLE_DELIMETER_SET = new Set(["-", "|", "—"]); 172 173 /** 174 * For a given set of clusters represented by indices, returns the index of the cluster 175 * that has the most anchor items inside it. 176 * 177 * An anhor item is an index that represents the index to a tab that is already grouped and in 178 * the cluster we're interested in finding more items for. 179 * 180 * @param {number[][]} groupIndices - Array of clusters represented as arrays of indices. 181 * @param {number[]} anchorItems - Array of anchor item indices. 182 * @returns {{anchorClusterIndex: number, numAnchorItemsInCluster: number}} Index of best cluster and the number of anchor items. 183 */ 184 export function getBestAnchorClusterInfo(groupIndices, anchorItems) { 185 const anchorItemSet = new Set(anchorItems); 186 const numItemsList = groupIndices.map(g => 187 g.reduce( 188 (cur, itemIndex) => (anchorItemSet.has(itemIndex) ? cur + 1 : cur), 189 0 190 ) 191 ); 192 const anchorClusterIndex = numItemsList.indexOf(Math.max(...numItemsList)); 193 const numAnchorItemsInCluster = numItemsList[anchorClusterIndex]; 194 return { anchorClusterIndex, numAnchorItemsInCluster }; 195 } 196 197 /** 198 * Check tab to see if it's a search page 199 * 200 * @param {object} tab 201 * @returns {boolean} Returns true if the tab is a web search from the Firefox search UI and the user is still on the original page. 202 * Changes in search query after search is made is supported. 203 * Returns false if user started from a hompepage of a site rather than the New Tab / browser UI 204 */ 205 export function isSearchTab(tab) { 206 const linkedBrowser = tab?.linkedBrowser; 207 if (!linkedBrowser) { 208 return false; 209 } 210 const searchURL = linkedBrowser.getAttribute("triggeringSearchEngineURL"); 211 const curURL = linkedBrowser.currentURI?.spec; 212 if (!searchURL) { 213 return false; 214 } 215 const queryFieldsMarker = searchURL.indexOf("?"); 216 217 if ( 218 queryFieldsMarker > 0 && 219 searchURL.substring(0, queryFieldsMarker) === 220 curURL.substring(0, queryFieldsMarker) 221 ) { 222 return true; 223 } 224 return false; 225 } 226 227 export class SmartTabGroupingManager { 228 /** 229 * Creates the SmartTabGroupingManager object. 230 * 231 * @param {object} config configuration options 232 */ 233 constructor(config) { 234 this.config = config || structuredClone(SMART_TAB_GROUPING_CONFIG); 235 } 236 237 /** 238 * 239 * @param {MLEngine} engine the engine to check 240 * @return {boolean} true if the engine has not been initialized or closed 241 */ 242 static isEngineClosed(engine) { 243 return !engine || engine?.engineStatus === "closed"; 244 } 245 246 /** 247 * Initializes the embedding engine by running a test request 248 * This helps remove the init latency 249 */ 250 async initEmbeddingEngine() { 251 if (!SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) { 252 return; 253 } 254 try { 255 this.embeddingEngine = await this._createMLEngine(this.config.embedding); 256 const request = { 257 args: ["Test"], 258 options: { pooling: "mean", normalize: true }, 259 }; 260 this.embeddingEngine.run(request); 261 } catch (e) {} 262 } 263 264 /** 265 * Generates tabs to process with a limit. First MAX_GROUPED_TABS are tabs that are 266 * present in the group of the anchor tab. The remaining "ungrouped" tabs fill the 267 * slots up to MAX_TABS_TO_PROCESS 268 * 269 * @param {Array} tabsInGroup active tabs in anchor group we are adding tabs to 270 * @param {Array} allTabs list of tabs from gbrowser, some of which may be grouped in other groups 271 * @param {number} max_limit_to_process max number of tabs we want to process as part of the flow 272 * @returns a list of suggested new tabs. If no new tabs are suggested an empty list is returned. 273 */ 274 getTabsToProcess( 275 tabsInGroup, 276 allTabs, 277 max_limit_to_process = MAX_TABS_TO_PROCESS 278 ) { 279 const seen = new Set(); 280 let tabsToProcess = []; 281 282 const shouldInclude = tab => { 283 if (tab.pinned) { 284 return false; 285 } 286 if (!tab?.linkedBrowser?.currentURI?.spec) { 287 return false; 288 } 289 return true; 290 }; 291 292 // include tabs in the anchor group first 293 for (const tab of tabsInGroup) { 294 if (!shouldInclude(tab)) { 295 continue; 296 } 297 if (!seen.has(tab)) { 298 // make sure we have "seen" all the 299 // tabs already in the current group 300 seen.add(tab); 301 tabsToProcess.push(tab); 302 } 303 } 304 305 // when generating embeddings, we only look at the first MAX_GROUPED_TABS 306 // so use that limit here 307 tabsToProcess = tabsToProcess.slice(0, MAX_GROUPED_TABS); 308 // fill remaining slots with ungrouped tabs from the window 309 for (const tab of allTabs) { 310 if (tabsToProcess.length >= max_limit_to_process) { 311 break; 312 } 313 if (!shouldInclude(tab)) { 314 continue; 315 } 316 if (!seen.has(tab)) { 317 seen.add(tab); 318 tabsToProcess.push(tab); 319 } 320 } 321 return tabsToProcess; 322 } 323 324 /** 325 * Generates suggested tabs for an existing or provisional group 326 * 327 * @param {object} group active group we are adding tabs to 328 * @param {Array} tabs list of tabs from gbrowser, some of which may be grouped in other groups 329 * @returns a list of suggested new tabs. If no new tabs are suggested an empty list is returned. 330 */ 331 async smartTabGroupingForGroup(group, tabs) { 332 // Add tabs to suggested group 333 const groupTabs = group.tabs; 334 const allTabs = this.getTabsToProcess(groupTabs, tabs, MAX_TABS_TO_PROCESS); 335 // first (1 up to MAX_GROUPED_TABS) are tabs in the group 336 const groupIndices = []; 337 for (let i = 0; i < MAX_GROUPED_TABS; i++) { 338 if (groupTabs.includes(allTabs[i])) { 339 groupIndices.push(i); 340 } 341 } 342 343 // find tabs that are part of other groups 344 const alreadyGroupedIndices = allTabs 345 .map((t, i) => (t.group ? i : -1)) 346 .filter(a => a >= 0); 347 348 let suggestedTabs; 349 switch (lazy.suggestOtherTabsMethod) { 350 case SUGGEST_OTHER_TABS_METHODS.KMEANS_WITH_ANCHOR: 351 suggestedTabs = await this.generateClusters( 352 allTabs, 353 null, 354 null, 355 null, 356 groupIndices, 357 alreadyGroupedIndices 358 ).then(clusters => { 359 if (!clusters) { 360 return []; 361 } 362 const targetCluster = clusters.clusterRepresentations.find(c => 363 groupTabs.some(g => c.tabs.includes(g)) 364 ); 365 if (targetCluster) { 366 // Return only tabs not already grouped 367 return targetCluster.tabs.filter(t => !t.group); 368 } 369 return []; 370 }); 371 break; 372 case SUGGEST_OTHER_TABS_METHODS.LOGISTIC_REGRESSION: 373 suggestedTabs = await this.findSimilarTabsLogisticRegression({ 374 allTabs, 375 groupedIndices: groupIndices, 376 alreadyGroupedIndices, 377 groupLabel: group?.label, 378 }); 379 break; 380 case SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR: 381 default: 382 // find nearest neighbors to current group 383 suggestedTabs = await this.findNearestNeighbors({ 384 allTabs, 385 groupedIndices: groupIndices, 386 alreadyGroupedIndices, 387 groupLabel: group?.label, 388 }); 389 } 390 return suggestedTabs.slice(0, MAX_SUGGESTED_TABS); 391 } 392 393 /** 394 * Get tabs that need to be included in suggestions 395 * 396 * @param {Array} allTabs all tabs that are part of the window 397 * @param {Array} groupedIndices indices of tabs that are already part of the group 398 * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups 399 * @returns {Array} tabs indices to be considered for suggestions 400 */ 401 getTabsToSuggest(allTabs, groupedIndices, alreadyGroupedIndices) { 402 // tabs to be excluded 403 // indices of all tabs that should be excluded (with duplicates) 404 const tabURLIndicesToExclude = allTabs 405 .map((at, index) => (TAB_URLS_TO_EXCLUDE.includes(at.url) ? index : -1)) 406 .filter(index => index !== -1); 407 const excludedTabIndices = [ 408 ...groupedIndices, 409 ...alreadyGroupedIndices, 410 ...tabURLIndicesToExclude, 411 ]; 412 413 // tabs to be included 414 return allTabs 415 .map((_, index) => index) 416 .filter(i => !excludedTabIndices.includes(i)); 417 } 418 419 /** 420 * Generates similar tabs a grouped list of tabs 421 * 422 * @param {Array} allTabs all tabs that are part of the window 423 * @param {Array} groupedIndices indices of tabs that are already part of the group 424 * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups 425 * @param {string} groupLabel name of group if present 426 * @param {number} threshold for nearest neighbor similarity 427 * @returns a list of suggested tabs that are similar to the groupedIndices tabs 428 */ 429 async findNearestNeighbors({ 430 allTabs, 431 groupedIndices, 432 alreadyGroupedIndices, 433 groupLabel = "", 434 thresholdMills = lazy.nearestNeighborThresholdInt, 435 precomputedEmbeddings = [], 436 depth = 0, 437 }) { 438 // get embeddings for all the tabs 439 const tabData = await this._prepareTabData(allTabs); 440 let embeddings = precomputedEmbeddings; 441 if (precomputedEmbeddings.length === 0) { 442 embeddings = await this._generateEmbeddings( 443 tabData.map((td, index) => { 444 let text = SmartTabGroupingManager.preprocessText(td[EMBED_TEXT_KEY]); 445 // augment with group name if it's present 446 if (groupLabel && groupedIndices.includes(index)) { 447 text = `${groupLabel.slice(0, 100)}. ${text}`; 448 } 449 return text; 450 }) 451 ); 452 } 453 454 // tabs that need to be assigned after filtering 455 const tabsToAssignIndices = this.getTabsToSuggest( 456 tabData, 457 groupedIndices, 458 alreadyGroupedIndices 459 ); 460 461 let closestTabs = []; 462 const similarTabsIndices = []; 463 for (let i = 0; i < tabsToAssignIndices.length; i++) { 464 let closestScore = null; 465 for ( 466 let j = 0; 467 j < Math.min(groupedIndices.length, MAX_GROUPED_TABS); 468 j++ 469 ) { 470 const cosineSim = cosSim( 471 embeddings[tabsToAssignIndices[i]], 472 embeddings[groupedIndices[j]] 473 ); 474 if (!closestScore || cosineSim > closestScore) { 475 closestScore = cosineSim; 476 } 477 } 478 // threshold could also be set via a nimbus experiment, in which case 479 // it will be an int <= 1000 480 if (closestScore > thresholdMills / 1000) { 481 closestTabs.push([allTabs[tabsToAssignIndices[i]], closestScore]); 482 similarTabsIndices.push(tabsToAssignIndices[i]); 483 } 484 } 485 closestTabs.sort((a, b) => b[1] - a[1]); 486 closestTabs = closestTabs.map(t => t[0]); 487 // recurse once if the initial call only had a single tab 488 // and we found at least 1 similar tab - this improves recall 489 if (groupedIndices.length === 1 && !!closestTabs.length && depth === 1) { 490 const recurseSimilarTabs = await this.findNearestNeighbors({ 491 allTabs, 492 groupedIndices: similarTabsIndices, 493 alreadyGroupedIndices: alreadyGroupedIndices.concat(groupedIndices), 494 groupLabel, 495 thresholdMills, 496 precomputedEmbeddings: embeddings, 497 depth: depth - 1, 498 }); 499 closestTabs = closestTabs.concat(recurseSimilarTabs); 500 } 501 return closestTabs; 502 } 503 504 /** 505 * Calculates the average similarity between the anchor embeddings and the candidate embeddings 506 * 507 * @param {number[]} anchorEmbeddings title embeddings for the anchor tabs 508 * @param {number[]} candidateEmbeddings title embeddings for the candidate tabs 509 */ 510 getAverageSimilarity(anchorEmbeddings, candidateEmbeddings) { 511 let averageSimilarities = []; 512 for (let candidate_embedding of candidateEmbeddings) { 513 let averageSimilarity = 0; 514 for (let anchor_embedding of anchorEmbeddings) { 515 averageSimilarity += cosSim(candidate_embedding, anchor_embedding); 516 } 517 averageSimilarities.push(averageSimilarity / anchorEmbeddings.length); 518 } 519 return averageSimilarities; 520 } 521 522 /** 523 * Calculates the max similarity between the anchor embeddings and the candidate embeddings 524 * (used for s_tt_max). 525 * 526 * @param {number[]} anchorEmbeddings title embeddings for the anchor tabs 527 * @param {number[]} candidateEmbeddings title embeddings for the candidate tabs 528 */ 529 getMaxSimilarity(anchorEmbeddings, candidateEmbeddings) { 530 let maxSimilarities = []; 531 for (let candidate_embedding of candidateEmbeddings) { 532 let maxSimilarity = -1; 533 for (let anchor_embedding of anchorEmbeddings) { 534 const sim = cosSim(candidate_embedding, anchor_embedding); 535 if (sim > maxSimilarity) { 536 maxSimilarity = sim; 537 } 538 } 539 maxSimilarities.push(maxSimilarity); 540 } 541 return maxSimilarities; 542 } 543 544 /** 545 * Extract base domain from a URL with error handling 546 * 547 * @param {string} url 548 * @return {string} 549 */ 550 static getBaseDomain(url) { 551 if (!url) { 552 return ""; 553 } 554 555 let hostname; 556 try { 557 ({ hostname } = new URL(url)); 558 } catch (_e) { 559 // invalid URL 560 return ""; 561 } 562 563 if (!hostname) { 564 return ""; 565 } 566 567 try { 568 // additionalParts = 1 → one label above the registrable domain 569 // then remove 'www' 570 // https://www.example.com -> www.example.com -> example.com 571 // https://www.docs.google.com -> docs.google.com 572 // https://localhost -> error 573 return Services.eTLD 574 .getBaseDomain(Services.io.newURI(url.toLowerCase()), 1) 575 .replace(/^www\./, ""); 576 } catch (_e) { 577 // localhost, IPs, internal hosts, etc. 578 // bucket by the hostname. 579 return hostname.toLowerCase(); 580 } 581 } 582 583 /** 584 * For each candidate tab, compute s_dd = fraction of anchors whose base domain 585 * matches the candidate's base domain. 586 * 587 * @param {Array} anchorTabsPrep output of _prepareTabData for anchor tabs 588 * @param {Array} candidateTabsPrep output of _prepareTabData for candidate tabs 589 * @return {number[]} array of s_dd values in [0, 1] 590 */ 591 getDomainMatchFractions(anchorTabsPrep, candidateTabsPrep) { 592 const anchorDomains = anchorTabsPrep.map(t => 593 SmartTabGroupingManager.getBaseDomain(t.url) 594 ); 595 const numAnchors = anchorDomains.length || 1; 596 597 return candidateTabsPrep.map(tab => { 598 const candDomain = SmartTabGroupingManager.getBaseDomain(tab.url); 599 if (!candDomain) { 600 return 0; 601 } 602 let same = 0; 603 for (const ad of anchorDomains) { 604 if (ad && ad === candDomain) { 605 same++; 606 } 607 } 608 return same / numAnchors; 609 }); 610 } 611 612 /** 613 * Calculates the sigmoid value of the input 614 * 615 * @param {number} z 616 * @return {number} 617 */ 618 sigmoid(z) { 619 return 1 / (1 + Math.exp(-z)); 620 } 621 622 /** 623 * Calculates the probability using the linear combination of the parameters 624 * 625 * @param {number} groupSimilarity s_gc in [0,1] 626 * @param {number} titleSimilarity s_tt_max in [0,1] 627 * @param {number} domainSimilarity s_dd in [0,1] 628 * @param {object} params the logistic regression weights assigned to each parameter 629 * @return {number} 630 */ 631 calculateProbability( 632 groupSimilarity, 633 titleSimilarity, 634 domainSimilarity, 635 params 636 ) { 637 const wGroup = params.GROUP_SIMILARITY_WEIGHT || 0; 638 const wTitle = params.TITLE_SIMILARITY_WEIGHT || 0; 639 const wDomain = params.DOMAIN_SIMILARITY_WEIGHT || 0; 640 const z = 641 groupSimilarity * wGroup + 642 titleSimilarity * wTitle + 643 domainSimilarity * wDomain + 644 params.INTERCEPT; 645 return this.sigmoid(z); 646 } 647 648 /** 649 * Calculates the probabilities given similarity lists (cosine) and domain fractions. 650 * 651 * @param {number[]|null} groupSimilaritiesCos cosine(group, candidate) in [-1,1] or null 652 * @param {number[]} titleSimilaritiesCos max cosine(anchor, candidate) in [-1,1] 653 * @param {number[]} domainSimilarities s_dd in [0,1] 654 * @return {number[]} probabilities for each candidate tab 655 */ 656 calculateAllProbabilities( 657 groupSimilaritiesCos, 658 titleSimilaritiesCos, 659 domainSimilarities 660 ) { 661 const hasGroupSimilarity = 662 Array.isArray(groupSimilaritiesCos) && groupSimilaritiesCos.length; 663 const useDomain = 664 Array.isArray(domainSimilarities) && domainSimilarities.length; 665 666 const probabilities = []; 667 for (let i = 0; i < titleSimilaritiesCos.length; i++) { 668 // groupTitleSim and titleSim are (cos + 1)/2 -> [0,1] 669 const groupTitleSim = hasGroupSimilarity 670 ? 0.5 * (groupSimilaritiesCos[i] + 1) 671 : 0; 672 const titleSim = 0.5 * (titleSimilaritiesCos[i] + 1); 673 const domainSim = useDomain ? domainSimilarities[i] : 0; 674 675 const params = hasGroupSimilarity 676 ? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME 677 : LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY; 678 679 probabilities.push( 680 this.calculateProbability(groupTitleSim, titleSim, domainSim, params) 681 ); 682 } 683 return probabilities; 684 } 685 686 /** 687 * Generates similar tabs to a grouped list of tabs using a logistic regression "model" 688 * 689 * @param {Array} allTabs all tabs that are part of the window 690 * @param {Array} groupedIndices indices of tabs that are already part of the group 691 * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups 692 * @param {string} groupLabel name of group if present 693 */ 694 async findSimilarTabsLogisticRegression({ 695 allTabs, 696 groupedIndices, 697 alreadyGroupedIndices, 698 groupLabel = "", 699 }) { 700 const tabData = await this._prepareTabData(allTabs); 701 const candidateIndices = this.getTabsToSuggest( 702 tabData, 703 groupedIndices, 704 alreadyGroupedIndices 705 ); 706 707 const candidateTabsData = candidateIndices.map(ci => allTabs[ci]); 708 const candidateTabsPrep = await this._prepareTabData(candidateTabsData); 709 710 const anchorTabsPrep = groupedIndices 711 .map(gi => tabData[gi]) 712 .slice(0, MAX_GROUPED_TABS); 713 714 // generate embeddings for both anchor and candidate titles 715 const titleEmbeddings = await this._generateEmbeddings( 716 anchorTabsPrep 717 .concat(candidateTabsPrep) 718 .map(tab => SmartTabGroupingManager.preprocessText(tab[EMBED_TEXT_KEY])) 719 ); 720 721 let groupEmbedding; 722 let groupSimilaritiesCos = null; 723 if (groupLabel) { 724 groupEmbedding = await this._generateEmbeddings([groupLabel]); 725 // cosine(group, candidate_title) in [-1,1] 726 groupSimilaritiesCos = this.getAverageSimilarity( 727 groupEmbedding, 728 titleEmbeddings.slice(anchorTabsPrep.length) 729 ); 730 } 731 732 // s_tt_max: max cosine(anchor_title, candidate_title) in [-1,1] 733 const titleSimilaritiesCos = this.getMaxSimilarity( 734 titleEmbeddings.slice(0, anchorTabsPrep.length), 735 titleEmbeddings.slice(anchorTabsPrep.length) 736 ); 737 738 // s_dd: fraction of anchors sharing the candidate's base domain 739 const domainSimilarities = this.getDomainMatchFractions( 740 anchorTabsPrep, 741 candidateTabsPrep 742 ); 743 744 const candidateProbabilities = this.calculateAllProbabilities( 745 groupSimilaritiesCos, 746 titleSimilaritiesCos, 747 domainSimilarities 748 ); 749 750 // get matching params depending on the group name availability 751 const probabilityThreshold = groupEmbedding 752 ? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME.THRESHOLD 753 : LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY.THRESHOLD; 754 755 return ( 756 candidateTabsData 757 // combine candidate tabs with corresponding probabilities 758 .map((ct, index) => ({ 759 ct, 760 prob: candidateProbabilities[index], 761 })) 762 // only keep those that are within the probability threshold 763 .filter(item => item.prob >= probabilityThreshold) 764 // ensure the highest probability candidates come first in the list 765 .sort((a, b) => b.prob - a.prob) 766 // keep the tabs only 767 .map(item => item.ct) 768 ); 769 } 770 771 /** 772 * This function will terminate a grouping or label generation in progress 773 * It is currently not implemented. 774 */ 775 terminateProcess() { 776 // TODO - teminate AI processes, This method will be 777 // called when tab grouping panel is closed. 778 } 779 780 /** 781 * Changes the clustering method. Must be one of supported methods. 782 * 783 * @param {string} method Name of method 784 */ 785 setClusteringMethod(method) { 786 if (!(method in CLUSTER_METHODS)) { 787 throw new Error(`Clustering method ${method} not supported`); 788 } 789 this.config.clustering.clusterImplementation = method; 790 } 791 792 /** 793 * Set the technique for clustering when certain tabs are already assigned to groups 794 * 795 * @param {string} method which is one of ANCHOR_METHODS 796 */ 797 setAnchorMethod(method) { 798 if (!(method in ANCHOR_METHODS)) { 799 throw new Error(`Clustering anchor method ${method} not supported`); 800 } 801 this.config.clustering.anchorMethod = method; 802 } 803 804 setSilBoost(boost) { 805 this.config.clustering.pregroupedSilhouetteBoost = boost; 806 } 807 808 /** 809 * Sets method to reduce dimensionality of embeddings prior to clustering 810 * 811 * @param {string} method Name of method 812 */ 813 setDimensionReductionMethod(method) { 814 if (method && !(method in DIM_REDUCTION_METHODS)) { 815 throw new Error(`Dimension reduction method ${method} not supported`); 816 } 817 this.config.clustering.dimReductionMethod = method; 818 } 819 820 /** 821 * Sets the field name of the title of a page to be used when clustering or generating embeddings 822 * This is useful when clustering test data that is not a tab object 823 * 824 * @param {string} titleKey KEY FOR THE TITLE 825 */ 826 setDataTitleKey(titleKey) { 827 this.config.dataConfig.titleKey = titleKey; 828 } 829 830 /** 831 * Logs to the appropriate place for debugging. Console for now 832 * 833 * @param {string} msg Message to log 834 * @param {boolean} useDescription Whether to add description to the final text 835 */ 836 log(_msg) {} 837 838 /** 839 * Prepares data to be used by the ml models 840 * 841 * @param {object[]} tabList list of tabs in the current window 842 * @param {boolean} useDescription whether we should combined the title and description 843 * @return {Promise<*[Object]>} 844 * @private 845 */ 846 async _prepareTabData(tabList, useDescription = false) { 847 const titleKey = this.config.dataConfig.titleKey; 848 const descriptionKey = this.config.dataConfig.descriptionKey; 849 const structuredData = []; 850 for (let tab of tabList) { 851 const description = 852 useDescription && descriptionKey && tab[descriptionKey]; 853 854 let textToEmbed; 855 if (description) { 856 textToEmbed = tab[titleKey] + " " + description; 857 } else { 858 textToEmbed = tab[titleKey] || "Unknown"; 859 } 860 861 structuredData.push({ 862 [EMBED_TEXT_KEY]: textToEmbed, 863 title: tab[titleKey], 864 description, 865 url: tab?.linkedBrowser?.currentURI?.spec, 866 }); 867 } 868 return structuredData; 869 } 870 871 /** 872 * Get updated config for the ml engine 873 * 874 * @param {object} initData 875 * @param {string} featureId 876 * @return {*} 877 */ 878 static getUpdatedInitData(initData, featureId) { 879 // we're setting a specific modelRevision through about:config or Nimbus 880 if ( 881 featureId === SMART_TAB_GROUPING_CONFIG.topicGeneration.featureId && 882 lazy.topicModelRevision !== LATEST_MODEL_REVISION 883 ) { 884 initData.modelRevision = lazy.topicModelRevision; 885 } else if ( 886 featureId === SMART_TAB_GROUPING_CONFIG.embedding.featureId && 887 lazy.embeddingModelRevision !== LATEST_MODEL_REVISION 888 ) { 889 initData.modelRevision = lazy.embeddingModelRevision; 890 } 891 return initData; 892 } 893 894 /** 895 * Creates an ML engine for a given config. 896 * 897 * @param {*} engineConfig 898 * @param {function} progressCallback 899 * @returns MLEngine 900 */ 901 async _createMLEngine(engineConfig, progressCallback) { 902 const { 903 featureId, 904 engineId, 905 dtype, 906 taskName, 907 timeoutMS, 908 modelId, 909 modelRevision, 910 backend, 911 fallbackBackend, 912 } = engineConfig; 913 let initData = { 914 featureId, 915 engineId, 916 dtype, 917 taskName, 918 timeoutMS, 919 modelId, 920 modelRevision, 921 backend, 922 }; 923 initData = SmartTabGroupingManager.getUpdatedInitData(initData, featureId); 924 let engine; 925 try { 926 engine = await createEngine(initData, progressCallback); 927 this.backend = backend; 928 } catch (e) { 929 engine = await createEngine( 930 { 931 ...initData, 932 backend: fallbackBackend, 933 }, 934 progressCallback 935 ); 936 this.backend = fallbackBackend; 937 } 938 return engine; 939 } 940 941 /** 942 * Generates embeddings from a list of tab data structures 943 * 944 * @param tabList List of tabs with label (title) and description keys 945 * @returns {Promise<*[]>} List of embeddings (2d array) 946 * @private 947 */ 948 async _generateEmbeddings(textToEmbedList) { 949 const inputData = { 950 inputArgs: textToEmbedList, 951 runOptions: { 952 pooling: "mean", 953 normalize: true, 954 }, 955 }; 956 957 if (SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) { 958 this.embeddingEngine = await this._createMLEngine(this.config.embedding); 959 } 960 const request = { 961 args: [inputData.inputArgs], 962 options: inputData.runOptions, 963 }; 964 return await this.embeddingEngine.run(request); 965 } 966 967 /** 968 * Clusters in desired methods 969 * based on the config of the class 970 * 971 * @param tabList List of tabs as array 972 * @param docEmbeddings Precomputed embeddings for the Tab as two dimensional array 973 * @param k Desired number of clusters. Tries a range of sizes if 0. 974 * @param {function} randomFunc Optional seeded random number generator for testing 975 * @returns {SmartTabGroupingResult} 976 * @private 977 */ 978 _clusterEmbeddings({ 979 tabs, 980 embeddings, 981 k, 982 randomFunc, 983 anchorIndices, 984 alreadyGroupedIndices = [], 985 }) { 986 let allItems; 987 988 const freezeAnchorsInZeroCluster = 989 anchorIndices && 990 this.config.clustering.anchorMethod == ANCHOR_METHODS.FIXED; 991 992 const dimReductionMethod = this.config.clustering.dimReductionMethod; 993 switch (dimReductionMethod) { 994 default: 995 // Dimensionality reduction support is landing very soon. 996 break; 997 } 998 k = k || 0; 999 let startK = k; 1000 let endK = k + 1; 1001 if (!k) { 1002 startK = 2; 1003 // Find a reasonable max # of clusters 1004 endK = 1005 Math.min( 1006 Math.floor(Math.log(embeddings.length) * 2.0), 1007 embeddings.length 1008 ) + 1; 1009 } 1010 let bestResult; 1011 let bestResultSilScore = -100.0; 1012 let bestResultCenterCluster = 0; 1013 1014 const clusteringMethod = this.config.clustering.clusterImplementation; 1015 const clusteringTriesPerK = this.config.clustering.clusteringTriesPerK; 1016 for (let curK = startK; curK < endK; curK++) { 1017 let bestItemsForK; 1018 let bestInertiaForK = 500000000000; 1019 for (let j = 0; j < clusteringTriesPerK; j++) { 1020 switch (clusteringMethod) { 1021 case CLUSTER_METHODS.KMEANS: 1022 allItems = kmeansPlusPlus({ 1023 data: embeddings, 1024 k: curK, 1025 maxIterations: 0, 1026 randomFunc, 1027 anchorIndices, 1028 preassignedIndices: 1029 this.config.clustering.pregroupedHandlingMethod === 1030 PREGROUPED_HANDLING_METHODS.EXCLUDE 1031 ? alreadyGroupedIndices 1032 : [], 1033 freezeAnchorsInZeroCluster, 1034 }); 1035 break; 1036 default: 1037 throw Error("Clustering implementation not supported"); 1038 } 1039 const tempResult = new SmartTabGroupingResult({ 1040 indices: allItems, 1041 embeddings, 1042 config: this.config, 1043 }); 1044 const inertia = tempResult.getCentroidInertia(); 1045 if (inertia < bestInertiaForK) { 1046 bestInertiaForK = inertia; 1047 bestItemsForK = tempResult; 1048 } 1049 } 1050 const silScores = silhouetteCoefficients( 1051 embeddings, 1052 bestItemsForK.indices 1053 ); 1054 1055 if ( 1056 freezeAnchorsInZeroCluster && 1057 this.config.clustering.pregroupedSilhouetteBoost > 0 1058 ) { 1059 // Boost silhouette score of target cluster when we are grouping around an existing cluster 1060 // pregroupedSilhouetteBoost indicates the relative weight of the cluster's score and all other cluster's combined 1061 silScores[0] *= this.config.clustering.pregroupedSilhouetteBoost; 1062 } 1063 1064 let avgSil = silScores.reduce((p, c) => p + c, 0) / silScores.length; 1065 let curAnchorCluster = 0; 1066 if (anchorIndices && !freezeAnchorsInZeroCluster) { 1067 const { anchorClusterIndex, numAnchorItemsInCluster } = 1068 getBestAnchorClusterInfo(bestItemsForK.indices, anchorIndices); 1069 curAnchorCluster = anchorClusterIndex; 1070 const penalty = 1071 (MISSING_ANCHOR_IN_CLUSTER_PENALTY * 1072 (anchorIndices.length - numAnchorItemsInCluster)) / 1073 anchorIndices.length; 1074 avgSil -= penalty; 1075 } 1076 if (avgSil > bestResultSilScore) { 1077 bestResultSilScore = avgSil; 1078 bestResult = bestItemsForK.indices; 1079 bestResultCenterCluster = curAnchorCluster; 1080 } 1081 } 1082 const result = new SmartTabGroupingResult({ 1083 indices: bestResult, 1084 tabs, 1085 embeddings, 1086 config: this.config, 1087 }); 1088 if (anchorIndices) { 1089 result.setAnchorClusterIndex( 1090 freezeAnchorsInZeroCluster ? 0 : bestResultCenterCluster 1091 ); // In our k-means clustering implementation anchor cluster is always first 1092 if (!freezeAnchorsInZeroCluster) { 1093 result.adjustClusterForAnchors(anchorIndices); 1094 } 1095 } 1096 return result; 1097 } 1098 1099 /** 1100 * Generate a label for tabs in a group created by the user 1101 * 1102 * @param tabs tabs that are currently in the group 1103 * @param otherTabs tabs in the window not part of the group 1104 * @return {Promise<null|string|string|*>} 1105 */ 1106 async getPredictedLabelForGroup(tabs, otherTabs) { 1107 const clusters = this.createStaticCluster(tabs); 1108 const otherClusters = this.createStaticCluster(otherTabs); 1109 let predictedLabel; 1110 try { 1111 // function below modifies "clusters" object 1112 await this.generateGroupLabels(clusters, otherClusters); 1113 predictedLabel = clusters.clusterRepresentations[0].predictedTopicLabel; 1114 } catch (e) { 1115 this.labelReason = LABEL_REASONS.ERROR; 1116 predictedLabel = ""; 1117 } 1118 return predictedLabel; 1119 } 1120 1121 /** 1122 * Generates clusters for a given list of tabs using precomputed embeddings or newly generated ones. 1123 * 1124 * @param {object[]} tabList - List of tab objects to be clustered. 1125 * @param {number[][]} [precomputedEmbeddings] - Precomputed embeddings for tab titles and descriptions. 1126 * @param {number} numClusters - Number of clusters to form. 1127 * @param {Function} randFunc - Random function used for clustering initialization. 1128 * @param {number[]} [anchorIndices=[]] - Indices of anchor tabs that should be prioritized in clustering. 1129 * @param {number[]} [alreadyGroupedIndices=[]] - Indices of tabs that are already assigned to groups. 1130 * @returns {SmartTabGroupingResult} - The best clustering result based on centroid inertia. 1131 */ 1132 async generateClusters( 1133 tabList, 1134 precomputedEmbeddings, 1135 numClusters, 1136 randFunc, 1137 anchorIndices = [], 1138 alreadyGroupedIndices = [] 1139 ) { 1140 numClusters = numClusters ?? 0; 1141 const structuredData = await this._prepareTabData(tabList); 1142 1143 // embeddings for title and description 1144 if (precomputedEmbeddings) { 1145 this.docEmbeddings = precomputedEmbeddings; 1146 } else { 1147 this.docEmbeddings = await this._generateEmbeddings( 1148 structuredData.map(a => a[EMBED_TEXT_KEY]) 1149 ); 1150 } 1151 let bestResultCluster; 1152 let bestResultDistance = 50000000.0; 1153 1154 const NUM_RUNS = 1; 1155 for (let i = 0; i < NUM_RUNS; i++) { 1156 const curResult = this._clusterEmbeddings({ 1157 tabs: tabList, 1158 embeddings: this.docEmbeddings, 1159 k: numClusters, 1160 randomFunc: randFunc, 1161 anchorIndices, 1162 alreadyGroupedIndices, 1163 }); 1164 const distance = curResult.getCentroidInertia(); 1165 if (distance < bestResultDistance) { 1166 bestResultDistance = distance; 1167 bestResultCluster = curResult; 1168 } 1169 } 1170 return bestResultCluster; 1171 } 1172 1173 /** 1174 * Create static cluster from a list of tabs. A single tab is Ok. Returns null for 0 tabs 1175 * 1176 * @param tabs 1177 * @returns {SmartTabGroupingResult} groupingResult 1178 */ 1179 createStaticCluster(tabs) { 1180 if (!tabs) { 1181 return null; 1182 } 1183 1184 return new SmartTabGroupingResult({ 1185 indices: [Array.from({ length: tabs.length }, (_, i) => i)], 1186 tabs, 1187 config: this.config, 1188 }); 1189 } 1190 1191 /** 1192 * Utility function that loads all required engines for Smart Tab Grouping and any dependent models 1193 * 1194 * @param {(progress: { percentage: number }) => void} progressCallback callback function to call. 1195 * Callback passes a dict with percentage indicating best effort 0.0-100.0 progress in model download. 1196 */ 1197 async preloadAllModels(progressCallback) { 1198 let previousProgress = -1; 1199 const expectedObjects = 1200 EXPECTED_TOPIC_MODEL_OBJECTS + EXPECTED_EMBEDDING_MODEL_OBJECTS; 1201 // TODO - Find a way to get these fields. Add as a transformers js callback or within remotesettings 1202 1203 const UPDATE_THRESHOLD_PERCENTAGE = 0.5; 1204 const ONE_MB = 1024 * 1024; 1205 const START_THRESHOLD_BYTES = ONE_MB * 0.2; 1206 1207 const mutliProgressAggregator = new lazy.MultiProgressAggregator({ 1208 progressCallback: ({ progress, totalLoaded, metadata }) => { 1209 if (totalLoaded < START_THRESHOLD_BYTES) { 1210 progress = 0.0; 1211 } else { 1212 const numObjSeen = metadata.totalObjectsSeen || 0; 1213 if (numObjSeen > 0 && numObjSeen < expectedObjects) { 1214 // When starting to download we may still be getting configs and not have all the data 1215 progress *= numObjSeen / expectedObjects; 1216 } 1217 if (progress > 100) { 1218 progress = 100; 1219 } 1220 } 1221 if ( 1222 Math.abs(previousProgress - progress) > UPDATE_THRESHOLD_PERCENTAGE 1223 ) { 1224 // Update only once changes are above a threshold to avoid throttling the UI with events. 1225 progressCallback({ 1226 percentage: progress, 1227 }); 1228 previousProgress = progress; 1229 } 1230 }, 1231 watchedTypes: [ 1232 lazy.Progress.ProgressType.DOWNLOAD, 1233 lazy.Progress.ProgressType.LOAD_FROM_CACHE, 1234 ], 1235 }); 1236 1237 const [topicEngine, embeddingEngine] = await Promise.all([ 1238 this._createMLEngine( 1239 this.config.topicGeneration, 1240 mutliProgressAggregator?.aggregateCallback.bind( 1241 mutliProgressAggregator 1242 ) || null 1243 ), 1244 this._createMLEngine( 1245 this.config.embedding, 1246 mutliProgressAggregator?.aggregateCallback.bind( 1247 mutliProgressAggregator 1248 ) || null 1249 ), 1250 ]); 1251 this.topicEngine = topicEngine; 1252 this.embeddingEngine = embeddingEngine; 1253 } 1254 1255 /** 1256 * Generate model input from keywords and documents 1257 * 1258 * @param {string []} keywords 1259 * @param {string []} documents 1260 */ 1261 createModelInput(keywords, documents) { 1262 if (!keywords || keywords.length === 0) { 1263 return `Topic from keywords: titles: \n${documents.join(" \n")}`; 1264 } 1265 return `Topic from keywords: ${keywords.join(", ")}. titles: \n${documents.join(" \n")}`; 1266 } 1267 1268 /** 1269 * One artifact of the LLM output is that sometimes words are duplicated 1270 * This function cuts the phrase when it sees the first duplicate word. 1271 * Handles simple singluar / plural duplicates (-s only). 1272 * 1273 * @param {string} phrase Input phrase 1274 * @returns {string} phrase cut before any duplicate word 1275 */ 1276 static cutAtDuplicateWords(phrase) { 1277 if (!phrase.length) { 1278 return phrase; 1279 } 1280 const wordsSet = new Set(); 1281 const wordList = phrase.split(" "); 1282 for (let i = 0; i < wordList.length; i++) { 1283 let baseWord = wordList[i].toLowerCase(); 1284 if (baseWord.length > 3) { 1285 if (baseWord.slice(-1) === "s") { 1286 baseWord = baseWord.slice(0, -1); 1287 } 1288 } 1289 if (wordsSet.has(baseWord)) { 1290 // We are seeing a baseWord word. Exit with just the words so far and don't 1291 // add any new words 1292 return wordList.slice(0, i).join(" "); 1293 } 1294 wordsSet.add(baseWord); 1295 } 1296 return phrase; // return original phrase 1297 } 1298 1299 /** 1300 * Removes trailing domain-related text such as '... - Mail' or '... | News' 1301 * If there's not enough information remaining after, we keep the text as is 1302 * 1303 * @param {string} text tab title with potential domain information 1304 * @return {string} 1305 */ 1306 static preprocessText(text) { 1307 // Matches 'xyz - Domain' or 'xyz | Domain' 1308 // with a space before and after delimiter 1309 // or if there are multiple delimiters next to each other 1310 const delimiters = /(?<=\s)[|–-]+(?=\s)/; 1311 const splitText = text.split(delimiters); 1312 1313 // ensure there's enough info without the last element 1314 const hasEnoughInfo = 1315 !!splitText.length && splitText.slice(0, -1).join(" ").length > 5; 1316 1317 // domain related texts are usually shorter, this takes care of the most common cases 1318 const isPotentialDomainInfo = 1319 splitText.length > 1 && splitText[splitText.length - 1].length < 20; 1320 1321 // If both conditions are met, remove the last chunk, filter out empty strings, 1322 // join on space, trim, and lowercase 1323 if (hasEnoughInfo && isPotentialDomainInfo) { 1324 return splitText 1325 .slice(0, -1) // everything except the last element 1326 .map(t => t.trim()) 1327 .filter(Boolean) // remove empty strings 1328 .join(" ") // join with spaces 1329 .trim(); // remove leading/trailing spaces 1330 } 1331 1332 // Otherwise, just return the text 1333 return text; 1334 } 1335 1336 /** 1337 * Postprocessing of raw output from Topic Model ML Engine 1338 * 1339 * @param {string | undefined} topic Raw topic phrase from topic model or undefined in case of an error 1340 */ 1341 processTopicModelResult(topic) { 1342 let basicResult = (topic || "").trim(); 1343 if (!basicResult) { 1344 this.labelReason = LABEL_REASONS.LOW_CONFIDENCE; 1345 } 1346 if (LABELS_TO_EXCLUDE.includes(basicResult.toLowerCase())) { 1347 this.labelReason = LABEL_REASONS.EXCLUDE; 1348 return ""; 1349 } 1350 return SmartTabGroupingManager.cutAtDuplicateWords(basicResult); 1351 } 1352 1353 /** 1354 * Add titles to a cluster in a SmartTabGroupingResult using generative tehniques 1355 * Currently this function only works with a single target group, and a separate 1356 * item that represents all other ungrouped tabs. 1357 * 1358 * In the future this may be updated to more generally find labels for a set of clusters. 1359 * 1360 * @param {SmartTabGroupingResult} groupingResult The cluster we are generating the label for 1361 * @param {SmartTabGroupingResult} otherGroupingResult A 'made up' cluster representing all other tabs in the window 1362 */ 1363 async generateGroupLabels(groupingResult, otherGroupingResult = null) { 1364 // Special case for a search page 1365 const searchTopicSpecialCase = Services.prefs.getBoolPref( 1366 "browser.tabs.groups.smart.searchTopicEnabled", 1367 true 1368 ); 1369 if ( 1370 searchTopicSpecialCase && 1371 groupingResult.clusterRepresentations.length == 1 && 1372 groupingResult.clusterRepresentations[0].isSingleTabSearch 1373 ) { 1374 if (groupingResult.clusterRepresentations[0].setSingleTabSearchLabel()) { 1375 return; 1376 } 1377 } 1378 1379 const { keywords, documents } = 1380 groupingResult.getRepresentativeDocsAndKeywords( 1381 otherGroupingResult 1382 ? otherGroupingResult.getRepresentativeDocuments() 1383 : [] 1384 ); 1385 const inputArgs = this.createModelInput( 1386 keywords ? keywords[0] : [], 1387 documents 1388 ); 1389 const requestInfo = { 1390 inputArgs, 1391 runOptions: { 1392 max_length: 6, 1393 }, 1394 }; 1395 1396 if (SmartTabGroupingManager.isEngineClosed(this.topicEngine)) { 1397 this.topicEngine = await this._createMLEngine( 1398 this.config.topicGeneration 1399 ); 1400 } 1401 const request = { 1402 args: [requestInfo.inputArgs], 1403 options: requestInfo.runOptions, 1404 }; 1405 const genLabelResults = await this.topicEngine.run(request); 1406 genLabelResults.forEach((genResult, genResultIndex) => { 1407 groupingResult.clusterRepresentations[ 1408 genResultIndex 1409 ].predictedTopicLabel = this.processTopicModelResult( 1410 genResult.generated_text 1411 ); 1412 }); 1413 } 1414 1415 getLabelReason() { 1416 return this.labelReason || LABEL_REASONS.DEFAULT; 1417 } 1418 1419 /** 1420 * Generates glean metrics for ml smart tab label / topic. 1421 * This is currently called when the user saves or cancels the "suggest label" flow. 1422 * 1423 * @param {string} action "save" or "cancel" 1424 * @param {number} numTabsInGroup Number of tabs used to generate the label 1425 * @param {string} mlLabel ML generated label for the tab group 1426 * @param {string} userLabel User saved label for the tab group 1427 * @param {string} id The id of the group 1428 */ 1429 async handleLabelTelemetry({ 1430 action, 1431 numTabsInGroup, 1432 mlLabel, 1433 userLabel, 1434 id = "", 1435 }) { 1436 const { [ML_TASK_TEXT2TEXT]: topicEngineConfig } = 1437 await this.getEngineConfigs(); 1438 const labelReason = this.getLabelReason(); 1439 Glean.tabgroup.smartTabTopic.record({ 1440 action, 1441 tabs_in_group: numTabsInGroup, 1442 ml_label_length: (mlLabel || "").length, 1443 user_label_length: (userLabel || "").length, 1444 levenshtein_distance: lazy.NLP.levenshtein( 1445 userLabel || "", 1446 mlLabel || "" 1447 ), 1448 model_revision: topicEngineConfig.modelRevision || "", 1449 id, 1450 label_reason: labelReason, 1451 backend: this.backend || "onnx-native", 1452 }); 1453 this.labelReason = LABEL_REASONS.DEFAULT; 1454 } 1455 1456 /** 1457 * Generates glean metrics for ml smart tab label / topic. 1458 * This is currently called when the user saves or cancels the "suggest other tabs" flow 1459 * 1460 * @param {string} action "save" or "cancel" 1461 * @param {number} numTabsInWindow Number of tabs in the current window 1462 * @param {number} numTabsInGroup Number of tabs in the current group 1463 * @param {number} numTabsSuggested Number of tabs suggested by the model 1464 * @param {number} numTabsApproved Number of tabs approved by the user 1465 * @param {number} numTabsRemoved Number of tabs removed by the user 1466 * @param {string} id The id of the group 1467 */ 1468 async handleSuggestTelemetry({ 1469 action, 1470 numTabsInWindow, 1471 numTabsInGroup, 1472 numTabsSuggested, 1473 numTabsApproved, 1474 numTabsRemoved, 1475 id = "", 1476 }) { 1477 const { [ML_TASK_FEATURE_EXTRACTION]: embeddingEngineConfig } = 1478 await this.getEngineConfigs(); 1479 Glean.tabgroup.smartTabSuggest.record({ 1480 action, 1481 tabs_in_window: numTabsInWindow, 1482 tabs_in_group: numTabsInGroup, 1483 tabs_suggested: numTabsSuggested, 1484 tabs_approved: numTabsApproved, 1485 tabs_removed: numTabsRemoved, 1486 model_revision: embeddingEngineConfig.modelRevision || "", 1487 id, 1488 backend: this.backend || "onnx-native", 1489 }); 1490 } 1491 1492 /** 1493 * Gets config that engine was initialized with 1494 * 1495 * @return {Promise<{"[ML_TASK_TEXT2TEXT]", "[ML_TASK_FEATURE_EXTRACTION]"}>} 1496 */ 1497 async getEngineConfigs() { 1498 if (!this.topicEngineConfig) { 1499 this.topicEngineConfig = await lazy.MLEngineParent.getInferenceOptions( 1500 this.config.topicGeneration.featureId, 1501 this.config.topicGeneration.taskName 1502 ); 1503 } 1504 if (!this.embeddingEngineConfig) { 1505 this.embeddingEngineConfig = 1506 await lazy.MLEngineParent.getInferenceOptions( 1507 this.config.embedding.featureId, 1508 this.config.embedding.taskName 1509 ); 1510 } 1511 return { 1512 [ML_TASK_TEXT2TEXT]: this.topicEngineConfig, 1513 [ML_TASK_FEATURE_EXTRACTION]: this.embeddingEngineConfig, 1514 }; 1515 } 1516 } 1517 1518 export class SmartTabGroupingResult { 1519 #anchorClusterIndex = -1; // Index of cluster that has original items we're building clustering around, when building around an existing item. 1520 1521 /** 1522 * Creates a result from indices and complete tab and embedding lists. 1523 * This may create some extra data for management later 1524 * 1525 * @param indices indices of clusters (eg [[2,4], [1], [3]]_ 1526 * @param tabItems 1D array of tabs 1527 * @param embeddingItems Two dimensional array of embeddings 1528 * @param config Cluster config 1529 */ 1530 constructor({ indices = [], tabs, embeddings, config }) { 1531 this.embeddingItems = embeddings; 1532 this.config = config; 1533 this.indices = indices.filter(subArray => !!subArray.length); // Cleanup any empty clusters 1534 this.tabItems = tabs; 1535 this._buildClusterRepresentations(); 1536 } 1537 1538 /** 1539 * Builds list of ClusterRepresentations 1540 */ 1541 _buildClusterRepresentations() { 1542 this.clusterRepresentations = this.indices.map(subClusterIndices => { 1543 const tabItemsMapped = 1544 this.tabItems && subClusterIndices.map(idx => this.tabItems[idx]); 1545 const embeddingItemsMapped = 1546 this.embeddingItems && 1547 subClusterIndices.map(idx => this.embeddingItems[idx]); 1548 return new ClusterRepresentation({ 1549 tabs: tabItemsMapped, 1550 embeddings: embeddingItemsMapped, 1551 config: this.config, 1552 }); 1553 }); 1554 } 1555 1556 /** 1557 * Returns a list of documents for each cluster. Currently it is a list of documents picked 1558 * in no particular order. 1559 * 1560 * @return {[strings]} Title and description that represent the cluster. (If no docs are in the class, then titles are returned) 1561 */ 1562 getRepresentativeDocuments() { 1563 if (!this.documents) { 1564 this.documents = this.tabItems.map( 1565 t => t[this.config.dataConfig.titleKey] 1566 ); 1567 } 1568 // set a limit of 10 for now 1569 return this.documents.slice(0, 10); 1570 } 1571 1572 /** 1573 * Returns the keywords and documents for the cluster, computing if needed 1574 * Does not return keywods if only one document is passed to the function. 1575 * 1576 * @param {string[]} otherDocuments other clusters that we'll compare against 1577 * @return keywords and documents that represent the cluster 1578 */ 1579 getRepresentativeDocsAndKeywords(otherDocuments = []) { 1580 this.documents = this.getRepresentativeDocuments(); 1581 if (!this.keywords) { 1582 const joinedDocs = this.documents.slice(0, 3).join(" "); 1583 const otherDocs = otherDocuments.join(" "); 1584 if (this.documents.length > 1) { 1585 const keywordExtractor = new KeywordExtractor(); 1586 this.keywords = keywordExtractor.fitTransform([joinedDocs, otherDocs]); 1587 } else { 1588 this.keywords = []; 1589 } 1590 } 1591 return { keywords: this.keywords, documents: this.documents }; 1592 } 1593 1594 setAnchorClusterIndex(index) { 1595 this.#anchorClusterIndex = index; 1596 } 1597 1598 /** 1599 * Get the cluster we originally are grouping around (finding additinoal item) 1600 * 1601 * @returns ClusterRepresentation 1602 */ 1603 getAnchorCluster() { 1604 if (this.#anchorClusterIndex === -1) { 1605 return null; 1606 } 1607 return this.clusterRepresentations[this.#anchorClusterIndex]; 1608 } 1609 1610 /** 1611 * Given the indices that we were clustering around, make sure they are are all in the target grouping 1612 * Our generic k-means clustering might have them in separate groups 1613 */ 1614 adjustClusterForAnchors(anchorIndices) { 1615 if (!anchorIndices.length) { 1616 return; 1617 } 1618 const anchorSet = new Set(anchorIndices); 1619 for (let i = 0; i < this.indices.length; i++) { 1620 if (i === this.#anchorClusterIndex) { 1621 continue; 1622 } 1623 this.indices[i] = this.indices[i].filter(item => { 1624 if (anchorSet.has(item)) { 1625 this.indices[this.#anchorClusterIndex].push(item); 1626 return false; 1627 } 1628 return true; 1629 }); 1630 } 1631 this._buildClusterRepresentations(); 1632 } 1633 1634 /** 1635 * Prints information about the cluster 1636 */ 1637 printClusters() { 1638 for (let cluster of this.clusterRepresentations) { 1639 cluster.print(); 1640 } 1641 } 1642 1643 /** 1644 * Computes the inertia of the cluster which is the sum of square total distance. 1645 * 1646 * @returns {number} 1647 */ 1648 getCentroidInertia() { 1649 let runningTotalDistance = 0; 1650 this.clusterRepresentations.forEach(rep => { 1651 runningTotalDistance += rep.computeTotalSquaredCentroidDistance(); 1652 }); 1653 return runningTotalDistance; 1654 } 1655 1656 /** 1657 * Converts a cluster representation to a flat list of tabs, with clusterID key in each 1658 * tab representing the id of the cluster it was part of. 1659 * 1660 * @returns {[object]} 1661 */ 1662 _flatMapItemsInClusters() { 1663 return this.clusterRepresentations.reduce((result, clusterRep) => { 1664 const annotatedTabs = clusterRep.tabs.map(a => { 1665 let c = {}; 1666 Object.assign(c, a); 1667 c.clusterID = clusterRep.clusterID; 1668 return c; 1669 }); 1670 return result.concat(annotatedTabs); 1671 }, []); 1672 } 1673 1674 /** 1675 * Get rand score which describes the accuracy versus a user labeled 1676 * annotation on the dataset. Requires the dataset to be labeled. 1677 * 1678 * @param labelKey Key in the tabs that represent a unique label ID for the cluster. 1679 * @returns {number} The rand score. 1680 */ 1681 getRandScore(labelKey = "annotatedLabel") { 1682 const combinedItems = this._flatMapItemsInClusters(); 1683 return computeRandScore(combinedItems, "clusterID", labelKey); 1684 } 1685 1686 /** 1687 * Get accuracy for a specific cluster 1688 * 1689 * @param labelKey Key in the tabs that represent a unique label ID for the cluster. 1690 * @param clusterValue is the cluster we are comparing 1691 * @returns {number} The rand score. 1692 */ 1693 getAccuracyStatsForCluster(labelKey = "annotatedLabel", clusterValue) { 1694 const combinedItems = this._flatMapItemsInClusters(); 1695 1696 let keyClusterId = combinedItems.find( 1697 a => a[labelKey] === clusterValue 1698 ).clusterID; 1699 1700 let truePositives = 0, 1701 trueNegatives = 0, 1702 falseNegatives = 0, 1703 falsePositives = 0; 1704 1705 combinedItems.forEach(item => { 1706 const sameLabel = item[labelKey] === clusterValue; 1707 const sameCluster = item.clusterID === keyClusterId; 1708 if (sameLabel && sameCluster) { 1709 truePositives++; 1710 } 1711 if (!sameLabel && !sameCluster) { 1712 trueNegatives++; 1713 } 1714 if (sameLabel && !sameCluster) { 1715 falseNegatives++; 1716 } 1717 if (!sameLabel && sameCluster) { 1718 falsePositives++; 1719 } 1720 }); 1721 return getAccuracyStats({ 1722 truePositives, 1723 trueNegatives, 1724 falsePositives, 1725 falseNegatives, 1726 }); 1727 } 1728 } 1729 1730 /** 1731 * Utility function to generate a random ID string 1732 * 1733 * @param len Length of the string 1734 * @returns {string} 1735 */ 1736 function genHexString(len) { 1737 const hex = "0123456789ABCDEF"; 1738 let output = ""; 1739 for (let i = 0; i < len; ++i) { 1740 output += hex.charAt(Math.floor(Math.random() * hex.length)); 1741 } 1742 return output; 1743 } 1744 1745 class EmbeddingCluster { 1746 constructor({ tabs, embeddings, centroid }) { 1747 this.embeddings = embeddings; 1748 this.centroid = 1749 centroid || (embeddings && computeCentroidFrom2DArray(this.embeddings)); 1750 this.tabs = tabs; 1751 } 1752 1753 /** 1754 * @returns total sum euclidan squared distance of each item from cluster's centroid 1755 */ 1756 computeTotalSquaredCentroidDistance() { 1757 let totalDistance = 0; 1758 if (this.embeddings.length === 0) { 1759 return 0; 1760 } 1761 this.embeddings.forEach(embedding => { 1762 totalDistance += euclideanDistance(this.centroid, embedding, true); 1763 }); 1764 return totalDistance; 1765 } 1766 1767 /** 1768 * Returns number of items in the cluster 1769 * 1770 * @returns {int} 1771 */ 1772 numItems() { 1773 return this.tabs.length; 1774 } 1775 } 1776 1777 /** 1778 * Represents a single cluster with additional saved metadata 1779 */ 1780 export class ClusterRepresentation extends EmbeddingCluster { 1781 constructor({ tabs, embeddings, centroid, config }) { 1782 super({ tabs, embeddings, centroid }); 1783 this.config = config; 1784 this.predictedTopicLabel = null; 1785 this.annotatedTopicLabel = null; 1786 this.userEditedTopicLabel = null; 1787 this.representativeText = null; 1788 this.keywords = null; 1789 this.documents = null; 1790 this.clusterID = genHexString(10); 1791 this.isSingleTabSearch = tabs?.length == 1 && isSearchTab(tabs[0]); 1792 } 1793 1794 /** 1795 * For a single tab cluster with a search field, set the predicted topic 1796 * to be the title of the page 1797 * 1798 * @returns {boolean} True if we updated the cluster label successfully 1799 */ 1800 setSingleTabSearchLabel() { 1801 if (this.tabs.length !== 1) { 1802 return false; 1803 } 1804 const pageTitle = this.tabs[0][this.config.dataConfig.titleKey] || ""; 1805 for (let i = pageTitle.length - 1; i > 0; i--) { 1806 if (TITLE_DELIMETER_SET.has(pageTitle[i])) { 1807 const topicString = pageTitle.substring(0, i).trim(); 1808 if (topicString.length > MAX_NON_SUMMARIZED_SEARCH_LENGTH) { 1809 return false; 1810 } 1811 // Capitalize first character of each word. Regex returns first char of each word 1812 this.predictedTopicLabel = topicString.replace(/(^|\s)\S/g, t => 1813 t.toUpperCase() 1814 ); 1815 return true; 1816 } 1817 } 1818 return false; 1819 } 1820 1821 /** 1822 * Returns the representative text for a cluster, computing it if needed 1823 */ 1824 getRepresentativeText() { 1825 if (!this.representativeText) { 1826 this.representativeText = this._generateRepresentativeText(); 1827 } 1828 return this.representativeText; 1829 } 1830 1831 /** 1832 * Returns representative text for a cluster. 1833 * For this in initial implementation it simply returns title from a few tabs 1834 * 1835 * @returns {string} 1836 * @private 1837 */ 1838 _generateRepresentativeText() { 1839 let text = ""; 1840 const titleKey = this.config.dataConfig.titleKey; 1841 for (const tab of this.tabs.slice(0, 3)) { 1842 text += `\n${tab[titleKey]}`; 1843 } 1844 return text; 1845 } 1846 1847 print() { 1848 // Add console log for debugging 1849 } 1850 }