enc_ans.cc (70545B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_ans.h" 7 8 #include <jxl/memory_manager.h> 9 #include <jxl/types.h> 10 11 #include <algorithm> 12 #include <cmath> 13 #include <cstdint> 14 #include <limits> 15 #include <unordered_map> 16 #include <utility> 17 #include <vector> 18 19 #include "lib/jxl/ans_common.h" 20 #include "lib/jxl/base/bits.h" 21 #include "lib/jxl/base/fast_math-inl.h" 22 #include "lib/jxl/base/status.h" 23 #include "lib/jxl/dec_ans.h" 24 #include "lib/jxl/enc_ans_params.h" 25 #include "lib/jxl/enc_aux_out.h" 26 #include "lib/jxl/enc_cluster.h" 27 #include "lib/jxl/enc_context_map.h" 28 #include "lib/jxl/enc_fields.h" 29 #include "lib/jxl/enc_huffman.h" 30 #include "lib/jxl/enc_params.h" 31 #include "lib/jxl/fields.h" 32 33 namespace jxl { 34 35 namespace { 36 37 #if (!JXL_IS_DEBUG_BUILD) 38 constexpr 39 #endif 40 bool ans_fuzzer_friendly_ = false; 41 42 const int kMaxNumSymbolsForSmallCode = 4; 43 44 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table, 45 size_t alphabet_size, size_t log_alpha_size, 46 ANSEncSymbolInfo* info) { 47 size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size; 48 size_t entry_size_minus_1 = (1 << log_entry_size) - 1; 49 // create valid alias table for empty streams. 50 for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) { 51 const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s]; 52 info[s].freq_ = static_cast<uint16_t>(freq); 53 #ifdef USE_MULT_BY_RECIPROCAL 54 if (freq != 0) { 55 info[s].ifreq_ = 56 ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_; 57 } else { 58 info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but... 59 } 60 #endif 61 info[s].reverse_map_.resize(freq); 62 } 63 for (int i = 0; i < ANS_TAB_SIZE; i++) { 64 AliasTable::Symbol s = 65 AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1); 66 info[s.value].reverse_map_[s.offset] = i; 67 } 68 } 69 70 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts, 71 size_t len) { 72 float sum = 0.0f; 73 int total_histogram = 0; 74 int total_counts = 0; 75 for (size_t i = 0; i < len; ++i) { 76 total_histogram += histogram[i]; 77 total_counts += counts[i]; 78 if (histogram[i] > 0) { 79 JXL_DASSERT(counts[i] > 0); 80 // += histogram[i] * -log(counts[i]/total_counts) 81 sum += histogram[i] * 82 std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i])); 83 } 84 } 85 if (total_histogram > 0) { 86 // Used only in assert. 87 (void)total_counts; 88 JXL_DASSERT(total_counts == ANS_TAB_SIZE); 89 } 90 return sum; 91 } 92 93 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) { 94 const float flat_bits = std::max(FastLog2f(len), 0.0f); 95 float total_histogram = 0; 96 for (size_t i = 0; i < len; ++i) { 97 total_histogram += histogram[i]; 98 } 99 return total_histogram * flat_bits; 100 } 101 102 // Static Huffman code for encoding logcounts. The last symbol is used as RLE 103 // sequence. 104 const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = { 105 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7, 106 }; 107 const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = { 108 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65, 109 }; 110 111 // Returns the difference between largest count that can be represented and is 112 // smaller than "count" and smallest representable count larger than "count". 113 int SmallestIncrement(uint32_t count, uint32_t shift) { 114 int bits = count == 0 ? -1 : FloorLog2Nonzero(count); 115 int drop_bits = bits - GetPopulationCountPrecision(bits, shift); 116 return drop_bits < 0 ? 1 : (1 << drop_bits); 117 } 118 119 template <bool minimize_error_of_sum> 120 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size, 121 uint32_t shift, int* omit_pos, ANSHistBin* counts) { 122 int sum = 0; 123 float sum_nonrounded = 0.0; 124 int remainder_pos = 0; // if all of them are handled in first loop 125 int remainder_log = -1; 126 for (int n = 0; n < max_symbol; ++n) { 127 if (targets[n] > 0 && targets[n] < 1.0f) { 128 counts[n] = 1; 129 sum_nonrounded += targets[n]; 130 sum += counts[n]; 131 } 132 } 133 const float discount_ratio = 134 (table_size - sum) / (table_size - sum_nonrounded); 135 JXL_ENSURE(discount_ratio > 0); 136 JXL_ENSURE(discount_ratio <= 1.0f); 137 // Invariant for minimize_error_of_sum == true: 138 // abs(sum - sum_nonrounded) 139 // <= SmallestIncrement(max(targets[])) + max_symbol 140 for (int n = 0; n < max_symbol; ++n) { 141 if (targets[n] >= 1.0f) { 142 sum_nonrounded += targets[n]; 143 counts[n] = 144 static_cast<ANSHistBin>(targets[n] * discount_ratio); // truncate 145 if (counts[n] == 0) counts[n] = 1; 146 if (counts[n] == table_size) counts[n] = table_size - 1; 147 // Round the count to the closest nonzero multiple of SmallestIncrement 148 // (when minimize_error_of_sum is false) or one of two closest so as to 149 // keep the sum as close as possible to sum_nonrounded. 150 int inc = SmallestIncrement(counts[n], shift); 151 counts[n] -= counts[n] & (inc - 1); 152 // TODO(robryk): Should we rescale targets[n]? 153 const int target = minimize_error_of_sum 154 ? (static_cast<int>(sum_nonrounded) - sum) 155 : static_cast<int>(targets[n]); 156 if (counts[n] == 0 || 157 (target >= counts[n] + inc / 2 && counts[n] + inc < table_size)) { 158 counts[n] += inc; 159 } 160 sum += counts[n]; 161 const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n])); 162 if (count_log > remainder_log) { 163 remainder_pos = n; 164 remainder_log = count_log; 165 } 166 } 167 } 168 JXL_ENSURE(remainder_pos != -1); 169 // NOTE: This is the only place where counts could go negative. We could 170 // detect that, return false and make ANSHistBin uint32_t. 171 counts[remainder_pos] -= sum - table_size; 172 *omit_pos = remainder_pos; 173 return counts[remainder_pos] > 0; 174 } 175 176 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length, 177 const int precision_bits, uint32_t shift, 178 int* num_symbols, int* symbols) { 179 const int32_t table_size = 1 << precision_bits; // target sum / table size 180 uint64_t total = 0; 181 int max_symbol = 0; 182 int symbol_count = 0; 183 for (int n = 0; n < length; ++n) { 184 total += counts[n]; 185 if (counts[n] > 0) { 186 if (symbol_count < kMaxNumSymbolsForSmallCode) { 187 symbols[symbol_count] = n; 188 } 189 ++symbol_count; 190 max_symbol = n + 1; 191 } 192 } 193 *num_symbols = symbol_count; 194 if (symbol_count == 0) { 195 return true; 196 } 197 if (symbol_count == 1) { 198 counts[symbols[0]] = table_size; 199 return true; 200 } 201 if (symbol_count > table_size) 202 return JXL_FAILURE("Too many entries in an ANS histogram"); 203 204 const float norm = 1.f * table_size / total; 205 std::vector<float> targets(max_symbol); 206 for (size_t n = 0; n < targets.size(); ++n) { 207 targets[n] = norm * counts[n]; 208 } 209 if (!RebalanceHistogram<false>(targets.data(), max_symbol, table_size, shift, 210 omit_pos, counts)) { 211 // Use an alternative rebalancing mechanism if the one above failed 212 // to create a histogram that is positive wherever the original one was. 213 if (!RebalanceHistogram<true>(targets.data(), max_symbol, table_size, shift, 214 omit_pos, counts)) { 215 return JXL_FAILURE("Logic error: couldn't rebalance a histogram"); 216 } 217 } 218 return true; 219 } 220 221 struct SizeWriter { 222 size_t size = 0; 223 void Write(size_t num, size_t bits) { size += num; } 224 }; 225 226 template <typename Writer> 227 void StoreVarLenUint8(size_t n, Writer* writer) { 228 JXL_DASSERT(n <= 255); 229 if (n == 0) { 230 writer->Write(1, 0); 231 } else { 232 writer->Write(1, 1); 233 size_t nbits = FloorLog2Nonzero(n); 234 writer->Write(3, nbits); 235 writer->Write(nbits, n - (1ULL << nbits)); 236 } 237 } 238 239 template <typename Writer> 240 void StoreVarLenUint16(size_t n, Writer* writer) { 241 JXL_DASSERT(n <= 65535); 242 if (n == 0) { 243 writer->Write(1, 0); 244 } else { 245 writer->Write(1, 1); 246 size_t nbits = FloorLog2Nonzero(n); 247 writer->Write(4, nbits); 248 writer->Write(nbits, n - (1ULL << nbits)); 249 } 250 } 251 252 template <typename Writer> 253 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size, 254 const int omit_pos, const int num_symbols, uint32_t shift, 255 const int* symbols, Writer* writer) { 256 bool ok = true; 257 if (num_symbols <= 2) { 258 // Small tree marker to encode 1-2 symbols. 259 writer->Write(1, 1); 260 if (num_symbols == 0) { 261 writer->Write(1, 0); 262 StoreVarLenUint8(0, writer); 263 } else { 264 writer->Write(1, num_symbols - 1); 265 for (int i = 0; i < num_symbols; ++i) { 266 StoreVarLenUint8(symbols[i], writer); 267 } 268 } 269 if (num_symbols == 2) { 270 writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]); 271 } 272 } else { 273 // Mark non-small tree. 274 writer->Write(1, 0); 275 // Mark non-flat histogram. 276 writer->Write(1, 0); 277 278 // Precompute sequences for RLE encoding. Contains the number of identical 279 // values starting at a given index. Only contains the value at the first 280 // element of the series. 281 std::vector<uint32_t> same(alphabet_size, 0); 282 int last = 0; 283 for (int i = 1; i < alphabet_size; i++) { 284 // Store the sequence length once different symbol reached, or we're at 285 // the end, or the length is longer than we can encode, or we are at 286 // the omit_pos. We don't support including the omit_pos in an RLE 287 // sequence because this value may use a different amount of log2 bits 288 // than standard, it is too complex to handle in the decoder. 289 if (counts[i] != counts[last] || i + 1 == alphabet_size || 290 (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) { 291 same[last] = (i - last); 292 last = i + 1; 293 } 294 } 295 296 int length = 0; 297 std::vector<int> logcounts(alphabet_size); 298 int omit_log = 0; 299 for (int i = 0; i < alphabet_size; ++i) { 300 JXL_ENSURE(counts[i] <= ANS_TAB_SIZE); 301 JXL_ENSURE(counts[i] >= 0); 302 if (i == omit_pos) { 303 length = i + 1; 304 } else if (counts[i] > 0) { 305 logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1; 306 length = i + 1; 307 if (i < omit_pos) { 308 omit_log = std::max(omit_log, logcounts[i] + 1); 309 } else { 310 omit_log = std::max(omit_log, logcounts[i]); 311 } 312 } 313 } 314 logcounts[omit_pos] = omit_log; 315 316 // Elias gamma-like code for shift. Only difference is that if the number 317 // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip 318 // the terminating 0 in unary coding. 319 int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); 320 int log = FloorLog2Nonzero(shift + 1); 321 writer->Write(log, (1 << log) - 1); 322 if (log != upper_bound_log) writer->Write(1, 0); 323 writer->Write(log, ((1 << log) - 1) & (shift + 1)); 324 325 // Since num_symbols >= 3, we know that length >= 3, therefore we encode 326 // length - 3. 327 if (length - 3 > 255) { 328 // Pretend that everything is OK, but complain about correctness later. 329 StoreVarLenUint8(255, writer); 330 ok = false; 331 } else { 332 StoreVarLenUint8(length - 3, writer); 333 } 334 335 // The logcount values are encoded with a static Huffman code. 336 static const size_t kMinReps = 4; 337 size_t rep = ANS_LOG_TAB_SIZE + 1; 338 for (int i = 0; i < length; ++i) { 339 if (i > 0 && same[i - 1] > kMinReps) { 340 // Encode the RLE symbol and skip the repeated ones. 341 writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]); 342 StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer); 343 i += same[i - 1] - 2; 344 continue; 345 } 346 writer->Write(kLogCountBitLengths[logcounts[i]], 347 kLogCountSymbols[logcounts[i]]); 348 } 349 for (int i = 0; i < length; ++i) { 350 if (i > 0 && same[i - 1] > kMinReps) { 351 // Skip symbols encoded by RLE. 352 i += same[i - 1] - 2; 353 continue; 354 } 355 if (logcounts[i] > 1 && i != omit_pos) { 356 int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift); 357 int drop_bits = logcounts[i] - 1 - bitcount; 358 JXL_ENSURE((counts[i] & ((1 << drop_bits) - 1)) == 0); 359 writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount)); 360 } 361 } 362 } 363 return ok; 364 } 365 366 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) { 367 // Mark non-small tree. 368 writer->Write(1, 0); 369 // Mark uniform histogram. 370 writer->Write(1, 1); 371 JXL_DASSERT(alphabet_size > 0); 372 // Encode alphabet size. 373 StoreVarLenUint8(alphabet_size - 1, writer); 374 } 375 376 StatusOr<float> ComputeHistoAndDataCost(const ANSHistBin* histogram, 377 size_t alphabet_size, uint32_t method) { 378 if (method == 0) { // Flat code 379 return ANS_LOG_TAB_SIZE + 2 + 380 EstimateDataBitsFlat(histogram, alphabet_size); 381 } 382 // Non-flat: shift = method-1. 383 uint32_t shift = method - 1; 384 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size); 385 int omit_pos = 0; 386 int num_symbols; 387 int symbols[kMaxNumSymbolsForSmallCode] = {}; 388 JXL_RETURN_IF_ERROR(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, 389 ANS_LOG_TAB_SIZE, shift, &num_symbols, 390 symbols)); 391 SizeWriter writer; 392 // Ignore the correctness, no real encoding happens at this stage. 393 (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift, 394 symbols, &writer); 395 return writer.size + 396 EstimateDataBits(histogram, counts.data(), alphabet_size); 397 } 398 399 StatusOr<uint32_t> ComputeBestMethod( 400 const ANSHistBin* histogram, size_t alphabet_size, float* cost, 401 HistogramParams::ANSHistogramStrategy ans_histogram_strategy) { 402 uint32_t method = 0; 403 JXL_ASSIGN_OR_RETURN(float fcost, 404 ComputeHistoAndDataCost(histogram, alphabet_size, 0)); 405 auto try_shift = [&](size_t shift) -> Status { 406 JXL_ASSIGN_OR_RETURN( 407 float c, ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1)); 408 if (c < fcost) { 409 method = shift + 1; 410 fcost = c; 411 } 412 return true; 413 }; 414 switch (ans_histogram_strategy) { 415 case HistogramParams::ANSHistogramStrategy::kPrecise: { 416 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) { 417 JXL_RETURN_IF_ERROR(try_shift(shift)); 418 } 419 break; 420 } 421 case HistogramParams::ANSHistogramStrategy::kApproximate: { 422 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) { 423 JXL_RETURN_IF_ERROR(try_shift(shift)); 424 } 425 break; 426 } 427 case HistogramParams::ANSHistogramStrategy::kFast: { 428 JXL_RETURN_IF_ERROR(try_shift(0)); 429 JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE / 2)); 430 JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE)); 431 break; 432 } 433 }; 434 *cost = fcost; 435 return method; 436 } 437 438 } // namespace 439 440 // Returns an estimate of the cost of encoding this histogram and the 441 // corresponding data. 442 StatusOr<size_t> BuildAndStoreANSEncodingData( 443 JxlMemoryManager* memory_manager, 444 HistogramParams::ANSHistogramStrategy ans_histogram_strategy, 445 const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size, 446 bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) { 447 if (use_prefix_code) { 448 size_t cost = 0; 449 if (alphabet_size <= 1) return 0; 450 std::vector<uint32_t> histo(alphabet_size); 451 for (size_t i = 0; i < alphabet_size; i++) { 452 histo[i] = histogram[i]; 453 JXL_ENSURE(histogram[i] >= 0); 454 } 455 { 456 std::vector<uint8_t> depths(alphabet_size); 457 std::vector<uint16_t> bits(alphabet_size); 458 if (writer == nullptr) { 459 BitWriter tmp_writer{memory_manager}; 460 JXL_RETURN_IF_ERROR(tmp_writer.WithMaxBits( 461 8 * alphabet_size + 8, // safe upper bound 462 LayerType::Header, /*aux_out=*/nullptr, [&] { 463 return BuildAndStoreHuffmanTree(histo.data(), alphabet_size, 464 depths.data(), bits.data(), 465 &tmp_writer); 466 })); 467 cost = tmp_writer.BitsWritten(); 468 } else { 469 size_t start = writer->BitsWritten(); 470 JXL_RETURN_IF_ERROR(BuildAndStoreHuffmanTree( 471 histo.data(), alphabet_size, depths.data(), bits.data(), writer)); 472 cost = writer->BitsWritten() - start; 473 } 474 for (size_t i = 0; i < alphabet_size; i++) { 475 info[i].bits = depths[i] == 0 ? 0 : bits[i]; 476 info[i].depth = depths[i]; 477 } 478 } 479 // Estimate data cost. 480 for (size_t i = 0; i < alphabet_size; i++) { 481 cost += histogram[i] * info[i].depth; 482 } 483 return cost; 484 } 485 JXL_ENSURE(alphabet_size <= ANS_TAB_SIZE); 486 float fcost; 487 JXL_ASSIGN_OR_RETURN(uint32_t method, 488 ComputeBestMethod(histogram, alphabet_size, &fcost, 489 ans_histogram_strategy)); 490 JXL_ENSURE(fcost >= 0); 491 int num_symbols; 492 int symbols[kMaxNumSymbolsForSmallCode] = {}; 493 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size); 494 if (!counts.empty()) { 495 size_t sum = 0; 496 for (int count : counts) { 497 sum += count; 498 } 499 if (sum == 0) { 500 counts[0] = ANS_TAB_SIZE; 501 } 502 } 503 int omit_pos = 0; 504 uint32_t shift = method - 1; 505 if (method == 0) { 506 JXL_ENSURE(alphabet_size > 0); 507 counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); 508 } else { 509 JXL_RETURN_IF_ERROR(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, 510 ANS_LOG_TAB_SIZE, shift, &num_symbols, 511 symbols)); 512 } 513 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE]; 514 JXL_RETURN_IF_ERROR( 515 InitAliasTable(counts, ANS_LOG_TAB_SIZE, log_alpha_size, a)); 516 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info); 517 if (writer != nullptr) { 518 if (method == 0) { 519 JXL_ENSURE(alphabet_size > 0); 520 EncodeFlatHistogram(alphabet_size, writer); 521 } else { 522 if (!EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, 523 method - 1, symbols, writer)) { 524 return JXL_FAILURE("EncodeCounts failed"); 525 } 526 } 527 } 528 return static_cast<size_t>(fcost); 529 } 530 531 StatusOr<float> ANSPopulationCost(const ANSHistBin* data, 532 size_t alphabet_size) { 533 float cost = 0.0f; 534 if (ANS_MAX_ALPHABET_SIZE < alphabet_size) { 535 return std::numeric_limits<float>::max(); 536 } 537 JXL_ASSIGN_OR_RETURN( 538 uint32_t method, 539 ComputeBestMethod(data, alphabet_size, &cost, 540 HistogramParams::ANSHistogramStrategy::kFast)); 541 (void)method; 542 return cost; 543 } 544 545 template <typename Writer> 546 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer, 547 size_t log_alpha_size) { 548 writer->Write(CeilLog2Nonzero(log_alpha_size + 1), 549 uint_config.split_exponent); 550 if (uint_config.split_exponent == log_alpha_size) { 551 return; // msb/lsb don't matter. 552 } 553 size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1); 554 writer->Write(nbits, uint_config.msb_in_token); 555 nbits = CeilLog2Nonzero(uint_config.split_exponent - 556 uint_config.msb_in_token + 1); 557 writer->Write(nbits, uint_config.lsb_in_token); 558 } 559 template <typename Writer> 560 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config, 561 Writer* writer, size_t log_alpha_size) { 562 // TODO(veluca): RLE? 563 for (const auto& cfg : uint_config) { 564 EncodeUintConfig(cfg, writer, log_alpha_size); 565 } 566 } 567 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&, 568 BitWriter*, size_t); 569 570 namespace { 571 572 Status ChooseUintConfigs(const HistogramParams& params, 573 const std::vector<std::vector<Token>>& tokens, 574 const std::vector<uint8_t>& context_map, 575 std::vector<Histogram>* clustered_histograms, 576 EntropyEncodingData* codes, size_t* log_alpha_size) { 577 codes->uint_config.resize(clustered_histograms->size()); 578 if (params.uint_method == HistogramParams::HybridUintMethod::kNone) { 579 return true; 580 } 581 if (params.uint_method == HistogramParams::HybridUintMethod::k000) { 582 codes->uint_config.clear(); 583 codes->uint_config.resize(clustered_histograms->size(), 584 HybridUintConfig(0, 0, 0)); 585 return true; 586 } 587 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { 588 codes->uint_config.clear(); 589 codes->uint_config.resize(clustered_histograms->size(), 590 HybridUintConfig(2, 0, 1)); 591 return true; 592 } 593 594 // If the uint config is adaptive, just stick with the default in streaming 595 // mode. 596 if (params.streaming_mode) { 597 return true; 598 } 599 600 // Brute-force method that tries a few options. 601 std::vector<HybridUintConfig> configs; 602 if (params.uint_method == HistogramParams::HybridUintMethod::kBest) { 603 configs = { 604 HybridUintConfig(4, 2, 0), // default 605 HybridUintConfig(4, 1, 0), // less precise 606 HybridUintConfig(4, 2, 1), // add sign 607 HybridUintConfig(4, 2, 2), // add sign+parity 608 HybridUintConfig(4, 1, 2), // add parity but less msb 609 // Same as above, but more direct coding. 610 HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0), 611 HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2), 612 HybridUintConfig(5, 1, 2), 613 // Same as above, but less direct coding. 614 HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0), 615 HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2), 616 // For near-lossless. 617 HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4), 618 HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5), 619 HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0), 620 // Other 621 HybridUintConfig(0, 0, 0), // varlenuint 622 HybridUintConfig(2, 0, 1), // works well for ctx map 623 HybridUintConfig(7, 0, 0), // direct coding 624 HybridUintConfig(8, 0, 0), // direct coding 625 HybridUintConfig(9, 0, 0), // direct coding 626 HybridUintConfig(10, 0, 0), // direct coding 627 HybridUintConfig(11, 0, 0), // direct coding 628 HybridUintConfig(12, 0, 0), // direct coding 629 }; 630 } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) { 631 configs = { 632 HybridUintConfig(4, 2, 0), // default 633 HybridUintConfig(4, 1, 2), // add parity but less msb 634 HybridUintConfig(0, 0, 0), // smallest histograms 635 HybridUintConfig(2, 0, 1), // works well for ctx map 636 }; 637 } 638 639 std::vector<float> costs(clustered_histograms->size(), 640 std::numeric_limits<float>::max()); 641 std::vector<uint32_t> extra_bits(clustered_histograms->size()); 642 std::vector<uint8_t> is_valid(clustered_histograms->size()); 643 size_t max_alpha = 644 codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE; 645 for (HybridUintConfig cfg : configs) { 646 std::fill(is_valid.begin(), is_valid.end(), true); 647 std::fill(extra_bits.begin(), extra_bits.end(), 0); 648 649 for (auto& histo : *clustered_histograms) { 650 histo.Clear(); 651 } 652 for (const auto& stream : tokens) { 653 for (const auto& token : stream) { 654 // TODO(veluca): do not ignore lz77 commands. 655 if (token.is_lz77_length) continue; 656 size_t histo = context_map[token.context]; 657 uint32_t tok, nbits, bits; 658 cfg.Encode(token.value, &tok, &nbits, &bits); 659 if (tok >= max_alpha || 660 (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) { 661 is_valid[histo] = JXL_FALSE; 662 continue; 663 } 664 extra_bits[histo] += nbits; 665 (*clustered_histograms)[histo].Add(tok); 666 } 667 } 668 669 for (size_t i = 0; i < clustered_histograms->size(); i++) { 670 if (!is_valid[i]) continue; 671 JXL_ASSIGN_OR_RETURN(float cost, 672 (*clustered_histograms)[i].PopulationCost()); 673 cost += extra_bits[i]; 674 // add signaling cost of the hybriduintconfig itself 675 cost += CeilLog2Nonzero(cfg.split_exponent + 1); 676 cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1); 677 if (cost < costs[i]) { 678 codes->uint_config[i] = cfg; 679 costs[i] = cost; 680 } 681 } 682 } 683 684 // Rebuild histograms. 685 for (auto& histo : *clustered_histograms) { 686 histo.Clear(); 687 } 688 *log_alpha_size = 4; 689 for (const auto& stream : tokens) { 690 for (const auto& token : stream) { 691 uint32_t tok, nbits, bits; 692 size_t histo = context_map[token.context]; 693 (token.is_lz77_length ? codes->lz77.length_uint_config 694 : codes->uint_config[histo]) 695 .Encode(token.value, &tok, &nbits, &bits); 696 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; 697 (*clustered_histograms)[histo].Add(tok); 698 while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++; 699 } 700 } 701 size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8; 702 JXL_ENSURE(*log_alpha_size <= max_log_alpha_size); 703 return true; 704 } 705 706 Histogram HistogramFromSymbolInfo( 707 const std::vector<ANSEncSymbolInfo>& encoding_info, bool use_prefix_code) { 708 Histogram histo; 709 histo.data_.resize(DivCeil(encoding_info.size(), Histogram::kRounding) * 710 Histogram::kRounding); 711 histo.total_count_ = 0; 712 for (size_t i = 0; i < encoding_info.size(); ++i) { 713 const ANSEncSymbolInfo& info = encoding_info[i]; 714 int count = use_prefix_code 715 ? (info.depth ? (1u << (PREFIX_MAX_BITS - info.depth)) : 0) 716 : info.freq_; 717 histo.data_[i] = count; 718 histo.total_count_ += count; 719 } 720 return histo; 721 } 722 723 class HistogramBuilder { 724 public: 725 explicit HistogramBuilder(const size_t num_contexts) 726 : histograms_(num_contexts) {} 727 728 void VisitSymbol(int symbol, size_t histo_idx) { 729 JXL_DASSERT(histo_idx < histograms_.size()); 730 histograms_[histo_idx].Add(symbol); 731 } 732 733 // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge. 734 // Returns cost (in bits). 735 StatusOr<size_t> BuildAndStoreEntropyCodes( 736 JxlMemoryManager* memory_manager, const HistogramParams& params, 737 const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes, 738 std::vector<uint8_t>* context_map, BitWriter* writer, LayerType layer, 739 AuxOut* aux_out) const { 740 const size_t prev_histograms = codes->encoding_info.size(); 741 std::vector<Histogram> clustered_histograms; 742 for (size_t i = 0; i < prev_histograms; ++i) { 743 clustered_histograms.push_back(HistogramFromSymbolInfo( 744 codes->encoding_info[i], codes->use_prefix_code)); 745 } 746 size_t context_offset = context_map->size(); 747 context_map->resize(context_offset + histograms_.size()); 748 if (histograms_.size() > 1) { 749 if (!ans_fuzzer_friendly_) { 750 std::vector<uint32_t> histogram_symbols; 751 JXL_RETURN_IF_ERROR( 752 ClusterHistograms(params, histograms_, kClustersLimit, 753 &clustered_histograms, &histogram_symbols)); 754 for (size_t c = 0; c < histograms_.size(); ++c) { 755 (*context_map)[context_offset + c] = 756 static_cast<uint8_t>(histogram_symbols[c]); 757 } 758 } else { 759 JXL_ENSURE(codes->encoding_info.empty()); 760 fill(context_map->begin(), context_map->end(), 0); 761 size_t max_symbol = 0; 762 for (const Histogram& h : histograms_) { 763 max_symbol = std::max(h.data_.size(), max_symbol); 764 } 765 size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1); 766 clustered_histograms.resize(1); 767 clustered_histograms[0].Clear(); 768 for (size_t i = 0; i < num_symbols; i++) { 769 clustered_histograms[0].Add(i); 770 } 771 } 772 if (writer != nullptr) { 773 JXL_RETURN_IF_ERROR(EncodeContextMap( 774 *context_map, clustered_histograms.size(), writer, layer, aux_out)); 775 } 776 } else { 777 JXL_ENSURE(codes->encoding_info.empty()); 778 clustered_histograms.push_back(histograms_[0]); 779 } 780 if (aux_out != nullptr) { 781 for (size_t i = prev_histograms; i < clustered_histograms.size(); ++i) { 782 aux_out->layer(layer).clustered_entropy += 783 clustered_histograms[i].ShannonEntropy(); 784 } 785 } 786 size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default. 787 if (ans_fuzzer_friendly_) { 788 codes->uint_config.clear(); 789 codes->uint_config.resize(1, HybridUintConfig(7, 0, 0)); 790 } else { 791 JXL_RETURN_IF_ERROR(ChooseUintConfigs(params, tokens, *context_map, 792 &clustered_histograms, codes, 793 &log_alpha_size)); 794 } 795 if (log_alpha_size < 5) log_alpha_size = 5; 796 if (params.streaming_mode) { 797 // TODO(szabadka) Figure out if we can use lower values here. 798 log_alpha_size = 8; 799 } 800 SizeWriter size_writer; // Used if writer == nullptr to estimate costs. 801 size_t cost = 1; 802 if (writer) writer->Write(1, TO_JXL_BOOL(codes->use_prefix_code)); 803 804 if (codes->use_prefix_code) { 805 log_alpha_size = PREFIX_MAX_BITS; 806 } else { 807 cost += 2; 808 } 809 if (writer == nullptr) { 810 EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size); 811 } else { 812 if (!codes->use_prefix_code) writer->Write(2, log_alpha_size - 5); 813 EncodeUintConfigs(codes->uint_config, writer, log_alpha_size); 814 } 815 if (codes->use_prefix_code) { 816 for (const auto& histo : clustered_histograms) { 817 size_t alphabet_size = histo.alphabet_size(); 818 if (writer) { 819 StoreVarLenUint16(alphabet_size - 1, writer); 820 } else { 821 StoreVarLenUint16(alphabet_size - 1, &size_writer); 822 } 823 } 824 } 825 cost += size_writer.size; 826 for (size_t c = prev_histograms; c < clustered_histograms.size(); ++c) { 827 size_t alphabet_size = clustered_histograms[c].alphabet_size(); 828 codes->encoding_info.emplace_back(); 829 codes->encoding_info.back().resize(alphabet_size); 830 BitWriter* histo_writer = writer; 831 if (params.streaming_mode) { 832 codes->encoded_histograms.emplace_back(memory_manager); 833 histo_writer = &codes->encoded_histograms.back(); 834 } 835 const auto& body = [&]() -> Status { 836 JXL_ASSIGN_OR_RETURN( 837 size_t ans_cost, 838 BuildAndStoreANSEncodingData( 839 memory_manager, params.ans_histogram_strategy, 840 clustered_histograms[c].data_.data(), alphabet_size, 841 log_alpha_size, codes->use_prefix_code, 842 codes->encoding_info.back().data(), histo_writer)); 843 cost += ans_cost; 844 return true; 845 }; 846 if (histo_writer) { 847 JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits( 848 256 + alphabet_size * 24, layer, aux_out, body, 849 /*finished_histogram=*/true)); 850 } else { 851 JXL_RETURN_IF_ERROR(body()); 852 } 853 if (params.streaming_mode) { 854 JXL_RETURN_IF_ERROR(writer->AppendUnaligned(*histo_writer)); 855 } 856 } 857 return cost; 858 } 859 860 const Histogram& Histo(size_t i) const { return histograms_[i]; } 861 862 private: 863 std::vector<Histogram> histograms_; 864 }; 865 866 class SymbolCostEstimator { 867 public: 868 SymbolCostEstimator(size_t num_contexts, bool force_huffman, 869 const std::vector<std::vector<Token>>& tokens, 870 const LZ77Params& lz77) { 871 HistogramBuilder builder(num_contexts); 872 // Build histograms for estimating lz77 savings. 873 HybridUintConfig uint_config; 874 for (const auto& stream : tokens) { 875 for (const auto& token : stream) { 876 uint32_t tok, nbits, bits; 877 (token.is_lz77_length ? lz77.length_uint_config : uint_config) 878 .Encode(token.value, &tok, &nbits, &bits); 879 tok += token.is_lz77_length ? lz77.min_symbol : 0; 880 builder.VisitSymbol(tok, token.context); 881 } 882 } 883 max_alphabet_size_ = 0; 884 for (size_t i = 0; i < num_contexts; i++) { 885 max_alphabet_size_ = 886 std::max(max_alphabet_size_, builder.Histo(i).data_.size()); 887 } 888 bits_.resize(num_contexts * max_alphabet_size_); 889 // TODO(veluca): SIMD? 890 add_symbol_cost_.resize(num_contexts); 891 for (size_t i = 0; i < num_contexts; i++) { 892 float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f); 893 float total_cost = 0; 894 for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) { 895 size_t cnt = builder.Histo(i).data_[j]; 896 float cost = 0; 897 if (cnt != 0 && cnt != builder.Histo(i).total_count_) { 898 cost = -FastLog2f(cnt * inv_total); 899 if (force_huffman) cost = std::ceil(cost); 900 } else if (cnt == 0) { 901 cost = ANS_LOG_TAB_SIZE; // Highest possible cost. 902 } 903 bits_[i * max_alphabet_size_ + j] = cost; 904 total_cost += cost * builder.Histo(i).data_[j]; 905 } 906 // Penalty for adding a lz77 symbol to this contest (only used for static 907 // cost model). Higher penalty for contexts that have a very low 908 // per-symbol entropy. 909 add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total); 910 } 911 } 912 float Bits(size_t ctx, size_t sym) const { 913 return bits_[ctx * max_alphabet_size_ + sym]; 914 } 915 float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const { 916 uint32_t nbits, bits, tok; 917 lz77.length_uint_config.Encode(len, &tok, &nbits, &bits); 918 tok += lz77.min_symbol; 919 return nbits + Bits(ctx, tok); 920 } 921 float DistCost(size_t len, const LZ77Params& lz77) const { 922 uint32_t nbits, bits, tok; 923 HybridUintConfig().Encode(len, &tok, &nbits, &bits); 924 return nbits + Bits(lz77.nonserialized_distance_context, tok); 925 } 926 float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; } 927 928 private: 929 size_t max_alphabet_size_; 930 std::vector<float> bits_; 931 std::vector<float> add_symbol_cost_; 932 }; 933 934 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts, 935 const std::vector<std::vector<Token>>& tokens, 936 LZ77Params& lz77, 937 std::vector<std::vector<Token>>& tokens_lz77) { 938 // TODO(veluca): tune heuristics here. 939 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); 940 float bit_decrease = 0; 941 size_t total_symbols = 0; 942 tokens_lz77.resize(tokens.size()); 943 std::vector<float> sym_cost; 944 HybridUintConfig uint_config; 945 for (size_t stream = 0; stream < tokens.size(); stream++) { 946 size_t distance_multiplier = 947 params.image_widths.size() > stream ? params.image_widths[stream] : 0; 948 const auto& in = tokens[stream]; 949 auto& out = tokens_lz77[stream]; 950 total_symbols += in.size(); 951 // Cumulative sum of bit costs. 952 sym_cost.resize(in.size() + 1); 953 for (size_t i = 0; i < in.size(); i++) { 954 uint32_t tok, nbits, unused_bits; 955 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); 956 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; 957 } 958 out.reserve(in.size()); 959 for (size_t i = 0; i < in.size(); i++) { 960 size_t num_to_copy = 0; 961 size_t distance_symbol = 0; // 1 for RLE. 962 if (distance_multiplier != 0) { 963 distance_symbol = 1; // Special distance 1 if enabled. 964 JXL_DASSERT(kSpecialDistances[1][0] == 1); 965 JXL_DASSERT(kSpecialDistances[1][1] == 0); 966 } 967 if (i > 0) { 968 for (; i + num_to_copy < in.size(); num_to_copy++) { 969 if (in[i + num_to_copy].value != in[i - 1].value) { 970 break; 971 } 972 } 973 } 974 if (num_to_copy == 0) { 975 out.push_back(in[i]); 976 continue; 977 } 978 float cost = sym_cost[i + num_to_copy] - sym_cost[i]; 979 // This subtraction might overflow, but that's OK. 980 size_t lz77_len = num_to_copy - lz77.min_length; 981 float lz77_cost = num_to_copy >= lz77.min_length 982 ? CeilLog2Nonzero(lz77_len + 1) + 1 983 : 0; 984 if (num_to_copy < lz77.min_length || cost <= lz77_cost) { 985 for (size_t j = 0; j < num_to_copy; j++) { 986 out.push_back(in[i + j]); 987 } 988 i += num_to_copy - 1; 989 continue; 990 } 991 // Output the LZ77 length 992 out.emplace_back(in[i].context, lz77_len); 993 out.back().is_lz77_length = true; 994 i += num_to_copy - 1; 995 bit_decrease += cost - lz77_cost; 996 // Output the LZ77 copy distance. 997 out.emplace_back(lz77.nonserialized_distance_context, distance_symbol); 998 } 999 } 1000 1001 if (bit_decrease > total_symbols * 0.2 + 16) { 1002 lz77.enabled = true; 1003 } 1004 } 1005 1006 // Hash chain for LZ77 matching 1007 struct HashChain { 1008 size_t size_; 1009 std::vector<uint32_t> data_; 1010 1011 unsigned hash_num_values_ = 32768; 1012 unsigned hash_mask_ = hash_num_values_ - 1; 1013 unsigned hash_shift_ = 5; 1014 1015 std::vector<int> head; 1016 std::vector<uint32_t> chain; 1017 std::vector<int> val; 1018 1019 // Speed up repetitions of zero 1020 std::vector<int> headz; 1021 std::vector<uint32_t> chainz; 1022 std::vector<uint32_t> zeros; 1023 uint32_t numzeros = 0; 1024 1025 size_t window_size_; 1026 size_t window_mask_; 1027 size_t min_length_; 1028 size_t max_length_; 1029 1030 // Map of special distance codes. 1031 std::unordered_map<int, int> special_dist_table_; 1032 size_t num_special_distances_ = 0; 1033 1034 uint32_t maxchainlength = 256; // window_size_ to allow all 1035 1036 HashChain(const Token* data, size_t size, size_t window_size, 1037 size_t min_length, size_t max_length, size_t distance_multiplier) 1038 : size_(size), 1039 window_size_(window_size), 1040 window_mask_(window_size - 1), 1041 min_length_(min_length), 1042 max_length_(max_length) { 1043 data_.resize(size); 1044 for (size_t i = 0; i < size; i++) { 1045 data_[i] = data[i].value; 1046 } 1047 1048 head.resize(hash_num_values_, -1); 1049 val.resize(window_size_, -1); 1050 chain.resize(window_size_); 1051 for (uint32_t i = 0; i < window_size_; ++i) { 1052 chain[i] = i; // same value as index indicates uninitialized 1053 } 1054 1055 zeros.resize(window_size_); 1056 headz.resize(window_size_ + 1, -1); 1057 chainz.resize(window_size_); 1058 for (uint32_t i = 0; i < window_size_; ++i) { 1059 chainz[i] = i; 1060 } 1061 // Translate distance to special distance code. 1062 if (distance_multiplier) { 1063 // Count down, so if due to small distance multiplier multiple distances 1064 // map to the same code, the smallest code will be used in the end. 1065 for (int i = kNumSpecialDistances - 1; i >= 0; --i) { 1066 special_dist_table_[SpecialDistance(i, distance_multiplier)] = i; 1067 } 1068 num_special_distances_ = kNumSpecialDistances; 1069 } 1070 } 1071 1072 uint32_t GetHash(size_t pos) const { 1073 uint32_t result = 0; 1074 if (pos + 2 < size_) { 1075 // TODO(lode): take the MSB's of the uint32_t values into account as well, 1076 // given that the hash code itself is less than 32 bits. 1077 result ^= static_cast<uint32_t>(data_[pos + 0] << 0u); 1078 result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_); 1079 result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2)); 1080 } else { 1081 // No need to compute hash of last 2 bytes, the length 2 is too short. 1082 return 0; 1083 } 1084 return result & hash_mask_; 1085 } 1086 1087 uint32_t CountZeros(size_t pos, uint32_t prevzeros) const { 1088 size_t end = pos + window_size_; 1089 if (end > size_) end = size_; 1090 if (prevzeros > 0) { 1091 if (prevzeros >= window_mask_ && data_[end - 1] == 0 && 1092 end == pos + window_size_) { 1093 return prevzeros; 1094 } else { 1095 return prevzeros - 1; 1096 } 1097 } 1098 uint32_t num = 0; 1099 while (pos + num < end && data_[pos + num] == 0) num++; 1100 return num; 1101 } 1102 1103 void Update(size_t pos) { 1104 uint32_t hashval = GetHash(pos); 1105 uint32_t wpos = pos & window_mask_; 1106 1107 val[wpos] = static_cast<int>(hashval); 1108 if (head[hashval] != -1) chain[wpos] = head[hashval]; 1109 head[hashval] = wpos; 1110 1111 if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0; 1112 numzeros = CountZeros(pos, numzeros); 1113 1114 zeros[wpos] = numzeros; 1115 if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros]; 1116 headz[numzeros] = wpos; 1117 } 1118 1119 void Update(size_t pos, size_t len) { 1120 for (size_t i = 0; i < len; i++) { 1121 Update(pos + i); 1122 } 1123 } 1124 1125 template <typename CB> 1126 void FindMatches(size_t pos, int max_dist, const CB& found_match) const { 1127 uint32_t wpos = pos & window_mask_; 1128 uint32_t hashval = GetHash(pos); 1129 uint32_t hashpos = chain[wpos]; 1130 1131 int prev_dist = 0; 1132 int end = std::min<int>(pos + max_length_, size_); 1133 uint32_t chainlength = 0; 1134 uint32_t best_len = 0; 1135 for (;;) { 1136 int dist = (hashpos <= wpos) ? (wpos - hashpos) 1137 : (wpos - hashpos + window_mask_ + 1); 1138 if (dist < prev_dist) break; 1139 prev_dist = dist; 1140 uint32_t len = 0; 1141 if (dist > 0) { 1142 int i = pos; 1143 int j = pos - dist; 1144 if (numzeros > 3) { 1145 int r = std::min<int>(numzeros - 1, zeros[hashpos]); 1146 if (i + r >= end) r = end - i - 1; 1147 i += r; 1148 j += r; 1149 } 1150 while (i < end && data_[i] == data_[j]) { 1151 i++; 1152 j++; 1153 } 1154 len = i - pos; 1155 // This can trigger even if the new length is slightly smaller than the 1156 // best length, because it is possible for a slightly cheaper distance 1157 // symbol to occur. 1158 if (len >= min_length_ && len + 2 >= best_len) { 1159 auto it = special_dist_table_.find(dist); 1160 int dist_symbol = (it == special_dist_table_.end()) 1161 ? (num_special_distances_ + dist - 1) 1162 : it->second; 1163 found_match(len, dist_symbol); 1164 if (len > best_len) best_len = len; 1165 } 1166 } 1167 1168 chainlength++; 1169 if (chainlength >= maxchainlength) break; 1170 1171 if (numzeros >= 3 && len > numzeros) { 1172 if (hashpos == chainz[hashpos]) break; 1173 hashpos = chainz[hashpos]; 1174 if (zeros[hashpos] != numzeros) break; 1175 } else { 1176 if (hashpos == chain[hashpos]) break; 1177 hashpos = chain[hashpos]; 1178 if (val[hashpos] != static_cast<int>(hashval)) { 1179 // outdated hash value 1180 break; 1181 } 1182 } 1183 } 1184 } 1185 void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol, 1186 size_t* result_len) const { 1187 *result_dist_symbol = 0; 1188 *result_len = 1; 1189 FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) { 1190 if (len > *result_len || 1191 (len == *result_len && *result_dist_symbol > dist_symbol)) { 1192 *result_len = len; 1193 *result_dist_symbol = dist_symbol; 1194 } 1195 }); 1196 } 1197 }; 1198 1199 float LenCost(size_t len) { 1200 uint32_t nbits, bits, tok; 1201 HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits); 1202 constexpr float kCostTable[] = { 1203 2.797667318563126, 3.213177690381199, 2.5706009246743737, 1204 2.408392498667534, 2.829649191872326, 3.3923087753324577, 1205 4.029267451554331, 4.415576699706408, 4.509357574741465, 1206 9.21481543803004, 10.020590190114898, 11.858671627804766, 1207 12.45853300490526, 11.713105831990857, 12.561996324849314, 1208 13.775477692278367, 13.174027068768641, 1209 }; 1210 size_t table_size = sizeof kCostTable / sizeof *kCostTable; 1211 if (tok >= table_size) tok = table_size - 1; 1212 return kCostTable[tok] + nbits; 1213 } 1214 1215 // TODO(veluca): this does not take into account usage or non-usage of distance 1216 // multipliers. 1217 float DistCost(size_t dist) { 1218 uint32_t nbits, bits, tok; 1219 HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits); 1220 constexpr float kCostTable[] = { 1221 6.368282626312716, 5.680793277090298, 8.347404197105247, 1222 7.641619201599141, 6.914328374119438, 7.959808291537444, 1223 8.70023120759855, 8.71378518934703, 9.379132523982769, 1224 9.110472749092708, 9.159029569270908, 9.430936766731973, 1225 7.278284055315169, 7.8278514904267755, 10.026641158289236, 1226 9.976049229827066, 9.64351607048908, 9.563403863480442, 1227 10.171474111762747, 10.45950155077234, 9.994813912104219, 1228 10.322524683741156, 8.465808729388186, 8.756254166066853, 1229 10.160930174662234, 10.247329273413435, 10.04090403724809, 1230 10.129398517544082, 9.342311691539546, 9.07608009102374, 1231 10.104799540677513, 10.378079384990906, 10.165828974075072, 1232 10.337595322341553, 7.940557464567944, 10.575665823319431, 1233 11.023344321751955, 10.736144698831827, 11.118277044595054, 1234 7.468468230648442, 10.738305230932939, 10.906980780216568, 1235 10.163468216353817, 10.17805759656433, 11.167283670483565, 1236 11.147050200274544, 10.517921919244333, 10.651764778156886, 1237 10.17074446448919, 11.217636876224745, 11.261630721139484, 1238 11.403140815247259, 10.892472096873417, 11.1859607804481, 1239 8.017346947551262, 7.895143720278828, 11.036577113822025, 1240 11.170562110315794, 10.326988722591086, 10.40872184751056, 1241 11.213498225466386, 11.30580635516863, 10.672272515665442, 1242 10.768069466228063, 11.145257364153565, 11.64668307145549, 1243 10.593156194627339, 11.207499484844943, 10.767517766396908, 1244 10.826629811407042, 10.737764794499988, 10.6200448518045, 1245 10.191315385198092, 8.468384171390085, 11.731295299170432, 1246 11.824619886654398, 10.41518844301179, 10.16310536548649, 1247 10.539423685097576, 10.495136599328031, 10.469112847728267, 1248 11.72057686174922, 10.910326337834674, 11.378921834673758, 1249 11.847759036098536, 11.92071647623854, 10.810628276345282, 1250 11.008601085273893, 11.910326337834674, 11.949212023423133, 1251 11.298614839104337, 11.611603659010392, 10.472930394619985, 1252 11.835564720850282, 11.523267392285337, 12.01055816679611, 1253 8.413029688994023, 11.895784139536406, 11.984679534970505, 1254 11.220654278717394, 11.716311684833672, 10.61036646226114, 1255 10.89849965960364, 10.203762898863669, 10.997560826267238, 1256 11.484217379438984, 11.792836176993665, 12.24310468755171, 1257 11.464858097919262, 12.212747017409377, 11.425595666074955, 1258 11.572048533398757, 12.742093965163013, 11.381874288645637, 1259 12.191870445817015, 11.683156920035426, 11.152442115262197, 1260 11.90303691580457, 11.653292787169159, 11.938615382266098, 1261 16.970641701570223, 16.853602280380002, 17.26240782594733, 1262 16.644655390108507, 17.14310889757499, 16.910935455445955, 1263 17.505678976959697, 17.213498225466388, 2.4162310293553024, 1264 3.494587244462329, 3.5258600986408344, 3.4959806589517095, 1265 3.098390886949687, 3.343454654302911, 3.588847442290287, 1266 4.14614790111827, 5.152948641990529, 7.433696808092598, 1267 9.716311684833672, 1268 }; 1269 size_t table_size = sizeof kCostTable / sizeof *kCostTable; 1270 if (tok >= table_size) tok = table_size - 1; 1271 return kCostTable[tok] + nbits; 1272 } 1273 1274 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts, 1275 const std::vector<std::vector<Token>>& tokens, 1276 LZ77Params& lz77, 1277 std::vector<std::vector<Token>>& tokens_lz77) { 1278 // TODO(veluca): tune heuristics here. 1279 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); 1280 float bit_decrease = 0; 1281 size_t total_symbols = 0; 1282 tokens_lz77.resize(tokens.size()); 1283 HybridUintConfig uint_config; 1284 std::vector<float> sym_cost; 1285 for (size_t stream = 0; stream < tokens.size(); stream++) { 1286 size_t distance_multiplier = 1287 params.image_widths.size() > stream ? params.image_widths[stream] : 0; 1288 const auto& in = tokens[stream]; 1289 auto& out = tokens_lz77[stream]; 1290 total_symbols += in.size(); 1291 // Cumulative sum of bit costs. 1292 sym_cost.resize(in.size() + 1); 1293 for (size_t i = 0; i < in.size(); i++) { 1294 uint32_t tok, nbits, unused_bits; 1295 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); 1296 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; 1297 } 1298 1299 out.reserve(in.size()); 1300 size_t max_distance = in.size(); 1301 size_t min_length = lz77.min_length; 1302 JXL_DASSERT(min_length >= 3); 1303 size_t max_length = in.size(); 1304 1305 // Use next power of two as window size. 1306 size_t window_size = 1; 1307 while (window_size < max_distance && window_size < kWindowSize) { 1308 window_size <<= 1; 1309 } 1310 1311 HashChain chain(in.data(), in.size(), window_size, min_length, max_length, 1312 distance_multiplier); 1313 size_t len; 1314 size_t dist_symbol; 1315 1316 const size_t max_lazy_match_len = 256; // 0 to disable lazy matching 1317 1318 // Whether the next symbol was already updated (to test lazy matching) 1319 bool already_updated = false; 1320 for (size_t i = 0; i < in.size(); i++) { 1321 out.push_back(in[i]); 1322 if (!already_updated) chain.Update(i); 1323 already_updated = false; 1324 chain.FindMatch(i, max_distance, &dist_symbol, &len); 1325 if (len >= min_length) { 1326 if (len < max_lazy_match_len && i + 1 < in.size()) { 1327 // Try length at next symbol lazy matching 1328 chain.Update(i + 1); 1329 already_updated = true; 1330 size_t len2, dist_symbol2; 1331 chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2); 1332 if (len2 > len) { 1333 // Use the lazy match. Add literal, and use the next length starting 1334 // from the next byte. 1335 ++i; 1336 already_updated = false; 1337 len = len2; 1338 dist_symbol = dist_symbol2; 1339 out.push_back(in[i]); 1340 } 1341 } 1342 1343 float cost = sym_cost[i + len] - sym_cost[i]; 1344 size_t lz77_len = len - lz77.min_length; 1345 float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) + 1346 sce.AddSymbolCost(out.back().context); 1347 1348 if (lz77_cost <= cost) { 1349 out.back().value = len - min_length; 1350 out.back().is_lz77_length = true; 1351 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); 1352 bit_decrease += cost - lz77_cost; 1353 } else { 1354 // LZ77 match ignored, and symbol already pushed. Push all other 1355 // symbols and skip. 1356 for (size_t j = 1; j < len; j++) { 1357 out.push_back(in[i + j]); 1358 } 1359 } 1360 1361 if (already_updated) { 1362 chain.Update(i + 2, len - 2); 1363 already_updated = false; 1364 } else { 1365 chain.Update(i + 1, len - 1); 1366 } 1367 i += len - 1; 1368 } else { 1369 // Literal, already pushed 1370 } 1371 } 1372 } 1373 1374 if (bit_decrease > total_symbols * 0.2 + 16) { 1375 lz77.enabled = true; 1376 } 1377 } 1378 1379 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts, 1380 const std::vector<std::vector<Token>>& tokens, 1381 LZ77Params& lz77, 1382 std::vector<std::vector<Token>>& tokens_lz77) { 1383 std::vector<std::vector<Token>> tokens_for_cost_estimate; 1384 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate); 1385 // If greedy-LZ77 does not give better compression than no-lz77, no reason to 1386 // run the optimal matching. 1387 if (!lz77.enabled) return; 1388 SymbolCostEstimator sce(num_contexts + 1, params.force_huffman, 1389 tokens_for_cost_estimate, lz77); 1390 tokens_lz77.resize(tokens.size()); 1391 HybridUintConfig uint_config; 1392 std::vector<float> sym_cost; 1393 std::vector<uint32_t> dist_symbols; 1394 for (size_t stream = 0; stream < tokens.size(); stream++) { 1395 size_t distance_multiplier = 1396 params.image_widths.size() > stream ? params.image_widths[stream] : 0; 1397 const auto& in = tokens[stream]; 1398 auto& out = tokens_lz77[stream]; 1399 // Cumulative sum of bit costs. 1400 sym_cost.resize(in.size() + 1); 1401 for (size_t i = 0; i < in.size(); i++) { 1402 uint32_t tok, nbits, unused_bits; 1403 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); 1404 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; 1405 } 1406 1407 out.reserve(in.size()); 1408 size_t max_distance = in.size(); 1409 size_t min_length = lz77.min_length; 1410 JXL_DASSERT(min_length >= 3); 1411 size_t max_length = in.size(); 1412 1413 // Use next power of two as window size. 1414 size_t window_size = 1; 1415 while (window_size < max_distance && window_size < kWindowSize) { 1416 window_size <<= 1; 1417 } 1418 1419 HashChain chain(in.data(), in.size(), window_size, min_length, max_length, 1420 distance_multiplier); 1421 1422 struct MatchInfo { 1423 uint32_t len; 1424 uint32_t dist_symbol; 1425 uint32_t ctx; 1426 float total_cost = std::numeric_limits<float>::max(); 1427 }; 1428 // Total cost to encode the first N symbols. 1429 std::vector<MatchInfo> prefix_costs(in.size() + 1); 1430 prefix_costs[0].total_cost = 0; 1431 1432 size_t rle_length = 0; 1433 size_t skip_lz77 = 0; 1434 for (size_t i = 0; i < in.size(); i++) { 1435 chain.Update(i); 1436 float lit_cost = 1437 prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i]; 1438 if (prefix_costs[i + 1].total_cost > lit_cost) { 1439 prefix_costs[i + 1].dist_symbol = 0; 1440 prefix_costs[i + 1].len = 1; 1441 prefix_costs[i + 1].ctx = in[i].context; 1442 prefix_costs[i + 1].total_cost = lit_cost; 1443 } 1444 if (skip_lz77 > 0) { 1445 skip_lz77--; 1446 continue; 1447 } 1448 dist_symbols.clear(); 1449 chain.FindMatches(i, max_distance, 1450 [&dist_symbols](size_t len, size_t dist_symbol) { 1451 if (dist_symbols.size() <= len) { 1452 dist_symbols.resize(len + 1, dist_symbol); 1453 } 1454 if (dist_symbol < dist_symbols[len]) { 1455 dist_symbols[len] = dist_symbol; 1456 } 1457 }); 1458 if (dist_symbols.size() <= min_length) continue; 1459 { 1460 size_t best_cost = dist_symbols.back(); 1461 for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) { 1462 if (dist_symbols[j] < best_cost) { 1463 best_cost = dist_symbols[j]; 1464 } 1465 dist_symbols[j] = best_cost; 1466 } 1467 } 1468 for (size_t j = min_length; j < dist_symbols.size(); j++) { 1469 // Cost model that uses results from lazy LZ77. 1470 float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) + 1471 sce.DistCost(dist_symbols[j], lz77); 1472 float cost = prefix_costs[i].total_cost + lz77_cost; 1473 if (prefix_costs[i + j].total_cost > cost) { 1474 prefix_costs[i + j].len = j; 1475 prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1; 1476 prefix_costs[i + j].ctx = in[i].context; 1477 prefix_costs[i + j].total_cost = cost; 1478 } 1479 } 1480 // We are in a RLE sequence: skip all the symbols except the first 8 and 1481 // the last 8. This avoid quadratic costs for sequences with long runs of 1482 // the same symbol. 1483 if ((dist_symbols.back() == 0 && distance_multiplier == 0) || 1484 (dist_symbols.back() == 1 && distance_multiplier != 0)) { 1485 rle_length++; 1486 } else { 1487 rle_length = 0; 1488 } 1489 if (rle_length >= 8 && dist_symbols.size() > 9) { 1490 skip_lz77 = dist_symbols.size() - 10; 1491 rle_length = 0; 1492 } 1493 } 1494 size_t pos = in.size(); 1495 while (pos > 0) { 1496 bool is_lz77_length = prefix_costs[pos].dist_symbol != 0; 1497 if (is_lz77_length) { 1498 size_t dist_symbol = prefix_costs[pos].dist_symbol - 1; 1499 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); 1500 } 1501 size_t val = is_lz77_length ? prefix_costs[pos].len - min_length 1502 : in[pos - 1].value; 1503 out.emplace_back(prefix_costs[pos].ctx, val); 1504 out.back().is_lz77_length = is_lz77_length; 1505 pos -= prefix_costs[pos].len; 1506 } 1507 std::reverse(out.begin(), out.end()); 1508 } 1509 } 1510 1511 void ApplyLZ77(const HistogramParams& params, size_t num_contexts, 1512 const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77, 1513 std::vector<std::vector<Token>>& tokens_lz77) { 1514 if (params.initialize_global_state) { 1515 lz77.enabled = false; 1516 } 1517 if (params.force_huffman) { 1518 lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512); 1519 } else { 1520 lz77.min_symbol = 224; 1521 } 1522 switch (params.lz77_method) { 1523 case HistogramParams::LZ77Method::kNone: 1524 return; 1525 case HistogramParams::LZ77Method::kRLE: 1526 ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77); 1527 return; 1528 case HistogramParams::LZ77Method::kLZ77: 1529 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77); 1530 return; 1531 case HistogramParams::LZ77Method::kOptimal: 1532 ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77); 1533 return; 1534 } 1535 } 1536 } // namespace 1537 1538 Status EncodeHistograms(const std::vector<uint8_t>& context_map, 1539 const EntropyEncodingData& codes, BitWriter* writer, 1540 LayerType layer, AuxOut* aux_out) { 1541 return writer->WithMaxBits( 1542 128 + kClustersLimit * 136, layer, aux_out, 1543 [&]() -> Status { 1544 JXL_RETURN_IF_ERROR(Bundle::Write(codes.lz77, writer, layer, aux_out)); 1545 if (codes.lz77.enabled) { 1546 EncodeUintConfig(codes.lz77.length_uint_config, writer, 1547 /*log_alpha_size=*/8); 1548 } 1549 JXL_RETURN_IF_ERROR(EncodeContextMap( 1550 context_map, codes.encoding_info.size(), writer, layer, aux_out)); 1551 writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code)); 1552 size_t log_alpha_size = 8; 1553 if (codes.use_prefix_code) { 1554 log_alpha_size = PREFIX_MAX_BITS; 1555 } else { 1556 log_alpha_size = 8; // streaming_mode 1557 writer->Write(2, log_alpha_size - 5); 1558 } 1559 EncodeUintConfigs(codes.uint_config, writer, log_alpha_size); 1560 if (codes.use_prefix_code) { 1561 for (const auto& info : codes.encoding_info) { 1562 StoreVarLenUint16(info.size() - 1, writer); 1563 } 1564 } 1565 for (const auto& histo_writer : codes.encoded_histograms) { 1566 JXL_RETURN_IF_ERROR(writer->AppendUnaligned(histo_writer)); 1567 } 1568 return true; 1569 }, 1570 /*finished_histogram=*/true); 1571 } 1572 1573 StatusOr<size_t> BuildAndEncodeHistograms( 1574 JxlMemoryManager* memory_manager, const HistogramParams& params, 1575 size_t num_contexts, std::vector<std::vector<Token>>& tokens, 1576 EntropyEncodingData* codes, std::vector<uint8_t>* context_map, 1577 BitWriter* writer, LayerType layer, AuxOut* aux_out) { 1578 size_t cost = 0; 1579 codes->lz77.nonserialized_distance_context = num_contexts; 1580 std::vector<std::vector<Token>> tokens_lz77; 1581 ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77); 1582 if (ans_fuzzer_friendly_) { 1583 codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0); 1584 codes->lz77.min_symbol = 2048; 1585 } 1586 1587 const size_t max_contexts = std::min(num_contexts, kClustersLimit); 1588 const auto& body = [&]() -> Status { 1589 if (writer) { 1590 JXL_RETURN_IF_ERROR(Bundle::Write(codes->lz77, writer, layer, aux_out)); 1591 } else { 1592 size_t ebits, bits; 1593 JXL_RETURN_IF_ERROR(Bundle::CanEncode(codes->lz77, &ebits, &bits)); 1594 cost += bits; 1595 } 1596 if (codes->lz77.enabled) { 1597 if (writer) { 1598 size_t b = writer->BitsWritten(); 1599 EncodeUintConfig(codes->lz77.length_uint_config, writer, 1600 /*log_alpha_size=*/8); 1601 cost += writer->BitsWritten() - b; 1602 } else { 1603 SizeWriter size_writer; 1604 EncodeUintConfig(codes->lz77.length_uint_config, &size_writer, 1605 /*log_alpha_size=*/8); 1606 cost += size_writer.size; 1607 } 1608 num_contexts += 1; 1609 tokens = std::move(tokens_lz77); 1610 } 1611 size_t total_tokens = 0; 1612 // Build histograms. 1613 HistogramBuilder builder(num_contexts); 1614 HybridUintConfig uint_config; // Default config for clustering. 1615 // Unless we are using the kContextMap histogram option. 1616 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { 1617 uint_config = HybridUintConfig(2, 0, 1); 1618 } 1619 if (params.uint_method == HistogramParams::HybridUintMethod::k000) { 1620 uint_config = HybridUintConfig(0, 0, 0); 1621 } 1622 if (ans_fuzzer_friendly_) { 1623 uint_config = HybridUintConfig(10, 0, 0); 1624 } 1625 for (const auto& stream : tokens) { 1626 if (codes->lz77.enabled) { 1627 for (const auto& token : stream) { 1628 total_tokens++; 1629 uint32_t tok, nbits, bits; 1630 (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config) 1631 .Encode(token.value, &tok, &nbits, &bits); 1632 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; 1633 builder.VisitSymbol(tok, token.context); 1634 } 1635 } else if (num_contexts == 1) { 1636 for (const auto& token : stream) { 1637 total_tokens++; 1638 uint32_t tok, nbits, bits; 1639 uint_config.Encode(token.value, &tok, &nbits, &bits); 1640 builder.VisitSymbol(tok, /*token.context=*/0); 1641 } 1642 } else { 1643 for (const auto& token : stream) { 1644 total_tokens++; 1645 uint32_t tok, nbits, bits; 1646 uint_config.Encode(token.value, &tok, &nbits, &bits); 1647 builder.VisitSymbol(tok, token.context); 1648 } 1649 } 1650 } 1651 1652 if (params.add_missing_symbols) { 1653 for (size_t c = 0; c < num_contexts; ++c) { 1654 for (int symbol = 0; symbol < ANS_MAX_ALPHABET_SIZE; ++symbol) { 1655 builder.VisitSymbol(symbol, c); 1656 } 1657 } 1658 } 1659 1660 if (params.initialize_global_state) { 1661 bool use_prefix_code = 1662 params.force_huffman || total_tokens < 100 || 1663 params.clustering == HistogramParams::ClusteringType::kFastest || 1664 ans_fuzzer_friendly_; 1665 if (!use_prefix_code) { 1666 bool all_singleton = true; 1667 for (size_t i = 0; i < num_contexts; i++) { 1668 if (builder.Histo(i).ShannonEntropy() >= 1e-5) { 1669 all_singleton = false; 1670 } 1671 } 1672 if (all_singleton) { 1673 use_prefix_code = true; 1674 } 1675 } 1676 codes->use_prefix_code = use_prefix_code; 1677 } 1678 1679 if (params.add_fixed_histograms) { 1680 // TODO(szabadka) Add more fixed histograms. 1681 // TODO(szabadka) Reduce alphabet size by choosing a non-default 1682 // uint_config. 1683 const size_t alphabet_size = ANS_MAX_ALPHABET_SIZE; 1684 const size_t log_alpha_size = 8; 1685 JXL_ENSURE(alphabet_size == 1u << log_alpha_size); 1686 static_assert(ANS_MAX_ALPHABET_SIZE <= ANS_TAB_SIZE); 1687 std::vector<int32_t> counts = 1688 CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); 1689 codes->encoding_info.emplace_back(); 1690 codes->encoding_info.back().resize(alphabet_size); 1691 codes->encoded_histograms.emplace_back(memory_manager); 1692 BitWriter* histo_writer = &codes->encoded_histograms.back(); 1693 JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits( 1694 256 + alphabet_size * 24, LayerType::Header, nullptr, 1695 [&]() -> Status { 1696 JXL_ASSIGN_OR_RETURN( 1697 size_t ans_cost, 1698 BuildAndStoreANSEncodingData( 1699 memory_manager, params.ans_histogram_strategy, 1700 counts.data(), alphabet_size, log_alpha_size, 1701 codes->use_prefix_code, codes->encoding_info.back().data(), 1702 histo_writer)); 1703 (void)ans_cost; 1704 return true; 1705 })); 1706 } 1707 1708 // Encode histograms. 1709 JXL_ASSIGN_OR_RETURN( 1710 size_t entropy_bits, 1711 builder.BuildAndStoreEntropyCodes(memory_manager, params, tokens, codes, 1712 context_map, writer, layer, aux_out)); 1713 cost += entropy_bits; 1714 return true; 1715 }; 1716 if (writer) { 1717 JXL_RETURN_IF_ERROR(writer->WithMaxBits( 1718 128 + num_contexts * 40 + max_contexts * 96, layer, aux_out, body, 1719 /*finished_histogram=*/true)); 1720 } else { 1721 JXL_RETURN_IF_ERROR(body()); 1722 } 1723 1724 if (aux_out != nullptr) { 1725 aux_out->layer(layer).num_clustered_histograms += 1726 codes->encoding_info.size(); 1727 } 1728 return cost; 1729 } 1730 1731 size_t WriteTokens(const std::vector<Token>& tokens, 1732 const EntropyEncodingData& codes, 1733 const std::vector<uint8_t>& context_map, 1734 size_t context_offset, BitWriter* writer) { 1735 size_t num_extra_bits = 0; 1736 if (codes.use_prefix_code) { 1737 for (const auto& token : tokens) { 1738 uint32_t tok, nbits, bits; 1739 size_t histo = context_map[context_offset + token.context]; 1740 (token.is_lz77_length ? codes.lz77.length_uint_config 1741 : codes.uint_config[histo]) 1742 .Encode(token.value, &tok, &nbits, &bits); 1743 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; 1744 // Combine two calls to the BitWriter. Equivalent to: 1745 // writer->Write(codes.encoding_info[histo][tok].depth, 1746 // codes.encoding_info[histo][tok].bits); 1747 // writer->Write(nbits, bits); 1748 uint64_t data = codes.encoding_info[histo][tok].bits; 1749 data |= static_cast<uint64_t>(bits) 1750 << codes.encoding_info[histo][tok].depth; 1751 writer->Write(codes.encoding_info[histo][tok].depth + nbits, data); 1752 num_extra_bits += nbits; 1753 } 1754 return num_extra_bits; 1755 } 1756 std::vector<uint64_t> out; 1757 std::vector<uint8_t> out_nbits; 1758 out.reserve(tokens.size()); 1759 out_nbits.reserve(tokens.size()); 1760 uint64_t allbits = 0; 1761 size_t numallbits = 0; 1762 // Writes in *reversed* order. 1763 auto addbits = [&](size_t bits, size_t nbits) { 1764 if (JXL_UNLIKELY(nbits)) { 1765 JXL_DASSERT(bits >> nbits == 0); 1766 if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) { 1767 out.push_back(allbits); 1768 out_nbits.push_back(numallbits); 1769 numallbits = allbits = 0; 1770 } 1771 allbits <<= nbits; 1772 allbits |= bits; 1773 numallbits += nbits; 1774 } 1775 }; 1776 const int end = tokens.size(); 1777 ANSCoder ans; 1778 if (codes.lz77.enabled || context_map.size() > 1) { 1779 for (int i = end - 1; i >= 0; --i) { 1780 const Token token = tokens[i]; 1781 const uint8_t histo = context_map[context_offset + token.context]; 1782 uint32_t tok, nbits, bits; 1783 (token.is_lz77_length ? codes.lz77.length_uint_config 1784 : codes.uint_config[histo]) 1785 .Encode(tokens[i].value, &tok, &nbits, &bits); 1786 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; 1787 const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok]; 1788 JXL_DASSERT(info.freq_ > 0); 1789 // Extra bits first as this is reversed. 1790 addbits(bits, nbits); 1791 num_extra_bits += nbits; 1792 uint8_t ans_nbits = 0; 1793 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); 1794 addbits(ans_bits, ans_nbits); 1795 } 1796 } else { 1797 for (int i = end - 1; i >= 0; --i) { 1798 uint32_t tok, nbits, bits; 1799 codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits); 1800 const ANSEncSymbolInfo& info = codes.encoding_info[0][tok]; 1801 // Extra bits first as this is reversed. 1802 addbits(bits, nbits); 1803 num_extra_bits += nbits; 1804 uint8_t ans_nbits = 0; 1805 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); 1806 addbits(ans_bits, ans_nbits); 1807 } 1808 } 1809 const uint32_t state = ans.GetState(); 1810 writer->Write(32, state); 1811 writer->Write(numallbits, allbits); 1812 for (int i = out.size(); i > 0; --i) { 1813 writer->Write(out_nbits[i - 1], out[i - 1]); 1814 } 1815 return num_extra_bits; 1816 } 1817 1818 Status WriteTokens(const std::vector<Token>& tokens, 1819 const EntropyEncodingData& codes, 1820 const std::vector<uint8_t>& context_map, 1821 size_t context_offset, BitWriter* writer, LayerType layer, 1822 AuxOut* aux_out) { 1823 // Theoretically, we could have 15 prefix code bits + 31 extra bits. 1824 return writer->WithMaxBits( 1825 46 * tokens.size() + 32 * 1024 * 4, layer, aux_out, [&] { 1826 size_t num_extra_bits = 1827 WriteTokens(tokens, codes, context_map, context_offset, writer); 1828 if (aux_out != nullptr) { 1829 aux_out->layer(layer).extra_bits += num_extra_bits; 1830 } 1831 return true; 1832 }); 1833 } 1834 1835 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) { 1836 #if JXL_IS_DEBUG_BUILD // Guard against accidental / malicious changes. 1837 ans_fuzzer_friendly_ = ans_fuzzer_friendly; 1838 #endif 1839 } 1840 1841 HistogramParams HistogramParams::ForModular( 1842 const CompressParams& cparams, 1843 const std::vector<uint8_t>& extra_dc_precision, bool streaming_mode) { 1844 HistogramParams params; 1845 params.streaming_mode = streaming_mode; 1846 if (cparams.speed_tier > SpeedTier::kKitten) { 1847 params.clustering = HistogramParams::ClusteringType::kFast; 1848 params.ans_histogram_strategy = 1849 cparams.speed_tier > SpeedTier::kThunder 1850 ? HistogramParams::ANSHistogramStrategy::kFast 1851 : HistogramParams::ANSHistogramStrategy::kApproximate; 1852 params.lz77_method = 1853 cparams.decoding_speed_tier >= 3 && cparams.modular_mode 1854 ? (cparams.speed_tier >= SpeedTier::kFalcon 1855 ? HistogramParams::LZ77Method::kRLE 1856 : HistogramParams::LZ77Method::kLZ77) 1857 : HistogramParams::LZ77Method::kNone; 1858 // Near-lossless DC, as well as modular mode, require choosing hybrid uint 1859 // more carefully. 1860 if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) || 1861 (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) { 1862 params.uint_method = HistogramParams::HybridUintMethod::kFast; 1863 } else { 1864 params.uint_method = HistogramParams::HybridUintMethod::kNone; 1865 } 1866 } else if (cparams.speed_tier <= SpeedTier::kTortoise) { 1867 params.lz77_method = HistogramParams::LZ77Method::kOptimal; 1868 } else { 1869 params.lz77_method = HistogramParams::LZ77Method::kLZ77; 1870 } 1871 if (cparams.decoding_speed_tier >= 1) { 1872 params.max_histograms = 12; 1873 } 1874 if (cparams.decoding_speed_tier >= 1 && cparams.responsive) { 1875 params.lz77_method = cparams.speed_tier >= SpeedTier::kCheetah 1876 ? HistogramParams::LZ77Method::kRLE 1877 : cparams.speed_tier >= SpeedTier::kKitten 1878 ? HistogramParams::LZ77Method::kLZ77 1879 : HistogramParams::LZ77Method::kOptimal; 1880 } 1881 if (cparams.decoding_speed_tier >= 2 && cparams.responsive) { 1882 params.uint_method = HistogramParams::HybridUintMethod::k000; 1883 params.force_huffman = true; 1884 } 1885 return params; 1886 } 1887 } // namespace jxl