enc_encoding.cc (30517B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include <jxl/memory_manager.h> 7 8 #include <algorithm> 9 #include <array> 10 #include <cstddef> 11 #include <cstdint> 12 #include <cstdlib> 13 #include <limits> 14 #include <queue> 15 #include <utility> 16 #include <vector> 17 18 #include "lib/jxl/base/common.h" 19 #include "lib/jxl/base/printf_macros.h" 20 #include "lib/jxl/base/status.h" 21 #include "lib/jxl/enc_ans.h" 22 #include "lib/jxl/enc_aux_out.h" 23 #include "lib/jxl/enc_bit_writer.h" 24 #include "lib/jxl/enc_fields.h" 25 #include "lib/jxl/fields.h" 26 #include "lib/jxl/image_ops.h" 27 #include "lib/jxl/modular/encoding/context_predict.h" 28 #include "lib/jxl/modular/encoding/enc_ma.h" 29 #include "lib/jxl/modular/encoding/encoding.h" 30 #include "lib/jxl/modular/encoding/ma_common.h" 31 #include "lib/jxl/modular/options.h" 32 #include "lib/jxl/pack_signed.h" 33 34 namespace jxl { 35 36 namespace { 37 // Plot tree (if enabled) and predictor usage map. 38 constexpr bool kWantDebug = true; 39 // constexpr bool kPrintTree = false; 40 41 inline std::array<uint8_t, 3> PredictorColor(Predictor p) { 42 switch (p) { 43 case Predictor::Zero: 44 return {{0, 0, 0}}; 45 case Predictor::Left: 46 return {{255, 0, 0}}; 47 case Predictor::Top: 48 return {{0, 255, 0}}; 49 case Predictor::Average0: 50 return {{0, 0, 255}}; 51 case Predictor::Average4: 52 return {{192, 128, 128}}; 53 case Predictor::Select: 54 return {{255, 255, 0}}; 55 case Predictor::Gradient: 56 return {{255, 0, 255}}; 57 case Predictor::Weighted: 58 return {{0, 255, 255}}; 59 // TODO(jon) 60 default: 61 return {{255, 255, 255}}; 62 }; 63 } 64 65 // `cutoffs` must be sorted. 66 Tree MakeFixedTree(int property, const std::vector<int32_t> &cutoffs, 67 Predictor pred, size_t num_pixels, int bitdepth) { 68 size_t log_px = CeilLog2Nonzero(num_pixels); 69 size_t min_gap = 0; 70 // Reduce fixed tree height when encoding small images. 71 if (log_px < 14) { 72 min_gap = 8 * (14 - log_px); 73 } 74 const int shift = bitdepth > 11 ? std::min(4, bitdepth - 11) : 0; 75 const int mul = 1 << shift; 76 Tree tree; 77 struct NodeInfo { 78 size_t begin, end, pos; 79 }; 80 std::queue<NodeInfo> q; 81 // Leaf IDs will be set by roundtrip decoding the tree. 82 tree.push_back(PropertyDecisionNode::Leaf(pred)); 83 q.push(NodeInfo{0, cutoffs.size(), 0}); 84 while (!q.empty()) { 85 NodeInfo info = q.front(); 86 q.pop(); 87 if (info.begin + min_gap >= info.end) continue; 88 uint32_t split = (info.begin + info.end) / 2; 89 int32_t cutoff = cutoffs[split] * mul; 90 tree[info.pos] = PropertyDecisionNode::Split(property, cutoff, tree.size()); 91 q.push(NodeInfo{split + 1, info.end, tree.size()}); 92 tree.push_back(PropertyDecisionNode::Leaf(pred)); 93 q.push(NodeInfo{info.begin, split, tree.size()}); 94 tree.push_back(PropertyDecisionNode::Leaf(pred)); 95 } 96 return tree; 97 } 98 99 } // namespace 100 101 Status GatherTreeData(const Image &image, pixel_type chan, size_t group_id, 102 const weighted::Header &wp_header, 103 const ModularOptions &options, TreeSamples &tree_samples, 104 size_t *total_pixels) { 105 const Channel &channel = image.channel[chan]; 106 JxlMemoryManager *memory_manager = channel.memory_manager(); 107 108 JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w, 109 channel.h, chan); 110 111 std::array<pixel_type, kNumStaticProperties> static_props = { 112 {chan, static_cast<int>(group_id)}}; 113 Properties properties(kNumNonrefProperties + 114 kExtraPropsPerChannel * options.max_properties); 115 double pixel_fraction = std::min(1.0f, options.nb_repeats); 116 // a fraction of 0 is used to disable learning entirely. 117 if (pixel_fraction > 0) { 118 pixel_fraction = std::max(pixel_fraction, 119 std::min(1.0, 1024.0 / (channel.w * channel.h))); 120 } 121 uint64_t threshold = 122 (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction; 123 uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull), 124 static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)}; 125 // Xorshift128+ adapted from xorshift128+-inl.h 126 auto use_sample = [&]() { 127 auto s1 = s[0]; 128 const auto s0 = s[1]; 129 const auto bits = s1 + s0; // b, c 130 s[0] = s0; 131 s1 ^= s1 << 23; 132 s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); 133 s[1] = s1; 134 return (bits >> 32) <= threshold; 135 }; 136 137 const intptr_t onerow = channel.plane.PixelsPerRow(); 138 JXL_ASSIGN_OR_RETURN( 139 Channel references, 140 Channel::Create(memory_manager, properties.size() - kNumNonrefProperties, 141 channel.w)); 142 weighted::State wp_state(wp_header, channel.w, channel.h); 143 tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64); 144 const bool multiple_predictors = tree_samples.NumPredictors() != 1; 145 auto compute_sample = [&](const pixel_type *p, size_t x, size_t y) { 146 pixel_type_w pred[kNumModularPredictors]; 147 if (multiple_predictors) { 148 PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references, 149 &wp_state, pred); 150 } else { 151 pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] = 152 PredictLearn(&properties, channel.w, p + x, onerow, x, y, 153 tree_samples.PredictorFromIndex(0), references, 154 &wp_state) 155 .guess; 156 } 157 (*total_pixels)++; 158 if (use_sample()) { 159 tree_samples.AddSample(p[x], properties, pred); 160 } 161 wp_state.UpdateErrors(p[x], x, y, channel.w); 162 }; 163 164 for (size_t y = 0; y < channel.h; y++) { 165 const pixel_type *JXL_RESTRICT p = channel.Row(y); 166 PrecomputeReferences(channel, y, image, chan, &references); 167 InitPropsRow(&properties, static_props, y); 168 169 // TODO(veluca): avoid computing WP if we don't use its property or 170 // predictions. 171 if (y > 1 && channel.w > 8 && references.w == 0) { 172 for (size_t x = 0; x < 2; x++) { 173 compute_sample(p, x, y); 174 } 175 for (size_t x = 2; x < channel.w - 2; x++) { 176 pixel_type_w pred[kNumModularPredictors]; 177 if (multiple_predictors) { 178 PredictLearnAllNEC(&properties, channel.w, p + x, onerow, x, y, 179 references, &wp_state, pred); 180 } else { 181 pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] = 182 PredictLearnNEC(&properties, channel.w, p + x, onerow, x, y, 183 tree_samples.PredictorFromIndex(0), references, 184 &wp_state) 185 .guess; 186 } 187 (*total_pixels)++; 188 if (use_sample()) { 189 tree_samples.AddSample(p[x], properties, pred); 190 } 191 wp_state.UpdateErrors(p[x], x, y, channel.w); 192 } 193 for (size_t x = channel.w - 2; x < channel.w; x++) { 194 compute_sample(p, x, y); 195 } 196 } else { 197 for (size_t x = 0; x < channel.w; x++) { 198 compute_sample(p, x, y); 199 } 200 } 201 } 202 return true; 203 } 204 205 Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels, 206 int bitdepth, int prevprop) { 207 switch (tree_kind) { 208 case ModularOptions::TreeKind::kJpegTranscodeACMeta: 209 // All the data is 0, so no need for a fancy tree. 210 return {PropertyDecisionNode::Leaf(Predictor::Zero)}; 211 case ModularOptions::TreeKind::kTrivialTreeNoPredictor: 212 // All the data is 0, so no need for a fancy tree. 213 return {PropertyDecisionNode::Leaf(Predictor::Zero)}; 214 case ModularOptions::TreeKind::kFalconACMeta: 215 // All the data is 0 except the quant field. TODO(veluca): make that 0 216 // too. 217 return {PropertyDecisionNode::Leaf(Predictor::Left)}; 218 case ModularOptions::TreeKind::kACMeta: { 219 // Small image. 220 if (total_pixels < 1024) { 221 return {PropertyDecisionNode::Leaf(Predictor::Left)}; 222 } 223 Tree tree; 224 // 0: c > 1 225 tree.push_back(PropertyDecisionNode::Split(0, 1, 1)); 226 // 1: c > 2 227 tree.push_back(PropertyDecisionNode::Split(0, 2, 3)); 228 // 2: c > 0 229 tree.push_back(PropertyDecisionNode::Split(0, 0, 5)); 230 // 3: EPF control field (all 0 or 4), top > 3 231 tree.push_back(PropertyDecisionNode::Split(6, 3, 21)); 232 // 4: ACS+QF, y > 0 233 tree.push_back(PropertyDecisionNode::Split(2, 0, 7)); 234 // 5: CfL x 235 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient)); 236 // 6: CfL b 237 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient)); 238 // 7: QF: split according to the left quant value. 239 tree.push_back(PropertyDecisionNode::Split(7, 5, 9)); 240 // 8: ACS: split in 4 segments (8x8 from 0 to 3, large square 4-5, large 241 // rectangular 6-11, 8x8 12+), according to previous ACS value. 242 tree.push_back(PropertyDecisionNode::Split(7, 5, 15)); 243 // QF 244 tree.push_back(PropertyDecisionNode::Split(7, 11, 11)); 245 tree.push_back(PropertyDecisionNode::Split(7, 3, 13)); 246 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); 247 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); 248 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); 249 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); 250 // ACS 251 tree.push_back(PropertyDecisionNode::Split(7, 11, 17)); 252 tree.push_back(PropertyDecisionNode::Split(7, 3, 19)); 253 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); 254 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); 255 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); 256 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); 257 // EPF, left > 3 258 tree.push_back(PropertyDecisionNode::Split(7, 3, 23)); 259 tree.push_back(PropertyDecisionNode::Split(7, 3, 25)); 260 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); 261 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); 262 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); 263 tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); 264 return tree; 265 } 266 case ModularOptions::TreeKind::kWPFixedDC: { 267 std::vector<int32_t> cutoffs = { 268 -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15, 269 -11, -7, -4, -3, -1, 0, 1, 3, 5, 7, 11, 270 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500}; 271 return MakeFixedTree(kWPProp, cutoffs, Predictor::Weighted, total_pixels, 272 bitdepth); 273 } 274 case ModularOptions::TreeKind::kGradientFixedDC: { 275 std::vector<int32_t> cutoffs = { 276 -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15, 277 -11, -7, -4, -3, -1, 0, 1, 3, 5, 7, 11, 278 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500}; 279 return MakeFixedTree( 280 prevprop > 0 ? kNumNonrefProperties + 2 : kGradientProp, cutoffs, 281 Predictor::Gradient, total_pixels, bitdepth); 282 } 283 case ModularOptions::TreeKind::kLearn: { 284 JXL_DEBUG_ABORT("internal: kLearn is not predefined tree"); 285 return {}; 286 } 287 } 288 JXL_DEBUG_ABORT("internal: unexpected TreeKind: %d", 289 static_cast<int>(tree_kind)); 290 return {}; 291 } 292 293 StatusOr<Tree> LearnTree( 294 TreeSamples &&tree_samples, size_t total_pixels, 295 const ModularOptions &options, 296 const std::vector<ModularMultiplierInfo> &multiplier_info = {}, 297 StaticPropRange static_prop_range = {}) { 298 Tree tree; 299 for (size_t i = 0; i < kNumStaticProperties; i++) { 300 if (static_prop_range[i][1] == 0) { 301 static_prop_range[i][1] = std::numeric_limits<uint32_t>::max(); 302 } 303 } 304 if (!tree_samples.HasSamples()) { 305 tree.emplace_back(); 306 tree.back().predictor = tree_samples.PredictorFromIndex(0); 307 tree.back().property = -1; 308 tree.back().predictor_offset = 0; 309 tree.back().multiplier = 1; 310 return tree; 311 } 312 float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels; 313 float required_cost = pixel_fraction * 0.9 + 0.1; 314 tree_samples.AllSamplesDone(); 315 JXL_RETURN_IF_ERROR(ComputeBestTree( 316 tree_samples, options.splitting_heuristics_node_threshold * required_cost, 317 multiplier_info, static_prop_range, options.fast_decode_multiplier, 318 &tree)); 319 return tree; 320 } 321 322 Status EncodeModularChannelMAANS(const Image &image, pixel_type chan, 323 const weighted::Header &wp_header, 324 const Tree &global_tree, Token **tokenpp, 325 AuxOut *aux_out, size_t group_id, 326 bool skip_encoder_fast_path) { 327 const Channel &channel = image.channel[chan]; 328 JxlMemoryManager *memory_manager = channel.memory_manager(); 329 Token *tokenp = *tokenpp; 330 JXL_ENSURE(channel.w != 0 && channel.h != 0); 331 332 Image3F predictor_img; 333 if (kWantDebug) { 334 JXL_ASSIGN_OR_RETURN(predictor_img, 335 Image3F::Create(memory_manager, channel.w, channel.h)); 336 } 337 338 JXL_DEBUG_V(6, 339 "Encoding %" PRIuS "x%" PRIuS 340 " channel %d, " 341 "(shift=%i,%i)", 342 channel.w, channel.h, chan, channel.hshift, channel.vshift); 343 344 std::array<pixel_type, kNumStaticProperties> static_props = { 345 {chan, static_cast<int>(group_id)}}; 346 bool use_wp; 347 bool is_wp_only; 348 bool is_gradient_only; 349 size_t num_props; 350 FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp, 351 &is_wp_only, &is_gradient_only); 352 Properties properties(num_props); 353 MATreeLookup tree_lookup(tree); 354 JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size()); 355 356 // Check if this tree is a WP-only tree with a small enough property value 357 // range. 358 // Initialized to avoid clang-tidy complaining. 359 auto tree_lut = jxl::make_unique<TreeLut<uint16_t, false, false>>(); 360 if (is_wp_only) { 361 is_wp_only = TreeToLookupTable(tree, *tree_lut); 362 } 363 if (is_gradient_only) { 364 is_gradient_only = TreeToLookupTable(tree, *tree_lut); 365 } 366 367 if (is_wp_only && !skip_encoder_fast_path) { 368 for (size_t c = 0; c < 3; c++) { 369 FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]), 370 &predictor_img.Plane(c)); 371 } 372 const intptr_t onerow = channel.plane.PixelsPerRow(); 373 weighted::State wp_state(wp_header, channel.w, channel.h); 374 Properties properties(1); 375 for (size_t y = 0; y < channel.h; y++) { 376 const pixel_type *JXL_RESTRICT r = channel.Row(y); 377 for (size_t x = 0; x < channel.w; x++) { 378 size_t offset = 0; 379 pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); 380 pixel_type_w top = (y ? *(r + x - onerow) : left); 381 pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); 382 pixel_type_w topright = 383 (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); 384 pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); 385 int32_t guess = wp_state.Predict</*compute_properties=*/true>( 386 x, y, channel.w, top, left, topright, topleft, toptop, &properties, 387 offset); 388 uint32_t pos = 389 kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), 390 kPropRangeFast - 1); 391 uint32_t ctx_id = tree_lut->context_lookup[pos]; 392 int32_t residual = r[x] - guess; 393 *tokenp++ = Token(ctx_id, PackSigned(residual)); 394 wp_state.UpdateErrors(r[x], x, y, channel.w); 395 } 396 } 397 } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient && 398 tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && 399 !skip_encoder_fast_path) { 400 for (size_t c = 0; c < 3; c++) { 401 FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]), 402 &predictor_img.Plane(c)); 403 } 404 const intptr_t onerow = channel.plane.PixelsPerRow(); 405 for (size_t y = 0; y < channel.h; y++) { 406 const pixel_type *JXL_RESTRICT r = channel.Row(y); 407 for (size_t x = 0; x < channel.w; x++) { 408 pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); 409 pixel_type_w top = (y ? *(r + x - onerow) : left); 410 pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); 411 int32_t guess = ClampedGradient(top, left, topleft); 412 int32_t residual = r[x] - guess; 413 *tokenp++ = Token(tree[0].childID, PackSigned(residual)); 414 } 415 } 416 } else if (is_gradient_only && !skip_encoder_fast_path) { 417 for (size_t c = 0; c < 3; c++) { 418 FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]), 419 &predictor_img.Plane(c)); 420 } 421 const intptr_t onerow = channel.plane.PixelsPerRow(); 422 for (size_t y = 0; y < channel.h; y++) { 423 const pixel_type *JXL_RESTRICT r = channel.Row(y); 424 for (size_t x = 0; x < channel.w; x++) { 425 pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); 426 pixel_type_w top = (y ? *(r + x - onerow) : left); 427 pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); 428 int32_t guess = ClampedGradient(top, left, topleft); 429 uint32_t pos = 430 kPropRangeFast + 431 std::min<pixel_type_w>( 432 std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft), 433 kPropRangeFast - 1); 434 uint32_t ctx_id = tree_lut->context_lookup[pos]; 435 int32_t residual = r[x] - guess; 436 *tokenp++ = Token(ctx_id, PackSigned(residual)); 437 } 438 } 439 } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero && 440 tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && 441 !skip_encoder_fast_path) { 442 for (size_t c = 0; c < 3; c++) { 443 FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]), 444 &predictor_img.Plane(c)); 445 } 446 for (size_t y = 0; y < channel.h; y++) { 447 const pixel_type *JXL_RESTRICT p = channel.Row(y); 448 for (size_t x = 0; x < channel.w; x++) { 449 *tokenp++ = Token(tree[0].childID, PackSigned(p[x])); 450 } 451 } 452 } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted && 453 (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 && 454 tree[0].predictor_offset == 0 && !skip_encoder_fast_path) { 455 // multiplier is a power of 2. 456 for (size_t c = 0; c < 3; c++) { 457 FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]), 458 &predictor_img.Plane(c)); 459 } 460 uint32_t mul_shift = 461 FloorLog2Nonzero(static_cast<uint32_t>(tree[0].multiplier)); 462 const intptr_t onerow = channel.plane.PixelsPerRow(); 463 for (size_t y = 0; y < channel.h; y++) { 464 const pixel_type *JXL_RESTRICT r = channel.Row(y); 465 for (size_t x = 0; x < channel.w; x++) { 466 PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x, 467 y, tree[0].predictor); 468 pixel_type_w residual = r[x] - pred.guess; 469 JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual); 470 *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift)); 471 } 472 } 473 474 } else if (!use_wp && !skip_encoder_fast_path) { 475 const intptr_t onerow = channel.plane.PixelsPerRow(); 476 JXL_ASSIGN_OR_RETURN( 477 Channel references, 478 Channel::Create(memory_manager, 479 properties.size() - kNumNonrefProperties, channel.w)); 480 for (size_t y = 0; y < channel.h; y++) { 481 const pixel_type *JXL_RESTRICT p = channel.Row(y); 482 PrecomputeReferences(channel, y, image, chan, &references); 483 float *pred_img_row[3]; 484 if (kWantDebug) { 485 for (size_t c = 0; c < 3; c++) { 486 pred_img_row[c] = predictor_img.PlaneRow(c, y); 487 } 488 } 489 InitPropsRow(&properties, static_props, y); 490 for (size_t x = 0; x < channel.w; x++) { 491 PredictionResult res = 492 PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, 493 tree_lookup, references); 494 if (kWantDebug) { 495 for (size_t i = 0; i < 3; i++) { 496 pred_img_row[i][x] = PredictorColor(res.predictor)[i]; 497 } 498 } 499 pixel_type_w residual = p[x] - res.guess; 500 JXL_DASSERT(residual % res.multiplier == 0); 501 *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); 502 } 503 } 504 } else { 505 const intptr_t onerow = channel.plane.PixelsPerRow(); 506 JXL_ASSIGN_OR_RETURN( 507 Channel references, 508 Channel::Create(memory_manager, 509 properties.size() - kNumNonrefProperties, channel.w)); 510 weighted::State wp_state(wp_header, channel.w, channel.h); 511 for (size_t y = 0; y < channel.h; y++) { 512 const pixel_type *JXL_RESTRICT p = channel.Row(y); 513 PrecomputeReferences(channel, y, image, chan, &references); 514 float *pred_img_row[3]; 515 if (kWantDebug) { 516 for (size_t c = 0; c < 3; c++) { 517 pred_img_row[c] = predictor_img.PlaneRow(c, y); 518 } 519 } 520 InitPropsRow(&properties, static_props, y); 521 for (size_t x = 0; x < channel.w; x++) { 522 PredictionResult res = 523 PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, 524 tree_lookup, references, &wp_state); 525 if (kWantDebug) { 526 for (size_t i = 0; i < 3; i++) { 527 pred_img_row[i][x] = PredictorColor(res.predictor)[i]; 528 } 529 } 530 pixel_type_w residual = p[x] - res.guess; 531 JXL_DASSERT(residual % res.multiplier == 0); 532 *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); 533 wp_state.UpdateErrors(p[x], x, y, channel.w); 534 } 535 } 536 } 537 /* TODO(szabadka): Add cparams to the call stack here. 538 if (kWantDebug && WantDebugOutput(cparams)) { 539 DumpImage( 540 cparams, 541 ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(), 542 predictor_img); 543 } 544 */ 545 *tokenpp = tokenp; 546 return true; 547 } 548 549 Status ModularEncode(const Image &image, const ModularOptions &options, 550 BitWriter *writer, AuxOut *aux_out, LayerType layer, 551 size_t group_id, TreeSamples *tree_samples, 552 size_t *total_pixels, const Tree *tree, 553 GroupHeader *header, std::vector<Token> *tokens, 554 size_t *width) { 555 if (image.error) return JXL_FAILURE("Invalid image"); 556 JxlMemoryManager *memory_manager = image.memory_manager(); 557 size_t nb_channels = image.channel.size(); 558 JXL_DEBUG_V( 559 2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.", 560 nb_channels, image.bitdepth, image.w, image.h); 561 562 if (nb_channels < 1) { 563 return true; // is there any use for a zero-channel image? 564 } 565 566 // encode transforms 567 GroupHeader header_storage; 568 if (header == nullptr) header = &header_storage; 569 Bundle::Init(header); 570 if (options.predictor == Predictor::Weighted) { 571 weighted::PredictorMode(options.wp_mode, &header->wp_header); 572 } 573 header->transforms = image.transform; 574 // This doesn't actually work 575 if (tree != nullptr) { 576 header->use_global_tree = true; 577 } 578 if (tree_samples == nullptr && tree == nullptr) { 579 JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out)); 580 } 581 582 TreeSamples tree_samples_storage; 583 size_t total_pixels_storage = 0; 584 if (!total_pixels) total_pixels = &total_pixels_storage; 585 if (*total_pixels == 0) { 586 for (size_t i = 0; i < nb_channels; i++) { 587 if (i >= image.nb_meta_channels && 588 (image.channel[i].w > options.max_chan_size || 589 image.channel[i].h > options.max_chan_size)) { 590 break; 591 } 592 *total_pixels += image.channel[i].w * image.channel[i].h; 593 } 594 *total_pixels = std::max<size_t>(*total_pixels, 1); 595 } 596 // If there's no tree, compute one (or gather data to). 597 if (tree == nullptr && 598 options.tree_kind == ModularOptions::TreeKind::kLearn) { 599 bool gather_data = tree_samples != nullptr; 600 if (tree_samples == nullptr) { 601 JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor( 602 options.predictor, options.wp_tree_mode)); 603 JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties( 604 options.splitting_heuristics_properties, options.wp_tree_mode)); 605 std::vector<pixel_type> pixel_samples; 606 std::vector<pixel_type> diff_samples; 607 std::vector<uint32_t> group_pixel_count; 608 std::vector<uint32_t> channel_pixel_count; 609 CollectPixelSamples(image, options, 0, group_pixel_count, 610 channel_pixel_count, pixel_samples, diff_samples); 611 std::vector<ModularMultiplierInfo> placeholder_multiplier_info; 612 StaticPropRange range; 613 tree_samples_storage.PreQuantizeProperties( 614 range, placeholder_multiplier_info, group_pixel_count, 615 channel_pixel_count, pixel_samples, diff_samples, 616 options.max_property_values); 617 } 618 for (size_t i = 0; i < nb_channels; i++) { 619 if (!image.channel[i].w || !image.channel[i].h) { 620 continue; // skip empty channels 621 } 622 if (i >= image.nb_meta_channels && 623 (image.channel[i].w > options.max_chan_size || 624 image.channel[i].h > options.max_chan_size)) { 625 break; 626 } 627 JXL_RETURN_IF_ERROR(GatherTreeData( 628 image, i, group_id, header->wp_header, options, 629 gather_data ? *tree_samples : tree_samples_storage, total_pixels)); 630 } 631 if (gather_data) return true; 632 } 633 634 JXL_ENSURE((tree == nullptr) == (tokens == nullptr)); 635 636 Tree tree_storage; 637 std::vector<std::vector<Token>> tokens_storage(1); 638 // Compute tree. 639 if (tree == nullptr) { 640 EntropyEncodingData code; 641 std::vector<uint8_t> context_map; 642 643 std::vector<std::vector<Token>> tree_tokens(1); 644 if (options.tree_kind == ModularOptions::TreeKind::kLearn) { 645 JXL_ASSIGN_OR_RETURN( 646 tree_storage, 647 LearnTree(std::move(tree_samples_storage), *total_pixels, options)); 648 } else { 649 tree_storage = PredefinedTree(options.tree_kind, *total_pixels, 650 image.bitdepth, options.max_properties); 651 } 652 tree = &tree_storage; 653 tokens = tokens_storage.data(); 654 655 Tree decoded_tree; 656 JXL_RETURN_IF_ERROR(TokenizeTree(*tree, tree_tokens.data(), &decoded_tree)); 657 JXL_ENSURE(tree->size() == decoded_tree.size()); 658 tree_storage = std::move(decoded_tree); 659 660 /* TODO(szabadka) Add text output callback 661 if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) { 662 PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id)); 663 } */ 664 665 // Write tree 666 JXL_ASSIGN_OR_RETURN(size_t cost, 667 BuildAndEncodeHistograms( 668 memory_manager, options.histogram_params, 669 kNumTreeContexts, tree_tokens, &code, &context_map, 670 writer, LayerType::ModularTree, aux_out)); 671 (void)cost; 672 JXL_RETURN_IF_ERROR(WriteTokens(tree_tokens[0], code, context_map, 0, 673 writer, LayerType::ModularTree, aux_out)); 674 } 675 676 size_t image_width = 0; 677 size_t total_tokens = 0; 678 for (size_t i = 0; i < nb_channels; i++) { 679 if (i >= image.nb_meta_channels && 680 (image.channel[i].w > options.max_chan_size || 681 image.channel[i].h > options.max_chan_size)) { 682 break; 683 } 684 if (image.channel[i].w > image_width) image_width = image.channel[i].w; 685 total_tokens += image.channel[i].w * image.channel[i].h; 686 } 687 if (options.zero_tokens) { 688 tokens->resize(tokens->size() + total_tokens, {0, 0}); 689 } else { 690 // Do one big allocation for all the tokens we'll need, 691 // to avoid reallocs that might require copying. 692 size_t pos = tokens->size(); 693 tokens->resize(pos + total_tokens); 694 Token *tokenp = tokens->data() + pos; 695 for (size_t i = 0; i < nb_channels; i++) { 696 if (!image.channel[i].w || !image.channel[i].h) { 697 continue; // skip empty channels 698 } 699 if (i >= image.nb_meta_channels && 700 (image.channel[i].w > options.max_chan_size || 701 image.channel[i].h > options.max_chan_size)) { 702 break; 703 } 704 JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS( 705 image, i, header->wp_header, *tree, &tokenp, aux_out, group_id, 706 options.skip_encoder_fast_path)); 707 } 708 // Make sure we actually wrote all tokens 709 JXL_ENSURE(tokenp == tokens->data() + tokens->size()); 710 } 711 712 // Write data if not using a global tree/ANS stream. 713 if (!header->use_global_tree) { 714 EntropyEncodingData code; 715 std::vector<uint8_t> context_map; 716 HistogramParams histo_params = options.histogram_params; 717 histo_params.image_widths.push_back(image_width); 718 JXL_ASSIGN_OR_RETURN( 719 size_t cost, 720 BuildAndEncodeHistograms(memory_manager, histo_params, 721 (tree->size() + 1) / 2, tokens_storage, &code, 722 &context_map, writer, layer, aux_out)); 723 (void)cost; 724 JXL_RETURN_IF_ERROR(WriteTokens(tokens_storage[0], code, context_map, 0, 725 writer, layer, aux_out)); 726 } else { 727 *width = image_width; 728 } 729 return true; 730 } 731 732 Status ModularGenericCompress(Image &image, const ModularOptions &opts, 733 BitWriter *writer, AuxOut *aux_out, 734 LayerType layer, size_t group_id, 735 TreeSamples *tree_samples, size_t *total_pixels, 736 const Tree *tree, GroupHeader *header, 737 std::vector<Token> *tokens, size_t *width) { 738 if (image.w == 0 || image.h == 0) return true; 739 ModularOptions options = opts; // Make a copy to modify it. 740 741 if (options.predictor == kUndefinedPredictor) { 742 options.predictor = Predictor::Gradient; 743 } 744 745 size_t bits = writer ? writer->BitsWritten() : 0; 746 JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer, 747 group_id, tree_samples, total_pixels, tree, 748 header, tokens, width)); 749 bits = writer ? writer->BitsWritten() - bits : 0; 750 if (writer) { 751 JXL_DEBUG_V(4, 752 "Modular-encoded a %" PRIuS "x%" PRIuS 753 " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes", 754 image.w, image.h, image.bitdepth, image.channel.size(), 755 bits / 8); 756 } 757 (void)bits; 758 return true; 759 } 760 761 } // namespace jxl