tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ans_test.cc (10109B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include <jxl/memory_manager.h>
      7 
      8 #include <cstddef>
      9 #include <cstdint>
     10 #include <vector>
     11 
     12 #include "lib/jxl/ans_params.h"
     13 #include "lib/jxl/base/random.h"
     14 #include "lib/jxl/base/status.h"
     15 #include "lib/jxl/dec_ans.h"
     16 #include "lib/jxl/dec_bit_reader.h"
     17 #include "lib/jxl/enc_ans.h"
     18 #include "lib/jxl/enc_aux_out.h"
     19 #include "lib/jxl/enc_bit_writer.h"
     20 #include "lib/jxl/test_memory_manager.h"
     21 #include "lib/jxl/test_utils.h"
     22 #include "lib/jxl/testing.h"
     23 
     24 namespace jxl {
     25 namespace {
     26 
     27 void RoundtripTestcase(int n_histograms, int alphabet_size,
     28                       const std::vector<Token>& input_values) {
     29  JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
     30  constexpr uint16_t kMagic1 = 0x9e33;
     31  constexpr uint16_t kMagic2 = 0x8b04;
     32 
     33  BitWriter writer{memory_manager};
     34  // Space for magic bytes.
     35  ASSERT_TRUE(writer.WithMaxBits(16, LayerType::Header, nullptr, [&] {
     36    writer.Write(16, kMagic1);
     37    return true;
     38  }));
     39 
     40  std::vector<uint8_t> context_map;
     41  EntropyEncodingData codes;
     42  std::vector<std::vector<Token>> input_values_vec;
     43  input_values_vec.push_back(input_values);
     44 
     45  JXL_TEST_ASSIGN_OR_DIE(
     46      size_t cost,
     47      BuildAndEncodeHistograms(memory_manager, HistogramParams(), n_histograms,
     48                               input_values_vec, &codes, &context_map, &writer,
     49                               LayerType::Header, nullptr));
     50  (void)cost;
     51  ASSERT_TRUE(WriteTokens(input_values_vec[0], codes, context_map, 0, &writer,
     52                          LayerType::Header, nullptr));
     53 
     54  // Magic bytes + padding
     55  ASSERT_TRUE(writer.WithMaxBits(24, LayerType::Header, nullptr, [&] {
     56    writer.Write(16, kMagic2);
     57    writer.ZeroPadToByte();
     58    return true;
     59  }));
     60 
     61  // We do not truncate the output. Reading past the end reads out zeroes
     62  // anyway.
     63  BitReader br(writer.GetSpan());
     64 
     65  ASSERT_EQ(br.ReadBits(16), kMagic1);
     66 
     67  std::vector<uint8_t> dec_context_map;
     68  ANSCode decoded_codes;
     69  ASSERT_TRUE(DecodeHistograms(memory_manager, &br, n_histograms,
     70                               &decoded_codes, &dec_context_map));
     71  ASSERT_EQ(dec_context_map, context_map);
     72  JXL_TEST_ASSIGN_OR_DIE(ANSSymbolReader reader,
     73                         ANSSymbolReader::Create(&decoded_codes, &br));
     74 
     75  for (const Token& symbol : input_values) {
     76    uint32_t read_symbol =
     77        reader.ReadHybridUint(symbol.context, &br, dec_context_map);
     78    ASSERT_EQ(read_symbol, symbol.value);
     79  }
     80  ASSERT_TRUE(reader.CheckANSFinalState());
     81 
     82  ASSERT_EQ(br.ReadBits(16), kMagic2);
     83  EXPECT_TRUE(br.Close());
     84 }
     85 
     86 TEST(ANSTest, EmptyRoundtrip) {
     87  RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector<Token>());
     88 }
     89 
     90 TEST(ANSTest, SingleSymbolRoundtrip) {
     91  for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) {
     92    RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}});
     93  }
     94  for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) {
     95    RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE,
     96                      std::vector<Token>(1024, {0, i}));
     97  }
     98 }
     99 
    100 #if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \
    101    defined(THREAD_SANITIZER)
    102 constexpr size_t kReps = 3;
    103 #else
    104 constexpr size_t kReps = 10;
    105 #endif
    106 
    107 void RoundtripRandomStream(int alphabet_size, size_t reps = kReps,
    108                           size_t num = 1 << 18) {
    109  constexpr int kNumHistograms = 3;
    110  Rng rng(0);
    111  for (size_t i = 0; i < reps; i++) {
    112    std::vector<Token> symbols;
    113    for (size_t j = 0; j < num; j++) {
    114      int context = rng.UniformI(0, kNumHistograms);
    115      int value = rng.UniformU(0, alphabet_size);
    116      symbols.emplace_back(context, value);
    117    }
    118    RoundtripTestcase(kNumHistograms, alphabet_size, symbols);
    119  }
    120 }
    121 
    122 void RoundtripRandomUnbalancedStream(int alphabet_size) {
    123  constexpr int kNumHistograms = 3;
    124  constexpr int kPrecision = 1 << 10;
    125  Rng rng(0);
    126  for (size_t i = 0; i < kReps; i++) {
    127    std::vector<int> distributions[kNumHistograms] = {};
    128    for (auto& distr : distributions) {
    129      distr.resize(kPrecision);
    130      int symbol = 0;
    131      int remaining = 1;
    132      for (int k = 0; k < kPrecision; k++) {
    133        if (remaining == 0) {
    134          if (symbol < alphabet_size - 1) symbol++;
    135          // There is no meaning behind this distribution: it's anything that
    136          // will create a nonuniform distribution and won't have too few
    137          // symbols usually. Also we want different distributions we get to be
    138          // sufficiently dissimilar.
    139          remaining = rng.UniformU(0, kPrecision - k + 1);
    140        }
    141        distr[k] = symbol;
    142        remaining--;
    143      }
    144    }
    145    std::vector<Token> symbols;
    146    for (int j = 0; j < 1 << 18; j++) {
    147      int context = rng.UniformI(0, kNumHistograms);
    148      int value = rng.UniformU(0, kPrecision);
    149      symbols.emplace_back(context, value);
    150    }
    151    RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols);
    152  }
    153 }
    154 
    155 TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); }
    156 
    157 TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); }
    158 
    159 TEST(ANSTest, RandomStreamRoundtripBig) {
    160  RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE);
    161 }
    162 
    163 TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) {
    164  RoundtripRandomUnbalancedStream(3);
    165 }
    166 
    167 TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) {
    168  RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE);
    169 }
    170 
    171 TEST(ANSTest, UintConfigRoundtrip) {
    172  JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
    173  for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) {
    174    std::vector<HybridUintConfig> uint_config;
    175    std::vector<HybridUintConfig> uint_config_dec;
    176    for (size_t i = 0; i < log_alpha_size; i++) {
    177      for (size_t j = 0; j <= i; j++) {
    178        for (size_t k = 0; k <= i - j; k++) {
    179          uint_config.emplace_back(i, j, k);
    180        }
    181      }
    182    }
    183    uint_config.emplace_back(log_alpha_size, 0, 0);
    184    uint_config_dec.resize(uint_config.size());
    185    BitWriter writer{memory_manager};
    186    ASSERT_TRUE(writer.WithMaxBits(
    187        10 * uint_config.size(), LayerType::Header, nullptr, [&] {
    188          EncodeUintConfigs(uint_config, &writer, log_alpha_size);
    189          return true;
    190        }));
    191    writer.ZeroPadToByte();
    192    BitReader br(writer.GetSpan());
    193    EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br));
    194    EXPECT_TRUE(br.Close());
    195    for (size_t i = 0; i < uint_config.size(); i++) {
    196      EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token);
    197      EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token);
    198      EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token);
    199    }
    200  }
    201 }
    202 
    203 void TestCheckpointing(bool ans, bool lz77) {
    204  JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
    205  std::vector<std::vector<Token>> input_values(1);
    206  for (size_t i = 0; i < 1024; i++) {
    207    input_values[0].emplace_back(0, i % 4);
    208  }
    209  // up to lz77 window size.
    210  for (size_t i = 0; i < (1 << 20) - 1022; i++) {
    211    input_values[0].emplace_back(0, (i % 5) + 4);
    212  }
    213  // Ensure that when the window wraps around, new values are different.
    214  input_values[0].emplace_back(0, 0);
    215  for (size_t i = 0; i < 1024; i++) {
    216    input_values[0].emplace_back(0, i % 4);
    217  }
    218 
    219  std::vector<uint8_t> context_map;
    220  EntropyEncodingData codes;
    221  HistogramParams params;
    222  params.lz77_method = lz77 ? HistogramParams::LZ77Method::kLZ77
    223                            : HistogramParams::LZ77Method::kNone;
    224  params.force_huffman = !ans;
    225 
    226  BitWriter writer{memory_manager};
    227  {
    228    auto input_values_copy = input_values;
    229    JXL_TEST_ASSIGN_OR_DIE(
    230        size_t cost, BuildAndEncodeHistograms(
    231                         memory_manager, params, 1, input_values_copy, &codes,
    232                         &context_map, &writer, LayerType::Header, nullptr));
    233    (void)cost;
    234    ASSERT_TRUE(WriteTokens(input_values_copy[0], codes, context_map, 0,
    235                            &writer, LayerType::Header, nullptr));
    236    writer.ZeroPadToByte();
    237  }
    238 
    239  // We do not truncate the output. Reading past the end reads out zeroes
    240  // anyway.
    241  BitReader br(writer.GetSpan());
    242  Status status = true;
    243  {
    244    BitReaderScopedCloser bc(br, status);
    245 
    246    std::vector<uint8_t> dec_context_map;
    247    ANSCode decoded_codes;
    248    ASSERT_TRUE(DecodeHistograms(memory_manager, &br, 1, &decoded_codes,
    249                                 &dec_context_map));
    250    ASSERT_EQ(dec_context_map, context_map);
    251    JXL_TEST_ASSIGN_OR_DIE(ANSSymbolReader reader,
    252                           ANSSymbolReader::Create(&decoded_codes, &br));
    253 
    254    ANSSymbolReader::Checkpoint checkpoint;
    255    size_t br_pos = 0;
    256    constexpr size_t kInterval = ANSSymbolReader::kMaxCheckpointInterval - 2;
    257    for (size_t i = 0; i < input_values[0].size(); i++) {
    258      if (i % kInterval == 0 && i > 0) {
    259        reader.Restore(checkpoint);
    260        ASSERT_TRUE(br.Close());
    261        br = BitReader(writer.GetSpan());
    262        br.SkipBits(br_pos);
    263        for (size_t j = i - kInterval; j < i; j++) {
    264          Token symbol = input_values[0][j];
    265          uint32_t read_symbol =
    266              reader.ReadHybridUint(symbol.context, &br, dec_context_map);
    267          ASSERT_EQ(read_symbol, symbol.value) << "j = " << j;
    268        }
    269      }
    270      if (i % kInterval == 0) {
    271        reader.Save(&checkpoint);
    272        br_pos = br.TotalBitsConsumed();
    273      }
    274      Token symbol = input_values[0][i];
    275      uint32_t read_symbol =
    276          reader.ReadHybridUint(symbol.context, &br, dec_context_map);
    277      ASSERT_EQ(read_symbol, symbol.value) << "i = " << i;
    278    }
    279    ASSERT_TRUE(reader.CheckANSFinalState());
    280  }
    281  EXPECT_TRUE(status);
    282 }
    283 
    284 TEST(ANSTest, TestCheckpointingANS) {
    285  TestCheckpointing(/*ans=*/true, /*lz77=*/false);
    286 }
    287 
    288 TEST(ANSTest, TestCheckpointingPrefix) {
    289  TestCheckpointing(/*ans=*/false, /*lz77=*/false);
    290 }
    291 
    292 TEST(ANSTest, TestCheckpointingANSLZ77) {
    293  TestCheckpointing(/*ans=*/true, /*lz77=*/true);
    294 }
    295 
    296 TEST(ANSTest, TestCheckpointingPrefixLZ77) {
    297  TestCheckpointing(/*ans=*/false, /*lz77=*/true);
    298 }
    299 
    300 }  // namespace
    301 }  // namespace jxl