tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_encoding.cc (30517B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include <jxl/memory_manager.h>
      7 
      8 #include <algorithm>
      9 #include <array>
     10 #include <cstddef>
     11 #include <cstdint>
     12 #include <cstdlib>
     13 #include <limits>
     14 #include <queue>
     15 #include <utility>
     16 #include <vector>
     17 
     18 #include "lib/jxl/base/common.h"
     19 #include "lib/jxl/base/printf_macros.h"
     20 #include "lib/jxl/base/status.h"
     21 #include "lib/jxl/enc_ans.h"
     22 #include "lib/jxl/enc_aux_out.h"
     23 #include "lib/jxl/enc_bit_writer.h"
     24 #include "lib/jxl/enc_fields.h"
     25 #include "lib/jxl/fields.h"
     26 #include "lib/jxl/image_ops.h"
     27 #include "lib/jxl/modular/encoding/context_predict.h"
     28 #include "lib/jxl/modular/encoding/enc_ma.h"
     29 #include "lib/jxl/modular/encoding/encoding.h"
     30 #include "lib/jxl/modular/encoding/ma_common.h"
     31 #include "lib/jxl/modular/options.h"
     32 #include "lib/jxl/pack_signed.h"
     33 
     34 namespace jxl {
     35 
     36 namespace {
     37 // Plot tree (if enabled) and predictor usage map.
     38 constexpr bool kWantDebug = true;
     39 // constexpr bool kPrintTree = false;
     40 
     41 inline std::array<uint8_t, 3> PredictorColor(Predictor p) {
     42  switch (p) {
     43    case Predictor::Zero:
     44      return {{0, 0, 0}};
     45    case Predictor::Left:
     46      return {{255, 0, 0}};
     47    case Predictor::Top:
     48      return {{0, 255, 0}};
     49    case Predictor::Average0:
     50      return {{0, 0, 255}};
     51    case Predictor::Average4:
     52      return {{192, 128, 128}};
     53    case Predictor::Select:
     54      return {{255, 255, 0}};
     55    case Predictor::Gradient:
     56      return {{255, 0, 255}};
     57    case Predictor::Weighted:
     58      return {{0, 255, 255}};
     59      // TODO(jon)
     60    default:
     61      return {{255, 255, 255}};
     62  };
     63 }
     64 
     65 // `cutoffs` must be sorted.
     66 Tree MakeFixedTree(int property, const std::vector<int32_t> &cutoffs,
     67                   Predictor pred, size_t num_pixels, int bitdepth) {
     68  size_t log_px = CeilLog2Nonzero(num_pixels);
     69  size_t min_gap = 0;
     70  // Reduce fixed tree height when encoding small images.
     71  if (log_px < 14) {
     72    min_gap = 8 * (14 - log_px);
     73  }
     74  const int shift = bitdepth > 11 ? std::min(4, bitdepth - 11) : 0;
     75  const int mul = 1 << shift;
     76  Tree tree;
     77  struct NodeInfo {
     78    size_t begin, end, pos;
     79  };
     80  std::queue<NodeInfo> q;
     81  // Leaf IDs will be set by roundtrip decoding the tree.
     82  tree.push_back(PropertyDecisionNode::Leaf(pred));
     83  q.push(NodeInfo{0, cutoffs.size(), 0});
     84  while (!q.empty()) {
     85    NodeInfo info = q.front();
     86    q.pop();
     87    if (info.begin + min_gap >= info.end) continue;
     88    uint32_t split = (info.begin + info.end) / 2;
     89    int32_t cutoff = cutoffs[split] * mul;
     90    tree[info.pos] = PropertyDecisionNode::Split(property, cutoff, tree.size());
     91    q.push(NodeInfo{split + 1, info.end, tree.size()});
     92    tree.push_back(PropertyDecisionNode::Leaf(pred));
     93    q.push(NodeInfo{info.begin, split, tree.size()});
     94    tree.push_back(PropertyDecisionNode::Leaf(pred));
     95  }
     96  return tree;
     97 }
     98 
     99 }  // namespace
    100 
    101 Status GatherTreeData(const Image &image, pixel_type chan, size_t group_id,
    102                      const weighted::Header &wp_header,
    103                      const ModularOptions &options, TreeSamples &tree_samples,
    104                      size_t *total_pixels) {
    105  const Channel &channel = image.channel[chan];
    106  JxlMemoryManager *memory_manager = channel.memory_manager();
    107 
    108  JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w,
    109              channel.h, chan);
    110 
    111  std::array<pixel_type, kNumStaticProperties> static_props = {
    112      {chan, static_cast<int>(group_id)}};
    113  Properties properties(kNumNonrefProperties +
    114                        kExtraPropsPerChannel * options.max_properties);
    115  double pixel_fraction = std::min(1.0f, options.nb_repeats);
    116  // a fraction of 0 is used to disable learning entirely.
    117  if (pixel_fraction > 0) {
    118    pixel_fraction = std::max(pixel_fraction,
    119                              std::min(1.0, 1024.0 / (channel.w * channel.h)));
    120  }
    121  uint64_t threshold =
    122      (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction;
    123  uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull),
    124                   static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)};
    125  // Xorshift128+ adapted from xorshift128+-inl.h
    126  auto use_sample = [&]() {
    127    auto s1 = s[0];
    128    const auto s0 = s[1];
    129    const auto bits = s1 + s0;  // b, c
    130    s[0] = s0;
    131    s1 ^= s1 << 23;
    132    s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5);
    133    s[1] = s1;
    134    return (bits >> 32) <= threshold;
    135  };
    136 
    137  const intptr_t onerow = channel.plane.PixelsPerRow();
    138  JXL_ASSIGN_OR_RETURN(
    139      Channel references,
    140      Channel::Create(memory_manager, properties.size() - kNumNonrefProperties,
    141                      channel.w));
    142  weighted::State wp_state(wp_header, channel.w, channel.h);
    143  tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64);
    144  const bool multiple_predictors = tree_samples.NumPredictors() != 1;
    145  auto compute_sample = [&](const pixel_type *p, size_t x, size_t y) {
    146    pixel_type_w pred[kNumModularPredictors];
    147    if (multiple_predictors) {
    148      PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references,
    149                      &wp_state, pred);
    150    } else {
    151      pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] =
    152          PredictLearn(&properties, channel.w, p + x, onerow, x, y,
    153                       tree_samples.PredictorFromIndex(0), references,
    154                       &wp_state)
    155              .guess;
    156    }
    157    (*total_pixels)++;
    158    if (use_sample()) {
    159      tree_samples.AddSample(p[x], properties, pred);
    160    }
    161    wp_state.UpdateErrors(p[x], x, y, channel.w);
    162  };
    163 
    164  for (size_t y = 0; y < channel.h; y++) {
    165    const pixel_type *JXL_RESTRICT p = channel.Row(y);
    166    PrecomputeReferences(channel, y, image, chan, &references);
    167    InitPropsRow(&properties, static_props, y);
    168 
    169    // TODO(veluca): avoid computing WP if we don't use its property or
    170    // predictions.
    171    if (y > 1 && channel.w > 8 && references.w == 0) {
    172      for (size_t x = 0; x < 2; x++) {
    173        compute_sample(p, x, y);
    174      }
    175      for (size_t x = 2; x < channel.w - 2; x++) {
    176        pixel_type_w pred[kNumModularPredictors];
    177        if (multiple_predictors) {
    178          PredictLearnAllNEC(&properties, channel.w, p + x, onerow, x, y,
    179                             references, &wp_state, pred);
    180        } else {
    181          pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] =
    182              PredictLearnNEC(&properties, channel.w, p + x, onerow, x, y,
    183                              tree_samples.PredictorFromIndex(0), references,
    184                              &wp_state)
    185                  .guess;
    186        }
    187        (*total_pixels)++;
    188        if (use_sample()) {
    189          tree_samples.AddSample(p[x], properties, pred);
    190        }
    191        wp_state.UpdateErrors(p[x], x, y, channel.w);
    192      }
    193      for (size_t x = channel.w - 2; x < channel.w; x++) {
    194        compute_sample(p, x, y);
    195      }
    196    } else {
    197      for (size_t x = 0; x < channel.w; x++) {
    198        compute_sample(p, x, y);
    199      }
    200    }
    201  }
    202  return true;
    203 }
    204 
    205 Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels,
    206                    int bitdepth, int prevprop) {
    207  switch (tree_kind) {
    208    case ModularOptions::TreeKind::kJpegTranscodeACMeta:
    209      // All the data is 0, so no need for a fancy tree.
    210      return {PropertyDecisionNode::Leaf(Predictor::Zero)};
    211    case ModularOptions::TreeKind::kTrivialTreeNoPredictor:
    212      // All the data is 0, so no need for a fancy tree.
    213      return {PropertyDecisionNode::Leaf(Predictor::Zero)};
    214    case ModularOptions::TreeKind::kFalconACMeta:
    215      // All the data is 0 except the quant field. TODO(veluca): make that 0
    216      // too.
    217      return {PropertyDecisionNode::Leaf(Predictor::Left)};
    218    case ModularOptions::TreeKind::kACMeta: {
    219      // Small image.
    220      if (total_pixels < 1024) {
    221        return {PropertyDecisionNode::Leaf(Predictor::Left)};
    222      }
    223      Tree tree;
    224      // 0: c > 1
    225      tree.push_back(PropertyDecisionNode::Split(0, 1, 1));
    226      // 1: c > 2
    227      tree.push_back(PropertyDecisionNode::Split(0, 2, 3));
    228      // 2: c > 0
    229      tree.push_back(PropertyDecisionNode::Split(0, 0, 5));
    230      // 3: EPF control field (all 0 or 4), top > 3
    231      tree.push_back(PropertyDecisionNode::Split(6, 3, 21));
    232      // 4: ACS+QF, y > 0
    233      tree.push_back(PropertyDecisionNode::Split(2, 0, 7));
    234      // 5: CfL x
    235      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient));
    236      // 6: CfL b
    237      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient));
    238      // 7: QF: split according to the left quant value.
    239      tree.push_back(PropertyDecisionNode::Split(7, 5, 9));
    240      // 8: ACS: split in 4 segments (8x8 from 0 to 3, large square 4-5, large
    241      // rectangular 6-11, 8x8 12+), according to previous ACS value.
    242      tree.push_back(PropertyDecisionNode::Split(7, 5, 15));
    243      // QF
    244      tree.push_back(PropertyDecisionNode::Split(7, 11, 11));
    245      tree.push_back(PropertyDecisionNode::Split(7, 3, 13));
    246      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
    247      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
    248      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
    249      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
    250      // ACS
    251      tree.push_back(PropertyDecisionNode::Split(7, 11, 17));
    252      tree.push_back(PropertyDecisionNode::Split(7, 3, 19));
    253      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    254      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    255      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    256      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    257      // EPF, left > 3
    258      tree.push_back(PropertyDecisionNode::Split(7, 3, 23));
    259      tree.push_back(PropertyDecisionNode::Split(7, 3, 25));
    260      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    261      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    262      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    263      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    264      return tree;
    265    }
    266    case ModularOptions::TreeKind::kWPFixedDC: {
    267      std::vector<int32_t> cutoffs = {
    268          -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15,
    269          -11,  -7,   -4,   -3,   -1,   0,   1,   3,   5,   7,   11,
    270          15,   23,   31,   47,   63,   95,  127, 191, 255, 392, 500};
    271      return MakeFixedTree(kWPProp, cutoffs, Predictor::Weighted, total_pixels,
    272                           bitdepth);
    273    }
    274    case ModularOptions::TreeKind::kGradientFixedDC: {
    275      std::vector<int32_t> cutoffs = {
    276          -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15,
    277          -11,  -7,   -4,   -3,   -1,   0,   1,   3,   5,   7,   11,
    278          15,   23,   31,   47,   63,   95,  127, 191, 255, 392, 500};
    279      return MakeFixedTree(
    280          prevprop > 0 ? kNumNonrefProperties + 2 : kGradientProp, cutoffs,
    281          Predictor::Gradient, total_pixels, bitdepth);
    282    }
    283    case ModularOptions::TreeKind::kLearn: {
    284      JXL_DEBUG_ABORT("internal: kLearn is not predefined tree");
    285      return {};
    286    }
    287  }
    288  JXL_DEBUG_ABORT("internal: unexpected TreeKind: %d",
    289                  static_cast<int>(tree_kind));
    290  return {};
    291 }
    292 
    293 StatusOr<Tree> LearnTree(
    294    TreeSamples &&tree_samples, size_t total_pixels,
    295    const ModularOptions &options,
    296    const std::vector<ModularMultiplierInfo> &multiplier_info = {},
    297    StaticPropRange static_prop_range = {}) {
    298  Tree tree;
    299  for (size_t i = 0; i < kNumStaticProperties; i++) {
    300    if (static_prop_range[i][1] == 0) {
    301      static_prop_range[i][1] = std::numeric_limits<uint32_t>::max();
    302    }
    303  }
    304  if (!tree_samples.HasSamples()) {
    305    tree.emplace_back();
    306    tree.back().predictor = tree_samples.PredictorFromIndex(0);
    307    tree.back().property = -1;
    308    tree.back().predictor_offset = 0;
    309    tree.back().multiplier = 1;
    310    return tree;
    311  }
    312  float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels;
    313  float required_cost = pixel_fraction * 0.9 + 0.1;
    314  tree_samples.AllSamplesDone();
    315  JXL_RETURN_IF_ERROR(ComputeBestTree(
    316      tree_samples, options.splitting_heuristics_node_threshold * required_cost,
    317      multiplier_info, static_prop_range, options.fast_decode_multiplier,
    318      &tree));
    319  return tree;
    320 }
    321 
    322 Status EncodeModularChannelMAANS(const Image &image, pixel_type chan,
    323                                 const weighted::Header &wp_header,
    324                                 const Tree &global_tree, Token **tokenpp,
    325                                 AuxOut *aux_out, size_t group_id,
    326                                 bool skip_encoder_fast_path) {
    327  const Channel &channel = image.channel[chan];
    328  JxlMemoryManager *memory_manager = channel.memory_manager();
    329  Token *tokenp = *tokenpp;
    330  JXL_ENSURE(channel.w != 0 && channel.h != 0);
    331 
    332  Image3F predictor_img;
    333  if (kWantDebug) {
    334    JXL_ASSIGN_OR_RETURN(predictor_img,
    335                         Image3F::Create(memory_manager, channel.w, channel.h));
    336  }
    337 
    338  JXL_DEBUG_V(6,
    339              "Encoding %" PRIuS "x%" PRIuS
    340              " channel %d, "
    341              "(shift=%i,%i)",
    342              channel.w, channel.h, chan, channel.hshift, channel.vshift);
    343 
    344  std::array<pixel_type, kNumStaticProperties> static_props = {
    345      {chan, static_cast<int>(group_id)}};
    346  bool use_wp;
    347  bool is_wp_only;
    348  bool is_gradient_only;
    349  size_t num_props;
    350  FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp,
    351                             &is_wp_only, &is_gradient_only);
    352  Properties properties(num_props);
    353  MATreeLookup tree_lookup(tree);
    354  JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size());
    355 
    356  // Check if this tree is a WP-only tree with a small enough property value
    357  // range.
    358  // Initialized to avoid clang-tidy complaining.
    359  auto tree_lut = jxl::make_unique<TreeLut<uint16_t, false, false>>();
    360  if (is_wp_only) {
    361    is_wp_only = TreeToLookupTable(tree, *tree_lut);
    362  }
    363  if (is_gradient_only) {
    364    is_gradient_only = TreeToLookupTable(tree, *tree_lut);
    365  }
    366 
    367  if (is_wp_only && !skip_encoder_fast_path) {
    368    for (size_t c = 0; c < 3; c++) {
    369      FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]),
    370                &predictor_img.Plane(c));
    371    }
    372    const intptr_t onerow = channel.plane.PixelsPerRow();
    373    weighted::State wp_state(wp_header, channel.w, channel.h);
    374    Properties properties(1);
    375    for (size_t y = 0; y < channel.h; y++) {
    376      const pixel_type *JXL_RESTRICT r = channel.Row(y);
    377      for (size_t x = 0; x < channel.w; x++) {
    378        size_t offset = 0;
    379        pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    380        pixel_type_w top = (y ? *(r + x - onerow) : left);
    381        pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
    382        pixel_type_w topright =
    383            (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top);
    384        pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
    385        int32_t guess = wp_state.Predict</*compute_properties=*/true>(
    386            x, y, channel.w, top, left, topright, topleft, toptop, &properties,
    387            offset);
    388        uint32_t pos =
    389            kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
    390                                      kPropRangeFast - 1);
    391        uint32_t ctx_id = tree_lut->context_lookup[pos];
    392        int32_t residual = r[x] - guess;
    393        *tokenp++ = Token(ctx_id, PackSigned(residual));
    394        wp_state.UpdateErrors(r[x], x, y, channel.w);
    395      }
    396    }
    397  } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient &&
    398             tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
    399             !skip_encoder_fast_path) {
    400    for (size_t c = 0; c < 3; c++) {
    401      FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
    402                &predictor_img.Plane(c));
    403    }
    404    const intptr_t onerow = channel.plane.PixelsPerRow();
    405    for (size_t y = 0; y < channel.h; y++) {
    406      const pixel_type *JXL_RESTRICT r = channel.Row(y);
    407      for (size_t x = 0; x < channel.w; x++) {
    408        pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    409        pixel_type_w top = (y ? *(r + x - onerow) : left);
    410        pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
    411        int32_t guess = ClampedGradient(top, left, topleft);
    412        int32_t residual = r[x] - guess;
    413        *tokenp++ = Token(tree[0].childID, PackSigned(residual));
    414      }
    415    }
    416  } else if (is_gradient_only && !skip_encoder_fast_path) {
    417    for (size_t c = 0; c < 3; c++) {
    418      FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
    419                &predictor_img.Plane(c));
    420    }
    421    const intptr_t onerow = channel.plane.PixelsPerRow();
    422    for (size_t y = 0; y < channel.h; y++) {
    423      const pixel_type *JXL_RESTRICT r = channel.Row(y);
    424      for (size_t x = 0; x < channel.w; x++) {
    425        pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    426        pixel_type_w top = (y ? *(r + x - onerow) : left);
    427        pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
    428        int32_t guess = ClampedGradient(top, left, topleft);
    429        uint32_t pos =
    430            kPropRangeFast +
    431            std::min<pixel_type_w>(
    432                std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
    433                kPropRangeFast - 1);
    434        uint32_t ctx_id = tree_lut->context_lookup[pos];
    435        int32_t residual = r[x] - guess;
    436        *tokenp++ = Token(ctx_id, PackSigned(residual));
    437      }
    438    }
    439  } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero &&
    440             tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
    441             !skip_encoder_fast_path) {
    442    for (size_t c = 0; c < 3; c++) {
    443      FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]),
    444                &predictor_img.Plane(c));
    445    }
    446    for (size_t y = 0; y < channel.h; y++) {
    447      const pixel_type *JXL_RESTRICT p = channel.Row(y);
    448      for (size_t x = 0; x < channel.w; x++) {
    449        *tokenp++ = Token(tree[0].childID, PackSigned(p[x]));
    450      }
    451    }
    452  } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted &&
    453             (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 &&
    454             tree[0].predictor_offset == 0 && !skip_encoder_fast_path) {
    455    // multiplier is a power of 2.
    456    for (size_t c = 0; c < 3; c++) {
    457      FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]),
    458                &predictor_img.Plane(c));
    459    }
    460    uint32_t mul_shift =
    461        FloorLog2Nonzero(static_cast<uint32_t>(tree[0].multiplier));
    462    const intptr_t onerow = channel.plane.PixelsPerRow();
    463    for (size_t y = 0; y < channel.h; y++) {
    464      const pixel_type *JXL_RESTRICT r = channel.Row(y);
    465      for (size_t x = 0; x < channel.w; x++) {
    466        PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x,
    467                                                  y, tree[0].predictor);
    468        pixel_type_w residual = r[x] - pred.guess;
    469        JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual);
    470        *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift));
    471      }
    472    }
    473 
    474  } else if (!use_wp && !skip_encoder_fast_path) {
    475    const intptr_t onerow = channel.plane.PixelsPerRow();
    476    JXL_ASSIGN_OR_RETURN(
    477        Channel references,
    478        Channel::Create(memory_manager,
    479                        properties.size() - kNumNonrefProperties, channel.w));
    480    for (size_t y = 0; y < channel.h; y++) {
    481      const pixel_type *JXL_RESTRICT p = channel.Row(y);
    482      PrecomputeReferences(channel, y, image, chan, &references);
    483      float *pred_img_row[3];
    484      if (kWantDebug) {
    485        for (size_t c = 0; c < 3; c++) {
    486          pred_img_row[c] = predictor_img.PlaneRow(c, y);
    487        }
    488      }
    489      InitPropsRow(&properties, static_props, y);
    490      for (size_t x = 0; x < channel.w; x++) {
    491        PredictionResult res =
    492            PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
    493                            tree_lookup, references);
    494        if (kWantDebug) {
    495          for (size_t i = 0; i < 3; i++) {
    496            pred_img_row[i][x] = PredictorColor(res.predictor)[i];
    497          }
    498        }
    499        pixel_type_w residual = p[x] - res.guess;
    500        JXL_DASSERT(residual % res.multiplier == 0);
    501        *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
    502      }
    503    }
    504  } else {
    505    const intptr_t onerow = channel.plane.PixelsPerRow();
    506    JXL_ASSIGN_OR_RETURN(
    507        Channel references,
    508        Channel::Create(memory_manager,
    509                        properties.size() - kNumNonrefProperties, channel.w));
    510    weighted::State wp_state(wp_header, channel.w, channel.h);
    511    for (size_t y = 0; y < channel.h; y++) {
    512      const pixel_type *JXL_RESTRICT p = channel.Row(y);
    513      PrecomputeReferences(channel, y, image, chan, &references);
    514      float *pred_img_row[3];
    515      if (kWantDebug) {
    516        for (size_t c = 0; c < 3; c++) {
    517          pred_img_row[c] = predictor_img.PlaneRow(c, y);
    518        }
    519      }
    520      InitPropsRow(&properties, static_props, y);
    521      for (size_t x = 0; x < channel.w; x++) {
    522        PredictionResult res =
    523            PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
    524                          tree_lookup, references, &wp_state);
    525        if (kWantDebug) {
    526          for (size_t i = 0; i < 3; i++) {
    527            pred_img_row[i][x] = PredictorColor(res.predictor)[i];
    528          }
    529        }
    530        pixel_type_w residual = p[x] - res.guess;
    531        JXL_DASSERT(residual % res.multiplier == 0);
    532        *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
    533        wp_state.UpdateErrors(p[x], x, y, channel.w);
    534      }
    535    }
    536  }
    537  /* TODO(szabadka): Add cparams to the call stack here.
    538  if (kWantDebug && WantDebugOutput(cparams)) {
    539    DumpImage(
    540        cparams,
    541        ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(),
    542        predictor_img);
    543  }
    544  */
    545  *tokenpp = tokenp;
    546  return true;
    547 }
    548 
    549 Status ModularEncode(const Image &image, const ModularOptions &options,
    550                     BitWriter *writer, AuxOut *aux_out, LayerType layer,
    551                     size_t group_id, TreeSamples *tree_samples,
    552                     size_t *total_pixels, const Tree *tree,
    553                     GroupHeader *header, std::vector<Token> *tokens,
    554                     size_t *width) {
    555  if (image.error) return JXL_FAILURE("Invalid image");
    556  JxlMemoryManager *memory_manager = image.memory_manager();
    557  size_t nb_channels = image.channel.size();
    558  JXL_DEBUG_V(
    559      2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.",
    560      nb_channels, image.bitdepth, image.w, image.h);
    561 
    562  if (nb_channels < 1) {
    563    return true;  // is there any use for a zero-channel image?
    564  }
    565 
    566  // encode transforms
    567  GroupHeader header_storage;
    568  if (header == nullptr) header = &header_storage;
    569  Bundle::Init(header);
    570  if (options.predictor == Predictor::Weighted) {
    571    weighted::PredictorMode(options.wp_mode, &header->wp_header);
    572  }
    573  header->transforms = image.transform;
    574  // This doesn't actually work
    575  if (tree != nullptr) {
    576    header->use_global_tree = true;
    577  }
    578  if (tree_samples == nullptr && tree == nullptr) {
    579    JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out));
    580  }
    581 
    582  TreeSamples tree_samples_storage;
    583  size_t total_pixels_storage = 0;
    584  if (!total_pixels) total_pixels = &total_pixels_storage;
    585  if (*total_pixels == 0) {
    586    for (size_t i = 0; i < nb_channels; i++) {
    587      if (i >= image.nb_meta_channels &&
    588          (image.channel[i].w > options.max_chan_size ||
    589           image.channel[i].h > options.max_chan_size)) {
    590        break;
    591      }
    592      *total_pixels += image.channel[i].w * image.channel[i].h;
    593    }
    594    *total_pixels = std::max<size_t>(*total_pixels, 1);
    595  }
    596  // If there's no tree, compute one (or gather data to).
    597  if (tree == nullptr &&
    598      options.tree_kind == ModularOptions::TreeKind::kLearn) {
    599    bool gather_data = tree_samples != nullptr;
    600    if (tree_samples == nullptr) {
    601      JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor(
    602          options.predictor, options.wp_tree_mode));
    603      JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties(
    604          options.splitting_heuristics_properties, options.wp_tree_mode));
    605      std::vector<pixel_type> pixel_samples;
    606      std::vector<pixel_type> diff_samples;
    607      std::vector<uint32_t> group_pixel_count;
    608      std::vector<uint32_t> channel_pixel_count;
    609      CollectPixelSamples(image, options, 0, group_pixel_count,
    610                          channel_pixel_count, pixel_samples, diff_samples);
    611      std::vector<ModularMultiplierInfo> placeholder_multiplier_info;
    612      StaticPropRange range;
    613      tree_samples_storage.PreQuantizeProperties(
    614          range, placeholder_multiplier_info, group_pixel_count,
    615          channel_pixel_count, pixel_samples, diff_samples,
    616          options.max_property_values);
    617    }
    618    for (size_t i = 0; i < nb_channels; i++) {
    619      if (!image.channel[i].w || !image.channel[i].h) {
    620        continue;  // skip empty channels
    621      }
    622      if (i >= image.nb_meta_channels &&
    623          (image.channel[i].w > options.max_chan_size ||
    624           image.channel[i].h > options.max_chan_size)) {
    625        break;
    626      }
    627      JXL_RETURN_IF_ERROR(GatherTreeData(
    628          image, i, group_id, header->wp_header, options,
    629          gather_data ? *tree_samples : tree_samples_storage, total_pixels));
    630    }
    631    if (gather_data) return true;
    632  }
    633 
    634  JXL_ENSURE((tree == nullptr) == (tokens == nullptr));
    635 
    636  Tree tree_storage;
    637  std::vector<std::vector<Token>> tokens_storage(1);
    638  // Compute tree.
    639  if (tree == nullptr) {
    640    EntropyEncodingData code;
    641    std::vector<uint8_t> context_map;
    642 
    643    std::vector<std::vector<Token>> tree_tokens(1);
    644    if (options.tree_kind == ModularOptions::TreeKind::kLearn) {
    645      JXL_ASSIGN_OR_RETURN(
    646          tree_storage,
    647          LearnTree(std::move(tree_samples_storage), *total_pixels, options));
    648    } else {
    649      tree_storage = PredefinedTree(options.tree_kind, *total_pixels,
    650                                    image.bitdepth, options.max_properties);
    651    }
    652    tree = &tree_storage;
    653    tokens = tokens_storage.data();
    654 
    655    Tree decoded_tree;
    656    JXL_RETURN_IF_ERROR(TokenizeTree(*tree, tree_tokens.data(), &decoded_tree));
    657    JXL_ENSURE(tree->size() == decoded_tree.size());
    658    tree_storage = std::move(decoded_tree);
    659 
    660    /* TODO(szabadka) Add text output callback
    661    if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) {
    662      PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id));
    663    } */
    664 
    665    // Write tree
    666    JXL_ASSIGN_OR_RETURN(size_t cost,
    667                         BuildAndEncodeHistograms(
    668                             memory_manager, options.histogram_params,
    669                             kNumTreeContexts, tree_tokens, &code, &context_map,
    670                             writer, LayerType::ModularTree, aux_out));
    671    (void)cost;
    672    JXL_RETURN_IF_ERROR(WriteTokens(tree_tokens[0], code, context_map, 0,
    673                                    writer, LayerType::ModularTree, aux_out));
    674  }
    675 
    676  size_t image_width = 0;
    677  size_t total_tokens = 0;
    678  for (size_t i = 0; i < nb_channels; i++) {
    679    if (i >= image.nb_meta_channels &&
    680        (image.channel[i].w > options.max_chan_size ||
    681         image.channel[i].h > options.max_chan_size)) {
    682      break;
    683    }
    684    if (image.channel[i].w > image_width) image_width = image.channel[i].w;
    685    total_tokens += image.channel[i].w * image.channel[i].h;
    686  }
    687  if (options.zero_tokens) {
    688    tokens->resize(tokens->size() + total_tokens, {0, 0});
    689  } else {
    690    // Do one big allocation for all the tokens we'll need,
    691    // to avoid reallocs that might require copying.
    692    size_t pos = tokens->size();
    693    tokens->resize(pos + total_tokens);
    694    Token *tokenp = tokens->data() + pos;
    695    for (size_t i = 0; i < nb_channels; i++) {
    696      if (!image.channel[i].w || !image.channel[i].h) {
    697        continue;  // skip empty channels
    698      }
    699      if (i >= image.nb_meta_channels &&
    700          (image.channel[i].w > options.max_chan_size ||
    701           image.channel[i].h > options.max_chan_size)) {
    702        break;
    703      }
    704      JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS(
    705          image, i, header->wp_header, *tree, &tokenp, aux_out, group_id,
    706          options.skip_encoder_fast_path));
    707    }
    708    // Make sure we actually wrote all tokens
    709    JXL_ENSURE(tokenp == tokens->data() + tokens->size());
    710  }
    711 
    712  // Write data if not using a global tree/ANS stream.
    713  if (!header->use_global_tree) {
    714    EntropyEncodingData code;
    715    std::vector<uint8_t> context_map;
    716    HistogramParams histo_params = options.histogram_params;
    717    histo_params.image_widths.push_back(image_width);
    718    JXL_ASSIGN_OR_RETURN(
    719        size_t cost,
    720        BuildAndEncodeHistograms(memory_manager, histo_params,
    721                                 (tree->size() + 1) / 2, tokens_storage, &code,
    722                                 &context_map, writer, layer, aux_out));
    723    (void)cost;
    724    JXL_RETURN_IF_ERROR(WriteTokens(tokens_storage[0], code, context_map, 0,
    725                                    writer, layer, aux_out));
    726  } else {
    727    *width = image_width;
    728  }
    729  return true;
    730 }
    731 
    732 Status ModularGenericCompress(Image &image, const ModularOptions &opts,
    733                              BitWriter *writer, AuxOut *aux_out,
    734                              LayerType layer, size_t group_id,
    735                              TreeSamples *tree_samples, size_t *total_pixels,
    736                              const Tree *tree, GroupHeader *header,
    737                              std::vector<Token> *tokens, size_t *width) {
    738  if (image.w == 0 || image.h == 0) return true;
    739  ModularOptions options = opts;  // Make a copy to modify it.
    740 
    741  if (options.predictor == kUndefinedPredictor) {
    742    options.predictor = Predictor::Gradient;
    743  }
    744 
    745  size_t bits = writer ? writer->BitsWritten() : 0;
    746  JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer,
    747                                    group_id, tree_samples, total_pixels, tree,
    748                                    header, tokens, width));
    749  bits = writer ? writer->BitsWritten() - bits : 0;
    750  if (writer) {
    751    JXL_DEBUG_V(4,
    752                "Modular-encoded a %" PRIuS "x%" PRIuS
    753                " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes",
    754                image.w, image.h, image.bitdepth, image.channel.size(),
    755                bits / 8);
    756  }
    757  (void)bits;
    758  return true;
    759 }
    760 
    761 }  // namespace jxl