tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_heuristics.cc (48256B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_heuristics.h"
      7 
      8 #include <jxl/cms_interface.h>
      9 #include <jxl/memory_manager.h>
     10 
     11 #include <algorithm>
     12 #include <cstddef>
     13 #include <cstdint>
     14 #include <cstdlib>
     15 #include <limits>
     16 #include <memory>
     17 #include <numeric>
     18 #include <string>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "lib/jxl/ac_context.h"
     23 #include "lib/jxl/ac_strategy.h"
     24 #include "lib/jxl/base/common.h"
     25 #include "lib/jxl/base/compiler_specific.h"
     26 #include "lib/jxl/base/data_parallel.h"
     27 #include "lib/jxl/base/override.h"
     28 #include "lib/jxl/base/rect.h"
     29 #include "lib/jxl/base/status.h"
     30 #include "lib/jxl/butteraugli/butteraugli.h"
     31 #include "lib/jxl/chroma_from_luma.h"
     32 #include "lib/jxl/coeff_order.h"
     33 #include "lib/jxl/coeff_order_fwd.h"
     34 #include "lib/jxl/common.h"
     35 #include "lib/jxl/dec_cache.h"
     36 #include "lib/jxl/dec_group.h"
     37 #include "lib/jxl/dec_noise.h"
     38 #include "lib/jxl/dec_xyb.h"
     39 #include "lib/jxl/enc_ac_strategy.h"
     40 #include "lib/jxl/enc_adaptive_quantization.h"
     41 #include "lib/jxl/enc_cache.h"
     42 #include "lib/jxl/enc_chroma_from_luma.h"
     43 #include "lib/jxl/enc_gaborish.h"
     44 #include "lib/jxl/enc_modular.h"
     45 #include "lib/jxl/enc_noise.h"
     46 #include "lib/jxl/enc_params.h"
     47 #include "lib/jxl/enc_patch_dictionary.h"
     48 #include "lib/jxl/enc_quant_weights.h"
     49 #include "lib/jxl/enc_splines.h"
     50 #include "lib/jxl/epf.h"
     51 #include "lib/jxl/frame_dimensions.h"
     52 #include "lib/jxl/frame_header.h"
     53 #include "lib/jxl/image.h"
     54 #include "lib/jxl/image_metadata.h"
     55 #include "lib/jxl/image_ops.h"
     56 #include "lib/jxl/memory_manager_internal.h"
     57 #include "lib/jxl/passes_state.h"
     58 #include "lib/jxl/quant_weights.h"
     59 
     60 namespace jxl {
     61 
     62 struct AuxOut;
     63 
     64 void FindBestBlockEntropyModel(const CompressParams& cparams, const ImageI& rqf,
     65                               const AcStrategyImage& ac_strategy,
     66                               BlockCtxMap* block_ctx_map) {
     67  if (cparams.decoding_speed_tier >= 1) {
     68    static constexpr uint8_t kSimpleCtxMap[] = {
     69        // Cluster all blocks together
     70        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  //
     71        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  //
     72        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  //
     73    };
     74    static_assert(
     75        3 * kNumOrders == sizeof(kSimpleCtxMap) / sizeof *kSimpleCtxMap,
     76        "Update simple context map");
     77 
     78    auto bcm = *block_ctx_map;
     79    bcm.ctx_map.assign(std::begin(kSimpleCtxMap), std::end(kSimpleCtxMap));
     80    bcm.num_ctxs = 2;
     81    bcm.num_dc_ctxs = 1;
     82    return;
     83  }
     84  if (cparams.speed_tier >= SpeedTier::kFalcon) {
     85    return;
     86  }
     87  // No need to change context modeling for small images.
     88  size_t tot = rqf.xsize() * rqf.ysize();
     89  size_t size_for_ctx_model = (1 << 10) * cparams.butteraugli_distance;
     90  if (tot < size_for_ctx_model) return;
     91 
     92  struct OccCounters {
     93    // count the occurrences of each qf value and each strategy type.
     94    OccCounters(const ImageI& rqf, const AcStrategyImage& ac_strategy) {
     95      for (size_t y = 0; y < rqf.ysize(); y++) {
     96        const int32_t* qf_row = rqf.Row(y);
     97        AcStrategyRow acs_row = ac_strategy.ConstRow(y);
     98        for (size_t x = 0; x < rqf.xsize(); x++) {
     99          int ord = kStrategyOrder[acs_row[x].RawStrategy()];
    100          int qf = qf_row[x] - 1;
    101          qf_counts[qf]++;
    102          qf_ord_counts[ord][qf]++;
    103          ord_counts[ord]++;
    104        }
    105      }
    106    }
    107 
    108    size_t qf_counts[256] = {};
    109    size_t qf_ord_counts[kNumOrders][256] = {};
    110    size_t ord_counts[kNumOrders] = {};
    111  };
    112  // The OccCounters struct is too big to allocate on the stack.
    113  std::unique_ptr<OccCounters> counters(new OccCounters(rqf, ac_strategy));
    114 
    115  // Splitting the context model according to the quantization field seems to
    116  // mostly benefit only large images.
    117  size_t size_for_qf_split = (1 << 13) * cparams.butteraugli_distance;
    118  size_t num_qf_segments = tot < size_for_qf_split ? 1 : 2;
    119  std::vector<uint32_t>& qft = block_ctx_map->qf_thresholds;
    120  qft.clear();
    121  // Divide the quant field in up to num_qf_segments segments.
    122  size_t cumsum = 0;
    123  size_t next = 1;
    124  size_t last_cut = 256;
    125  size_t cut = tot * next / num_qf_segments;
    126  for (uint32_t j = 0; j < 256; j++) {
    127    cumsum += counters->qf_counts[j];
    128    if (cumsum > cut) {
    129      if (j != 0) {
    130        qft.push_back(j);
    131      }
    132      last_cut = j;
    133      while (cumsum > cut) {
    134        next++;
    135        cut = tot * next / num_qf_segments;
    136      }
    137    } else if (next > qft.size() + 1) {
    138      if (j - 1 == last_cut && j != 0) {
    139        qft.push_back(j);
    140      }
    141    }
    142  }
    143 
    144  // Count the occurrences of each segment.
    145  std::vector<size_t> counts(kNumOrders * (qft.size() + 1));
    146  size_t qft_pos = 0;
    147  for (size_t j = 0; j < 256; j++) {
    148    if (qft_pos < qft.size() && j == qft[qft_pos]) {
    149      qft_pos++;
    150    }
    151    for (size_t i = 0; i < kNumOrders; i++) {
    152      counts[qft_pos + i * (qft.size() + 1)] += counters->qf_ord_counts[i][j];
    153    }
    154  }
    155 
    156  // Repeatedly merge the lowest-count pair.
    157  std::vector<uint8_t> remap((qft.size() + 1) * kNumOrders);
    158  std::iota(remap.begin(), remap.end(), 0);
    159  std::vector<uint8_t> clusters(remap);
    160  size_t nb_clusters =
    161      Clamp1(static_cast<int>(tot / size_for_ctx_model / 2), 2, 9);
    162  size_t nb_clusters_chroma =
    163      Clamp1(static_cast<int>(tot / size_for_ctx_model / 3), 1, 5);
    164  // This is O(n^2 log n), but n is small.
    165  while (clusters.size() > nb_clusters) {
    166    std::sort(clusters.begin(), clusters.end(),
    167              [&](int a, int b) { return counts[a] > counts[b]; });
    168    counts[clusters[clusters.size() - 2]] += counts[clusters.back()];
    169    counts[clusters.back()] = 0;
    170    remap[clusters.back()] = clusters[clusters.size() - 2];
    171    clusters.pop_back();
    172  }
    173  for (size_t i = 0; i < remap.size(); i++) {
    174    while (remap[remap[i]] != remap[i]) {
    175      remap[i] = remap[remap[i]];
    176    }
    177  }
    178  // Relabel starting from 0.
    179  std::vector<uint8_t> remap_remap(remap.size(), remap.size());
    180  size_t num = 0;
    181  for (size_t i = 0; i < remap.size(); i++) {
    182    if (remap_remap[remap[i]] == remap.size()) {
    183      remap_remap[remap[i]] = num++;
    184    }
    185    remap[i] = remap_remap[remap[i]];
    186  }
    187  // Write the block context map.
    188  auto& ctx_map = block_ctx_map->ctx_map;
    189  ctx_map = remap;
    190  ctx_map.resize(remap.size() * 3);
    191  // for chroma, only use up to nb_clusters_chroma separate block contexts
    192  // (those for the biggest clusters)
    193  for (size_t i = remap.size(); i < remap.size() * 3; i++) {
    194    ctx_map[i] = num + Clamp1(static_cast<int>(remap[i % remap.size()]), 0,
    195                              static_cast<int>(nb_clusters_chroma) - 1);
    196  }
    197  block_ctx_map->num_ctxs =
    198      *std::max_element(ctx_map.begin(), ctx_map.end()) + 1;
    199 }
    200 
    201 namespace {
    202 
    203 Status FindBestDequantMatrices(JxlMemoryManager* memory_manager,
    204                               const CompressParams& cparams,
    205                               ModularFrameEncoder* modular_frame_encoder,
    206                               DequantMatrices* dequant_matrices) {
    207  // TODO(veluca): quant matrices for no-gaborish.
    208  // TODO(veluca): heuristics for in-bitstream quant tables.
    209  *dequant_matrices = DequantMatrices();
    210  if (cparams.max_error_mode || cparams.disable_perceptual_optimizations) {
    211    constexpr float kMSEWeights[3] = {0.001, 0.001, 0.001};
    212    const float* wp = cparams.disable_perceptual_optimizations
    213                          ? kMSEWeights
    214                          : cparams.max_error;
    215    // Set numerators of all quantization matrices to constant values.
    216    float weights[3][1] = {{1.0f / wp[0]}, {1.0f / wp[1]}, {1.0f / wp[2]}};
    217    DctQuantWeightParams dct_params(weights);
    218    std::vector<QuantEncoding> encodings(kNumQuantTables,
    219                                         QuantEncoding::DCT(dct_params));
    220    JXL_RETURN_IF_ERROR(DequantMatricesSetCustom(dequant_matrices, encodings,
    221                                                 modular_frame_encoder));
    222    float dc_weights[3] = {1.0f / wp[0], 1.0f / wp[1], 1.0f / wp[2]};
    223    JXL_RETURN_IF_ERROR(DequantMatricesSetCustomDC(
    224        memory_manager, dequant_matrices, dc_weights));
    225  }
    226  return true;
    227 }
    228 
    229 void StoreMin2(const float v, float& min1, float& min2) {
    230  if (v < min2) {
    231    if (v < min1) {
    232      min2 = min1;
    233      min1 = v;
    234    } else {
    235      min2 = v;
    236    }
    237  }
    238 }
    239 
    240 void CreateMask(const ImageF& image, ImageF& mask) {
    241  for (size_t y = 0; y < image.ysize(); y++) {
    242    const auto* row_n = y > 0 ? image.Row(y - 1) : image.Row(y);
    243    const auto* row_in = image.Row(y);
    244    const auto* row_s = y + 1 < image.ysize() ? image.Row(y + 1) : image.Row(y);
    245    auto* row_out = mask.Row(y);
    246    for (size_t x = 0; x < image.xsize(); x++) {
    247      // Center, west, east, north, south values and their absolute difference
    248      float c = row_in[x];
    249      float w = x > 0 ? row_in[x - 1] : row_in[x];
    250      float e = x + 1 < image.xsize() ? row_in[x + 1] : row_in[x];
    251      float n = row_n[x];
    252      float s = row_s[x];
    253      float dw = std::abs(c - w);
    254      float de = std::abs(c - e);
    255      float dn = std::abs(c - n);
    256      float ds = std::abs(c - s);
    257      float min = std::numeric_limits<float>::max();
    258      float min2 = std::numeric_limits<float>::max();
    259      StoreMin2(dw, min, min2);
    260      StoreMin2(de, min, min2);
    261      StoreMin2(dn, min, min2);
    262      StoreMin2(ds, min, min2);
    263      row_out[x] = min2;
    264    }
    265  }
    266 }
    267 
    268 // Downsamples the image by a factor of 2 with a kernel that's sharper than
    269 // the standard 2x2 box kernel used by DownsampleImage.
    270 // The kernel is optimized against the result of the 2x2 upsampling kernel used
    271 // by the decoder. Ringing is slightly reduced by clamping the values of the
    272 // resulting pixels within certain bounds of a small region in the original
    273 // image.
    274 Status DownsampleImage2_Sharper(const ImageF& input, ImageF* output) {
    275  const int64_t kernelx = 12;
    276  const int64_t kernely = 12;
    277  JxlMemoryManager* memory_manager = input.memory_manager();
    278 
    279  static const float kernel[144] = {
    280      -0.000314256996835, -0.000314256996835, -0.000897597057705,
    281      -0.000562751488849, -0.000176807273646, 0.001864627368902,
    282      0.001864627368902,  -0.000176807273646, -0.000562751488849,
    283      -0.000897597057705, -0.000314256996835, -0.000314256996835,
    284      -0.000314256996835, -0.001527942804748, -0.000121760530512,
    285      0.000191123989093,  0.010193185932466,  0.058637519197110,
    286      0.058637519197110,  0.010193185932466,  0.000191123989093,
    287      -0.000121760530512, -0.001527942804748, -0.000314256996835,
    288      -0.000897597057705, -0.000121760530512, 0.000946363683751,
    289      0.007113577630288,  0.000437956841058,  -0.000372823835211,
    290      -0.000372823835211, 0.000437956841058,  0.007113577630288,
    291      0.000946363683751,  -0.000121760530512, -0.000897597057705,
    292      -0.000562751488849, 0.000191123989093,  0.007113577630288,
    293      0.044592622228814,  0.000222278879007,  -0.162864473015945,
    294      -0.162864473015945, 0.000222278879007,  0.044592622228814,
    295      0.007113577630288,  0.000191123989093,  -0.000562751488849,
    296      -0.000176807273646, 0.010193185932466,  0.000437956841058,
    297      0.000222278879007,  -0.000913092543974, -0.017071696107902,
    298      -0.017071696107902, -0.000913092543974, 0.000222278879007,
    299      0.000437956841058,  0.010193185932466,  -0.000176807273646,
    300      0.001864627368902,  0.058637519197110,  -0.000372823835211,
    301      -0.162864473015945, -0.017071696107902, 0.414660099370354,
    302      0.414660099370354,  -0.017071696107902, -0.162864473015945,
    303      -0.000372823835211, 0.058637519197110,  0.001864627368902,
    304      0.001864627368902,  0.058637519197110,  -0.000372823835211,
    305      -0.162864473015945, -0.017071696107902, 0.414660099370354,
    306      0.414660099370354,  -0.017071696107902, -0.162864473015945,
    307      -0.000372823835211, 0.058637519197110,  0.001864627368902,
    308      -0.000176807273646, 0.010193185932466,  0.000437956841058,
    309      0.000222278879007,  -0.000913092543974, -0.017071696107902,
    310      -0.017071696107902, -0.000913092543974, 0.000222278879007,
    311      0.000437956841058,  0.010193185932466,  -0.000176807273646,
    312      -0.000562751488849, 0.000191123989093,  0.007113577630288,
    313      0.044592622228814,  0.000222278879007,  -0.162864473015945,
    314      -0.162864473015945, 0.000222278879007,  0.044592622228814,
    315      0.007113577630288,  0.000191123989093,  -0.000562751488849,
    316      -0.000897597057705, -0.000121760530512, 0.000946363683751,
    317      0.007113577630288,  0.000437956841058,  -0.000372823835211,
    318      -0.000372823835211, 0.000437956841058,  0.007113577630288,
    319      0.000946363683751,  -0.000121760530512, -0.000897597057705,
    320      -0.000314256996835, -0.001527942804748, -0.000121760530512,
    321      0.000191123989093,  0.010193185932466,  0.058637519197110,
    322      0.058637519197110,  0.010193185932466,  0.000191123989093,
    323      -0.000121760530512, -0.001527942804748, -0.000314256996835,
    324      -0.000314256996835, -0.000314256996835, -0.000897597057705,
    325      -0.000562751488849, -0.000176807273646, 0.001864627368902,
    326      0.001864627368902,  -0.000176807273646, -0.000562751488849,
    327      -0.000897597057705, -0.000314256996835, -0.000314256996835};
    328 
    329  int64_t xsize = input.xsize();
    330  int64_t ysize = input.ysize();
    331 
    332  JXL_ASSIGN_OR_RETURN(ImageF box_downsample,
    333                       ImageF::Create(memory_manager, xsize, ysize));
    334  JXL_RETURN_IF_ERROR(CopyImageTo(input, &box_downsample));
    335  JXL_ASSIGN_OR_RETURN(box_downsample, DownsampleImage(box_downsample, 2));
    336 
    337  JXL_ASSIGN_OR_RETURN(ImageF mask,
    338                       ImageF::Create(memory_manager, box_downsample.xsize(),
    339                                      box_downsample.ysize()));
    340  CreateMask(box_downsample, mask);
    341 
    342  for (size_t y = 0; y < output->ysize(); y++) {
    343    float* row_out = output->Row(y);
    344    const float* row_in[kernely];
    345    const float* row_mask = mask.Row(y);
    346    // get the rows in the support
    347    for (size_t ky = 0; ky < kernely; ky++) {
    348      int64_t iy = y * 2 + ky - (kernely - 1) / 2;
    349      if (iy < 0) iy = 0;
    350      if (iy >= ysize) iy = ysize - 1;
    351      row_in[ky] = input.Row(iy);
    352    }
    353 
    354    for (size_t x = 0; x < output->xsize(); x++) {
    355      // get min and max values of the original image in the support
    356      float min = std::numeric_limits<float>::max();
    357      float max = std::numeric_limits<float>::min();
    358      // kernelx - R and kernely - R are the radius of a rectangular region in
    359      // which the values of a pixel are bounded to reduce ringing.
    360      static constexpr int64_t R = 5;
    361      for (int64_t ky = R; ky + R < kernely; ky++) {
    362        for (int64_t kx = R; kx + R < kernelx; kx++) {
    363          int64_t ix = x * 2 + kx - (kernelx - 1) / 2;
    364          if (ix < 0) ix = 0;
    365          if (ix >= xsize) ix = xsize - 1;
    366          min = std::min<float>(min, row_in[ky][ix]);
    367          max = std::max<float>(max, row_in[ky][ix]);
    368        }
    369      }
    370 
    371      float sum = 0;
    372      for (int64_t ky = 0; ky < kernely; ky++) {
    373        for (int64_t kx = 0; kx < kernelx; kx++) {
    374          int64_t ix = x * 2 + kx - (kernelx - 1) / 2;
    375          if (ix < 0) ix = 0;
    376          if (ix >= xsize) ix = xsize - 1;
    377          sum += row_in[ky][ix] * kernel[ky * kernelx + kx];
    378        }
    379      }
    380 
    381      row_out[x] = sum;
    382 
    383      // Clamp the pixel within the value  of a small area to prevent ringning.
    384      // The mask determines how much to clamp, clamp more to reduce more
    385      // ringing in smooth areas, clamp less in noisy areas to get more
    386      // sharpness. Higher mask_multiplier gives less clamping, so less
    387      // ringing reduction.
    388      const constexpr float mask_multiplier = 1;
    389      float a = row_mask[x] * mask_multiplier;
    390      float clip_min = min - a;
    391      float clip_max = max + a;
    392      if (row_out[x] < clip_min) {
    393        row_out[x] = clip_min;
    394      } else if (row_out[x] > clip_max) {
    395        row_out[x] = clip_max;
    396      }
    397    }
    398  }
    399  return true;
    400 }
    401 
    402 }  // namespace
    403 
    404 Status DownsampleImage2_Sharper(Image3F* opsin) {
    405  // Allocate extra space to avoid a reallocation when padding.
    406  JxlMemoryManager* memory_manager = opsin->memory_manager();
    407  JXL_ASSIGN_OR_RETURN(
    408      Image3F downsampled,
    409      Image3F::Create(memory_manager, DivCeil(opsin->xsize(), 2) + kBlockDim,
    410                      DivCeil(opsin->ysize(), 2) + kBlockDim));
    411  JXL_RETURN_IF_ERROR(downsampled.ShrinkTo(downsampled.xsize() - kBlockDim,
    412                                           downsampled.ysize() - kBlockDim));
    413 
    414  for (size_t c = 0; c < 3; c++) {
    415    JXL_RETURN_IF_ERROR(
    416        DownsampleImage2_Sharper(opsin->Plane(c), &downsampled.Plane(c)));
    417  }
    418  *opsin = std::move(downsampled);
    419  return true;
    420 }
    421 
    422 namespace {
    423 
    424 // The default upsampling kernels used by Upsampler in the decoder.
    425 const constexpr int64_t kSize = 5;
    426 
    427 const float kernel00[25] = {
    428    -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f,
    429    -0.03452303f, 0.14111091f,  0.28896755f,  0.00278718f,  -0.01610267f,
    430    -0.04022174f, 0.28896755f,  0.56661550f,  0.03777607f,  -0.01986694f,
    431    -0.02921014f, 0.00278718f,  0.03777607f,  -0.03144731f, -0.01185068f,
    432    -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f,
    433 };
    434 const float kernel01[25] = {
    435    -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f,
    436    -0.02921014f, 0.00278718f,  0.03777607f,  -0.03144731f, -0.01185068f,
    437    -0.04022174f, 0.28896755f,  0.56661550f,  0.03777607f,  -0.01986694f,
    438    -0.03452303f, 0.14111091f,  0.28896755f,  0.00278718f,  -0.01610267f,
    439    -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f,
    440 };
    441 const float kernel10[25] = {
    442    -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f,
    443    -0.01610267f, 0.00278718f,  0.28896755f,  0.14111091f,  -0.03452303f,
    444    -0.01986694f, 0.03777607f,  0.56661550f,  0.28896755f,  -0.04022174f,
    445    -0.01185068f, -0.03144731f, 0.03777607f,  0.00278718f,  -0.02921014f,
    446    -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f,
    447 };
    448 const float kernel11[25] = {
    449    -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f,
    450    -0.01185068f, -0.03144731f, 0.03777607f,  0.00278718f,  -0.02921014f,
    451    -0.01986694f, 0.03777607f,  0.56661550f,  0.28896755f,  -0.04022174f,
    452    -0.01610267f, 0.00278718f,  0.28896755f,  0.14111091f,  -0.03452303f,
    453    -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f,
    454 };
    455 
    456 // Does exactly the same as the Upsampler in dec_upsampler for 2x2 pixels, with
    457 // default CustomTransformData.
    458 // TODO(lode): use Upsampler instead. However, it requires pre-initialization
    459 // and padding on the left side of the image which requires refactoring the
    460 // other code using this.
    461 void UpsampleImage(const ImageF& input, ImageF* output) {
    462  int64_t xsize = input.xsize();
    463  int64_t ysize = input.ysize();
    464  int64_t xsize2 = output->xsize();
    465  int64_t ysize2 = output->ysize();
    466  for (int64_t y = 0; y < ysize2; y++) {
    467    for (int64_t x = 0; x < xsize2; x++) {
    468      const auto* kernel = kernel00;
    469      if ((x & 1) && (y & 1)) {
    470        kernel = kernel11;
    471      } else if (x & 1) {
    472        kernel = kernel10;
    473      } else if (y & 1) {
    474        kernel = kernel01;
    475      }
    476      float sum = 0;
    477      int64_t x2 = x / 2;
    478      int64_t y2 = y / 2;
    479 
    480      // get min and max values of the original image in the support
    481      float min = std::numeric_limits<float>::max();
    482      float max = std::numeric_limits<float>::min();
    483 
    484      for (int64_t ky = 0; ky < kSize; ky++) {
    485        for (int64_t kx = 0; kx < kSize; kx++) {
    486          int64_t xi = x2 - kSize / 2 + kx;
    487          int64_t yi = y2 - kSize / 2 + ky;
    488          if (xi < 0) xi = 0;
    489          if (xi >= xsize) xi = input.xsize() - 1;
    490          if (yi < 0) yi = 0;
    491          if (yi >= ysize) yi = input.ysize() - 1;
    492          min = std::min<float>(min, input.Row(yi)[xi]);
    493          max = std::max<float>(max, input.Row(yi)[xi]);
    494        }
    495      }
    496 
    497      for (int64_t ky = 0; ky < kSize; ky++) {
    498        for (int64_t kx = 0; kx < kSize; kx++) {
    499          int64_t xi = x2 - kSize / 2 + kx;
    500          int64_t yi = y2 - kSize / 2 + ky;
    501          if (xi < 0) xi = 0;
    502          if (xi >= xsize) xi = input.xsize() - 1;
    503          if (yi < 0) yi = 0;
    504          if (yi >= ysize) yi = input.ysize() - 1;
    505          sum += input.Row(yi)[xi] * kernel[ky * kSize + kx];
    506        }
    507      }
    508      output->Row(y)[x] = sum;
    509      if (output->Row(y)[x] < min) output->Row(y)[x] = min;
    510      if (output->Row(y)[x] > max) output->Row(y)[x] = max;
    511    }
    512  }
    513 }
    514 
    515 // Returns the derivative of Upsampler, with respect to input pixel x2, y2, to
    516 // output pixel x, y (ignoring the clamping).
    517 float UpsamplerDeriv(int64_t x2, int64_t y2, int64_t x, int64_t y) {
    518  const auto* kernel = kernel00;
    519  if ((x & 1) && (y & 1)) {
    520    kernel = kernel11;
    521  } else if (x & 1) {
    522    kernel = kernel10;
    523  } else if (y & 1) {
    524    kernel = kernel01;
    525  }
    526 
    527  int64_t ix = x / 2;
    528  int64_t iy = y / 2;
    529  int64_t kx = x2 - ix + kSize / 2;
    530  int64_t ky = y2 - iy + kSize / 2;
    531 
    532  // This should not happen.
    533  if (kx < 0 || kx >= kSize || ky < 0 || ky >= kSize) return 0;
    534 
    535  return kernel[ky * kSize + kx];
    536 }
    537 
    538 // Apply the derivative of the Upsampler to the input, reversing the effect of
    539 // its coefficients. The output image is 2x2 times smaller than the input.
    540 void AntiUpsample(const ImageF& input, ImageF* d) {
    541  int64_t xsize = input.xsize();
    542  int64_t ysize = input.ysize();
    543  int64_t xsize2 = d->xsize();
    544  int64_t ysize2 = d->ysize();
    545  int64_t k0 = kSize - 1;
    546  int64_t k1 = kSize;
    547  for (int64_t y2 = 0; y2 < ysize2; ++y2) {
    548    auto* row = d->Row(y2);
    549    for (int64_t x2 = 0; x2 < xsize2; ++x2) {
    550      int64_t x0 = x2 * 2 - k0;
    551      if (x0 < 0) x0 = 0;
    552      int64_t x1 = x2 * 2 + k1 + 1;
    553      if (x1 > xsize) x1 = xsize;
    554      int64_t y0 = y2 * 2 - k0;
    555      if (y0 < 0) y0 = 0;
    556      int64_t y1 = y2 * 2 + k1 + 1;
    557      if (y1 > ysize) y1 = ysize;
    558 
    559      float sum = 0;
    560      for (int64_t y = y0; y < y1; ++y) {
    561        const auto* row_in = input.Row(y);
    562        for (int64_t x = x0; x < x1; ++x) {
    563          double deriv = UpsamplerDeriv(x2, y2, x, y);
    564          sum += deriv * row_in[x];
    565        }
    566      }
    567      row[x2] = sum;
    568    }
    569  }
    570 }
    571 
    572 // Element-wise multiplies two images.
    573 template <typename T>
    574 Status ElwiseMul(const Plane<T>& image1, const Plane<T>& image2,
    575                 Plane<T>* out) {
    576  const size_t xsize = image1.xsize();
    577  const size_t ysize = image1.ysize();
    578  JXL_ENSURE(xsize == image2.xsize());
    579  JXL_ENSURE(ysize == image2.ysize());
    580  JXL_ENSURE(xsize == out->xsize());
    581  JXL_ENSURE(ysize == out->ysize());
    582  for (size_t y = 0; y < ysize; ++y) {
    583    const T* const JXL_RESTRICT row1 = image1.Row(y);
    584    const T* const JXL_RESTRICT row2 = image2.Row(y);
    585    T* const JXL_RESTRICT row_out = out->Row(y);
    586    for (size_t x = 0; x < xsize; ++x) {
    587      row_out[x] = row1[x] * row2[x];
    588    }
    589  }
    590  return true;
    591 }
    592 
    593 // Element-wise divides two images.
    594 template <typename T>
    595 Status ElwiseDiv(const Plane<T>& image1, const Plane<T>& image2,
    596                 Plane<T>* out) {
    597  const size_t xsize = image1.xsize();
    598  const size_t ysize = image1.ysize();
    599  JXL_ENSURE(xsize == image2.xsize());
    600  JXL_ENSURE(ysize == image2.ysize());
    601  JXL_ENSURE(xsize == out->xsize());
    602  JXL_ENSURE(ysize == out->ysize());
    603  for (size_t y = 0; y < ysize; ++y) {
    604    const T* const JXL_RESTRICT row1 = image1.Row(y);
    605    const T* const JXL_RESTRICT row2 = image2.Row(y);
    606    T* const JXL_RESTRICT row_out = out->Row(y);
    607    for (size_t x = 0; x < xsize; ++x) {
    608      row_out[x] = row1[x] / row2[x];
    609    }
    610  }
    611  return true;
    612 }
    613 
    614 void ReduceRinging(const ImageF& initial, const ImageF& mask, ImageF& down) {
    615  int64_t xsize2 = down.xsize();
    616  int64_t ysize2 = down.ysize();
    617 
    618  for (size_t y = 0; y < down.ysize(); y++) {
    619    const float* row_mask = mask.Row(y);
    620    float* row_out = down.Row(y);
    621    for (size_t x = 0; x < down.xsize(); x++) {
    622      float v = down.Row(y)[x];
    623      float min = initial.Row(y)[x];
    624      float max = initial.Row(y)[x];
    625      for (int64_t yi = -1; yi < 2; yi++) {
    626        for (int64_t xi = -1; xi < 2; xi++) {
    627          int64_t x2 = static_cast<int64_t>(x) + xi;
    628          int64_t y2 = static_cast<int64_t>(y) + yi;
    629          if (x2 < 0 || y2 < 0 || x2 >= xsize2 || y2 >= ysize2) continue;
    630          min = std::min<float>(min, initial.Row(y2)[x2]);
    631          max = std::max<float>(max, initial.Row(y2)[x2]);
    632        }
    633      }
    634 
    635      row_out[x] = v;
    636 
    637      // Clamp the pixel within the value  of a small area to prevent ringning.
    638      // The mask determines how much to clamp, clamp more to reduce more
    639      // ringing in smooth areas, clamp less in noisy areas to get more
    640      // sharpness. Higher mask_multiplier gives less clamping, so less
    641      // ringing reduction.
    642      const constexpr float mask_multiplier = 2;
    643      float a = row_mask[x] * mask_multiplier;
    644      float clip_min = min - a;
    645      float clip_max = max + a;
    646      if (row_out[x] < clip_min) row_out[x] = clip_min;
    647      if (row_out[x] > clip_max) row_out[x] = clip_max;
    648    }
    649  }
    650 }
    651 
    652 // TODO(lode): move this to a separate file enc_downsample.cc
    653 Status DownsampleImage2_Iterative(const ImageF& orig, ImageF* output) {
    654  int64_t xsize = orig.xsize();
    655  int64_t ysize = orig.ysize();
    656  int64_t xsize2 = DivCeil(orig.xsize(), 2);
    657  int64_t ysize2 = DivCeil(orig.ysize(), 2);
    658  JxlMemoryManager* memory_manager = orig.memory_manager();
    659 
    660  JXL_ASSIGN_OR_RETURN(ImageF box_downsample,
    661                       ImageF::Create(memory_manager, xsize, ysize));
    662  JXL_RETURN_IF_ERROR(CopyImageTo(orig, &box_downsample));
    663  JXL_ASSIGN_OR_RETURN(box_downsample, DownsampleImage(box_downsample, 2));
    664  JXL_ASSIGN_OR_RETURN(ImageF mask,
    665                       ImageF::Create(memory_manager, box_downsample.xsize(),
    666                                      box_downsample.ysize()));
    667  CreateMask(box_downsample, mask);
    668 
    669  JXL_RETURN_IF_ERROR(output->ShrinkTo(xsize2, ysize2));
    670 
    671  // Initial result image using the sharper downsampling.
    672  // Allocate extra space to avoid a reallocation when padding.
    673  JXL_ASSIGN_OR_RETURN(
    674      ImageF initial,
    675      ImageF::Create(memory_manager, DivCeil(orig.xsize(), 2) + kBlockDim,
    676                     DivCeil(orig.ysize(), 2) + kBlockDim));
    677  JXL_RETURN_IF_ERROR(initial.ShrinkTo(initial.xsize() - kBlockDim,
    678                                       initial.ysize() - kBlockDim));
    679  JXL_RETURN_IF_ERROR(DownsampleImage2_Sharper(orig, &initial));
    680 
    681  JXL_ASSIGN_OR_RETURN(
    682      ImageF down,
    683      ImageF::Create(memory_manager, initial.xsize(), initial.ysize()));
    684  JXL_RETURN_IF_ERROR(CopyImageTo(initial, &down));
    685  JXL_ASSIGN_OR_RETURN(ImageF up, ImageF::Create(memory_manager, xsize, ysize));
    686  JXL_ASSIGN_OR_RETURN(ImageF corr,
    687                       ImageF::Create(memory_manager, xsize, ysize));
    688  JXL_ASSIGN_OR_RETURN(ImageF corr2,
    689                       ImageF::Create(memory_manager, xsize2, ysize2));
    690 
    691  // In the weights map, relatively higher values will allow less ringing but
    692  // also less sharpness. With all constant values, it optimizes equally
    693  // everywhere. Even in this case, the weights2 computed from
    694  // this is still used and differs at the borders of the image.
    695  // TODO(lode): Make use of the weights field for anti-ringing and clamping,
    696  // the values are all set to 1 for now, but it is intended to be used for
    697  // reducing ringing based on the mask, and taking clamping into account.
    698  JXL_ASSIGN_OR_RETURN(ImageF weights,
    699                       ImageF::Create(memory_manager, xsize, ysize));
    700  for (size_t y = 0; y < weights.ysize(); y++) {
    701    auto* row = weights.Row(y);
    702    for (size_t x = 0; x < weights.xsize(); x++) {
    703      row[x] = 1;
    704    }
    705  }
    706  JXL_ASSIGN_OR_RETURN(ImageF weights2,
    707                       ImageF::Create(memory_manager, xsize2, ysize2));
    708  AntiUpsample(weights, &weights2);
    709 
    710  const size_t num_it = 3;
    711  for (size_t it = 0; it < num_it; ++it) {
    712    UpsampleImage(down, &up);
    713    JXL_ASSIGN_OR_RETURN(corr, LinComb<float>(1, orig, -1, up));
    714    JXL_RETURN_IF_ERROR(ElwiseMul(corr, weights, &corr));
    715    AntiUpsample(corr, &corr2);
    716    JXL_RETURN_IF_ERROR(ElwiseDiv(corr2, weights2, &corr2));
    717 
    718    JXL_ASSIGN_OR_RETURN(down, LinComb<float>(1, down, 1, corr2));
    719  }
    720 
    721  ReduceRinging(initial, mask, down);
    722 
    723  // can't just use CopyImage, because the output image was prepared with
    724  // padding.
    725  for (size_t y = 0; y < down.ysize(); y++) {
    726    for (size_t x = 0; x < down.xsize(); x++) {
    727      float v = down.Row(y)[x];
    728      output->Row(y)[x] = v;
    729    }
    730  }
    731  return true;
    732 }
    733 
    734 }  // namespace
    735 
    736 Status DownsampleImage2_Iterative(Image3F* opsin) {
    737  JxlMemoryManager* memory_manager = opsin->memory_manager();
    738  // Allocate extra space to avoid a reallocation when padding.
    739  JXL_ASSIGN_OR_RETURN(
    740      Image3F downsampled,
    741      Image3F::Create(memory_manager, DivCeil(opsin->xsize(), 2) + kBlockDim,
    742                      DivCeil(opsin->ysize(), 2) + kBlockDim));
    743  JXL_RETURN_IF_ERROR(downsampled.ShrinkTo(downsampled.xsize() - kBlockDim,
    744                                           downsampled.ysize() - kBlockDim));
    745 
    746  JXL_ASSIGN_OR_RETURN(
    747      Image3F rgb,
    748      Image3F::Create(memory_manager, opsin->xsize(), opsin->ysize()));
    749  OpsinParams opsin_params;  // TODO(user): use the ones that are actually used
    750  opsin_params.Init(kDefaultIntensityTarget);
    751  JXL_RETURN_IF_ERROR(
    752      OpsinToLinear(*opsin, Rect(rgb), nullptr, &rgb, opsin_params));
    753 
    754  JXL_ASSIGN_OR_RETURN(
    755      ImageF mask,
    756      ImageF::Create(memory_manager, opsin->xsize(), opsin->ysize()));
    757  ButteraugliParams butter_params;
    758  JXL_ASSIGN_OR_RETURN(std::unique_ptr<ButteraugliComparator> butter,
    759                       ButteraugliComparator::Make(rgb, butter_params));
    760  JXL_RETURN_IF_ERROR(butter->Mask(&mask));
    761  JXL_ASSIGN_OR_RETURN(
    762      ImageF mask_fuzzy,
    763      ImageF::Create(memory_manager, opsin->xsize(), opsin->ysize()));
    764 
    765  for (size_t c = 0; c < 3; c++) {
    766    JXL_RETURN_IF_ERROR(
    767        DownsampleImage2_Iterative(opsin->Plane(c), &downsampled.Plane(c)));
    768  }
    769  *opsin = std::move(downsampled);
    770  return true;
    771 }
    772 
    773 StatusOr<Image3F> ReconstructImage(
    774    const FrameHeader& orig_frame_header, const PassesSharedState& shared,
    775    const std::vector<std::unique_ptr<ACImage>>& coeffs, ThreadPool* pool) {
    776  const FrameDimensions& frame_dim = shared.frame_dim;
    777  JxlMemoryManager* memory_manager = shared.memory_manager;
    778 
    779  FrameHeader frame_header = orig_frame_header;
    780  frame_header.UpdateFlag(shared.image_features.patches.HasAny(),
    781                          FrameHeader::kPatches);
    782  frame_header.UpdateFlag(shared.image_features.splines.HasAny(),
    783                          FrameHeader::kSplines);
    784  frame_header.color_transform = ColorTransform::kNone;
    785 
    786  CodecMetadata metadata = *frame_header.nonserialized_metadata;
    787  metadata.m.extra_channel_info.clear();
    788  metadata.m.num_extra_channels = metadata.m.extra_channel_info.size();
    789  frame_header.nonserialized_metadata = &metadata;
    790  frame_header.extra_channel_upsampling.clear();
    791 
    792  const bool is_gray = shared.metadata->m.color_encoding.IsGray();
    793  PassesDecoderState dec_state(memory_manager);
    794  JXL_RETURN_IF_ERROR(
    795      dec_state.output_encoding_info.SetFromMetadata(*shared.metadata));
    796  JXL_RETURN_IF_ERROR(dec_state.output_encoding_info.MaybeSetColorEncoding(
    797      ColorEncoding::LinearSRGB(is_gray)));
    798  dec_state.shared = &shared;
    799  JXL_RETURN_IF_ERROR(dec_state.Init(frame_header));
    800 
    801  ImageBundle decoded(memory_manager, &shared.metadata->m);
    802  decoded.origin = frame_header.frame_origin;
    803  JXL_ASSIGN_OR_RETURN(
    804      Image3F tmp,
    805      Image3F::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
    806  JXL_RETURN_IF_ERROR(decoded.SetFromImage(
    807      std::move(tmp), dec_state.output_encoding_info.color_encoding));
    808 
    809  PassesDecoderState::PipelineOptions options;
    810  options.use_slow_render_pipeline = false;
    811  options.coalescing = false;
    812  options.render_spotcolors = false;
    813  options.render_noise = true;
    814 
    815  JXL_RETURN_IF_ERROR(dec_state.PreparePipeline(
    816      frame_header, &shared.metadata->m, &decoded, options));
    817 
    818  AlignedArray<GroupDecCache> group_dec_caches;
    819  const auto allocate_storage = [&](const size_t num_threads) -> Status {
    820    JXL_RETURN_IF_ERROR(
    821        dec_state.render_pipeline->PrepareForThreads(num_threads,
    822                                                     /*use_group_ids=*/false));
    823    JXL_ASSIGN_OR_RETURN(group_dec_caches, AlignedArray<GroupDecCache>::Create(
    824                                               memory_manager, num_threads));
    825    return true;
    826  };
    827  const auto process_group = [&](const uint32_t group_index,
    828                                 const size_t thread) -> Status {
    829    if (frame_header.loop_filter.epf_iters > 0) {
    830      JXL_RETURN_IF_ERROR(ComputeSigma(frame_header.loop_filter,
    831                                       frame_dim.BlockGroupRect(group_index),
    832                                       &dec_state));
    833    }
    834    RenderPipelineInput input =
    835        dec_state.render_pipeline->GetInputBuffers(group_index, thread);
    836    JXL_RETURN_IF_ERROR(DecodeGroupForRoundtrip(
    837        frame_header, coeffs, group_index, &dec_state,
    838        &group_dec_caches[thread], thread, input, nullptr, nullptr));
    839    if ((frame_header.flags & FrameHeader::kNoise) != 0) {
    840      PrepareNoiseInput(dec_state, shared.frame_dim, frame_header, group_index,
    841                        thread);
    842    }
    843    JXL_RETURN_IF_ERROR(input.Done());
    844    return true;
    845  };
    846  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.num_groups, allocate_storage,
    847                                process_group, "ReconstructImage"));
    848  return std::move(*decoded.color());
    849 }
    850 
    851 float ComputeBlockL2Distance(const Image3F& a, const Image3F& b,
    852                             const ImageF& mask1x1, size_t by, size_t bx) {
    853  Rect rect(bx * kBlockDim, by * kBlockDim, kBlockDim, kBlockDim, a.xsize(),
    854            a.ysize());
    855  float err2[3] = {0.0f};
    856  for (size_t y = 0; y < rect.ysize(); ++y) {
    857    const float* row_a[3] = {
    858        rect.ConstPlaneRow(a, 0, y),
    859        rect.ConstPlaneRow(a, 1, y),
    860        rect.ConstPlaneRow(a, 2, y),
    861    };
    862    const float* row_b[3] = {
    863        rect.ConstPlaneRow(b, 0, y),
    864        rect.ConstPlaneRow(b, 1, y),
    865        rect.ConstPlaneRow(b, 2, y),
    866    };
    867    const float* row_mask = rect.ConstRow(mask1x1, y);
    868    for (size_t x = 0; x < rect.xsize(); ++x) {
    869      float mask = row_mask[x];
    870      float mask2 = mask * mask;
    871      for (int i = 0; i < 3; ++i) {
    872        float diff = row_a[i][x] - row_b[i][x];
    873        err2[i] += mask2 * diff * diff;
    874      }
    875    }
    876  }
    877  static const double kW[] = {
    878      12.339445295782363,
    879      1.0,
    880      0.2,
    881  };
    882  float retval = kW[0] * err2[0] + kW[1] * err2[1] + kW[2] * err2[2];
    883  return retval;
    884 }
    885 
    886 Status ComputeARHeuristics(const FrameHeader& frame_header,
    887                           PassesEncoderState* enc_state,
    888                           const Image3F& orig_opsin, const Rect& rect,
    889                           ThreadPool* pool) {
    890  const CompressParams& cparams = enc_state->cparams;
    891  PassesSharedState& shared = enc_state->shared;
    892  const FrameDimensions& frame_dim = shared.frame_dim;
    893  const ImageF& initial_quant_masking1x1 = enc_state->initial_quant_masking1x1;
    894  ImageB& epf_sharpness = shared.epf_sharpness;
    895  JxlMemoryManager* memory_manager = enc_state->memory_manager();
    896 
    897  float clamped_butteraugli = std::min(5.0f, cparams.butteraugli_distance);
    898  if (cparams.butteraugli_distance < kMinButteraugliForDynamicAR ||
    899      cparams.speed_tier > SpeedTier::kWombat ||
    900      frame_header.loop_filter.epf_iters == 0) {
    901    FillPlane(static_cast<uint8_t>(4), &epf_sharpness, Rect(epf_sharpness));
    902    return true;
    903  }
    904 
    905  std::vector<uint8_t> epf_steps;
    906  if (cparams.butteraugli_distance > 4.5f) {
    907    epf_steps.push_back(0);
    908    epf_steps.push_back(4);
    909  } else {
    910    epf_steps.push_back(0);
    911    epf_steps.push_back(2);
    912    epf_steps.push_back(7);
    913  }
    914  static const int kNumEPFVals = 8;
    915  size_t epf_steps_lut[kNumEPFVals] = {0};
    916  {
    917    for (size_t i = 0; i < epf_steps.size(); ++i) {
    918      epf_steps_lut[epf_steps[i]] = i;
    919    }
    920  }
    921  std::array<ImageF, kNumEPFVals> error_images;
    922  for (uint8_t val : epf_steps) {
    923    FillPlane(val, &epf_sharpness, Rect(epf_sharpness));
    924    JXL_ASSIGN_OR_RETURN(
    925        Image3F decoded,
    926        ReconstructImage(frame_header, shared, enc_state->coeffs, pool));
    927    JXL_ASSIGN_OR_RETURN(error_images[val],
    928                         ImageF::Create(memory_manager, frame_dim.xsize_blocks,
    929                                        frame_dim.ysize_blocks));
    930    for (size_t by = 0; by < frame_dim.ysize_blocks; by++) {
    931      float* error_row = error_images[val].Row(by);
    932      for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) {
    933        error_row[bx] = ComputeBlockL2Distance(
    934            orig_opsin, decoded, initial_quant_masking1x1, by, bx);
    935      }
    936    }
    937  }
    938  std::vector<std::vector<size_t>> histo(9, std::vector<size_t>(kNumEPFVals));
    939  std::vector<size_t> totals(9, 1);
    940  const float c5 = 0.007620386618483585f;
    941  const float c6 = 0.0083224805679680686f;
    942  const float c7 = 0.99663939685686753;
    943  for (size_t by = 0; by < frame_dim.ysize_blocks; by++) {
    944    uint8_t* JXL_RESTRICT out_row = epf_sharpness.Row(by);
    945    uint8_t* JXL_RESTRICT prev_row = epf_sharpness.Row(by > 0 ? by - 1 : 0);
    946    for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) {
    947      uint8_t best_val = 0;
    948      float best_error = std::numeric_limits<float>::max();
    949      uint8_t top_val = by > 0 ? prev_row[bx] : 0;
    950      uint8_t left_val = bx > 0 ? out_row[bx - 1] : 0;
    951      float top_error = error_images[top_val].Row(by)[bx];
    952      float left_error = error_images[left_val].Row(by)[bx];
    953      for (uint8_t val : epf_steps) {
    954        float error = error_images[val].Row(by)[bx];
    955        if (val == 0) {
    956          error *= c7 - c5 * clamped_butteraugli;
    957        }
    958        if (error < best_error) {
    959          best_val = val;
    960          best_error = error;
    961        }
    962      }
    963      if (best_error <
    964          (1.0 - c6 * clamped_butteraugli) * std::min(top_error, left_error)) {
    965        out_row[bx] = best_val;
    966      } else if (top_error < left_error) {
    967        out_row[bx] = top_val;
    968      } else {
    969        out_row[bx] = left_val;
    970      }
    971      int context = epf_steps_lut[top_val] * 3 + epf_steps_lut[left_val];
    972      ++histo[context][out_row[bx]];
    973      ++totals[context];
    974    }
    975  }
    976  const float c1 = 0.059588212153340203f;
    977  const float c2 = 0.10599497107315753f;
    978  const float c3base = 0.97;
    979  const float c3 = pow(c3base, clamped_butteraugli);
    980  const float c4 = 1.247544678665836f;
    981  const float context_weight = c1 + c2 * clamped_butteraugli;
    982  for (size_t by = 0; by < frame_dim.ysize_blocks; by++) {
    983    uint8_t* JXL_RESTRICT out_row = epf_sharpness.Row(by);
    984    uint8_t* JXL_RESTRICT prev_row = epf_sharpness.Row(by > 0 ? by - 1 : 0);
    985    for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) {
    986      uint8_t best_val = 0;
    987      float best_error = std::numeric_limits<float>::max();
    988      uint8_t top_val = by > 0 ? prev_row[bx] : 0;
    989      uint8_t left_val = bx > 0 ? out_row[bx - 1] : 0;
    990      int context = epf_steps_lut[top_val] * 3 + epf_steps_lut[left_val];
    991      const auto& ctx_histo = histo[context];
    992      for (uint8_t val : epf_steps) {
    993        float error = error_images[val].Row(by)[bx] /
    994                      (c4 + std::log1p(ctx_histo[val] * context_weight /
    995                                       totals[context]));
    996        if (val == 0) {
    997          error *= c3;
    998        }
    999        if (error < best_error) {
   1000          best_val = val;
   1001          best_error = error;
   1002        }
   1003      }
   1004      out_row[bx] = best_val;
   1005    }
   1006  }
   1007 
   1008  return true;
   1009 }
   1010 
   1011 Status LossyFrameHeuristics(const FrameHeader& frame_header,
   1012                            PassesEncoderState* enc_state,
   1013                            ModularFrameEncoder* modular_frame_encoder,
   1014                            const Image3F* linear, Image3F* opsin,
   1015                            const Rect& rect, const JxlCmsInterface& cms,
   1016                            ThreadPool* pool, AuxOut* aux_out) {
   1017  const CompressParams& cparams = enc_state->cparams;
   1018  const bool streaming_mode = enc_state->streaming_mode;
   1019  const bool initialize_global_state = enc_state->initialize_global_state;
   1020  PassesSharedState& shared = enc_state->shared;
   1021  const FrameDimensions& frame_dim = shared.frame_dim;
   1022  ImageFeatures& image_features = shared.image_features;
   1023  DequantMatrices& matrices = shared.matrices;
   1024  Quantizer& quantizer = shared.quantizer;
   1025  ImageF& initial_quant_masking1x1 = enc_state->initial_quant_masking1x1;
   1026  ImageI& raw_quant_field = shared.raw_quant_field;
   1027  ColorCorrelationMap& cmap = shared.cmap;
   1028  AcStrategyImage& ac_strategy = shared.ac_strategy;
   1029  BlockCtxMap& block_ctx_map = shared.block_ctx_map;
   1030  JxlMemoryManager* memory_manager = enc_state->memory_manager();
   1031 
   1032  // Find and subtract splines.
   1033  if (cparams.custom_splines.HasAny()) {
   1034    image_features.splines = cparams.custom_splines;
   1035  }
   1036  if (!streaming_mode && cparams.speed_tier <= SpeedTier::kSquirrel) {
   1037    if (!cparams.custom_splines.HasAny()) {
   1038      image_features.splines = FindSplines(*opsin);
   1039    }
   1040    JXL_RETURN_IF_ERROR(image_features.splines.InitializeDrawCache(
   1041        opsin->xsize(), opsin->ysize(), cmap.base()));
   1042    image_features.splines.SubtractFrom(opsin);
   1043  }
   1044 
   1045  // Find and subtract patches/dots.
   1046  if (!streaming_mode &&
   1047      ApplyOverride(cparams.patches,
   1048                    cparams.speed_tier <= SpeedTier::kSquirrel)) {
   1049    JXL_RETURN_IF_ERROR(
   1050        FindBestPatchDictionary(*opsin, enc_state, cms, pool, aux_out));
   1051    JXL_RETURN_IF_ERROR(
   1052        PatchDictionaryEncoder::SubtractFrom(image_features.patches, opsin));
   1053  }
   1054 
   1055  const float quant_dc = InitialQuantDC(cparams.butteraugli_distance);
   1056 
   1057  // TODO(veluca): we can now run all the code from here to FindBestQuantizer
   1058  // (excluded) one rect at a time. Do that.
   1059 
   1060  // Dependency graph:
   1061  //
   1062  // input: either XYB or input image
   1063  //
   1064  // input image -> XYB [optional]
   1065  // XYB -> initial quant field
   1066  // XYB -> Gaborished XYB
   1067  // Gaborished XYB -> CfL1
   1068  // initial quant field, Gaborished XYB, CfL1 -> ACS
   1069  // initial quant field, ACS, Gaborished XYB -> EPF control field
   1070  // initial quant field -> adjusted initial quant field
   1071  // adjusted initial quant field, ACS -> raw quant field
   1072  // raw quant field, ACS, Gaborished XYB -> CfL2
   1073  //
   1074  // output: Gaborished XYB, CfL, ACS, raw quant field, EPF control field.
   1075 
   1076  AcStrategyHeuristics acs_heuristics(memory_manager, cparams);
   1077  CfLHeuristics cfl_heuristics(memory_manager);
   1078  ImageF initial_quant_field;
   1079  ImageF initial_quant_masking;
   1080 
   1081  // Compute an initial estimate of the quantization field.
   1082  // Call InitialQuantField only in Hare mode or slower. Otherwise, rely
   1083  // on simple heuristics in FindBestAcStrategy, or set a constant for Falcon
   1084  // mode.
   1085  if (cparams.speed_tier > SpeedTier::kHare ||
   1086      cparams.disable_perceptual_optimizations) {
   1087    JXL_ASSIGN_OR_RETURN(initial_quant_field,
   1088                         ImageF::Create(memory_manager, frame_dim.xsize_blocks,
   1089                                        frame_dim.ysize_blocks));
   1090    JXL_ASSIGN_OR_RETURN(initial_quant_masking,
   1091                         ImageF::Create(memory_manager, frame_dim.xsize_blocks,
   1092                                        frame_dim.ysize_blocks));
   1093    float q = 0.79 / cparams.butteraugli_distance;
   1094    FillImage(q, &initial_quant_field);
   1095    float masking = 1.0f / (q + 0.001f);
   1096    FillImage(masking, &initial_quant_masking);
   1097    if (cparams.disable_perceptual_optimizations) {
   1098      JXL_ASSIGN_OR_RETURN(
   1099          initial_quant_masking1x1,
   1100          ImageF::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
   1101      FillImage(masking, &initial_quant_masking1x1);
   1102    }
   1103    quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0);
   1104  } else {
   1105    // Call this here, as it relies on pre-gaborish values.
   1106    float butteraugli_distance_for_iqf = cparams.butteraugli_distance;
   1107    if (!frame_header.loop_filter.gab) {
   1108      butteraugli_distance_for_iqf *= 0.62f;
   1109    }
   1110    JXL_ASSIGN_OR_RETURN(
   1111        initial_quant_field,
   1112        InitialQuantField(butteraugli_distance_for_iqf, *opsin, rect, pool,
   1113                          1.0f, &initial_quant_masking,
   1114                          &initial_quant_masking1x1));
   1115    float q = 0.39 / cparams.butteraugli_distance;
   1116    quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0);
   1117  }
   1118 
   1119  // TODO(veluca): do something about animations.
   1120 
   1121  // Apply inverse-gaborish.
   1122  if (frame_header.loop_filter.gab) {
   1123    // Changing the weight here to 0.99f would help to reduce ringing in
   1124    // generation loss.
   1125    float weight[3] = {
   1126        1.0f,
   1127        1.0f,
   1128        1.0f,
   1129    };
   1130    JXL_RETURN_IF_ERROR(GaborishInverse(opsin, rect, weight, pool));
   1131  }
   1132 
   1133  if (initialize_global_state) {
   1134    JXL_RETURN_IF_ERROR(FindBestDequantMatrices(
   1135        memory_manager, cparams, modular_frame_encoder, &matrices));
   1136  }
   1137 
   1138  JXL_RETURN_IF_ERROR(cfl_heuristics.Init(rect));
   1139  JXL_RETURN_IF_ERROR(acs_heuristics.Init(*opsin, rect, initial_quant_field,
   1140                                          initial_quant_masking,
   1141                                          initial_quant_masking1x1, &matrices));
   1142 
   1143  auto process_tile = [&](const uint32_t tid, const size_t thread) -> Status {
   1144    size_t n_enc_tiles = DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks);
   1145    size_t tx = tid % n_enc_tiles;
   1146    size_t ty = tid / n_enc_tiles;
   1147    size_t by0 = ty * kEncTileDimInBlocks;
   1148    size_t by1 =
   1149        std::min((ty + 1) * kEncTileDimInBlocks, frame_dim.ysize_blocks);
   1150    size_t bx0 = tx * kEncTileDimInBlocks;
   1151    size_t bx1 =
   1152        std::min((tx + 1) * kEncTileDimInBlocks, frame_dim.xsize_blocks);
   1153    Rect r(bx0, by0, bx1 - bx0, by1 - by0);
   1154 
   1155    // For speeds up to Wombat, we only compute the color correlation map
   1156    // once we know the transform type and the quantization map.
   1157    if (cparams.speed_tier <= SpeedTier::kSquirrel) {
   1158      JXL_RETURN_IF_ERROR(cfl_heuristics.ComputeTile(
   1159          r, *opsin, rect, matrices,
   1160          /*ac_strategy=*/nullptr,
   1161          /*raw_quant_field=*/nullptr,
   1162          /*quantizer=*/nullptr, /*fast=*/false, thread, &cmap));
   1163    }
   1164 
   1165    // Choose block sizes.
   1166    JXL_RETURN_IF_ERROR(
   1167        acs_heuristics.ProcessRect(r, cmap, &ac_strategy, thread));
   1168 
   1169    // Always set the initial quant field, so we can compute the CfL map with
   1170    // more accuracy. The initial quant field might change in slower modes, but
   1171    // adjusting the quant field with butteraugli when all the other encoding
   1172    // parameters are fixed is likely a more reliable choice anyway.
   1173    JXL_RETURN_IF_ERROR(AdjustQuantField(
   1174        ac_strategy, r, cparams.butteraugli_distance, &initial_quant_field));
   1175    quantizer.SetQuantFieldRect(initial_quant_field, r, &raw_quant_field);
   1176 
   1177    // Compute a non-default CfL map if we are at Hare speed, or slower.
   1178    if (cparams.speed_tier <= SpeedTier::kHare) {
   1179      JXL_RETURN_IF_ERROR(cfl_heuristics.ComputeTile(
   1180          r, *opsin, rect, matrices, &ac_strategy, &raw_quant_field, &quantizer,
   1181          /*fast=*/cparams.speed_tier >= SpeedTier::kWombat, thread, &cmap));
   1182    }
   1183    return true;
   1184  };
   1185  size_t num_tiles = DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks) *
   1186                     DivCeil(frame_dim.ysize_blocks, kEncTileDimInBlocks);
   1187  const auto prepare = [&](const size_t num_threads) -> Status {
   1188    JXL_RETURN_IF_ERROR(acs_heuristics.PrepareForThreads(num_threads));
   1189    JXL_RETURN_IF_ERROR(cfl_heuristics.PrepareForThreads(num_threads));
   1190    return true;
   1191  };
   1192  JXL_RETURN_IF_ERROR(
   1193      RunOnPool(pool, 0, num_tiles, prepare, process_tile, "Enc Heuristics"));
   1194 
   1195  JXL_RETURN_IF_ERROR(acs_heuristics.Finalize(frame_dim, ac_strategy, aux_out));
   1196 
   1197  // Refine quantization levels.
   1198  if (!streaming_mode && !cparams.disable_perceptual_optimizations) {
   1199    ImageB& epf_sharpness = shared.epf_sharpness;
   1200    FillPlane(static_cast<uint8_t>(4), &epf_sharpness, Rect(epf_sharpness));
   1201    JXL_RETURN_IF_ERROR(FindBestQuantizer(frame_header, linear, *opsin,
   1202                                          initial_quant_field, enc_state, cms,
   1203                                          pool, aux_out));
   1204  }
   1205 
   1206  // Choose a context model that depends on the amount of quantization for AC.
   1207  if (cparams.speed_tier < SpeedTier::kFalcon && initialize_global_state) {
   1208    FindBestBlockEntropyModel(cparams, raw_quant_field, ac_strategy,
   1209                              &block_ctx_map);
   1210  }
   1211  return true;
   1212 }
   1213 
   1214 }  // namespace jxl