dec_ans.cc (15861B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_ans.h" 7 8 #include <jxl/memory_manager.h> 9 10 #include <cstdint> 11 #include <vector> 12 13 #include "lib/jxl/ans_common.h" 14 #include "lib/jxl/ans_params.h" 15 #include "lib/jxl/base/bits.h" 16 #include "lib/jxl/base/printf_macros.h" 17 #include "lib/jxl/base/status.h" 18 #include "lib/jxl/dec_context_map.h" 19 #include "lib/jxl/fields.h" 20 #include "lib/jxl/memory_manager_internal.h" 21 22 namespace jxl { 23 namespace { 24 25 // Decodes a number in the range [0..255], by reading 1 - 11 bits. 26 inline int DecodeVarLenUint8(BitReader* input) { 27 if (input->ReadFixedBits<1>()) { 28 int nbits = static_cast<int>(input->ReadFixedBits<3>()); 29 if (nbits == 0) { 30 return 1; 31 } else { 32 return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits); 33 } 34 } 35 return 0; 36 } 37 38 // Decodes a number in the range [0..65535], by reading 1 - 21 bits. 39 inline int DecodeVarLenUint16(BitReader* input) { 40 if (input->ReadFixedBits<1>()) { 41 int nbits = static_cast<int>(input->ReadFixedBits<4>()); 42 if (nbits == 0) { 43 return 1; 44 } else { 45 return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits); 46 } 47 } 48 return 0; 49 } 50 51 Status ReadHistogram(int precision_bits, std::vector<int32_t>* counts, 52 BitReader* input) { 53 int range = 1 << precision_bits; 54 int simple_code = input->ReadBits(1); 55 if (simple_code == 1) { 56 int i; 57 int symbols[2] = {0}; 58 int max_symbol = 0; 59 const int num_symbols = input->ReadBits(1) + 1; 60 for (i = 0; i < num_symbols; ++i) { 61 symbols[i] = DecodeVarLenUint8(input); 62 if (symbols[i] > max_symbol) max_symbol = symbols[i]; 63 } 64 counts->resize(max_symbol + 1); 65 if (num_symbols == 1) { 66 (*counts)[symbols[0]] = range; 67 } else { 68 if (symbols[0] == symbols[1]) { // corrupt data 69 return false; 70 } 71 (*counts)[symbols[0]] = input->ReadBits(precision_bits); 72 (*counts)[symbols[1]] = range - (*counts)[symbols[0]]; 73 } 74 } else { 75 int is_flat = input->ReadBits(1); 76 if (is_flat == 1) { 77 int alphabet_size = DecodeVarLenUint8(input) + 1; 78 JXL_ENSURE(alphabet_size <= range); 79 *counts = CreateFlatHistogram(alphabet_size, range); 80 return true; 81 } 82 83 uint32_t shift; 84 { 85 // TODO(veluca): speed up reading with table lookups. 86 int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); 87 int log = 0; 88 for (; log < upper_bound_log; log++) { 89 if (input->ReadFixedBits<1>() == 0) break; 90 } 91 shift = (input->ReadBits(log) | (1 << log)) - 1; 92 if (shift > ANS_LOG_TAB_SIZE + 1) { 93 return JXL_FAILURE("Invalid shift value"); 94 } 95 } 96 97 int length = DecodeVarLenUint8(input) + 3; 98 counts->resize(length); 99 int total_count = 0; 100 101 static const uint8_t huff[128][2] = { 102 {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 103 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 104 {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 105 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 106 {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 107 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 108 {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 109 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 110 {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 111 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 112 {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 113 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 114 {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 115 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 116 {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 117 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 118 }; 119 120 std::vector<int> logcounts(counts->size()); 121 int omit_log = -1; 122 int omit_pos = -1; 123 // This array remembers which symbols have an RLE length. 124 std::vector<int> same(counts->size(), 0); 125 for (size_t i = 0; i < logcounts.size(); ++i) { 126 input->Refill(); // for PeekFixedBits + Advance 127 int idx = input->PeekFixedBits<7>(); 128 input->Consume(huff[idx][0]); 129 logcounts[i] = huff[idx][1]; 130 // The RLE symbol. 131 if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) { 132 int rle_length = DecodeVarLenUint8(input); 133 same[i] = rle_length + 5; 134 i += rle_length + 3; 135 continue; 136 } 137 if (logcounts[i] > omit_log) { 138 omit_log = logcounts[i]; 139 omit_pos = i; 140 } 141 } 142 // Invalid input, e.g. due to invalid usage of RLE. 143 if (omit_pos < 0) return JXL_FAILURE("Invalid histogram."); 144 if (static_cast<size_t>(omit_pos) + 1 < logcounts.size() && 145 logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) { 146 return JXL_FAILURE("Invalid histogram."); 147 } 148 int prev = 0; 149 int numsame = 0; 150 for (size_t i = 0; i < logcounts.size(); ++i) { 151 if (same[i]) { 152 // RLE sequence, let this loop output the same count for the next 153 // iterations. 154 numsame = same[i] - 1; 155 prev = i > 0 ? (*counts)[i - 1] : 0; 156 } 157 if (numsame > 0) { 158 (*counts)[i] = prev; 159 numsame--; 160 } else { 161 unsigned int code = logcounts[i]; 162 // omit_pos may not be negative at this point (checked before). 163 if (i == static_cast<size_t>(omit_pos)) { 164 continue; 165 } else if (code == 0) { 166 continue; 167 } else if (code == 1) { 168 (*counts)[i] = 1; 169 } else { 170 int bitcount = GetPopulationCountPrecision(code - 1, shift); 171 (*counts)[i] = (1u << (code - 1)) + 172 (input->ReadBits(bitcount) << (code - 1 - bitcount)); 173 } 174 } 175 total_count += (*counts)[i]; 176 } 177 (*counts)[omit_pos] = range - total_count; 178 if ((*counts)[omit_pos] <= 0) { 179 // The histogram we've read sums to more than total_count (including at 180 // least 1 for the omitted value). 181 return JXL_FAILURE("Invalid histogram count."); 182 } 183 } 184 return true; 185 } 186 187 } // namespace 188 189 Status DecodeANSCodes(JxlMemoryManager* memory_manager, 190 const size_t num_histograms, 191 const size_t max_alphabet_size, BitReader* in, 192 ANSCode* result) { 193 result->memory_manager = memory_manager; 194 result->degenerate_symbols.resize(num_histograms, -1); 195 if (result->use_prefix_code) { 196 JXL_ENSURE(max_alphabet_size <= 1 << PREFIX_MAX_BITS); 197 result->huffman_data.resize(num_histograms); 198 std::vector<uint16_t> alphabet_sizes(num_histograms); 199 for (size_t c = 0; c < num_histograms; c++) { 200 alphabet_sizes[c] = DecodeVarLenUint16(in) + 1; 201 if (alphabet_sizes[c] > max_alphabet_size) { 202 return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]); 203 } 204 } 205 for (size_t c = 0; c < num_histograms; c++) { 206 if (alphabet_sizes[c] > 1) { 207 if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) { 208 if (!in->AllReadsWithinBounds()) { 209 return JXL_STATUS(StatusCode::kNotEnoughBytes, 210 "Not enough bytes for huffman code"); 211 } 212 return JXL_FAILURE("Invalid huffman tree number %" PRIuS 213 ", alphabet size %u", 214 c, alphabet_sizes[c]); 215 } 216 } else { 217 // 0-bit codes does not require extension tables. 218 result->huffman_data[c].table_.clear(); 219 result->huffman_data[c].table_.resize(1u << kHuffmanTableBits); 220 } 221 for (const auto& h : result->huffman_data[c].table_) { 222 if (h.bits <= kHuffmanTableBits) { 223 result->UpdateMaxNumBits(c, h.value); 224 } 225 } 226 } 227 } else { 228 JXL_ENSURE(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE); 229 size_t alloc_size = num_histograms * (1 << result->log_alpha_size) * 230 sizeof(AliasTable::Entry); 231 JXL_ASSIGN_OR_RETURN(result->alias_tables, 232 AlignedMemory::Create(memory_manager, alloc_size)); 233 AliasTable::Entry* alias_tables = 234 result->alias_tables.address<AliasTable::Entry>(); 235 for (size_t c = 0; c < num_histograms; ++c) { 236 std::vector<int32_t> counts; 237 if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) { 238 return JXL_FAILURE("Invalid histogram bitstream."); 239 } 240 if (counts.size() > max_alphabet_size) { 241 return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size()); 242 } 243 while (!counts.empty() && counts.back() == 0) { 244 counts.pop_back(); 245 } 246 for (size_t s = 0; s < counts.size(); s++) { 247 if (counts[s] != 0) { 248 result->UpdateMaxNumBits(c, s); 249 } 250 } 251 // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol. 252 int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1); 253 for (int s = 0; s < degenerate_symbol; ++s) { 254 if (counts[s] != 0) { 255 degenerate_symbol = -1; 256 break; 257 } 258 } 259 result->degenerate_symbols[c] = degenerate_symbol; 260 JXL_RETURN_IF_ERROR( 261 InitAliasTable(counts, ANS_LOG_TAB_SIZE, result->log_alpha_size, 262 alias_tables + c * (1 << result->log_alpha_size))); 263 } 264 } 265 return true; 266 } 267 Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config, 268 BitReader* br) { 269 br->Refill(); 270 size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1)); 271 size_t msb_in_token = 0; 272 size_t lsb_in_token = 0; 273 if (split_exponent != log_alpha_size) { 274 // otherwise, msb/lsb don't matter. 275 size_t nbits = CeilLog2Nonzero(split_exponent + 1); 276 msb_in_token = br->ReadBits(nbits); 277 if (msb_in_token > split_exponent) { 278 // This could be invalid here already and we need to check this before 279 // we use its value to read more bits. 280 return JXL_FAILURE("Invalid HybridUintConfig"); 281 } 282 nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1); 283 lsb_in_token = br->ReadBits(nbits); 284 } 285 if (lsb_in_token + msb_in_token > split_exponent) { 286 return JXL_FAILURE("Invalid HybridUintConfig"); 287 } 288 *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token); 289 return true; 290 } 291 292 Status DecodeUintConfigs(size_t log_alpha_size, 293 std::vector<HybridUintConfig>* uint_config, 294 BitReader* br) { 295 // TODO(veluca): RLE? 296 for (auto& cfg : *uint_config) { 297 JXL_RETURN_IF_ERROR(DecodeUintConfig(log_alpha_size, &cfg, br)); 298 } 299 return true; 300 } 301 302 LZ77Params::LZ77Params() { Bundle::Init(this); } 303 Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) { 304 JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled)); 305 if (!visitor->Conditional(enabled)) return true; 306 JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096), 307 BitsOffset(15, 8), 224, &min_symbol)); 308 JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5), 309 BitsOffset(8, 9), 3, &min_length)); 310 return true; 311 } 312 313 void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) { 314 HybridUintConfig* cfg = &uint_config[ctx]; 315 // LZ77 symbols use a different uint config. 316 if (lz77.enabled && lz77.nonserialized_distance_context != ctx && 317 symbol >= lz77.min_symbol) { 318 symbol -= lz77.min_symbol; 319 cfg = &lz77.length_uint_config; 320 } 321 size_t split_token = cfg->split_token; 322 size_t msb_in_token = cfg->msb_in_token; 323 size_t lsb_in_token = cfg->lsb_in_token; 324 size_t split_exponent = cfg->split_exponent; 325 if (symbol < split_token) { 326 max_num_bits = std::max(max_num_bits, split_exponent); 327 return; 328 } 329 uint32_t n_extra_bits = 330 split_exponent - (msb_in_token + lsb_in_token) + 331 ((symbol - split_token) >> (msb_in_token + lsb_in_token)); 332 size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1; 333 max_num_bits = std::max(max_num_bits, total_bits); 334 } 335 336 Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br, 337 size_t num_contexts, ANSCode* code, 338 std::vector<uint8_t>* context_map, bool disallow_lz77) { 339 JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77)); 340 if (code->lz77.enabled) { 341 num_contexts++; 342 JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8, 343 &code->lz77.length_uint_config, br)); 344 } 345 if (code->lz77.enabled && disallow_lz77) { 346 return JXL_FAILURE("Using LZ77 when explicitly disallowed"); 347 } 348 size_t num_histograms = 1; 349 context_map->resize(num_contexts); 350 if (num_contexts > 1) { 351 JXL_RETURN_IF_ERROR( 352 DecodeContextMap(memory_manager, context_map, &num_histograms, br)); 353 } 354 JXL_DEBUG_V( 355 4, "Decoded context map of size %" PRIuS " and %" PRIuS " histograms", 356 num_contexts, num_histograms); 357 code->lz77.nonserialized_distance_context = context_map->back(); 358 code->use_prefix_code = static_cast<bool>(br->ReadFixedBits<1>()); 359 if (code->use_prefix_code) { 360 code->log_alpha_size = PREFIX_MAX_BITS; 361 } else { 362 code->log_alpha_size = br->ReadFixedBits<2>() + 5; 363 } 364 code->uint_config.resize(num_histograms); 365 JXL_RETURN_IF_ERROR( 366 DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br)); 367 const size_t max_alphabet_size = 1 << code->log_alpha_size; 368 JXL_RETURN_IF_ERROR(DecodeANSCodes(memory_manager, num_histograms, 369 max_alphabet_size, br, code)); 370 return true; 371 } 372 373 StatusOr<ANSSymbolReader> ANSSymbolReader::Create(const ANSCode* code, 374 BitReader* JXL_RESTRICT br, 375 size_t distance_multiplier) { 376 AlignedMemory lz77_window_storage; 377 if (code->lz77.enabled) { 378 JxlMemoryManager* memory_manager = code->memory_manager; 379 JXL_ASSIGN_OR_RETURN( 380 lz77_window_storage, 381 AlignedMemory::Create(memory_manager, kWindowSize * sizeof(uint32_t))); 382 } 383 return ANSSymbolReader(code, br, distance_multiplier, 384 std::move(lz77_window_storage)); 385 } 386 387 ANSSymbolReader::ANSSymbolReader(const ANSCode* code, 388 BitReader* JXL_RESTRICT br, 389 size_t distance_multiplier, 390 AlignedMemory&& lz77_window_storage) 391 : alias_tables_(code->alias_tables.address<AliasTable::Entry>()), 392 huffman_data_(code->huffman_data.data()), 393 use_prefix_code_(code->use_prefix_code), 394 configs(code->uint_config.data()), 395 lz77_window_storage_(std::move(lz77_window_storage)) { 396 if (!use_prefix_code_) { 397 state_ = static_cast<uint32_t>(br->ReadFixedBits<32>()); 398 log_alpha_size_ = code->log_alpha_size; 399 log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size; 400 entry_size_minus_1_ = (1 << log_entry_size_) - 1; 401 } else { 402 state_ = (ANS_SIGNATURE << 16u); 403 } 404 if (!code->lz77.enabled) return; 405 lz77_window_ = lz77_window_storage_.address<uint32_t>(); 406 lz77_ctx_ = code->lz77.nonserialized_distance_context; 407 lz77_length_uint_ = code->lz77.length_uint_config; 408 lz77_threshold_ = code->lz77.min_symbol; 409 lz77_min_length_ = code->lz77.min_length; 410 num_special_distances_ = distance_multiplier == 0 ? 0 : kNumSpecialDistances; 411 for (size_t i = 0; i < num_special_distances_; i++) { 412 special_distances_[i] = SpecialDistance(i, distance_multiplier); 413 } 414 } 415 416 } // namespace jxl