dec_ans.h (19460B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #ifndef LIB_JXL_DEC_ANS_H_ 7 #define LIB_JXL_DEC_ANS_H_ 8 9 // Library to decode the ANS population counts from the bit-stream and build a 10 // decoding table from them. 11 12 #include <jxl/memory_manager.h> 13 #include <jxl/types.h> 14 15 #include <algorithm> 16 #include <cstddef> 17 #include <cstdint> 18 #include <cstring> 19 #include <vector> 20 21 #include "lib/jxl/ans_common.h" 22 #include "lib/jxl/ans_params.h" 23 #include "lib/jxl/base/bits.h" 24 #include "lib/jxl/base/compiler_specific.h" 25 #include "lib/jxl/base/status.h" 26 #include "lib/jxl/dec_bit_reader.h" 27 #include "lib/jxl/dec_huffman.h" 28 #include "lib/jxl/field_encodings.h" 29 #include "lib/jxl/memory_manager_internal.h" 30 31 namespace jxl { 32 33 class ANSSymbolReader; 34 35 // Experiments show that best performance is typically achieved for a 36 // split-exponent of 3 or 4. Trend seems to be that '4' is better 37 // for large-ish pictures, and '3' better for rather small-ish pictures. 38 // This is plausible - the more special symbols we have, the better 39 // statistics we need to get a benefit out of them. 40 41 // Our hybrid-encoding scheme has dedicated tokens for the smallest 42 // (1 << split_exponents) numbers, and for the rest 43 // encodes (number of bits) + (msb_in_token sub-leading binary digits) + 44 // (lsb_in_token lowest binary digits) in the token, with the remaining bits 45 // then being encoded as data. 46 // 47 // Example with split_exponent = 4, msb_in_token = 2, lsb_in_token = 0. 48 // 49 // Numbers N in [0 .. 15]: 50 // These get represented as (token=N, bits=''). 51 // Numbers N >= 16: 52 // If n is such that 2**n <= N < 2**(n+1), 53 // and m = N - 2**n is the 'mantissa', 54 // these get represented as: 55 // (token=split_token + 56 // ((n - split_exponent) * 4) + 57 // (m >> (n - msb_in_token)), 58 // bits=m & (1 << (n - msb_in_token)) - 1) 59 // Specifically, we would get: 60 // N = 0 - 15: (token=N, nbits=0, bits='') 61 // N = 16 (10000): (token=16, nbits=2, bits='00') 62 // N = 17 (10001): (token=16, nbits=2, bits='01') 63 // N = 20 (10100): (token=17, nbits=2, bits='00') 64 // N = 24 (11000): (token=18, nbits=2, bits='00') 65 // N = 28 (11100): (token=19, nbits=2, bits='00') 66 // N = 32 (100000): (token=20, nbits=3, bits='000') 67 // N = 65535: (token=63, nbits=13, bits='1111111111111') 68 struct HybridUintConfig { 69 uint32_t split_exponent; 70 uint32_t split_token; 71 uint32_t msb_in_token; 72 uint32_t lsb_in_token; 73 JXL_INLINE void Encode(uint32_t value, uint32_t* JXL_RESTRICT token, 74 uint32_t* JXL_RESTRICT nbits, 75 uint32_t* JXL_RESTRICT bits) const { 76 if (value < split_token) { 77 *token = value; 78 *nbits = 0; 79 *bits = 0; 80 } else { 81 uint32_t n = FloorLog2Nonzero(value); 82 uint32_t m = value - (1 << n); 83 *token = split_token + 84 ((n - split_exponent) << (msb_in_token + lsb_in_token)) + 85 ((m >> (n - msb_in_token)) << lsb_in_token) + 86 (m & ((1 << lsb_in_token) - 1)); 87 *nbits = n - msb_in_token - lsb_in_token; 88 *bits = (value >> lsb_in_token) & ((1UL << *nbits) - 1); 89 } 90 } 91 92 explicit HybridUintConfig(uint32_t split_exponent = 4, 93 uint32_t msb_in_token = 2, 94 uint32_t lsb_in_token = 0) 95 : split_exponent(split_exponent), 96 split_token(1 << split_exponent), 97 msb_in_token(msb_in_token), 98 lsb_in_token(lsb_in_token) { 99 JXL_DASSERT(split_exponent >= msb_in_token + lsb_in_token); 100 } 101 }; 102 103 struct LZ77Params : public Fields { 104 LZ77Params(); 105 JXL_FIELDS_NAME(LZ77Params) 106 Status VisitFields(Visitor* JXL_RESTRICT visitor) override; 107 bool enabled; 108 109 // Symbols above min_symbol use a special hybrid uint encoding and 110 // represent a length, to be added to min_length. 111 uint32_t min_symbol; 112 uint32_t min_length; 113 114 // Not serialized by VisitFields. 115 HybridUintConfig length_uint_config{0, 0, 0}; 116 117 size_t nonserialized_distance_context; 118 }; 119 120 static constexpr size_t kWindowSize = 1 << 20; 121 static constexpr size_t kNumSpecialDistances = 120; 122 // Table of special distance codes from WebP lossless. 123 static constexpr int8_t kSpecialDistances[kNumSpecialDistances][2] = { 124 {0, 1}, {1, 0}, {1, 1}, {-1, 1}, {0, 2}, {2, 0}, {1, 2}, {-1, 2}, 125 {2, 1}, {-2, 1}, {2, 2}, {-2, 2}, {0, 3}, {3, 0}, {1, 3}, {-1, 3}, 126 {3, 1}, {-3, 1}, {2, 3}, {-2, 3}, {3, 2}, {-3, 2}, {0, 4}, {4, 0}, 127 {1, 4}, {-1, 4}, {4, 1}, {-4, 1}, {3, 3}, {-3, 3}, {2, 4}, {-2, 4}, 128 {4, 2}, {-4, 2}, {0, 5}, {3, 4}, {-3, 4}, {4, 3}, {-4, 3}, {5, 0}, 129 {1, 5}, {-1, 5}, {5, 1}, {-5, 1}, {2, 5}, {-2, 5}, {5, 2}, {-5, 2}, 130 {4, 4}, {-4, 4}, {3, 5}, {-3, 5}, {5, 3}, {-5, 3}, {0, 6}, {6, 0}, 131 {1, 6}, {-1, 6}, {6, 1}, {-6, 1}, {2, 6}, {-2, 6}, {6, 2}, {-6, 2}, 132 {4, 5}, {-4, 5}, {5, 4}, {-5, 4}, {3, 6}, {-3, 6}, {6, 3}, {-6, 3}, 133 {0, 7}, {7, 0}, {1, 7}, {-1, 7}, {5, 5}, {-5, 5}, {7, 1}, {-7, 1}, 134 {4, 6}, {-4, 6}, {6, 4}, {-6, 4}, {2, 7}, {-2, 7}, {7, 2}, {-7, 2}, 135 {3, 7}, {-3, 7}, {7, 3}, {-7, 3}, {5, 6}, {-5, 6}, {6, 5}, {-6, 5}, 136 {8, 0}, {4, 7}, {-4, 7}, {7, 4}, {-7, 4}, {8, 1}, {8, 2}, {6, 6}, 137 {-6, 6}, {8, 3}, {5, 7}, {-5, 7}, {7, 5}, {-7, 5}, {8, 4}, {6, 7}, 138 {-6, 7}, {7, 6}, {-7, 6}, {8, 5}, {7, 7}, {-7, 7}, {8, 6}, {8, 7}}; 139 static JXL_INLINE int SpecialDistance(size_t index, int multiplier) { 140 int dist = kSpecialDistances[index][0] + 141 static_cast<int>(multiplier) * kSpecialDistances[index][1]; 142 return (dist > 1) ? dist : 1; 143 } 144 145 struct ANSCode { 146 AlignedMemory alias_tables; 147 std::vector<HuffmanDecodingData> huffman_data; 148 std::vector<HybridUintConfig> uint_config; 149 std::vector<int> degenerate_symbols; 150 bool use_prefix_code; 151 uint8_t log_alpha_size; // for ANS. 152 LZ77Params lz77; 153 // Maximum number of bits necessary to represent the result of a 154 // ReadHybridUint call done with this ANSCode. 155 size_t max_num_bits = 0; 156 JxlMemoryManager* memory_manager; 157 void UpdateMaxNumBits(size_t ctx, size_t symbol); 158 }; 159 160 class ANSSymbolReader { 161 public: 162 // Invalid symbol reader, to be overwritten. 163 ANSSymbolReader() = default; 164 static StatusOr<ANSSymbolReader> Create(const ANSCode* code, 165 BitReader* JXL_RESTRICT br, 166 size_t distance_multiplier = 0); 167 168 JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx, 169 BitReader* JXL_RESTRICT br) { 170 const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); 171 172 const AliasTable::Entry* table = 173 &alias_tables_[histo_idx << log_alpha_size_]; 174 const AliasTable::Symbol symbol = 175 AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); 176 state_ = symbol.freq * (state_ >> ANS_LOG_TAB_SIZE) + symbol.offset; 177 178 #if JXL_TRUE 179 // Branchless version is about equally fast on SKX. 180 const uint32_t new_state = 181 (state_ << 16u) | static_cast<uint32_t>(br->PeekFixedBits<16>()); 182 const bool normalize = state_ < (1u << 16u); 183 state_ = normalize ? new_state : state_; 184 br->Consume(normalize ? 16 : 0); 185 #else 186 if (JXL_UNLIKELY(state_ < (1u << 16u))) { 187 state_ = (state_ << 16u) | br->PeekFixedBits<16>(); 188 br->Consume(16); 189 } 190 #endif 191 const uint32_t next_res = state_ & (ANS_TAB_SIZE - 1u); 192 AliasTable::Prefetch(table, next_res, log_entry_size_); 193 194 return symbol.value; 195 } 196 197 JXL_INLINE size_t ReadSymbolHuffWithoutRefill(const size_t histo_idx, 198 BitReader* JXL_RESTRICT br) { 199 return huffman_data_[histo_idx].ReadSymbol(br); 200 } 201 202 JXL_INLINE size_t ReadSymbolWithoutRefill(const size_t histo_idx, 203 BitReader* JXL_RESTRICT br) { 204 // TODO(veluca): hoist if in hotter loops. 205 if (JXL_UNLIKELY(use_prefix_code_)) { 206 return ReadSymbolHuffWithoutRefill(histo_idx, br); 207 } 208 return ReadSymbolANSWithoutRefill(histo_idx, br); 209 } 210 211 JXL_INLINE size_t ReadSymbol(const size_t histo_idx, 212 BitReader* JXL_RESTRICT br) { 213 br->Refill(); 214 return ReadSymbolWithoutRefill(histo_idx, br); 215 } 216 217 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION 218 bool CheckANSFinalState() const { return true; } 219 #else 220 bool CheckANSFinalState() const { return state_ == (ANS_SIGNATURE << 16u); } 221 #endif 222 223 template <typename BitReader> 224 static JXL_INLINE uint32_t ReadHybridUintConfig( 225 const HybridUintConfig& config, size_t token, BitReader* br) { 226 size_t split_token = config.split_token; 227 size_t msb_in_token = config.msb_in_token; 228 size_t lsb_in_token = config.lsb_in_token; 229 size_t split_exponent = config.split_exponent; 230 // Fast-track version of hybrid integer decoding. 231 if (token < split_token) return token; 232 uint32_t nbits = split_exponent - (msb_in_token + lsb_in_token) + 233 ((token - split_token) >> (msb_in_token + lsb_in_token)); 234 // Max amount of bits for ReadBits is 32 and max valid left shift is 29 235 // bits. However, for speed no error is propagated here, instead limit the 236 // nbits size. If nbits > 29, the code stream is invalid, but no error is 237 // returned. 238 // Note that in most cases we will emit an error if the histogram allows 239 // representing numbers that would cause invalid shifts, but we need to 240 // keep this check as when LZ77 is enabled it might make sense to have an 241 // histogram that could in principle cause invalid shifts. 242 nbits &= 31u; 243 uint32_t low = token & ((1 << lsb_in_token) - 1); 244 token >>= lsb_in_token; 245 const size_t bits = br->PeekBits(nbits); 246 br->Consume(nbits); 247 size_t ret = (((((1 << msb_in_token) | (token & ((1 << msb_in_token) - 1))) 248 << nbits) | 249 bits) 250 << lsb_in_token) | 251 low; 252 // TODO(eustas): mark BitReader as unhealthy if nbits > 29 or ret does not 253 // fit uint32_t 254 return static_cast<uint32_t>(ret); 255 } 256 257 // Takes a *clustered* idx. Can only use if HuffRleOnly() is true. 258 JXL_INLINE void ReadHybridUintClusteredHuffRleOnly(size_t ctx, 259 BitReader* JXL_RESTRICT br, 260 uint32_t* value, 261 uint32_t* run) { 262 JXL_DASSERT(IsHuffRleOnly()); 263 br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits 264 size_t token = ReadSymbolHuffWithoutRefill(ctx, br); 265 if (JXL_UNLIKELY(token >= lz77_threshold_)) { 266 *run = 267 ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + 268 lz77_min_length_ - 1; 269 return; 270 } 271 *value = ReadHybridUintConfig(configs[ctx], token, br); 272 } 273 bool IsHuffRleOnly() const { 274 if (lz77_window_ == nullptr) return false; 275 if (!use_prefix_code_) return false; 276 for (size_t i = 0; i < kHuffmanTableBits; i++) { 277 if (huffman_data_[lz77_ctx_].table_[i].bits) return false; 278 if (huffman_data_[lz77_ctx_].table_[i].value != 1) return false; 279 } 280 if (configs[lz77_ctx_].split_token > 1) return false; 281 return true; 282 } 283 bool UsesLZ77() { return lz77_window_ != nullptr; } 284 285 // Takes a *clustered* idx. Inlined, for use in hot paths. 286 template <bool uses_lz77> 287 JXL_INLINE size_t ReadHybridUintClusteredInlined(size_t ctx, 288 BitReader* JXL_RESTRICT br) { 289 if (uses_lz77) { 290 if (JXL_UNLIKELY(num_to_copy_ > 0)) { 291 size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; 292 num_to_copy_--; 293 lz77_window_[(num_decoded_++) & kWindowMask] = ret; 294 return ret; 295 } 296 } 297 298 br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits 299 size_t token = ReadSymbolWithoutRefill(ctx, br); 300 if (uses_lz77) { 301 if (JXL_UNLIKELY(token >= lz77_threshold_)) { 302 num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_, 303 token - lz77_threshold_, br) + 304 lz77_min_length_; 305 br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits 306 // Distance code. 307 size_t token = ReadSymbolWithoutRefill(lz77_ctx_, br); 308 size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], token, br); 309 if (JXL_LIKELY(distance < num_special_distances_)) { 310 distance = special_distances_[distance]; 311 } else { 312 distance = distance + 1 - num_special_distances_; 313 } 314 if (JXL_UNLIKELY(distance > num_decoded_)) { 315 distance = num_decoded_; 316 } 317 if (JXL_UNLIKELY(distance > kWindowSize)) { 318 distance = kWindowSize; 319 } 320 copy_pos_ = num_decoded_ - distance; 321 if (JXL_UNLIKELY(distance == 0)) { 322 JXL_DASSERT(lz77_window_ != nullptr); 323 // distance 0 -> num_decoded_ == copy_pos_ == 0 324 size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); 325 memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); 326 } 327 // TODO(eustas): overflow; mark BitReader as unhealthy 328 if (num_to_copy_ < lz77_min_length_) return 0; 329 // the code below is the same as doing this: 330 // return ReadHybridUintClustered<uses_lz77>(ctx, br); 331 // but gcc doesn't like recursive inlining 332 333 size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; 334 num_to_copy_--; 335 lz77_window_[(num_decoded_++) & kWindowMask] = ret; 336 return ret; 337 } 338 } 339 size_t ret = ReadHybridUintConfig(configs[ctx], token, br); 340 if (uses_lz77 && lz77_window_) 341 lz77_window_[(num_decoded_++) & kWindowMask] = ret; 342 return ret; 343 } 344 345 // same but not inlined 346 template <bool uses_lz77> 347 size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { 348 return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); 349 } 350 351 // inlined only in the no-lz77 case 352 template <bool uses_lz77> 353 JXL_INLINE size_t 354 ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) { 355 if (uses_lz77) { 356 return ReadHybridUintClustered<uses_lz77>(ctx, br); 357 } else { 358 return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); 359 } 360 } 361 362 // inlined, for use in hot paths 363 template <bool uses_lz77> 364 JXL_INLINE size_t 365 ReadHybridUintInlined(size_t ctx, BitReader* JXL_RESTRICT br, 366 const std::vector<uint8_t>& context_map) { 367 return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br); 368 } 369 370 // not inlined, for use in non-hot paths 371 size_t ReadHybridUint(size_t ctx, BitReader* JXL_RESTRICT br, 372 const std::vector<uint8_t>& context_map) { 373 return ReadHybridUintClustered</*uses_lz77=*/true>(context_map[ctx], br); 374 } 375 376 // ctx is a *clustered* context! 377 // This function will modify the ANS state as if `count` symbols have been 378 // decoded. 379 bool IsSingleValueAndAdvance(size_t ctx, uint32_t* value, size_t count) { 380 // TODO(veluca): No optimization for Huffman mode yet. 381 if (use_prefix_code_) return false; 382 // TODO(eustas): propagate "degenerate_symbol" to simplify this method. 383 const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); 384 const AliasTable::Entry* table = &alias_tables_[ctx << log_alpha_size_]; 385 AliasTable::Symbol symbol = 386 AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); 387 if (symbol.freq != ANS_TAB_SIZE) return false; 388 if (configs[ctx].split_token <= symbol.value) return false; 389 if (symbol.value >= lz77_threshold_) return false; 390 *value = symbol.value; 391 if (lz77_window_) { 392 for (size_t i = 0; i < count; i++) { 393 lz77_window_[(num_decoded_++) & kWindowMask] = symbol.value; 394 } 395 } 396 return true; 397 } 398 399 static constexpr size_t kMaxCheckpointInterval = 512; 400 struct Checkpoint { 401 uint32_t state; 402 uint32_t num_to_copy; 403 uint32_t copy_pos; 404 uint32_t num_decoded; 405 uint32_t lz77_window[kMaxCheckpointInterval]; 406 }; 407 void Save(Checkpoint* checkpoint) { 408 checkpoint->state = state_; 409 checkpoint->num_decoded = num_decoded_; 410 checkpoint->num_to_copy = num_to_copy_; 411 checkpoint->copy_pos = copy_pos_; 412 if (lz77_window_) { 413 size_t win_start = num_decoded_ & kWindowMask; 414 size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; 415 if (win_end > win_start) { 416 memcpy(checkpoint->lz77_window, lz77_window_ + win_start, 417 (win_end - win_start) * sizeof(*lz77_window_)); 418 } else { 419 memcpy(checkpoint->lz77_window, lz77_window_ + win_start, 420 (kWindowSize - win_start) * sizeof(*lz77_window_)); 421 memcpy(checkpoint->lz77_window + (kWindowSize - win_start), 422 lz77_window_, win_end * sizeof(*lz77_window_)); 423 } 424 } 425 } 426 void Restore(const Checkpoint& checkpoint) { 427 state_ = checkpoint.state; 428 JXL_DASSERT(num_decoded_ <= 429 checkpoint.num_decoded + kMaxCheckpointInterval); 430 num_decoded_ = checkpoint.num_decoded; 431 num_to_copy_ = checkpoint.num_to_copy; 432 copy_pos_ = checkpoint.copy_pos; 433 if (lz77_window_) { 434 size_t win_start = num_decoded_ & kWindowMask; 435 size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; 436 if (win_end > win_start) { 437 memcpy(lz77_window_ + win_start, checkpoint.lz77_window, 438 (win_end - win_start) * sizeof(*lz77_window_)); 439 } else { 440 memcpy(lz77_window_ + win_start, checkpoint.lz77_window, 441 (kWindowSize - win_start) * sizeof(*lz77_window_)); 442 memcpy(lz77_window_, checkpoint.lz77_window + (kWindowSize - win_start), 443 win_end * sizeof(*lz77_window_)); 444 } 445 } 446 } 447 448 private: 449 ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, 450 size_t distance_multiplier, 451 AlignedMemory&& lz77_window_storage); 452 453 const AliasTable::Entry* JXL_RESTRICT alias_tables_; // not owned 454 const HuffmanDecodingData* huffman_data_; 455 bool use_prefix_code_; 456 uint32_t state_ = ANS_SIGNATURE << 16u; 457 const HybridUintConfig* JXL_RESTRICT configs; 458 uint32_t log_alpha_size_{}; 459 uint32_t log_entry_size_{}; 460 uint32_t entry_size_minus_1_{}; 461 462 // LZ77 structures and constants. 463 static constexpr size_t kWindowMask = kWindowSize - 1; 464 // a std::vector incurs unacceptable decoding speed loss because of 465 // initialization. 466 AlignedMemory lz77_window_storage_; 467 uint32_t* lz77_window_ = nullptr; 468 uint32_t num_decoded_ = 0; 469 uint32_t num_to_copy_ = 0; 470 uint32_t copy_pos_ = 0; 471 uint32_t lz77_ctx_ = 0; 472 uint32_t lz77_min_length_ = 0; 473 uint32_t lz77_threshold_ = 1 << 20; // bigger than any symbol. 474 HybridUintConfig lz77_length_uint_; 475 uint32_t special_distances_[kNumSpecialDistances]{}; 476 uint32_t num_special_distances_{}; 477 }; 478 479 Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br, 480 size_t num_contexts, ANSCode* code, 481 std::vector<uint8_t>* context_map, 482 bool disallow_lz77 = false); 483 484 // Exposed for tests. 485 Status DecodeUintConfigs(size_t log_alpha_size, 486 std::vector<HybridUintConfig>* uint_config, 487 BitReader* br); 488 489 } // namespace jxl 490 491 #endif // LIB_JXL_DEC_ANS_H_