ThomSample.mjs (4585B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 4 5 /* 6 * This module has utility functions for doing thompson sampling and ranking 7 */ 8 9 /** 10 * Sample from a Normal distribution 11 * 12 * @param {object} getRandom Helper function for making tests deterministic 13 * @returns {number } A sampled float from a Gamma distribution 14 */ 15 export function sampleNormal(getRandom = Math.random) { 16 // Magic constants below are straight from Leva paper 17 // These are left as variables to better match Leva paper 18 const s = 0.449871; 19 const t = -0.386595; 20 const a = 0.196; 21 const b = 0.25472; 22 23 let q; 24 let u; 25 let v; 26 let x; 27 let y; 28 while (true) { 29 u = getRandom(); 30 v = 1.7156 * (getRandom() - 0.5); 31 x = u - s; 32 y = Math.abs(v) - t; 33 q = Math.pow(x, 2) + y * (a * y - b * x); 34 35 if (q < 0.27597) { 36 break; 37 } 38 if (q > 0.27846) { 39 continue; 40 } 41 if (v * v <= -4 * Math.log(u) * u * u) { 42 break; 43 } 44 } 45 46 return v / u; 47 } 48 49 /** 50 * Sample from a Gamma distribution, only valid for a>=1 51 * 52 * @param {number} a Shape of the Gamma distribution 53 * @param {object} normalSampler Helper function for making tests deterministic 54 * @param {object} uniSampler Helper function for making tests deterministic 55 * @returns {number } A sampled float from a Gamma distribution 56 */ 57 export function sampleGamma( 58 a, 59 normalSampler = sampleNormal, 60 uniSampler = Math.random 61 ) { 62 // Marsaglia and Tsang method for sampling from gamma 63 // requires a > 1 to be valid! there are other methods for a < 1 to be implemented 64 let uni; 65 let v; 66 let x; 67 const d = a - 1 / 3; 68 const c = 1 / Math.sqrt(9 * d); 69 do { 70 x = normalSampler(); 71 v = Math.pow(1 + c * x, 3); 72 uni = uniSampler(); 73 } while ( 74 v < 0 || 75 Math.log(uni) > 0.5 * Math.pow(x, 2) + d - d * v + d * Math.log(v) 76 ); 77 return d * v; 78 } 79 80 /** 81 * Sort an array of keys by values in scores 82 * 83 * @param {number[]} scores The vector to with values we sort by 84 * @param {object[]} keys Array of keys to be sorted 85 * @returns {[number[], object[] ]} Sorted keys and sorted scores 86 */ 87 export function sortKeysValues(scores, keys) { 88 // Pair the values 89 const paired = scores.map((score, i) => ({ score, key: keys[i] })); 90 91 // Sort by score descending 92 paired.sort((a, b) => b.score - a.score); 93 94 // Unzip into separate arrays 95 const sortedScores = paired.map(p => p.score); 96 const sortedKeys = paired.map(p => p.key); 97 98 return [sortedKeys, sortedScores]; 99 } 100 101 /** 102 * Sample from a Beta distribution, only valid for a and b >=1 103 * 104 * @param {number} a Alpha in the Beta distribution 105 * @param {number} b Beta in the Beta distribution 106 * @returns {number } A sampled float from a Beta distribution 107 */ 108 export function sampleBeta(a, b) { 109 const ag = sampleGamma(a); 110 const bg = sampleGamma(b); 111 return ag / (ag + bg); 112 } 113 114 /** 115 * Utility function to sort items based on a Thompson Sampling draw 116 * 117 * @param {object} observationsPriors - An object containing counts and priors for clicks and impressions 118 * @param {int[]} observationsPriors.key_array - Array of items to be ranked 119 * @param {int[]} observationsPriors.obs_positive - Array of clicks 120 * @param {int[]} observationsPriors.obs_negative - Array of impressions 121 * @param {int[]} observationsPriors.prior_positive - Array of priors for clicks 122 * @param {int[]} observationsPriors.prior_negative - Array of priors for impressions 123 * @param {boolean} observationsPriors.do_sort - Boolean flag for sorting scores and key_array 124 * @returns {[accuracy: final_keys, kappa: final_thetas]} An object containing arrays of keys and scores 125 */ 126 export async function thompsonSampleSort({ 127 key_array, 128 obs_positive, 129 obs_negative, 130 prior_positive, 131 prior_negative, 132 do_sort = true, 133 }) { 134 // If priors are not provided, initialize them to arrays of 1s 135 const used_prior_positive = prior_positive ?? obs_positive.map(() => 1); 136 const used_prior_negative = prior_negative ?? obs_negative.map(() => 1); 137 138 // sample a theta (score) for each item 139 const thetas = key_array.map((_, i) => 140 sampleBeta( 141 obs_positive[i] + used_prior_positive[i], 142 obs_negative[i] + used_prior_negative[i] 143 ) 144 ); 145 // sort theta and key_array by theta 146 let final_keys; 147 let final_thetas; 148 if (do_sort) { 149 [final_keys, final_thetas] = sortKeysValues(thetas, key_array); 150 } else { 151 final_keys = key_array; 152 final_thetas = thetas; 153 } 154 return [final_keys, final_thetas]; 155 }