enc_ma.cc (38538B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/modular/encoding/enc_ma.h" 7 8 #include <algorithm> 9 #include <cstdlib> 10 #include <limits> 11 #include <numeric> 12 #include <queue> 13 #include <vector> 14 15 #include "lib/jxl/modular/encoding/ma_common.h" 16 17 #undef HWY_TARGET_INCLUDE 18 #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" 19 #include <hwy/foreach_target.h> 20 #include <hwy/highway.h> 21 22 #include "lib/jxl/base/fast_math-inl.h" 23 #include "lib/jxl/base/random.h" 24 #include "lib/jxl/enc_ans.h" 25 #include "lib/jxl/modular/encoding/context_predict.h" 26 #include "lib/jxl/modular/options.h" 27 #include "lib/jxl/pack_signed.h" 28 HWY_BEFORE_NAMESPACE(); 29 namespace jxl { 30 namespace HWY_NAMESPACE { 31 32 // These templates are not found via ADL. 33 using hwy::HWY_NAMESPACE::Eq; 34 using hwy::HWY_NAMESPACE::IfThenElse; 35 using hwy::HWY_NAMESPACE::Lt; 36 using hwy::HWY_NAMESPACE::Max; 37 38 const HWY_FULL(float) df; 39 const HWY_FULL(int32_t) di; 40 size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } 41 42 // Compute entropy of the histogram, taking into account the minimum probability 43 // for symbols with non-zero counts. 44 float EstimateBits(const int32_t *counts, size_t num_symbols) { 45 int32_t total = std::accumulate(counts, counts + num_symbols, 0); 46 const auto zero = Zero(df); 47 const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); 48 const auto inv_total = Set(df, 1.0f / total); 49 auto bits_lanes = Zero(df); 50 auto total_v = Set(di, total); 51 for (size_t i = 0; i < num_symbols; i += Lanes(df)) { 52 const auto counts_iv = LoadU(di, &counts[i]); 53 const auto counts_fv = ConvertTo(df, counts_iv); 54 const auto probs = Mul(counts_fv, inv_total); 55 const auto mprobs = Max(probs, minprob); 56 const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), 57 BitCast(di, FastLog2f(df, mprobs))); 58 bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); 59 } 60 return GetLane(SumOfLanes(df, bits_lanes)); 61 } 62 63 void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, 64 int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { 65 // Note that the tree splits on *strictly greater*. 66 (*tree)[pos].lchild = tree->size(); 67 (*tree)[pos].rchild = tree->size() + 1; 68 (*tree)[pos].splitval = splitval; 69 (*tree)[pos].property = property; 70 tree->emplace_back(); 71 tree->back().property = -1; 72 tree->back().predictor = rpred; 73 tree->back().predictor_offset = roff; 74 tree->back().multiplier = 1; 75 tree->emplace_back(); 76 tree->back().property = -1; 77 tree->back().predictor = lpred; 78 tree->back().predictor_offset = loff; 79 tree->back().multiplier = 1; 80 } 81 82 enum class IntersectionType { kNone, kPartial, kInside }; 83 IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, 84 uint32_t &partial_axis, uint32_t &partial_val) { 85 bool partial = false; 86 for (size_t i = 0; i < kNumStaticProperties; i++) { 87 if (haystack[i][0] >= needle[i][1]) { 88 return IntersectionType::kNone; 89 } 90 if (haystack[i][1] <= needle[i][0]) { 91 return IntersectionType::kNone; 92 } 93 if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { 94 continue; 95 } 96 partial = true; 97 partial_axis = i; 98 if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { 99 partial_val = haystack[i][0] - 1; 100 } else { 101 JXL_DASSERT(haystack[i][1] > needle[i][0] && 102 haystack[i][1] < needle[i][1]); 103 partial_val = haystack[i][1] - 1; 104 } 105 } 106 return partial ? IntersectionType::kPartial : IntersectionType::kInside; 107 } 108 109 void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, 110 size_t end, size_t prop) { 111 auto cmp = [&](size_t a, size_t b) { 112 return static_cast<int32_t>(tree_samples.Property(prop, a)) - 113 static_cast<int32_t>(tree_samples.Property(prop, b)); 114 }; 115 Rng rng(0); 116 while (end > begin + 1) { 117 { 118 size_t pivot = rng.UniformU(begin, end); 119 tree_samples.Swap(begin, pivot); 120 } 121 size_t pivot_begin = begin; 122 size_t pivot_end = pivot_begin + 1; 123 for (size_t i = begin + 1; i < end; i++) { 124 JXL_DASSERT(i >= pivot_end); 125 JXL_DASSERT(pivot_end > pivot_begin); 126 int32_t cmp_result = cmp(i, pivot_begin); 127 if (cmp_result < 0) { // i < pivot, move pivot forward and put i before 128 // the pivot. 129 tree_samples.ThreeShuffle(pivot_begin, pivot_end, i); 130 pivot_begin++; 131 pivot_end++; 132 } else if (cmp_result == 0) { 133 tree_samples.Swap(pivot_end, i); 134 pivot_end++; 135 } 136 } 137 JXL_DASSERT(pivot_begin >= begin); 138 JXL_DASSERT(pivot_end > pivot_begin); 139 JXL_DASSERT(pivot_end <= end); 140 for (size_t i = begin; i < pivot_begin; i++) { 141 JXL_DASSERT(cmp(i, pivot_begin) < 0); 142 } 143 for (size_t i = pivot_end; i < end; i++) { 144 JXL_DASSERT(cmp(i, pivot_begin) > 0); 145 } 146 for (size_t i = pivot_begin; i < pivot_end; i++) { 147 JXL_DASSERT(cmp(i, pivot_begin) == 0); 148 } 149 // We now have that [begin, pivot_begin) is < pivot, [pivot_begin, 150 // pivot_end) is = pivot, and [pivot_end, end) is > pivot. 151 // If pos falls in the first or the last interval, we continue in that 152 // interval; otherwise, we are done. 153 if (pivot_begin > pos) { 154 end = pivot_begin; 155 } else if (pivot_end < pos) { 156 begin = pivot_end; 157 } else { 158 break; 159 } 160 } 161 } 162 163 void FindBestSplit(TreeSamples &tree_samples, float threshold, 164 const std::vector<ModularMultiplierInfo> &mul_info, 165 StaticPropRange initial_static_prop_range, 166 float fast_decode_multiplier, Tree *tree) { 167 struct NodeInfo { 168 size_t pos; 169 size_t begin; 170 size_t end; 171 uint64_t used_properties; 172 StaticPropRange static_prop_range; 173 }; 174 std::vector<NodeInfo> nodes; 175 nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, 176 initial_static_prop_range}); 177 178 size_t num_predictors = tree_samples.NumPredictors(); 179 size_t num_properties = tree_samples.NumProperties(); 180 181 // TODO(veluca): consider parallelizing the search (processing multiple nodes 182 // at a time). 183 while (!nodes.empty()) { 184 size_t pos = nodes.back().pos; 185 size_t begin = nodes.back().begin; 186 size_t end = nodes.back().end; 187 uint64_t used_properties = nodes.back().used_properties; 188 StaticPropRange static_prop_range = nodes.back().static_prop_range; 189 nodes.pop_back(); 190 if (begin == end) continue; 191 192 struct SplitInfo { 193 size_t prop = 0; 194 uint32_t val = 0; 195 size_t pos = 0; 196 float lcost = std::numeric_limits<float>::max(); 197 float rcost = std::numeric_limits<float>::max(); 198 Predictor lpred = Predictor::Zero; 199 Predictor rpred = Predictor::Zero; 200 float Cost() const { return lcost + rcost; } 201 }; 202 203 SplitInfo best_split_static_constant; 204 SplitInfo best_split_static; 205 SplitInfo best_split_nonstatic; 206 SplitInfo best_split_nowp; 207 208 JXL_DASSERT(begin <= end); 209 JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); 210 211 // Compute the maximum token in the range. 212 size_t max_symbols = 0; 213 for (size_t pred = 0; pred < num_predictors; pred++) { 214 for (size_t i = begin; i < end; i++) { 215 uint32_t tok = tree_samples.Token(pred, i); 216 max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; 217 } 218 } 219 max_symbols = Padded(max_symbols); 220 std::vector<int32_t> counts(max_symbols * num_predictors); 221 std::vector<uint32_t> tot_extra_bits(num_predictors); 222 for (size_t pred = 0; pred < num_predictors; pred++) { 223 for (size_t i = begin; i < end; i++) { 224 counts[pred * max_symbols + tree_samples.Token(pred, i)] += 225 tree_samples.Count(i); 226 tot_extra_bits[pred] += 227 tree_samples.NBits(pred, i) * tree_samples.Count(i); 228 } 229 } 230 231 float base_bits; 232 { 233 size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); 234 base_bits = 235 EstimateBits(counts.data() + pred * max_symbols, max_symbols) + 236 tot_extra_bits[pred]; 237 } 238 239 SplitInfo *best = &best_split_nonstatic; 240 241 SplitInfo forced_split; 242 // The multiplier ranges cut halfway through the current ranges of static 243 // properties. We do this even if the current node is not a leaf, to 244 // minimize the number of nodes in the resulting tree. 245 for (const auto &mmi : mul_info) { 246 uint32_t axis; 247 uint32_t val; 248 IntersectionType t = 249 BoxIntersects(static_prop_range, mmi.range, axis, val); 250 if (t == IntersectionType::kNone) continue; 251 if (t == IntersectionType::kInside) { 252 (*tree)[pos].multiplier = mmi.multiplier; 253 break; 254 } 255 if (t == IntersectionType::kPartial) { 256 forced_split.val = tree_samples.QuantizeProperty(axis, val); 257 forced_split.prop = axis; 258 forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; 259 forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; 260 best = &forced_split; 261 best->pos = begin; 262 JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); 263 for (size_t x = begin; x < end; x++) { 264 if (tree_samples.Property(best->prop, x) <= best->val) { 265 best->pos++; 266 } 267 } 268 break; 269 } 270 } 271 272 if (best != &forced_split) { 273 std::vector<int> prop_value_used_count; 274 std::vector<int> count_increase; 275 std::vector<size_t> extra_bits_increase; 276 // For each property, compute which of its values are used, and what 277 // tokens correspond to those usages. Then, iterate through the values, 278 // and compute the entropy of each side of the split (of the form `prop > 279 // threshold`). Finally, find the split that minimizes the cost. 280 struct CostInfo { 281 float cost = std::numeric_limits<float>::max(); 282 float extra_cost = 0; 283 float Cost() const { return cost + extra_cost; } 284 Predictor pred; // will be uninitialized in some cases, but never used. 285 }; 286 std::vector<CostInfo> costs_l; 287 std::vector<CostInfo> costs_r; 288 289 std::vector<int32_t> counts_above(max_symbols); 290 std::vector<int32_t> counts_below(max_symbols); 291 292 // The lower the threshold, the higher the expected noisiness of the 293 // estimate. Thus, discourage changing predictors. 294 float change_pred_penalty = 800.0f / (100.0f + threshold); 295 for (size_t prop = 0; prop < num_properties && base_bits > threshold; 296 prop++) { 297 costs_l.clear(); 298 costs_r.clear(); 299 size_t prop_size = tree_samples.NumPropertyValues(prop); 300 if (extra_bits_increase.size() < prop_size) { 301 count_increase.resize(prop_size * max_symbols); 302 extra_bits_increase.resize(prop_size); 303 } 304 // Clear prop_value_used_count (which cannot be cleared "on the go") 305 prop_value_used_count.clear(); 306 prop_value_used_count.resize(prop_size); 307 308 size_t first_used = prop_size; 309 size_t last_used = 0; 310 311 // TODO(veluca): consider finding multiple splits along a single 312 // property at the same time, possibly with a bottom-up approach. 313 for (size_t i = begin; i < end; i++) { 314 size_t p = tree_samples.Property(prop, i); 315 prop_value_used_count[p]++; 316 last_used = std::max(last_used, p); 317 first_used = std::min(first_used, p); 318 } 319 costs_l.resize(last_used - first_used); 320 costs_r.resize(last_used - first_used); 321 // For all predictors, compute the right and left costs of each split. 322 for (size_t pred = 0; pred < num_predictors; pred++) { 323 // Compute cost and histogram increments for each property value. 324 for (size_t i = begin; i < end; i++) { 325 size_t p = tree_samples.Property(prop, i); 326 size_t cnt = tree_samples.Count(i); 327 size_t sym = tree_samples.Token(pred, i); 328 count_increase[p * max_symbols + sym] += cnt; 329 extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; 330 } 331 memcpy(counts_above.data(), counts.data() + pred * max_symbols, 332 max_symbols * sizeof counts_above[0]); 333 memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); 334 size_t extra_bits_below = 0; 335 // Exclude last used: this ensures neither counts_above nor 336 // counts_below is empty. 337 for (size_t i = first_used; i < last_used; i++) { 338 if (!prop_value_used_count[i]) continue; 339 extra_bits_below += extra_bits_increase[i]; 340 // The increase for this property value has been used, and will not 341 // be used again: clear it. Also below. 342 extra_bits_increase[i] = 0; 343 for (size_t sym = 0; sym < max_symbols; sym++) { 344 counts_above[sym] -= count_increase[i * max_symbols + sym]; 345 counts_below[sym] += count_increase[i * max_symbols + sym]; 346 count_increase[i * max_symbols + sym] = 0; 347 } 348 float rcost = EstimateBits(counts_above.data(), max_symbols) + 349 tot_extra_bits[pred] - extra_bits_below; 350 float lcost = EstimateBits(counts_below.data(), max_symbols) + 351 extra_bits_below; 352 JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); 353 float penalty = 0; 354 // Never discourage moving away from the Weighted predictor. 355 if (tree_samples.PredictorFromIndex(pred) != 356 (*tree)[pos].predictor && 357 (*tree)[pos].predictor != Predictor::Weighted) { 358 penalty = change_pred_penalty; 359 } 360 // If everything else is equal, disfavour Weighted (slower) and 361 // favour Zero (faster if it's the only predictor used in a 362 // group+channel combination) 363 if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { 364 penalty += 1e-8; 365 } 366 if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { 367 penalty -= 1e-8; 368 } 369 if (rcost + penalty < costs_r[i - first_used].Cost()) { 370 costs_r[i - first_used].cost = rcost; 371 costs_r[i - first_used].extra_cost = penalty; 372 costs_r[i - first_used].pred = 373 tree_samples.PredictorFromIndex(pred); 374 } 375 if (lcost + penalty < costs_l[i - first_used].Cost()) { 376 costs_l[i - first_used].cost = lcost; 377 costs_l[i - first_used].extra_cost = penalty; 378 costs_l[i - first_used].pred = 379 tree_samples.PredictorFromIndex(pred); 380 } 381 } 382 } 383 // Iterate through the possible splits and find the one with minimum sum 384 // of costs of the two sides. 385 size_t split = begin; 386 for (size_t i = first_used; i < last_used; i++) { 387 if (!prop_value_used_count[i]) continue; 388 split += prop_value_used_count[i]; 389 float rcost = costs_r[i - first_used].cost; 390 float lcost = costs_l[i - first_used].cost; 391 // WP was not used + we would use the WP property or predictor 392 bool adds_wp = 393 (tree_samples.PropertyFromIndex(prop) == kWPProp && 394 (used_properties & (1LU << prop)) == 0) || 395 ((costs_l[i - first_used].pred == Predictor::Weighted || 396 costs_r[i - first_used].pred == Predictor::Weighted) && 397 (*tree)[pos].predictor != Predictor::Weighted); 398 bool zero_entropy_side = rcost == 0 || lcost == 0; 399 400 SplitInfo &best = 401 prop < kNumStaticProperties 402 ? (zero_entropy_side ? best_split_static_constant 403 : best_split_static) 404 : (adds_wp ? best_split_nonstatic : best_split_nowp); 405 if (lcost + rcost < best.Cost()) { 406 best.prop = prop; 407 best.val = i; 408 best.pos = split; 409 best.lcost = lcost; 410 best.lpred = costs_l[i - first_used].pred; 411 best.rcost = rcost; 412 best.rpred = costs_r[i - first_used].pred; 413 } 414 } 415 // Clear extra_bits_increase and cost_increase for last_used. 416 extra_bits_increase[last_used] = 0; 417 for (size_t sym = 0; sym < max_symbols; sym++) { 418 count_increase[last_used * max_symbols + sym] = 0; 419 } 420 } 421 422 // Try to avoid introducing WP. 423 if (best_split_nowp.Cost() + threshold < base_bits && 424 best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { 425 best = &best_split_nowp; 426 } 427 // Split along static props if possible and not significantly more 428 // expensive. 429 if (best_split_static.Cost() + threshold < base_bits && 430 best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { 431 best = &best_split_static; 432 } 433 // Split along static props to create constant nodes if possible. 434 if (best_split_static_constant.Cost() + threshold < base_bits) { 435 best = &best_split_static_constant; 436 } 437 } 438 439 if (best->Cost() + threshold < base_bits) { 440 uint32_t p = tree_samples.PropertyFromIndex(best->prop); 441 pixel_type dequant = 442 tree_samples.UnquantizeProperty(best->prop, best->val); 443 // Split node and try to split children. 444 MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); 445 // "Sort" according to winning property 446 SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop); 447 if (p >= kNumStaticProperties) { 448 used_properties |= 1 << best->prop; 449 } 450 auto new_sp_range = static_prop_range; 451 if (p < kNumStaticProperties) { 452 JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); 453 new_sp_range[p][1] = dequant + 1; 454 JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); 455 } 456 nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, 457 used_properties, new_sp_range}); 458 new_sp_range = static_prop_range; 459 if (p < kNumStaticProperties) { 460 JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); 461 new_sp_range[p][0] = dequant + 1; 462 JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); 463 } 464 nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, 465 used_properties, new_sp_range}); 466 } 467 } 468 } 469 470 // NOLINTNEXTLINE(google-readability-namespace-comments) 471 } // namespace HWY_NAMESPACE 472 } // namespace jxl 473 HWY_AFTER_NAMESPACE(); 474 475 #if HWY_ONCE 476 namespace jxl { 477 478 HWY_EXPORT(FindBestSplit); // Local function. 479 480 Status ComputeBestTree(TreeSamples &tree_samples, float threshold, 481 const std::vector<ModularMultiplierInfo> &mul_info, 482 StaticPropRange static_prop_range, 483 float fast_decode_multiplier, Tree *tree) { 484 // TODO(veluca): take into account that different contexts can have different 485 // uint configs. 486 // 487 // Initialize tree. 488 tree->emplace_back(); 489 tree->back().property = -1; 490 tree->back().predictor = tree_samples.PredictorFromIndex(0); 491 tree->back().predictor_offset = 0; 492 tree->back().multiplier = 1; 493 JXL_ENSURE(tree_samples.NumProperties() < 64); 494 495 JXL_ENSURE(tree_samples.NumDistinctSamples() <= 496 std::numeric_limits<uint32_t>::max()); 497 HWY_DYNAMIC_DISPATCH(FindBestSplit) 498 (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, 499 tree); 500 return true; 501 } 502 503 #if JXL_CXX_LANG < JXL_CXX_17 504 constexpr int32_t TreeSamples::kPropertyRange; 505 constexpr uint32_t TreeSamples::kDedupEntryUnused; 506 #endif 507 508 Status TreeSamples::SetPredictor(Predictor predictor, 509 ModularOptions::TreeMode wp_tree_mode) { 510 if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { 511 predictors = {Predictor::Weighted}; 512 residuals.resize(1); 513 return true; 514 } 515 if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && 516 predictor == Predictor::Weighted) { 517 return JXL_FAILURE("Invalid predictor settings"); 518 } 519 if (predictor == Predictor::Variable) { 520 for (size_t i = 0; i < kNumModularPredictors; i++) { 521 predictors.push_back(static_cast<Predictor>(i)); 522 } 523 std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); 524 std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); 525 } else if (predictor == Predictor::Best) { 526 predictors = {Predictor::Weighted, Predictor::Gradient}; 527 } else { 528 predictors = {predictor}; 529 } 530 if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { 531 auto wp_it = 532 std::find(predictors.begin(), predictors.end(), Predictor::Weighted); 533 if (wp_it != predictors.end()) { 534 predictors.erase(wp_it); 535 } 536 } 537 residuals.resize(predictors.size()); 538 return true; 539 } 540 541 Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, 542 ModularOptions::TreeMode wp_tree_mode) { 543 props_to_use = properties; 544 if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { 545 props_to_use = {static_cast<uint32_t>(kWPProp)}; 546 } 547 if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { 548 props_to_use = {static_cast<uint32_t>(kGradientProp)}; 549 } 550 if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { 551 auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp); 552 if (it != props_to_use.end()) { 553 props_to_use.erase(it); 554 } 555 } 556 if (props_to_use.empty()) { 557 return JXL_FAILURE("Invalid property set configuration"); 558 } 559 props.resize(props_to_use.size()); 560 return true; 561 } 562 563 void TreeSamples::InitTable(size_t log_size) { 564 size_t size = 1ULL << log_size; 565 if (dedup_table_.size() == size) return; 566 dedup_table_.resize(size, kDedupEntryUnused); 567 for (size_t i = 0; i < NumDistinctSamples(); i++) { 568 if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { 569 AddToTable(i); 570 } 571 } 572 } 573 574 bool TreeSamples::AddToTableAndMerge(size_t a) { 575 size_t pos1 = Hash1(a); 576 size_t pos2 = Hash2(a); 577 if (dedup_table_[pos1] != kDedupEntryUnused && 578 IsSameSample(a, dedup_table_[pos1])) { 579 JXL_DASSERT(sample_counts[a] == 1); 580 sample_counts[dedup_table_[pos1]]++; 581 // Remove from hash table samples that are saturated. 582 if (sample_counts[dedup_table_[pos1]] == 583 std::numeric_limits<uint16_t>::max()) { 584 dedup_table_[pos1] = kDedupEntryUnused; 585 } 586 return true; 587 } 588 if (dedup_table_[pos2] != kDedupEntryUnused && 589 IsSameSample(a, dedup_table_[pos2])) { 590 JXL_DASSERT(sample_counts[a] == 1); 591 sample_counts[dedup_table_[pos2]]++; 592 // Remove from hash table samples that are saturated. 593 if (sample_counts[dedup_table_[pos2]] == 594 std::numeric_limits<uint16_t>::max()) { 595 dedup_table_[pos2] = kDedupEntryUnused; 596 } 597 return true; 598 } 599 AddToTable(a); 600 return false; 601 } 602 603 void TreeSamples::AddToTable(size_t a) { 604 size_t pos1 = Hash1(a); 605 size_t pos2 = Hash2(a); 606 if (dedup_table_[pos1] == kDedupEntryUnused) { 607 dedup_table_[pos1] = a; 608 } else if (dedup_table_[pos2] == kDedupEntryUnused) { 609 dedup_table_[pos2] = a; 610 } 611 } 612 613 void TreeSamples::PrepareForSamples(size_t num_samples) { 614 for (auto &res : residuals) { 615 res.reserve(res.size() + num_samples); 616 } 617 for (auto &p : props) { 618 p.reserve(p.size() + num_samples); 619 } 620 size_t total_num_samples = num_samples + sample_counts.size(); 621 size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2); 622 InitTable(next_size); 623 } 624 625 size_t TreeSamples::Hash1(size_t a) const { 626 constexpr uint64_t constant = 0x1e35a7bd; 627 uint64_t h = constant; 628 for (const auto &r : residuals) { 629 h = h * constant + r[a].tok; 630 h = h * constant + r[a].nbits; 631 } 632 for (const auto &p : props) { 633 h = h * constant + p[a]; 634 } 635 return (h >> 16) & (dedup_table_.size() - 1); 636 } 637 size_t TreeSamples::Hash2(size_t a) const { 638 constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; 639 uint64_t h = constant; 640 for (const auto &p : props) { 641 h = h * constant ^ p[a]; 642 } 643 for (const auto &r : residuals) { 644 h = h * constant ^ r[a].tok; 645 h = h * constant ^ r[a].nbits; 646 } 647 return (h >> 16) & (dedup_table_.size() - 1); 648 } 649 650 bool TreeSamples::IsSameSample(size_t a, size_t b) const { 651 bool ret = true; 652 for (const auto &r : residuals) { 653 if (r[a].tok != r[b].tok) { 654 ret = false; 655 } 656 if (r[a].nbits != r[b].nbits) { 657 ret = false; 658 } 659 } 660 for (const auto &p : props) { 661 if (p[a] != p[b]) { 662 ret = false; 663 } 664 } 665 return ret; 666 } 667 668 void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, 669 const pixel_type_w *predictions) { 670 for (size_t i = 0; i < predictors.size(); i++) { 671 pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; 672 uint32_t tok, nbits, bits; 673 HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); 674 JXL_DASSERT(tok < 256); 675 JXL_DASSERT(nbits < 256); 676 residuals[i].emplace_back( 677 ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)}); 678 } 679 for (size_t i = 0; i < props_to_use.size(); i++) { 680 props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]])); 681 } 682 sample_counts.push_back(1); 683 num_samples++; 684 if (AddToTableAndMerge(sample_counts.size() - 1)) { 685 for (auto &r : residuals) r.pop_back(); 686 for (auto &p : props) p.pop_back(); 687 sample_counts.pop_back(); 688 } 689 } 690 691 void TreeSamples::Swap(size_t a, size_t b) { 692 if (a == b) return; 693 for (auto &r : residuals) { 694 std::swap(r[a], r[b]); 695 } 696 for (auto &p : props) { 697 std::swap(p[a], p[b]); 698 } 699 std::swap(sample_counts[a], sample_counts[b]); 700 } 701 702 void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) { 703 if (b == c) { 704 Swap(a, b); 705 return; 706 } 707 708 for (auto &r : residuals) { 709 auto tmp = r[a]; 710 r[a] = r[c]; 711 r[c] = r[b]; 712 r[b] = tmp; 713 } 714 for (auto &p : props) { 715 auto tmp = p[a]; 716 p[a] = p[c]; 717 p[c] = p[b]; 718 p[b] = tmp; 719 } 720 auto tmp = sample_counts[a]; 721 sample_counts[a] = sample_counts[c]; 722 sample_counts[c] = sample_counts[b]; 723 sample_counts[b] = tmp; 724 } 725 726 namespace { 727 std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, 728 size_t num_chunks) { 729 if (histogram.empty()) return {}; 730 // TODO(veluca): selecting distinct quantiles is likely not the best 731 // way to go about this. 732 std::vector<int32_t> thresholds; 733 uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); 734 uint64_t cumsum = 0; 735 uint64_t threshold = 1; 736 for (size_t i = 0; i + 1 < histogram.size(); i++) { 737 cumsum += histogram[i]; 738 if (cumsum >= threshold * sum / num_chunks) { 739 thresholds.push_back(i); 740 while (cumsum > threshold * sum / num_chunks) threshold++; 741 } 742 } 743 return thresholds; 744 } 745 746 std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, 747 size_t num_chunks) { 748 if (samples.empty()) return {}; 749 int min = *std::min_element(samples.begin(), samples.end()); 750 constexpr int kRange = 512; 751 min = std::min(std::max(min, -kRange), kRange); 752 std::vector<uint32_t> counts(2 * kRange + 1); 753 for (int s : samples) { 754 uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min; 755 counts[sample_offset]++; 756 } 757 std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); 758 for (auto &v : thresholds) v += min; 759 return thresholds; 760 } 761 } // namespace 762 763 void TreeSamples::PreQuantizeProperties( 764 const StaticPropRange &range, 765 const std::vector<ModularMultiplierInfo> &multiplier_info, 766 const std::vector<uint32_t> &group_pixel_count, 767 const std::vector<uint32_t> &channel_pixel_count, 768 std::vector<pixel_type> &pixel_samples, 769 std::vector<pixel_type> &diff_samples, size_t max_property_values) { 770 // If we have forced splits because of multipliers, choose channel and group 771 // thresholds accordingly. 772 std::vector<int32_t> group_multiplier_thresholds; 773 std::vector<int32_t> channel_multiplier_thresholds; 774 for (const auto &v : multiplier_info) { 775 if (v.range[0][0] != range[0][0]) { 776 channel_multiplier_thresholds.push_back(v.range[0][0] - 1); 777 } 778 if (v.range[0][1] != range[0][1]) { 779 channel_multiplier_thresholds.push_back(v.range[0][1] - 1); 780 } 781 if (v.range[1][0] != range[1][0]) { 782 group_multiplier_thresholds.push_back(v.range[1][0] - 1); 783 } 784 if (v.range[1][1] != range[1][1]) { 785 group_multiplier_thresholds.push_back(v.range[1][1] - 1); 786 } 787 } 788 std::sort(channel_multiplier_thresholds.begin(), 789 channel_multiplier_thresholds.end()); 790 channel_multiplier_thresholds.resize( 791 std::unique(channel_multiplier_thresholds.begin(), 792 channel_multiplier_thresholds.end()) - 793 channel_multiplier_thresholds.begin()); 794 std::sort(group_multiplier_thresholds.begin(), 795 group_multiplier_thresholds.end()); 796 group_multiplier_thresholds.resize( 797 std::unique(group_multiplier_thresholds.begin(), 798 group_multiplier_thresholds.end()) - 799 group_multiplier_thresholds.begin()); 800 801 compact_properties.resize(props_to_use.size()); 802 auto quantize_channel = [&]() { 803 if (!channel_multiplier_thresholds.empty()) { 804 return channel_multiplier_thresholds; 805 } 806 return QuantizeHistogram(channel_pixel_count, max_property_values); 807 }; 808 auto quantize_group_id = [&]() { 809 if (!group_multiplier_thresholds.empty()) { 810 return group_multiplier_thresholds; 811 } 812 return QuantizeHistogram(group_pixel_count, max_property_values); 813 }; 814 auto quantize_coordinate = [&]() { 815 std::vector<int32_t> quantized; 816 quantized.reserve(max_property_values - 1); 817 for (size_t i = 0; i + 1 < max_property_values; i++) { 818 quantized.push_back((i + 1) * 256 / max_property_values - 1); 819 } 820 return quantized; 821 }; 822 std::vector<int32_t> abs_pixel_thresholds; 823 std::vector<int32_t> pixel_thresholds; 824 auto quantize_pixel_property = [&]() { 825 if (pixel_thresholds.empty()) { 826 pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values); 827 } 828 return pixel_thresholds; 829 }; 830 auto quantize_abs_pixel_property = [&]() { 831 if (abs_pixel_thresholds.empty()) { 832 quantize_pixel_property(); // Compute the non-abs thresholds. 833 for (auto &v : pixel_samples) v = std::abs(v); 834 abs_pixel_thresholds = 835 QuantizeSamples(pixel_samples, max_property_values); 836 } 837 return abs_pixel_thresholds; 838 }; 839 std::vector<int32_t> abs_diff_thresholds; 840 std::vector<int32_t> diff_thresholds; 841 auto quantize_diff_property = [&]() { 842 if (diff_thresholds.empty()) { 843 diff_thresholds = QuantizeSamples(diff_samples, max_property_values); 844 } 845 return diff_thresholds; 846 }; 847 auto quantize_abs_diff_property = [&]() { 848 if (abs_diff_thresholds.empty()) { 849 quantize_diff_property(); // Compute the non-abs thresholds. 850 for (auto &v : diff_samples) v = std::abs(v); 851 abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values); 852 } 853 return abs_diff_thresholds; 854 }; 855 auto quantize_wp = [&]() { 856 if (max_property_values < 32) { 857 return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, 858 1, 3, 7, 15, 31, 63, 127}; 859 } 860 if (max_property_values < 64) { 861 return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, 862 -15, -11, -7, -5, -3, -1, 0, 1, 863 3, 5, 7, 11, 15, 23, 31, 47, 864 63, 95, 127, 191, 255}; 865 } 866 return std::vector<int32_t>{ 867 -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, 868 -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, 869 -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 870 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, 871 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; 872 }; 873 874 property_mapping.resize(props_to_use.size()); 875 for (size_t i = 0; i < props_to_use.size(); i++) { 876 if (props_to_use[i] == 0) { 877 compact_properties[i] = quantize_channel(); 878 } else if (props_to_use[i] == 1) { 879 compact_properties[i] = quantize_group_id(); 880 } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { 881 compact_properties[i] = quantize_coordinate(); 882 } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || 883 props_to_use[i] == 8 || 884 (props_to_use[i] >= kNumNonrefProperties && 885 (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { 886 compact_properties[i] = quantize_pixel_property(); 887 } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || 888 (props_to_use[i] >= kNumNonrefProperties && 889 (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { 890 compact_properties[i] = quantize_abs_pixel_property(); 891 } else if (props_to_use[i] >= kNumNonrefProperties && 892 (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { 893 compact_properties[i] = quantize_abs_diff_property(); 894 } else if (props_to_use[i] == kWPProp) { 895 compact_properties[i] = quantize_wp(); 896 } else { 897 compact_properties[i] = quantize_diff_property(); 898 } 899 property_mapping[i].resize(kPropertyRange * 2 + 1); 900 size_t mapped = 0; 901 for (size_t j = 0; j < property_mapping[i].size(); j++) { 902 while (mapped < compact_properties[i].size() && 903 static_cast<int>(j) - kPropertyRange > 904 compact_properties[i][mapped]) { 905 mapped++; 906 } 907 // property_mapping[i] of a value V is `mapped` if 908 // compact_properties[i][mapped] <= j and 909 // compact_properties[i][mapped-1] > j 910 // This is because the decision node in the tree splits on (property) > j, 911 // hence everything that is not > of a threshold should be clustered 912 // together. 913 property_mapping[i][j] = mapped; 914 } 915 } 916 } 917 918 void CollectPixelSamples(const Image &image, const ModularOptions &options, 919 uint32_t group_id, 920 std::vector<uint32_t> &group_pixel_count, 921 std::vector<uint32_t> &channel_pixel_count, 922 std::vector<pixel_type> &pixel_samples, 923 std::vector<pixel_type> &diff_samples) { 924 if (options.nb_repeats == 0) return; 925 if (group_pixel_count.size() <= group_id) { 926 group_pixel_count.resize(group_id + 1); 927 } 928 if (channel_pixel_count.size() < image.channel.size()) { 929 channel_pixel_count.resize(image.channel.size()); 930 } 931 Rng rng(group_id); 932 // Sample 10% of the final number of samples for property quantization. 933 float fraction = std::min(options.nb_repeats * 0.1, 0.99); 934 Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction); 935 size_t total_pixels = 0; 936 std::vector<size_t> channel_ids; 937 for (size_t i = 0; i < image.channel.size(); i++) { 938 if (image.channel[i].w <= 1 || image.channel[i].h == 0) { 939 continue; // skip empty or width-1 channels. 940 } 941 if (i >= image.nb_meta_channels && 942 (image.channel[i].w > options.max_chan_size || 943 image.channel[i].h > options.max_chan_size)) { 944 break; 945 } 946 channel_ids.push_back(i); 947 group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; 948 channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; 949 total_pixels += image.channel[i].w * image.channel[i].h; 950 } 951 if (channel_ids.empty()) return; 952 pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); 953 diff_samples.reserve(diff_samples.size() + fraction * total_pixels); 954 size_t i = 0; 955 size_t y = 0; 956 size_t x = 0; 957 auto advance = [&](size_t amount) { 958 x += amount; 959 // Detect row overflow (rare). 960 while (x >= image.channel[channel_ids[i]].w) { 961 x -= image.channel[channel_ids[i]].w; 962 y++; 963 // Detect end-of-channel (even rarer). 964 if (y == image.channel[channel_ids[i]].h) { 965 i++; 966 y = 0; 967 if (i >= channel_ids.size()) { 968 return; 969 } 970 } 971 } 972 }; 973 advance(rng.Geometric(dist)); 974 for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { 975 const pixel_type *row = image.channel[channel_ids[i]].Row(y); 976 pixel_samples.push_back(row[x]); 977 size_t xp = x == 0 ? 1 : x - 1; 978 diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]); 979 } 980 } 981 982 // TODO(veluca): very simple encoding scheme. This should be improved. 983 Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens, 984 Tree *decoder_tree) { 985 JXL_ENSURE(tree.size() <= kMaxTreeSize); 986 std::queue<int> q; 987 q.push(0); 988 size_t leaf_id = 0; 989 decoder_tree->clear(); 990 while (!q.empty()) { 991 int cur = q.front(); 992 q.pop(); 993 JXL_ENSURE(tree[cur].property >= -1); 994 tokens->emplace_back(kPropertyContext, tree[cur].property + 1); 995 if (tree[cur].property == -1) { 996 tokens->emplace_back(kPredictorContext, 997 static_cast<int>(tree[cur].predictor)); 998 tokens->emplace_back(kOffsetContext, 999 PackSigned(tree[cur].predictor_offset)); 1000 uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); 1001 uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; 1002 tokens->emplace_back(kMultiplierLogContext, mul_log); 1003 tokens->emplace_back(kMultiplierBitsContext, mul_bits); 1004 JXL_ENSURE(tree[cur].predictor < Predictor::Best); 1005 decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, 1006 tree[cur].predictor_offset, 1007 tree[cur].multiplier); 1008 continue; 1009 } 1010 decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, 1011 decoder_tree->size() + q.size() + 1, 1012 decoder_tree->size() + q.size() + 2, 1013 Predictor::Zero, 0, 1); 1014 q.push(tree[cur].lchild); 1015 q.push(tree[cur].rchild); 1016 tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); 1017 } 1018 return true; 1019 } 1020 1021 } // namespace jxl 1022 #endif // HWY_ONCE