gen_constrained_tokenset.py (4193B)
1 #!/usr/bin/env python3 2 ## 3 ## Copyright (c) 2016, Alliance for Open Media. All rights reserved. 4 ## 5 ## This source code is subject to the terms of the BSD 2 Clause License and 6 ## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 7 ## was not distributed with this source code in the LICENSE file, you can 8 ## obtain it at www.aomedia.org/license/software. If the Alliance for Open 9 ## Media Patent License 1.0 was not distributed with this source code in the 10 ## PATENTS file, you can obtain it at www.aomedia.org/license/patent. 11 ## 12 """Generate the probability model for the constrained token set. 13 14 Model obtained from a 2-sided zero-centered distribution derived 15 from a Pareto distribution. The cdf of the distribution is: 16 cdf(x) = 0.5 + 0.5 * sgn(x) * [1 - {alpha/(alpha + |x|)} ^ beta] 17 18 For a given beta and a given probability of the 1-node, the alpha 19 is first solved, and then the {alpha, beta} pair is used to generate 20 the probabilities for the rest of the nodes. 21 """ 22 23 import heapq 24 import sys 25 import numpy as np 26 import scipy.optimize 27 import scipy.stats 28 29 30 def cdf_spareto(x, xm, beta): 31 p = 1 - (xm / (np.abs(x) + xm))**beta 32 p = 0.5 + 0.5 * np.sign(x) * p 33 return p 34 35 36 def get_spareto(p, beta): 37 cdf = cdf_spareto 38 39 def func(x): 40 return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) / 41 (1 - cdf(0.5, x, beta)) - p)**2 42 43 alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12) 44 parray = np.zeros(11) 45 parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5) 46 parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta))) 47 parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta))) 48 parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta))) 49 parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta))) 50 parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta))) 51 parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta))) 52 parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta))) 53 parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta))) 54 parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta))) 55 parray[10] = 2 * (1. - cdf(66.5, alpha, beta)) 56 return parray 57 58 59 def quantize_probs(p, save_first_bin, bits): 60 """Quantize probability precisely. 61 62 Quantize probabilities minimizing dH (Kullback-Leibler divergence) 63 approximated by: sum (p_i-q_i)^2/p_i. 64 References: 65 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 66 https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit 67 """ 68 num_sym = p.size 69 p = np.clip(p, 1e-16, 1) 70 L = 2**bits 71 pL = p * L 72 ip = 1. / p # inverse probability 73 q = np.clip(np.round(pL), 1, L + 1 - num_sym) 74 quant_err = (pL - q)**2 * ip 75 sgn = np.sign(L - q.sum()) # direction of correction 76 if sgn != 0: # correction is needed 77 v = [] # heap of adjustment results (adjustment err, index) of each symbol 78 for i in range(1 if save_first_bin else 0, num_sym): 79 q_adj = q[i] + sgn 80 if q_adj > 0 and q_adj < L: 81 adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i] 82 heapq.heappush(v, (adj_err, i)) 83 while q.sum() != L: 84 # apply lowest error adjustment 85 (adj_err, i) = heapq.heappop(v) 86 quant_err[i] += adj_err 87 q[i] += sgn 88 # calculate the cost of adjusting this symbol again 89 q_adj = q[i] + sgn 90 if q_adj > 0 and q_adj < L: 91 adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i] 92 heapq.heappush(v, (adj_err, i)) 93 return q 94 95 96 def get_quantized_spareto(p, beta, bits, first_token): 97 parray = get_spareto(p, beta) 98 parray = parray[1:] / (1 - parray[0]) 99 # CONFIG_NEW_TOKENSET 100 if first_token > 1: 101 parray = parray[1:] / (1 - parray[0]) 102 qarray = quantize_probs(parray, first_token == 1, bits) 103 return qarray.astype(np.int) 104 105 106 def main(bits=15, first_token=1): 107 beta = 8 108 for q in range(1, 256): 109 parray = get_quantized_spareto(q / 256., beta, bits, first_token) 110 assert parray.sum() == 2**bits 111 print('{', ', '.join('%d' % i for i in parray), '},') 112 113 114 if __name__ == '__main__': 115 if len(sys.argv) > 2: 116 main(int(sys.argv[1]), int(sys.argv[2])) 117 elif len(sys.argv) > 1: 118 main(int(sys.argv[1])) 119 else: 120 main()