tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

dec_ans.cc (15861B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_ans.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <cstdint>
     11 #include <vector>
     12 
     13 #include "lib/jxl/ans_common.h"
     14 #include "lib/jxl/ans_params.h"
     15 #include "lib/jxl/base/bits.h"
     16 #include "lib/jxl/base/printf_macros.h"
     17 #include "lib/jxl/base/status.h"
     18 #include "lib/jxl/dec_context_map.h"
     19 #include "lib/jxl/fields.h"
     20 #include "lib/jxl/memory_manager_internal.h"
     21 
     22 namespace jxl {
     23 namespace {
     24 
     25 // Decodes a number in the range [0..255], by reading 1 - 11 bits.
     26 inline int DecodeVarLenUint8(BitReader* input) {
     27  if (input->ReadFixedBits<1>()) {
     28    int nbits = static_cast<int>(input->ReadFixedBits<3>());
     29    if (nbits == 0) {
     30      return 1;
     31    } else {
     32      return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits);
     33    }
     34  }
     35  return 0;
     36 }
     37 
     38 // Decodes a number in the range [0..65535], by reading 1 - 21 bits.
     39 inline int DecodeVarLenUint16(BitReader* input) {
     40  if (input->ReadFixedBits<1>()) {
     41    int nbits = static_cast<int>(input->ReadFixedBits<4>());
     42    if (nbits == 0) {
     43      return 1;
     44    } else {
     45      return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits);
     46    }
     47  }
     48  return 0;
     49 }
     50 
     51 Status ReadHistogram(int precision_bits, std::vector<int32_t>* counts,
     52                     BitReader* input) {
     53  int range = 1 << precision_bits;
     54  int simple_code = input->ReadBits(1);
     55  if (simple_code == 1) {
     56    int i;
     57    int symbols[2] = {0};
     58    int max_symbol = 0;
     59    const int num_symbols = input->ReadBits(1) + 1;
     60    for (i = 0; i < num_symbols; ++i) {
     61      symbols[i] = DecodeVarLenUint8(input);
     62      if (symbols[i] > max_symbol) max_symbol = symbols[i];
     63    }
     64    counts->resize(max_symbol + 1);
     65    if (num_symbols == 1) {
     66      (*counts)[symbols[0]] = range;
     67    } else {
     68      if (symbols[0] == symbols[1]) {  // corrupt data
     69        return false;
     70      }
     71      (*counts)[symbols[0]] = input->ReadBits(precision_bits);
     72      (*counts)[symbols[1]] = range - (*counts)[symbols[0]];
     73    }
     74  } else {
     75    int is_flat = input->ReadBits(1);
     76    if (is_flat == 1) {
     77      int alphabet_size = DecodeVarLenUint8(input) + 1;
     78      JXL_ENSURE(alphabet_size <= range);
     79      *counts = CreateFlatHistogram(alphabet_size, range);
     80      return true;
     81    }
     82 
     83    uint32_t shift;
     84    {
     85      // TODO(veluca): speed up reading with table lookups.
     86      int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
     87      int log = 0;
     88      for (; log < upper_bound_log; log++) {
     89        if (input->ReadFixedBits<1>() == 0) break;
     90      }
     91      shift = (input->ReadBits(log) | (1 << log)) - 1;
     92      if (shift > ANS_LOG_TAB_SIZE + 1) {
     93        return JXL_FAILURE("Invalid shift value");
     94      }
     95    }
     96 
     97    int length = DecodeVarLenUint8(input) + 3;
     98    counts->resize(length);
     99    int total_count = 0;
    100 
    101    static const uint8_t huff[128][2] = {
    102        {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    103        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    104        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    105        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    106        {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    107        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    108        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    109        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    110        {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    111        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    112        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    113        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    114        {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    115        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    116        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    117        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    118    };
    119 
    120    std::vector<int> logcounts(counts->size());
    121    int omit_log = -1;
    122    int omit_pos = -1;
    123    // This array remembers which symbols have an RLE length.
    124    std::vector<int> same(counts->size(), 0);
    125    for (size_t i = 0; i < logcounts.size(); ++i) {
    126      input->Refill();  // for PeekFixedBits + Advance
    127      int idx = input->PeekFixedBits<7>();
    128      input->Consume(huff[idx][0]);
    129      logcounts[i] = huff[idx][1];
    130      // The RLE symbol.
    131      if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) {
    132        int rle_length = DecodeVarLenUint8(input);
    133        same[i] = rle_length + 5;
    134        i += rle_length + 3;
    135        continue;
    136      }
    137      if (logcounts[i] > omit_log) {
    138        omit_log = logcounts[i];
    139        omit_pos = i;
    140      }
    141    }
    142    // Invalid input, e.g. due to invalid usage of RLE.
    143    if (omit_pos < 0) return JXL_FAILURE("Invalid histogram.");
    144    if (static_cast<size_t>(omit_pos) + 1 < logcounts.size() &&
    145        logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) {
    146      return JXL_FAILURE("Invalid histogram.");
    147    }
    148    int prev = 0;
    149    int numsame = 0;
    150    for (size_t i = 0; i < logcounts.size(); ++i) {
    151      if (same[i]) {
    152        // RLE sequence, let this loop output the same count for the next
    153        // iterations.
    154        numsame = same[i] - 1;
    155        prev = i > 0 ? (*counts)[i - 1] : 0;
    156      }
    157      if (numsame > 0) {
    158        (*counts)[i] = prev;
    159        numsame--;
    160      } else {
    161        unsigned int code = logcounts[i];
    162        // omit_pos may not be negative at this point (checked before).
    163        if (i == static_cast<size_t>(omit_pos)) {
    164          continue;
    165        } else if (code == 0) {
    166          continue;
    167        } else if (code == 1) {
    168          (*counts)[i] = 1;
    169        } else {
    170          int bitcount = GetPopulationCountPrecision(code - 1, shift);
    171          (*counts)[i] = (1u << (code - 1)) +
    172                         (input->ReadBits(bitcount) << (code - 1 - bitcount));
    173        }
    174      }
    175      total_count += (*counts)[i];
    176    }
    177    (*counts)[omit_pos] = range - total_count;
    178    if ((*counts)[omit_pos] <= 0) {
    179      // The histogram we've read sums to more than total_count (including at
    180      // least 1 for the omitted value).
    181      return JXL_FAILURE("Invalid histogram count.");
    182    }
    183  }
    184  return true;
    185 }
    186 
    187 }  // namespace
    188 
    189 Status DecodeANSCodes(JxlMemoryManager* memory_manager,
    190                      const size_t num_histograms,
    191                      const size_t max_alphabet_size, BitReader* in,
    192                      ANSCode* result) {
    193  result->memory_manager = memory_manager;
    194  result->degenerate_symbols.resize(num_histograms, -1);
    195  if (result->use_prefix_code) {
    196    JXL_ENSURE(max_alphabet_size <= 1 << PREFIX_MAX_BITS);
    197    result->huffman_data.resize(num_histograms);
    198    std::vector<uint16_t> alphabet_sizes(num_histograms);
    199    for (size_t c = 0; c < num_histograms; c++) {
    200      alphabet_sizes[c] = DecodeVarLenUint16(in) + 1;
    201      if (alphabet_sizes[c] > max_alphabet_size) {
    202        return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]);
    203      }
    204    }
    205    for (size_t c = 0; c < num_histograms; c++) {
    206      if (alphabet_sizes[c] > 1) {
    207        if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) {
    208          if (!in->AllReadsWithinBounds()) {
    209            return JXL_STATUS(StatusCode::kNotEnoughBytes,
    210                              "Not enough bytes for huffman code");
    211          }
    212          return JXL_FAILURE("Invalid huffman tree number %" PRIuS
    213                             ", alphabet size %u",
    214                             c, alphabet_sizes[c]);
    215        }
    216      } else {
    217        // 0-bit codes does not require extension tables.
    218        result->huffman_data[c].table_.clear();
    219        result->huffman_data[c].table_.resize(1u << kHuffmanTableBits);
    220      }
    221      for (const auto& h : result->huffman_data[c].table_) {
    222        if (h.bits <= kHuffmanTableBits) {
    223          result->UpdateMaxNumBits(c, h.value);
    224        }
    225      }
    226    }
    227  } else {
    228    JXL_ENSURE(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE);
    229    size_t alloc_size = num_histograms * (1 << result->log_alpha_size) *
    230                        sizeof(AliasTable::Entry);
    231    JXL_ASSIGN_OR_RETURN(result->alias_tables,
    232                         AlignedMemory::Create(memory_manager, alloc_size));
    233    AliasTable::Entry* alias_tables =
    234        result->alias_tables.address<AliasTable::Entry>();
    235    for (size_t c = 0; c < num_histograms; ++c) {
    236      std::vector<int32_t> counts;
    237      if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) {
    238        return JXL_FAILURE("Invalid histogram bitstream.");
    239      }
    240      if (counts.size() > max_alphabet_size) {
    241        return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size());
    242      }
    243      while (!counts.empty() && counts.back() == 0) {
    244        counts.pop_back();
    245      }
    246      for (size_t s = 0; s < counts.size(); s++) {
    247        if (counts[s] != 0) {
    248          result->UpdateMaxNumBits(c, s);
    249        }
    250      }
    251      // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol.
    252      int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1);
    253      for (int s = 0; s < degenerate_symbol; ++s) {
    254        if (counts[s] != 0) {
    255          degenerate_symbol = -1;
    256          break;
    257        }
    258      }
    259      result->degenerate_symbols[c] = degenerate_symbol;
    260      JXL_RETURN_IF_ERROR(
    261          InitAliasTable(counts, ANS_LOG_TAB_SIZE, result->log_alpha_size,
    262                         alias_tables + c * (1 << result->log_alpha_size)));
    263    }
    264  }
    265  return true;
    266 }
    267 Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config,
    268                        BitReader* br) {
    269  br->Refill();
    270  size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1));
    271  size_t msb_in_token = 0;
    272  size_t lsb_in_token = 0;
    273  if (split_exponent != log_alpha_size) {
    274    // otherwise, msb/lsb don't matter.
    275    size_t nbits = CeilLog2Nonzero(split_exponent + 1);
    276    msb_in_token = br->ReadBits(nbits);
    277    if (msb_in_token > split_exponent) {
    278      // This could be invalid here already and we need to check this before
    279      // we use its value to read more bits.
    280      return JXL_FAILURE("Invalid HybridUintConfig");
    281    }
    282    nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1);
    283    lsb_in_token = br->ReadBits(nbits);
    284  }
    285  if (lsb_in_token + msb_in_token > split_exponent) {
    286    return JXL_FAILURE("Invalid HybridUintConfig");
    287  }
    288  *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token);
    289  return true;
    290 }
    291 
    292 Status DecodeUintConfigs(size_t log_alpha_size,
    293                         std::vector<HybridUintConfig>* uint_config,
    294                         BitReader* br) {
    295  // TODO(veluca): RLE?
    296  for (auto& cfg : *uint_config) {
    297    JXL_RETURN_IF_ERROR(DecodeUintConfig(log_alpha_size, &cfg, br));
    298  }
    299  return true;
    300 }
    301 
    302 LZ77Params::LZ77Params() { Bundle::Init(this); }
    303 Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) {
    304  JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled));
    305  if (!visitor->Conditional(enabled)) return true;
    306  JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096),
    307                                         BitsOffset(15, 8), 224, &min_symbol));
    308  JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5),
    309                                         BitsOffset(8, 9), 3, &min_length));
    310  return true;
    311 }
    312 
    313 void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) {
    314  HybridUintConfig* cfg = &uint_config[ctx];
    315  // LZ77 symbols use a different uint config.
    316  if (lz77.enabled && lz77.nonserialized_distance_context != ctx &&
    317      symbol >= lz77.min_symbol) {
    318    symbol -= lz77.min_symbol;
    319    cfg = &lz77.length_uint_config;
    320  }
    321  size_t split_token = cfg->split_token;
    322  size_t msb_in_token = cfg->msb_in_token;
    323  size_t lsb_in_token = cfg->lsb_in_token;
    324  size_t split_exponent = cfg->split_exponent;
    325  if (symbol < split_token) {
    326    max_num_bits = std::max(max_num_bits, split_exponent);
    327    return;
    328  }
    329  uint32_t n_extra_bits =
    330      split_exponent - (msb_in_token + lsb_in_token) +
    331      ((symbol - split_token) >> (msb_in_token + lsb_in_token));
    332  size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1;
    333  max_num_bits = std::max(max_num_bits, total_bits);
    334 }
    335 
    336 Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br,
    337                        size_t num_contexts, ANSCode* code,
    338                        std::vector<uint8_t>* context_map, bool disallow_lz77) {
    339  JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77));
    340  if (code->lz77.enabled) {
    341    num_contexts++;
    342    JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8,
    343                                         &code->lz77.length_uint_config, br));
    344  }
    345  if (code->lz77.enabled && disallow_lz77) {
    346    return JXL_FAILURE("Using LZ77 when explicitly disallowed");
    347  }
    348  size_t num_histograms = 1;
    349  context_map->resize(num_contexts);
    350  if (num_contexts > 1) {
    351    JXL_RETURN_IF_ERROR(
    352        DecodeContextMap(memory_manager, context_map, &num_histograms, br));
    353  }
    354  JXL_DEBUG_V(
    355      4, "Decoded context map of size %" PRIuS " and %" PRIuS " histograms",
    356      num_contexts, num_histograms);
    357  code->lz77.nonserialized_distance_context = context_map->back();
    358  code->use_prefix_code = static_cast<bool>(br->ReadFixedBits<1>());
    359  if (code->use_prefix_code) {
    360    code->log_alpha_size = PREFIX_MAX_BITS;
    361  } else {
    362    code->log_alpha_size = br->ReadFixedBits<2>() + 5;
    363  }
    364  code->uint_config.resize(num_histograms);
    365  JXL_RETURN_IF_ERROR(
    366      DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br));
    367  const size_t max_alphabet_size = 1 << code->log_alpha_size;
    368  JXL_RETURN_IF_ERROR(DecodeANSCodes(memory_manager, num_histograms,
    369                                     max_alphabet_size, br, code));
    370  return true;
    371 }
    372 
    373 StatusOr<ANSSymbolReader> ANSSymbolReader::Create(const ANSCode* code,
    374                                                  BitReader* JXL_RESTRICT br,
    375                                                  size_t distance_multiplier) {
    376  AlignedMemory lz77_window_storage;
    377  if (code->lz77.enabled) {
    378    JxlMemoryManager* memory_manager = code->memory_manager;
    379    JXL_ASSIGN_OR_RETURN(
    380        lz77_window_storage,
    381        AlignedMemory::Create(memory_manager, kWindowSize * sizeof(uint32_t)));
    382  }
    383  return ANSSymbolReader(code, br, distance_multiplier,
    384                         std::move(lz77_window_storage));
    385 }
    386 
    387 ANSSymbolReader::ANSSymbolReader(const ANSCode* code,
    388                                 BitReader* JXL_RESTRICT br,
    389                                 size_t distance_multiplier,
    390                                 AlignedMemory&& lz77_window_storage)
    391    : alias_tables_(code->alias_tables.address<AliasTable::Entry>()),
    392      huffman_data_(code->huffman_data.data()),
    393      use_prefix_code_(code->use_prefix_code),
    394      configs(code->uint_config.data()),
    395      lz77_window_storage_(std::move(lz77_window_storage)) {
    396  if (!use_prefix_code_) {
    397    state_ = static_cast<uint32_t>(br->ReadFixedBits<32>());
    398    log_alpha_size_ = code->log_alpha_size;
    399    log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size;
    400    entry_size_minus_1_ = (1 << log_entry_size_) - 1;
    401  } else {
    402    state_ = (ANS_SIGNATURE << 16u);
    403  }
    404  if (!code->lz77.enabled) return;
    405  lz77_window_ = lz77_window_storage_.address<uint32_t>();
    406  lz77_ctx_ = code->lz77.nonserialized_distance_context;
    407  lz77_length_uint_ = code->lz77.length_uint_config;
    408  lz77_threshold_ = code->lz77.min_symbol;
    409  lz77_min_length_ = code->lz77.min_length;
    410  num_special_distances_ = distance_multiplier == 0 ? 0 : kNumSpecialDistances;
    411  for (size_t i = 0; i < num_special_distances_; i++) {
    412    special_distances_[i] = SpecialDistance(i, distance_multiplier);
    413  }
    414 }
    415 
    416 }  // namespace jxl