dec_context_map.cc (2975B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_context_map.h" 7 8 #include <jxl/memory_manager.h> 9 10 #include <algorithm> 11 #include <cstdint> 12 #include <vector> 13 14 #include "lib/jxl/base/status.h" 15 #include "lib/jxl/dec_ans.h" 16 #include "lib/jxl/inverse_mtf-inl.h" 17 18 namespace jxl { 19 20 namespace { 21 22 // Context map uses uint8_t. 23 constexpr size_t kMaxClusters = 256; 24 25 Status VerifyContextMap(const std::vector<uint8_t>& context_map, 26 const size_t num_htrees) { 27 std::vector<bool> have_htree(num_htrees); 28 size_t num_found = 0; 29 for (const uint8_t htree : context_map) { 30 if (htree >= num_htrees) { 31 return JXL_FAILURE("Invalid histogram index in context map."); 32 } 33 if (!have_htree[htree]) { 34 have_htree[htree] = true; 35 ++num_found; 36 } 37 } 38 if (num_found != num_htrees) { 39 return JXL_FAILURE("Incomplete context map."); 40 } 41 return true; 42 } 43 44 } // namespace 45 46 Status DecodeContextMap(JxlMemoryManager* memory_manager, 47 std::vector<uint8_t>* context_map, size_t* num_htrees, 48 BitReader* input) { 49 bool is_simple = static_cast<bool>(input->ReadFixedBits<1>()); 50 if (is_simple) { 51 int bits_per_entry = input->ReadFixedBits<2>(); 52 if (bits_per_entry != 0) { 53 for (uint8_t& entry : *context_map) { 54 entry = input->ReadBits(bits_per_entry); 55 } 56 } else { 57 std::fill(context_map->begin(), context_map->end(), 0); 58 } 59 } else { 60 bool use_mtf = static_cast<bool>(input->ReadFixedBits<1>()); 61 ANSCode code; 62 std::vector<uint8_t> sink_ctx_map; 63 // Usage of LZ77 is disallowed if decoding only two symbols. This doesn't 64 // make sense in non-malicious bitstreams, and could cause a stack overflow 65 // in malicious bitstreams by making every context map require its own 66 // context map. 67 JXL_RETURN_IF_ERROR( 68 DecodeHistograms(memory_manager, input, 1, &code, &sink_ctx_map, 69 /*disallow_lz77=*/context_map->size() <= 2)); 70 JXL_ASSIGN_OR_RETURN(ANSSymbolReader reader, 71 ANSSymbolReader::Create(&code, input)); 72 size_t i = 0; 73 uint32_t maxsym = 0; 74 while (i < context_map->size()) { 75 uint32_t sym = reader.ReadHybridUintInlined</*uses_lz77=*/true>( 76 0, input, sink_ctx_map); 77 maxsym = sym > maxsym ? sym : maxsym; 78 (*context_map)[i] = sym; 79 i++; 80 } 81 if (maxsym >= kMaxClusters) { 82 return JXL_FAILURE("Invalid cluster ID"); 83 } 84 if (!reader.CheckANSFinalState()) { 85 return JXL_FAILURE("Invalid context map"); 86 } 87 if (use_mtf) { 88 InverseMoveToFrontTransform(context_map->data(), context_map->size()); 89 } 90 } 91 *num_htrees = *std::max_element(context_map->begin(), context_map->end()) + 1; 92 return VerifyContextMap(*context_map, *num_htrees); 93 } 94 95 } // namespace jxl