context_predict.h (25841B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ 7 #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ 8 9 #include <algorithm> 10 #include <array> 11 #include <cmath> 12 #include <cstddef> 13 #include <cstdint> 14 #include <vector> 15 16 #include "lib/jxl/base/bits.h" 17 #include "lib/jxl/base/compiler_specific.h" 18 #include "lib/jxl/base/status.h" 19 #include "lib/jxl/field_encodings.h" 20 #include "lib/jxl/fields.h" 21 #include "lib/jxl/image_ops.h" 22 #include "lib/jxl/modular/modular_image.h" 23 #include "lib/jxl/modular/options.h" 24 25 namespace jxl { 26 27 namespace weighted { 28 constexpr static size_t kNumPredictors = 4; 29 constexpr static int64_t kPredExtraBits = 3; 30 constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1; 31 constexpr static size_t kNumProperties = 1; 32 33 struct Header : public Fields { 34 JXL_FIELDS_NAME(WeightedPredictorHeader) 35 // TODO(janwas): move to cc file, avoid including fields.h. 36 Header() { Bundle::Init(this); } 37 38 Status VisitFields(Visitor *JXL_RESTRICT visitor) override { 39 if (visitor->AllDefault(*this, &all_default)) { 40 // Overwrite all serialized fields, but not any nonserialized_*. 41 visitor->SetDefault(this); 42 return true; 43 } 44 auto visit_p = [visitor](pixel_type val, pixel_type *p) { 45 uint32_t up = *p; 46 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up)); 47 *p = up; 48 return Status(true); 49 }; 50 JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C)); 51 JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C)); 52 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca)); 53 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb)); 54 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc)); 55 JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd)); 56 JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce)); 57 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0])); 58 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1])); 59 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2])); 60 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3])); 61 return true; 62 } 63 64 bool all_default; 65 pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0; 66 uint32_t w[kNumPredictors] = {}; 67 }; 68 69 struct State { 70 pixel_type_w prediction[kNumPredictors] = {}; 71 pixel_type_w pred = 0; // *before* removing the added bits. 72 std::vector<uint32_t> pred_errors[kNumPredictors]; 73 std::vector<int32_t> error; 74 const Header &header; 75 76 // Allows to approximate division by a number from 1 to 64. 77 // for (int i = 0; i < 64; i++) divlookup[i] = (1 << 24) / (i + 1); 78 79 const uint32_t divlookup[64] = { 80 16777216, 8388608, 5592405, 4194304, 3355443, 2796202, 2396745, 2097152, 81 1864135, 1677721, 1525201, 1398101, 1290555, 1198372, 1118481, 1048576, 82 986895, 932067, 883011, 838860, 798915, 762600, 729444, 699050, 83 671088, 645277, 621378, 599186, 578524, 559240, 541200, 524288, 84 508400, 493447, 479349, 466033, 453438, 441505, 430185, 419430, 85 409200, 399457, 390167, 381300, 372827, 364722, 356962, 349525, 86 342392, 335544, 328965, 322638, 316551, 310689, 305040, 299593, 87 294337, 289262, 284359, 279620, 275036, 270600, 266305, 262144}; 88 89 constexpr static pixel_type_w AddBits(pixel_type_w x) { 90 return static_cast<uint64_t>(x) << kPredExtraBits; 91 } 92 93 State(const Header &header, size_t xsize, size_t ysize) : header(header) { 94 // Extra margin to avoid out-of-bounds writes. 95 // All have space for two rows of data. 96 for (auto &pred_error : pred_errors) { 97 pred_error.resize((xsize + 2) * 2); 98 } 99 error.resize((xsize + 2) * 2); 100 } 101 102 // Approximates 4+(maxweight<<24)/(x+1), avoiding division 103 JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const { 104 int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5; 105 if (shift < 0) shift = 0; 106 return 4 + ((maxweight * divlookup[x >> shift]) >> shift); 107 } 108 109 // Approximates the weighted average of the input values with the given 110 // weights, avoiding division. Weights must sum to at least 16. 111 JXL_INLINE pixel_type_w 112 WeightedAverage(const pixel_type_w *JXL_RESTRICT p, 113 std::array<uint32_t, kNumPredictors> w) const { 114 uint32_t weight_sum = 0; 115 for (size_t i = 0; i < kNumPredictors; i++) { 116 weight_sum += w[i]; 117 } 118 JXL_DASSERT(weight_sum > 15); 119 uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4. 120 weight_sum = 0; 121 for (size_t i = 0; i < kNumPredictors; i++) { 122 w[i] >>= log_weight - 4; 123 weight_sum += w[i]; 124 } 125 // for rounding. 126 pixel_type_w sum = (weight_sum >> 1) - 1; 127 for (size_t i = 0; i < kNumPredictors; i++) { 128 sum += p[i] * w[i]; 129 } 130 return (sum * divlookup[weight_sum - 1]) >> 24; 131 } 132 133 template <bool compute_properties> 134 JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize, 135 pixel_type_w N, pixel_type_w W, 136 pixel_type_w NE, pixel_type_w NW, 137 pixel_type_w NN, Properties *properties, 138 size_t offset) { 139 size_t cur_row = y & 1 ? 0 : (xsize + 2); 140 size_t prev_row = y & 1 ? (xsize + 2) : 0; 141 size_t pos_N = prev_row + x; 142 size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N; 143 size_t pos_NW = x > 0 ? pos_N - 1 : pos_N; 144 std::array<uint32_t, kNumPredictors> weights; 145 for (size_t i = 0; i < kNumPredictors; i++) { 146 // pred_errors[pos_N] also contains the error of pixel W. 147 // pred_errors[pos_NW] also contains the error of pixel WW. 148 weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] + 149 pred_errors[i][pos_NW]; 150 weights[i] = ErrorWeight(weights[i], header.w[i]); 151 } 152 153 N = AddBits(N); 154 W = AddBits(W); 155 NE = AddBits(NE); 156 NW = AddBits(NW); 157 NN = AddBits(NN); 158 159 pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1]; 160 pixel_type_w teN = error[pos_N]; 161 pixel_type_w teNW = error[pos_NW]; 162 pixel_type_w sumWN = teN + teW; 163 pixel_type_w teNE = error[pos_NE]; 164 165 if (compute_properties) { 166 pixel_type_w p = teW; 167 if (std::abs(teN) > std::abs(p)) p = teN; 168 if (std::abs(teNW) > std::abs(p)) p = teNW; 169 if (std::abs(teNE) > std::abs(p)) p = teNE; 170 (*properties)[offset++] = p; 171 } 172 173 prediction[0] = W + NE - N; 174 prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5); 175 prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5); 176 prediction[3] = 177 N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc + 178 (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >> 179 5); 180 181 pred = WeightedAverage(prediction, weights); 182 183 // If all three have the same sign, skip clamping. 184 if (((teN ^ teW) | (teN ^ teNW)) > 0) { 185 return (pred + kPredictionRound) >> kPredExtraBits; 186 } 187 188 // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N). 189 pixel_type_w mx = std::max(W, std::max(NE, N)); 190 pixel_type_w mn = std::min(W, std::min(NE, N)); 191 pred = std::max(mn, std::min(mx, pred)); 192 return (pred + kPredictionRound) >> kPredExtraBits; 193 } 194 195 JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y, 196 size_t xsize) { 197 size_t cur_row = y & 1 ? 0 : (xsize + 2); 198 size_t prev_row = y & 1 ? (xsize + 2) : 0; 199 val = AddBits(val); 200 error[cur_row + x] = pred - val; 201 for (size_t i = 0; i < kNumPredictors; i++) { 202 pixel_type_w err = 203 (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits; 204 // For predicting in the next row. 205 pred_errors[i][cur_row + x] = err; 206 // Add the error on this pixel to the error on the NE pixel. This has the 207 // effect of adding the error on this pixel to the E and EE pixels. 208 pred_errors[i][prev_row + x + 1] += err; 209 } 210 } 211 }; 212 213 // Encoder helper function to set the parameters to some presets. 214 inline void PredictorMode(int i, Header *header) { 215 switch (i) { 216 case 0: 217 // ~ lossless16 predictor 218 header->w[0] = 0xd; 219 header->w[1] = 0xc; 220 header->w[2] = 0xc; 221 header->w[3] = 0xc; 222 header->p1C = 16; 223 header->p2C = 10; 224 header->p3Ca = 7; 225 header->p3Cb = 7; 226 header->p3Cc = 7; 227 header->p3Cd = 0; 228 header->p3Ce = 0; 229 break; 230 case 1: 231 // ~ default lossless8 predictor 232 header->w[0] = 0xd; 233 header->w[1] = 0xc; 234 header->w[2] = 0xc; 235 header->w[3] = 0xb; 236 header->p1C = 8; 237 header->p2C = 8; 238 header->p3Ca = 4; 239 header->p3Cb = 0; 240 header->p3Cc = 3; 241 header->p3Cd = 23; 242 header->p3Ce = 2; 243 break; 244 case 2: 245 // ~ west lossless8 predictor 246 header->w[0] = 0xd; 247 header->w[1] = 0xc; 248 header->w[2] = 0xd; 249 header->w[3] = 0xc; 250 header->p1C = 10; 251 header->p2C = 9; 252 header->p3Ca = 7; 253 header->p3Cb = 0; 254 header->p3Cc = 0; 255 header->p3Cd = 16; 256 header->p3Ce = 9; 257 break; 258 case 3: 259 // ~ north lossless8 predictor 260 header->w[0] = 0xd; 261 header->w[1] = 0xd; 262 header->w[2] = 0xc; 263 header->w[3] = 0xc; 264 header->p1C = 16; 265 header->p2C = 8; 266 header->p3Ca = 0; 267 header->p3Cb = 16; 268 header->p3Cc = 0; 269 header->p3Cd = 23; 270 header->p3Ce = 0; 271 break; 272 case 4: 273 default: 274 // something else, because why not 275 header->w[0] = 0xd; 276 header->w[1] = 0xc; 277 header->w[2] = 0xc; 278 header->w[3] = 0xc; 279 header->p1C = 10; 280 header->p2C = 10; 281 header->p3Ca = 5; 282 header->p3Cb = 5; 283 header->p3Cc = 5; 284 header->p3Cd = 12; 285 header->p3Ce = 4; 286 break; 287 } 288 } 289 } // namespace weighted 290 291 // Stores a node and its two children at the same time. This significantly 292 // reduces the number of branches needed during decoding. 293 struct FlatDecisionNode { 294 // Property + splitval of the top node. 295 int32_t property0; // -1 if leaf. 296 union { 297 PropertyVal splitval0; 298 Predictor predictor; 299 }; 300 // Property+splitval of the two child nodes. 301 union { 302 PropertyVal splitvals[2]; 303 int32_t multiplier; 304 }; 305 uint32_t childID; // childID is ctx id if leaf. 306 union { 307 int16_t properties[2]; 308 int32_t predictor_offset; 309 }; 310 }; 311 using FlatTree = std::vector<FlatDecisionNode>; 312 313 class MATreeLookup { 314 public: 315 explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {} 316 struct LookupResult { 317 uint32_t context; 318 Predictor predictor; 319 int32_t offset; 320 int32_t multiplier; 321 }; 322 JXL_INLINE LookupResult Lookup(const Properties &properties) const { 323 uint32_t pos = 0; 324 while (true) { 325 #define TRAVERSE_THE_TREE \ 326 { \ 327 const FlatDecisionNode &node = nodes_[pos]; \ 328 if (node.property0 < 0) { \ 329 return {node.childID, node.predictor, node.predictor_offset, \ 330 node.multiplier}; \ 331 } \ 332 bool p0 = properties[node.property0] <= node.splitval0; \ 333 uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0]; \ 334 uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1]); \ 335 pos = node.childID + (p0 ? off1 : off0); \ 336 } 337 338 TRAVERSE_THE_TREE; 339 TRAVERSE_THE_TREE; 340 } 341 } 342 343 private: 344 const FlatTree &nodes_; 345 }; 346 347 static constexpr size_t kExtraPropsPerChannel = 4; 348 static constexpr size_t kNumNonrefProperties = 349 kNumStaticProperties + 13 + weighted::kNumProperties; 350 351 constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties; 352 constexpr size_t kGradientProp = 9; 353 354 // Clamps gradient to the min/max of n, w (and l, implicitly). 355 static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w, 356 const int32_t l) { 357 const int32_t m = std::min(n, w); 358 const int32_t M = std::max(n, w); 359 // The end result of this operation doesn't overflow or underflow if the 360 // result is between m and M, but the intermediate value may overflow, so we 361 // do the intermediate operations in uint32_t and check later if we had an 362 // overflow or underflow condition comparing m, M and l directly. 363 // grad = M + m - l = n + w - l 364 const int32_t grad = 365 static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) - 366 static_cast<uint32_t>(l)); 367 // We use two sets of ternary operators to force the evaluation of them in 368 // any case, allowing the compiler to avoid branches and use cmovl/cmovg in 369 // x86. 370 const int32_t grad_clamp_M = (l < m) ? M : grad; 371 return (l > M) ? m : grad_clamp_M; 372 } 373 374 inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) { 375 pixel_type_w p = a + b - c; 376 pixel_type_w pa = std::abs(p - a); 377 pixel_type_w pb = std::abs(p - b); 378 return pa < pb ? a : b; 379 } 380 381 inline void PrecomputeReferences(const Channel &ch, size_t y, 382 const Image &image, uint32_t i, 383 Channel *references) { 384 ZeroFillImage(&references->plane); 385 uint32_t offset = 0; 386 size_t num_extra_props = references->w; 387 intptr_t onerow = references->plane.PixelsPerRow(); 388 for (int32_t j = static_cast<int32_t>(i) - 1; 389 j >= 0 && offset < num_extra_props; j--) { 390 if (image.channel[j].w != image.channel[i].w || 391 image.channel[j].h != image.channel[i].h) { 392 continue; 393 } 394 if (image.channel[j].hshift != image.channel[i].hshift) continue; 395 if (image.channel[j].vshift != image.channel[i].vshift) continue; 396 pixel_type *JXL_RESTRICT rp = references->Row(0) + offset; 397 const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y); 398 const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0); 399 for (size_t x = 0; x < ch.w; x++, rp += onerow) { 400 pixel_type_w v = rpp[x]; 401 rp[0] = std::abs(v); 402 rp[1] = v; 403 pixel_type_w vleft = (x ? rpp[x - 1] : 0); 404 pixel_type_w vtop = (y ? rpprev[x] : vleft); 405 pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft); 406 pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft); 407 rp[2] = std::abs(v - vpredicted); 408 rp[3] = v - vpredicted; 409 } 410 411 offset += kExtraPropsPerChannel; 412 } 413 } 414 415 struct PredictionResult { 416 int context = 0; 417 pixel_type_w guess = 0; 418 Predictor predictor; 419 int32_t multiplier; 420 }; 421 422 inline void InitPropsRow( 423 Properties *p, 424 const std::array<pixel_type, kNumStaticProperties> &static_props, 425 const int y) { 426 for (size_t i = 0; i < kNumStaticProperties; i++) { 427 (*p)[i] = static_props[i]; 428 } 429 (*p)[2] = y; 430 (*p)[9] = 0; // local gradient. 431 } 432 433 namespace detail { 434 enum PredictorMode { 435 kUseTree = 1, 436 kUseWP = 2, 437 kForceComputeProperties = 4, 438 kAllPredictions = 8, 439 kNoEdgeCases = 16 440 }; 441 442 JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left, 443 pixel_type_w top, pixel_type_w toptop, 444 pixel_type_w topleft, pixel_type_w topright, 445 pixel_type_w leftleft, 446 pixel_type_w toprightright, 447 pixel_type_w wp_pred) { 448 switch (p) { 449 case Predictor::Zero: 450 return pixel_type_w{0}; 451 case Predictor::Left: 452 return left; 453 case Predictor::Top: 454 return top; 455 case Predictor::Select: 456 return Select(left, top, topleft); 457 case Predictor::Weighted: 458 return wp_pred; 459 case Predictor::Gradient: 460 return pixel_type_w{ClampedGradient(left, top, topleft)}; 461 case Predictor::TopLeft: 462 return topleft; 463 case Predictor::TopRight: 464 return topright; 465 case Predictor::LeftLeft: 466 return leftleft; 467 case Predictor::Average0: 468 return (left + top) / 2; 469 case Predictor::Average1: 470 return (left + topleft) / 2; 471 case Predictor::Average2: 472 return (topleft + top) / 2; 473 case Predictor::Average3: 474 return (top + topright) / 2; 475 case Predictor::Average4: 476 return (6 * top - 2 * toptop + 7 * left + 1 * leftleft + 477 1 * toprightright + 3 * topright + 8) / 478 16; 479 default: 480 return pixel_type_w{0}; 481 } 482 } 483 484 template <int mode> 485 JXL_INLINE PredictionResult Predict( 486 Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, 487 const intptr_t onerow, const size_t x, const size_t y, Predictor predictor, 488 const MATreeLookup *lookup, const Channel *references, 489 weighted::State *wp_state, pixel_type_w *predictions) { 490 // We start in position 3 because of 2 static properties + y. 491 size_t offset = 3; 492 constexpr bool compute_properties = 493 mode & kUseTree || mode & kForceComputeProperties; 494 constexpr bool nec = mode & kNoEdgeCases; 495 pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0)); 496 pixel_type_w top = (nec || y ? pp[-onerow] : left); 497 pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left); 498 pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top); 499 pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left); 500 pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top); 501 pixel_type_w toprightright = 502 (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright); 503 504 if (compute_properties) { 505 // location 506 (*p)[offset++] = x; 507 // neighbors 508 (*p)[offset++] = top > 0 ? top : -top; 509 (*p)[offset++] = left > 0 ? left : -left; 510 (*p)[offset++] = top; 511 (*p)[offset++] = left; 512 513 // local gradient 514 (*p)[offset] = left - (*p)[offset + 1]; 515 offset++; 516 // local gradient 517 (*p)[offset++] = left + top - topleft; 518 519 // FFV1 context properties 520 (*p)[offset++] = left - topleft; 521 (*p)[offset++] = topleft - top; 522 (*p)[offset++] = top - topright; 523 (*p)[offset++] = top - toptop; 524 (*p)[offset++] = left - leftleft; 525 } 526 527 pixel_type_w wp_pred = 0; 528 if (mode & kUseWP) { 529 wp_pred = wp_state->Predict<compute_properties>( 530 x, y, w, top, left, topright, topleft, toptop, p, offset); 531 } 532 if (!nec && compute_properties) { 533 offset += weighted::kNumProperties; 534 // Extra properties. 535 const pixel_type *JXL_RESTRICT rp = references->Row(x); 536 for (size_t i = 0; i < references->w; i++) { 537 (*p)[offset++] = rp[i]; 538 } 539 } 540 PredictionResult result; 541 if (mode & kUseTree) { 542 MATreeLookup::LookupResult lr = lookup->Lookup(*p); 543 result.context = lr.context; 544 result.guess = lr.offset; 545 result.multiplier = lr.multiplier; 546 predictor = lr.predictor; 547 } 548 if (mode & kAllPredictions) { 549 for (size_t i = 0; i < kNumModularPredictors; i++) { 550 predictions[i] = 551 PredictOne(static_cast<Predictor>(i), left, top, toptop, topleft, 552 topright, leftleft, toprightright, wp_pred); 553 } 554 } 555 result.guess += PredictOne(predictor, left, top, toptop, topleft, topright, 556 leftleft, toprightright, wp_pred); 557 result.predictor = predictor; 558 559 return result; 560 } 561 } // namespace detail 562 563 inline PredictionResult PredictNoTreeNoWP(size_t w, 564 const pixel_type *JXL_RESTRICT pp, 565 const intptr_t onerow, const int x, 566 const int y, Predictor predictor) { 567 return detail::Predict</*mode=*/0>( 568 /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, 569 /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr); 570 } 571 572 inline PredictionResult PredictNoTreeWP(size_t w, 573 const pixel_type *JXL_RESTRICT pp, 574 const intptr_t onerow, const int x, 575 const int y, Predictor predictor, 576 weighted::State *wp_state) { 577 return detail::Predict<detail::kUseWP>( 578 /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, 579 /*references=*/nullptr, wp_state, /*predictions=*/nullptr); 580 } 581 582 inline PredictionResult PredictTreeNoWP(Properties *p, size_t w, 583 const pixel_type *JXL_RESTRICT pp, 584 const intptr_t onerow, const int x, 585 const int y, 586 const MATreeLookup &tree_lookup, 587 const Channel &references) { 588 return detail::Predict<detail::kUseTree>( 589 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, 590 /*wp_state=*/nullptr, /*predictions=*/nullptr); 591 } 592 // Only use for y > 1, x > 1, x < w-2, and empty references 593 JXL_INLINE PredictionResult 594 PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, 595 const intptr_t onerow, const int x, const int y, 596 const MATreeLookup &tree_lookup, const Channel &references) { 597 return detail::Predict<detail::kUseTree | detail::kNoEdgeCases>( 598 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, 599 /*wp_state=*/nullptr, /*predictions=*/nullptr); 600 } 601 602 inline PredictionResult PredictTreeWP(Properties *p, size_t w, 603 const pixel_type *JXL_RESTRICT pp, 604 const intptr_t onerow, const int x, 605 const int y, 606 const MATreeLookup &tree_lookup, 607 const Channel &references, 608 weighted::State *wp_state) { 609 return detail::Predict<detail::kUseTree | detail::kUseWP>( 610 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, 611 wp_state, /*predictions=*/nullptr); 612 } 613 JXL_INLINE PredictionResult PredictTreeWPNEC(Properties *p, size_t w, 614 const pixel_type *JXL_RESTRICT pp, 615 const intptr_t onerow, const int x, 616 const int y, 617 const MATreeLookup &tree_lookup, 618 const Channel &references, 619 weighted::State *wp_state) { 620 return detail::Predict<detail::kUseTree | detail::kUseWP | 621 detail::kNoEdgeCases>( 622 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, 623 wp_state, /*predictions=*/nullptr); 624 } 625 626 inline PredictionResult PredictLearn(Properties *p, size_t w, 627 const pixel_type *JXL_RESTRICT pp, 628 const intptr_t onerow, const int x, 629 const int y, Predictor predictor, 630 const Channel &references, 631 weighted::State *wp_state) { 632 return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>( 633 p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, 634 wp_state, /*predictions=*/nullptr); 635 } 636 637 inline void PredictLearnAll(Properties *p, size_t w, 638 const pixel_type *JXL_RESTRICT pp, 639 const intptr_t onerow, const int x, const int y, 640 const Channel &references, 641 weighted::State *wp_state, 642 pixel_type_w *predictions) { 643 detail::Predict<detail::kForceComputeProperties | detail::kUseWP | 644 detail::kAllPredictions>( 645 p, w, pp, onerow, x, y, Predictor::Zero, 646 /*lookup=*/nullptr, &references, wp_state, predictions); 647 } 648 inline PredictionResult PredictLearnNEC(Properties *p, size_t w, 649 const pixel_type *JXL_RESTRICT pp, 650 const intptr_t onerow, const int x, 651 const int y, Predictor predictor, 652 const Channel &references, 653 weighted::State *wp_state) { 654 return detail::Predict<detail::kForceComputeProperties | detail::kUseWP | 655 detail::kNoEdgeCases>( 656 p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, 657 wp_state, /*predictions=*/nullptr); 658 } 659 660 inline void PredictLearnAllNEC(Properties *p, size_t w, 661 const pixel_type *JXL_RESTRICT pp, 662 const intptr_t onerow, const int x, const int y, 663 const Channel &references, 664 weighted::State *wp_state, 665 pixel_type_w *predictions) { 666 detail::Predict<detail::kForceComputeProperties | detail::kUseWP | 667 detail::kAllPredictions | detail::kNoEdgeCases>( 668 p, w, pp, onerow, x, y, Predictor::Zero, 669 /*lookup=*/nullptr, &references, wp_state, predictions); 670 } 671 672 inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, 673 const intptr_t onerow, const int x, const int y, 674 pixel_type_w *predictions) { 675 detail::Predict<detail::kAllPredictions>( 676 /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero, 677 /*lookup=*/nullptr, 678 /*references=*/nullptr, /*wp_state=*/nullptr, predictions); 679 } 680 } // namespace jxl 681 682 #endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_