tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_adaptive_quantization.cc (53125B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_adaptive_quantization.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <algorithm>
     11 #include <atomic>
     12 #include <cmath>
     13 #include <cstddef>
     14 #include <cstdlib>
     15 #include <string>
     16 #include <vector>
     17 
     18 #include "lib/jxl/memory_manager_internal.h"
     19 
     20 #undef HWY_TARGET_INCLUDE
     21 #define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc"
     22 #include <hwy/foreach_target.h>
     23 #include <hwy/highway.h>
     24 
     25 #include "lib/jxl/ac_strategy.h"
     26 #include "lib/jxl/base/common.h"
     27 #include "lib/jxl/base/compiler_specific.h"
     28 #include "lib/jxl/base/data_parallel.h"
     29 #include "lib/jxl/base/fast_math-inl.h"
     30 #include "lib/jxl/base/rect.h"
     31 #include "lib/jxl/base/status.h"
     32 #include "lib/jxl/butteraugli/butteraugli.h"
     33 #include "lib/jxl/convolve.h"
     34 #include "lib/jxl/dec_cache.h"
     35 #include "lib/jxl/dec_group.h"
     36 #include "lib/jxl/enc_aux_out.h"
     37 #include "lib/jxl/enc_butteraugli_comparator.h"
     38 #include "lib/jxl/enc_cache.h"
     39 #include "lib/jxl/enc_debug_image.h"
     40 #include "lib/jxl/enc_group.h"
     41 #include "lib/jxl/enc_modular.h"
     42 #include "lib/jxl/enc_params.h"
     43 #include "lib/jxl/enc_transforms-inl.h"
     44 #include "lib/jxl/epf.h"
     45 #include "lib/jxl/frame_dimensions.h"
     46 #include "lib/jxl/image.h"
     47 #include "lib/jxl/image_bundle.h"
     48 #include "lib/jxl/image_ops.h"
     49 #include "lib/jxl/quant_weights.h"
     50 
     51 // Set JXL_DEBUG_ADAPTIVE_QUANTIZATION to 1 to enable debugging.
     52 #ifndef JXL_DEBUG_ADAPTIVE_QUANTIZATION
     53 #define JXL_DEBUG_ADAPTIVE_QUANTIZATION 0
     54 #endif
     55 
     56 HWY_BEFORE_NAMESPACE();
     57 namespace jxl {
     58 namespace HWY_NAMESPACE {
     59 namespace {
     60 
     61 // These templates are not found via ADL.
     62 using hwy::HWY_NAMESPACE::AbsDiff;
     63 using hwy::HWY_NAMESPACE::Add;
     64 using hwy::HWY_NAMESPACE::And;
     65 using hwy::HWY_NAMESPACE::Gt;
     66 using hwy::HWY_NAMESPACE::IfThenElseZero;
     67 using hwy::HWY_NAMESPACE::Max;
     68 using hwy::HWY_NAMESPACE::Min;
     69 using hwy::HWY_NAMESPACE::Rebind;
     70 using hwy::HWY_NAMESPACE::Sqrt;
     71 using hwy::HWY_NAMESPACE::ZeroIfNegative;
     72 
     73 // The following functions modulate an exponent (out_val) and return the updated
     74 // value. Their descriptor is limited to 8 lanes for 8x8 blocks.
     75 
     76 // Hack for mask estimation. Eventually replace this code with butteraugli's
     77 // masking.
     78 float ComputeMaskForAcStrategyUse(const float out_val) {
     79  const float kMul = 1.0f;
     80  const float kOffset = 0.001f;
     81  return kMul / (out_val + kOffset);
     82 }
     83 
     84 template <class D, class V>
     85 V ComputeMask(const D d, const V out_val) {
     86  const auto kBase = Set(d, -0.7647f);
     87  const auto kMul4 = Set(d, 9.4708735624378946f);
     88  const auto kMul2 = Set(d, 17.35036561631863f);
     89  const auto kOffset2 = Set(d, 302.59587815579727f);
     90  const auto kMul3 = Set(d, 6.7943250517376494f);
     91  const auto kOffset3 = Set(d, 3.7179635626140772f);
     92  const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3);
     93  const auto kMul0 = Set(d, 0.80061762862741759f);
     94  const auto k1 = Set(d, 1.0f);
     95 
     96  // Avoid division by zero.
     97  const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f));
     98  const auto v2 = Div(k1, Add(v1, kOffset2));
     99  const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3));
    100  const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4));
    101  // TODO(jyrki):
    102  // A log or two here could make sense. In butteraugli we have effectively
    103  // log(log(x + C)) for this kind of use, as a single log is used in
    104  // saturating visual masking and here the modulation values are exponential,
    105  // another log would counter that.
    106  return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3))));
    107 }
    108 
    109 // mul and mul2 represent a scaling difference between jxl and butteraugli.
    110 const float kSGmul = 226.77216153508914f;
    111 const float kSGmul2 = 1.0f / 73.377132366608819f;
    112 const float kLog2 = 0.693147181f;
    113 // Includes correction factor for std::log -> log2.
    114 const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2;
    115 const float kSGVOffset = 7.7825991679894591f;
    116 
    117 template <bool invert, typename D, typename V>
    118 V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) {
    119  // The opsin space in jxl is the cubic root of photons, i.e., v * v * v
    120  // is related to the number of photons.
    121  //
    122  // SimpleGamma(v * v * v) is the psychovisual space in butteraugli.
    123  // This ratio allows quantization to move from jxl's opsin space to
    124  // butteraugli's log-gamma space.
    125  float kEpsilon = 1e-2;
    126  v = ZeroIfNegative(v);
    127  const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul);
    128  const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon);
    129  const auto kDenMul = Set(d, kLog2 * kSGmul);
    130 
    131  const auto v2 = Mul(v, v);
    132 
    133  const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon));
    134  const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset);
    135  return invert ? Div(num, den) : Div(den, num);
    136 }
    137 
    138 template <bool invert = false>
    139 float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) {
    140  using DScalar = HWY_CAPPED(float, 1);
    141  auto vscalar = Load(DScalar(), &v);
    142  return GetLane(
    143      RatioOfDerivativesOfCubicRootToSimpleGamma<invert>(DScalar(), vscalar));
    144 }
    145 
    146 // TODO(veluca): this function computes an approximation of the derivative of
    147 // SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or
    148 // exact derivatives. For reference, SimpleGamma was:
    149 /*
    150 template <typename D, typename V>
    151 V SimpleGamma(const D d, V v) {
    152  // A simple HDR compatible gamma function.
    153  const auto mul = Set(d, kSGmul);
    154  const auto kRetMul = Set(d, kSGRetMul);
    155  const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f);
    156  const auto kVOffset = Set(d, kSGVOffset);
    157 
    158  v *= mul;
    159 
    160  // This should happen rarely, but may lead to a NaN, which is rather
    161  // undesirable. Since negative photons don't exist we solve the NaNs by
    162  // clamping here.
    163  // TODO(veluca): with FastLog2f, this no longer leads to NaNs.
    164  v = ZeroIfNegative(v);
    165  return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd;
    166 }
    167 */
    168 
    169 template <class D, class V>
    170 V GammaModulation(const D d, const size_t x, const size_t y,
    171                  const ImageF& xyb_x, const ImageF& xyb_y, const Rect& rect,
    172                  const V out_val) {
    173  const float kBias = 0.16f;
    174  JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[0]);
    175  JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[1]);
    176  JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[2]);
    177  auto overall_ratio = Zero(d);
    178  auto bias = Set(d, kBias);
    179  for (size_t dy = 0; dy < 8; ++dy) {
    180    const float* const JXL_RESTRICT row_in_x = rect.ConstRow(xyb_x, y + dy);
    181    const float* const JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy);
    182    for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
    183      const auto iny = Add(Load(d, row_in_y + x + dx), bias);
    184      const auto inx = Load(d, row_in_x + x + dx);
    185 
    186      const auto r = Sub(iny, inx);
    187      const auto ratio_r =
    188          RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, r);
    189      overall_ratio = Add(overall_ratio, ratio_r);
    190 
    191      const auto g = Add(iny, inx);
    192      const auto ratio_g =
    193          RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, g);
    194      overall_ratio = Add(overall_ratio, ratio_g);
    195    }
    196  }
    197  overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 0.5f / 64));
    198  // ideally -1.0, but likely optimal correction adds some entropy, so slightly
    199  // less than that.
    200  const auto kGamma = Set(d, 0.1005613337192697f);
    201  return MulAdd(kGamma, FastLog2f(d, overall_ratio), out_val);
    202 }
    203 
    204 // Change precision in 8x8 blocks that have significant amounts of blue
    205 // content (but are not close to solid blue).
    206 // This is based on the idea that M and L cone activations saturate the
    207 // S (blue) receptors, and the S reception becomes more important when
    208 // both M and L levels are low. In that case M and L receptors may be
    209 // observing S-spectra instead and viewing them with higher spatial
    210 // accuracy, justifying spending more bits here.
    211 template <class D, class V>
    212 V BlueModulation(const D d, const size_t x, const size_t y,
    213                 const ImageF& planex, const ImageF& planey,
    214                 const ImageF& planeb, const Rect& rect, const V out_val) {
    215  auto sum = Zero(d);
    216  static const float kLimit = 0.027121074570634722;
    217  static const float kOffset = 0.084381641171960495;
    218  for (size_t dy = 0; dy < 8; ++dy) {
    219    const float* JXL_RESTRICT row_in_x = rect.ConstRow(planex, y + dy) + x;
    220    const float* JXL_RESTRICT row_in_y = rect.ConstRow(planey, y + dy) + x;
    221    const float* JXL_RESTRICT row_in_b = rect.ConstRow(planeb, y + dy) + x;
    222    for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
    223      const auto p_x = Load(d, row_in_x + dx);
    224      const auto p_b = Load(d, row_in_b + dx);
    225      const auto p_y_raw = Add(Load(d, row_in_y + dx), Set(d, kOffset));
    226      const auto p_y_effective = Add(p_y_raw, Abs(p_x));
    227      sum = Add(sum,
    228                IfThenElseZero(Gt(p_b, p_y_effective),
    229                               Min(Sub(p_b, p_y_effective), Set(d, kLimit))));
    230    }
    231  }
    232  static const float kMul = 0.14207000358439159;
    233  sum = SumOfLanes(d, sum);
    234  float scalar_sum = GetLane(sum);
    235  // If it is all blue, don't boost the quantization.
    236  // All blue likely means low frequency blue. Let's not make the most
    237  // perfect sky ever.
    238  if (scalar_sum >= 32 * kLimit) {
    239    scalar_sum = 64 * kLimit - scalar_sum;
    240  }
    241  static const float kMaxLimit = 15.398788439047934f;
    242  if (scalar_sum >= kMaxLimit * kLimit) {
    243    scalar_sum = kMaxLimit * kLimit;
    244  }
    245  scalar_sum *= kMul;
    246  return Add(Set(d, scalar_sum), out_val);
    247 }
    248 
    249 // Change precision in 8x8 blocks that have high frequency content.
    250 template <class D, class V>
    251 V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb_y,
    252               const Rect& rect, const V out_val) {
    253  // Zero out the invalid differences for the rightmost value per row.
    254  const Rebind<uint32_t, D> du;
    255  HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u,
    256                                                        ~0u, ~0u, ~0u, 0};
    257 
    258  // Sums of deltas of y and x components between (approximate)
    259  // 4-connected pixels.
    260  auto sum_y = Zero(d);
    261  static const float valmin_y = 0.0206;
    262  auto valminv_y = Set(d, valmin_y);
    263  for (size_t dy = 0; dy < 8; ++dy) {
    264    const float* JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy) + x;
    265    const float* JXL_RESTRICT row_in_y_next =
    266        dy == 7 ? row_in_y : rect.ConstRow(xyb_y, y + dy + 1) + x;
    267 
    268    // In SCALAR, there is no guarantee of having extra row padding.
    269    // Hence, we need to ensure we don't access pixels outside the row itself.
    270    // In SIMD modes, however, rows are padded, so it's safe to access one
    271    // garbage value after the row. The vector then gets masked with kMaskRight
    272    // to remove the influence of that value.
    273 #if HWY_TARGET != HWY_SCALAR
    274    for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
    275 #else
    276    for (size_t dx = 0; dx < 7; dx += Lanes(d)) {
    277 #endif
    278      const auto mask = BitCast(d, Load(du, kMaskRight + dx));
    279      {
    280        const auto p_y = Load(d, row_in_y + dx);
    281        const auto pr_y = LoadU(d, row_in_y + dx + 1);
    282        sum_y = Add(sum_y, And(mask, Min(valminv_y, AbsDiff(p_y, pr_y))));
    283        const auto pd_y = Load(d, row_in_y_next + dx);
    284        sum_y = Add(sum_y, Min(valminv_y, AbsDiff(p_y, pd_y)));
    285      }
    286    }
    287 #if HWY_TARGET == HWY_SCALAR
    288    const auto p_y = Load(d, row_in_y + 7);
    289    const auto pd_y = Load(d, row_in_y_next + 7);
    290    sum_y = Add(sum_y, Min(valminv_y, AbsDiff(p_y, pd_y)));
    291 #endif
    292  }
    293  static const float kMul_y = -0.38;
    294  sum_y = SumOfLanes(d, sum_y);
    295 
    296  float scalar_sum_y = GetLane(sum_y);
    297  scalar_sum_y *= kMul_y;
    298 
    299  // higher value -> more bpp
    300  float kOffset = 0.42;
    301  scalar_sum_y += kOffset;
    302 
    303  return Add(Set(d, scalar_sum_y), out_val);
    304 }
    305 
    306 void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x,
    307                         const ImageF& xyb_y, const ImageF& xyb_b,
    308                         const Rect& rect_in, const float scale,
    309                         const Rect& rect_out, ImageF* out) {
    310  float base_level = 0.48f * scale;
    311  float kDampenRampStart = 2.0f;
    312  float kDampenRampEnd = 14.0f;
    313  float dampen = 1.0f;
    314  if (butteraugli_target >= kDampenRampStart) {
    315    dampen = 1.0f - ((butteraugli_target - kDampenRampStart) /
    316                     (kDampenRampEnd - kDampenRampStart));
    317    if (dampen < 0) {
    318      dampen = 0;
    319    }
    320  }
    321  const float mul = scale * dampen;
    322  const float add = (1.0f - dampen) * base_level;
    323  for (size_t iy = rect_out.y0(); iy < rect_out.y1(); iy++) {
    324    const size_t y = iy * 8;
    325    float* const JXL_RESTRICT row_out = out->Row(iy);
    326    const HWY_CAPPED(float, kBlockDim) df;
    327    for (size_t ix = rect_out.x0(); ix < rect_out.x1(); ix++) {
    328      size_t x = ix * 8;
    329      auto out_val = Set(df, row_out[ix]);
    330      out_val = ComputeMask(df, out_val);
    331      out_val = HfModulation(df, x, y, xyb_y, rect_in, out_val);
    332      out_val = GammaModulation(df, x, y, xyb_x, xyb_y, rect_in, out_val);
    333      out_val = BlueModulation(df, x, y, xyb_x, xyb_y, xyb_b, rect_in, out_val);
    334      // We want multiplicative quantization field, so everything
    335      // until this point has been modulating the exponent.
    336      row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add;
    337    }
    338  }
    339 }
    340 
    341 template <typename D, typename V>
    342 V MaskingSqrt(const D d, V v) {
    343  static const float kLogOffset = 27.505837037000106f;
    344  static const float kMul = 211.66567973503678f;
    345  const auto mul_v = Set(d, kMul * 1e8);
    346  const auto offset_v = Set(d, kLogOffset);
    347  return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v)));
    348 }
    349 
    350 float MaskingSqrt(const float v) {
    351  using DScalar = HWY_CAPPED(float, 1);
    352  auto vscalar = Load(DScalar(), &v);
    353  return GetLane(MaskingSqrt(DScalar(), vscalar));
    354 }
    355 
    356 void StoreMin4(const float v, float& min0, float& min1, float& min2,
    357               float& min3) {
    358  if (v < min3) {
    359    if (v < min0) {
    360      min3 = min2;
    361      min2 = min1;
    362      min1 = min0;
    363      min0 = v;
    364    } else if (v < min1) {
    365      min3 = min2;
    366      min2 = min1;
    367      min1 = v;
    368    } else if (v < min2) {
    369      min3 = min2;
    370      min2 = v;
    371    } else {
    372      min3 = v;
    373    }
    374  }
    375 }
    376 
    377 // Look for smooth areas near the area of degradation.
    378 // If the areas are generally smooth, don't do masking.
    379 // Output is downsampled 2x.
    380 Status FuzzyErosion(const float butteraugli_target, const Rect& from_rect,
    381                    const ImageF& from, const Rect& to_rect, ImageF* to) {
    382  const size_t xsize = from.xsize();
    383  const size_t ysize = from.ysize();
    384  constexpr int kStep = 1;
    385  static_assert(kStep == 1, "Step must be 1");
    386  JXL_ENSURE(to_rect.xsize() * 2 == from_rect.xsize());
    387  JXL_ENSURE(to_rect.ysize() * 2 == from_rect.ysize());
    388  static const float kMulBase0 = 0.125;
    389  static const float kMulBase1 = 0.10;
    390  static const float kMulBase2 = 0.09;
    391  static const float kMulBase3 = 0.06;
    392  static const float kMulAdd0 = 0.0;
    393  static const float kMulAdd1 = -0.10;
    394  static const float kMulAdd2 = -0.09;
    395  static const float kMulAdd3 = -0.06;
    396 
    397  float mul = 0.0;
    398  if (butteraugli_target < 2.0f) {
    399    mul = (2.0f - butteraugli_target) * (1.0f / 2.0f);
    400  }
    401  float kMul0 = kMulBase0 + mul * kMulAdd0;
    402  float kMul1 = kMulBase1 + mul * kMulAdd1;
    403  float kMul2 = kMulBase2 + mul * kMulAdd2;
    404  float kMul3 = kMulBase3 + mul * kMulAdd3;
    405  static const float kTotal = 0.29959705784054957;
    406  float norm = kTotal / (kMul0 + kMul1 + kMul2 + kMul3);
    407  kMul0 *= norm;
    408  kMul1 *= norm;
    409  kMul2 *= norm;
    410  kMul3 *= norm;
    411 
    412  for (size_t fy = 0; fy < from_rect.ysize(); ++fy) {
    413    size_t y = fy + from_rect.y0();
    414    size_t ym1 = y >= kStep ? y - kStep : y;
    415    size_t yp1 = y + kStep < ysize ? y + kStep : y;
    416    const float* rowt = from.Row(ym1);
    417    const float* row = from.Row(y);
    418    const float* rowb = from.Row(yp1);
    419    float* row_out = to_rect.Row(to, fy / 2);
    420    for (size_t fx = 0; fx < from_rect.xsize(); ++fx) {
    421      size_t x = fx + from_rect.x0();
    422      size_t xm1 = x >= kStep ? x - kStep : x;
    423      size_t xp1 = x + kStep < xsize ? x + kStep : x;
    424      float min0 = row[x];
    425      float min1 = row[xm1];
    426      float min2 = row[xp1];
    427      float min3 = rowt[xm1];
    428      // Sort the first four values.
    429      if (min0 > min1) std::swap(min0, min1);
    430      if (min0 > min2) std::swap(min0, min2);
    431      if (min0 > min3) std::swap(min0, min3);
    432      if (min1 > min2) std::swap(min1, min2);
    433      if (min1 > min3) std::swap(min1, min3);
    434      if (min2 > min3) std::swap(min2, min3);
    435      // The remaining five values of a 3x3 neighbourhood.
    436      StoreMin4(rowt[x], min0, min1, min2, min3);
    437      StoreMin4(rowt[xp1], min0, min1, min2, min3);
    438      StoreMin4(rowb[xm1], min0, min1, min2, min3);
    439      StoreMin4(rowb[x], min0, min1, min2, min3);
    440      StoreMin4(rowb[xp1], min0, min1, min2, min3);
    441 
    442      float v = kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + kMul3 * min3;
    443      if (fx % 2 == 0 && fy % 2 == 0) {
    444        row_out[fx / 2] = v;
    445      } else {
    446        row_out[fx / 2] += v;
    447      }
    448    }
    449  }
    450  return true;
    451 }
    452 
    453 struct AdaptiveQuantizationImpl {
    454  Status PrepareBuffers(JxlMemoryManager* memory_manager, size_t num_threads) {
    455    JXL_ASSIGN_OR_RETURN(
    456        diff_buffer,
    457        ImageF::Create(memory_manager, kEncTileDim + 8, num_threads));
    458    for (size_t i = pre_erosion.size(); i < num_threads; i++) {
    459      JXL_ASSIGN_OR_RETURN(
    460          ImageF tmp,
    461          ImageF::Create(memory_manager, kEncTileDimInBlocks * 2 + 2,
    462                         kEncTileDimInBlocks * 2 + 2));
    463      pre_erosion.emplace_back(std::move(tmp));
    464    }
    465    return true;
    466  }
    467 
    468  Status ComputeTile(float butteraugli_target, float scale, const Image3F& xyb,
    469                     const Rect& rect_in, const Rect& rect_out,
    470                     const int thread, ImageF* mask, ImageF* mask1x1) {
    471    JXL_ENSURE(rect_in.x0() % kBlockDim == 0);
    472    JXL_ENSURE(rect_in.y0() % kBlockDim == 0);
    473    const size_t xsize = xyb.xsize();
    474    const size_t ysize = xyb.ysize();
    475 
    476    // The XYB gamma is 3.0 to be able to decode faster with two muls.
    477    // Butteraugli's gamma is matching the gamma of human eye, around 2.6.
    478    // We approximate the gamma difference by adding one cubic root into
    479    // the adaptive quantization. This gives us a total gamma of 2.6666
    480    // for quantization uses.
    481    const float match_gamma_offset = 0.019;
    482 
    483    const HWY_FULL(float) df;
    484 
    485    size_t y_start_1x1 = rect_in.y0() + rect_out.y0() * 8;
    486    size_t y_end_1x1 = y_start_1x1 + rect_out.ysize() * 8;
    487 
    488    size_t x_start_1x1 = rect_in.x0() + rect_out.x0() * 8;
    489    size_t x_end_1x1 = x_start_1x1 + rect_out.xsize() * 8;
    490 
    491    if (rect_in.x0() != 0 && rect_out.x0() == 0) x_start_1x1 -= 2;
    492    if (rect_in.x1() < xsize && rect_out.x1() * 8 == rect_in.xsize()) {
    493      x_end_1x1 += 2;
    494    }
    495    if (rect_in.y0() != 0 && rect_out.y0() == 0) y_start_1x1 -= 2;
    496    if (rect_in.y1() < ysize && rect_out.y1() * 8 == rect_in.ysize()) {
    497      y_end_1x1 += 2;
    498    }
    499 
    500    // Computes image (padded to multiple of 8x8) of local pixel differences.
    501    // Subsample both directions by 4.
    502    // 1x1 Laplacian of intensity.
    503    for (size_t y = y_start_1x1; y < y_end_1x1; ++y) {
    504      const size_t y2 = y + 1 < ysize ? y + 1 : y;
    505      const size_t y1 = y > 0 ? y - 1 : y;
    506      const float* row_in = xyb.ConstPlaneRow(1, y);
    507      const float* row_in1 = xyb.ConstPlaneRow(1, y1);
    508      const float* row_in2 = xyb.ConstPlaneRow(1, y2);
    509      float* mask1x1_out = mask1x1->Row(y);
    510      auto scalar_pixel1x1 = [&](size_t x) {
    511        const size_t x2 = x + 1 < xsize ? x + 1 : x;
    512        const size_t x1 = x > 0 ? x - 1 : x;
    513        const float base =
    514            0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]);
    515        const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma(
    516            row_in[x] + match_gamma_offset);
    517        float diff = fabs(gammac * (row_in[x] - base));
    518        static const double kScaler = 1.0;
    519        diff *= kScaler;
    520        diff = log1p(diff);
    521        static const float kMul = 1.0;
    522        static const float kOffset = 0.01;
    523        mask1x1_out[x] = kMul / (diff + kOffset);
    524      };
    525      for (size_t x = x_start_1x1; x < x_end_1x1; ++x) {
    526        scalar_pixel1x1(x);
    527      }
    528    }
    529 
    530    size_t y_start = rect_in.y0() + rect_out.y0() * 8;
    531    size_t y_end = y_start + rect_out.ysize() * 8;
    532 
    533    size_t x_start = rect_in.x0() + rect_out.x0() * 8;
    534    size_t x_end = x_start + rect_out.xsize() * 8;
    535 
    536    if (x_start != 0) x_start -= 4;
    537    if (x_end != xsize) x_end += 4;
    538    if (y_start != 0) y_start -= 4;
    539    if (y_end != ysize) y_end += 4;
    540    JXL_RETURN_IF_ERROR(pre_erosion[thread].ShrinkTo((x_end - x_start) / 4,
    541                                                     (y_end - y_start) / 4));
    542 
    543    static const float limit = 0.2f;
    544    for (size_t y = y_start; y < y_end; ++y) {
    545      size_t y2 = y + 1 < ysize ? y + 1 : y;
    546      size_t y1 = y > 0 ? y - 1 : y;
    547 
    548      const float* row_in = xyb.ConstPlaneRow(1, y);
    549      const float* row_in1 = xyb.ConstPlaneRow(1, y1);
    550      const float* row_in2 = xyb.ConstPlaneRow(1, y2);
    551      float* JXL_RESTRICT row_out = diff_buffer.Row(thread);
    552 
    553      auto scalar_pixel = [&](size_t x) {
    554        const size_t x2 = x + 1 < xsize ? x + 1 : x;
    555        const size_t x1 = x > 0 ? x - 1 : x;
    556        const float base =
    557            0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]);
    558        const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma(
    559            row_in[x] + match_gamma_offset);
    560        float diff = gammac * (row_in[x] - base);
    561        diff *= diff;
    562        if (diff >= limit) {
    563          diff = limit;
    564        }
    565        diff = MaskingSqrt(diff);
    566        if ((y % 4) != 0) {
    567          row_out[x - x_start] += diff;
    568        } else {
    569          row_out[x - x_start] = diff;
    570        }
    571      };
    572 
    573      size_t x = x_start;
    574      // First pixel of the row.
    575      if (x_start == 0) {
    576        scalar_pixel(x_start);
    577        ++x;
    578      }
    579      // SIMD
    580      const auto match_gamma_offset_v = Set(df, match_gamma_offset);
    581      const auto quarter = Set(df, 0.25f);
    582      for (; x + 1 + Lanes(df) < x_end; x += Lanes(df)) {
    583        const auto in = LoadU(df, row_in + x);
    584        const auto in_r = LoadU(df, row_in + x + 1);
    585        const auto in_l = LoadU(df, row_in + x - 1);
    586        const auto in_t = LoadU(df, row_in2 + x);
    587        const auto in_b = LoadU(df, row_in1 + x);
    588        auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b)));
    589        auto gammacv =
    590            RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/false>(
    591                df, Add(in, match_gamma_offset_v));
    592        auto diff = Mul(gammacv, Sub(in, base));
    593        diff = Mul(diff, diff);
    594        diff = Min(diff, Set(df, limit));
    595        diff = MaskingSqrt(df, diff);
    596        if ((y & 3) != 0) {
    597          diff = Add(diff, LoadU(df, row_out + x - x_start));
    598        }
    599        StoreU(diff, df, row_out + x - x_start);
    600      }
    601      // Scalar
    602      for (; x < x_end; ++x) {
    603        scalar_pixel(x);
    604      }
    605      if (y % 4 == 3) {
    606        float* row_d_out = pre_erosion[thread].Row((y - y_start) / 4);
    607        for (size_t x = 0; x < (x_end - x_start) / 4; x++) {
    608          row_d_out[x] = (row_out[x * 4] + row_out[x * 4 + 1] +
    609                          row_out[x * 4 + 2] + row_out[x * 4 + 3]) *
    610                         0.25f;
    611        }
    612      }
    613    }
    614    JXL_ENSURE(x_start % (kBlockDim / 2) == 0);
    615    JXL_ENSURE(y_start % (kBlockDim / 2) == 0);
    616    Rect from_rect(x_start % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1,
    617                   rect_out.xsize() * 2, rect_out.ysize() * 2);
    618    JXL_RETURN_IF_ERROR(FuzzyErosion(butteraugli_target, from_rect,
    619                                     pre_erosion[thread], rect_out, &aq_map));
    620    for (size_t y = 0; y < rect_out.ysize(); ++y) {
    621      const float* aq_map_row = rect_out.ConstRow(aq_map, y);
    622      float* mask_row = rect_out.Row(mask, y);
    623      for (size_t x = 0; x < rect_out.xsize(); ++x) {
    624        mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]);
    625      }
    626    }
    627    PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1),
    628                        xyb.Plane(2), rect_in, scale, rect_out, &aq_map);
    629    return true;
    630  }
    631  std::vector<ImageF> pre_erosion;
    632  ImageF aq_map;
    633  ImageF diff_buffer;
    634 };
    635 
    636 Status Blur1x1Masking(JxlMemoryManager* memory_manager, ThreadPool* pool,
    637                      ImageF* mask1x1, const Rect& rect) {
    638  // Blur the mask1x1 to obtain the masking image.
    639  // Before blurring it contains an image of absolute value of the
    640  // Laplacian of the intensity channel.
    641  static const float kFilterMask1x1[5] = {
    642      static_cast<float>(0.25647067633737227),
    643      static_cast<float>(0.2050056912354399075),
    644      static_cast<float>(0.154082048668497307),
    645      static_cast<float>(0.08149576591362004441),
    646      static_cast<float>(0.0512750104812308467),
    647  };
    648  double sum =
    649      1.0 + 4 * (kFilterMask1x1[0] + kFilterMask1x1[1] + kFilterMask1x1[2] +
    650                 kFilterMask1x1[4] + 2 * kFilterMask1x1[3]);
    651  if (sum < 1e-5) {
    652    sum = 1e-5;
    653  }
    654  const float normalize = static_cast<float>(1.0 / sum);
    655  const float normalize_mul = normalize;
    656  WeightsSymmetric5 weights =
    657      WeightsSymmetric5{{HWY_REP4(normalize)},
    658                        {HWY_REP4(normalize_mul * kFilterMask1x1[0])},
    659                        {HWY_REP4(normalize_mul * kFilterMask1x1[2])},
    660                        {HWY_REP4(normalize_mul * kFilterMask1x1[1])},
    661                        {HWY_REP4(normalize_mul * kFilterMask1x1[4])},
    662                        {HWY_REP4(normalize_mul * kFilterMask1x1[3])}};
    663  JXL_ASSIGN_OR_RETURN(
    664      ImageF temp, ImageF::Create(memory_manager, rect.xsize(), rect.ysize()));
    665  JXL_RETURN_IF_ERROR(Symmetric5(*mask1x1, rect, weights, pool, &temp));
    666  *mask1x1 = std::move(temp);
    667  return true;
    668 }
    669 
    670 StatusOr<ImageF> AdaptiveQuantizationMap(const float butteraugli_target,
    671                                         const Image3F& xyb, const Rect& rect,
    672                                         float scale, ThreadPool* pool,
    673                                         ImageF* mask, ImageF* mask1x1) {
    674  JXL_ENSURE(rect.xsize() % kBlockDim == 0);
    675  JXL_ENSURE(rect.ysize() % kBlockDim == 0);
    676  AdaptiveQuantizationImpl impl;
    677  const size_t xsize_blocks = rect.xsize() / kBlockDim;
    678  const size_t ysize_blocks = rect.ysize() / kBlockDim;
    679  JxlMemoryManager* memory_manager = xyb.memory_manager();
    680  JXL_ASSIGN_OR_RETURN(
    681      impl.aq_map, ImageF::Create(memory_manager, xsize_blocks, ysize_blocks));
    682  JXL_ASSIGN_OR_RETURN(
    683      *mask, ImageF::Create(memory_manager, xsize_blocks, ysize_blocks));
    684  JXL_ASSIGN_OR_RETURN(
    685      *mask1x1, ImageF::Create(memory_manager, xyb.xsize(), xyb.ysize()));
    686  const auto prepare = [&](const size_t num_threads) -> Status {
    687    JXL_RETURN_IF_ERROR(impl.PrepareBuffers(memory_manager, num_threads));
    688    return true;
    689  };
    690  const auto process_tile = [&](const uint32_t tid,
    691                                const size_t thread) -> Status {
    692    size_t n_enc_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks);
    693    size_t tx = tid % n_enc_tiles;
    694    size_t ty = tid / n_enc_tiles;
    695    size_t by0 = ty * kEncTileDimInBlocks;
    696    size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, ysize_blocks);
    697    size_t bx0 = tx * kEncTileDimInBlocks;
    698    size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, xsize_blocks);
    699    Rect rect_out(bx0, by0, bx1 - bx0, by1 - by0);
    700    JXL_RETURN_IF_ERROR(impl.ComputeTile(butteraugli_target, scale, xyb, rect,
    701                                         rect_out, thread, mask, mask1x1));
    702    return true;
    703  };
    704  size_t num_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks) *
    705                     DivCeil(ysize_blocks, kEncTileDimInBlocks);
    706  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_tiles, prepare, process_tile,
    707                                "AQ DiffPrecompute"));
    708 
    709  JXL_RETURN_IF_ERROR(Blur1x1Masking(memory_manager, pool, mask1x1, rect));
    710  return std::move(impl).aq_map;
    711 }
    712 
    713 }  // namespace
    714 
    715 // NOLINTNEXTLINE(google-readability-namespace-comments)
    716 }  // namespace HWY_NAMESPACE
    717 }  // namespace jxl
    718 HWY_AFTER_NAMESPACE();
    719 
    720 #if HWY_ONCE
    721 namespace jxl {
    722 HWY_EXPORT(AdaptiveQuantizationMap);
    723 
    724 namespace {
    725 
    726 // If true, prints the quantization maps at each iteration.
    727 constexpr bool FLAGS_dump_quant_state = false;
    728 
    729 Status DumpHeatmap(const CompressParams& cparams, const AuxOut* aux_out,
    730                   const std::string& label, const ImageF& image,
    731                   float good_threshold, float bad_threshold) {
    732  if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
    733    JXL_ASSIGN_OR_RETURN(
    734        Image3F heatmap,
    735        CreateHeatMapImage(image, good_threshold, bad_threshold));
    736    char filename[200];
    737    snprintf(filename, sizeof(filename), "%s%05d", label.c_str(),
    738             aux_out->num_butteraugli_iters);
    739    JXL_RETURN_IF_ERROR(DumpImage(cparams, filename, heatmap));
    740  }
    741  return true;
    742 }
    743 
    744 Status DumpHeatmaps(const CompressParams& cparams, const AuxOut* aux_out,
    745                    float butteraugli_target, const ImageF& quant_field,
    746                    const ImageF& tile_heatmap, const ImageF& bt_diffmap) {
    747  if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
    748    JxlMemoryManager* memory_manager = quant_field.memory_manager();
    749    if (!WantDebugOutput(cparams)) return true;
    750    JXL_ASSIGN_OR_RETURN(ImageF inv_qmap,
    751                         ImageF::Create(memory_manager, quant_field.xsize(),
    752                                        quant_field.ysize()));
    753    for (size_t y = 0; y < quant_field.ysize(); ++y) {
    754      const float* JXL_RESTRICT row_q = quant_field.ConstRow(y);
    755      float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y);
    756      for (size_t x = 0; x < quant_field.xsize(); ++x) {
    757        row_inv_q[x] = 1.0f / row_q[x];  // never zero
    758      }
    759    }
    760    JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "quant_heatmap", inv_qmap,
    761                                    4.0f * butteraugli_target,
    762                                    6.0f * butteraugli_target));
    763    JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "tile_heatmap",
    764                                    tile_heatmap, butteraugli_target,
    765                                    1.5f * butteraugli_target));
    766    // matches heat maps produced by the command line tool.
    767    JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "bt_diffmap", bt_diffmap,
    768                                    ButteraugliFuzzyInverse(1.5),
    769                                    ButteraugliFuzzyInverse(0.5)));
    770  }
    771  return true;
    772 }
    773 
    774 StatusOr<ImageF> TileDistMap(const ImageF& distmap, int tile_size, int margin,
    775                             const AcStrategyImage& ac_strategy) {
    776  const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size;
    777  const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size;
    778  JxlMemoryManager* memory_manager = distmap.memory_manager();
    779  JXL_ASSIGN_OR_RETURN(ImageF tile_distmap,
    780                       ImageF::Create(memory_manager, tile_xsize, tile_ysize));
    781  size_t distmap_stride = tile_distmap.PixelsPerRow();
    782  for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) {
    783    AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y);
    784    float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y);
    785    for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) {
    786      AcStrategy acs = ac_strategy_row[tile_x];
    787      if (!acs.IsFirstBlock()) continue;
    788      int this_tile_xsize = acs.covered_blocks_x() * tile_size;
    789      int this_tile_ysize = acs.covered_blocks_y() * tile_size;
    790      int y_begin = std::max<int>(0, tile_size * tile_y - margin);
    791      int y_end = std::min<int>(distmap.ysize(),
    792                                tile_size * tile_y + this_tile_ysize + margin);
    793      int x_begin = std::max<int>(0, tile_size * tile_x - margin);
    794      int x_end = std::min<int>(distmap.xsize(),
    795                                tile_size * tile_x + this_tile_xsize + margin);
    796      float dist_norm = 0.0;
    797      double pixels = 0;
    798      for (int y = y_begin; y < y_end; ++y) {
    799        float ymul = 1.0;
    800        constexpr float kBorderMul = 0.98f;
    801        constexpr float kCornerMul = 0.7f;
    802        if (margin != 0 && (y == y_begin || y == y_end - 1)) {
    803          ymul = kBorderMul;
    804        }
    805        const float* const JXL_RESTRICT row = distmap.Row(y);
    806        for (int x = x_begin; x < x_end; ++x) {
    807          float xmul = ymul;
    808          if (margin != 0 && (x == x_begin || x == x_end - 1)) {
    809            if (xmul == 1.0) {
    810              xmul = kBorderMul;
    811            } else {
    812              xmul = kCornerMul;
    813            }
    814          }
    815          float v = row[x];
    816          v *= v;
    817          v *= v;
    818          v *= v;
    819          v *= v;
    820          dist_norm += xmul * v;
    821          pixels += xmul;
    822        }
    823      }
    824      if (pixels == 0) pixels = 1;
    825      // 16th norm is less than the max norm, we reduce the difference
    826      // with this normalization factor.
    827      constexpr float kTileNorm = 1.2f;
    828      const float tile_dist =
    829          kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f);
    830      dist_row[tile_x] = tile_dist;
    831      for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    832        for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    833          dist_row[tile_x + distmap_stride * iy + ix] = tile_dist;
    834        }
    835      }
    836    }
    837  }
    838  return tile_distmap;
    839 }
    840 
    841 const float kDcQuantPow = 0.83f;
    842 const float kDcQuant = 1.095924047623553f;
    843 const float kAcQuant = 0.725f;
    844 
    845 // Computes the decoded image for a given set of compression parameters.
    846 StatusOr<ImageBundle> RoundtripImage(const FrameHeader& frame_header,
    847                                     const Image3F& opsin,
    848                                     PassesEncoderState* enc_state,
    849                                     const JxlCmsInterface& cms,
    850                                     ThreadPool* pool) {
    851  JxlMemoryManager* memory_manager = enc_state->memory_manager();
    852  std::unique_ptr<PassesDecoderState> dec_state =
    853      jxl::make_unique<PassesDecoderState>(memory_manager);
    854  JXL_RETURN_IF_ERROR(dec_state->output_encoding_info.SetFromMetadata(
    855      *enc_state->shared.metadata));
    856  dec_state->shared = &enc_state->shared;
    857  JXL_ENSURE(opsin.ysize() % kBlockDim == 0);
    858 
    859  const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim);
    860  const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim);
    861  const size_t num_groups = xsize_groups * ysize_groups;
    862 
    863  size_t num_special_frames = enc_state->special_frames.size();
    864  size_t num_passes = enc_state->progressive_splitter.GetNumPasses();
    865  JXL_ASSIGN_OR_RETURN(ModularFrameEncoder modular_frame_encoder,
    866                       ModularFrameEncoder::Create(memory_manager, frame_header,
    867                                                   enc_state->cparams, false));
    868  JXL_RETURN_IF_ERROR(InitializePassesEncoder(frame_header, opsin, Rect(opsin),
    869                                              cms, pool, enc_state,
    870                                              &modular_frame_encoder, nullptr));
    871  JXL_RETURN_IF_ERROR(dec_state->Init(frame_header));
    872  JXL_RETURN_IF_ERROR(dec_state->InitForAC(num_passes, pool));
    873 
    874  ImageBundle decoded(memory_manager, &enc_state->shared.metadata->m);
    875  decoded.origin = frame_header.frame_origin;
    876  JXL_ASSIGN_OR_RETURN(
    877      Image3F tmp,
    878      Image3F::Create(memory_manager, opsin.xsize(), opsin.ysize()));
    879  JXL_RETURN_IF_ERROR(decoded.SetFromImage(
    880      std::move(tmp), dec_state->output_encoding_info.color_encoding));
    881 
    882  PassesDecoderState::PipelineOptions options;
    883  options.use_slow_render_pipeline = false;
    884  options.coalescing = false;
    885  options.render_spotcolors = false;
    886  options.render_noise = false;
    887 
    888  // Same as frame_header.nonserialized_metadata->m
    889  const ImageMetadata& metadata = *decoded.metadata();
    890 
    891  JXL_RETURN_IF_ERROR(dec_state->PreparePipeline(
    892      frame_header, &enc_state->shared.metadata->m, &decoded, options));
    893 
    894  AlignedArray<GroupDecCache> group_dec_caches;
    895  const auto allocate_storage = [&](const size_t num_threads) -> Status {
    896    JXL_RETURN_IF_ERROR(
    897        dec_state->render_pipeline->PrepareForThreads(num_threads,
    898                                                      /*use_group_ids=*/false));
    899    JXL_ASSIGN_OR_RETURN(group_dec_caches, AlignedArray<GroupDecCache>::Create(
    900                                               memory_manager, num_threads));
    901    return true;
    902  };
    903  const auto process_group = [&](const uint32_t group_index,
    904                                 const size_t thread) -> Status {
    905    if (frame_header.loop_filter.epf_iters > 0) {
    906      JXL_RETURN_IF_ERROR(
    907          ComputeSigma(frame_header.loop_filter,
    908                       dec_state->shared->frame_dim.BlockGroupRect(group_index),
    909                       dec_state.get()));
    910    }
    911    RenderPipelineInput input =
    912        dec_state->render_pipeline->GetInputBuffers(group_index, thread);
    913    JXL_RETURN_IF_ERROR(DecodeGroupForRoundtrip(
    914        frame_header, enc_state->coeffs, group_index, dec_state.get(),
    915        &group_dec_caches[thread], thread, input, nullptr, nullptr));
    916    for (size_t c = 0; c < metadata.num_extra_channels; c++) {
    917      std::pair<ImageF*, Rect> ri = input.GetBuffer(3 + c);
    918      FillPlane(0.0f, ri.first, ri.second);
    919    }
    920    JXL_RETURN_IF_ERROR(input.Done());
    921    return true;
    922  };
    923  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_groups, allocate_storage,
    924                                process_group, "AQ loop"));
    925 
    926  // Ensure we don't create any new special frames.
    927  enc_state->special_frames.resize(num_special_frames);
    928 
    929  return decoded;
    930 }
    931 
    932 constexpr int kMaxButteraugliIters = 4;
    933 
    934 Status FindBestQuantization(const FrameHeader& frame_header,
    935                            const Image3F& linear, const Image3F& opsin,
    936                            ImageF& quant_field, PassesEncoderState* enc_state,
    937                            const JxlCmsInterface& cms, ThreadPool* pool,
    938                            AuxOut* aux_out) {
    939  const CompressParams& cparams = enc_state->cparams;
    940  if (cparams.resampling > 1 &&
    941      cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) {
    942    // For downsampled opsin image, the butteraugli based adaptive quantization
    943    // loop would only make the size bigger without improving the distance much,
    944    // so in this case we enable it only for very high butteraugli targets.
    945    return true;
    946  }
    947  JxlMemoryManager* memory_manager = enc_state->memory_manager();
    948  Quantizer& quantizer = enc_state->shared.quantizer;
    949  ImageI& raw_quant_field = enc_state->shared.raw_quant_field;
    950 
    951  const float butteraugli_target = cparams.butteraugli_distance;
    952  const float original_butteraugli = cparams.original_butteraugli_distance;
    953  ButteraugliParams params;
    954  const auto& tf = frame_header.nonserialized_metadata->m.color_encoding.Tf();
    955  params.intensity_target =
    956      tf.IsPQ() || tf.IsHLG()
    957          ? frame_header.nonserialized_metadata->m.IntensityTarget()
    958          : 80.f;
    959  JxlButteraugliComparator comparator(params, cms);
    960  JXL_RETURN_IF_ERROR(comparator.SetLinearReferenceImage(linear));
    961  bool lower_is_better =
    962      (comparator.GoodQualityScore() < comparator.BadQualityScore());
    963  const float initial_quant_dc = InitialQuantDC(butteraugli_target);
    964  JXL_RETURN_IF_ERROR(AdjustQuantField(enc_state->shared.ac_strategy,
    965                                       Rect(quant_field), original_butteraugli,
    966                                       &quant_field));
    967  ImageF tile_distmap;
    968  JXL_ASSIGN_OR_RETURN(
    969      ImageF initial_quant_field,
    970      ImageF::Create(memory_manager, quant_field.xsize(), quant_field.ysize()));
    971  JXL_RETURN_IF_ERROR(CopyImageTo(quant_field, &initial_quant_field));
    972 
    973  float initial_qf_min;
    974  float initial_qf_max;
    975  ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max);
    976  float initial_qf_ratio = initial_qf_max / initial_qf_min;
    977  float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio);
    978  float asymmetry = 2;
    979  if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low;
    980  float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low);
    981  float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry);
    982 
    983  JXL_ENSURE(qf_higher / qf_lower < 253);
    984 
    985  constexpr int kOriginalComparisonRound = 1;
    986  int iters = kMaxButteraugliIters;
    987  if (cparams.speed_tier != SpeedTier::kTortoise) {
    988    iters = 2;
    989  }
    990  for (int i = 0; i < iters + 1; ++i) {
    991    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
    992      printf("\nQuantization field:\n");
    993      for (size_t y = 0; y < quant_field.ysize(); ++y) {
    994        for (size_t x = 0; x < quant_field.xsize(); ++x) {
    995          printf(" %.5f", quant_field.Row(y)[x]);
    996        }
    997        printf("\n");
    998      }
    999    }
   1000    JXL_RETURN_IF_ERROR(quantizer.SetQuantField(initial_quant_dc, quant_field,
   1001                                                &raw_quant_field));
   1002    JXL_ASSIGN_OR_RETURN(
   1003        ImageBundle dec_linear,
   1004        RoundtripImage(frame_header, opsin, enc_state, cms, pool));
   1005    float score;
   1006    ImageF diffmap;
   1007    JXL_RETURN_IF_ERROR(comparator.CompareWith(dec_linear, &diffmap, &score));
   1008    if (!lower_is_better) {
   1009      score = -score;
   1010      ScaleImage(-1.0f, &diffmap);
   1011    }
   1012    JXL_ASSIGN_OR_RETURN(tile_distmap,
   1013                         TileDistMap(diffmap, 8 * cparams.resampling, 0,
   1014                                     enc_state->shared.ac_strategy));
   1015    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && WantDebugOutput(cparams)) {
   1016      JXL_RETURN_IF_ERROR(DumpImage(cparams, ("dec" + ToString(i)).c_str(),
   1017                                    *dec_linear.color()));
   1018      JXL_RETURN_IF_ERROR(DumpHeatmaps(cparams, aux_out, butteraugli_target,
   1019                                       quant_field, tile_distmap, diffmap));
   1020    }
   1021    if (aux_out != nullptr) ++aux_out->num_butteraugli_iters;
   1022    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
   1023      float minval;
   1024      float maxval;
   1025      ImageMinMax(quant_field, &minval, &maxval);
   1026      printf("\nButteraugli iter: %d/%d\n", i, kMaxButteraugliIters);
   1027      printf("Butteraugli distance: %f  (target = %f)\n", score,
   1028             original_butteraugli);
   1029      printf("quant range: %f ... %f  DC quant: %f\n", minval, maxval,
   1030             initial_quant_dc);
   1031      if (FLAGS_dump_quant_state) {
   1032        quantizer.DumpQuantizationMap(raw_quant_field);
   1033      }
   1034    }
   1035 
   1036    if (i == iters) break;
   1037 
   1038    double kPow[8] = {
   1039        0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
   1040    };
   1041    double kPowMod[8] = {
   1042        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
   1043    };
   1044    if (i == kOriginalComparisonRound) {
   1045      // Don't allow optimization to make the quant field a lot worse than
   1046      // what the initial guess was. This allows the AC field to have enough
   1047      // precision to reduce the oscillations due to the dc reconstruction.
   1048      double kInitMul = 0.6;
   1049      const double kOneMinusInitMul = 1.0 - kInitMul;
   1050      for (size_t y = 0; y < quant_field.ysize(); ++y) {
   1051        float* const JXL_RESTRICT row_q = quant_field.Row(y);
   1052        const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y);
   1053        for (size_t x = 0; x < quant_field.xsize(); ++x) {
   1054          double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x];
   1055          if (row_q[x] < clamp) {
   1056            row_q[x] = clamp;
   1057            if (row_q[x] > qf_higher) row_q[x] = qf_higher;
   1058            if (row_q[x] < qf_lower) row_q[x] = qf_lower;
   1059          }
   1060        }
   1061      }
   1062    }
   1063 
   1064    double cur_pow = 0.0;
   1065    if (i < 7) {
   1066      cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i];
   1067      if (cur_pow < 0) {
   1068        cur_pow = 0;
   1069      }
   1070    }
   1071    if (cur_pow == 0.0) {
   1072      for (size_t y = 0; y < quant_field.ysize(); ++y) {
   1073        const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y);
   1074        float* const JXL_RESTRICT row_q = quant_field.Row(y);
   1075        for (size_t x = 0; x < quant_field.xsize(); ++x) {
   1076          const float diff = row_dist[x] / original_butteraugli;
   1077          if (diff > 1.0f) {
   1078            float old = row_q[x];
   1079            row_q[x] *= diff;
   1080            int qf_old =
   1081                static_cast<int>(std::lround(old * quantizer.InvGlobalScale()));
   1082            int qf_new = static_cast<int>(
   1083                std::lround(row_q[x] * quantizer.InvGlobalScale()));
   1084            if (qf_old == qf_new) {
   1085              row_q[x] = old + quantizer.Scale();
   1086            }
   1087          }
   1088          if (row_q[x] > qf_higher) row_q[x] = qf_higher;
   1089          if (row_q[x] < qf_lower) row_q[x] = qf_lower;
   1090        }
   1091      }
   1092    } else {
   1093      for (size_t y = 0; y < quant_field.ysize(); ++y) {
   1094        const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y);
   1095        float* const JXL_RESTRICT row_q = quant_field.Row(y);
   1096        for (size_t x = 0; x < quant_field.xsize(); ++x) {
   1097          const float diff = row_dist[x] / original_butteraugli;
   1098          if (diff <= 1.0f) {
   1099            row_q[x] *= std::pow(diff, cur_pow);
   1100          } else {
   1101            float old = row_q[x];
   1102            row_q[x] *= diff;
   1103            int qf_old =
   1104                static_cast<int>(std::lround(old * quantizer.InvGlobalScale()));
   1105            int qf_new = static_cast<int>(
   1106                std::lround(row_q[x] * quantizer.InvGlobalScale()));
   1107            if (qf_old == qf_new) {
   1108              row_q[x] = old + quantizer.Scale();
   1109            }
   1110          }
   1111          if (row_q[x] > qf_higher) row_q[x] = qf_higher;
   1112          if (row_q[x] < qf_lower) row_q[x] = qf_lower;
   1113        }
   1114      }
   1115    }
   1116  }
   1117  JXL_RETURN_IF_ERROR(
   1118      quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field));
   1119  return true;
   1120 }
   1121 
   1122 Status FindBestQuantizationMaxError(const FrameHeader& frame_header,
   1123                                    const Image3F& opsin, ImageF& quant_field,
   1124                                    PassesEncoderState* enc_state,
   1125                                    const JxlCmsInterface& cms,
   1126                                    ThreadPool* pool, AuxOut* aux_out) {
   1127  // TODO(szabadka): Make this work for non-opsin color spaces.
   1128  const CompressParams& cparams = enc_state->cparams;
   1129  Quantizer& quantizer = enc_state->shared.quantizer;
   1130  ImageI& raw_quant_field = enc_state->shared.raw_quant_field;
   1131 
   1132  // TODO(veluca): better choice of this value.
   1133  const float initial_quant_dc =
   1134      16 * std::sqrt(0.1f / cparams.butteraugli_distance);
   1135  JXL_RETURN_IF_ERROR(
   1136      AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field),
   1137                       cparams.original_butteraugli_distance, &quant_field));
   1138 
   1139  const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0],
   1140                                1.0f / enc_state->cparams.max_error[1],
   1141                                1.0f / enc_state->cparams.max_error[2]};
   1142 
   1143  for (int i = 0; i < kMaxButteraugliIters + 1; ++i) {
   1144    JXL_RETURN_IF_ERROR(quantizer.SetQuantField(initial_quant_dc, quant_field,
   1145                                                &raw_quant_field));
   1146    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) {
   1147      JXL_RETURN_IF_ERROR(
   1148          DumpXybImage(cparams, ("ops" + ToString(i)).c_str(), opsin));
   1149    }
   1150    JXL_ASSIGN_OR_RETURN(
   1151        ImageBundle decoded,
   1152        RoundtripImage(frame_header, opsin, enc_state, cms, pool));
   1153    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) {
   1154      JXL_RETURN_IF_ERROR(DumpXybImage(cparams, ("dec" + ToString(i)).c_str(),
   1155                                       *decoded.color()));
   1156    }
   1157    for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) {
   1158      AcStrategyRow ac_strategy_row =
   1159          enc_state->shared.ac_strategy.ConstRow(by);
   1160      for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) {
   1161        AcStrategy acs = ac_strategy_row[bx];
   1162        if (!acs.IsFirstBlock()) continue;
   1163        float max_error = 0;
   1164        for (size_t c = 0; c < 3; c++) {
   1165          for (size_t y = by * kBlockDim;
   1166               y < (by + acs.covered_blocks_y()) * kBlockDim; y++) {
   1167            if (y >= decoded.ysize()) continue;
   1168            const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y);
   1169            const float* JXL_RESTRICT dec_row =
   1170                decoded.color()->ConstPlaneRow(c, y);
   1171            for (size_t x = bx * kBlockDim;
   1172                 x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) {
   1173              if (x >= decoded.xsize()) continue;
   1174              max_error = std::max(
   1175                  std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error);
   1176            }
   1177          }
   1178        }
   1179        // Target an error between max_error/2 and max_error.
   1180        // If the error in the varblock is above the target, increase the qf to
   1181        // compensate. If the error is below the target, decrease the qf.
   1182        // However, to avoid an excessive increase of the qf, only do so if the
   1183        // error is less than half the maximum allowed error.
   1184        const float qf_mul = (max_error < 0.5f)   ? max_error * 2.0f
   1185                             : (max_error > 1.0f) ? max_error
   1186                                                  : 1.0f;
   1187        for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) {
   1188          float* JXL_RESTRICT quant_field_row = quant_field.Row(qy);
   1189          for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) {
   1190            quant_field_row[qx] *= qf_mul;
   1191          }
   1192        }
   1193      }
   1194    }
   1195  }
   1196  JXL_RETURN_IF_ERROR(
   1197      quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field));
   1198  return true;
   1199 }
   1200 
   1201 }  // namespace
   1202 
   1203 Status AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect,
   1204                        float butteraugli_target, ImageF* quant_field) {
   1205  // Replace the whole quant_field in non-8x8 blocks with the maximum of each
   1206  // 8x8 block.
   1207  size_t stride = quant_field->PixelsPerRow();
   1208 
   1209  // At low distances it is great to use max, but mean works better
   1210  // at high distances. We interpolate between them for a distance
   1211  // range.
   1212  float mean_max_mixer = 1.0f;
   1213  {
   1214    static const float kLimit = 1.54138f;
   1215    static const float kMul = 0.56391f;
   1216    static const float kMin = 0.0f;
   1217    if (butteraugli_target > kLimit) {
   1218      mean_max_mixer -= (butteraugli_target - kLimit) * kMul;
   1219      if (mean_max_mixer < kMin) {
   1220        mean_max_mixer = kMin;
   1221      }
   1222    }
   1223  }
   1224  for (size_t y = 0; y < rect.ysize(); ++y) {
   1225    AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y);
   1226    float* JXL_RESTRICT quant_row = rect.Row(quant_field, y);
   1227    for (size_t x = 0; x < rect.xsize(); ++x) {
   1228      AcStrategy acs = ac_strategy_row[x];
   1229      if (!acs.IsFirstBlock()) continue;
   1230      JXL_ENSURE(x + acs.covered_blocks_x() <= quant_field->xsize());
   1231      JXL_ENSURE(y + acs.covered_blocks_y() <= quant_field->ysize());
   1232      float max = quant_row[x];
   1233      float mean = 0.0;
   1234      for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
   1235        for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
   1236          mean += quant_row[x + ix + iy * stride];
   1237          max = std::max(quant_row[x + ix + iy * stride], max);
   1238        }
   1239      }
   1240      mean /= acs.covered_blocks_y() * acs.covered_blocks_x();
   1241      if (acs.covered_blocks_y() * acs.covered_blocks_x() >= 4) {
   1242        max *= mean_max_mixer;
   1243        max += (1.0f - mean_max_mixer) * mean;
   1244      }
   1245      for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
   1246        for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
   1247          quant_row[x + ix + iy * stride] = max;
   1248        }
   1249      }
   1250    }
   1251  }
   1252  return true;
   1253 }
   1254 
   1255 float InitialQuantDC(float butteraugli_target) {
   1256  const float kDcMul = 0.3;  // Butteraugli target where non-linearity kicks in.
   1257  const float butteraugli_target_dc = std::max<float>(
   1258      0.5f * butteraugli_target,
   1259      std::min<float>(butteraugli_target,
   1260                      kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target,
   1261                                        kDcQuantPow)));
   1262  // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc.
   1263  // The maximum DC value might not be in the kXybRange because of inverse
   1264  // gaborish, so we add some slack to the maximum theoretical quant obtained
   1265  // this way (64).
   1266  return std::min(kDcQuant / butteraugli_target_dc, 50.f);
   1267 }
   1268 
   1269 StatusOr<ImageF> InitialQuantField(const float butteraugli_target,
   1270                                   const Image3F& opsin, const Rect& rect,
   1271                                   ThreadPool* pool, float rescale,
   1272                                   ImageF* mask, ImageF* mask1x1) {
   1273  const float quant_ac = kAcQuant / butteraugli_target;
   1274  return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)(
   1275      butteraugli_target, opsin, rect, quant_ac * rescale, pool, mask, mask1x1);
   1276 }
   1277 
   1278 Status FindBestQuantizer(const FrameHeader& frame_header, const Image3F* linear,
   1279                         const Image3F& opsin, ImageF& quant_field,
   1280                         PassesEncoderState* enc_state,
   1281                         const JxlCmsInterface& cms, ThreadPool* pool,
   1282                         AuxOut* aux_out, double rescale) {
   1283  const CompressParams& cparams = enc_state->cparams;
   1284  if (cparams.max_error_mode) {
   1285    JXL_RETURN_IF_ERROR(FindBestQuantizationMaxError(
   1286        frame_header, opsin, quant_field, enc_state, cms, pool, aux_out));
   1287  } else if (linear && cparams.speed_tier <= SpeedTier::kKitten) {
   1288    // Normal encoding to a butteraugli score.
   1289    JXL_RETURN_IF_ERROR(FindBestQuantization(frame_header, *linear, opsin,
   1290                                             quant_field, enc_state, cms, pool,
   1291                                             aux_out));
   1292  }
   1293  return true;
   1294 }
   1295 
   1296 }  // namespace jxl
   1297 #endif  // HWY_ONCE