tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_ans.cc (70545B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_ans.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 #include <jxl/types.h>
     10 
     11 #include <algorithm>
     12 #include <cmath>
     13 #include <cstdint>
     14 #include <limits>
     15 #include <unordered_map>
     16 #include <utility>
     17 #include <vector>
     18 
     19 #include "lib/jxl/ans_common.h"
     20 #include "lib/jxl/base/bits.h"
     21 #include "lib/jxl/base/fast_math-inl.h"
     22 #include "lib/jxl/base/status.h"
     23 #include "lib/jxl/dec_ans.h"
     24 #include "lib/jxl/enc_ans_params.h"
     25 #include "lib/jxl/enc_aux_out.h"
     26 #include "lib/jxl/enc_cluster.h"
     27 #include "lib/jxl/enc_context_map.h"
     28 #include "lib/jxl/enc_fields.h"
     29 #include "lib/jxl/enc_huffman.h"
     30 #include "lib/jxl/enc_params.h"
     31 #include "lib/jxl/fields.h"
     32 
     33 namespace jxl {
     34 
     35 namespace {
     36 
     37 #if (!JXL_IS_DEBUG_BUILD)
     38 constexpr
     39 #endif
     40    bool ans_fuzzer_friendly_ = false;
     41 
     42 const int kMaxNumSymbolsForSmallCode = 4;
     43 
     44 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
     45                       size_t alphabet_size, size_t log_alpha_size,
     46                       ANSEncSymbolInfo* info) {
     47  size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
     48  size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
     49  // create valid alias table for empty streams.
     50  for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
     51    const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
     52    info[s].freq_ = static_cast<uint16_t>(freq);
     53 #ifdef USE_MULT_BY_RECIPROCAL
     54    if (freq != 0) {
     55      info[s].ifreq_ =
     56          ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
     57    } else {
     58      info[s].ifreq_ = 1;  // shouldn't matter (symbol shouldn't occur), but...
     59    }
     60 #endif
     61    info[s].reverse_map_.resize(freq);
     62  }
     63  for (int i = 0; i < ANS_TAB_SIZE; i++) {
     64    AliasTable::Symbol s =
     65        AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
     66    info[s.value].reverse_map_[s.offset] = i;
     67  }
     68 }
     69 
     70 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
     71                       size_t len) {
     72  float sum = 0.0f;
     73  int total_histogram = 0;
     74  int total_counts = 0;
     75  for (size_t i = 0; i < len; ++i) {
     76    total_histogram += histogram[i];
     77    total_counts += counts[i];
     78    if (histogram[i] > 0) {
     79      JXL_DASSERT(counts[i] > 0);
     80      // += histogram[i] * -log(counts[i]/total_counts)
     81      sum += histogram[i] *
     82             std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
     83    }
     84  }
     85  if (total_histogram > 0) {
     86    // Used only in assert.
     87    (void)total_counts;
     88    JXL_DASSERT(total_counts == ANS_TAB_SIZE);
     89  }
     90  return sum;
     91 }
     92 
     93 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
     94  const float flat_bits = std::max(FastLog2f(len), 0.0f);
     95  float total_histogram = 0;
     96  for (size_t i = 0; i < len; ++i) {
     97    total_histogram += histogram[i];
     98  }
     99  return total_histogram * flat_bits;
    100 }
    101 
    102 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
    103 // sequence.
    104 const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
    105    5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
    106 };
    107 const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
    108    17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
    109 };
    110 
    111 // Returns the difference between largest count that can be represented and is
    112 // smaller than "count" and smallest representable count larger than "count".
    113 int SmallestIncrement(uint32_t count, uint32_t shift) {
    114  int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
    115  int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
    116  return drop_bits < 0 ? 1 : (1 << drop_bits);
    117 }
    118 
    119 template <bool minimize_error_of_sum>
    120 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
    121                        uint32_t shift, int* omit_pos, ANSHistBin* counts) {
    122  int sum = 0;
    123  float sum_nonrounded = 0.0;
    124  int remainder_pos = 0;  // if all of them are handled in first loop
    125  int remainder_log = -1;
    126  for (int n = 0; n < max_symbol; ++n) {
    127    if (targets[n] > 0 && targets[n] < 1.0f) {
    128      counts[n] = 1;
    129      sum_nonrounded += targets[n];
    130      sum += counts[n];
    131    }
    132  }
    133  const float discount_ratio =
    134      (table_size - sum) / (table_size - sum_nonrounded);
    135  JXL_ENSURE(discount_ratio > 0);
    136  JXL_ENSURE(discount_ratio <= 1.0f);
    137  // Invariant for minimize_error_of_sum == true:
    138  // abs(sum - sum_nonrounded)
    139  //   <= SmallestIncrement(max(targets[])) + max_symbol
    140  for (int n = 0; n < max_symbol; ++n) {
    141    if (targets[n] >= 1.0f) {
    142      sum_nonrounded += targets[n];
    143      counts[n] =
    144          static_cast<ANSHistBin>(targets[n] * discount_ratio);  // truncate
    145      if (counts[n] == 0) counts[n] = 1;
    146      if (counts[n] == table_size) counts[n] = table_size - 1;
    147      // Round the count to the closest nonzero multiple of SmallestIncrement
    148      // (when minimize_error_of_sum is false) or one of two closest so as to
    149      // keep the sum as close as possible to sum_nonrounded.
    150      int inc = SmallestIncrement(counts[n], shift);
    151      counts[n] -= counts[n] & (inc - 1);
    152      // TODO(robryk): Should we rescale targets[n]?
    153      const int target = minimize_error_of_sum
    154                             ? (static_cast<int>(sum_nonrounded) - sum)
    155                             : static_cast<int>(targets[n]);
    156      if (counts[n] == 0 ||
    157          (target >= counts[n] + inc / 2 && counts[n] + inc < table_size)) {
    158        counts[n] += inc;
    159      }
    160      sum += counts[n];
    161      const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
    162      if (count_log > remainder_log) {
    163        remainder_pos = n;
    164        remainder_log = count_log;
    165      }
    166    }
    167  }
    168  JXL_ENSURE(remainder_pos != -1);
    169  // NOTE: This is the only place where counts could go negative. We could
    170  // detect that, return false and make ANSHistBin uint32_t.
    171  counts[remainder_pos] -= sum - table_size;
    172  *omit_pos = remainder_pos;
    173  return counts[remainder_pos] > 0;
    174 }
    175 
    176 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
    177                       const int precision_bits, uint32_t shift,
    178                       int* num_symbols, int* symbols) {
    179  const int32_t table_size = 1 << precision_bits;  // target sum / table size
    180  uint64_t total = 0;
    181  int max_symbol = 0;
    182  int symbol_count = 0;
    183  for (int n = 0; n < length; ++n) {
    184    total += counts[n];
    185    if (counts[n] > 0) {
    186      if (symbol_count < kMaxNumSymbolsForSmallCode) {
    187        symbols[symbol_count] = n;
    188      }
    189      ++symbol_count;
    190      max_symbol = n + 1;
    191    }
    192  }
    193  *num_symbols = symbol_count;
    194  if (symbol_count == 0) {
    195    return true;
    196  }
    197  if (symbol_count == 1) {
    198    counts[symbols[0]] = table_size;
    199    return true;
    200  }
    201  if (symbol_count > table_size)
    202    return JXL_FAILURE("Too many entries in an ANS histogram");
    203 
    204  const float norm = 1.f * table_size / total;
    205  std::vector<float> targets(max_symbol);
    206  for (size_t n = 0; n < targets.size(); ++n) {
    207    targets[n] = norm * counts[n];
    208  }
    209  if (!RebalanceHistogram<false>(targets.data(), max_symbol, table_size, shift,
    210                                 omit_pos, counts)) {
    211    // Use an alternative rebalancing mechanism if the one above failed
    212    // to create a histogram that is positive wherever the original one was.
    213    if (!RebalanceHistogram<true>(targets.data(), max_symbol, table_size, shift,
    214                                  omit_pos, counts)) {
    215      return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
    216    }
    217  }
    218  return true;
    219 }
    220 
    221 struct SizeWriter {
    222  size_t size = 0;
    223  void Write(size_t num, size_t bits) { size += num; }
    224 };
    225 
    226 template <typename Writer>
    227 void StoreVarLenUint8(size_t n, Writer* writer) {
    228  JXL_DASSERT(n <= 255);
    229  if (n == 0) {
    230    writer->Write(1, 0);
    231  } else {
    232    writer->Write(1, 1);
    233    size_t nbits = FloorLog2Nonzero(n);
    234    writer->Write(3, nbits);
    235    writer->Write(nbits, n - (1ULL << nbits));
    236  }
    237 }
    238 
    239 template <typename Writer>
    240 void StoreVarLenUint16(size_t n, Writer* writer) {
    241  JXL_DASSERT(n <= 65535);
    242  if (n == 0) {
    243    writer->Write(1, 0);
    244  } else {
    245    writer->Write(1, 1);
    246    size_t nbits = FloorLog2Nonzero(n);
    247    writer->Write(4, nbits);
    248    writer->Write(nbits, n - (1ULL << nbits));
    249  }
    250 }
    251 
    252 template <typename Writer>
    253 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
    254                  const int omit_pos, const int num_symbols, uint32_t shift,
    255                  const int* symbols, Writer* writer) {
    256  bool ok = true;
    257  if (num_symbols <= 2) {
    258    // Small tree marker to encode 1-2 symbols.
    259    writer->Write(1, 1);
    260    if (num_symbols == 0) {
    261      writer->Write(1, 0);
    262      StoreVarLenUint8(0, writer);
    263    } else {
    264      writer->Write(1, num_symbols - 1);
    265      for (int i = 0; i < num_symbols; ++i) {
    266        StoreVarLenUint8(symbols[i], writer);
    267      }
    268    }
    269    if (num_symbols == 2) {
    270      writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
    271    }
    272  } else {
    273    // Mark non-small tree.
    274    writer->Write(1, 0);
    275    // Mark non-flat histogram.
    276    writer->Write(1, 0);
    277 
    278    // Precompute sequences for RLE encoding. Contains the number of identical
    279    // values starting at a given index. Only contains the value at the first
    280    // element of the series.
    281    std::vector<uint32_t> same(alphabet_size, 0);
    282    int last = 0;
    283    for (int i = 1; i < alphabet_size; i++) {
    284      // Store the sequence length once different symbol reached, or we're at
    285      // the end, or the length is longer than we can encode, or we are at
    286      // the omit_pos. We don't support including the omit_pos in an RLE
    287      // sequence because this value may use a different amount of log2 bits
    288      // than standard, it is too complex to handle in the decoder.
    289      if (counts[i] != counts[last] || i + 1 == alphabet_size ||
    290          (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
    291        same[last] = (i - last);
    292        last = i + 1;
    293      }
    294    }
    295 
    296    int length = 0;
    297    std::vector<int> logcounts(alphabet_size);
    298    int omit_log = 0;
    299    for (int i = 0; i < alphabet_size; ++i) {
    300      JXL_ENSURE(counts[i] <= ANS_TAB_SIZE);
    301      JXL_ENSURE(counts[i] >= 0);
    302      if (i == omit_pos) {
    303        length = i + 1;
    304      } else if (counts[i] > 0) {
    305        logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
    306        length = i + 1;
    307        if (i < omit_pos) {
    308          omit_log = std::max(omit_log, logcounts[i] + 1);
    309        } else {
    310          omit_log = std::max(omit_log, logcounts[i]);
    311        }
    312      }
    313    }
    314    logcounts[omit_pos] = omit_log;
    315 
    316    // Elias gamma-like code for shift. Only difference is that if the number
    317    // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
    318    // the terminating 0 in unary coding.
    319    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
    320    int log = FloorLog2Nonzero(shift + 1);
    321    writer->Write(log, (1 << log) - 1);
    322    if (log != upper_bound_log) writer->Write(1, 0);
    323    writer->Write(log, ((1 << log) - 1) & (shift + 1));
    324 
    325    // Since num_symbols >= 3, we know that length >= 3, therefore we encode
    326    // length - 3.
    327    if (length - 3 > 255) {
    328      // Pretend that everything is OK, but complain about correctness later.
    329      StoreVarLenUint8(255, writer);
    330      ok = false;
    331    } else {
    332      StoreVarLenUint8(length - 3, writer);
    333    }
    334 
    335    // The logcount values are encoded with a static Huffman code.
    336    static const size_t kMinReps = 4;
    337    size_t rep = ANS_LOG_TAB_SIZE + 1;
    338    for (int i = 0; i < length; ++i) {
    339      if (i > 0 && same[i - 1] > kMinReps) {
    340        // Encode the RLE symbol and skip the repeated ones.
    341        writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
    342        StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
    343        i += same[i - 1] - 2;
    344        continue;
    345      }
    346      writer->Write(kLogCountBitLengths[logcounts[i]],
    347                    kLogCountSymbols[logcounts[i]]);
    348    }
    349    for (int i = 0; i < length; ++i) {
    350      if (i > 0 && same[i - 1] > kMinReps) {
    351        // Skip symbols encoded by RLE.
    352        i += same[i - 1] - 2;
    353        continue;
    354      }
    355      if (logcounts[i] > 1 && i != omit_pos) {
    356        int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
    357        int drop_bits = logcounts[i] - 1 - bitcount;
    358        JXL_ENSURE((counts[i] & ((1 << drop_bits) - 1)) == 0);
    359        writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
    360      }
    361    }
    362  }
    363  return ok;
    364 }
    365 
    366 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
    367  // Mark non-small tree.
    368  writer->Write(1, 0);
    369  // Mark uniform histogram.
    370  writer->Write(1, 1);
    371  JXL_DASSERT(alphabet_size > 0);
    372  // Encode alphabet size.
    373  StoreVarLenUint8(alphabet_size - 1, writer);
    374 }
    375 
    376 StatusOr<float> ComputeHistoAndDataCost(const ANSHistBin* histogram,
    377                                        size_t alphabet_size, uint32_t method) {
    378  if (method == 0) {  // Flat code
    379    return ANS_LOG_TAB_SIZE + 2 +
    380           EstimateDataBitsFlat(histogram, alphabet_size);
    381  }
    382  // Non-flat: shift = method-1.
    383  uint32_t shift = method - 1;
    384  std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
    385  int omit_pos = 0;
    386  int num_symbols;
    387  int symbols[kMaxNumSymbolsForSmallCode] = {};
    388  JXL_RETURN_IF_ERROR(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
    389                                      ANS_LOG_TAB_SIZE, shift, &num_symbols,
    390                                      symbols));
    391  SizeWriter writer;
    392  // Ignore the correctness, no real encoding happens at this stage.
    393  (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
    394                     symbols, &writer);
    395  return writer.size +
    396         EstimateDataBits(histogram, counts.data(), alphabet_size);
    397 }
    398 
    399 StatusOr<uint32_t> ComputeBestMethod(
    400    const ANSHistBin* histogram, size_t alphabet_size, float* cost,
    401    HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
    402  uint32_t method = 0;
    403  JXL_ASSIGN_OR_RETURN(float fcost,
    404                       ComputeHistoAndDataCost(histogram, alphabet_size, 0));
    405  auto try_shift = [&](size_t shift) -> Status {
    406    JXL_ASSIGN_OR_RETURN(
    407        float c, ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1));
    408    if (c < fcost) {
    409      method = shift + 1;
    410      fcost = c;
    411    }
    412    return true;
    413  };
    414  switch (ans_histogram_strategy) {
    415    case HistogramParams::ANSHistogramStrategy::kPrecise: {
    416      for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) {
    417        JXL_RETURN_IF_ERROR(try_shift(shift));
    418      }
    419      break;
    420    }
    421    case HistogramParams::ANSHistogramStrategy::kApproximate: {
    422      for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
    423        JXL_RETURN_IF_ERROR(try_shift(shift));
    424      }
    425      break;
    426    }
    427    case HistogramParams::ANSHistogramStrategy::kFast: {
    428      JXL_RETURN_IF_ERROR(try_shift(0));
    429      JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE / 2));
    430      JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE));
    431      break;
    432    }
    433  };
    434  *cost = fcost;
    435  return method;
    436 }
    437 
    438 }  // namespace
    439 
    440 // Returns an estimate of the cost of encoding this histogram and the
    441 // corresponding data.
    442 StatusOr<size_t> BuildAndStoreANSEncodingData(
    443    JxlMemoryManager* memory_manager,
    444    HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
    445    const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
    446    bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
    447  if (use_prefix_code) {
    448    size_t cost = 0;
    449    if (alphabet_size <= 1) return 0;
    450    std::vector<uint32_t> histo(alphabet_size);
    451    for (size_t i = 0; i < alphabet_size; i++) {
    452      histo[i] = histogram[i];
    453      JXL_ENSURE(histogram[i] >= 0);
    454    }
    455    {
    456      std::vector<uint8_t> depths(alphabet_size);
    457      std::vector<uint16_t> bits(alphabet_size);
    458      if (writer == nullptr) {
    459        BitWriter tmp_writer{memory_manager};
    460        JXL_RETURN_IF_ERROR(tmp_writer.WithMaxBits(
    461            8 * alphabet_size + 8,  // safe upper bound
    462            LayerType::Header, /*aux_out=*/nullptr, [&] {
    463              return BuildAndStoreHuffmanTree(histo.data(), alphabet_size,
    464                                              depths.data(), bits.data(),
    465                                              &tmp_writer);
    466            }));
    467        cost = tmp_writer.BitsWritten();
    468      } else {
    469        size_t start = writer->BitsWritten();
    470        JXL_RETURN_IF_ERROR(BuildAndStoreHuffmanTree(
    471            histo.data(), alphabet_size, depths.data(), bits.data(), writer));
    472        cost = writer->BitsWritten() - start;
    473      }
    474      for (size_t i = 0; i < alphabet_size; i++) {
    475        info[i].bits = depths[i] == 0 ? 0 : bits[i];
    476        info[i].depth = depths[i];
    477      }
    478    }
    479    // Estimate data cost.
    480    for (size_t i = 0; i < alphabet_size; i++) {
    481      cost += histogram[i] * info[i].depth;
    482    }
    483    return cost;
    484  }
    485  JXL_ENSURE(alphabet_size <= ANS_TAB_SIZE);
    486  float fcost;
    487  JXL_ASSIGN_OR_RETURN(uint32_t method,
    488                       ComputeBestMethod(histogram, alphabet_size, &fcost,
    489                                         ans_histogram_strategy));
    490  JXL_ENSURE(fcost >= 0);
    491  int num_symbols;
    492  int symbols[kMaxNumSymbolsForSmallCode] = {};
    493  std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
    494  if (!counts.empty()) {
    495    size_t sum = 0;
    496    for (int count : counts) {
    497      sum += count;
    498    }
    499    if (sum == 0) {
    500      counts[0] = ANS_TAB_SIZE;
    501    }
    502  }
    503  int omit_pos = 0;
    504  uint32_t shift = method - 1;
    505  if (method == 0) {
    506    JXL_ENSURE(alphabet_size > 0);
    507    counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
    508  } else {
    509    JXL_RETURN_IF_ERROR(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
    510                                        ANS_LOG_TAB_SIZE, shift, &num_symbols,
    511                                        symbols));
    512  }
    513  AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
    514  JXL_RETURN_IF_ERROR(
    515      InitAliasTable(counts, ANS_LOG_TAB_SIZE, log_alpha_size, a));
    516  ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
    517  if (writer != nullptr) {
    518    if (method == 0) {
    519      JXL_ENSURE(alphabet_size > 0);
    520      EncodeFlatHistogram(alphabet_size, writer);
    521    } else {
    522      if (!EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
    523                        method - 1, symbols, writer)) {
    524        return JXL_FAILURE("EncodeCounts failed");
    525      }
    526    }
    527  }
    528  return static_cast<size_t>(fcost);
    529 }
    530 
    531 StatusOr<float> ANSPopulationCost(const ANSHistBin* data,
    532                                  size_t alphabet_size) {
    533  float cost = 0.0f;
    534  if (ANS_MAX_ALPHABET_SIZE < alphabet_size) {
    535    return std::numeric_limits<float>::max();
    536  }
    537  JXL_ASSIGN_OR_RETURN(
    538      uint32_t method,
    539      ComputeBestMethod(data, alphabet_size, &cost,
    540                        HistogramParams::ANSHistogramStrategy::kFast));
    541  (void)method;
    542  return cost;
    543 }
    544 
    545 template <typename Writer>
    546 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
    547                      size_t log_alpha_size) {
    548  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
    549                uint_config.split_exponent);
    550  if (uint_config.split_exponent == log_alpha_size) {
    551    return;  // msb/lsb don't matter.
    552  }
    553  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
    554  writer->Write(nbits, uint_config.msb_in_token);
    555  nbits = CeilLog2Nonzero(uint_config.split_exponent -
    556                          uint_config.msb_in_token + 1);
    557  writer->Write(nbits, uint_config.lsb_in_token);
    558 }
    559 template <typename Writer>
    560 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
    561                       Writer* writer, size_t log_alpha_size) {
    562  // TODO(veluca): RLE?
    563  for (const auto& cfg : uint_config) {
    564    EncodeUintConfig(cfg, writer, log_alpha_size);
    565  }
    566 }
    567 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
    568                                BitWriter*, size_t);
    569 
    570 namespace {
    571 
    572 Status ChooseUintConfigs(const HistogramParams& params,
    573                         const std::vector<std::vector<Token>>& tokens,
    574                         const std::vector<uint8_t>& context_map,
    575                         std::vector<Histogram>* clustered_histograms,
    576                         EntropyEncodingData* codes, size_t* log_alpha_size) {
    577  codes->uint_config.resize(clustered_histograms->size());
    578  if (params.uint_method == HistogramParams::HybridUintMethod::kNone) {
    579    return true;
    580  }
    581  if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
    582    codes->uint_config.clear();
    583    codes->uint_config.resize(clustered_histograms->size(),
    584                              HybridUintConfig(0, 0, 0));
    585    return true;
    586  }
    587  if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
    588    codes->uint_config.clear();
    589    codes->uint_config.resize(clustered_histograms->size(),
    590                              HybridUintConfig(2, 0, 1));
    591    return true;
    592  }
    593 
    594  // If the uint config is adaptive, just stick with the default in streaming
    595  // mode.
    596  if (params.streaming_mode) {
    597    return true;
    598  }
    599 
    600  // Brute-force method that tries a few options.
    601  std::vector<HybridUintConfig> configs;
    602  if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
    603    configs = {
    604        HybridUintConfig(4, 2, 0),  // default
    605        HybridUintConfig(4, 1, 0),  // less precise
    606        HybridUintConfig(4, 2, 1),  // add sign
    607        HybridUintConfig(4, 2, 2),  // add sign+parity
    608        HybridUintConfig(4, 1, 2),  // add parity but less msb
    609        // Same as above, but more direct coding.
    610        HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
    611        HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
    612        HybridUintConfig(5, 1, 2),
    613        // Same as above, but less direct coding.
    614        HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
    615        HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
    616        // For near-lossless.
    617        HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
    618        HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
    619        HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
    620        // Other
    621        HybridUintConfig(0, 0, 0),   // varlenuint
    622        HybridUintConfig(2, 0, 1),   // works well for ctx map
    623        HybridUintConfig(7, 0, 0),   // direct coding
    624        HybridUintConfig(8, 0, 0),   // direct coding
    625        HybridUintConfig(9, 0, 0),   // direct coding
    626        HybridUintConfig(10, 0, 0),  // direct coding
    627        HybridUintConfig(11, 0, 0),  // direct coding
    628        HybridUintConfig(12, 0, 0),  // direct coding
    629    };
    630  } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
    631    configs = {
    632        HybridUintConfig(4, 2, 0),  // default
    633        HybridUintConfig(4, 1, 2),  // add parity but less msb
    634        HybridUintConfig(0, 0, 0),  // smallest histograms
    635        HybridUintConfig(2, 0, 1),  // works well for ctx map
    636    };
    637  }
    638 
    639  std::vector<float> costs(clustered_histograms->size(),
    640                           std::numeric_limits<float>::max());
    641  std::vector<uint32_t> extra_bits(clustered_histograms->size());
    642  std::vector<uint8_t> is_valid(clustered_histograms->size());
    643  size_t max_alpha =
    644      codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
    645  for (HybridUintConfig cfg : configs) {
    646    std::fill(is_valid.begin(), is_valid.end(), true);
    647    std::fill(extra_bits.begin(), extra_bits.end(), 0);
    648 
    649    for (auto& histo : *clustered_histograms) {
    650      histo.Clear();
    651    }
    652    for (const auto& stream : tokens) {
    653      for (const auto& token : stream) {
    654        // TODO(veluca): do not ignore lz77 commands.
    655        if (token.is_lz77_length) continue;
    656        size_t histo = context_map[token.context];
    657        uint32_t tok, nbits, bits;
    658        cfg.Encode(token.value, &tok, &nbits, &bits);
    659        if (tok >= max_alpha ||
    660            (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
    661          is_valid[histo] = JXL_FALSE;
    662          continue;
    663        }
    664        extra_bits[histo] += nbits;
    665        (*clustered_histograms)[histo].Add(tok);
    666      }
    667    }
    668 
    669    for (size_t i = 0; i < clustered_histograms->size(); i++) {
    670      if (!is_valid[i]) continue;
    671      JXL_ASSIGN_OR_RETURN(float cost,
    672                           (*clustered_histograms)[i].PopulationCost());
    673      cost += extra_bits[i];
    674      // add signaling cost of the hybriduintconfig itself
    675      cost += CeilLog2Nonzero(cfg.split_exponent + 1);
    676      cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
    677      if (cost < costs[i]) {
    678        codes->uint_config[i] = cfg;
    679        costs[i] = cost;
    680      }
    681    }
    682  }
    683 
    684  // Rebuild histograms.
    685  for (auto& histo : *clustered_histograms) {
    686    histo.Clear();
    687  }
    688  *log_alpha_size = 4;
    689  for (const auto& stream : tokens) {
    690    for (const auto& token : stream) {
    691      uint32_t tok, nbits, bits;
    692      size_t histo = context_map[token.context];
    693      (token.is_lz77_length ? codes->lz77.length_uint_config
    694                            : codes->uint_config[histo])
    695          .Encode(token.value, &tok, &nbits, &bits);
    696      tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
    697      (*clustered_histograms)[histo].Add(tok);
    698      while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
    699    }
    700  }
    701  size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
    702  JXL_ENSURE(*log_alpha_size <= max_log_alpha_size);
    703  return true;
    704 }
    705 
    706 Histogram HistogramFromSymbolInfo(
    707    const std::vector<ANSEncSymbolInfo>& encoding_info, bool use_prefix_code) {
    708  Histogram histo;
    709  histo.data_.resize(DivCeil(encoding_info.size(), Histogram::kRounding) *
    710                     Histogram::kRounding);
    711  histo.total_count_ = 0;
    712  for (size_t i = 0; i < encoding_info.size(); ++i) {
    713    const ANSEncSymbolInfo& info = encoding_info[i];
    714    int count = use_prefix_code
    715                    ? (info.depth ? (1u << (PREFIX_MAX_BITS - info.depth)) : 0)
    716                    : info.freq_;
    717    histo.data_[i] = count;
    718    histo.total_count_ += count;
    719  }
    720  return histo;
    721 }
    722 
    723 class HistogramBuilder {
    724 public:
    725  explicit HistogramBuilder(const size_t num_contexts)
    726      : histograms_(num_contexts) {}
    727 
    728  void VisitSymbol(int symbol, size_t histo_idx) {
    729    JXL_DASSERT(histo_idx < histograms_.size());
    730    histograms_[histo_idx].Add(symbol);
    731  }
    732 
    733  // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
    734  // Returns cost (in bits).
    735  StatusOr<size_t> BuildAndStoreEntropyCodes(
    736      JxlMemoryManager* memory_manager, const HistogramParams& params,
    737      const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
    738      std::vector<uint8_t>* context_map, BitWriter* writer, LayerType layer,
    739      AuxOut* aux_out) const {
    740    const size_t prev_histograms = codes->encoding_info.size();
    741    std::vector<Histogram> clustered_histograms;
    742    for (size_t i = 0; i < prev_histograms; ++i) {
    743      clustered_histograms.push_back(HistogramFromSymbolInfo(
    744          codes->encoding_info[i], codes->use_prefix_code));
    745    }
    746    size_t context_offset = context_map->size();
    747    context_map->resize(context_offset + histograms_.size());
    748    if (histograms_.size() > 1) {
    749      if (!ans_fuzzer_friendly_) {
    750        std::vector<uint32_t> histogram_symbols;
    751        JXL_RETURN_IF_ERROR(
    752            ClusterHistograms(params, histograms_, kClustersLimit,
    753                              &clustered_histograms, &histogram_symbols));
    754        for (size_t c = 0; c < histograms_.size(); ++c) {
    755          (*context_map)[context_offset + c] =
    756              static_cast<uint8_t>(histogram_symbols[c]);
    757        }
    758      } else {
    759        JXL_ENSURE(codes->encoding_info.empty());
    760        fill(context_map->begin(), context_map->end(), 0);
    761        size_t max_symbol = 0;
    762        for (const Histogram& h : histograms_) {
    763          max_symbol = std::max(h.data_.size(), max_symbol);
    764        }
    765        size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
    766        clustered_histograms.resize(1);
    767        clustered_histograms[0].Clear();
    768        for (size_t i = 0; i < num_symbols; i++) {
    769          clustered_histograms[0].Add(i);
    770        }
    771      }
    772      if (writer != nullptr) {
    773        JXL_RETURN_IF_ERROR(EncodeContextMap(
    774            *context_map, clustered_histograms.size(), writer, layer, aux_out));
    775      }
    776    } else {
    777      JXL_ENSURE(codes->encoding_info.empty());
    778      clustered_histograms.push_back(histograms_[0]);
    779    }
    780    if (aux_out != nullptr) {
    781      for (size_t i = prev_histograms; i < clustered_histograms.size(); ++i) {
    782        aux_out->layer(layer).clustered_entropy +=
    783            clustered_histograms[i].ShannonEntropy();
    784      }
    785    }
    786    size_t log_alpha_size = codes->lz77.enabled ? 8 : 7;  // Sane default.
    787    if (ans_fuzzer_friendly_) {
    788      codes->uint_config.clear();
    789      codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
    790    } else {
    791      JXL_RETURN_IF_ERROR(ChooseUintConfigs(params, tokens, *context_map,
    792                                            &clustered_histograms, codes,
    793                                            &log_alpha_size));
    794    }
    795    if (log_alpha_size < 5) log_alpha_size = 5;
    796    if (params.streaming_mode) {
    797      // TODO(szabadka) Figure out if we can use lower values here.
    798      log_alpha_size = 8;
    799    }
    800    SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
    801    size_t cost = 1;
    802    if (writer) writer->Write(1, TO_JXL_BOOL(codes->use_prefix_code));
    803 
    804    if (codes->use_prefix_code) {
    805      log_alpha_size = PREFIX_MAX_BITS;
    806    } else {
    807      cost += 2;
    808    }
    809    if (writer == nullptr) {
    810      EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
    811    } else {
    812      if (!codes->use_prefix_code) writer->Write(2, log_alpha_size - 5);
    813      EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
    814    }
    815    if (codes->use_prefix_code) {
    816      for (const auto& histo : clustered_histograms) {
    817        size_t alphabet_size = histo.alphabet_size();
    818        if (writer) {
    819          StoreVarLenUint16(alphabet_size - 1, writer);
    820        } else {
    821          StoreVarLenUint16(alphabet_size - 1, &size_writer);
    822        }
    823      }
    824    }
    825    cost += size_writer.size;
    826    for (size_t c = prev_histograms; c < clustered_histograms.size(); ++c) {
    827      size_t alphabet_size = clustered_histograms[c].alphabet_size();
    828      codes->encoding_info.emplace_back();
    829      codes->encoding_info.back().resize(alphabet_size);
    830      BitWriter* histo_writer = writer;
    831      if (params.streaming_mode) {
    832        codes->encoded_histograms.emplace_back(memory_manager);
    833        histo_writer = &codes->encoded_histograms.back();
    834      }
    835      const auto& body = [&]() -> Status {
    836        JXL_ASSIGN_OR_RETURN(
    837            size_t ans_cost,
    838            BuildAndStoreANSEncodingData(
    839                memory_manager, params.ans_histogram_strategy,
    840                clustered_histograms[c].data_.data(), alphabet_size,
    841                log_alpha_size, codes->use_prefix_code,
    842                codes->encoding_info.back().data(), histo_writer));
    843        cost += ans_cost;
    844        return true;
    845      };
    846      if (histo_writer) {
    847        JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits(
    848            256 + alphabet_size * 24, layer, aux_out, body,
    849            /*finished_histogram=*/true));
    850      } else {
    851        JXL_RETURN_IF_ERROR(body());
    852      }
    853      if (params.streaming_mode) {
    854        JXL_RETURN_IF_ERROR(writer->AppendUnaligned(*histo_writer));
    855      }
    856    }
    857    return cost;
    858  }
    859 
    860  const Histogram& Histo(size_t i) const { return histograms_[i]; }
    861 
    862 private:
    863  std::vector<Histogram> histograms_;
    864 };
    865 
    866 class SymbolCostEstimator {
    867 public:
    868  SymbolCostEstimator(size_t num_contexts, bool force_huffman,
    869                      const std::vector<std::vector<Token>>& tokens,
    870                      const LZ77Params& lz77) {
    871    HistogramBuilder builder(num_contexts);
    872    // Build histograms for estimating lz77 savings.
    873    HybridUintConfig uint_config;
    874    for (const auto& stream : tokens) {
    875      for (const auto& token : stream) {
    876        uint32_t tok, nbits, bits;
    877        (token.is_lz77_length ? lz77.length_uint_config : uint_config)
    878            .Encode(token.value, &tok, &nbits, &bits);
    879        tok += token.is_lz77_length ? lz77.min_symbol : 0;
    880        builder.VisitSymbol(tok, token.context);
    881      }
    882    }
    883    max_alphabet_size_ = 0;
    884    for (size_t i = 0; i < num_contexts; i++) {
    885      max_alphabet_size_ =
    886          std::max(max_alphabet_size_, builder.Histo(i).data_.size());
    887    }
    888    bits_.resize(num_contexts * max_alphabet_size_);
    889    // TODO(veluca): SIMD?
    890    add_symbol_cost_.resize(num_contexts);
    891    for (size_t i = 0; i < num_contexts; i++) {
    892      float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
    893      float total_cost = 0;
    894      for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
    895        size_t cnt = builder.Histo(i).data_[j];
    896        float cost = 0;
    897        if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
    898          cost = -FastLog2f(cnt * inv_total);
    899          if (force_huffman) cost = std::ceil(cost);
    900        } else if (cnt == 0) {
    901          cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
    902        }
    903        bits_[i * max_alphabet_size_ + j] = cost;
    904        total_cost += cost * builder.Histo(i).data_[j];
    905      }
    906      // Penalty for adding a lz77 symbol to this contest (only used for static
    907      // cost model). Higher penalty for contexts that have a very low
    908      // per-symbol entropy.
    909      add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
    910    }
    911  }
    912  float Bits(size_t ctx, size_t sym) const {
    913    return bits_[ctx * max_alphabet_size_ + sym];
    914  }
    915  float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
    916    uint32_t nbits, bits, tok;
    917    lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
    918    tok += lz77.min_symbol;
    919    return nbits + Bits(ctx, tok);
    920  }
    921  float DistCost(size_t len, const LZ77Params& lz77) const {
    922    uint32_t nbits, bits, tok;
    923    HybridUintConfig().Encode(len, &tok, &nbits, &bits);
    924    return nbits + Bits(lz77.nonserialized_distance_context, tok);
    925  }
    926  float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
    927 
    928 private:
    929  size_t max_alphabet_size_;
    930  std::vector<float> bits_;
    931  std::vector<float> add_symbol_cost_;
    932 };
    933 
    934 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
    935                   const std::vector<std::vector<Token>>& tokens,
    936                   LZ77Params& lz77,
    937                   std::vector<std::vector<Token>>& tokens_lz77) {
    938  // TODO(veluca): tune heuristics here.
    939  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
    940  float bit_decrease = 0;
    941  size_t total_symbols = 0;
    942  tokens_lz77.resize(tokens.size());
    943  std::vector<float> sym_cost;
    944  HybridUintConfig uint_config;
    945  for (size_t stream = 0; stream < tokens.size(); stream++) {
    946    size_t distance_multiplier =
    947        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
    948    const auto& in = tokens[stream];
    949    auto& out = tokens_lz77[stream];
    950    total_symbols += in.size();
    951    // Cumulative sum of bit costs.
    952    sym_cost.resize(in.size() + 1);
    953    for (size_t i = 0; i < in.size(); i++) {
    954      uint32_t tok, nbits, unused_bits;
    955      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
    956      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
    957    }
    958    out.reserve(in.size());
    959    for (size_t i = 0; i < in.size(); i++) {
    960      size_t num_to_copy = 0;
    961      size_t distance_symbol = 0;  // 1 for RLE.
    962      if (distance_multiplier != 0) {
    963        distance_symbol = 1;  // Special distance 1 if enabled.
    964        JXL_DASSERT(kSpecialDistances[1][0] == 1);
    965        JXL_DASSERT(kSpecialDistances[1][1] == 0);
    966      }
    967      if (i > 0) {
    968        for (; i + num_to_copy < in.size(); num_to_copy++) {
    969          if (in[i + num_to_copy].value != in[i - 1].value) {
    970            break;
    971          }
    972        }
    973      }
    974      if (num_to_copy == 0) {
    975        out.push_back(in[i]);
    976        continue;
    977      }
    978      float cost = sym_cost[i + num_to_copy] - sym_cost[i];
    979      // This subtraction might overflow, but that's OK.
    980      size_t lz77_len = num_to_copy - lz77.min_length;
    981      float lz77_cost = num_to_copy >= lz77.min_length
    982                            ? CeilLog2Nonzero(lz77_len + 1) + 1
    983                            : 0;
    984      if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
    985        for (size_t j = 0; j < num_to_copy; j++) {
    986          out.push_back(in[i + j]);
    987        }
    988        i += num_to_copy - 1;
    989        continue;
    990      }
    991      // Output the LZ77 length
    992      out.emplace_back(in[i].context, lz77_len);
    993      out.back().is_lz77_length = true;
    994      i += num_to_copy - 1;
    995      bit_decrease += cost - lz77_cost;
    996      // Output the LZ77 copy distance.
    997      out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
    998    }
    999  }
   1000 
   1001  if (bit_decrease > total_symbols * 0.2 + 16) {
   1002    lz77.enabled = true;
   1003  }
   1004 }
   1005 
   1006 // Hash chain for LZ77 matching
   1007 struct HashChain {
   1008  size_t size_;
   1009  std::vector<uint32_t> data_;
   1010 
   1011  unsigned hash_num_values_ = 32768;
   1012  unsigned hash_mask_ = hash_num_values_ - 1;
   1013  unsigned hash_shift_ = 5;
   1014 
   1015  std::vector<int> head;
   1016  std::vector<uint32_t> chain;
   1017  std::vector<int> val;
   1018 
   1019  // Speed up repetitions of zero
   1020  std::vector<int> headz;
   1021  std::vector<uint32_t> chainz;
   1022  std::vector<uint32_t> zeros;
   1023  uint32_t numzeros = 0;
   1024 
   1025  size_t window_size_;
   1026  size_t window_mask_;
   1027  size_t min_length_;
   1028  size_t max_length_;
   1029 
   1030  // Map of special distance codes.
   1031  std::unordered_map<int, int> special_dist_table_;
   1032  size_t num_special_distances_ = 0;
   1033 
   1034  uint32_t maxchainlength = 256;  // window_size_ to allow all
   1035 
   1036  HashChain(const Token* data, size_t size, size_t window_size,
   1037            size_t min_length, size_t max_length, size_t distance_multiplier)
   1038      : size_(size),
   1039        window_size_(window_size),
   1040        window_mask_(window_size - 1),
   1041        min_length_(min_length),
   1042        max_length_(max_length) {
   1043    data_.resize(size);
   1044    for (size_t i = 0; i < size; i++) {
   1045      data_[i] = data[i].value;
   1046    }
   1047 
   1048    head.resize(hash_num_values_, -1);
   1049    val.resize(window_size_, -1);
   1050    chain.resize(window_size_);
   1051    for (uint32_t i = 0; i < window_size_; ++i) {
   1052      chain[i] = i;  // same value as index indicates uninitialized
   1053    }
   1054 
   1055    zeros.resize(window_size_);
   1056    headz.resize(window_size_ + 1, -1);
   1057    chainz.resize(window_size_);
   1058    for (uint32_t i = 0; i < window_size_; ++i) {
   1059      chainz[i] = i;
   1060    }
   1061    // Translate distance to special distance code.
   1062    if (distance_multiplier) {
   1063      // Count down, so if due to small distance multiplier multiple distances
   1064      // map to the same code, the smallest code will be used in the end.
   1065      for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
   1066        special_dist_table_[SpecialDistance(i, distance_multiplier)] = i;
   1067      }
   1068      num_special_distances_ = kNumSpecialDistances;
   1069    }
   1070  }
   1071 
   1072  uint32_t GetHash(size_t pos) const {
   1073    uint32_t result = 0;
   1074    if (pos + 2 < size_) {
   1075      // TODO(lode): take the MSB's of the uint32_t values into account as well,
   1076      // given that the hash code itself is less than 32 bits.
   1077      result ^= static_cast<uint32_t>(data_[pos + 0] << 0u);
   1078      result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_);
   1079      result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2));
   1080    } else {
   1081      // No need to compute hash of last 2 bytes, the length 2 is too short.
   1082      return 0;
   1083    }
   1084    return result & hash_mask_;
   1085  }
   1086 
   1087  uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
   1088    size_t end = pos + window_size_;
   1089    if (end > size_) end = size_;
   1090    if (prevzeros > 0) {
   1091      if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
   1092          end == pos + window_size_) {
   1093        return prevzeros;
   1094      } else {
   1095        return prevzeros - 1;
   1096      }
   1097    }
   1098    uint32_t num = 0;
   1099    while (pos + num < end && data_[pos + num] == 0) num++;
   1100    return num;
   1101  }
   1102 
   1103  void Update(size_t pos) {
   1104    uint32_t hashval = GetHash(pos);
   1105    uint32_t wpos = pos & window_mask_;
   1106 
   1107    val[wpos] = static_cast<int>(hashval);
   1108    if (head[hashval] != -1) chain[wpos] = head[hashval];
   1109    head[hashval] = wpos;
   1110 
   1111    if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
   1112    numzeros = CountZeros(pos, numzeros);
   1113 
   1114    zeros[wpos] = numzeros;
   1115    if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
   1116    headz[numzeros] = wpos;
   1117  }
   1118 
   1119  void Update(size_t pos, size_t len) {
   1120    for (size_t i = 0; i < len; i++) {
   1121      Update(pos + i);
   1122    }
   1123  }
   1124 
   1125  template <typename CB>
   1126  void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
   1127    uint32_t wpos = pos & window_mask_;
   1128    uint32_t hashval = GetHash(pos);
   1129    uint32_t hashpos = chain[wpos];
   1130 
   1131    int prev_dist = 0;
   1132    int end = std::min<int>(pos + max_length_, size_);
   1133    uint32_t chainlength = 0;
   1134    uint32_t best_len = 0;
   1135    for (;;) {
   1136      int dist = (hashpos <= wpos) ? (wpos - hashpos)
   1137                                   : (wpos - hashpos + window_mask_ + 1);
   1138      if (dist < prev_dist) break;
   1139      prev_dist = dist;
   1140      uint32_t len = 0;
   1141      if (dist > 0) {
   1142        int i = pos;
   1143        int j = pos - dist;
   1144        if (numzeros > 3) {
   1145          int r = std::min<int>(numzeros - 1, zeros[hashpos]);
   1146          if (i + r >= end) r = end - i - 1;
   1147          i += r;
   1148          j += r;
   1149        }
   1150        while (i < end && data_[i] == data_[j]) {
   1151          i++;
   1152          j++;
   1153        }
   1154        len = i - pos;
   1155        // This can trigger even if the new length is slightly smaller than the
   1156        // best length, because it is possible for a slightly cheaper distance
   1157        // symbol to occur.
   1158        if (len >= min_length_ && len + 2 >= best_len) {
   1159          auto it = special_dist_table_.find(dist);
   1160          int dist_symbol = (it == special_dist_table_.end())
   1161                                ? (num_special_distances_ + dist - 1)
   1162                                : it->second;
   1163          found_match(len, dist_symbol);
   1164          if (len > best_len) best_len = len;
   1165        }
   1166      }
   1167 
   1168      chainlength++;
   1169      if (chainlength >= maxchainlength) break;
   1170 
   1171      if (numzeros >= 3 && len > numzeros) {
   1172        if (hashpos == chainz[hashpos]) break;
   1173        hashpos = chainz[hashpos];
   1174        if (zeros[hashpos] != numzeros) break;
   1175      } else {
   1176        if (hashpos == chain[hashpos]) break;
   1177        hashpos = chain[hashpos];
   1178        if (val[hashpos] != static_cast<int>(hashval)) {
   1179          // outdated hash value
   1180          break;
   1181        }
   1182      }
   1183    }
   1184  }
   1185  void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
   1186                 size_t* result_len) const {
   1187    *result_dist_symbol = 0;
   1188    *result_len = 1;
   1189    FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
   1190      if (len > *result_len ||
   1191          (len == *result_len && *result_dist_symbol > dist_symbol)) {
   1192        *result_len = len;
   1193        *result_dist_symbol = dist_symbol;
   1194      }
   1195    });
   1196  }
   1197 };
   1198 
   1199 float LenCost(size_t len) {
   1200  uint32_t nbits, bits, tok;
   1201  HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
   1202  constexpr float kCostTable[] = {
   1203      2.797667318563126,  3.213177690381199,  2.5706009246743737,
   1204      2.408392498667534,  2.829649191872326,  3.3923087753324577,
   1205      4.029267451554331,  4.415576699706408,  4.509357574741465,
   1206      9.21481543803004,   10.020590190114898, 11.858671627804766,
   1207      12.45853300490526,  11.713105831990857, 12.561996324849314,
   1208      13.775477692278367, 13.174027068768641,
   1209  };
   1210  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
   1211  if (tok >= table_size) tok = table_size - 1;
   1212  return kCostTable[tok] + nbits;
   1213 }
   1214 
   1215 // TODO(veluca): this does not take into account usage or non-usage of distance
   1216 // multipliers.
   1217 float DistCost(size_t dist) {
   1218  uint32_t nbits, bits, tok;
   1219  HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
   1220  constexpr float kCostTable[] = {
   1221      6.368282626312716,  5.680793277090298,  8.347404197105247,
   1222      7.641619201599141,  6.914328374119438,  7.959808291537444,
   1223      8.70023120759855,   8.71378518934703,   9.379132523982769,
   1224      9.110472749092708,  9.159029569270908,  9.430936766731973,
   1225      7.278284055315169,  7.8278514904267755, 10.026641158289236,
   1226      9.976049229827066,  9.64351607048908,   9.563403863480442,
   1227      10.171474111762747, 10.45950155077234,  9.994813912104219,
   1228      10.322524683741156, 8.465808729388186,  8.756254166066853,
   1229      10.160930174662234, 10.247329273413435, 10.04090403724809,
   1230      10.129398517544082, 9.342311691539546,  9.07608009102374,
   1231      10.104799540677513, 10.378079384990906, 10.165828974075072,
   1232      10.337595322341553, 7.940557464567944,  10.575665823319431,
   1233      11.023344321751955, 10.736144698831827, 11.118277044595054,
   1234      7.468468230648442,  10.738305230932939, 10.906980780216568,
   1235      10.163468216353817, 10.17805759656433,  11.167283670483565,
   1236      11.147050200274544, 10.517921919244333, 10.651764778156886,
   1237      10.17074446448919,  11.217636876224745, 11.261630721139484,
   1238      11.403140815247259, 10.892472096873417, 11.1859607804481,
   1239      8.017346947551262,  7.895143720278828,  11.036577113822025,
   1240      11.170562110315794, 10.326988722591086, 10.40872184751056,
   1241      11.213498225466386, 11.30580635516863,  10.672272515665442,
   1242      10.768069466228063, 11.145257364153565, 11.64668307145549,
   1243      10.593156194627339, 11.207499484844943, 10.767517766396908,
   1244      10.826629811407042, 10.737764794499988, 10.6200448518045,
   1245      10.191315385198092, 8.468384171390085,  11.731295299170432,
   1246      11.824619886654398, 10.41518844301179,  10.16310536548649,
   1247      10.539423685097576, 10.495136599328031, 10.469112847728267,
   1248      11.72057686174922,  10.910326337834674, 11.378921834673758,
   1249      11.847759036098536, 11.92071647623854,  10.810628276345282,
   1250      11.008601085273893, 11.910326337834674, 11.949212023423133,
   1251      11.298614839104337, 11.611603659010392, 10.472930394619985,
   1252      11.835564720850282, 11.523267392285337, 12.01055816679611,
   1253      8.413029688994023,  11.895784139536406, 11.984679534970505,
   1254      11.220654278717394, 11.716311684833672, 10.61036646226114,
   1255      10.89849965960364,  10.203762898863669, 10.997560826267238,
   1256      11.484217379438984, 11.792836176993665, 12.24310468755171,
   1257      11.464858097919262, 12.212747017409377, 11.425595666074955,
   1258      11.572048533398757, 12.742093965163013, 11.381874288645637,
   1259      12.191870445817015, 11.683156920035426, 11.152442115262197,
   1260      11.90303691580457,  11.653292787169159, 11.938615382266098,
   1261      16.970641701570223, 16.853602280380002, 17.26240782594733,
   1262      16.644655390108507, 17.14310889757499,  16.910935455445955,
   1263      17.505678976959697, 17.213498225466388, 2.4162310293553024,
   1264      3.494587244462329,  3.5258600986408344, 3.4959806589517095,
   1265      3.098390886949687,  3.343454654302911,  3.588847442290287,
   1266      4.14614790111827,   5.152948641990529,  7.433696808092598,
   1267      9.716311684833672,
   1268  };
   1269  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
   1270  if (tok >= table_size) tok = table_size - 1;
   1271  return kCostTable[tok] + nbits;
   1272 }
   1273 
   1274 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
   1275                    const std::vector<std::vector<Token>>& tokens,
   1276                    LZ77Params& lz77,
   1277                    std::vector<std::vector<Token>>& tokens_lz77) {
   1278  // TODO(veluca): tune heuristics here.
   1279  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
   1280  float bit_decrease = 0;
   1281  size_t total_symbols = 0;
   1282  tokens_lz77.resize(tokens.size());
   1283  HybridUintConfig uint_config;
   1284  std::vector<float> sym_cost;
   1285  for (size_t stream = 0; stream < tokens.size(); stream++) {
   1286    size_t distance_multiplier =
   1287        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
   1288    const auto& in = tokens[stream];
   1289    auto& out = tokens_lz77[stream];
   1290    total_symbols += in.size();
   1291    // Cumulative sum of bit costs.
   1292    sym_cost.resize(in.size() + 1);
   1293    for (size_t i = 0; i < in.size(); i++) {
   1294      uint32_t tok, nbits, unused_bits;
   1295      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
   1296      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
   1297    }
   1298 
   1299    out.reserve(in.size());
   1300    size_t max_distance = in.size();
   1301    size_t min_length = lz77.min_length;
   1302    JXL_DASSERT(min_length >= 3);
   1303    size_t max_length = in.size();
   1304 
   1305    // Use next power of two as window size.
   1306    size_t window_size = 1;
   1307    while (window_size < max_distance && window_size < kWindowSize) {
   1308      window_size <<= 1;
   1309    }
   1310 
   1311    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
   1312                    distance_multiplier);
   1313    size_t len;
   1314    size_t dist_symbol;
   1315 
   1316    const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
   1317 
   1318    // Whether the next symbol was already updated (to test lazy matching)
   1319    bool already_updated = false;
   1320    for (size_t i = 0; i < in.size(); i++) {
   1321      out.push_back(in[i]);
   1322      if (!already_updated) chain.Update(i);
   1323      already_updated = false;
   1324      chain.FindMatch(i, max_distance, &dist_symbol, &len);
   1325      if (len >= min_length) {
   1326        if (len < max_lazy_match_len && i + 1 < in.size()) {
   1327          // Try length at next symbol lazy matching
   1328          chain.Update(i + 1);
   1329          already_updated = true;
   1330          size_t len2, dist_symbol2;
   1331          chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
   1332          if (len2 > len) {
   1333            // Use the lazy match. Add literal, and use the next length starting
   1334            // from the next byte.
   1335            ++i;
   1336            already_updated = false;
   1337            len = len2;
   1338            dist_symbol = dist_symbol2;
   1339            out.push_back(in[i]);
   1340          }
   1341        }
   1342 
   1343        float cost = sym_cost[i + len] - sym_cost[i];
   1344        size_t lz77_len = len - lz77.min_length;
   1345        float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
   1346                          sce.AddSymbolCost(out.back().context);
   1347 
   1348        if (lz77_cost <= cost) {
   1349          out.back().value = len - min_length;
   1350          out.back().is_lz77_length = true;
   1351          out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
   1352          bit_decrease += cost - lz77_cost;
   1353        } else {
   1354          // LZ77 match ignored, and symbol already pushed. Push all other
   1355          // symbols and skip.
   1356          for (size_t j = 1; j < len; j++) {
   1357            out.push_back(in[i + j]);
   1358          }
   1359        }
   1360 
   1361        if (already_updated) {
   1362          chain.Update(i + 2, len - 2);
   1363          already_updated = false;
   1364        } else {
   1365          chain.Update(i + 1, len - 1);
   1366        }
   1367        i += len - 1;
   1368      } else {
   1369        // Literal, already pushed
   1370      }
   1371    }
   1372  }
   1373 
   1374  if (bit_decrease > total_symbols * 0.2 + 16) {
   1375    lz77.enabled = true;
   1376  }
   1377 }
   1378 
   1379 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
   1380                       const std::vector<std::vector<Token>>& tokens,
   1381                       LZ77Params& lz77,
   1382                       std::vector<std::vector<Token>>& tokens_lz77) {
   1383  std::vector<std::vector<Token>> tokens_for_cost_estimate;
   1384  ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
   1385  // If greedy-LZ77 does not give better compression than no-lz77, no reason to
   1386  // run the optimal matching.
   1387  if (!lz77.enabled) return;
   1388  SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
   1389                          tokens_for_cost_estimate, lz77);
   1390  tokens_lz77.resize(tokens.size());
   1391  HybridUintConfig uint_config;
   1392  std::vector<float> sym_cost;
   1393  std::vector<uint32_t> dist_symbols;
   1394  for (size_t stream = 0; stream < tokens.size(); stream++) {
   1395    size_t distance_multiplier =
   1396        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
   1397    const auto& in = tokens[stream];
   1398    auto& out = tokens_lz77[stream];
   1399    // Cumulative sum of bit costs.
   1400    sym_cost.resize(in.size() + 1);
   1401    for (size_t i = 0; i < in.size(); i++) {
   1402      uint32_t tok, nbits, unused_bits;
   1403      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
   1404      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
   1405    }
   1406 
   1407    out.reserve(in.size());
   1408    size_t max_distance = in.size();
   1409    size_t min_length = lz77.min_length;
   1410    JXL_DASSERT(min_length >= 3);
   1411    size_t max_length = in.size();
   1412 
   1413    // Use next power of two as window size.
   1414    size_t window_size = 1;
   1415    while (window_size < max_distance && window_size < kWindowSize) {
   1416      window_size <<= 1;
   1417    }
   1418 
   1419    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
   1420                    distance_multiplier);
   1421 
   1422    struct MatchInfo {
   1423      uint32_t len;
   1424      uint32_t dist_symbol;
   1425      uint32_t ctx;
   1426      float total_cost = std::numeric_limits<float>::max();
   1427    };
   1428    // Total cost to encode the first N symbols.
   1429    std::vector<MatchInfo> prefix_costs(in.size() + 1);
   1430    prefix_costs[0].total_cost = 0;
   1431 
   1432    size_t rle_length = 0;
   1433    size_t skip_lz77 = 0;
   1434    for (size_t i = 0; i < in.size(); i++) {
   1435      chain.Update(i);
   1436      float lit_cost =
   1437          prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
   1438      if (prefix_costs[i + 1].total_cost > lit_cost) {
   1439        prefix_costs[i + 1].dist_symbol = 0;
   1440        prefix_costs[i + 1].len = 1;
   1441        prefix_costs[i + 1].ctx = in[i].context;
   1442        prefix_costs[i + 1].total_cost = lit_cost;
   1443      }
   1444      if (skip_lz77 > 0) {
   1445        skip_lz77--;
   1446        continue;
   1447      }
   1448      dist_symbols.clear();
   1449      chain.FindMatches(i, max_distance,
   1450                        [&dist_symbols](size_t len, size_t dist_symbol) {
   1451                          if (dist_symbols.size() <= len) {
   1452                            dist_symbols.resize(len + 1, dist_symbol);
   1453                          }
   1454                          if (dist_symbol < dist_symbols[len]) {
   1455                            dist_symbols[len] = dist_symbol;
   1456                          }
   1457                        });
   1458      if (dist_symbols.size() <= min_length) continue;
   1459      {
   1460        size_t best_cost = dist_symbols.back();
   1461        for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
   1462          if (dist_symbols[j] < best_cost) {
   1463            best_cost = dist_symbols[j];
   1464          }
   1465          dist_symbols[j] = best_cost;
   1466        }
   1467      }
   1468      for (size_t j = min_length; j < dist_symbols.size(); j++) {
   1469        // Cost model that uses results from lazy LZ77.
   1470        float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
   1471                          sce.DistCost(dist_symbols[j], lz77);
   1472        float cost = prefix_costs[i].total_cost + lz77_cost;
   1473        if (prefix_costs[i + j].total_cost > cost) {
   1474          prefix_costs[i + j].len = j;
   1475          prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
   1476          prefix_costs[i + j].ctx = in[i].context;
   1477          prefix_costs[i + j].total_cost = cost;
   1478        }
   1479      }
   1480      // We are in a RLE sequence: skip all the symbols except the first 8 and
   1481      // the last 8. This avoid quadratic costs for sequences with long runs of
   1482      // the same symbol.
   1483      if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
   1484          (dist_symbols.back() == 1 && distance_multiplier != 0)) {
   1485        rle_length++;
   1486      } else {
   1487        rle_length = 0;
   1488      }
   1489      if (rle_length >= 8 && dist_symbols.size() > 9) {
   1490        skip_lz77 = dist_symbols.size() - 10;
   1491        rle_length = 0;
   1492      }
   1493    }
   1494    size_t pos = in.size();
   1495    while (pos > 0) {
   1496      bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
   1497      if (is_lz77_length) {
   1498        size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
   1499        out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
   1500      }
   1501      size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
   1502                                  : in[pos - 1].value;
   1503      out.emplace_back(prefix_costs[pos].ctx, val);
   1504      out.back().is_lz77_length = is_lz77_length;
   1505      pos -= prefix_costs[pos].len;
   1506    }
   1507    std::reverse(out.begin(), out.end());
   1508  }
   1509 }
   1510 
   1511 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
   1512               const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
   1513               std::vector<std::vector<Token>>& tokens_lz77) {
   1514  if (params.initialize_global_state) {
   1515    lz77.enabled = false;
   1516  }
   1517  if (params.force_huffman) {
   1518    lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
   1519  } else {
   1520    lz77.min_symbol = 224;
   1521  }
   1522  switch (params.lz77_method) {
   1523    case HistogramParams::LZ77Method::kNone:
   1524      return;
   1525    case HistogramParams::LZ77Method::kRLE:
   1526      ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
   1527      return;
   1528    case HistogramParams::LZ77Method::kLZ77:
   1529      ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
   1530      return;
   1531    case HistogramParams::LZ77Method::kOptimal:
   1532      ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
   1533      return;
   1534  }
   1535 }
   1536 }  // namespace
   1537 
   1538 Status EncodeHistograms(const std::vector<uint8_t>& context_map,
   1539                        const EntropyEncodingData& codes, BitWriter* writer,
   1540                        LayerType layer, AuxOut* aux_out) {
   1541  return writer->WithMaxBits(
   1542      128 + kClustersLimit * 136, layer, aux_out,
   1543      [&]() -> Status {
   1544        JXL_RETURN_IF_ERROR(Bundle::Write(codes.lz77, writer, layer, aux_out));
   1545        if (codes.lz77.enabled) {
   1546          EncodeUintConfig(codes.lz77.length_uint_config, writer,
   1547                           /*log_alpha_size=*/8);
   1548        }
   1549        JXL_RETURN_IF_ERROR(EncodeContextMap(
   1550            context_map, codes.encoding_info.size(), writer, layer, aux_out));
   1551        writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code));
   1552        size_t log_alpha_size = 8;
   1553        if (codes.use_prefix_code) {
   1554          log_alpha_size = PREFIX_MAX_BITS;
   1555        } else {
   1556          log_alpha_size = 8;  // streaming_mode
   1557          writer->Write(2, log_alpha_size - 5);
   1558        }
   1559        EncodeUintConfigs(codes.uint_config, writer, log_alpha_size);
   1560        if (codes.use_prefix_code) {
   1561          for (const auto& info : codes.encoding_info) {
   1562            StoreVarLenUint16(info.size() - 1, writer);
   1563          }
   1564        }
   1565        for (const auto& histo_writer : codes.encoded_histograms) {
   1566          JXL_RETURN_IF_ERROR(writer->AppendUnaligned(histo_writer));
   1567        }
   1568        return true;
   1569      },
   1570      /*finished_histogram=*/true);
   1571 }
   1572 
   1573 StatusOr<size_t> BuildAndEncodeHistograms(
   1574    JxlMemoryManager* memory_manager, const HistogramParams& params,
   1575    size_t num_contexts, std::vector<std::vector<Token>>& tokens,
   1576    EntropyEncodingData* codes, std::vector<uint8_t>* context_map,
   1577    BitWriter* writer, LayerType layer, AuxOut* aux_out) {
   1578  size_t cost = 0;
   1579  codes->lz77.nonserialized_distance_context = num_contexts;
   1580  std::vector<std::vector<Token>> tokens_lz77;
   1581  ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
   1582  if (ans_fuzzer_friendly_) {
   1583    codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
   1584    codes->lz77.min_symbol = 2048;
   1585  }
   1586 
   1587  const size_t max_contexts = std::min(num_contexts, kClustersLimit);
   1588  const auto& body = [&]() -> Status {
   1589    if (writer) {
   1590      JXL_RETURN_IF_ERROR(Bundle::Write(codes->lz77, writer, layer, aux_out));
   1591    } else {
   1592      size_t ebits, bits;
   1593      JXL_RETURN_IF_ERROR(Bundle::CanEncode(codes->lz77, &ebits, &bits));
   1594      cost += bits;
   1595    }
   1596    if (codes->lz77.enabled) {
   1597      if (writer) {
   1598        size_t b = writer->BitsWritten();
   1599        EncodeUintConfig(codes->lz77.length_uint_config, writer,
   1600                         /*log_alpha_size=*/8);
   1601        cost += writer->BitsWritten() - b;
   1602      } else {
   1603        SizeWriter size_writer;
   1604        EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
   1605                         /*log_alpha_size=*/8);
   1606        cost += size_writer.size;
   1607      }
   1608      num_contexts += 1;
   1609      tokens = std::move(tokens_lz77);
   1610    }
   1611    size_t total_tokens = 0;
   1612    // Build histograms.
   1613    HistogramBuilder builder(num_contexts);
   1614    HybridUintConfig uint_config;  //  Default config for clustering.
   1615    // Unless we are using the kContextMap histogram option.
   1616    if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
   1617      uint_config = HybridUintConfig(2, 0, 1);
   1618    }
   1619    if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
   1620      uint_config = HybridUintConfig(0, 0, 0);
   1621    }
   1622    if (ans_fuzzer_friendly_) {
   1623      uint_config = HybridUintConfig(10, 0, 0);
   1624    }
   1625    for (const auto& stream : tokens) {
   1626      if (codes->lz77.enabled) {
   1627        for (const auto& token : stream) {
   1628          total_tokens++;
   1629          uint32_t tok, nbits, bits;
   1630          (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
   1631              .Encode(token.value, &tok, &nbits, &bits);
   1632          tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
   1633          builder.VisitSymbol(tok, token.context);
   1634        }
   1635      } else if (num_contexts == 1) {
   1636        for (const auto& token : stream) {
   1637          total_tokens++;
   1638          uint32_t tok, nbits, bits;
   1639          uint_config.Encode(token.value, &tok, &nbits, &bits);
   1640          builder.VisitSymbol(tok, /*token.context=*/0);
   1641        }
   1642      } else {
   1643        for (const auto& token : stream) {
   1644          total_tokens++;
   1645          uint32_t tok, nbits, bits;
   1646          uint_config.Encode(token.value, &tok, &nbits, &bits);
   1647          builder.VisitSymbol(tok, token.context);
   1648        }
   1649      }
   1650    }
   1651 
   1652    if (params.add_missing_symbols) {
   1653      for (size_t c = 0; c < num_contexts; ++c) {
   1654        for (int symbol = 0; symbol < ANS_MAX_ALPHABET_SIZE; ++symbol) {
   1655          builder.VisitSymbol(symbol, c);
   1656        }
   1657      }
   1658    }
   1659 
   1660    if (params.initialize_global_state) {
   1661      bool use_prefix_code =
   1662          params.force_huffman || total_tokens < 100 ||
   1663          params.clustering == HistogramParams::ClusteringType::kFastest ||
   1664          ans_fuzzer_friendly_;
   1665      if (!use_prefix_code) {
   1666        bool all_singleton = true;
   1667        for (size_t i = 0; i < num_contexts; i++) {
   1668          if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
   1669            all_singleton = false;
   1670          }
   1671        }
   1672        if (all_singleton) {
   1673          use_prefix_code = true;
   1674        }
   1675      }
   1676      codes->use_prefix_code = use_prefix_code;
   1677    }
   1678 
   1679    if (params.add_fixed_histograms) {
   1680      // TODO(szabadka) Add more fixed histograms.
   1681      // TODO(szabadka) Reduce alphabet size by choosing a non-default
   1682      // uint_config.
   1683      const size_t alphabet_size = ANS_MAX_ALPHABET_SIZE;
   1684      const size_t log_alpha_size = 8;
   1685      JXL_ENSURE(alphabet_size == 1u << log_alpha_size);
   1686      static_assert(ANS_MAX_ALPHABET_SIZE <= ANS_TAB_SIZE);
   1687      std::vector<int32_t> counts =
   1688          CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
   1689      codes->encoding_info.emplace_back();
   1690      codes->encoding_info.back().resize(alphabet_size);
   1691      codes->encoded_histograms.emplace_back(memory_manager);
   1692      BitWriter* histo_writer = &codes->encoded_histograms.back();
   1693      JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits(
   1694          256 + alphabet_size * 24, LayerType::Header, nullptr,
   1695          [&]() -> Status {
   1696            JXL_ASSIGN_OR_RETURN(
   1697                size_t ans_cost,
   1698                BuildAndStoreANSEncodingData(
   1699                    memory_manager, params.ans_histogram_strategy,
   1700                    counts.data(), alphabet_size, log_alpha_size,
   1701                    codes->use_prefix_code, codes->encoding_info.back().data(),
   1702                    histo_writer));
   1703            (void)ans_cost;
   1704            return true;
   1705          }));
   1706    }
   1707 
   1708    // Encode histograms.
   1709    JXL_ASSIGN_OR_RETURN(
   1710        size_t entropy_bits,
   1711        builder.BuildAndStoreEntropyCodes(memory_manager, params, tokens, codes,
   1712                                          context_map, writer, layer, aux_out));
   1713    cost += entropy_bits;
   1714    return true;
   1715  };
   1716  if (writer) {
   1717    JXL_RETURN_IF_ERROR(writer->WithMaxBits(
   1718        128 + num_contexts * 40 + max_contexts * 96, layer, aux_out, body,
   1719        /*finished_histogram=*/true));
   1720  } else {
   1721    JXL_RETURN_IF_ERROR(body());
   1722  }
   1723 
   1724  if (aux_out != nullptr) {
   1725    aux_out->layer(layer).num_clustered_histograms +=
   1726        codes->encoding_info.size();
   1727  }
   1728  return cost;
   1729 }
   1730 
   1731 size_t WriteTokens(const std::vector<Token>& tokens,
   1732                   const EntropyEncodingData& codes,
   1733                   const std::vector<uint8_t>& context_map,
   1734                   size_t context_offset, BitWriter* writer) {
   1735  size_t num_extra_bits = 0;
   1736  if (codes.use_prefix_code) {
   1737    for (const auto& token : tokens) {
   1738      uint32_t tok, nbits, bits;
   1739      size_t histo = context_map[context_offset + token.context];
   1740      (token.is_lz77_length ? codes.lz77.length_uint_config
   1741                            : codes.uint_config[histo])
   1742          .Encode(token.value, &tok, &nbits, &bits);
   1743      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
   1744      // Combine two calls to the BitWriter. Equivalent to:
   1745      // writer->Write(codes.encoding_info[histo][tok].depth,
   1746      //               codes.encoding_info[histo][tok].bits);
   1747      // writer->Write(nbits, bits);
   1748      uint64_t data = codes.encoding_info[histo][tok].bits;
   1749      data |= static_cast<uint64_t>(bits)
   1750              << codes.encoding_info[histo][tok].depth;
   1751      writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
   1752      num_extra_bits += nbits;
   1753    }
   1754    return num_extra_bits;
   1755  }
   1756  std::vector<uint64_t> out;
   1757  std::vector<uint8_t> out_nbits;
   1758  out.reserve(tokens.size());
   1759  out_nbits.reserve(tokens.size());
   1760  uint64_t allbits = 0;
   1761  size_t numallbits = 0;
   1762  // Writes in *reversed* order.
   1763  auto addbits = [&](size_t bits, size_t nbits) {
   1764    if (JXL_UNLIKELY(nbits)) {
   1765      JXL_DASSERT(bits >> nbits == 0);
   1766      if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
   1767        out.push_back(allbits);
   1768        out_nbits.push_back(numallbits);
   1769        numallbits = allbits = 0;
   1770      }
   1771      allbits <<= nbits;
   1772      allbits |= bits;
   1773      numallbits += nbits;
   1774    }
   1775  };
   1776  const int end = tokens.size();
   1777  ANSCoder ans;
   1778  if (codes.lz77.enabled || context_map.size() > 1) {
   1779    for (int i = end - 1; i >= 0; --i) {
   1780      const Token token = tokens[i];
   1781      const uint8_t histo = context_map[context_offset + token.context];
   1782      uint32_t tok, nbits, bits;
   1783      (token.is_lz77_length ? codes.lz77.length_uint_config
   1784                            : codes.uint_config[histo])
   1785          .Encode(tokens[i].value, &tok, &nbits, &bits);
   1786      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
   1787      const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
   1788      JXL_DASSERT(info.freq_ > 0);
   1789      // Extra bits first as this is reversed.
   1790      addbits(bits, nbits);
   1791      num_extra_bits += nbits;
   1792      uint8_t ans_nbits = 0;
   1793      uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
   1794      addbits(ans_bits, ans_nbits);
   1795    }
   1796  } else {
   1797    for (int i = end - 1; i >= 0; --i) {
   1798      uint32_t tok, nbits, bits;
   1799      codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
   1800      const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
   1801      // Extra bits first as this is reversed.
   1802      addbits(bits, nbits);
   1803      num_extra_bits += nbits;
   1804      uint8_t ans_nbits = 0;
   1805      uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
   1806      addbits(ans_bits, ans_nbits);
   1807    }
   1808  }
   1809  const uint32_t state = ans.GetState();
   1810  writer->Write(32, state);
   1811  writer->Write(numallbits, allbits);
   1812  for (int i = out.size(); i > 0; --i) {
   1813    writer->Write(out_nbits[i - 1], out[i - 1]);
   1814  }
   1815  return num_extra_bits;
   1816 }
   1817 
   1818 Status WriteTokens(const std::vector<Token>& tokens,
   1819                   const EntropyEncodingData& codes,
   1820                   const std::vector<uint8_t>& context_map,
   1821                   size_t context_offset, BitWriter* writer, LayerType layer,
   1822                   AuxOut* aux_out) {
   1823  // Theoretically, we could have 15 prefix code bits + 31 extra bits.
   1824  return writer->WithMaxBits(
   1825      46 * tokens.size() + 32 * 1024 * 4, layer, aux_out, [&] {
   1826        size_t num_extra_bits =
   1827            WriteTokens(tokens, codes, context_map, context_offset, writer);
   1828        if (aux_out != nullptr) {
   1829          aux_out->layer(layer).extra_bits += num_extra_bits;
   1830        }
   1831        return true;
   1832      });
   1833 }
   1834 
   1835 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
   1836 #if JXL_IS_DEBUG_BUILD  // Guard against accidental / malicious changes.
   1837  ans_fuzzer_friendly_ = ans_fuzzer_friendly;
   1838 #endif
   1839 }
   1840 
   1841 HistogramParams HistogramParams::ForModular(
   1842    const CompressParams& cparams,
   1843    const std::vector<uint8_t>& extra_dc_precision, bool streaming_mode) {
   1844  HistogramParams params;
   1845  params.streaming_mode = streaming_mode;
   1846  if (cparams.speed_tier > SpeedTier::kKitten) {
   1847    params.clustering = HistogramParams::ClusteringType::kFast;
   1848    params.ans_histogram_strategy =
   1849        cparams.speed_tier > SpeedTier::kThunder
   1850            ? HistogramParams::ANSHistogramStrategy::kFast
   1851            : HistogramParams::ANSHistogramStrategy::kApproximate;
   1852    params.lz77_method =
   1853        cparams.decoding_speed_tier >= 3 && cparams.modular_mode
   1854            ? (cparams.speed_tier >= SpeedTier::kFalcon
   1855                   ? HistogramParams::LZ77Method::kRLE
   1856                   : HistogramParams::LZ77Method::kLZ77)
   1857            : HistogramParams::LZ77Method::kNone;
   1858    // Near-lossless DC, as well as modular mode, require choosing hybrid uint
   1859    // more carefully.
   1860    if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) ||
   1861        (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) {
   1862      params.uint_method = HistogramParams::HybridUintMethod::kFast;
   1863    } else {
   1864      params.uint_method = HistogramParams::HybridUintMethod::kNone;
   1865    }
   1866  } else if (cparams.speed_tier <= SpeedTier::kTortoise) {
   1867    params.lz77_method = HistogramParams::LZ77Method::kOptimal;
   1868  } else {
   1869    params.lz77_method = HistogramParams::LZ77Method::kLZ77;
   1870  }
   1871  if (cparams.decoding_speed_tier >= 1) {
   1872    params.max_histograms = 12;
   1873  }
   1874  if (cparams.decoding_speed_tier >= 1 && cparams.responsive) {
   1875    params.lz77_method = cparams.speed_tier >= SpeedTier::kCheetah
   1876                             ? HistogramParams::LZ77Method::kRLE
   1877                         : cparams.speed_tier >= SpeedTier::kKitten
   1878                             ? HistogramParams::LZ77Method::kLZ77
   1879                             : HistogramParams::LZ77Method::kOptimal;
   1880  }
   1881  if (cparams.decoding_speed_tier >= 2 && cparams.responsive) {
   1882    params.uint_method = HistogramParams::HybridUintMethod::k000;
   1883    params.force_huffman = true;
   1884  }
   1885  return params;
   1886 }
   1887 }  // namespace jxl