enc_adaptive_quantization.cc (53125B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_adaptive_quantization.h" 7 8 #include <jxl/memory_manager.h> 9 10 #include <algorithm> 11 #include <atomic> 12 #include <cmath> 13 #include <cstddef> 14 #include <cstdlib> 15 #include <string> 16 #include <vector> 17 18 #include "lib/jxl/memory_manager_internal.h" 19 20 #undef HWY_TARGET_INCLUDE 21 #define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc" 22 #include <hwy/foreach_target.h> 23 #include <hwy/highway.h> 24 25 #include "lib/jxl/ac_strategy.h" 26 #include "lib/jxl/base/common.h" 27 #include "lib/jxl/base/compiler_specific.h" 28 #include "lib/jxl/base/data_parallel.h" 29 #include "lib/jxl/base/fast_math-inl.h" 30 #include "lib/jxl/base/rect.h" 31 #include "lib/jxl/base/status.h" 32 #include "lib/jxl/butteraugli/butteraugli.h" 33 #include "lib/jxl/convolve.h" 34 #include "lib/jxl/dec_cache.h" 35 #include "lib/jxl/dec_group.h" 36 #include "lib/jxl/enc_aux_out.h" 37 #include "lib/jxl/enc_butteraugli_comparator.h" 38 #include "lib/jxl/enc_cache.h" 39 #include "lib/jxl/enc_debug_image.h" 40 #include "lib/jxl/enc_group.h" 41 #include "lib/jxl/enc_modular.h" 42 #include "lib/jxl/enc_params.h" 43 #include "lib/jxl/enc_transforms-inl.h" 44 #include "lib/jxl/epf.h" 45 #include "lib/jxl/frame_dimensions.h" 46 #include "lib/jxl/image.h" 47 #include "lib/jxl/image_bundle.h" 48 #include "lib/jxl/image_ops.h" 49 #include "lib/jxl/quant_weights.h" 50 51 // Set JXL_DEBUG_ADAPTIVE_QUANTIZATION to 1 to enable debugging. 52 #ifndef JXL_DEBUG_ADAPTIVE_QUANTIZATION 53 #define JXL_DEBUG_ADAPTIVE_QUANTIZATION 0 54 #endif 55 56 HWY_BEFORE_NAMESPACE(); 57 namespace jxl { 58 namespace HWY_NAMESPACE { 59 namespace { 60 61 // These templates are not found via ADL. 62 using hwy::HWY_NAMESPACE::AbsDiff; 63 using hwy::HWY_NAMESPACE::Add; 64 using hwy::HWY_NAMESPACE::And; 65 using hwy::HWY_NAMESPACE::Gt; 66 using hwy::HWY_NAMESPACE::IfThenElseZero; 67 using hwy::HWY_NAMESPACE::Max; 68 using hwy::HWY_NAMESPACE::Min; 69 using hwy::HWY_NAMESPACE::Rebind; 70 using hwy::HWY_NAMESPACE::Sqrt; 71 using hwy::HWY_NAMESPACE::ZeroIfNegative; 72 73 // The following functions modulate an exponent (out_val) and return the updated 74 // value. Their descriptor is limited to 8 lanes for 8x8 blocks. 75 76 // Hack for mask estimation. Eventually replace this code with butteraugli's 77 // masking. 78 float ComputeMaskForAcStrategyUse(const float out_val) { 79 const float kMul = 1.0f; 80 const float kOffset = 0.001f; 81 return kMul / (out_val + kOffset); 82 } 83 84 template <class D, class V> 85 V ComputeMask(const D d, const V out_val) { 86 const auto kBase = Set(d, -0.7647f); 87 const auto kMul4 = Set(d, 9.4708735624378946f); 88 const auto kMul2 = Set(d, 17.35036561631863f); 89 const auto kOffset2 = Set(d, 302.59587815579727f); 90 const auto kMul3 = Set(d, 6.7943250517376494f); 91 const auto kOffset3 = Set(d, 3.7179635626140772f); 92 const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3); 93 const auto kMul0 = Set(d, 0.80061762862741759f); 94 const auto k1 = Set(d, 1.0f); 95 96 // Avoid division by zero. 97 const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f)); 98 const auto v2 = Div(k1, Add(v1, kOffset2)); 99 const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3)); 100 const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4)); 101 // TODO(jyrki): 102 // A log or two here could make sense. In butteraugli we have effectively 103 // log(log(x + C)) for this kind of use, as a single log is used in 104 // saturating visual masking and here the modulation values are exponential, 105 // another log would counter that. 106 return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3)))); 107 } 108 109 // mul and mul2 represent a scaling difference between jxl and butteraugli. 110 const float kSGmul = 226.77216153508914f; 111 const float kSGmul2 = 1.0f / 73.377132366608819f; 112 const float kLog2 = 0.693147181f; 113 // Includes correction factor for std::log -> log2. 114 const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2; 115 const float kSGVOffset = 7.7825991679894591f; 116 117 template <bool invert, typename D, typename V> 118 V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) { 119 // The opsin space in jxl is the cubic root of photons, i.e., v * v * v 120 // is related to the number of photons. 121 // 122 // SimpleGamma(v * v * v) is the psychovisual space in butteraugli. 123 // This ratio allows quantization to move from jxl's opsin space to 124 // butteraugli's log-gamma space. 125 float kEpsilon = 1e-2; 126 v = ZeroIfNegative(v); 127 const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul); 128 const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon); 129 const auto kDenMul = Set(d, kLog2 * kSGmul); 130 131 const auto v2 = Mul(v, v); 132 133 const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon)); 134 const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset); 135 return invert ? Div(num, den) : Div(den, num); 136 } 137 138 template <bool invert = false> 139 float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) { 140 using DScalar = HWY_CAPPED(float, 1); 141 auto vscalar = Load(DScalar(), &v); 142 return GetLane( 143 RatioOfDerivativesOfCubicRootToSimpleGamma<invert>(DScalar(), vscalar)); 144 } 145 146 // TODO(veluca): this function computes an approximation of the derivative of 147 // SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or 148 // exact derivatives. For reference, SimpleGamma was: 149 /* 150 template <typename D, typename V> 151 V SimpleGamma(const D d, V v) { 152 // A simple HDR compatible gamma function. 153 const auto mul = Set(d, kSGmul); 154 const auto kRetMul = Set(d, kSGRetMul); 155 const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f); 156 const auto kVOffset = Set(d, kSGVOffset); 157 158 v *= mul; 159 160 // This should happen rarely, but may lead to a NaN, which is rather 161 // undesirable. Since negative photons don't exist we solve the NaNs by 162 // clamping here. 163 // TODO(veluca): with FastLog2f, this no longer leads to NaNs. 164 v = ZeroIfNegative(v); 165 return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd; 166 } 167 */ 168 169 template <class D, class V> 170 V GammaModulation(const D d, const size_t x, const size_t y, 171 const ImageF& xyb_x, const ImageF& xyb_y, const Rect& rect, 172 const V out_val) { 173 const float kBias = 0.16f; 174 JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[0]); 175 JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[1]); 176 JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[2]); 177 auto overall_ratio = Zero(d); 178 auto bias = Set(d, kBias); 179 for (size_t dy = 0; dy < 8; ++dy) { 180 const float* const JXL_RESTRICT row_in_x = rect.ConstRow(xyb_x, y + dy); 181 const float* const JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy); 182 for (size_t dx = 0; dx < 8; dx += Lanes(d)) { 183 const auto iny = Add(Load(d, row_in_y + x + dx), bias); 184 const auto inx = Load(d, row_in_x + x + dx); 185 186 const auto r = Sub(iny, inx); 187 const auto ratio_r = 188 RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, r); 189 overall_ratio = Add(overall_ratio, ratio_r); 190 191 const auto g = Add(iny, inx); 192 const auto ratio_g = 193 RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, g); 194 overall_ratio = Add(overall_ratio, ratio_g); 195 } 196 } 197 overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 0.5f / 64)); 198 // ideally -1.0, but likely optimal correction adds some entropy, so slightly 199 // less than that. 200 const auto kGamma = Set(d, 0.1005613337192697f); 201 return MulAdd(kGamma, FastLog2f(d, overall_ratio), out_val); 202 } 203 204 // Change precision in 8x8 blocks that have significant amounts of blue 205 // content (but are not close to solid blue). 206 // This is based on the idea that M and L cone activations saturate the 207 // S (blue) receptors, and the S reception becomes more important when 208 // both M and L levels are low. In that case M and L receptors may be 209 // observing S-spectra instead and viewing them with higher spatial 210 // accuracy, justifying spending more bits here. 211 template <class D, class V> 212 V BlueModulation(const D d, const size_t x, const size_t y, 213 const ImageF& planex, const ImageF& planey, 214 const ImageF& planeb, const Rect& rect, const V out_val) { 215 auto sum = Zero(d); 216 static const float kLimit = 0.027121074570634722; 217 static const float kOffset = 0.084381641171960495; 218 for (size_t dy = 0; dy < 8; ++dy) { 219 const float* JXL_RESTRICT row_in_x = rect.ConstRow(planex, y + dy) + x; 220 const float* JXL_RESTRICT row_in_y = rect.ConstRow(planey, y + dy) + x; 221 const float* JXL_RESTRICT row_in_b = rect.ConstRow(planeb, y + dy) + x; 222 for (size_t dx = 0; dx < 8; dx += Lanes(d)) { 223 const auto p_x = Load(d, row_in_x + dx); 224 const auto p_b = Load(d, row_in_b + dx); 225 const auto p_y_raw = Add(Load(d, row_in_y + dx), Set(d, kOffset)); 226 const auto p_y_effective = Add(p_y_raw, Abs(p_x)); 227 sum = Add(sum, 228 IfThenElseZero(Gt(p_b, p_y_effective), 229 Min(Sub(p_b, p_y_effective), Set(d, kLimit)))); 230 } 231 } 232 static const float kMul = 0.14207000358439159; 233 sum = SumOfLanes(d, sum); 234 float scalar_sum = GetLane(sum); 235 // If it is all blue, don't boost the quantization. 236 // All blue likely means low frequency blue. Let's not make the most 237 // perfect sky ever. 238 if (scalar_sum >= 32 * kLimit) { 239 scalar_sum = 64 * kLimit - scalar_sum; 240 } 241 static const float kMaxLimit = 15.398788439047934f; 242 if (scalar_sum >= kMaxLimit * kLimit) { 243 scalar_sum = kMaxLimit * kLimit; 244 } 245 scalar_sum *= kMul; 246 return Add(Set(d, scalar_sum), out_val); 247 } 248 249 // Change precision in 8x8 blocks that have high frequency content. 250 template <class D, class V> 251 V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb_y, 252 const Rect& rect, const V out_val) { 253 // Zero out the invalid differences for the rightmost value per row. 254 const Rebind<uint32_t, D> du; 255 HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u, 256 ~0u, ~0u, ~0u, 0}; 257 258 // Sums of deltas of y and x components between (approximate) 259 // 4-connected pixels. 260 auto sum_y = Zero(d); 261 static const float valmin_y = 0.0206; 262 auto valminv_y = Set(d, valmin_y); 263 for (size_t dy = 0; dy < 8; ++dy) { 264 const float* JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy) + x; 265 const float* JXL_RESTRICT row_in_y_next = 266 dy == 7 ? row_in_y : rect.ConstRow(xyb_y, y + dy + 1) + x; 267 268 // In SCALAR, there is no guarantee of having extra row padding. 269 // Hence, we need to ensure we don't access pixels outside the row itself. 270 // In SIMD modes, however, rows are padded, so it's safe to access one 271 // garbage value after the row. The vector then gets masked with kMaskRight 272 // to remove the influence of that value. 273 #if HWY_TARGET != HWY_SCALAR 274 for (size_t dx = 0; dx < 8; dx += Lanes(d)) { 275 #else 276 for (size_t dx = 0; dx < 7; dx += Lanes(d)) { 277 #endif 278 const auto mask = BitCast(d, Load(du, kMaskRight + dx)); 279 { 280 const auto p_y = Load(d, row_in_y + dx); 281 const auto pr_y = LoadU(d, row_in_y + dx + 1); 282 sum_y = Add(sum_y, And(mask, Min(valminv_y, AbsDiff(p_y, pr_y)))); 283 const auto pd_y = Load(d, row_in_y_next + dx); 284 sum_y = Add(sum_y, Min(valminv_y, AbsDiff(p_y, pd_y))); 285 } 286 } 287 #if HWY_TARGET == HWY_SCALAR 288 const auto p_y = Load(d, row_in_y + 7); 289 const auto pd_y = Load(d, row_in_y_next + 7); 290 sum_y = Add(sum_y, Min(valminv_y, AbsDiff(p_y, pd_y))); 291 #endif 292 } 293 static const float kMul_y = -0.38; 294 sum_y = SumOfLanes(d, sum_y); 295 296 float scalar_sum_y = GetLane(sum_y); 297 scalar_sum_y *= kMul_y; 298 299 // higher value -> more bpp 300 float kOffset = 0.42; 301 scalar_sum_y += kOffset; 302 303 return Add(Set(d, scalar_sum_y), out_val); 304 } 305 306 void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x, 307 const ImageF& xyb_y, const ImageF& xyb_b, 308 const Rect& rect_in, const float scale, 309 const Rect& rect_out, ImageF* out) { 310 float base_level = 0.48f * scale; 311 float kDampenRampStart = 2.0f; 312 float kDampenRampEnd = 14.0f; 313 float dampen = 1.0f; 314 if (butteraugli_target >= kDampenRampStart) { 315 dampen = 1.0f - ((butteraugli_target - kDampenRampStart) / 316 (kDampenRampEnd - kDampenRampStart)); 317 if (dampen < 0) { 318 dampen = 0; 319 } 320 } 321 const float mul = scale * dampen; 322 const float add = (1.0f - dampen) * base_level; 323 for (size_t iy = rect_out.y0(); iy < rect_out.y1(); iy++) { 324 const size_t y = iy * 8; 325 float* const JXL_RESTRICT row_out = out->Row(iy); 326 const HWY_CAPPED(float, kBlockDim) df; 327 for (size_t ix = rect_out.x0(); ix < rect_out.x1(); ix++) { 328 size_t x = ix * 8; 329 auto out_val = Set(df, row_out[ix]); 330 out_val = ComputeMask(df, out_val); 331 out_val = HfModulation(df, x, y, xyb_y, rect_in, out_val); 332 out_val = GammaModulation(df, x, y, xyb_x, xyb_y, rect_in, out_val); 333 out_val = BlueModulation(df, x, y, xyb_x, xyb_y, xyb_b, rect_in, out_val); 334 // We want multiplicative quantization field, so everything 335 // until this point has been modulating the exponent. 336 row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add; 337 } 338 } 339 } 340 341 template <typename D, typename V> 342 V MaskingSqrt(const D d, V v) { 343 static const float kLogOffset = 27.505837037000106f; 344 static const float kMul = 211.66567973503678f; 345 const auto mul_v = Set(d, kMul * 1e8); 346 const auto offset_v = Set(d, kLogOffset); 347 return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v))); 348 } 349 350 float MaskingSqrt(const float v) { 351 using DScalar = HWY_CAPPED(float, 1); 352 auto vscalar = Load(DScalar(), &v); 353 return GetLane(MaskingSqrt(DScalar(), vscalar)); 354 } 355 356 void StoreMin4(const float v, float& min0, float& min1, float& min2, 357 float& min3) { 358 if (v < min3) { 359 if (v < min0) { 360 min3 = min2; 361 min2 = min1; 362 min1 = min0; 363 min0 = v; 364 } else if (v < min1) { 365 min3 = min2; 366 min2 = min1; 367 min1 = v; 368 } else if (v < min2) { 369 min3 = min2; 370 min2 = v; 371 } else { 372 min3 = v; 373 } 374 } 375 } 376 377 // Look for smooth areas near the area of degradation. 378 // If the areas are generally smooth, don't do masking. 379 // Output is downsampled 2x. 380 Status FuzzyErosion(const float butteraugli_target, const Rect& from_rect, 381 const ImageF& from, const Rect& to_rect, ImageF* to) { 382 const size_t xsize = from.xsize(); 383 const size_t ysize = from.ysize(); 384 constexpr int kStep = 1; 385 static_assert(kStep == 1, "Step must be 1"); 386 JXL_ENSURE(to_rect.xsize() * 2 == from_rect.xsize()); 387 JXL_ENSURE(to_rect.ysize() * 2 == from_rect.ysize()); 388 static const float kMulBase0 = 0.125; 389 static const float kMulBase1 = 0.10; 390 static const float kMulBase2 = 0.09; 391 static const float kMulBase3 = 0.06; 392 static const float kMulAdd0 = 0.0; 393 static const float kMulAdd1 = -0.10; 394 static const float kMulAdd2 = -0.09; 395 static const float kMulAdd3 = -0.06; 396 397 float mul = 0.0; 398 if (butteraugli_target < 2.0f) { 399 mul = (2.0f - butteraugli_target) * (1.0f / 2.0f); 400 } 401 float kMul0 = kMulBase0 + mul * kMulAdd0; 402 float kMul1 = kMulBase1 + mul * kMulAdd1; 403 float kMul2 = kMulBase2 + mul * kMulAdd2; 404 float kMul3 = kMulBase3 + mul * kMulAdd3; 405 static const float kTotal = 0.29959705784054957; 406 float norm = kTotal / (kMul0 + kMul1 + kMul2 + kMul3); 407 kMul0 *= norm; 408 kMul1 *= norm; 409 kMul2 *= norm; 410 kMul3 *= norm; 411 412 for (size_t fy = 0; fy < from_rect.ysize(); ++fy) { 413 size_t y = fy + from_rect.y0(); 414 size_t ym1 = y >= kStep ? y - kStep : y; 415 size_t yp1 = y + kStep < ysize ? y + kStep : y; 416 const float* rowt = from.Row(ym1); 417 const float* row = from.Row(y); 418 const float* rowb = from.Row(yp1); 419 float* row_out = to_rect.Row(to, fy / 2); 420 for (size_t fx = 0; fx < from_rect.xsize(); ++fx) { 421 size_t x = fx + from_rect.x0(); 422 size_t xm1 = x >= kStep ? x - kStep : x; 423 size_t xp1 = x + kStep < xsize ? x + kStep : x; 424 float min0 = row[x]; 425 float min1 = row[xm1]; 426 float min2 = row[xp1]; 427 float min3 = rowt[xm1]; 428 // Sort the first four values. 429 if (min0 > min1) std::swap(min0, min1); 430 if (min0 > min2) std::swap(min0, min2); 431 if (min0 > min3) std::swap(min0, min3); 432 if (min1 > min2) std::swap(min1, min2); 433 if (min1 > min3) std::swap(min1, min3); 434 if (min2 > min3) std::swap(min2, min3); 435 // The remaining five values of a 3x3 neighbourhood. 436 StoreMin4(rowt[x], min0, min1, min2, min3); 437 StoreMin4(rowt[xp1], min0, min1, min2, min3); 438 StoreMin4(rowb[xm1], min0, min1, min2, min3); 439 StoreMin4(rowb[x], min0, min1, min2, min3); 440 StoreMin4(rowb[xp1], min0, min1, min2, min3); 441 442 float v = kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + kMul3 * min3; 443 if (fx % 2 == 0 && fy % 2 == 0) { 444 row_out[fx / 2] = v; 445 } else { 446 row_out[fx / 2] += v; 447 } 448 } 449 } 450 return true; 451 } 452 453 struct AdaptiveQuantizationImpl { 454 Status PrepareBuffers(JxlMemoryManager* memory_manager, size_t num_threads) { 455 JXL_ASSIGN_OR_RETURN( 456 diff_buffer, 457 ImageF::Create(memory_manager, kEncTileDim + 8, num_threads)); 458 for (size_t i = pre_erosion.size(); i < num_threads; i++) { 459 JXL_ASSIGN_OR_RETURN( 460 ImageF tmp, 461 ImageF::Create(memory_manager, kEncTileDimInBlocks * 2 + 2, 462 kEncTileDimInBlocks * 2 + 2)); 463 pre_erosion.emplace_back(std::move(tmp)); 464 } 465 return true; 466 } 467 468 Status ComputeTile(float butteraugli_target, float scale, const Image3F& xyb, 469 const Rect& rect_in, const Rect& rect_out, 470 const int thread, ImageF* mask, ImageF* mask1x1) { 471 JXL_ENSURE(rect_in.x0() % kBlockDim == 0); 472 JXL_ENSURE(rect_in.y0() % kBlockDim == 0); 473 const size_t xsize = xyb.xsize(); 474 const size_t ysize = xyb.ysize(); 475 476 // The XYB gamma is 3.0 to be able to decode faster with two muls. 477 // Butteraugli's gamma is matching the gamma of human eye, around 2.6. 478 // We approximate the gamma difference by adding one cubic root into 479 // the adaptive quantization. This gives us a total gamma of 2.6666 480 // for quantization uses. 481 const float match_gamma_offset = 0.019; 482 483 const HWY_FULL(float) df; 484 485 size_t y_start_1x1 = rect_in.y0() + rect_out.y0() * 8; 486 size_t y_end_1x1 = y_start_1x1 + rect_out.ysize() * 8; 487 488 size_t x_start_1x1 = rect_in.x0() + rect_out.x0() * 8; 489 size_t x_end_1x1 = x_start_1x1 + rect_out.xsize() * 8; 490 491 if (rect_in.x0() != 0 && rect_out.x0() == 0) x_start_1x1 -= 2; 492 if (rect_in.x1() < xsize && rect_out.x1() * 8 == rect_in.xsize()) { 493 x_end_1x1 += 2; 494 } 495 if (rect_in.y0() != 0 && rect_out.y0() == 0) y_start_1x1 -= 2; 496 if (rect_in.y1() < ysize && rect_out.y1() * 8 == rect_in.ysize()) { 497 y_end_1x1 += 2; 498 } 499 500 // Computes image (padded to multiple of 8x8) of local pixel differences. 501 // Subsample both directions by 4. 502 // 1x1 Laplacian of intensity. 503 for (size_t y = y_start_1x1; y < y_end_1x1; ++y) { 504 const size_t y2 = y + 1 < ysize ? y + 1 : y; 505 const size_t y1 = y > 0 ? y - 1 : y; 506 const float* row_in = xyb.ConstPlaneRow(1, y); 507 const float* row_in1 = xyb.ConstPlaneRow(1, y1); 508 const float* row_in2 = xyb.ConstPlaneRow(1, y2); 509 float* mask1x1_out = mask1x1->Row(y); 510 auto scalar_pixel1x1 = [&](size_t x) { 511 const size_t x2 = x + 1 < xsize ? x + 1 : x; 512 const size_t x1 = x > 0 ? x - 1 : x; 513 const float base = 514 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); 515 const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( 516 row_in[x] + match_gamma_offset); 517 float diff = fabs(gammac * (row_in[x] - base)); 518 static const double kScaler = 1.0; 519 diff *= kScaler; 520 diff = log1p(diff); 521 static const float kMul = 1.0; 522 static const float kOffset = 0.01; 523 mask1x1_out[x] = kMul / (diff + kOffset); 524 }; 525 for (size_t x = x_start_1x1; x < x_end_1x1; ++x) { 526 scalar_pixel1x1(x); 527 } 528 } 529 530 size_t y_start = rect_in.y0() + rect_out.y0() * 8; 531 size_t y_end = y_start + rect_out.ysize() * 8; 532 533 size_t x_start = rect_in.x0() + rect_out.x0() * 8; 534 size_t x_end = x_start + rect_out.xsize() * 8; 535 536 if (x_start != 0) x_start -= 4; 537 if (x_end != xsize) x_end += 4; 538 if (y_start != 0) y_start -= 4; 539 if (y_end != ysize) y_end += 4; 540 JXL_RETURN_IF_ERROR(pre_erosion[thread].ShrinkTo((x_end - x_start) / 4, 541 (y_end - y_start) / 4)); 542 543 static const float limit = 0.2f; 544 for (size_t y = y_start; y < y_end; ++y) { 545 size_t y2 = y + 1 < ysize ? y + 1 : y; 546 size_t y1 = y > 0 ? y - 1 : y; 547 548 const float* row_in = xyb.ConstPlaneRow(1, y); 549 const float* row_in1 = xyb.ConstPlaneRow(1, y1); 550 const float* row_in2 = xyb.ConstPlaneRow(1, y2); 551 float* JXL_RESTRICT row_out = diff_buffer.Row(thread); 552 553 auto scalar_pixel = [&](size_t x) { 554 const size_t x2 = x + 1 < xsize ? x + 1 : x; 555 const size_t x1 = x > 0 ? x - 1 : x; 556 const float base = 557 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); 558 const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( 559 row_in[x] + match_gamma_offset); 560 float diff = gammac * (row_in[x] - base); 561 diff *= diff; 562 if (diff >= limit) { 563 diff = limit; 564 } 565 diff = MaskingSqrt(diff); 566 if ((y % 4) != 0) { 567 row_out[x - x_start] += diff; 568 } else { 569 row_out[x - x_start] = diff; 570 } 571 }; 572 573 size_t x = x_start; 574 // First pixel of the row. 575 if (x_start == 0) { 576 scalar_pixel(x_start); 577 ++x; 578 } 579 // SIMD 580 const auto match_gamma_offset_v = Set(df, match_gamma_offset); 581 const auto quarter = Set(df, 0.25f); 582 for (; x + 1 + Lanes(df) < x_end; x += Lanes(df)) { 583 const auto in = LoadU(df, row_in + x); 584 const auto in_r = LoadU(df, row_in + x + 1); 585 const auto in_l = LoadU(df, row_in + x - 1); 586 const auto in_t = LoadU(df, row_in2 + x); 587 const auto in_b = LoadU(df, row_in1 + x); 588 auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b))); 589 auto gammacv = 590 RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/false>( 591 df, Add(in, match_gamma_offset_v)); 592 auto diff = Mul(gammacv, Sub(in, base)); 593 diff = Mul(diff, diff); 594 diff = Min(diff, Set(df, limit)); 595 diff = MaskingSqrt(df, diff); 596 if ((y & 3) != 0) { 597 diff = Add(diff, LoadU(df, row_out + x - x_start)); 598 } 599 StoreU(diff, df, row_out + x - x_start); 600 } 601 // Scalar 602 for (; x < x_end; ++x) { 603 scalar_pixel(x); 604 } 605 if (y % 4 == 3) { 606 float* row_d_out = pre_erosion[thread].Row((y - y_start) / 4); 607 for (size_t x = 0; x < (x_end - x_start) / 4; x++) { 608 row_d_out[x] = (row_out[x * 4] + row_out[x * 4 + 1] + 609 row_out[x * 4 + 2] + row_out[x * 4 + 3]) * 610 0.25f; 611 } 612 } 613 } 614 JXL_ENSURE(x_start % (kBlockDim / 2) == 0); 615 JXL_ENSURE(y_start % (kBlockDim / 2) == 0); 616 Rect from_rect(x_start % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1, 617 rect_out.xsize() * 2, rect_out.ysize() * 2); 618 JXL_RETURN_IF_ERROR(FuzzyErosion(butteraugli_target, from_rect, 619 pre_erosion[thread], rect_out, &aq_map)); 620 for (size_t y = 0; y < rect_out.ysize(); ++y) { 621 const float* aq_map_row = rect_out.ConstRow(aq_map, y); 622 float* mask_row = rect_out.Row(mask, y); 623 for (size_t x = 0; x < rect_out.xsize(); ++x) { 624 mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]); 625 } 626 } 627 PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1), 628 xyb.Plane(2), rect_in, scale, rect_out, &aq_map); 629 return true; 630 } 631 std::vector<ImageF> pre_erosion; 632 ImageF aq_map; 633 ImageF diff_buffer; 634 }; 635 636 Status Blur1x1Masking(JxlMemoryManager* memory_manager, ThreadPool* pool, 637 ImageF* mask1x1, const Rect& rect) { 638 // Blur the mask1x1 to obtain the masking image. 639 // Before blurring it contains an image of absolute value of the 640 // Laplacian of the intensity channel. 641 static const float kFilterMask1x1[5] = { 642 static_cast<float>(0.25647067633737227), 643 static_cast<float>(0.2050056912354399075), 644 static_cast<float>(0.154082048668497307), 645 static_cast<float>(0.08149576591362004441), 646 static_cast<float>(0.0512750104812308467), 647 }; 648 double sum = 649 1.0 + 4 * (kFilterMask1x1[0] + kFilterMask1x1[1] + kFilterMask1x1[2] + 650 kFilterMask1x1[4] + 2 * kFilterMask1x1[3]); 651 if (sum < 1e-5) { 652 sum = 1e-5; 653 } 654 const float normalize = static_cast<float>(1.0 / sum); 655 const float normalize_mul = normalize; 656 WeightsSymmetric5 weights = 657 WeightsSymmetric5{{HWY_REP4(normalize)}, 658 {HWY_REP4(normalize_mul * kFilterMask1x1[0])}, 659 {HWY_REP4(normalize_mul * kFilterMask1x1[2])}, 660 {HWY_REP4(normalize_mul * kFilterMask1x1[1])}, 661 {HWY_REP4(normalize_mul * kFilterMask1x1[4])}, 662 {HWY_REP4(normalize_mul * kFilterMask1x1[3])}}; 663 JXL_ASSIGN_OR_RETURN( 664 ImageF temp, ImageF::Create(memory_manager, rect.xsize(), rect.ysize())); 665 JXL_RETURN_IF_ERROR(Symmetric5(*mask1x1, rect, weights, pool, &temp)); 666 *mask1x1 = std::move(temp); 667 return true; 668 } 669 670 StatusOr<ImageF> AdaptiveQuantizationMap(const float butteraugli_target, 671 const Image3F& xyb, const Rect& rect, 672 float scale, ThreadPool* pool, 673 ImageF* mask, ImageF* mask1x1) { 674 JXL_ENSURE(rect.xsize() % kBlockDim == 0); 675 JXL_ENSURE(rect.ysize() % kBlockDim == 0); 676 AdaptiveQuantizationImpl impl; 677 const size_t xsize_blocks = rect.xsize() / kBlockDim; 678 const size_t ysize_blocks = rect.ysize() / kBlockDim; 679 JxlMemoryManager* memory_manager = xyb.memory_manager(); 680 JXL_ASSIGN_OR_RETURN( 681 impl.aq_map, ImageF::Create(memory_manager, xsize_blocks, ysize_blocks)); 682 JXL_ASSIGN_OR_RETURN( 683 *mask, ImageF::Create(memory_manager, xsize_blocks, ysize_blocks)); 684 JXL_ASSIGN_OR_RETURN( 685 *mask1x1, ImageF::Create(memory_manager, xyb.xsize(), xyb.ysize())); 686 const auto prepare = [&](const size_t num_threads) -> Status { 687 JXL_RETURN_IF_ERROR(impl.PrepareBuffers(memory_manager, num_threads)); 688 return true; 689 }; 690 const auto process_tile = [&](const uint32_t tid, 691 const size_t thread) -> Status { 692 size_t n_enc_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks); 693 size_t tx = tid % n_enc_tiles; 694 size_t ty = tid / n_enc_tiles; 695 size_t by0 = ty * kEncTileDimInBlocks; 696 size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, ysize_blocks); 697 size_t bx0 = tx * kEncTileDimInBlocks; 698 size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, xsize_blocks); 699 Rect rect_out(bx0, by0, bx1 - bx0, by1 - by0); 700 JXL_RETURN_IF_ERROR(impl.ComputeTile(butteraugli_target, scale, xyb, rect, 701 rect_out, thread, mask, mask1x1)); 702 return true; 703 }; 704 size_t num_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks) * 705 DivCeil(ysize_blocks, kEncTileDimInBlocks); 706 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_tiles, prepare, process_tile, 707 "AQ DiffPrecompute")); 708 709 JXL_RETURN_IF_ERROR(Blur1x1Masking(memory_manager, pool, mask1x1, rect)); 710 return std::move(impl).aq_map; 711 } 712 713 } // namespace 714 715 // NOLINTNEXTLINE(google-readability-namespace-comments) 716 } // namespace HWY_NAMESPACE 717 } // namespace jxl 718 HWY_AFTER_NAMESPACE(); 719 720 #if HWY_ONCE 721 namespace jxl { 722 HWY_EXPORT(AdaptiveQuantizationMap); 723 724 namespace { 725 726 // If true, prints the quantization maps at each iteration. 727 constexpr bool FLAGS_dump_quant_state = false; 728 729 Status DumpHeatmap(const CompressParams& cparams, const AuxOut* aux_out, 730 const std::string& label, const ImageF& image, 731 float good_threshold, float bad_threshold) { 732 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { 733 JXL_ASSIGN_OR_RETURN( 734 Image3F heatmap, 735 CreateHeatMapImage(image, good_threshold, bad_threshold)); 736 char filename[200]; 737 snprintf(filename, sizeof(filename), "%s%05d", label.c_str(), 738 aux_out->num_butteraugli_iters); 739 JXL_RETURN_IF_ERROR(DumpImage(cparams, filename, heatmap)); 740 } 741 return true; 742 } 743 744 Status DumpHeatmaps(const CompressParams& cparams, const AuxOut* aux_out, 745 float butteraugli_target, const ImageF& quant_field, 746 const ImageF& tile_heatmap, const ImageF& bt_diffmap) { 747 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { 748 JxlMemoryManager* memory_manager = quant_field.memory_manager(); 749 if (!WantDebugOutput(cparams)) return true; 750 JXL_ASSIGN_OR_RETURN(ImageF inv_qmap, 751 ImageF::Create(memory_manager, quant_field.xsize(), 752 quant_field.ysize())); 753 for (size_t y = 0; y < quant_field.ysize(); ++y) { 754 const float* JXL_RESTRICT row_q = quant_field.ConstRow(y); 755 float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y); 756 for (size_t x = 0; x < quant_field.xsize(); ++x) { 757 row_inv_q[x] = 1.0f / row_q[x]; // never zero 758 } 759 } 760 JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "quant_heatmap", inv_qmap, 761 4.0f * butteraugli_target, 762 6.0f * butteraugli_target)); 763 JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "tile_heatmap", 764 tile_heatmap, butteraugli_target, 765 1.5f * butteraugli_target)); 766 // matches heat maps produced by the command line tool. 767 JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "bt_diffmap", bt_diffmap, 768 ButteraugliFuzzyInverse(1.5), 769 ButteraugliFuzzyInverse(0.5))); 770 } 771 return true; 772 } 773 774 StatusOr<ImageF> TileDistMap(const ImageF& distmap, int tile_size, int margin, 775 const AcStrategyImage& ac_strategy) { 776 const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size; 777 const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size; 778 JxlMemoryManager* memory_manager = distmap.memory_manager(); 779 JXL_ASSIGN_OR_RETURN(ImageF tile_distmap, 780 ImageF::Create(memory_manager, tile_xsize, tile_ysize)); 781 size_t distmap_stride = tile_distmap.PixelsPerRow(); 782 for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) { 783 AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y); 784 float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y); 785 for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) { 786 AcStrategy acs = ac_strategy_row[tile_x]; 787 if (!acs.IsFirstBlock()) continue; 788 int this_tile_xsize = acs.covered_blocks_x() * tile_size; 789 int this_tile_ysize = acs.covered_blocks_y() * tile_size; 790 int y_begin = std::max<int>(0, tile_size * tile_y - margin); 791 int y_end = std::min<int>(distmap.ysize(), 792 tile_size * tile_y + this_tile_ysize + margin); 793 int x_begin = std::max<int>(0, tile_size * tile_x - margin); 794 int x_end = std::min<int>(distmap.xsize(), 795 tile_size * tile_x + this_tile_xsize + margin); 796 float dist_norm = 0.0; 797 double pixels = 0; 798 for (int y = y_begin; y < y_end; ++y) { 799 float ymul = 1.0; 800 constexpr float kBorderMul = 0.98f; 801 constexpr float kCornerMul = 0.7f; 802 if (margin != 0 && (y == y_begin || y == y_end - 1)) { 803 ymul = kBorderMul; 804 } 805 const float* const JXL_RESTRICT row = distmap.Row(y); 806 for (int x = x_begin; x < x_end; ++x) { 807 float xmul = ymul; 808 if (margin != 0 && (x == x_begin || x == x_end - 1)) { 809 if (xmul == 1.0) { 810 xmul = kBorderMul; 811 } else { 812 xmul = kCornerMul; 813 } 814 } 815 float v = row[x]; 816 v *= v; 817 v *= v; 818 v *= v; 819 v *= v; 820 dist_norm += xmul * v; 821 pixels += xmul; 822 } 823 } 824 if (pixels == 0) pixels = 1; 825 // 16th norm is less than the max norm, we reduce the difference 826 // with this normalization factor. 827 constexpr float kTileNorm = 1.2f; 828 const float tile_dist = 829 kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f); 830 dist_row[tile_x] = tile_dist; 831 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { 832 for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { 833 dist_row[tile_x + distmap_stride * iy + ix] = tile_dist; 834 } 835 } 836 } 837 } 838 return tile_distmap; 839 } 840 841 const float kDcQuantPow = 0.83f; 842 const float kDcQuant = 1.095924047623553f; 843 const float kAcQuant = 0.725f; 844 845 // Computes the decoded image for a given set of compression parameters. 846 StatusOr<ImageBundle> RoundtripImage(const FrameHeader& frame_header, 847 const Image3F& opsin, 848 PassesEncoderState* enc_state, 849 const JxlCmsInterface& cms, 850 ThreadPool* pool) { 851 JxlMemoryManager* memory_manager = enc_state->memory_manager(); 852 std::unique_ptr<PassesDecoderState> dec_state = 853 jxl::make_unique<PassesDecoderState>(memory_manager); 854 JXL_RETURN_IF_ERROR(dec_state->output_encoding_info.SetFromMetadata( 855 *enc_state->shared.metadata)); 856 dec_state->shared = &enc_state->shared; 857 JXL_ENSURE(opsin.ysize() % kBlockDim == 0); 858 859 const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim); 860 const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim); 861 const size_t num_groups = xsize_groups * ysize_groups; 862 863 size_t num_special_frames = enc_state->special_frames.size(); 864 size_t num_passes = enc_state->progressive_splitter.GetNumPasses(); 865 JXL_ASSIGN_OR_RETURN(ModularFrameEncoder modular_frame_encoder, 866 ModularFrameEncoder::Create(memory_manager, frame_header, 867 enc_state->cparams, false)); 868 JXL_RETURN_IF_ERROR(InitializePassesEncoder(frame_header, opsin, Rect(opsin), 869 cms, pool, enc_state, 870 &modular_frame_encoder, nullptr)); 871 JXL_RETURN_IF_ERROR(dec_state->Init(frame_header)); 872 JXL_RETURN_IF_ERROR(dec_state->InitForAC(num_passes, pool)); 873 874 ImageBundle decoded(memory_manager, &enc_state->shared.metadata->m); 875 decoded.origin = frame_header.frame_origin; 876 JXL_ASSIGN_OR_RETURN( 877 Image3F tmp, 878 Image3F::Create(memory_manager, opsin.xsize(), opsin.ysize())); 879 JXL_RETURN_IF_ERROR(decoded.SetFromImage( 880 std::move(tmp), dec_state->output_encoding_info.color_encoding)); 881 882 PassesDecoderState::PipelineOptions options; 883 options.use_slow_render_pipeline = false; 884 options.coalescing = false; 885 options.render_spotcolors = false; 886 options.render_noise = false; 887 888 // Same as frame_header.nonserialized_metadata->m 889 const ImageMetadata& metadata = *decoded.metadata(); 890 891 JXL_RETURN_IF_ERROR(dec_state->PreparePipeline( 892 frame_header, &enc_state->shared.metadata->m, &decoded, options)); 893 894 AlignedArray<GroupDecCache> group_dec_caches; 895 const auto allocate_storage = [&](const size_t num_threads) -> Status { 896 JXL_RETURN_IF_ERROR( 897 dec_state->render_pipeline->PrepareForThreads(num_threads, 898 /*use_group_ids=*/false)); 899 JXL_ASSIGN_OR_RETURN(group_dec_caches, AlignedArray<GroupDecCache>::Create( 900 memory_manager, num_threads)); 901 return true; 902 }; 903 const auto process_group = [&](const uint32_t group_index, 904 const size_t thread) -> Status { 905 if (frame_header.loop_filter.epf_iters > 0) { 906 JXL_RETURN_IF_ERROR( 907 ComputeSigma(frame_header.loop_filter, 908 dec_state->shared->frame_dim.BlockGroupRect(group_index), 909 dec_state.get())); 910 } 911 RenderPipelineInput input = 912 dec_state->render_pipeline->GetInputBuffers(group_index, thread); 913 JXL_RETURN_IF_ERROR(DecodeGroupForRoundtrip( 914 frame_header, enc_state->coeffs, group_index, dec_state.get(), 915 &group_dec_caches[thread], thread, input, nullptr, nullptr)); 916 for (size_t c = 0; c < metadata.num_extra_channels; c++) { 917 std::pair<ImageF*, Rect> ri = input.GetBuffer(3 + c); 918 FillPlane(0.0f, ri.first, ri.second); 919 } 920 JXL_RETURN_IF_ERROR(input.Done()); 921 return true; 922 }; 923 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_groups, allocate_storage, 924 process_group, "AQ loop")); 925 926 // Ensure we don't create any new special frames. 927 enc_state->special_frames.resize(num_special_frames); 928 929 return decoded; 930 } 931 932 constexpr int kMaxButteraugliIters = 4; 933 934 Status FindBestQuantization(const FrameHeader& frame_header, 935 const Image3F& linear, const Image3F& opsin, 936 ImageF& quant_field, PassesEncoderState* enc_state, 937 const JxlCmsInterface& cms, ThreadPool* pool, 938 AuxOut* aux_out) { 939 const CompressParams& cparams = enc_state->cparams; 940 if (cparams.resampling > 1 && 941 cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) { 942 // For downsampled opsin image, the butteraugli based adaptive quantization 943 // loop would only make the size bigger without improving the distance much, 944 // so in this case we enable it only for very high butteraugli targets. 945 return true; 946 } 947 JxlMemoryManager* memory_manager = enc_state->memory_manager(); 948 Quantizer& quantizer = enc_state->shared.quantizer; 949 ImageI& raw_quant_field = enc_state->shared.raw_quant_field; 950 951 const float butteraugli_target = cparams.butteraugli_distance; 952 const float original_butteraugli = cparams.original_butteraugli_distance; 953 ButteraugliParams params; 954 const auto& tf = frame_header.nonserialized_metadata->m.color_encoding.Tf(); 955 params.intensity_target = 956 tf.IsPQ() || tf.IsHLG() 957 ? frame_header.nonserialized_metadata->m.IntensityTarget() 958 : 80.f; 959 JxlButteraugliComparator comparator(params, cms); 960 JXL_RETURN_IF_ERROR(comparator.SetLinearReferenceImage(linear)); 961 bool lower_is_better = 962 (comparator.GoodQualityScore() < comparator.BadQualityScore()); 963 const float initial_quant_dc = InitialQuantDC(butteraugli_target); 964 JXL_RETURN_IF_ERROR(AdjustQuantField(enc_state->shared.ac_strategy, 965 Rect(quant_field), original_butteraugli, 966 &quant_field)); 967 ImageF tile_distmap; 968 JXL_ASSIGN_OR_RETURN( 969 ImageF initial_quant_field, 970 ImageF::Create(memory_manager, quant_field.xsize(), quant_field.ysize())); 971 JXL_RETURN_IF_ERROR(CopyImageTo(quant_field, &initial_quant_field)); 972 973 float initial_qf_min; 974 float initial_qf_max; 975 ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max); 976 float initial_qf_ratio = initial_qf_max / initial_qf_min; 977 float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio); 978 float asymmetry = 2; 979 if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low; 980 float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low); 981 float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry); 982 983 JXL_ENSURE(qf_higher / qf_lower < 253); 984 985 constexpr int kOriginalComparisonRound = 1; 986 int iters = kMaxButteraugliIters; 987 if (cparams.speed_tier != SpeedTier::kTortoise) { 988 iters = 2; 989 } 990 for (int i = 0; i < iters + 1; ++i) { 991 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { 992 printf("\nQuantization field:\n"); 993 for (size_t y = 0; y < quant_field.ysize(); ++y) { 994 for (size_t x = 0; x < quant_field.xsize(); ++x) { 995 printf(" %.5f", quant_field.Row(y)[x]); 996 } 997 printf("\n"); 998 } 999 } 1000 JXL_RETURN_IF_ERROR(quantizer.SetQuantField(initial_quant_dc, quant_field, 1001 &raw_quant_field)); 1002 JXL_ASSIGN_OR_RETURN( 1003 ImageBundle dec_linear, 1004 RoundtripImage(frame_header, opsin, enc_state, cms, pool)); 1005 float score; 1006 ImageF diffmap; 1007 JXL_RETURN_IF_ERROR(comparator.CompareWith(dec_linear, &diffmap, &score)); 1008 if (!lower_is_better) { 1009 score = -score; 1010 ScaleImage(-1.0f, &diffmap); 1011 } 1012 JXL_ASSIGN_OR_RETURN(tile_distmap, 1013 TileDistMap(diffmap, 8 * cparams.resampling, 0, 1014 enc_state->shared.ac_strategy)); 1015 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && WantDebugOutput(cparams)) { 1016 JXL_RETURN_IF_ERROR(DumpImage(cparams, ("dec" + ToString(i)).c_str(), 1017 *dec_linear.color())); 1018 JXL_RETURN_IF_ERROR(DumpHeatmaps(cparams, aux_out, butteraugli_target, 1019 quant_field, tile_distmap, diffmap)); 1020 } 1021 if (aux_out != nullptr) ++aux_out->num_butteraugli_iters; 1022 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { 1023 float minval; 1024 float maxval; 1025 ImageMinMax(quant_field, &minval, &maxval); 1026 printf("\nButteraugli iter: %d/%d\n", i, kMaxButteraugliIters); 1027 printf("Butteraugli distance: %f (target = %f)\n", score, 1028 original_butteraugli); 1029 printf("quant range: %f ... %f DC quant: %f\n", minval, maxval, 1030 initial_quant_dc); 1031 if (FLAGS_dump_quant_state) { 1032 quantizer.DumpQuantizationMap(raw_quant_field); 1033 } 1034 } 1035 1036 if (i == iters) break; 1037 1038 double kPow[8] = { 1039 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1040 }; 1041 double kPowMod[8] = { 1042 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1043 }; 1044 if (i == kOriginalComparisonRound) { 1045 // Don't allow optimization to make the quant field a lot worse than 1046 // what the initial guess was. This allows the AC field to have enough 1047 // precision to reduce the oscillations due to the dc reconstruction. 1048 double kInitMul = 0.6; 1049 const double kOneMinusInitMul = 1.0 - kInitMul; 1050 for (size_t y = 0; y < quant_field.ysize(); ++y) { 1051 float* const JXL_RESTRICT row_q = quant_field.Row(y); 1052 const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y); 1053 for (size_t x = 0; x < quant_field.xsize(); ++x) { 1054 double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x]; 1055 if (row_q[x] < clamp) { 1056 row_q[x] = clamp; 1057 if (row_q[x] > qf_higher) row_q[x] = qf_higher; 1058 if (row_q[x] < qf_lower) row_q[x] = qf_lower; 1059 } 1060 } 1061 } 1062 } 1063 1064 double cur_pow = 0.0; 1065 if (i < 7) { 1066 cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i]; 1067 if (cur_pow < 0) { 1068 cur_pow = 0; 1069 } 1070 } 1071 if (cur_pow == 0.0) { 1072 for (size_t y = 0; y < quant_field.ysize(); ++y) { 1073 const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); 1074 float* const JXL_RESTRICT row_q = quant_field.Row(y); 1075 for (size_t x = 0; x < quant_field.xsize(); ++x) { 1076 const float diff = row_dist[x] / original_butteraugli; 1077 if (diff > 1.0f) { 1078 float old = row_q[x]; 1079 row_q[x] *= diff; 1080 int qf_old = 1081 static_cast<int>(std::lround(old * quantizer.InvGlobalScale())); 1082 int qf_new = static_cast<int>( 1083 std::lround(row_q[x] * quantizer.InvGlobalScale())); 1084 if (qf_old == qf_new) { 1085 row_q[x] = old + quantizer.Scale(); 1086 } 1087 } 1088 if (row_q[x] > qf_higher) row_q[x] = qf_higher; 1089 if (row_q[x] < qf_lower) row_q[x] = qf_lower; 1090 } 1091 } 1092 } else { 1093 for (size_t y = 0; y < quant_field.ysize(); ++y) { 1094 const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); 1095 float* const JXL_RESTRICT row_q = quant_field.Row(y); 1096 for (size_t x = 0; x < quant_field.xsize(); ++x) { 1097 const float diff = row_dist[x] / original_butteraugli; 1098 if (diff <= 1.0f) { 1099 row_q[x] *= std::pow(diff, cur_pow); 1100 } else { 1101 float old = row_q[x]; 1102 row_q[x] *= diff; 1103 int qf_old = 1104 static_cast<int>(std::lround(old * quantizer.InvGlobalScale())); 1105 int qf_new = static_cast<int>( 1106 std::lround(row_q[x] * quantizer.InvGlobalScale())); 1107 if (qf_old == qf_new) { 1108 row_q[x] = old + quantizer.Scale(); 1109 } 1110 } 1111 if (row_q[x] > qf_higher) row_q[x] = qf_higher; 1112 if (row_q[x] < qf_lower) row_q[x] = qf_lower; 1113 } 1114 } 1115 } 1116 } 1117 JXL_RETURN_IF_ERROR( 1118 quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field)); 1119 return true; 1120 } 1121 1122 Status FindBestQuantizationMaxError(const FrameHeader& frame_header, 1123 const Image3F& opsin, ImageF& quant_field, 1124 PassesEncoderState* enc_state, 1125 const JxlCmsInterface& cms, 1126 ThreadPool* pool, AuxOut* aux_out) { 1127 // TODO(szabadka): Make this work for non-opsin color spaces. 1128 const CompressParams& cparams = enc_state->cparams; 1129 Quantizer& quantizer = enc_state->shared.quantizer; 1130 ImageI& raw_quant_field = enc_state->shared.raw_quant_field; 1131 1132 // TODO(veluca): better choice of this value. 1133 const float initial_quant_dc = 1134 16 * std::sqrt(0.1f / cparams.butteraugli_distance); 1135 JXL_RETURN_IF_ERROR( 1136 AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), 1137 cparams.original_butteraugli_distance, &quant_field)); 1138 1139 const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0], 1140 1.0f / enc_state->cparams.max_error[1], 1141 1.0f / enc_state->cparams.max_error[2]}; 1142 1143 for (int i = 0; i < kMaxButteraugliIters + 1; ++i) { 1144 JXL_RETURN_IF_ERROR(quantizer.SetQuantField(initial_quant_dc, quant_field, 1145 &raw_quant_field)); 1146 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) { 1147 JXL_RETURN_IF_ERROR( 1148 DumpXybImage(cparams, ("ops" + ToString(i)).c_str(), opsin)); 1149 } 1150 JXL_ASSIGN_OR_RETURN( 1151 ImageBundle decoded, 1152 RoundtripImage(frame_header, opsin, enc_state, cms, pool)); 1153 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) { 1154 JXL_RETURN_IF_ERROR(DumpXybImage(cparams, ("dec" + ToString(i)).c_str(), 1155 *decoded.color())); 1156 } 1157 for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) { 1158 AcStrategyRow ac_strategy_row = 1159 enc_state->shared.ac_strategy.ConstRow(by); 1160 for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) { 1161 AcStrategy acs = ac_strategy_row[bx]; 1162 if (!acs.IsFirstBlock()) continue; 1163 float max_error = 0; 1164 for (size_t c = 0; c < 3; c++) { 1165 for (size_t y = by * kBlockDim; 1166 y < (by + acs.covered_blocks_y()) * kBlockDim; y++) { 1167 if (y >= decoded.ysize()) continue; 1168 const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y); 1169 const float* JXL_RESTRICT dec_row = 1170 decoded.color()->ConstPlaneRow(c, y); 1171 for (size_t x = bx * kBlockDim; 1172 x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) { 1173 if (x >= decoded.xsize()) continue; 1174 max_error = std::max( 1175 std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error); 1176 } 1177 } 1178 } 1179 // Target an error between max_error/2 and max_error. 1180 // If the error in the varblock is above the target, increase the qf to 1181 // compensate. If the error is below the target, decrease the qf. 1182 // However, to avoid an excessive increase of the qf, only do so if the 1183 // error is less than half the maximum allowed error. 1184 const float qf_mul = (max_error < 0.5f) ? max_error * 2.0f 1185 : (max_error > 1.0f) ? max_error 1186 : 1.0f; 1187 for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) { 1188 float* JXL_RESTRICT quant_field_row = quant_field.Row(qy); 1189 for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) { 1190 quant_field_row[qx] *= qf_mul; 1191 } 1192 } 1193 } 1194 } 1195 } 1196 JXL_RETURN_IF_ERROR( 1197 quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field)); 1198 return true; 1199 } 1200 1201 } // namespace 1202 1203 Status AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, 1204 float butteraugli_target, ImageF* quant_field) { 1205 // Replace the whole quant_field in non-8x8 blocks with the maximum of each 1206 // 8x8 block. 1207 size_t stride = quant_field->PixelsPerRow(); 1208 1209 // At low distances it is great to use max, but mean works better 1210 // at high distances. We interpolate between them for a distance 1211 // range. 1212 float mean_max_mixer = 1.0f; 1213 { 1214 static const float kLimit = 1.54138f; 1215 static const float kMul = 0.56391f; 1216 static const float kMin = 0.0f; 1217 if (butteraugli_target > kLimit) { 1218 mean_max_mixer -= (butteraugli_target - kLimit) * kMul; 1219 if (mean_max_mixer < kMin) { 1220 mean_max_mixer = kMin; 1221 } 1222 } 1223 } 1224 for (size_t y = 0; y < rect.ysize(); ++y) { 1225 AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y); 1226 float* JXL_RESTRICT quant_row = rect.Row(quant_field, y); 1227 for (size_t x = 0; x < rect.xsize(); ++x) { 1228 AcStrategy acs = ac_strategy_row[x]; 1229 if (!acs.IsFirstBlock()) continue; 1230 JXL_ENSURE(x + acs.covered_blocks_x() <= quant_field->xsize()); 1231 JXL_ENSURE(y + acs.covered_blocks_y() <= quant_field->ysize()); 1232 float max = quant_row[x]; 1233 float mean = 0.0; 1234 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { 1235 for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { 1236 mean += quant_row[x + ix + iy * stride]; 1237 max = std::max(quant_row[x + ix + iy * stride], max); 1238 } 1239 } 1240 mean /= acs.covered_blocks_y() * acs.covered_blocks_x(); 1241 if (acs.covered_blocks_y() * acs.covered_blocks_x() >= 4) { 1242 max *= mean_max_mixer; 1243 max += (1.0f - mean_max_mixer) * mean; 1244 } 1245 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { 1246 for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { 1247 quant_row[x + ix + iy * stride] = max; 1248 } 1249 } 1250 } 1251 } 1252 return true; 1253 } 1254 1255 float InitialQuantDC(float butteraugli_target) { 1256 const float kDcMul = 0.3; // Butteraugli target where non-linearity kicks in. 1257 const float butteraugli_target_dc = std::max<float>( 1258 0.5f * butteraugli_target, 1259 std::min<float>(butteraugli_target, 1260 kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target, 1261 kDcQuantPow))); 1262 // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc. 1263 // The maximum DC value might not be in the kXybRange because of inverse 1264 // gaborish, so we add some slack to the maximum theoretical quant obtained 1265 // this way (64). 1266 return std::min(kDcQuant / butteraugli_target_dc, 50.f); 1267 } 1268 1269 StatusOr<ImageF> InitialQuantField(const float butteraugli_target, 1270 const Image3F& opsin, const Rect& rect, 1271 ThreadPool* pool, float rescale, 1272 ImageF* mask, ImageF* mask1x1) { 1273 const float quant_ac = kAcQuant / butteraugli_target; 1274 return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)( 1275 butteraugli_target, opsin, rect, quant_ac * rescale, pool, mask, mask1x1); 1276 } 1277 1278 Status FindBestQuantizer(const FrameHeader& frame_header, const Image3F* linear, 1279 const Image3F& opsin, ImageF& quant_field, 1280 PassesEncoderState* enc_state, 1281 const JxlCmsInterface& cms, ThreadPool* pool, 1282 AuxOut* aux_out, double rescale) { 1283 const CompressParams& cparams = enc_state->cparams; 1284 if (cparams.max_error_mode) { 1285 JXL_RETURN_IF_ERROR(FindBestQuantizationMaxError( 1286 frame_header, opsin, quant_field, enc_state, cms, pool, aux_out)); 1287 } else if (linear && cparams.speed_tier <= SpeedTier::kKitten) { 1288 // Normal encoding to a butteraugli score. 1289 JXL_RETURN_IF_ERROR(FindBestQuantization(frame_header, *linear, opsin, 1290 quant_field, enc_state, cms, pool, 1291 aux_out)); 1292 } 1293 return true; 1294 } 1295 1296 } // namespace jxl 1297 #endif // HWY_ONCE