ans_test.cc (10109B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include <jxl/memory_manager.h> 7 8 #include <cstddef> 9 #include <cstdint> 10 #include <vector> 11 12 #include "lib/jxl/ans_params.h" 13 #include "lib/jxl/base/random.h" 14 #include "lib/jxl/base/status.h" 15 #include "lib/jxl/dec_ans.h" 16 #include "lib/jxl/dec_bit_reader.h" 17 #include "lib/jxl/enc_ans.h" 18 #include "lib/jxl/enc_aux_out.h" 19 #include "lib/jxl/enc_bit_writer.h" 20 #include "lib/jxl/test_memory_manager.h" 21 #include "lib/jxl/test_utils.h" 22 #include "lib/jxl/testing.h" 23 24 namespace jxl { 25 namespace { 26 27 void RoundtripTestcase(int n_histograms, int alphabet_size, 28 const std::vector<Token>& input_values) { 29 JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); 30 constexpr uint16_t kMagic1 = 0x9e33; 31 constexpr uint16_t kMagic2 = 0x8b04; 32 33 BitWriter writer{memory_manager}; 34 // Space for magic bytes. 35 ASSERT_TRUE(writer.WithMaxBits(16, LayerType::Header, nullptr, [&] { 36 writer.Write(16, kMagic1); 37 return true; 38 })); 39 40 std::vector<uint8_t> context_map; 41 EntropyEncodingData codes; 42 std::vector<std::vector<Token>> input_values_vec; 43 input_values_vec.push_back(input_values); 44 45 JXL_TEST_ASSIGN_OR_DIE( 46 size_t cost, 47 BuildAndEncodeHistograms(memory_manager, HistogramParams(), n_histograms, 48 input_values_vec, &codes, &context_map, &writer, 49 LayerType::Header, nullptr)); 50 (void)cost; 51 ASSERT_TRUE(WriteTokens(input_values_vec[0], codes, context_map, 0, &writer, 52 LayerType::Header, nullptr)); 53 54 // Magic bytes + padding 55 ASSERT_TRUE(writer.WithMaxBits(24, LayerType::Header, nullptr, [&] { 56 writer.Write(16, kMagic2); 57 writer.ZeroPadToByte(); 58 return true; 59 })); 60 61 // We do not truncate the output. Reading past the end reads out zeroes 62 // anyway. 63 BitReader br(writer.GetSpan()); 64 65 ASSERT_EQ(br.ReadBits(16), kMagic1); 66 67 std::vector<uint8_t> dec_context_map; 68 ANSCode decoded_codes; 69 ASSERT_TRUE(DecodeHistograms(memory_manager, &br, n_histograms, 70 &decoded_codes, &dec_context_map)); 71 ASSERT_EQ(dec_context_map, context_map); 72 JXL_TEST_ASSIGN_OR_DIE(ANSSymbolReader reader, 73 ANSSymbolReader::Create(&decoded_codes, &br)); 74 75 for (const Token& symbol : input_values) { 76 uint32_t read_symbol = 77 reader.ReadHybridUint(symbol.context, &br, dec_context_map); 78 ASSERT_EQ(read_symbol, symbol.value); 79 } 80 ASSERT_TRUE(reader.CheckANSFinalState()); 81 82 ASSERT_EQ(br.ReadBits(16), kMagic2); 83 EXPECT_TRUE(br.Close()); 84 } 85 86 TEST(ANSTest, EmptyRoundtrip) { 87 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector<Token>()); 88 } 89 90 TEST(ANSTest, SingleSymbolRoundtrip) { 91 for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { 92 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}}); 93 } 94 for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { 95 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, 96 std::vector<Token>(1024, {0, i})); 97 } 98 } 99 100 #if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ 101 defined(THREAD_SANITIZER) 102 constexpr size_t kReps = 3; 103 #else 104 constexpr size_t kReps = 10; 105 #endif 106 107 void RoundtripRandomStream(int alphabet_size, size_t reps = kReps, 108 size_t num = 1 << 18) { 109 constexpr int kNumHistograms = 3; 110 Rng rng(0); 111 for (size_t i = 0; i < reps; i++) { 112 std::vector<Token> symbols; 113 for (size_t j = 0; j < num; j++) { 114 int context = rng.UniformI(0, kNumHistograms); 115 int value = rng.UniformU(0, alphabet_size); 116 symbols.emplace_back(context, value); 117 } 118 RoundtripTestcase(kNumHistograms, alphabet_size, symbols); 119 } 120 } 121 122 void RoundtripRandomUnbalancedStream(int alphabet_size) { 123 constexpr int kNumHistograms = 3; 124 constexpr int kPrecision = 1 << 10; 125 Rng rng(0); 126 for (size_t i = 0; i < kReps; i++) { 127 std::vector<int> distributions[kNumHistograms] = {}; 128 for (auto& distr : distributions) { 129 distr.resize(kPrecision); 130 int symbol = 0; 131 int remaining = 1; 132 for (int k = 0; k < kPrecision; k++) { 133 if (remaining == 0) { 134 if (symbol < alphabet_size - 1) symbol++; 135 // There is no meaning behind this distribution: it's anything that 136 // will create a nonuniform distribution and won't have too few 137 // symbols usually. Also we want different distributions we get to be 138 // sufficiently dissimilar. 139 remaining = rng.UniformU(0, kPrecision - k + 1); 140 } 141 distr[k] = symbol; 142 remaining--; 143 } 144 } 145 std::vector<Token> symbols; 146 for (int j = 0; j < 1 << 18; j++) { 147 int context = rng.UniformI(0, kNumHistograms); 148 int value = rng.UniformU(0, kPrecision); 149 symbols.emplace_back(context, value); 150 } 151 RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols); 152 } 153 } 154 155 TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); } 156 157 TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); } 158 159 TEST(ANSTest, RandomStreamRoundtripBig) { 160 RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE); 161 } 162 163 TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) { 164 RoundtripRandomUnbalancedStream(3); 165 } 166 167 TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) { 168 RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE); 169 } 170 171 TEST(ANSTest, UintConfigRoundtrip) { 172 JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); 173 for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) { 174 std::vector<HybridUintConfig> uint_config; 175 std::vector<HybridUintConfig> uint_config_dec; 176 for (size_t i = 0; i < log_alpha_size; i++) { 177 for (size_t j = 0; j <= i; j++) { 178 for (size_t k = 0; k <= i - j; k++) { 179 uint_config.emplace_back(i, j, k); 180 } 181 } 182 } 183 uint_config.emplace_back(log_alpha_size, 0, 0); 184 uint_config_dec.resize(uint_config.size()); 185 BitWriter writer{memory_manager}; 186 ASSERT_TRUE(writer.WithMaxBits( 187 10 * uint_config.size(), LayerType::Header, nullptr, [&] { 188 EncodeUintConfigs(uint_config, &writer, log_alpha_size); 189 return true; 190 })); 191 writer.ZeroPadToByte(); 192 BitReader br(writer.GetSpan()); 193 EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br)); 194 EXPECT_TRUE(br.Close()); 195 for (size_t i = 0; i < uint_config.size(); i++) { 196 EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token); 197 EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token); 198 EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token); 199 } 200 } 201 } 202 203 void TestCheckpointing(bool ans, bool lz77) { 204 JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); 205 std::vector<std::vector<Token>> input_values(1); 206 for (size_t i = 0; i < 1024; i++) { 207 input_values[0].emplace_back(0, i % 4); 208 } 209 // up to lz77 window size. 210 for (size_t i = 0; i < (1 << 20) - 1022; i++) { 211 input_values[0].emplace_back(0, (i % 5) + 4); 212 } 213 // Ensure that when the window wraps around, new values are different. 214 input_values[0].emplace_back(0, 0); 215 for (size_t i = 0; i < 1024; i++) { 216 input_values[0].emplace_back(0, i % 4); 217 } 218 219 std::vector<uint8_t> context_map; 220 EntropyEncodingData codes; 221 HistogramParams params; 222 params.lz77_method = lz77 ? HistogramParams::LZ77Method::kLZ77 223 : HistogramParams::LZ77Method::kNone; 224 params.force_huffman = !ans; 225 226 BitWriter writer{memory_manager}; 227 { 228 auto input_values_copy = input_values; 229 JXL_TEST_ASSIGN_OR_DIE( 230 size_t cost, BuildAndEncodeHistograms( 231 memory_manager, params, 1, input_values_copy, &codes, 232 &context_map, &writer, LayerType::Header, nullptr)); 233 (void)cost; 234 ASSERT_TRUE(WriteTokens(input_values_copy[0], codes, context_map, 0, 235 &writer, LayerType::Header, nullptr)); 236 writer.ZeroPadToByte(); 237 } 238 239 // We do not truncate the output. Reading past the end reads out zeroes 240 // anyway. 241 BitReader br(writer.GetSpan()); 242 Status status = true; 243 { 244 BitReaderScopedCloser bc(br, status); 245 246 std::vector<uint8_t> dec_context_map; 247 ANSCode decoded_codes; 248 ASSERT_TRUE(DecodeHistograms(memory_manager, &br, 1, &decoded_codes, 249 &dec_context_map)); 250 ASSERT_EQ(dec_context_map, context_map); 251 JXL_TEST_ASSIGN_OR_DIE(ANSSymbolReader reader, 252 ANSSymbolReader::Create(&decoded_codes, &br)); 253 254 ANSSymbolReader::Checkpoint checkpoint; 255 size_t br_pos = 0; 256 constexpr size_t kInterval = ANSSymbolReader::kMaxCheckpointInterval - 2; 257 for (size_t i = 0; i < input_values[0].size(); i++) { 258 if (i % kInterval == 0 && i > 0) { 259 reader.Restore(checkpoint); 260 ASSERT_TRUE(br.Close()); 261 br = BitReader(writer.GetSpan()); 262 br.SkipBits(br_pos); 263 for (size_t j = i - kInterval; j < i; j++) { 264 Token symbol = input_values[0][j]; 265 uint32_t read_symbol = 266 reader.ReadHybridUint(symbol.context, &br, dec_context_map); 267 ASSERT_EQ(read_symbol, symbol.value) << "j = " << j; 268 } 269 } 270 if (i % kInterval == 0) { 271 reader.Save(&checkpoint); 272 br_pos = br.TotalBitsConsumed(); 273 } 274 Token symbol = input_values[0][i]; 275 uint32_t read_symbol = 276 reader.ReadHybridUint(symbol.context, &br, dec_context_map); 277 ASSERT_EQ(read_symbol, symbol.value) << "i = " << i; 278 } 279 ASSERT_TRUE(reader.CheckANSFinalState()); 280 } 281 EXPECT_TRUE(status); 282 } 283 284 TEST(ANSTest, TestCheckpointingANS) { 285 TestCheckpointing(/*ans=*/true, /*lz77=*/false); 286 } 287 288 TEST(ANSTest, TestCheckpointingPrefix) { 289 TestCheckpointing(/*ans=*/false, /*lz77=*/false); 290 } 291 292 TEST(ANSTest, TestCheckpointingANSLZ77) { 293 TestCheckpointing(/*ans=*/true, /*lz77=*/true); 294 } 295 296 TEST(ANSTest, TestCheckpointingPrefixLZ77) { 297 TestCheckpointing(/*ans=*/false, /*lz77=*/true); 298 } 299 300 } // namespace 301 } // namespace jxl