tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

context_predict.h (25841B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
      7 #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
      8 
      9 #include <algorithm>
     10 #include <array>
     11 #include <cmath>
     12 #include <cstddef>
     13 #include <cstdint>
     14 #include <vector>
     15 
     16 #include "lib/jxl/base/bits.h"
     17 #include "lib/jxl/base/compiler_specific.h"
     18 #include "lib/jxl/base/status.h"
     19 #include "lib/jxl/field_encodings.h"
     20 #include "lib/jxl/fields.h"
     21 #include "lib/jxl/image_ops.h"
     22 #include "lib/jxl/modular/modular_image.h"
     23 #include "lib/jxl/modular/options.h"
     24 
     25 namespace jxl {
     26 
     27 namespace weighted {
     28 constexpr static size_t kNumPredictors = 4;
     29 constexpr static int64_t kPredExtraBits = 3;
     30 constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1;
     31 constexpr static size_t kNumProperties = 1;
     32 
     33 struct Header : public Fields {
     34  JXL_FIELDS_NAME(WeightedPredictorHeader)
     35  // TODO(janwas): move to cc file, avoid including fields.h.
     36  Header() { Bundle::Init(this); }
     37 
     38  Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
     39    if (visitor->AllDefault(*this, &all_default)) {
     40      // Overwrite all serialized fields, but not any nonserialized_*.
     41      visitor->SetDefault(this);
     42      return true;
     43    }
     44    auto visit_p = [visitor](pixel_type val, pixel_type *p) {
     45      uint32_t up = *p;
     46      JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up));
     47      *p = up;
     48      return Status(true);
     49    };
     50    JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C));
     51    JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C));
     52    JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca));
     53    JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb));
     54    JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc));
     55    JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd));
     56    JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce));
     57    JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0]));
     58    JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1]));
     59    JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2]));
     60    JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3]));
     61    return true;
     62  }
     63 
     64  bool all_default;
     65  pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0;
     66  uint32_t w[kNumPredictors] = {};
     67 };
     68 
     69 struct State {
     70  pixel_type_w prediction[kNumPredictors] = {};
     71  pixel_type_w pred = 0;  // *before* removing the added bits.
     72  std::vector<uint32_t> pred_errors[kNumPredictors];
     73  std::vector<int32_t> error;
     74  const Header &header;
     75 
     76  // Allows to approximate division by a number from 1 to 64.
     77  //  for (int i = 0; i < 64; i++) divlookup[i] = (1 << 24) / (i + 1);
     78 
     79  const uint32_t divlookup[64] = {
     80      16777216, 8388608, 5592405, 4194304, 3355443, 2796202, 2396745, 2097152,
     81      1864135,  1677721, 1525201, 1398101, 1290555, 1198372, 1118481, 1048576,
     82      986895,   932067,  883011,  838860,  798915,  762600,  729444,  699050,
     83      671088,   645277,  621378,  599186,  578524,  559240,  541200,  524288,
     84      508400,   493447,  479349,  466033,  453438,  441505,  430185,  419430,
     85      409200,   399457,  390167,  381300,  372827,  364722,  356962,  349525,
     86      342392,   335544,  328965,  322638,  316551,  310689,  305040,  299593,
     87      294337,   289262,  284359,  279620,  275036,  270600,  266305,  262144};
     88 
     89  constexpr static pixel_type_w AddBits(pixel_type_w x) {
     90    return static_cast<uint64_t>(x) << kPredExtraBits;
     91  }
     92 
     93  State(const Header &header, size_t xsize, size_t ysize) : header(header) {
     94    // Extra margin to avoid out-of-bounds writes.
     95    // All have space for two rows of data.
     96    for (auto &pred_error : pred_errors) {
     97      pred_error.resize((xsize + 2) * 2);
     98    }
     99    error.resize((xsize + 2) * 2);
    100  }
    101 
    102  // Approximates 4+(maxweight<<24)/(x+1), avoiding division
    103  JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const {
    104    int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5;
    105    if (shift < 0) shift = 0;
    106    return 4 + ((maxweight * divlookup[x >> shift]) >> shift);
    107  }
    108 
    109  // Approximates the weighted average of the input values with the given
    110  // weights, avoiding division. Weights must sum to at least 16.
    111  JXL_INLINE pixel_type_w
    112  WeightedAverage(const pixel_type_w *JXL_RESTRICT p,
    113                  std::array<uint32_t, kNumPredictors> w) const {
    114    uint32_t weight_sum = 0;
    115    for (size_t i = 0; i < kNumPredictors; i++) {
    116      weight_sum += w[i];
    117    }
    118    JXL_DASSERT(weight_sum > 15);
    119    uint32_t log_weight = FloorLog2Nonzero(weight_sum);  // at least 4.
    120    weight_sum = 0;
    121    for (size_t i = 0; i < kNumPredictors; i++) {
    122      w[i] >>= log_weight - 4;
    123      weight_sum += w[i];
    124    }
    125    // for rounding.
    126    pixel_type_w sum = (weight_sum >> 1) - 1;
    127    for (size_t i = 0; i < kNumPredictors; i++) {
    128      sum += p[i] * w[i];
    129    }
    130    return (sum * divlookup[weight_sum - 1]) >> 24;
    131  }
    132 
    133  template <bool compute_properties>
    134  JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize,
    135                                  pixel_type_w N, pixel_type_w W,
    136                                  pixel_type_w NE, pixel_type_w NW,
    137                                  pixel_type_w NN, Properties *properties,
    138                                  size_t offset) {
    139    size_t cur_row = y & 1 ? 0 : (xsize + 2);
    140    size_t prev_row = y & 1 ? (xsize + 2) : 0;
    141    size_t pos_N = prev_row + x;
    142    size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N;
    143    size_t pos_NW = x > 0 ? pos_N - 1 : pos_N;
    144    std::array<uint32_t, kNumPredictors> weights;
    145    for (size_t i = 0; i < kNumPredictors; i++) {
    146      // pred_errors[pos_N] also contains the error of pixel W.
    147      // pred_errors[pos_NW] also contains the error of pixel WW.
    148      weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] +
    149                   pred_errors[i][pos_NW];
    150      weights[i] = ErrorWeight(weights[i], header.w[i]);
    151    }
    152 
    153    N = AddBits(N);
    154    W = AddBits(W);
    155    NE = AddBits(NE);
    156    NW = AddBits(NW);
    157    NN = AddBits(NN);
    158 
    159    pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1];
    160    pixel_type_w teN = error[pos_N];
    161    pixel_type_w teNW = error[pos_NW];
    162    pixel_type_w sumWN = teN + teW;
    163    pixel_type_w teNE = error[pos_NE];
    164 
    165    if (compute_properties) {
    166      pixel_type_w p = teW;
    167      if (std::abs(teN) > std::abs(p)) p = teN;
    168      if (std::abs(teNW) > std::abs(p)) p = teNW;
    169      if (std::abs(teNE) > std::abs(p)) p = teNE;
    170      (*properties)[offset++] = p;
    171    }
    172 
    173    prediction[0] = W + NE - N;
    174    prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5);
    175    prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5);
    176    prediction[3] =
    177        N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc +
    178              (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >>
    179             5);
    180 
    181    pred = WeightedAverage(prediction, weights);
    182 
    183    // If all three have the same sign, skip clamping.
    184    if (((teN ^ teW) | (teN ^ teNW)) > 0) {
    185      return (pred + kPredictionRound) >> kPredExtraBits;
    186    }
    187 
    188    // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N).
    189    pixel_type_w mx = std::max(W, std::max(NE, N));
    190    pixel_type_w mn = std::min(W, std::min(NE, N));
    191    pred = std::max(mn, std::min(mx, pred));
    192    return (pred + kPredictionRound) >> kPredExtraBits;
    193  }
    194 
    195  JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y,
    196                               size_t xsize) {
    197    size_t cur_row = y & 1 ? 0 : (xsize + 2);
    198    size_t prev_row = y & 1 ? (xsize + 2) : 0;
    199    val = AddBits(val);
    200    error[cur_row + x] = pred - val;
    201    for (size_t i = 0; i < kNumPredictors; i++) {
    202      pixel_type_w err =
    203          (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits;
    204      // For predicting in the next row.
    205      pred_errors[i][cur_row + x] = err;
    206      // Add the error on this pixel to the error on the NE pixel. This has the
    207      // effect of adding the error on this pixel to the E and EE pixels.
    208      pred_errors[i][prev_row + x + 1] += err;
    209    }
    210  }
    211 };
    212 
    213 // Encoder helper function to set the parameters to some presets.
    214 inline void PredictorMode(int i, Header *header) {
    215  switch (i) {
    216    case 0:
    217      // ~ lossless16 predictor
    218      header->w[0] = 0xd;
    219      header->w[1] = 0xc;
    220      header->w[2] = 0xc;
    221      header->w[3] = 0xc;
    222      header->p1C = 16;
    223      header->p2C = 10;
    224      header->p3Ca = 7;
    225      header->p3Cb = 7;
    226      header->p3Cc = 7;
    227      header->p3Cd = 0;
    228      header->p3Ce = 0;
    229      break;
    230    case 1:
    231      // ~ default lossless8 predictor
    232      header->w[0] = 0xd;
    233      header->w[1] = 0xc;
    234      header->w[2] = 0xc;
    235      header->w[3] = 0xb;
    236      header->p1C = 8;
    237      header->p2C = 8;
    238      header->p3Ca = 4;
    239      header->p3Cb = 0;
    240      header->p3Cc = 3;
    241      header->p3Cd = 23;
    242      header->p3Ce = 2;
    243      break;
    244    case 2:
    245      // ~ west lossless8 predictor
    246      header->w[0] = 0xd;
    247      header->w[1] = 0xc;
    248      header->w[2] = 0xd;
    249      header->w[3] = 0xc;
    250      header->p1C = 10;
    251      header->p2C = 9;
    252      header->p3Ca = 7;
    253      header->p3Cb = 0;
    254      header->p3Cc = 0;
    255      header->p3Cd = 16;
    256      header->p3Ce = 9;
    257      break;
    258    case 3:
    259      // ~ north lossless8 predictor
    260      header->w[0] = 0xd;
    261      header->w[1] = 0xd;
    262      header->w[2] = 0xc;
    263      header->w[3] = 0xc;
    264      header->p1C = 16;
    265      header->p2C = 8;
    266      header->p3Ca = 0;
    267      header->p3Cb = 16;
    268      header->p3Cc = 0;
    269      header->p3Cd = 23;
    270      header->p3Ce = 0;
    271      break;
    272    case 4:
    273    default:
    274      // something else, because why not
    275      header->w[0] = 0xd;
    276      header->w[1] = 0xc;
    277      header->w[2] = 0xc;
    278      header->w[3] = 0xc;
    279      header->p1C = 10;
    280      header->p2C = 10;
    281      header->p3Ca = 5;
    282      header->p3Cb = 5;
    283      header->p3Cc = 5;
    284      header->p3Cd = 12;
    285      header->p3Ce = 4;
    286      break;
    287  }
    288 }
    289 }  // namespace weighted
    290 
    291 // Stores a node and its two children at the same time. This significantly
    292 // reduces the number of branches needed during decoding.
    293 struct FlatDecisionNode {
    294  // Property + splitval of the top node.
    295  int32_t property0;  // -1 if leaf.
    296  union {
    297    PropertyVal splitval0;
    298    Predictor predictor;
    299  };
    300  // Property+splitval of the two child nodes.
    301  union {
    302    PropertyVal splitvals[2];
    303    int32_t multiplier;
    304  };
    305  uint32_t childID;  // childID is ctx id if leaf.
    306  union {
    307    int16_t properties[2];
    308    int32_t predictor_offset;
    309  };
    310 };
    311 using FlatTree = std::vector<FlatDecisionNode>;
    312 
    313 class MATreeLookup {
    314 public:
    315  explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {}
    316  struct LookupResult {
    317    uint32_t context;
    318    Predictor predictor;
    319    int32_t offset;
    320    int32_t multiplier;
    321  };
    322  JXL_INLINE LookupResult Lookup(const Properties &properties) const {
    323    uint32_t pos = 0;
    324    while (true) {
    325 #define TRAVERSE_THE_TREE                                                      \
    326  {                                                                            \
    327    const FlatDecisionNode &node = nodes_[pos];                                \
    328    if (node.property0 < 0) {                                                  \
    329      return {node.childID, node.predictor, node.predictor_offset,             \
    330              node.multiplier};                                                \
    331    }                                                                          \
    332    bool p0 = properties[node.property0] <= node.splitval0;                    \
    333    uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0];       \
    334    uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1]); \
    335    pos = node.childID + (p0 ? off1 : off0);                                   \
    336  }
    337 
    338      TRAVERSE_THE_TREE;
    339      TRAVERSE_THE_TREE;
    340    }
    341  }
    342 
    343 private:
    344  const FlatTree &nodes_;
    345 };
    346 
    347 static constexpr size_t kExtraPropsPerChannel = 4;
    348 static constexpr size_t kNumNonrefProperties =
    349    kNumStaticProperties + 13 + weighted::kNumProperties;
    350 
    351 constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties;
    352 constexpr size_t kGradientProp = 9;
    353 
    354 // Clamps gradient to the min/max of n, w (and l, implicitly).
    355 static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w,
    356                                          const int32_t l) {
    357  const int32_t m = std::min(n, w);
    358  const int32_t M = std::max(n, w);
    359  // The end result of this operation doesn't overflow or underflow if the
    360  // result is between m and M, but the intermediate value may overflow, so we
    361  // do the intermediate operations in uint32_t and check later if we had an
    362  // overflow or underflow condition comparing m, M and l directly.
    363  // grad = M + m - l = n + w - l
    364  const int32_t grad =
    365      static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) -
    366                           static_cast<uint32_t>(l));
    367  // We use two sets of ternary operators to force the evaluation of them in
    368  // any case, allowing the compiler to avoid branches and use cmovl/cmovg in
    369  // x86.
    370  const int32_t grad_clamp_M = (l < m) ? M : grad;
    371  return (l > M) ? m : grad_clamp_M;
    372 }
    373 
    374 inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) {
    375  pixel_type_w p = a + b - c;
    376  pixel_type_w pa = std::abs(p - a);
    377  pixel_type_w pb = std::abs(p - b);
    378  return pa < pb ? a : b;
    379 }
    380 
    381 inline void PrecomputeReferences(const Channel &ch, size_t y,
    382                                 const Image &image, uint32_t i,
    383                                 Channel *references) {
    384  ZeroFillImage(&references->plane);
    385  uint32_t offset = 0;
    386  size_t num_extra_props = references->w;
    387  intptr_t onerow = references->plane.PixelsPerRow();
    388  for (int32_t j = static_cast<int32_t>(i) - 1;
    389       j >= 0 && offset < num_extra_props; j--) {
    390    if (image.channel[j].w != image.channel[i].w ||
    391        image.channel[j].h != image.channel[i].h) {
    392      continue;
    393    }
    394    if (image.channel[j].hshift != image.channel[i].hshift) continue;
    395    if (image.channel[j].vshift != image.channel[i].vshift) continue;
    396    pixel_type *JXL_RESTRICT rp = references->Row(0) + offset;
    397    const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y);
    398    const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0);
    399    for (size_t x = 0; x < ch.w; x++, rp += onerow) {
    400      pixel_type_w v = rpp[x];
    401      rp[0] = std::abs(v);
    402      rp[1] = v;
    403      pixel_type_w vleft = (x ? rpp[x - 1] : 0);
    404      pixel_type_w vtop = (y ? rpprev[x] : vleft);
    405      pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft);
    406      pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft);
    407      rp[2] = std::abs(v - vpredicted);
    408      rp[3] = v - vpredicted;
    409    }
    410 
    411    offset += kExtraPropsPerChannel;
    412  }
    413 }
    414 
    415 struct PredictionResult {
    416  int context = 0;
    417  pixel_type_w guess = 0;
    418  Predictor predictor;
    419  int32_t multiplier;
    420 };
    421 
    422 inline void InitPropsRow(
    423    Properties *p,
    424    const std::array<pixel_type, kNumStaticProperties> &static_props,
    425    const int y) {
    426  for (size_t i = 0; i < kNumStaticProperties; i++) {
    427    (*p)[i] = static_props[i];
    428  }
    429  (*p)[2] = y;
    430  (*p)[9] = 0;  // local gradient.
    431 }
    432 
    433 namespace detail {
    434 enum PredictorMode {
    435  kUseTree = 1,
    436  kUseWP = 2,
    437  kForceComputeProperties = 4,
    438  kAllPredictions = 8,
    439  kNoEdgeCases = 16
    440 };
    441 
    442 JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left,
    443                                   pixel_type_w top, pixel_type_w toptop,
    444                                   pixel_type_w topleft, pixel_type_w topright,
    445                                   pixel_type_w leftleft,
    446                                   pixel_type_w toprightright,
    447                                   pixel_type_w wp_pred) {
    448  switch (p) {
    449    case Predictor::Zero:
    450      return pixel_type_w{0};
    451    case Predictor::Left:
    452      return left;
    453    case Predictor::Top:
    454      return top;
    455    case Predictor::Select:
    456      return Select(left, top, topleft);
    457    case Predictor::Weighted:
    458      return wp_pred;
    459    case Predictor::Gradient:
    460      return pixel_type_w{ClampedGradient(left, top, topleft)};
    461    case Predictor::TopLeft:
    462      return topleft;
    463    case Predictor::TopRight:
    464      return topright;
    465    case Predictor::LeftLeft:
    466      return leftleft;
    467    case Predictor::Average0:
    468      return (left + top) / 2;
    469    case Predictor::Average1:
    470      return (left + topleft) / 2;
    471    case Predictor::Average2:
    472      return (topleft + top) / 2;
    473    case Predictor::Average3:
    474      return (top + topright) / 2;
    475    case Predictor::Average4:
    476      return (6 * top - 2 * toptop + 7 * left + 1 * leftleft +
    477              1 * toprightright + 3 * topright + 8) /
    478             16;
    479    default:
    480      return pixel_type_w{0};
    481  }
    482 }
    483 
    484 template <int mode>
    485 JXL_INLINE PredictionResult Predict(
    486    Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
    487    const intptr_t onerow, const size_t x, const size_t y, Predictor predictor,
    488    const MATreeLookup *lookup, const Channel *references,
    489    weighted::State *wp_state, pixel_type_w *predictions) {
    490  // We start in position 3 because of 2 static properties + y.
    491  size_t offset = 3;
    492  constexpr bool compute_properties =
    493      mode & kUseTree || mode & kForceComputeProperties;
    494  constexpr bool nec = mode & kNoEdgeCases;
    495  pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0));
    496  pixel_type_w top = (nec || y ? pp[-onerow] : left);
    497  pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left);
    498  pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top);
    499  pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left);
    500  pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top);
    501  pixel_type_w toprightright =
    502      (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright);
    503 
    504  if (compute_properties) {
    505    // location
    506    (*p)[offset++] = x;
    507    // neighbors
    508    (*p)[offset++] = top > 0 ? top : -top;
    509    (*p)[offset++] = left > 0 ? left : -left;
    510    (*p)[offset++] = top;
    511    (*p)[offset++] = left;
    512 
    513    // local gradient
    514    (*p)[offset] = left - (*p)[offset + 1];
    515    offset++;
    516    // local gradient
    517    (*p)[offset++] = left + top - topleft;
    518 
    519    // FFV1 context properties
    520    (*p)[offset++] = left - topleft;
    521    (*p)[offset++] = topleft - top;
    522    (*p)[offset++] = top - topright;
    523    (*p)[offset++] = top - toptop;
    524    (*p)[offset++] = left - leftleft;
    525  }
    526 
    527  pixel_type_w wp_pred = 0;
    528  if (mode & kUseWP) {
    529    wp_pred = wp_state->Predict<compute_properties>(
    530        x, y, w, top, left, topright, topleft, toptop, p, offset);
    531  }
    532  if (!nec && compute_properties) {
    533    offset += weighted::kNumProperties;
    534    // Extra properties.
    535    const pixel_type *JXL_RESTRICT rp = references->Row(x);
    536    for (size_t i = 0; i < references->w; i++) {
    537      (*p)[offset++] = rp[i];
    538    }
    539  }
    540  PredictionResult result;
    541  if (mode & kUseTree) {
    542    MATreeLookup::LookupResult lr = lookup->Lookup(*p);
    543    result.context = lr.context;
    544    result.guess = lr.offset;
    545    result.multiplier = lr.multiplier;
    546    predictor = lr.predictor;
    547  }
    548  if (mode & kAllPredictions) {
    549    for (size_t i = 0; i < kNumModularPredictors; i++) {
    550      predictions[i] =
    551          PredictOne(static_cast<Predictor>(i), left, top, toptop, topleft,
    552                     topright, leftleft, toprightright, wp_pred);
    553    }
    554  }
    555  result.guess += PredictOne(predictor, left, top, toptop, topleft, topright,
    556                             leftleft, toprightright, wp_pred);
    557  result.predictor = predictor;
    558 
    559  return result;
    560 }
    561 }  // namespace detail
    562 
    563 inline PredictionResult PredictNoTreeNoWP(size_t w,
    564                                          const pixel_type *JXL_RESTRICT pp,
    565                                          const intptr_t onerow, const int x,
    566                                          const int y, Predictor predictor) {
    567  return detail::Predict</*mode=*/0>(
    568      /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
    569      /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr);
    570 }
    571 
    572 inline PredictionResult PredictNoTreeWP(size_t w,
    573                                        const pixel_type *JXL_RESTRICT pp,
    574                                        const intptr_t onerow, const int x,
    575                                        const int y, Predictor predictor,
    576                                        weighted::State *wp_state) {
    577  return detail::Predict<detail::kUseWP>(
    578      /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
    579      /*references=*/nullptr, wp_state, /*predictions=*/nullptr);
    580 }
    581 
    582 inline PredictionResult PredictTreeNoWP(Properties *p, size_t w,
    583                                        const pixel_type *JXL_RESTRICT pp,
    584                                        const intptr_t onerow, const int x,
    585                                        const int y,
    586                                        const MATreeLookup &tree_lookup,
    587                                        const Channel &references) {
    588  return detail::Predict<detail::kUseTree>(
    589      p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
    590      /*wp_state=*/nullptr, /*predictions=*/nullptr);
    591 }
    592 // Only use for y > 1, x > 1, x < w-2, and empty references
    593 JXL_INLINE PredictionResult
    594 PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
    595                   const intptr_t onerow, const int x, const int y,
    596                   const MATreeLookup &tree_lookup, const Channel &references) {
    597  return detail::Predict<detail::kUseTree | detail::kNoEdgeCases>(
    598      p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
    599      /*wp_state=*/nullptr, /*predictions=*/nullptr);
    600 }
    601 
    602 inline PredictionResult PredictTreeWP(Properties *p, size_t w,
    603                                      const pixel_type *JXL_RESTRICT pp,
    604                                      const intptr_t onerow, const int x,
    605                                      const int y,
    606                                      const MATreeLookup &tree_lookup,
    607                                      const Channel &references,
    608                                      weighted::State *wp_state) {
    609  return detail::Predict<detail::kUseTree | detail::kUseWP>(
    610      p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
    611      wp_state, /*predictions=*/nullptr);
    612 }
    613 JXL_INLINE PredictionResult PredictTreeWPNEC(Properties *p, size_t w,
    614                                             const pixel_type *JXL_RESTRICT pp,
    615                                             const intptr_t onerow, const int x,
    616                                             const int y,
    617                                             const MATreeLookup &tree_lookup,
    618                                             const Channel &references,
    619                                             weighted::State *wp_state) {
    620  return detail::Predict<detail::kUseTree | detail::kUseWP |
    621                         detail::kNoEdgeCases>(
    622      p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
    623      wp_state, /*predictions=*/nullptr);
    624 }
    625 
    626 inline PredictionResult PredictLearn(Properties *p, size_t w,
    627                                     const pixel_type *JXL_RESTRICT pp,
    628                                     const intptr_t onerow, const int x,
    629                                     const int y, Predictor predictor,
    630                                     const Channel &references,
    631                                     weighted::State *wp_state) {
    632  return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>(
    633      p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
    634      wp_state, /*predictions=*/nullptr);
    635 }
    636 
    637 inline void PredictLearnAll(Properties *p, size_t w,
    638                            const pixel_type *JXL_RESTRICT pp,
    639                            const intptr_t onerow, const int x, const int y,
    640                            const Channel &references,
    641                            weighted::State *wp_state,
    642                            pixel_type_w *predictions) {
    643  detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
    644                  detail::kAllPredictions>(
    645      p, w, pp, onerow, x, y, Predictor::Zero,
    646      /*lookup=*/nullptr, &references, wp_state, predictions);
    647 }
    648 inline PredictionResult PredictLearnNEC(Properties *p, size_t w,
    649                                        const pixel_type *JXL_RESTRICT pp,
    650                                        const intptr_t onerow, const int x,
    651                                        const int y, Predictor predictor,
    652                                        const Channel &references,
    653                                        weighted::State *wp_state) {
    654  return detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
    655                         detail::kNoEdgeCases>(
    656      p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
    657      wp_state, /*predictions=*/nullptr);
    658 }
    659 
    660 inline void PredictLearnAllNEC(Properties *p, size_t w,
    661                               const pixel_type *JXL_RESTRICT pp,
    662                               const intptr_t onerow, const int x, const int y,
    663                               const Channel &references,
    664                               weighted::State *wp_state,
    665                               pixel_type_w *predictions) {
    666  detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
    667                  detail::kAllPredictions | detail::kNoEdgeCases>(
    668      p, w, pp, onerow, x, y, Predictor::Zero,
    669      /*lookup=*/nullptr, &references, wp_state, predictions);
    670 }
    671 
    672 inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp,
    673                           const intptr_t onerow, const int x, const int y,
    674                           pixel_type_w *predictions) {
    675  detail::Predict<detail::kAllPredictions>(
    676      /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero,
    677      /*lookup=*/nullptr,
    678      /*references=*/nullptr, /*wp_state=*/nullptr, predictions);
    679 }
    680 }  // namespace jxl
    681 
    682 #endif  // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_