tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_ma.cc (38538B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/modular/encoding/enc_ma.h"
      7 
      8 #include <algorithm>
      9 #include <cstdlib>
     10 #include <limits>
     11 #include <numeric>
     12 #include <queue>
     13 #include <vector>
     14 
     15 #include "lib/jxl/modular/encoding/ma_common.h"
     16 
     17 #undef HWY_TARGET_INCLUDE
     18 #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc"
     19 #include <hwy/foreach_target.h>
     20 #include <hwy/highway.h>
     21 
     22 #include "lib/jxl/base/fast_math-inl.h"
     23 #include "lib/jxl/base/random.h"
     24 #include "lib/jxl/enc_ans.h"
     25 #include "lib/jxl/modular/encoding/context_predict.h"
     26 #include "lib/jxl/modular/options.h"
     27 #include "lib/jxl/pack_signed.h"
     28 HWY_BEFORE_NAMESPACE();
     29 namespace jxl {
     30 namespace HWY_NAMESPACE {
     31 
     32 // These templates are not found via ADL.
     33 using hwy::HWY_NAMESPACE::Eq;
     34 using hwy::HWY_NAMESPACE::IfThenElse;
     35 using hwy::HWY_NAMESPACE::Lt;
     36 using hwy::HWY_NAMESPACE::Max;
     37 
     38 const HWY_FULL(float) df;
     39 const HWY_FULL(int32_t) di;
     40 size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
     41 
     42 // Compute entropy of the histogram, taking into account the minimum probability
     43 // for symbols with non-zero counts.
     44 float EstimateBits(const int32_t *counts, size_t num_symbols) {
     45  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
     46  const auto zero = Zero(df);
     47  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
     48  const auto inv_total = Set(df, 1.0f / total);
     49  auto bits_lanes = Zero(df);
     50  auto total_v = Set(di, total);
     51  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
     52    const auto counts_iv = LoadU(di, &counts[i]);
     53    const auto counts_fv = ConvertTo(df, counts_iv);
     54    const auto probs = Mul(counts_fv, inv_total);
     55    const auto mprobs = Max(probs, minprob);
     56    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
     57                                 BitCast(di, FastLog2f(df, mprobs)));
     58    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
     59  }
     60  return GetLane(SumOfLanes(df, bits_lanes));
     61 }
     62 
     63 void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred,
     64                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
     65  // Note that the tree splits on *strictly greater*.
     66  (*tree)[pos].lchild = tree->size();
     67  (*tree)[pos].rchild = tree->size() + 1;
     68  (*tree)[pos].splitval = splitval;
     69  (*tree)[pos].property = property;
     70  tree->emplace_back();
     71  tree->back().property = -1;
     72  tree->back().predictor = rpred;
     73  tree->back().predictor_offset = roff;
     74  tree->back().multiplier = 1;
     75  tree->emplace_back();
     76  tree->back().property = -1;
     77  tree->back().predictor = lpred;
     78  tree->back().predictor_offset = loff;
     79  tree->back().multiplier = 1;
     80 }
     81 
     82 enum class IntersectionType { kNone, kPartial, kInside };
     83 IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack,
     84                               uint32_t &partial_axis, uint32_t &partial_val) {
     85  bool partial = false;
     86  for (size_t i = 0; i < kNumStaticProperties; i++) {
     87    if (haystack[i][0] >= needle[i][1]) {
     88      return IntersectionType::kNone;
     89    }
     90    if (haystack[i][1] <= needle[i][0]) {
     91      return IntersectionType::kNone;
     92    }
     93    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
     94      continue;
     95    }
     96    partial = true;
     97    partial_axis = i;
     98    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
     99      partial_val = haystack[i][0] - 1;
    100    } else {
    101      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
    102                  haystack[i][1] < needle[i][1]);
    103      partial_val = haystack[i][1] - 1;
    104    }
    105  }
    106  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
    107 }
    108 
    109 void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
    110                      size_t end, size_t prop) {
    111  auto cmp = [&](size_t a, size_t b) {
    112    return static_cast<int32_t>(tree_samples.Property(prop, a)) -
    113           static_cast<int32_t>(tree_samples.Property(prop, b));
    114  };
    115  Rng rng(0);
    116  while (end > begin + 1) {
    117    {
    118      size_t pivot = rng.UniformU(begin, end);
    119      tree_samples.Swap(begin, pivot);
    120    }
    121    size_t pivot_begin = begin;
    122    size_t pivot_end = pivot_begin + 1;
    123    for (size_t i = begin + 1; i < end; i++) {
    124      JXL_DASSERT(i >= pivot_end);
    125      JXL_DASSERT(pivot_end > pivot_begin);
    126      int32_t cmp_result = cmp(i, pivot_begin);
    127      if (cmp_result < 0) {  // i < pivot, move pivot forward and put i before
    128                             // the pivot.
    129        tree_samples.ThreeShuffle(pivot_begin, pivot_end, i);
    130        pivot_begin++;
    131        pivot_end++;
    132      } else if (cmp_result == 0) {
    133        tree_samples.Swap(pivot_end, i);
    134        pivot_end++;
    135      }
    136    }
    137    JXL_DASSERT(pivot_begin >= begin);
    138    JXL_DASSERT(pivot_end > pivot_begin);
    139    JXL_DASSERT(pivot_end <= end);
    140    for (size_t i = begin; i < pivot_begin; i++) {
    141      JXL_DASSERT(cmp(i, pivot_begin) < 0);
    142    }
    143    for (size_t i = pivot_end; i < end; i++) {
    144      JXL_DASSERT(cmp(i, pivot_begin) > 0);
    145    }
    146    for (size_t i = pivot_begin; i < pivot_end; i++) {
    147      JXL_DASSERT(cmp(i, pivot_begin) == 0);
    148    }
    149    // We now have that [begin, pivot_begin) is < pivot, [pivot_begin,
    150    // pivot_end) is = pivot, and [pivot_end, end) is > pivot.
    151    // If pos falls in the first or the last interval, we continue in that
    152    // interval; otherwise, we are done.
    153    if (pivot_begin > pos) {
    154      end = pivot_begin;
    155    } else if (pivot_end < pos) {
    156      begin = pivot_end;
    157    } else {
    158      break;
    159    }
    160  }
    161 }
    162 
    163 void FindBestSplit(TreeSamples &tree_samples, float threshold,
    164                   const std::vector<ModularMultiplierInfo> &mul_info,
    165                   StaticPropRange initial_static_prop_range,
    166                   float fast_decode_multiplier, Tree *tree) {
    167  struct NodeInfo {
    168    size_t pos;
    169    size_t begin;
    170    size_t end;
    171    uint64_t used_properties;
    172    StaticPropRange static_prop_range;
    173  };
    174  std::vector<NodeInfo> nodes;
    175  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
    176                           initial_static_prop_range});
    177 
    178  size_t num_predictors = tree_samples.NumPredictors();
    179  size_t num_properties = tree_samples.NumProperties();
    180 
    181  // TODO(veluca): consider parallelizing the search (processing multiple nodes
    182  // at a time).
    183  while (!nodes.empty()) {
    184    size_t pos = nodes.back().pos;
    185    size_t begin = nodes.back().begin;
    186    size_t end = nodes.back().end;
    187    uint64_t used_properties = nodes.back().used_properties;
    188    StaticPropRange static_prop_range = nodes.back().static_prop_range;
    189    nodes.pop_back();
    190    if (begin == end) continue;
    191 
    192    struct SplitInfo {
    193      size_t prop = 0;
    194      uint32_t val = 0;
    195      size_t pos = 0;
    196      float lcost = std::numeric_limits<float>::max();
    197      float rcost = std::numeric_limits<float>::max();
    198      Predictor lpred = Predictor::Zero;
    199      Predictor rpred = Predictor::Zero;
    200      float Cost() const { return lcost + rcost; }
    201    };
    202 
    203    SplitInfo best_split_static_constant;
    204    SplitInfo best_split_static;
    205    SplitInfo best_split_nonstatic;
    206    SplitInfo best_split_nowp;
    207 
    208    JXL_DASSERT(begin <= end);
    209    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
    210 
    211    // Compute the maximum token in the range.
    212    size_t max_symbols = 0;
    213    for (size_t pred = 0; pred < num_predictors; pred++) {
    214      for (size_t i = begin; i < end; i++) {
    215        uint32_t tok = tree_samples.Token(pred, i);
    216        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
    217      }
    218    }
    219    max_symbols = Padded(max_symbols);
    220    std::vector<int32_t> counts(max_symbols * num_predictors);
    221    std::vector<uint32_t> tot_extra_bits(num_predictors);
    222    for (size_t pred = 0; pred < num_predictors; pred++) {
    223      for (size_t i = begin; i < end; i++) {
    224        counts[pred * max_symbols + tree_samples.Token(pred, i)] +=
    225            tree_samples.Count(i);
    226        tot_extra_bits[pred] +=
    227            tree_samples.NBits(pred, i) * tree_samples.Count(i);
    228      }
    229    }
    230 
    231    float base_bits;
    232    {
    233      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
    234      base_bits =
    235          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
    236          tot_extra_bits[pred];
    237    }
    238 
    239    SplitInfo *best = &best_split_nonstatic;
    240 
    241    SplitInfo forced_split;
    242    // The multiplier ranges cut halfway through the current ranges of static
    243    // properties. We do this even if the current node is not a leaf, to
    244    // minimize the number of nodes in the resulting tree.
    245    for (const auto &mmi : mul_info) {
    246      uint32_t axis;
    247      uint32_t val;
    248      IntersectionType t =
    249          BoxIntersects(static_prop_range, mmi.range, axis, val);
    250      if (t == IntersectionType::kNone) continue;
    251      if (t == IntersectionType::kInside) {
    252        (*tree)[pos].multiplier = mmi.multiplier;
    253        break;
    254      }
    255      if (t == IntersectionType::kPartial) {
    256        forced_split.val = tree_samples.QuantizeProperty(axis, val);
    257        forced_split.prop = axis;
    258        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
    259        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
    260        best = &forced_split;
    261        best->pos = begin;
    262        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
    263        for (size_t x = begin; x < end; x++) {
    264          if (tree_samples.Property(best->prop, x) <= best->val) {
    265            best->pos++;
    266          }
    267        }
    268        break;
    269      }
    270    }
    271 
    272    if (best != &forced_split) {
    273      std::vector<int> prop_value_used_count;
    274      std::vector<int> count_increase;
    275      std::vector<size_t> extra_bits_increase;
    276      // For each property, compute which of its values are used, and what
    277      // tokens correspond to those usages. Then, iterate through the values,
    278      // and compute the entropy of each side of the split (of the form `prop >
    279      // threshold`). Finally, find the split that minimizes the cost.
    280      struct CostInfo {
    281        float cost = std::numeric_limits<float>::max();
    282        float extra_cost = 0;
    283        float Cost() const { return cost + extra_cost; }
    284        Predictor pred;  // will be uninitialized in some cases, but never used.
    285      };
    286      std::vector<CostInfo> costs_l;
    287      std::vector<CostInfo> costs_r;
    288 
    289      std::vector<int32_t> counts_above(max_symbols);
    290      std::vector<int32_t> counts_below(max_symbols);
    291 
    292      // The lower the threshold, the higher the expected noisiness of the
    293      // estimate. Thus, discourage changing predictors.
    294      float change_pred_penalty = 800.0f / (100.0f + threshold);
    295      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
    296           prop++) {
    297        costs_l.clear();
    298        costs_r.clear();
    299        size_t prop_size = tree_samples.NumPropertyValues(prop);
    300        if (extra_bits_increase.size() < prop_size) {
    301          count_increase.resize(prop_size * max_symbols);
    302          extra_bits_increase.resize(prop_size);
    303        }
    304        // Clear prop_value_used_count (which cannot be cleared "on the go")
    305        prop_value_used_count.clear();
    306        prop_value_used_count.resize(prop_size);
    307 
    308        size_t first_used = prop_size;
    309        size_t last_used = 0;
    310 
    311        // TODO(veluca): consider finding multiple splits along a single
    312        // property at the same time, possibly with a bottom-up approach.
    313        for (size_t i = begin; i < end; i++) {
    314          size_t p = tree_samples.Property(prop, i);
    315          prop_value_used_count[p]++;
    316          last_used = std::max(last_used, p);
    317          first_used = std::min(first_used, p);
    318        }
    319        costs_l.resize(last_used - first_used);
    320        costs_r.resize(last_used - first_used);
    321        // For all predictors, compute the right and left costs of each split.
    322        for (size_t pred = 0; pred < num_predictors; pred++) {
    323          // Compute cost and histogram increments for each property value.
    324          for (size_t i = begin; i < end; i++) {
    325            size_t p = tree_samples.Property(prop, i);
    326            size_t cnt = tree_samples.Count(i);
    327            size_t sym = tree_samples.Token(pred, i);
    328            count_increase[p * max_symbols + sym] += cnt;
    329            extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt;
    330          }
    331          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
    332                 max_symbols * sizeof counts_above[0]);
    333          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
    334          size_t extra_bits_below = 0;
    335          // Exclude last used: this ensures neither counts_above nor
    336          // counts_below is empty.
    337          for (size_t i = first_used; i < last_used; i++) {
    338            if (!prop_value_used_count[i]) continue;
    339            extra_bits_below += extra_bits_increase[i];
    340            // The increase for this property value has been used, and will not
    341            // be used again: clear it. Also below.
    342            extra_bits_increase[i] = 0;
    343            for (size_t sym = 0; sym < max_symbols; sym++) {
    344              counts_above[sym] -= count_increase[i * max_symbols + sym];
    345              counts_below[sym] += count_increase[i * max_symbols + sym];
    346              count_increase[i * max_symbols + sym] = 0;
    347            }
    348            float rcost = EstimateBits(counts_above.data(), max_symbols) +
    349                          tot_extra_bits[pred] - extra_bits_below;
    350            float lcost = EstimateBits(counts_below.data(), max_symbols) +
    351                          extra_bits_below;
    352            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
    353            float penalty = 0;
    354            // Never discourage moving away from the Weighted predictor.
    355            if (tree_samples.PredictorFromIndex(pred) !=
    356                    (*tree)[pos].predictor &&
    357                (*tree)[pos].predictor != Predictor::Weighted) {
    358              penalty = change_pred_penalty;
    359            }
    360            // If everything else is equal, disfavour Weighted (slower) and
    361            // favour Zero (faster if it's the only predictor used in a
    362            // group+channel combination)
    363            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
    364              penalty += 1e-8;
    365            }
    366            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
    367              penalty -= 1e-8;
    368            }
    369            if (rcost + penalty < costs_r[i - first_used].Cost()) {
    370              costs_r[i - first_used].cost = rcost;
    371              costs_r[i - first_used].extra_cost = penalty;
    372              costs_r[i - first_used].pred =
    373                  tree_samples.PredictorFromIndex(pred);
    374            }
    375            if (lcost + penalty < costs_l[i - first_used].Cost()) {
    376              costs_l[i - first_used].cost = lcost;
    377              costs_l[i - first_used].extra_cost = penalty;
    378              costs_l[i - first_used].pred =
    379                  tree_samples.PredictorFromIndex(pred);
    380            }
    381          }
    382        }
    383        // Iterate through the possible splits and find the one with minimum sum
    384        // of costs of the two sides.
    385        size_t split = begin;
    386        for (size_t i = first_used; i < last_used; i++) {
    387          if (!prop_value_used_count[i]) continue;
    388          split += prop_value_used_count[i];
    389          float rcost = costs_r[i - first_used].cost;
    390          float lcost = costs_l[i - first_used].cost;
    391          // WP was not used + we would use the WP property or predictor
    392          bool adds_wp =
    393              (tree_samples.PropertyFromIndex(prop) == kWPProp &&
    394               (used_properties & (1LU << prop)) == 0) ||
    395              ((costs_l[i - first_used].pred == Predictor::Weighted ||
    396                costs_r[i - first_used].pred == Predictor::Weighted) &&
    397               (*tree)[pos].predictor != Predictor::Weighted);
    398          bool zero_entropy_side = rcost == 0 || lcost == 0;
    399 
    400          SplitInfo &best =
    401              prop < kNumStaticProperties
    402                  ? (zero_entropy_side ? best_split_static_constant
    403                                       : best_split_static)
    404                  : (adds_wp ? best_split_nonstatic : best_split_nowp);
    405          if (lcost + rcost < best.Cost()) {
    406            best.prop = prop;
    407            best.val = i;
    408            best.pos = split;
    409            best.lcost = lcost;
    410            best.lpred = costs_l[i - first_used].pred;
    411            best.rcost = rcost;
    412            best.rpred = costs_r[i - first_used].pred;
    413          }
    414        }
    415        // Clear extra_bits_increase and cost_increase for last_used.
    416        extra_bits_increase[last_used] = 0;
    417        for (size_t sym = 0; sym < max_symbols; sym++) {
    418          count_increase[last_used * max_symbols + sym] = 0;
    419        }
    420      }
    421 
    422      // Try to avoid introducing WP.
    423      if (best_split_nowp.Cost() + threshold < base_bits &&
    424          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
    425        best = &best_split_nowp;
    426      }
    427      // Split along static props if possible and not significantly more
    428      // expensive.
    429      if (best_split_static.Cost() + threshold < base_bits &&
    430          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
    431        best = &best_split_static;
    432      }
    433      // Split along static props to create constant nodes if possible.
    434      if (best_split_static_constant.Cost() + threshold < base_bits) {
    435        best = &best_split_static_constant;
    436      }
    437    }
    438 
    439    if (best->Cost() + threshold < base_bits) {
    440      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
    441      pixel_type dequant =
    442          tree_samples.UnquantizeProperty(best->prop, best->val);
    443      // Split node and try to split children.
    444      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
    445      // "Sort" according to winning property
    446      SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop);
    447      if (p >= kNumStaticProperties) {
    448        used_properties |= 1 << best->prop;
    449      }
    450      auto new_sp_range = static_prop_range;
    451      if (p < kNumStaticProperties) {
    452        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
    453        new_sp_range[p][1] = dequant + 1;
    454        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
    455      }
    456      nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
    457                               used_properties, new_sp_range});
    458      new_sp_range = static_prop_range;
    459      if (p < kNumStaticProperties) {
    460        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
    461        new_sp_range[p][0] = dequant + 1;
    462        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
    463      }
    464      nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
    465                               used_properties, new_sp_range});
    466    }
    467  }
    468 }
    469 
    470 // NOLINTNEXTLINE(google-readability-namespace-comments)
    471 }  // namespace HWY_NAMESPACE
    472 }  // namespace jxl
    473 HWY_AFTER_NAMESPACE();
    474 
    475 #if HWY_ONCE
    476 namespace jxl {
    477 
    478 HWY_EXPORT(FindBestSplit);  // Local function.
    479 
    480 Status ComputeBestTree(TreeSamples &tree_samples, float threshold,
    481                       const std::vector<ModularMultiplierInfo> &mul_info,
    482                       StaticPropRange static_prop_range,
    483                       float fast_decode_multiplier, Tree *tree) {
    484  // TODO(veluca): take into account that different contexts can have different
    485  // uint configs.
    486  //
    487  // Initialize tree.
    488  tree->emplace_back();
    489  tree->back().property = -1;
    490  tree->back().predictor = tree_samples.PredictorFromIndex(0);
    491  tree->back().predictor_offset = 0;
    492  tree->back().multiplier = 1;
    493  JXL_ENSURE(tree_samples.NumProperties() < 64);
    494 
    495  JXL_ENSURE(tree_samples.NumDistinctSamples() <=
    496             std::numeric_limits<uint32_t>::max());
    497  HWY_DYNAMIC_DISPATCH(FindBestSplit)
    498  (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier,
    499   tree);
    500  return true;
    501 }
    502 
    503 #if JXL_CXX_LANG < JXL_CXX_17
    504 constexpr int32_t TreeSamples::kPropertyRange;
    505 constexpr uint32_t TreeSamples::kDedupEntryUnused;
    506 #endif
    507 
    508 Status TreeSamples::SetPredictor(Predictor predictor,
    509                                 ModularOptions::TreeMode wp_tree_mode) {
    510  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
    511    predictors = {Predictor::Weighted};
    512    residuals.resize(1);
    513    return true;
    514  }
    515  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP &&
    516      predictor == Predictor::Weighted) {
    517    return JXL_FAILURE("Invalid predictor settings");
    518  }
    519  if (predictor == Predictor::Variable) {
    520    for (size_t i = 0; i < kNumModularPredictors; i++) {
    521      predictors.push_back(static_cast<Predictor>(i));
    522    }
    523    std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]);
    524    std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]);
    525  } else if (predictor == Predictor::Best) {
    526    predictors = {Predictor::Weighted, Predictor::Gradient};
    527  } else {
    528    predictors = {predictor};
    529  }
    530  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
    531    auto wp_it =
    532        std::find(predictors.begin(), predictors.end(), Predictor::Weighted);
    533    if (wp_it != predictors.end()) {
    534      predictors.erase(wp_it);
    535    }
    536  }
    537  residuals.resize(predictors.size());
    538  return true;
    539 }
    540 
    541 Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties,
    542                                  ModularOptions::TreeMode wp_tree_mode) {
    543  props_to_use = properties;
    544  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
    545    props_to_use = {static_cast<uint32_t>(kWPProp)};
    546  }
    547  if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) {
    548    props_to_use = {static_cast<uint32_t>(kGradientProp)};
    549  }
    550  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
    551    auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp);
    552    if (it != props_to_use.end()) {
    553      props_to_use.erase(it);
    554    }
    555  }
    556  if (props_to_use.empty()) {
    557    return JXL_FAILURE("Invalid property set configuration");
    558  }
    559  props.resize(props_to_use.size());
    560  return true;
    561 }
    562 
    563 void TreeSamples::InitTable(size_t log_size) {
    564  size_t size = 1ULL << log_size;
    565  if (dedup_table_.size() == size) return;
    566  dedup_table_.resize(size, kDedupEntryUnused);
    567  for (size_t i = 0; i < NumDistinctSamples(); i++) {
    568    if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) {
    569      AddToTable(i);
    570    }
    571  }
    572 }
    573 
    574 bool TreeSamples::AddToTableAndMerge(size_t a) {
    575  size_t pos1 = Hash1(a);
    576  size_t pos2 = Hash2(a);
    577  if (dedup_table_[pos1] != kDedupEntryUnused &&
    578      IsSameSample(a, dedup_table_[pos1])) {
    579    JXL_DASSERT(sample_counts[a] == 1);
    580    sample_counts[dedup_table_[pos1]]++;
    581    // Remove from hash table samples that are saturated.
    582    if (sample_counts[dedup_table_[pos1]] ==
    583        std::numeric_limits<uint16_t>::max()) {
    584      dedup_table_[pos1] = kDedupEntryUnused;
    585    }
    586    return true;
    587  }
    588  if (dedup_table_[pos2] != kDedupEntryUnused &&
    589      IsSameSample(a, dedup_table_[pos2])) {
    590    JXL_DASSERT(sample_counts[a] == 1);
    591    sample_counts[dedup_table_[pos2]]++;
    592    // Remove from hash table samples that are saturated.
    593    if (sample_counts[dedup_table_[pos2]] ==
    594        std::numeric_limits<uint16_t>::max()) {
    595      dedup_table_[pos2] = kDedupEntryUnused;
    596    }
    597    return true;
    598  }
    599  AddToTable(a);
    600  return false;
    601 }
    602 
    603 void TreeSamples::AddToTable(size_t a) {
    604  size_t pos1 = Hash1(a);
    605  size_t pos2 = Hash2(a);
    606  if (dedup_table_[pos1] == kDedupEntryUnused) {
    607    dedup_table_[pos1] = a;
    608  } else if (dedup_table_[pos2] == kDedupEntryUnused) {
    609    dedup_table_[pos2] = a;
    610  }
    611 }
    612 
    613 void TreeSamples::PrepareForSamples(size_t num_samples) {
    614  for (auto &res : residuals) {
    615    res.reserve(res.size() + num_samples);
    616  }
    617  for (auto &p : props) {
    618    p.reserve(p.size() + num_samples);
    619  }
    620  size_t total_num_samples = num_samples + sample_counts.size();
    621  size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2);
    622  InitTable(next_size);
    623 }
    624 
    625 size_t TreeSamples::Hash1(size_t a) const {
    626  constexpr uint64_t constant = 0x1e35a7bd;
    627  uint64_t h = constant;
    628  for (const auto &r : residuals) {
    629    h = h * constant + r[a].tok;
    630    h = h * constant + r[a].nbits;
    631  }
    632  for (const auto &p : props) {
    633    h = h * constant + p[a];
    634  }
    635  return (h >> 16) & (dedup_table_.size() - 1);
    636 }
    637 size_t TreeSamples::Hash2(size_t a) const {
    638  constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
    639  uint64_t h = constant;
    640  for (const auto &p : props) {
    641    h = h * constant ^ p[a];
    642  }
    643  for (const auto &r : residuals) {
    644    h = h * constant ^ r[a].tok;
    645    h = h * constant ^ r[a].nbits;
    646  }
    647  return (h >> 16) & (dedup_table_.size() - 1);
    648 }
    649 
    650 bool TreeSamples::IsSameSample(size_t a, size_t b) const {
    651  bool ret = true;
    652  for (const auto &r : residuals) {
    653    if (r[a].tok != r[b].tok) {
    654      ret = false;
    655    }
    656    if (r[a].nbits != r[b].nbits) {
    657      ret = false;
    658    }
    659  }
    660  for (const auto &p : props) {
    661    if (p[a] != p[b]) {
    662      ret = false;
    663    }
    664  }
    665  return ret;
    666 }
    667 
    668 void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties,
    669                            const pixel_type_w *predictions) {
    670  for (size_t i = 0; i < predictors.size(); i++) {
    671    pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
    672    uint32_t tok, nbits, bits;
    673    HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
    674    JXL_DASSERT(tok < 256);
    675    JXL_DASSERT(nbits < 256);
    676    residuals[i].emplace_back(
    677        ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)});
    678  }
    679  for (size_t i = 0; i < props_to_use.size(); i++) {
    680    props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
    681  }
    682  sample_counts.push_back(1);
    683  num_samples++;
    684  if (AddToTableAndMerge(sample_counts.size() - 1)) {
    685    for (auto &r : residuals) r.pop_back();
    686    for (auto &p : props) p.pop_back();
    687    sample_counts.pop_back();
    688  }
    689 }
    690 
    691 void TreeSamples::Swap(size_t a, size_t b) {
    692  if (a == b) return;
    693  for (auto &r : residuals) {
    694    std::swap(r[a], r[b]);
    695  }
    696  for (auto &p : props) {
    697    std::swap(p[a], p[b]);
    698  }
    699  std::swap(sample_counts[a], sample_counts[b]);
    700 }
    701 
    702 void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) {
    703  if (b == c) {
    704    Swap(a, b);
    705    return;
    706  }
    707 
    708  for (auto &r : residuals) {
    709    auto tmp = r[a];
    710    r[a] = r[c];
    711    r[c] = r[b];
    712    r[b] = tmp;
    713  }
    714  for (auto &p : props) {
    715    auto tmp = p[a];
    716    p[a] = p[c];
    717    p[c] = p[b];
    718    p[b] = tmp;
    719  }
    720  auto tmp = sample_counts[a];
    721  sample_counts[a] = sample_counts[c];
    722  sample_counts[c] = sample_counts[b];
    723  sample_counts[b] = tmp;
    724 }
    725 
    726 namespace {
    727 std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram,
    728                                       size_t num_chunks) {
    729  if (histogram.empty()) return {};
    730  // TODO(veluca): selecting distinct quantiles is likely not the best
    731  // way to go about this.
    732  std::vector<int32_t> thresholds;
    733  uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
    734  uint64_t cumsum = 0;
    735  uint64_t threshold = 1;
    736  for (size_t i = 0; i + 1 < histogram.size(); i++) {
    737    cumsum += histogram[i];
    738    if (cumsum >= threshold * sum / num_chunks) {
    739      thresholds.push_back(i);
    740      while (cumsum > threshold * sum / num_chunks) threshold++;
    741    }
    742  }
    743  return thresholds;
    744 }
    745 
    746 std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples,
    747                                     size_t num_chunks) {
    748  if (samples.empty()) return {};
    749  int min = *std::min_element(samples.begin(), samples.end());
    750  constexpr int kRange = 512;
    751  min = std::min(std::max(min, -kRange), kRange);
    752  std::vector<uint32_t> counts(2 * kRange + 1);
    753  for (int s : samples) {
    754    uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min;
    755    counts[sample_offset]++;
    756  }
    757  std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks);
    758  for (auto &v : thresholds) v += min;
    759  return thresholds;
    760 }
    761 }  // namespace
    762 
    763 void TreeSamples::PreQuantizeProperties(
    764    const StaticPropRange &range,
    765    const std::vector<ModularMultiplierInfo> &multiplier_info,
    766    const std::vector<uint32_t> &group_pixel_count,
    767    const std::vector<uint32_t> &channel_pixel_count,
    768    std::vector<pixel_type> &pixel_samples,
    769    std::vector<pixel_type> &diff_samples, size_t max_property_values) {
    770  // If we have forced splits because of multipliers, choose channel and group
    771  // thresholds accordingly.
    772  std::vector<int32_t> group_multiplier_thresholds;
    773  std::vector<int32_t> channel_multiplier_thresholds;
    774  for (const auto &v : multiplier_info) {
    775    if (v.range[0][0] != range[0][0]) {
    776      channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
    777    }
    778    if (v.range[0][1] != range[0][1]) {
    779      channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
    780    }
    781    if (v.range[1][0] != range[1][0]) {
    782      group_multiplier_thresholds.push_back(v.range[1][0] - 1);
    783    }
    784    if (v.range[1][1] != range[1][1]) {
    785      group_multiplier_thresholds.push_back(v.range[1][1] - 1);
    786    }
    787  }
    788  std::sort(channel_multiplier_thresholds.begin(),
    789            channel_multiplier_thresholds.end());
    790  channel_multiplier_thresholds.resize(
    791      std::unique(channel_multiplier_thresholds.begin(),
    792                  channel_multiplier_thresholds.end()) -
    793      channel_multiplier_thresholds.begin());
    794  std::sort(group_multiplier_thresholds.begin(),
    795            group_multiplier_thresholds.end());
    796  group_multiplier_thresholds.resize(
    797      std::unique(group_multiplier_thresholds.begin(),
    798                  group_multiplier_thresholds.end()) -
    799      group_multiplier_thresholds.begin());
    800 
    801  compact_properties.resize(props_to_use.size());
    802  auto quantize_channel = [&]() {
    803    if (!channel_multiplier_thresholds.empty()) {
    804      return channel_multiplier_thresholds;
    805    }
    806    return QuantizeHistogram(channel_pixel_count, max_property_values);
    807  };
    808  auto quantize_group_id = [&]() {
    809    if (!group_multiplier_thresholds.empty()) {
    810      return group_multiplier_thresholds;
    811    }
    812    return QuantizeHistogram(group_pixel_count, max_property_values);
    813  };
    814  auto quantize_coordinate = [&]() {
    815    std::vector<int32_t> quantized;
    816    quantized.reserve(max_property_values - 1);
    817    for (size_t i = 0; i + 1 < max_property_values; i++) {
    818      quantized.push_back((i + 1) * 256 / max_property_values - 1);
    819    }
    820    return quantized;
    821  };
    822  std::vector<int32_t> abs_pixel_thresholds;
    823  std::vector<int32_t> pixel_thresholds;
    824  auto quantize_pixel_property = [&]() {
    825    if (pixel_thresholds.empty()) {
    826      pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values);
    827    }
    828    return pixel_thresholds;
    829  };
    830  auto quantize_abs_pixel_property = [&]() {
    831    if (abs_pixel_thresholds.empty()) {
    832      quantize_pixel_property();  // Compute the non-abs thresholds.
    833      for (auto &v : pixel_samples) v = std::abs(v);
    834      abs_pixel_thresholds =
    835          QuantizeSamples(pixel_samples, max_property_values);
    836    }
    837    return abs_pixel_thresholds;
    838  };
    839  std::vector<int32_t> abs_diff_thresholds;
    840  std::vector<int32_t> diff_thresholds;
    841  auto quantize_diff_property = [&]() {
    842    if (diff_thresholds.empty()) {
    843      diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
    844    }
    845    return diff_thresholds;
    846  };
    847  auto quantize_abs_diff_property = [&]() {
    848    if (abs_diff_thresholds.empty()) {
    849      quantize_diff_property();  // Compute the non-abs thresholds.
    850      for (auto &v : diff_samples) v = std::abs(v);
    851      abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
    852    }
    853    return abs_diff_thresholds;
    854  };
    855  auto quantize_wp = [&]() {
    856    if (max_property_values < 32) {
    857      return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0,
    858                                  1,    3,   7,   15,  31, 63, 127};
    859    }
    860    if (max_property_values < 64) {
    861      return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23,
    862                                  -15,  -11,  -7,   -5,  -3,  -1,  0,   1,
    863                                  3,    5,    7,    11,  15,  23,  31,  47,
    864                                  63,   95,   127,  191, 255};
    865    }
    866    return std::vector<int32_t>{
    867        -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47,
    868        -39,  -31,  -27,  -23,  -19,  -15,  -13, -11, -9,  -7,  -6,
    869        -5,   -4,   -3,   -2,   -1,   0,    1,   2,   3,   4,   5,
    870        6,    7,    9,    11,   13,   15,   19,  23,  27,  31,  39,
    871        47,   55,   63,   79,   95,   111,  127, 159, 191, 223, 255};
    872  };
    873 
    874  property_mapping.resize(props_to_use.size());
    875  for (size_t i = 0; i < props_to_use.size(); i++) {
    876    if (props_to_use[i] == 0) {
    877      compact_properties[i] = quantize_channel();
    878    } else if (props_to_use[i] == 1) {
    879      compact_properties[i] = quantize_group_id();
    880    } else if (props_to_use[i] == 2 || props_to_use[i] == 3) {
    881      compact_properties[i] = quantize_coordinate();
    882    } else if (props_to_use[i] == 6 || props_to_use[i] == 7 ||
    883               props_to_use[i] == 8 ||
    884               (props_to_use[i] >= kNumNonrefProperties &&
    885                (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) {
    886      compact_properties[i] = quantize_pixel_property();
    887    } else if (props_to_use[i] == 4 || props_to_use[i] == 5 ||
    888               (props_to_use[i] >= kNumNonrefProperties &&
    889                (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) {
    890      compact_properties[i] = quantize_abs_pixel_property();
    891    } else if (props_to_use[i] >= kNumNonrefProperties &&
    892               (props_to_use[i] - kNumNonrefProperties) % 4 == 2) {
    893      compact_properties[i] = quantize_abs_diff_property();
    894    } else if (props_to_use[i] == kWPProp) {
    895      compact_properties[i] = quantize_wp();
    896    } else {
    897      compact_properties[i] = quantize_diff_property();
    898    }
    899    property_mapping[i].resize(kPropertyRange * 2 + 1);
    900    size_t mapped = 0;
    901    for (size_t j = 0; j < property_mapping[i].size(); j++) {
    902      while (mapped < compact_properties[i].size() &&
    903             static_cast<int>(j) - kPropertyRange >
    904                 compact_properties[i][mapped]) {
    905        mapped++;
    906      }
    907      // property_mapping[i] of a value V is `mapped` if
    908      // compact_properties[i][mapped] <= j and
    909      // compact_properties[i][mapped-1] > j
    910      // This is because the decision node in the tree splits on (property) > j,
    911      // hence everything that is not > of a threshold should be clustered
    912      // together.
    913      property_mapping[i][j] = mapped;
    914    }
    915  }
    916 }
    917 
    918 void CollectPixelSamples(const Image &image, const ModularOptions &options,
    919                         uint32_t group_id,
    920                         std::vector<uint32_t> &group_pixel_count,
    921                         std::vector<uint32_t> &channel_pixel_count,
    922                         std::vector<pixel_type> &pixel_samples,
    923                         std::vector<pixel_type> &diff_samples) {
    924  if (options.nb_repeats == 0) return;
    925  if (group_pixel_count.size() <= group_id) {
    926    group_pixel_count.resize(group_id + 1);
    927  }
    928  if (channel_pixel_count.size() < image.channel.size()) {
    929    channel_pixel_count.resize(image.channel.size());
    930  }
    931  Rng rng(group_id);
    932  // Sample 10% of the final number of samples for property quantization.
    933  float fraction = std::min(options.nb_repeats * 0.1, 0.99);
    934  Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction);
    935  size_t total_pixels = 0;
    936  std::vector<size_t> channel_ids;
    937  for (size_t i = 0; i < image.channel.size(); i++) {
    938    if (image.channel[i].w <= 1 || image.channel[i].h == 0) {
    939      continue;  // skip empty or width-1 channels.
    940    }
    941    if (i >= image.nb_meta_channels &&
    942        (image.channel[i].w > options.max_chan_size ||
    943         image.channel[i].h > options.max_chan_size)) {
    944      break;
    945    }
    946    channel_ids.push_back(i);
    947    group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h;
    948    channel_pixel_count[i] += image.channel[i].w * image.channel[i].h;
    949    total_pixels += image.channel[i].w * image.channel[i].h;
    950  }
    951  if (channel_ids.empty()) return;
    952  pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels);
    953  diff_samples.reserve(diff_samples.size() + fraction * total_pixels);
    954  size_t i = 0;
    955  size_t y = 0;
    956  size_t x = 0;
    957  auto advance = [&](size_t amount) {
    958    x += amount;
    959    // Detect row overflow (rare).
    960    while (x >= image.channel[channel_ids[i]].w) {
    961      x -= image.channel[channel_ids[i]].w;
    962      y++;
    963      // Detect end-of-channel (even rarer).
    964      if (y == image.channel[channel_ids[i]].h) {
    965        i++;
    966        y = 0;
    967        if (i >= channel_ids.size()) {
    968          return;
    969        }
    970      }
    971    }
    972  };
    973  advance(rng.Geometric(dist));
    974  for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) {
    975    const pixel_type *row = image.channel[channel_ids[i]].Row(y);
    976    pixel_samples.push_back(row[x]);
    977    size_t xp = x == 0 ? 1 : x - 1;
    978    diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]);
    979  }
    980 }
    981 
    982 // TODO(veluca): very simple encoding scheme. This should be improved.
    983 Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
    984                    Tree *decoder_tree) {
    985  JXL_ENSURE(tree.size() <= kMaxTreeSize);
    986  std::queue<int> q;
    987  q.push(0);
    988  size_t leaf_id = 0;
    989  decoder_tree->clear();
    990  while (!q.empty()) {
    991    int cur = q.front();
    992    q.pop();
    993    JXL_ENSURE(tree[cur].property >= -1);
    994    tokens->emplace_back(kPropertyContext, tree[cur].property + 1);
    995    if (tree[cur].property == -1) {
    996      tokens->emplace_back(kPredictorContext,
    997                           static_cast<int>(tree[cur].predictor));
    998      tokens->emplace_back(kOffsetContext,
    999                           PackSigned(tree[cur].predictor_offset));
   1000      uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier);
   1001      uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1;
   1002      tokens->emplace_back(kMultiplierLogContext, mul_log);
   1003      tokens->emplace_back(kMultiplierBitsContext, mul_bits);
   1004      JXL_ENSURE(tree[cur].predictor < Predictor::Best);
   1005      decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor,
   1006                                 tree[cur].predictor_offset,
   1007                                 tree[cur].multiplier);
   1008      continue;
   1009    }
   1010    decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval,
   1011                               decoder_tree->size() + q.size() + 1,
   1012                               decoder_tree->size() + q.size() + 2,
   1013                               Predictor::Zero, 0, 1);
   1014    q.push(tree[cur].lchild);
   1015    q.push(tree[cur].rchild);
   1016    tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval));
   1017  }
   1018  return true;
   1019 }
   1020 
   1021 }  // namespace jxl
   1022 #endif  // HWY_ONCE