tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

encoding.cc (28335B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/modular/encoding/encoding.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <algorithm>
     11 #include <array>
     12 #include <cstddef>
     13 #include <cstdint>
     14 #include <cstdlib>
     15 #include <queue>
     16 #include <utility>
     17 #include <vector>
     18 
     19 #include "lib/jxl/base/printf_macros.h"
     20 #include "lib/jxl/base/scope_guard.h"
     21 #include "lib/jxl/base/status.h"
     22 #include "lib/jxl/dec_ans.h"
     23 #include "lib/jxl/dec_bit_reader.h"
     24 #include "lib/jxl/frame_dimensions.h"
     25 #include "lib/jxl/image_ops.h"
     26 #include "lib/jxl/modular/encoding/context_predict.h"
     27 #include "lib/jxl/modular/options.h"
     28 #include "lib/jxl/pack_signed.h"
     29 
     30 namespace jxl {
     31 
     32 // Removes all nodes that use a static property (i.e. channel or group ID) from
     33 // the tree and collapses each node on even levels with its two children to
     34 // produce a flatter tree. Also computes whether the resulting tree requires
     35 // using the weighted predictor.
     36 FlatTree FilterTree(const Tree &global_tree,
     37                    std::array<pixel_type, kNumStaticProperties> &static_props,
     38                    size_t *num_props, bool *use_wp, bool *wp_only,
     39                    bool *gradient_only) {
     40  *num_props = 0;
     41  bool has_wp = false;
     42  bool has_non_wp = false;
     43  *gradient_only = true;
     44  const auto mark_property = [&](int32_t p) {
     45    if (p == kWPProp) {
     46      has_wp = true;
     47    } else if (p >= kNumStaticProperties) {
     48      has_non_wp = true;
     49    }
     50    if (p >= kNumStaticProperties && p != kGradientProp) {
     51      *gradient_only = false;
     52    }
     53  };
     54  FlatTree output;
     55  std::queue<size_t> nodes;
     56  nodes.push(0);
     57  // Produces a trimmed and flattened tree by doing a BFS visit of the original
     58  // tree, ignoring branches that are known to be false and proceeding two
     59  // levels at a time to collapse nodes in a flatter tree; if an inner parent
     60  // node has a leaf as a child, the leaf is duplicated and an implicit fake
     61  // node is added. This allows to reduce the number of branches when traversing
     62  // the resulting flat tree.
     63  while (!nodes.empty()) {
     64    size_t cur = nodes.front();
     65    nodes.pop();
     66    // Skip nodes that we can decide now, by jumping directly to their children.
     67    while (global_tree[cur].property < kNumStaticProperties &&
     68           global_tree[cur].property != -1) {
     69      if (static_props[global_tree[cur].property] > global_tree[cur].splitval) {
     70        cur = global_tree[cur].lchild;
     71      } else {
     72        cur = global_tree[cur].rchild;
     73      }
     74    }
     75    FlatDecisionNode flat;
     76    if (global_tree[cur].property == -1) {
     77      flat.property0 = -1;
     78      flat.childID = global_tree[cur].lchild;
     79      flat.predictor = global_tree[cur].predictor;
     80      flat.predictor_offset = global_tree[cur].predictor_offset;
     81      flat.multiplier = global_tree[cur].multiplier;
     82      *gradient_only &= flat.predictor == Predictor::Gradient;
     83      has_wp |= flat.predictor == Predictor::Weighted;
     84      has_non_wp |= flat.predictor != Predictor::Weighted;
     85      output.push_back(flat);
     86      continue;
     87    }
     88    flat.childID = output.size() + nodes.size() + 1;
     89 
     90    flat.property0 = global_tree[cur].property;
     91    *num_props = std::max<size_t>(flat.property0 + 1, *num_props);
     92    flat.splitval0 = global_tree[cur].splitval;
     93 
     94    for (size_t i = 0; i < 2; i++) {
     95      size_t cur_child =
     96          i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild;
     97      // Skip nodes that we can decide now.
     98      while (global_tree[cur_child].property < kNumStaticProperties &&
     99             global_tree[cur_child].property != -1) {
    100        if (static_props[global_tree[cur_child].property] >
    101            global_tree[cur_child].splitval) {
    102          cur_child = global_tree[cur_child].lchild;
    103        } else {
    104          cur_child = global_tree[cur_child].rchild;
    105        }
    106      }
    107      // We ended up in a leaf, add a placeholder decision and two copies of the
    108      // leaf.
    109      if (global_tree[cur_child].property == -1) {
    110        flat.properties[i] = 0;
    111        flat.splitvals[i] = 0;
    112        nodes.push(cur_child);
    113        nodes.push(cur_child);
    114      } else {
    115        flat.properties[i] = global_tree[cur_child].property;
    116        flat.splitvals[i] = global_tree[cur_child].splitval;
    117        nodes.push(global_tree[cur_child].lchild);
    118        nodes.push(global_tree[cur_child].rchild);
    119        *num_props = std::max<size_t>(flat.properties[i] + 1, *num_props);
    120      }
    121    }
    122 
    123    for (int16_t property : flat.properties) mark_property(property);
    124    mark_property(flat.property0);
    125    output.push_back(flat);
    126  }
    127  if (*num_props > kNumNonrefProperties) {
    128    *num_props =
    129        DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) *
    130            kExtraPropsPerChannel +
    131        kNumNonrefProperties;
    132  } else {
    133    *num_props = kNumNonrefProperties;
    134  }
    135  *use_wp = has_wp;
    136  *wp_only = has_wp && !has_non_wp;
    137 
    138  return output;
    139 }
    140 
    141 namespace detail {
    142 template <bool uses_lz77>
    143 Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader,
    144                                 const std::vector<uint8_t> &context_map,
    145                                 const Tree &global_tree,
    146                                 const weighted::Header &wp_header,
    147                                 pixel_type chan, size_t group_id,
    148                                 TreeLut<uint8_t, false, false> &tree_lut,
    149                                 Image *image, uint32_t &fl_run,
    150                                 uint32_t &fl_v) {
    151  JxlMemoryManager *memory_manager = image->memory_manager();
    152  Channel &channel = image->channel[chan];
    153 
    154  std::array<pixel_type, kNumStaticProperties> static_props = {
    155      {chan, static_cast<int>(group_id)}};
    156  // TODO(veluca): filter the tree according to static_props.
    157 
    158  // zero pixel channel? could happen
    159  if (channel.w == 0 || channel.h == 0) return true;
    160 
    161  bool tree_has_wp_prop_or_pred = false;
    162  bool is_wp_only = false;
    163  bool is_gradient_only = false;
    164  size_t num_props;
    165  FlatTree tree =
    166      FilterTree(global_tree, static_props, &num_props,
    167                 &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only);
    168 
    169  // From here on, tree lookup returns a *clustered* context ID.
    170  // This avoids an extra memory lookup after tree traversal.
    171  for (auto &node : tree) {
    172    if (node.property0 == -1) {
    173      node.childID = context_map[node.childID];
    174    }
    175  }
    176 
    177  JXL_DEBUG_V(3, "Decoded MA tree with %" PRIuS " nodes", tree.size());
    178 
    179  // MAANS decode
    180  const auto make_pixel = [](uint64_t v, pixel_type multiplier,
    181                             pixel_type_w offset) -> pixel_type {
    182    JXL_DASSERT((v & 0xFFFFFFFF) == v);
    183    pixel_type_w val = UnpackSigned(v);
    184    // if it overflows, it overflows, and we have a problem anyway
    185    return val * multiplier + offset;
    186  };
    187 
    188  if (tree.size() == 1) {
    189    // special optimized case: no meta-adaptation, so no need
    190    // to compute properties.
    191    Predictor predictor = tree[0].predictor;
    192    int64_t offset = tree[0].predictor_offset;
    193    int32_t multiplier = tree[0].multiplier;
    194    size_t ctx_id = tree[0].childID;
    195    if (predictor == Predictor::Zero) {
    196      uint32_t value;
    197      if (reader->IsSingleValueAndAdvance(ctx_id, &value,
    198                                          channel.w * channel.h)) {
    199        // Special-case: histogram has a single symbol, with no extra bits, and
    200        // we use ANS mode.
    201        JXL_DEBUG_V(8, "Fastest track.");
    202        pixel_type v = make_pixel(value, multiplier, offset);
    203        for (size_t y = 0; y < channel.h; y++) {
    204          pixel_type *JXL_RESTRICT r = channel.Row(y);
    205          std::fill(r, r + channel.w, v);
    206        }
    207      } else {
    208        JXL_DEBUG_V(8, "Fast track.");
    209        if (multiplier == 1 && offset == 0) {
    210          for (size_t y = 0; y < channel.h; y++) {
    211            pixel_type *JXL_RESTRICT r = channel.Row(y);
    212            for (size_t x = 0; x < channel.w; x++) {
    213              uint32_t v =
    214                  reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br);
    215              r[x] = UnpackSigned(v);
    216            }
    217          }
    218        } else {
    219          for (size_t y = 0; y < channel.h; y++) {
    220            pixel_type *JXL_RESTRICT r = channel.Row(y);
    221            for (size_t x = 0; x < channel.w; x++) {
    222              uint32_t v =
    223                  reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(ctx_id,
    224                                                                         br);
    225              r[x] = make_pixel(v, multiplier, offset);
    226            }
    227          }
    228        }
    229      }
    230      return true;
    231    } else if (uses_lz77 && predictor == Predictor::Gradient && offset == 0 &&
    232               multiplier == 1 && reader->IsHuffRleOnly()) {
    233      JXL_DEBUG_V(8, "Gradient RLE (fjxl) very fast track.");
    234      pixel_type_w sv = UnpackSigned(fl_v);
    235      for (size_t y = 0; y < channel.h; y++) {
    236        pixel_type *JXL_RESTRICT r = channel.Row(y);
    237        const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1);
    238        const pixel_type *JXL_RESTRICT rtopleft =
    239            (y ? channel.Row(y - 1) - 1 : r - 1);
    240        pixel_type_w guess = (y ? rtop[0] : 0);
    241        if (fl_run == 0) {
    242          reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &fl_v,
    243                                                     &fl_run);
    244          sv = UnpackSigned(fl_v);
    245        } else {
    246          fl_run--;
    247        }
    248        r[0] = sv + guess;
    249        for (size_t x = 1; x < channel.w; x++) {
    250          pixel_type left = r[x - 1];
    251          pixel_type top = rtop[x];
    252          pixel_type topleft = rtopleft[x];
    253          pixel_type_w guess = ClampedGradient(top, left, topleft);
    254          if (!fl_run) {
    255            reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &fl_v,
    256                                                       &fl_run);
    257            sv = UnpackSigned(fl_v);
    258          } else {
    259            fl_run--;
    260          }
    261          r[x] = sv + guess;
    262        }
    263      }
    264      return true;
    265    } else if (predictor == Predictor::Gradient && offset == 0 &&
    266               multiplier == 1) {
    267      JXL_DEBUG_V(8, "Gradient very fast track.");
    268      const intptr_t onerow = channel.plane.PixelsPerRow();
    269      for (size_t y = 0; y < channel.h; y++) {
    270        pixel_type *JXL_RESTRICT r = channel.Row(y);
    271        for (size_t x = 0; x < channel.w; x++) {
    272          pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    273          pixel_type top = (y ? *(r + x - onerow) : left);
    274          pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left);
    275          pixel_type guess = ClampedGradient(top, left, topleft);
    276          uint64_t v = reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(
    277              ctx_id, br);
    278          r[x] = make_pixel(v, 1, guess);
    279        }
    280      }
    281      return true;
    282    }
    283  }
    284 
    285  // Check if this tree is a WP-only tree with a small enough property value
    286  // range.
    287  if (is_wp_only) {
    288    is_wp_only = TreeToLookupTable(tree, tree_lut);
    289  }
    290  if (is_gradient_only) {
    291    is_gradient_only = TreeToLookupTable(tree, tree_lut);
    292  }
    293 
    294  if (is_gradient_only) {
    295    JXL_DEBUG_V(8, "Gradient fast track.");
    296    const intptr_t onerow = channel.plane.PixelsPerRow();
    297    for (size_t y = 0; y < channel.h; y++) {
    298      pixel_type *JXL_RESTRICT r = channel.Row(y);
    299      for (size_t x = 0; x < channel.w; x++) {
    300        pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    301        pixel_type_w top = (y ? *(r + x - onerow) : left);
    302        pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
    303        int32_t guess = ClampedGradient(top, left, topleft);
    304        uint32_t pos =
    305            kPropRangeFast +
    306            std::min<pixel_type_w>(
    307                std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
    308                kPropRangeFast - 1);
    309        uint32_t ctx_id = tree_lut.context_lookup[pos];
    310        uint64_t v =
    311            reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(ctx_id, br);
    312        r[x] = make_pixel(v, 1, guess);
    313      }
    314    }
    315  } else if (!uses_lz77 && is_wp_only && channel.w > 8) {
    316    JXL_DEBUG_V(8, "WP fast track.");
    317    weighted::State wp_state(wp_header, channel.w, channel.h);
    318    Properties properties(1);
    319    for (size_t y = 0; y < channel.h; y++) {
    320      pixel_type *JXL_RESTRICT r = channel.Row(y);
    321      const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1);
    322      const pixel_type *JXL_RESTRICT rtoptop =
    323          (y > 1 ? channel.Row(y - 2) : rtop);
    324      const pixel_type *JXL_RESTRICT rtopleft =
    325          (y ? channel.Row(y - 1) - 1 : r - 1);
    326      const pixel_type *JXL_RESTRICT rtopright =
    327          (y ? channel.Row(y - 1) + 1 : r - 1);
    328      size_t x = 0;
    329      {
    330        size_t offset = 0;
    331        pixel_type_w left = y ? rtop[x] : 0;
    332        pixel_type_w toptop = y ? rtoptop[x] : 0;
    333        pixel_type_w topright = (x + 1 < channel.w && y ? rtop[x + 1] : left);
    334        int32_t guess = wp_state.Predict</*compute_properties=*/true>(
    335            x, y, channel.w, left, left, topright, left, toptop, &properties,
    336            offset);
    337        uint32_t pos =
    338            kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
    339                                      kPropRangeFast - 1);
    340        uint32_t ctx_id = tree_lut.context_lookup[pos];
    341        uint64_t v =
    342            reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br);
    343        r[x] = make_pixel(v, 1, guess);
    344        wp_state.UpdateErrors(r[x], x, y, channel.w);
    345      }
    346      for (x = 1; x + 1 < channel.w; x++) {
    347        size_t offset = 0;
    348        int32_t guess = wp_state.Predict</*compute_properties=*/true>(
    349            x, y, channel.w, rtop[x], r[x - 1], rtopright[x], rtopleft[x],
    350            rtoptop[x], &properties, offset);
    351        uint32_t pos =
    352            kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
    353                                      kPropRangeFast - 1);
    354        uint32_t ctx_id = tree_lut.context_lookup[pos];
    355        uint64_t v =
    356            reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br);
    357        r[x] = make_pixel(v, 1, guess);
    358        wp_state.UpdateErrors(r[x], x, y, channel.w);
    359      }
    360      {
    361        size_t offset = 0;
    362        int32_t guess = wp_state.Predict</*compute_properties=*/true>(
    363            x, y, channel.w, rtop[x], r[x - 1], rtop[x], rtopleft[x],
    364            rtoptop[x], &properties, offset);
    365        uint32_t pos =
    366            kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
    367                                      kPropRangeFast - 1);
    368        uint32_t ctx_id = tree_lut.context_lookup[pos];
    369        uint64_t v =
    370            reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br);
    371        r[x] = make_pixel(v, 1, guess);
    372        wp_state.UpdateErrors(r[x], x, y, channel.w);
    373      }
    374    }
    375  } else if (!tree_has_wp_prop_or_pred) {
    376    // special optimized case: the weighted predictor and its properties are not
    377    // used, so no need to compute weights and properties.
    378    JXL_DEBUG_V(8, "Slow track.");
    379    MATreeLookup tree_lookup(tree);
    380    Properties properties = Properties(num_props);
    381    const intptr_t onerow = channel.plane.PixelsPerRow();
    382    JXL_ASSIGN_OR_RETURN(
    383        Channel references,
    384        Channel::Create(memory_manager,
    385                        properties.size() - kNumNonrefProperties, channel.w));
    386    for (size_t y = 0; y < channel.h; y++) {
    387      pixel_type *JXL_RESTRICT p = channel.Row(y);
    388      PrecomputeReferences(channel, y, *image, chan, &references);
    389      InitPropsRow(&properties, static_props, y);
    390      if (y > 1 && channel.w > 8 && references.w == 0) {
    391        for (size_t x = 0; x < 2; x++) {
    392          PredictionResult res =
    393              PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
    394                              tree_lookup, references);
    395          uint64_t v =
    396              reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    397          p[x] = make_pixel(v, res.multiplier, res.guess);
    398        }
    399        for (size_t x = 2; x < channel.w - 2; x++) {
    400          PredictionResult res =
    401              PredictTreeNoWPNEC(&properties, channel.w, p + x, onerow, x, y,
    402                                 tree_lookup, references);
    403          uint64_t v = reader->ReadHybridUintClusteredInlined<uses_lz77>(
    404              res.context, br);
    405          p[x] = make_pixel(v, res.multiplier, res.guess);
    406        }
    407        for (size_t x = channel.w - 2; x < channel.w; x++) {
    408          PredictionResult res =
    409              PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
    410                              tree_lookup, references);
    411          uint64_t v =
    412              reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    413          p[x] = make_pixel(v, res.multiplier, res.guess);
    414        }
    415      } else {
    416        for (size_t x = 0; x < channel.w; x++) {
    417          PredictionResult res =
    418              PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
    419                              tree_lookup, references);
    420          uint64_t v = reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(
    421              res.context, br);
    422          p[x] = make_pixel(v, res.multiplier, res.guess);
    423        }
    424      }
    425    }
    426  } else {
    427    JXL_DEBUG_V(8, "Slowest track.");
    428    MATreeLookup tree_lookup(tree);
    429    Properties properties = Properties(num_props);
    430    const intptr_t onerow = channel.plane.PixelsPerRow();
    431    JXL_ASSIGN_OR_RETURN(
    432        Channel references,
    433        Channel::Create(memory_manager,
    434                        properties.size() - kNumNonrefProperties, channel.w));
    435    weighted::State wp_state(wp_header, channel.w, channel.h);
    436    for (size_t y = 0; y < channel.h; y++) {
    437      pixel_type *JXL_RESTRICT p = channel.Row(y);
    438      InitPropsRow(&properties, static_props, y);
    439      PrecomputeReferences(channel, y, *image, chan, &references);
    440      if (!uses_lz77 && y > 1 && channel.w > 8 && references.w == 0) {
    441        for (size_t x = 0; x < 2; x++) {
    442          PredictionResult res =
    443              PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
    444                            tree_lookup, references, &wp_state);
    445          uint64_t v =
    446              reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    447          p[x] = make_pixel(v, res.multiplier, res.guess);
    448          wp_state.UpdateErrors(p[x], x, y, channel.w);
    449        }
    450        for (size_t x = 2; x < channel.w - 2; x++) {
    451          PredictionResult res =
    452              PredictTreeWPNEC(&properties, channel.w, p + x, onerow, x, y,
    453                               tree_lookup, references, &wp_state);
    454          uint64_t v = reader->ReadHybridUintClusteredInlined<uses_lz77>(
    455              res.context, br);
    456          p[x] = make_pixel(v, res.multiplier, res.guess);
    457          wp_state.UpdateErrors(p[x], x, y, channel.w);
    458        }
    459        for (size_t x = channel.w - 2; x < channel.w; x++) {
    460          PredictionResult res =
    461              PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
    462                            tree_lookup, references, &wp_state);
    463          uint64_t v =
    464              reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    465          p[x] = make_pixel(v, res.multiplier, res.guess);
    466          wp_state.UpdateErrors(p[x], x, y, channel.w);
    467        }
    468      } else {
    469        for (size_t x = 0; x < channel.w; x++) {
    470          PredictionResult res =
    471              PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
    472                            tree_lookup, references, &wp_state);
    473          uint64_t v =
    474              reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    475          p[x] = make_pixel(v, res.multiplier, res.guess);
    476          wp_state.UpdateErrors(p[x], x, y, channel.w);
    477        }
    478      }
    479    }
    480  }
    481  return true;
    482 }
    483 }  // namespace detail
    484 
    485 Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader,
    486                                 const std::vector<uint8_t> &context_map,
    487                                 const Tree &global_tree,
    488                                 const weighted::Header &wp_header,
    489                                 pixel_type chan, size_t group_id,
    490                                 TreeLut<uint8_t, false, false> &tree_lut,
    491                                 Image *image, uint32_t &fl_run,
    492                                 uint32_t &fl_v) {
    493  if (reader->UsesLZ77()) {
    494    return detail::DecodeModularChannelMAANS</*uses_lz77=*/true>(
    495        br, reader, context_map, global_tree, wp_header, chan, group_id,
    496        tree_lut, image, fl_run, fl_v);
    497  } else {
    498    return detail::DecodeModularChannelMAANS</*uses_lz77=*/false>(
    499        br, reader, context_map, global_tree, wp_header, chan, group_id,
    500        tree_lut, image, fl_run, fl_v);
    501  }
    502 }
    503 
    504 GroupHeader::GroupHeader() { Bundle::Init(this); }
    505 
    506 Status ValidateChannelDimensions(const Image &image,
    507                                 const ModularOptions &options) {
    508  size_t nb_channels = image.channel.size();
    509  for (bool is_dc : {true, false}) {
    510    size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1);
    511    size_t c = image.nb_meta_channels;
    512    for (; c < nb_channels; c++) {
    513      const Channel &ch = image.channel[c];
    514      if (ch.w > options.group_dim || ch.h > options.group_dim) break;
    515    }
    516    for (; c < nb_channels; c++) {
    517      const Channel &ch = image.channel[c];
    518      if (ch.w == 0 || ch.h == 0) continue;  // skip empty
    519      bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3;
    520      if (is_dc_channel != is_dc) continue;
    521      size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift);
    522      if (tile_dim == 0) {
    523        return JXL_FAILURE("Inconsistent transforms");
    524      }
    525    }
    526  }
    527  return true;
    528 }
    529 
    530 Status ModularDecode(BitReader *br, Image &image, GroupHeader &header,
    531                     size_t group_id, ModularOptions *options,
    532                     const Tree *global_tree, const ANSCode *global_code,
    533                     const std::vector<uint8_t> *global_ctx_map,
    534                     const bool allow_truncated_group) {
    535  if (image.channel.empty()) return true;
    536  JxlMemoryManager *memory_manager = image.memory_manager();
    537 
    538  // decode transforms
    539  Status status = Bundle::Read(br, &header);
    540  if (!allow_truncated_group) JXL_RETURN_IF_ERROR(status);
    541  if (status.IsFatalError()) return status;
    542  if (!br->AllReadsWithinBounds()) {
    543    // Don't do/undo transforms if header is incomplete.
    544    header.transforms.clear();
    545    image.transform = header.transforms;
    546    for (auto &ch : image.channel) {
    547      ZeroFillImage(&ch.plane);
    548    }
    549    return Status(StatusCode::kNotEnoughBytes);
    550  }
    551 
    552  JXL_DEBUG_V(3, "Image data underwent %" PRIuS " transformations: ",
    553              header.transforms.size());
    554  image.transform = header.transforms;
    555  for (Transform &transform : image.transform) {
    556    JXL_RETURN_IF_ERROR(transform.MetaApply(image));
    557  }
    558  if (image.error) {
    559    return JXL_FAILURE("Corrupt file. Aborting.");
    560  }
    561  JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options));
    562 
    563  size_t nb_channels = image.channel.size();
    564 
    565  size_t num_chans = 0;
    566  size_t distance_multiplier = 0;
    567  for (size_t i = 0; i < nb_channels; i++) {
    568    Channel &channel = image.channel[i];
    569    if (!channel.w || !channel.h) {
    570      continue;  // skip empty channels
    571    }
    572    if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
    573                                        channel.h > options->max_chan_size)) {
    574      break;
    575    }
    576    if (channel.w > distance_multiplier) {
    577      distance_multiplier = channel.w;
    578    }
    579    num_chans++;
    580  }
    581  if (num_chans == 0) return true;
    582 
    583  size_t next_channel = 0;
    584  auto scope_guard = MakeScopeGuard([&]() {
    585    for (size_t c = next_channel; c < image.channel.size(); c++) {
    586      ZeroFillImage(&image.channel[c].plane);
    587    }
    588  });
    589  // Do not do anything if truncated groups are not allowed.
    590  if (allow_truncated_group) scope_guard.Disarm();
    591 
    592  // Read tree.
    593  Tree tree_storage;
    594  std::vector<uint8_t> context_map_storage;
    595  ANSCode code_storage;
    596  const Tree *tree = &tree_storage;
    597  const ANSCode *code = &code_storage;
    598  const std::vector<uint8_t> *context_map = &context_map_storage;
    599  if (!header.use_global_tree) {
    600    uint64_t max_tree_size = 1024;
    601    for (size_t i = 0; i < nb_channels; i++) {
    602      Channel &channel = image.channel[i];
    603      if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
    604                                          channel.h > options->max_chan_size)) {
    605        break;
    606      }
    607      uint64_t pixels = channel.w * channel.h;
    608      max_tree_size += pixels;
    609    }
    610    max_tree_size = std::min(static_cast<uint64_t>(1 << 20), max_tree_size);
    611    JXL_RETURN_IF_ERROR(
    612        DecodeTree(memory_manager, br, &tree_storage, max_tree_size));
    613    JXL_RETURN_IF_ERROR(DecodeHistograms(memory_manager, br,
    614                                         (tree_storage.size() + 1) / 2,
    615                                         &code_storage, &context_map_storage));
    616  } else {
    617    if (!global_tree || !global_code || !global_ctx_map ||
    618        global_tree->empty()) {
    619      return JXL_FAILURE("No global tree available but one was requested");
    620    }
    621    tree = global_tree;
    622    code = global_code;
    623    context_map = global_ctx_map;
    624  }
    625 
    626  // Read channels
    627  JXL_ASSIGN_OR_RETURN(ANSSymbolReader reader,
    628                       ANSSymbolReader::Create(code, br, distance_multiplier));
    629  auto tree_lut = jxl::make_unique<TreeLut<uint8_t, false, false>>();
    630  uint32_t fl_run = 0;
    631  uint32_t fl_v = 0;
    632  for (; next_channel < nb_channels; next_channel++) {
    633    Channel &channel = image.channel[next_channel];
    634    if (!channel.w || !channel.h) {
    635      continue;  // skip empty channels
    636    }
    637    if (next_channel >= image.nb_meta_channels &&
    638        (channel.w > options->max_chan_size ||
    639         channel.h > options->max_chan_size)) {
    640      break;
    641    }
    642    JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS(
    643        br, &reader, *context_map, *tree, header.wp_header, next_channel,
    644        group_id, *tree_lut, &image, fl_run, fl_v));
    645 
    646    // Truncated group.
    647    if (!br->AllReadsWithinBounds()) {
    648      if (!allow_truncated_group) return JXL_FAILURE("Truncated input");
    649      return Status(StatusCode::kNotEnoughBytes);
    650    }
    651  }
    652 
    653  // Make sure no zero-filling happens even if next_channel < nb_channels.
    654  scope_guard.Disarm();
    655 
    656  if (!reader.CheckANSFinalState()) {
    657    return JXL_FAILURE("ANS decode final state failed");
    658  }
    659  return true;
    660 }
    661 
    662 Status ModularGenericDecompress(BitReader *br, Image &image,
    663                                GroupHeader *header, size_t group_id,
    664                                ModularOptions *options, bool undo_transforms,
    665                                const Tree *tree, const ANSCode *code,
    666                                const std::vector<uint8_t> *ctx_map,
    667                                bool allow_truncated_group) {
    668  std::vector<std::pair<uint32_t, uint32_t>> req_sizes;
    669  req_sizes.reserve(image.channel.size());
    670  for (const auto &c : image.channel) {
    671    req_sizes.emplace_back(c.w, c.h);
    672  }
    673  GroupHeader local_header;
    674  if (header == nullptr) header = &local_header;
    675  size_t bit_pos = br->TotalBitsConsumed();
    676  auto dec_status = ModularDecode(br, image, *header, group_id, options, tree,
    677                                  code, ctx_map, allow_truncated_group);
    678  if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
    679  if (dec_status.IsFatalError()) return dec_status;
    680  if (undo_transforms) image.undo_transforms(header->wp_header);
    681  if (image.error) return JXL_FAILURE("Corrupt file. Aborting.");
    682  JXL_DEBUG_V(4,
    683              "Modular-decoded a %" PRIuS "x%" PRIuS " nbchans=%" PRIuS
    684              " image from %" PRIuS " bytes",
    685              image.w, image.h, image.channel.size(),
    686              (br->TotalBitsConsumed() - bit_pos) / 8);
    687  JXL_DEBUG_V(5, "Modular image: %s", image.DebugString().c_str());
    688  (void)bit_pos;
    689  // Check that after applying all transforms we are back to the requested
    690  // image sizes, otherwise there's a programming error with the
    691  // transformations.
    692  if (undo_transforms) {
    693    JXL_ENSURE(image.channel.size() == req_sizes.size());
    694    for (size_t c = 0; c < req_sizes.size(); c++) {
    695      JXL_ENSURE(req_sizes[c].first == image.channel[c].w);
    696      JXL_ENSURE(req_sizes[c].second == image.channel[c].h);
    697    }
    698  }
    699  return dec_status;
    700 }
    701 
    702 }  // namespace jxl