tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_ac_strategy.cc (48525B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_ac_strategy.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <algorithm>
     11 #include <cmath>
     12 #include <cstdint>
     13 #include <cstdio>
     14 #include <cstring>
     15 
     16 #include "lib/jxl/base/common.h"
     17 #include "lib/jxl/memory_manager_internal.h"
     18 
     19 #undef HWY_TARGET_INCLUDE
     20 #define HWY_TARGET_INCLUDE "lib/jxl/enc_ac_strategy.cc"
     21 #include <hwy/foreach_target.h>
     22 #include <hwy/highway.h>
     23 
     24 #include "lib/jxl/ac_strategy.h"
     25 #include "lib/jxl/base/bits.h"
     26 #include "lib/jxl/base/compiler_specific.h"
     27 #include "lib/jxl/base/fast_math-inl.h"
     28 #include "lib/jxl/base/rect.h"
     29 #include "lib/jxl/base/status.h"
     30 #include "lib/jxl/dec_transforms-inl.h"
     31 #include "lib/jxl/enc_aux_out.h"
     32 #include "lib/jxl/enc_debug_image.h"
     33 #include "lib/jxl/enc_params.h"
     34 #include "lib/jxl/enc_transforms-inl.h"
     35 #include "lib/jxl/simd_util.h"
     36 
     37 // Some of the floating point constants in this file and in other
     38 // files in the libjxl project have been obtained using the
     39 // tools/optimizer/simplex_fork.py tool. It is a variation of
     40 // Nelder-Mead optimization, and we generally try to minimize
     41 // BPP * pnorm aggregate as reported by the benchmark_xl tool,
     42 // but occasionally the values are optimized by using additional
     43 // constraints such as maintaining a certain density, or ratio of
     44 // popularity of integral transforms. Jyrki visually reviews all
     45 // such changes and often makes manual changes to maintain good
     46 // visual quality to changes where butteraugli was not sufficiently
     47 // sensitive to some kind of degradation. Unfortunately image quality
     48 // is still more of an art than science.
     49 
     50 // Set JXL_DEBUG_AC_STRATEGY to 1 to enable debugging.
     51 #ifndef JXL_DEBUG_AC_STRATEGY
     52 #define JXL_DEBUG_AC_STRATEGY 0
     53 #endif
     54 
     55 // This must come before the begin/end_target, but HWY_ONCE is only true
     56 // after that, so use an "include guard".
     57 #ifndef LIB_JXL_ENC_AC_STRATEGY_
     58 #define LIB_JXL_ENC_AC_STRATEGY_
     59 // Parameters of the heuristic are marked with a OPTIMIZE comment.
     60 namespace jxl {
     61 namespace {
     62 
     63 // Debugging utilities.
     64 
     65 // Returns a linear sRGB color (as bytes) for each AC strategy.
     66 const uint8_t* TypeColor(uint8_t raw_strategy) {
     67  JXL_DASSERT(AcStrategy::IsRawStrategyValid(raw_strategy));
     68  static_assert(AcStrategy::kNumValidStrategies == 27, "Update colors");
     69  static constexpr uint8_t kColors[AcStrategy::kNumValidStrategies + 1][3] = {
     70      {0xFF, 0xFF, 0x00},  // DCT8       | yellow
     71      {0xFF, 0x80, 0x80},  // HORNUSS    | vivid tangerine
     72      {0xFF, 0x80, 0x80},  // DCT2x2     | vivid tangerine
     73      {0xFF, 0x80, 0x80},  // DCT4x4     | vivid tangerine
     74      {0x80, 0xFF, 0x00},  // DCT16x16   | chartreuse
     75      {0x00, 0xC0, 0x00},  // DCT32x32   | waystone green
     76      {0xC0, 0xFF, 0x00},  // DCT16x8    | lime
     77      {0xC0, 0xFF, 0x00},  // DCT8x16    | lime
     78      {0x00, 0xFF, 0x00},  // DCT32x8    | green
     79      {0x00, 0xFF, 0x00},  // DCT8x32    | green
     80      {0x00, 0xFF, 0x00},  // DCT32x16   | green
     81      {0x00, 0xFF, 0x00},  // DCT16x32   | green
     82      {0xFF, 0x80, 0x00},  // DCT4x8     | orange juice
     83      {0xFF, 0x80, 0x00},  // DCT8x4     | orange juice
     84      {0xFF, 0xFF, 0x80},  // AFV0       | butter
     85      {0xFF, 0xFF, 0x80},  // AFV1       | butter
     86      {0xFF, 0xFF, 0x80},  // AFV2       | butter
     87      {0xFF, 0xFF, 0x80},  // AFV3       | butter
     88      {0x00, 0xC0, 0xFF},  // DCT64x64   | capri
     89      {0x00, 0xFF, 0xFF},  // DCT64x32   | aqua
     90      {0x00, 0xFF, 0xFF},  // DCT32x64   | aqua
     91      {0x00, 0x40, 0xFF},  // DCT128x128 | rare blue
     92      {0x00, 0x80, 0xFF},  // DCT128x64  | magic ink
     93      {0x00, 0x80, 0xFF},  // DCT64x128  | magic ink
     94      {0x00, 0x00, 0xC0},  // DCT256x256 | keese blue
     95      {0x00, 0x00, 0xFF},  // DCT256x128 | blue
     96      {0x00, 0x00, 0xFF},  // DCT128x256 | blue
     97      {0x00, 0x00, 0x00}   // invalid    | black
     98  };
     99  raw_strategy =
    100      Clamp1<uint8_t>(raw_strategy, 0, AcStrategy::kNumValidStrategies);
    101  return kColors[raw_strategy];
    102 }
    103 
    104 const uint8_t* TypeMask(uint8_t raw_strategy) {
    105  JXL_DASSERT(AcStrategy::IsRawStrategyValid(raw_strategy));
    106  static_assert(AcStrategy::kNumValidStrategies == 27, "Update masks");
    107  // implicitly, first row and column is made dark
    108  static constexpr uint8_t kMask[AcStrategy::kNumValidStrategies + 1][64] = {
    109      {
    110          0, 0, 0, 0, 0, 0, 0, 0,  //
    111          0, 0, 0, 0, 0, 0, 0, 0,  //
    112          0, 0, 0, 0, 0, 0, 0, 0,  //
    113          0, 0, 0, 0, 0, 0, 0, 0,  //
    114          0, 0, 0, 0, 0, 0, 0, 0,  //
    115          0, 0, 0, 0, 0, 0, 0, 0,  //
    116          0, 0, 0, 0, 0, 0, 0, 0,  //
    117          0, 0, 0, 0, 0, 0, 0, 0,  //
    118      },                           // DCT8
    119      {
    120          0, 0, 0, 0, 0, 0, 0, 0,  //
    121          0, 0, 0, 0, 0, 0, 0, 0,  //
    122          0, 0, 1, 0, 0, 1, 0, 0,  //
    123          0, 0, 1, 0, 0, 1, 0, 0,  //
    124          0, 0, 1, 1, 1, 1, 0, 0,  //
    125          0, 0, 1, 0, 0, 1, 0, 0,  //
    126          0, 0, 1, 0, 0, 1, 0, 0,  //
    127          0, 0, 0, 0, 0, 0, 0, 0,  //
    128      },                           // HORNUSS
    129      {
    130          1, 1, 1, 1, 1, 1, 1, 1,  //
    131          1, 0, 1, 0, 1, 0, 1, 0,  //
    132          1, 1, 1, 1, 1, 1, 1, 1,  //
    133          1, 0, 1, 0, 1, 0, 1, 0,  //
    134          1, 1, 1, 1, 1, 1, 1, 1,  //
    135          1, 0, 1, 0, 1, 0, 1, 0,  //
    136          1, 1, 1, 1, 1, 1, 1, 1,  //
    137          1, 0, 1, 0, 1, 0, 1, 0,  //
    138      },                           // 2x2
    139      {
    140          0, 0, 0, 0, 1, 0, 0, 0,  //
    141          0, 0, 0, 0, 1, 0, 0, 0,  //
    142          0, 0, 0, 0, 1, 0, 0, 0,  //
    143          0, 0, 0, 0, 1, 0, 0, 0,  //
    144          1, 1, 1, 1, 1, 1, 1, 1,  //
    145          0, 0, 0, 0, 1, 0, 0, 0,  //
    146          0, 0, 0, 0, 1, 0, 0, 0,  //
    147          0, 0, 0, 0, 1, 0, 0, 0,  //
    148      },                           // 4x4
    149      {},                          // DCT16x16 (unused)
    150      {},                          // DCT32x32 (unused)
    151      {},                          // DCT16x8 (unused)
    152      {},                          // DCT8x16 (unused)
    153      {},                          // DCT32x8 (unused)
    154      {},                          // DCT8x32 (unused)
    155      {},                          // DCT32x16 (unused)
    156      {},                          // DCT16x32 (unused)
    157      {
    158          0, 0, 0, 0, 0, 0, 0, 0,  //
    159          0, 0, 0, 0, 0, 0, 0, 0,  //
    160          0, 0, 0, 0, 0, 0, 0, 0,  //
    161          0, 0, 0, 0, 0, 0, 0, 0,  //
    162          1, 1, 1, 1, 1, 1, 1, 1,  //
    163          0, 0, 0, 0, 0, 0, 0, 0,  //
    164          0, 0, 0, 0, 0, 0, 0, 0,  //
    165          0, 0, 0, 0, 0, 0, 0, 0,  //
    166      },                           // DCT4x8
    167      {
    168          0, 0, 0, 0, 1, 0, 0, 0,  //
    169          0, 0, 0, 0, 1, 0, 0, 0,  //
    170          0, 0, 0, 0, 1, 0, 0, 0,  //
    171          0, 0, 0, 0, 1, 0, 0, 0,  //
    172          0, 0, 0, 0, 1, 0, 0, 0,  //
    173          0, 0, 0, 0, 1, 0, 0, 0,  //
    174          0, 0, 0, 0, 1, 0, 0, 0,  //
    175          0, 0, 0, 0, 1, 0, 0, 0,  //
    176      },                           // DCT8x4
    177      {
    178          1, 1, 1, 1, 1, 0, 0, 0,  //
    179          1, 1, 1, 1, 0, 0, 0, 0,  //
    180          1, 1, 1, 0, 0, 0, 0, 0,  //
    181          1, 1, 0, 0, 0, 0, 0, 0,  //
    182          1, 0, 0, 0, 0, 0, 0, 0,  //
    183          0, 0, 0, 0, 0, 0, 0, 0,  //
    184          0, 0, 0, 0, 0, 0, 0, 0,  //
    185          0, 0, 0, 0, 0, 0, 0, 0,  //
    186      },                           // AFV0
    187      {
    188          0, 0, 0, 0, 1, 1, 1, 1,  //
    189          0, 0, 0, 0, 0, 1, 1, 1,  //
    190          0, 0, 0, 0, 0, 0, 1, 1,  //
    191          0, 0, 0, 0, 0, 0, 0, 1,  //
    192          0, 0, 0, 0, 0, 0, 0, 0,  //
    193          0, 0, 0, 0, 0, 0, 0, 0,  //
    194          0, 0, 0, 0, 0, 0, 0, 0,  //
    195          0, 0, 0, 0, 0, 0, 0, 0,  //
    196      },                           // AFV1
    197      {
    198          0, 0, 0, 0, 0, 0, 0, 0,  //
    199          0, 0, 0, 0, 0, 0, 0, 0,  //
    200          0, 0, 0, 0, 0, 0, 0, 0,  //
    201          0, 0, 0, 0, 0, 0, 0, 0,  //
    202          1, 0, 0, 0, 0, 0, 0, 0,  //
    203          1, 1, 0, 0, 0, 0, 0, 0,  //
    204          1, 1, 1, 0, 0, 0, 0, 0,  //
    205          1, 1, 1, 1, 0, 0, 0, 0,  //
    206      },                           // AFV2
    207      {
    208          0, 0, 0, 0, 0, 0, 0, 0,  //
    209          0, 0, 0, 0, 0, 0, 0, 0,  //
    210          0, 0, 0, 0, 0, 0, 0, 0,  //
    211          0, 0, 0, 0, 0, 0, 0, 0,  //
    212          0, 0, 0, 0, 0, 0, 0, 0,  //
    213          0, 0, 0, 0, 0, 0, 0, 1,  //
    214          0, 0, 0, 0, 0, 0, 1, 1,  //
    215          0, 0, 0, 0, 0, 1, 1, 1,  //
    216      },                           // AFV3
    217      {}                           // invalid
    218  };
    219  raw_strategy =
    220      Clamp1<uint8_t>(raw_strategy, 0, AcStrategy::kNumValidStrategies);
    221  return kMask[raw_strategy];
    222 }
    223 
    224 Status DumpAcStrategy(const AcStrategyImage& ac_strategy, size_t xsize,
    225                      size_t ysize, const char* tag, AuxOut* aux_out,
    226                      const CompressParams& cparams) {
    227  JxlMemoryManager* memory_manager = ac_strategy.memory_manager();
    228  JXL_ASSIGN_OR_RETURN(Image3F color_acs,
    229                       Image3F::Create(memory_manager, xsize, ysize));
    230  for (size_t y = 0; y < ysize; y++) {
    231    float* JXL_RESTRICT rows[3] = {
    232        color_acs.PlaneRow(0, y),
    233        color_acs.PlaneRow(1, y),
    234        color_acs.PlaneRow(2, y),
    235    };
    236    const AcStrategyRow acs_row = ac_strategy.ConstRow(y / kBlockDim);
    237    for (size_t x = 0; x < xsize; x++) {
    238      AcStrategy acs = acs_row[x / kBlockDim];
    239      const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy());
    240      for (size_t c = 0; c < 3; c++) {
    241        rows[c][x] = color[c] / 255.f;
    242      }
    243    }
    244  }
    245  size_t stride = color_acs.PixelsPerRow();
    246  for (size_t c = 0; c < 3; c++) {
    247    for (size_t by = 0; by < DivCeil(ysize, kBlockDim); by++) {
    248      float* JXL_RESTRICT row = color_acs.PlaneRow(c, by * kBlockDim);
    249      const AcStrategyRow acs_row = ac_strategy.ConstRow(by);
    250      for (size_t bx = 0; bx < DivCeil(xsize, kBlockDim); bx++) {
    251        AcStrategy acs = acs_row[bx];
    252        if (!acs.IsFirstBlock()) continue;
    253        const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy());
    254        const uint8_t* JXL_RESTRICT mask = TypeMask(acs.RawStrategy());
    255        if (acs.covered_blocks_x() == 1 && acs.covered_blocks_y() == 1) {
    256          for (size_t iy = 0; iy < kBlockDim && by * kBlockDim + iy < ysize;
    257               iy++) {
    258            for (size_t ix = 0; ix < kBlockDim && bx * kBlockDim + ix < xsize;
    259                 ix++) {
    260              if (mask[iy * kBlockDim + ix]) {
    261                row[iy * stride + bx * kBlockDim + ix] = color[c] / 800.f;
    262              }
    263            }
    264          }
    265        }
    266        // draw block edges
    267        for (size_t ix = 0; ix < kBlockDim * acs.covered_blocks_x() &&
    268                            bx * kBlockDim + ix < xsize;
    269             ix++) {
    270          row[0 * stride + bx * kBlockDim + ix] = color[c] / 350.f;
    271        }
    272        for (size_t iy = 0; iy < kBlockDim * acs.covered_blocks_y() &&
    273                            by * kBlockDim + iy < ysize;
    274             iy++) {
    275          row[iy * stride + bx * kBlockDim + 0] = color[c] / 350.f;
    276        }
    277      }
    278    }
    279  }
    280  return DumpImage(cparams, tag, color_acs);
    281 }
    282 
    283 }  // namespace
    284 }  // namespace jxl
    285 #endif  // LIB_JXL_ENC_AC_STRATEGY_
    286 
    287 HWY_BEFORE_NAMESPACE();
    288 namespace jxl {
    289 namespace HWY_NAMESPACE {
    290 
    291 // These templates are not found via ADL.
    292 using hwy::HWY_NAMESPACE::AbsDiff;
    293 using hwy::HWY_NAMESPACE::Eq;
    294 using hwy::HWY_NAMESPACE::IfThenElseZero;
    295 using hwy::HWY_NAMESPACE::IfThenZeroElse;
    296 using hwy::HWY_NAMESPACE::Round;
    297 using hwy::HWY_NAMESPACE::Sqrt;
    298 
    299 bool MultiBlockTransformCrossesHorizontalBoundary(
    300    const AcStrategyImage& ac_strategy, size_t start_x, size_t y,
    301    size_t end_x) {
    302  if (start_x >= ac_strategy.xsize() || y >= ac_strategy.ysize()) {
    303    return false;
    304  }
    305  if (y % 8 == 0) {
    306    // Nothing crosses 64x64 boundaries, and the memory on the other side
    307    // of the 64x64 block may still uninitialized.
    308    return false;
    309  }
    310  end_x = std::min(end_x, ac_strategy.xsize());
    311  // The first multiblock might be before the start_x, let's adjust it
    312  // to point to the first IsFirstBlock() == true block we find by backward
    313  // tracing.
    314  AcStrategyRow row = ac_strategy.ConstRow(y);
    315  const size_t start_x_limit = start_x & ~7;
    316  while (start_x != start_x_limit && !row[start_x].IsFirstBlock()) {
    317    --start_x;
    318  }
    319  for (size_t x = start_x; x < end_x;) {
    320    if (row[x].IsFirstBlock()) {
    321      x += row[x].covered_blocks_x();
    322    } else {
    323      return true;
    324    }
    325  }
    326  return false;
    327 }
    328 
    329 bool MultiBlockTransformCrossesVerticalBoundary(
    330    const AcStrategyImage& ac_strategy, size_t x, size_t start_y,
    331    size_t end_y) {
    332  if (x >= ac_strategy.xsize() || start_y >= ac_strategy.ysize()) {
    333    return false;
    334  }
    335  if (x % 8 == 0) {
    336    // Nothing crosses 64x64 boundaries, and the memory on the other side
    337    // of the 64x64 block may still uninitialized.
    338    return false;
    339  }
    340  end_y = std::min(end_y, ac_strategy.ysize());
    341  // The first multiblock might be before the start_y, let's adjust it
    342  // to point to the first IsFirstBlock() == true block we find by backward
    343  // tracing.
    344  const size_t start_y_limit = start_y & ~7;
    345  while (start_y != start_y_limit &&
    346         !ac_strategy.ConstRow(start_y)[x].IsFirstBlock()) {
    347    --start_y;
    348  }
    349 
    350  for (size_t y = start_y; y < end_y;) {
    351    AcStrategyRow row = ac_strategy.ConstRow(y);
    352    if (row[x].IsFirstBlock()) {
    353      y += row[x].covered_blocks_y();
    354    } else {
    355      return true;
    356    }
    357  }
    358  return false;
    359 }
    360 
    361 Status EstimateEntropy(const AcStrategy& acs, float entropy_mul, size_t x,
    362                       size_t y, const ACSConfig& config,
    363                       const float* JXL_RESTRICT cmap_factors, float* block,
    364                       float* full_scratch_space, uint32_t* quantized,
    365                       float& entropy) {
    366  entropy = 0.0f;
    367  float* mem = full_scratch_space;
    368  float* scratch_space = full_scratch_space + AcStrategy::kMaxCoeffArea;
    369  const size_t size = (1 << acs.log2_covered_blocks()) * kDCTBlockSize;
    370 
    371  // Apply transform.
    372  for (size_t c = 0; c < 3; c++) {
    373    float* JXL_RESTRICT block_c = block + size * c;
    374    TransformFromPixels(acs.Strategy(), &config.Pixel(c, x, y),
    375                        config.src_stride, block_c, scratch_space);
    376  }
    377  HWY_FULL(float) df;
    378 
    379  const size_t num_blocks = acs.covered_blocks_x() * acs.covered_blocks_y();
    380  // avoid large blocks when there is a lot going on in red-green.
    381  float quant_norm16 = 0;
    382  if (num_blocks == 1) {
    383    // When it is only one 8x8, we don't need aggregation of values.
    384    quant_norm16 = config.Quant(x / 8, y / 8);
    385  } else if (num_blocks == 2) {
    386    // Taking max instead of 8th norm seems to work
    387    // better for smallest blocks up to 16x8. Jyrki couldn't get
    388    // improvements in trying the same for 16x16 blocks.
    389    if (acs.covered_blocks_y() == 2) {
    390      quant_norm16 =
    391          std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8, y / 8 + 1));
    392    } else {
    393      quant_norm16 =
    394          std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8 + 1, y / 8));
    395    }
    396  } else {
    397    // Load QF value, calculate empirical heuristic on masking field
    398    // for weighting the information loss. Information loss manifests
    399    // itself as ringing, and masking could hide it.
    400    for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    401      for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    402        float qval = config.Quant(x / 8 + ix, y / 8 + iy);
    403        qval *= qval;
    404        qval *= qval;
    405        qval *= qval;
    406        quant_norm16 += qval * qval;
    407      }
    408    }
    409    quant_norm16 /= num_blocks;
    410    quant_norm16 = FastPowf(quant_norm16, 1.0f / 16.0f);
    411  }
    412  const auto quant = Set(df, quant_norm16);
    413 
    414  // Compute entropy.
    415  const HWY_CAPPED(float, 8) df8;
    416 
    417  auto loss = Zero(df8);
    418  for (size_t c = 0; c < 3; c++) {
    419    const float* inv_matrix = config.dequant->InvMatrix(acs.Strategy(), c);
    420    const float* matrix = config.dequant->Matrix(acs.Strategy(), c);
    421    const auto cmap_factor = Set(df, cmap_factors[c]);
    422 
    423    auto entropy_v = Zero(df);
    424    auto nzeros_v = Zero(df);
    425    for (size_t i = 0; i < num_blocks * kDCTBlockSize; i += Lanes(df)) {
    426      const auto in = Load(df, block + c * size + i);
    427      const auto in_y = Mul(Load(df, block + size + i), cmap_factor);
    428      const auto im = Load(df, inv_matrix + i);
    429      const auto val = Mul(Sub(in, in_y), Mul(im, quant));
    430      const auto rval = Round(val);
    431      const auto diff = Sub(val, rval);
    432      const auto m = Load(df, matrix + i);
    433      Store(Mul(m, diff), df, &mem[i]);
    434      const auto q = Abs(rval);
    435      const auto q_is_zero = Eq(q, Zero(df));
    436      // We used to have q * C here, but that cost model seems to
    437      // be punishing large values more than necessary. Sqrt tries
    438      // to avoid large values less aggressively.
    439      entropy_v = Add(Sqrt(q), entropy_v);
    440      nzeros_v = Add(nzeros_v, IfThenZeroElse(q_is_zero, Set(df, 1.0f)));
    441    }
    442 
    443    {
    444      auto lossc = Zero(df8);
    445      TransformToPixels(acs.Strategy(), &mem[0], block,
    446                        acs.covered_blocks_x() * 8, scratch_space);
    447 
    448      for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    449        for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    450          for (size_t dy = 0; dy < kBlockDim; ++dy) {
    451            for (size_t dx = 0; dx < kBlockDim; dx += Lanes(df8)) {
    452              auto in = Load(df8, block +
    453                                      (iy * kBlockDim + dy) *
    454                                          (acs.covered_blocks_x() * kBlockDim) +
    455                                      ix * kBlockDim + dx);
    456              if (x + ix * 8 + dx + Lanes(df8) <= config.mask1x1_xsize) {
    457                auto masku =
    458                    Abs(Load(df8, config.MaskingPtr1x1(x + ix * 8 + dx,
    459                                                       y + iy * 8 + dy)));
    460                in = Mul(masku, in);
    461                in = Mul(in, in);
    462                in = Mul(in, in);
    463                in = Mul(in, in);
    464                lossc = Add(lossc, in);
    465              }
    466            }
    467          }
    468        }
    469      }
    470      static const double kChannelMul[3] = {
    471          pow(10.2, 8.0),
    472          pow(1.0, 8.0),
    473          pow(1.03, 8.0),
    474      };
    475      lossc = Mul(Set(df8, kChannelMul[c]), lossc);
    476      loss = Add(loss, lossc);
    477    }
    478    entropy += config.cost_delta * GetLane(SumOfLanes(df, entropy_v));
    479    size_t num_nzeros = GetLane(SumOfLanes(df, nzeros_v));
    480    // Add #bit of num_nonzeros, as an estimate of the cost for encoding the
    481    // number of non-zeros of the block.
    482    size_t nbits = CeilLog2Nonzero(num_nzeros + 1) + 1;
    483    // Also add #bit of #bit of num_nonzeros, to estimate the ANS cost, with a
    484    // bias.
    485    entropy += config.zeros_mul * (CeilLog2Nonzero(nbits + 17) + nbits);
    486  }
    487  float loss_scalar =
    488      pow(GetLane(SumOfLanes(df8, loss)) / (num_blocks * kDCTBlockSize),
    489          1.0 / 8.0) *
    490      (num_blocks * kDCTBlockSize) / quant_norm16;
    491  entropy *= entropy_mul;
    492  entropy += config.info_loss_multiplier * loss_scalar;
    493  return true;
    494 }
    495 
    496 Status FindBest8x8Transform(size_t x, size_t y, int encoding_speed_tier,
    497                            float butteraugli_target, const ACSConfig& config,
    498                            const float* JXL_RESTRICT cmap_factors,
    499                            AcStrategyImage* JXL_RESTRICT ac_strategy,
    500                            float* block, float* scratch_space,
    501                            uint32_t* quantized, float* entropy_out,
    502                            AcStrategyType& best_tx) {
    503  struct TransformTry8x8 {
    504    AcStrategyType type;
    505    int encoding_speed_tier_max_limit;
    506    double entropy_mul;
    507  };
    508  static const TransformTry8x8 kTransforms8x8[] = {
    509      {
    510          AcStrategyType::DCT,
    511          9,
    512          0.8,
    513      },
    514      {
    515          AcStrategyType::DCT4X4,
    516          5,
    517          1.08,
    518      },
    519      {
    520          AcStrategyType::DCT2X2,
    521          5,
    522          0.95,
    523      },
    524      {
    525          AcStrategyType::DCT4X8,
    526          4,
    527          0.85931637428340035,
    528      },
    529      {
    530          AcStrategyType::DCT8X4,
    531          4,
    532          0.85931637428340035,
    533      },
    534      {
    535          AcStrategyType::IDENTITY,
    536          5,
    537          1.0427542510634957,
    538      },
    539      {
    540          AcStrategyType::AFV0,
    541          4,
    542          0.81779489591359944,
    543      },
    544      {
    545          AcStrategyType::AFV1,
    546          4,
    547          0.81779489591359944,
    548      },
    549      {
    550          AcStrategyType::AFV2,
    551          4,
    552          0.81779489591359944,
    553      },
    554      {
    555          AcStrategyType::AFV3,
    556          4,
    557          0.81779489591359944,
    558      },
    559  };
    560  double best = 1e30;
    561  best_tx = kTransforms8x8[0].type;
    562  for (auto tx : kTransforms8x8) {
    563    if (tx.encoding_speed_tier_max_limit < encoding_speed_tier) {
    564      continue;
    565    }
    566    AcStrategy acs = AcStrategy::FromRawStrategy(tx.type);
    567    float entropy_mul = tx.entropy_mul / kTransforms8x8[0].entropy_mul;
    568    if ((tx.type == AcStrategyType::DCT2X2 ||
    569         tx.type == AcStrategyType::IDENTITY) &&
    570        butteraugli_target < 5.0) {
    571      static const float kFavor2X2AtHighQuality = 0.4;
    572      float weight = pow((5.0f - butteraugli_target) / 5.0f, 2.0);
    573      entropy_mul -= kFavor2X2AtHighQuality * weight;
    574    }
    575    if ((tx.type != AcStrategyType::DCT && tx.type != AcStrategyType::DCT2X2 &&
    576         tx.type != AcStrategyType::IDENTITY) &&
    577        butteraugli_target > 4.0) {
    578      static const float kAvoidEntropyOfTransforms = 0.5;
    579      float mul = 1.0;
    580      if (butteraugli_target < 12.0) {
    581        mul *= (12.0 - 4.0) / (butteraugli_target - 4.0);
    582      }
    583      entropy_mul += kAvoidEntropyOfTransforms * mul;
    584    }
    585    float entropy;
    586    JXL_RETURN_IF_ERROR(EstimateEntropy(acs, entropy_mul, x, y, config,
    587                                        cmap_factors, block, scratch_space,
    588                                        quantized, entropy));
    589    if (entropy < best) {
    590      best_tx = tx.type;
    591      best = entropy;
    592    }
    593  }
    594  *entropy_out = best;
    595  return true;
    596 }
    597 
    598 // bx, by addresses the 64x64 block at 8x8 subresolution
    599 // cx, cy addresses the left, upper 8x8 block position of the candidate
    600 // transform.
    601 Status TryMergeAcs(AcStrategyType acs_raw, size_t bx, size_t by, size_t cx,
    602                   size_t cy, const ACSConfig& config,
    603                   const float* JXL_RESTRICT cmap_factors,
    604                   AcStrategyImage* JXL_RESTRICT ac_strategy,
    605                   const float entropy_mul, const uint8_t candidate_priority,
    606                   uint8_t* priority, float* JXL_RESTRICT entropy_estimate,
    607                   float* block, float* scratch_space, uint32_t* quantized) {
    608  AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw);
    609  float entropy_current = 0;
    610  for (size_t iy = 0; iy < acs.covered_blocks_y(); ++iy) {
    611    for (size_t ix = 0; ix < acs.covered_blocks_x(); ++ix) {
    612      if (priority[(cy + iy) * 8 + (cx + ix)] >= candidate_priority) {
    613        // Transform would reuse already allocated blocks and
    614        // lead to invalid overlaps, for example DCT64X32 vs.
    615        // DCT32X64.
    616        return true;
    617      }
    618      entropy_current += entropy_estimate[(cy + iy) * 8 + (cx + ix)];
    619    }
    620  }
    621  float entropy_candidate;
    622  JXL_RETURN_IF_ERROR(EstimateEntropy(
    623      acs, entropy_mul, (bx + cx) * 8, (by + cy) * 8, config, cmap_factors,
    624      block, scratch_space, quantized, entropy_candidate));
    625  if (entropy_candidate >= entropy_current) return true;
    626  // Accept the candidate.
    627  for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    628    for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    629      entropy_estimate[(cy + iy) * 8 + cx + ix] = 0;
    630      priority[(cy + iy) * 8 + cx + ix] = candidate_priority;
    631    }
    632  }
    633  JXL_RETURN_IF_ERROR(ac_strategy->Set(bx + cx, by + cy, acs_raw));
    634  entropy_estimate[cy * 8 + cx] = entropy_candidate;
    635  return true;
    636 }
    637 
    638 static void SetEntropyForTransform(size_t cx, size_t cy,
    639                                   const AcStrategyType acs_raw, float entropy,
    640                                   float* JXL_RESTRICT entropy_estimate) {
    641  const AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw);
    642  for (size_t dy = 0; dy < acs.covered_blocks_y(); ++dy) {
    643    for (size_t dx = 0; dx < acs.covered_blocks_x(); ++dx) {
    644      entropy_estimate[(cy + dy) * 8 + cx + dx] = 0.0;
    645    }
    646  }
    647  entropy_estimate[cy * 8 + cx] = entropy;
    648 }
    649 
    650 AcStrategyType AcsSquare(size_t blocks) {
    651  if (blocks == 2) {
    652    return AcStrategyType::DCT16X16;
    653  } else if (blocks == 4) {
    654    return AcStrategyType::DCT32X32;
    655  } else {
    656    return AcStrategyType::DCT64X64;
    657  }
    658 }
    659 
    660 AcStrategyType AcsVerticalSplit(size_t blocks) {
    661  if (blocks == 2) {
    662    return AcStrategyType::DCT16X8;
    663  } else if (blocks == 4) {
    664    return AcStrategyType::DCT32X16;
    665  } else {
    666    return AcStrategyType::DCT64X32;
    667  }
    668 }
    669 
    670 AcStrategyType AcsHorizontalSplit(size_t blocks) {
    671  if (blocks == 2) {
    672    return AcStrategyType::DCT8X16;
    673  } else if (blocks == 4) {
    674    return AcStrategyType::DCT16X32;
    675  } else {
    676    return AcStrategyType::DCT32X64;
    677  }
    678 }
    679 
    680 // The following function tries to merge smaller transforms into
    681 // squares and the rectangles originating from a single middle division
    682 // (horizontal or vertical) fairly.
    683 //
    684 // This is now generalized to concern about squares
    685 // of blocks X blocks size, where a block is 8x8 pixels.
    686 Status FindBestFirstLevelDivisionForSquare(
    687    size_t blocks, bool allow_square_transform, size_t bx, size_t by, size_t cx,
    688    size_t cy, const ACSConfig& config, const float* JXL_RESTRICT cmap_factors,
    689    AcStrategyImage* JXL_RESTRICT ac_strategy, const float entropy_mul_JXK,
    690    const float entropy_mul_JXJ, float* JXL_RESTRICT entropy_estimate,
    691    float* block, float* scratch_space, uint32_t* quantized) {
    692  // We denote J for the larger dimension here, and K for the smaller.
    693  // For example, for 32x32 block splitting, J would be 32, K 16.
    694  const size_t blocks_half = blocks / 2;
    695  const AcStrategyType acs_rawJXK = AcsVerticalSplit(blocks);
    696  const AcStrategyType acs_rawKXJ = AcsHorizontalSplit(blocks);
    697  const AcStrategyType acs_rawJXJ = AcsSquare(blocks);
    698  const AcStrategy acsJXK = AcStrategy::FromRawStrategy(acs_rawJXK);
    699  const AcStrategy acsKXJ = AcStrategy::FromRawStrategy(acs_rawKXJ);
    700  const AcStrategy acsJXJ = AcStrategy::FromRawStrategy(acs_rawJXJ);
    701  AcStrategyRow row0 = ac_strategy->ConstRow(by + cy + 0);
    702  AcStrategyRow row1 = ac_strategy->ConstRow(by + cy + blocks_half);
    703  // Let's check if we can consider a JXJ block here at all.
    704  // This is not necessary in the basic use of hierarchically merging
    705  // blocks in the simplest possible way, but is needed when we try other
    706  // 'floating' options of merging, possibly after a simple hierarchical
    707  // merge has been explored.
    708  if (MultiBlockTransformCrossesHorizontalBoundary(*ac_strategy, bx + cx,
    709                                                   by + cy, bx + cx + blocks) ||
    710      MultiBlockTransformCrossesHorizontalBoundary(
    711          *ac_strategy, bx + cx, by + cy + blocks, bx + cx + blocks) ||
    712      MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx, by + cy,
    713                                                 by + cy + blocks) ||
    714      MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx + blocks,
    715                                                 by + cy, by + cy + blocks)) {
    716    return true;  // not suitable for JxJ analysis, some transforms leak out.
    717  }
    718  // For floating transforms there may be
    719  // already blocks selected that make either or both JXK and
    720  // KXJ not feasible for this location.
    721  const bool allow_JXK = !MultiBlockTransformCrossesVerticalBoundary(
    722      *ac_strategy, bx + cx + blocks_half, by + cy, by + cy + blocks);
    723  const bool allow_KXJ = !MultiBlockTransformCrossesHorizontalBoundary(
    724      *ac_strategy, bx + cx, by + cy + blocks_half, bx + cx + blocks);
    725  // Current entropies aggregated on NxN resolution.
    726  float entropy[2][2] = {};
    727  for (size_t dy = 0; dy < blocks; ++dy) {
    728    for (size_t dx = 0; dx < blocks; ++dx) {
    729      entropy[dy / blocks_half][dx / blocks_half] +=
    730          entropy_estimate[(cy + dy) * 8 + (cx + dx)];
    731    }
    732  }
    733  float entropy_JXK_left = std::numeric_limits<float>::max();
    734  float entropy_JXK_right = std::numeric_limits<float>::max();
    735  float entropy_KXJ_top = std::numeric_limits<float>::max();
    736  float entropy_KXJ_bottom = std::numeric_limits<float>::max();
    737  float entropy_JXJ = std::numeric_limits<float>::max();
    738  if (allow_JXK) {
    739    if (row0[bx + cx + 0].Strategy() != acs_rawJXK) {
    740      JXL_RETURN_IF_ERROR(EstimateEntropy(
    741          acsJXK, entropy_mul_JXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config,
    742          cmap_factors, block, scratch_space, quantized, entropy_JXK_left));
    743    }
    744    if (row0[bx + cx + blocks_half].Strategy() != acs_rawJXK) {
    745      JXL_RETURN_IF_ERROR(
    746          EstimateEntropy(acsJXK, entropy_mul_JXK, (bx + cx + blocks_half) * 8,
    747                          (by + cy + 0) * 8, config, cmap_factors, block,
    748                          scratch_space, quantized, entropy_JXK_right));
    749    }
    750  }
    751  if (allow_KXJ) {
    752    if (row0[bx + cx].Strategy() != acs_rawKXJ) {
    753      JXL_RETURN_IF_ERROR(EstimateEntropy(
    754          acsKXJ, entropy_mul_JXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config,
    755          cmap_factors, block, scratch_space, quantized, entropy_KXJ_top));
    756    }
    757    if (row1[bx + cx].Strategy() != acs_rawKXJ) {
    758      JXL_RETURN_IF_ERROR(
    759          EstimateEntropy(acsKXJ, entropy_mul_JXK, (bx + cx + 0) * 8,
    760                          (by + cy + blocks_half) * 8, config, cmap_factors,
    761                          block, scratch_space, quantized, entropy_KXJ_bottom));
    762    }
    763  }
    764  if (allow_square_transform) {
    765    // We control the exploration of the square transform separately so that
    766    // we can turn it off at high decoding speeds for 32x32, but still allow
    767    // exploring 16x32 and 32x16.
    768    JXL_RETURN_IF_ERROR(EstimateEntropy(
    769        acsJXJ, entropy_mul_JXJ, (bx + cx + 0) * 8, (by + cy + 0) * 8, config,
    770        cmap_factors, block, scratch_space, quantized, entropy_JXJ));
    771  }
    772 
    773  // Test if this block should have JXK or KXJ transforms,
    774  // because it can have only one or the other.
    775  float costJxN = std::min(entropy_JXK_left, entropy[0][0] + entropy[1][0]) +
    776                  std::min(entropy_JXK_right, entropy[0][1] + entropy[1][1]);
    777  float costNxJ = std::min(entropy_KXJ_top, entropy[0][0] + entropy[0][1]) +
    778                  std::min(entropy_KXJ_bottom, entropy[1][0] + entropy[1][1]);
    779  if (entropy_JXJ < costJxN && entropy_JXJ < costNxJ) {
    780    JXL_RETURN_IF_ERROR(ac_strategy->Set(bx + cx, by + cy, acs_rawJXJ));
    781    SetEntropyForTransform(cx, cy, acs_rawJXJ, entropy_JXJ, entropy_estimate);
    782  } else if (costJxN < costNxJ) {
    783    if (entropy_JXK_left < entropy[0][0] + entropy[1][0]) {
    784      JXL_RETURN_IF_ERROR(ac_strategy->Set(bx + cx, by + cy, acs_rawJXK));
    785      SetEntropyForTransform(cx, cy, acs_rawJXK, entropy_JXK_left,
    786                             entropy_estimate);
    787    }
    788    if (entropy_JXK_right < entropy[0][1] + entropy[1][1]) {
    789      JXL_RETURN_IF_ERROR(
    790          ac_strategy->Set(bx + cx + blocks_half, by + cy, acs_rawJXK));
    791      SetEntropyForTransform(cx + blocks_half, cy, acs_rawJXK,
    792                             entropy_JXK_right, entropy_estimate);
    793    }
    794  } else {
    795    if (entropy_KXJ_top < entropy[0][0] + entropy[0][1]) {
    796      JXL_RETURN_IF_ERROR(ac_strategy->Set(bx + cx, by + cy, acs_rawKXJ));
    797      SetEntropyForTransform(cx, cy, acs_rawKXJ, entropy_KXJ_top,
    798                             entropy_estimate);
    799    }
    800    if (entropy_KXJ_bottom < entropy[1][0] + entropy[1][1]) {
    801      JXL_RETURN_IF_ERROR(
    802          ac_strategy->Set(bx + cx, by + cy + blocks_half, acs_rawKXJ));
    803      SetEntropyForTransform(cx, cy + blocks_half, acs_rawKXJ,
    804                             entropy_KXJ_bottom, entropy_estimate);
    805    }
    806  }
    807  return true;
    808 }
    809 
    810 Status ProcessRectACS(const CompressParams& cparams, const ACSConfig& config,
    811                      const Rect& rect, const ColorCorrelationMap& cmap,
    812                      float* JXL_RESTRICT block,
    813                      uint32_t* JXL_RESTRICT quantized,
    814                      AcStrategyImage* ac_strategy) {
    815  // Main philosophy here:
    816  // 1. First find best 8x8 transform for each area.
    817  // 2. Merging them into larger transforms where possibly, but
    818  // starting from the smallest transforms (16x8 and 8x16).
    819  // Additional complication: 16x8 and 8x16 are considered
    820  // simultaneously and fairly against each other.
    821  // We are looking at 64x64 squares since the Y-to-X and Y-to-B
    822  // maps happen to be at that resolution, and having
    823  // integral transforms cross these boundaries leads to
    824  // additional complications.
    825  const float butteraugli_target = cparams.butteraugli_distance;
    826  float* JXL_RESTRICT scratch_space = block + 3 * AcStrategy::kMaxCoeffArea;
    827  size_t bx = rect.x0();
    828  size_t by = rect.y0();
    829  JXL_ENSURE(rect.xsize() <= 8);
    830  JXL_ENSURE(rect.ysize() <= 8);
    831  size_t tx = bx / kColorTileDimInBlocks;
    832  size_t ty = by / kColorTileDimInBlocks;
    833  const float cmap_factors[3] = {
    834      cmap.base().YtoXRatio(cmap.ytox_map.ConstRow(ty)[tx]),
    835      0.0f,
    836      cmap.base().YtoBRatio(cmap.ytob_map.ConstRow(ty)[tx]),
    837  };
    838  if (cparams.speed_tier > SpeedTier::kHare) return true;
    839  // First compute the best 8x8 transform for each square. Later, we do not
    840  // experiment with different combinations, but only use the best of the 8x8s
    841  // when DCT8X8 is specified in the tree search.
    842  // 8x8 transforms have 10 variants, but every larger transform is just a DCT.
    843  float entropy_estimate[64] = {};
    844  // Favor all 8x8 transforms (against 16x8 and larger transforms)) at
    845  // low butteraugli_target distances.
    846  static const float k8x8mul1 = -0.4;
    847  static const float k8x8mul2 = 1.0;
    848  static const float k8x8base = 1.4;
    849  const float mul8x8 = k8x8mul2 + k8x8mul1 / (butteraugli_target + k8x8base);
    850  for (size_t iy = 0; iy < rect.ysize(); iy++) {
    851    for (size_t ix = 0; ix < rect.xsize(); ix++) {
    852      float entropy = 0.0;
    853      AcStrategyType best_of_8x8s;
    854      JXL_RETURN_IF_ERROR(FindBest8x8Transform(
    855          8 * (bx + ix), 8 * (by + iy), static_cast<int>(cparams.speed_tier),
    856          butteraugli_target, config, cmap_factors, ac_strategy, block,
    857          scratch_space, quantized, &entropy, best_of_8x8s));
    858      JXL_RETURN_IF_ERROR(ac_strategy->Set(bx + ix, by + iy, best_of_8x8s));
    859      entropy_estimate[iy * 8 + ix] = entropy * mul8x8;
    860    }
    861  }
    862  // Merge when a larger transform is better than the previously
    863  // searched best combination of 8x8 transforms.
    864  struct MergeTry {
    865    AcStrategyType type;
    866    uint8_t priority;
    867    uint8_t decoding_speed_tier_max_limit;
    868    uint8_t encoding_speed_tier_max_limit;
    869    float entropy_mul;
    870  };
    871  // These numbers need to be figured out manually and looking at
    872  // ringing next to sky etc. Optimization will find smaller numbers
    873  // and produce more ringing than is ideal. Larger numbers will
    874  // help stop ringing.
    875  const float entropy_mul16X8 = 1.25;
    876  const float entropy_mul16X16 = 1.35;
    877  const float entropy_mul16X32 = 1.5;
    878  const float entropy_mul32X32 = 1.5;
    879  const float entropy_mul64X32 = 2.26;
    880  const float entropy_mul64X64 = 2.26;
    881  // TODO(jyrki): Consider this feedback in further changes:
    882  // Also effectively when the multipliers for smaller blocks are
    883  // below 1, this raises the bar for the bigger blocks even higher
    884  // in that sense these constants are not independent (e.g. changing
    885  // the constant for DCT16x32 by -5% (making it more likely) also
    886  // means that DCT32x32 becomes harder to do when starting from
    887  // two DCT16x32s). It might be better to make them more independent,
    888  // e.g. by not applying the multiplier when storing the new entropy
    889  // estimates in TryMergeToACSCandidate().
    890  const MergeTry kTransformsForMerge[9] = {
    891      {AcStrategyType::DCT16X8, 2, 4, 5, entropy_mul16X8},
    892      {AcStrategyType::DCT8X16, 2, 4, 5, entropy_mul16X8},
    893      // FindBestFirstLevelDivisionForSquare looks for DCT16X16 and its
    894      // subdivisions. {AcStrategyType::DCT16X16, 3, entropy_mul16X16},
    895      {AcStrategyType::DCT16X32, 4, 4, 4, entropy_mul16X32},
    896      {AcStrategyType::DCT32X16, 4, 4, 4, entropy_mul16X32},
    897      // FindBestFirstLevelDivisionForSquare looks for DCT32X32 and its
    898      // subdivisions. {AcStrategyType::DCT32X32, 5, 1, 5,
    899      // 0.9822994906548809f},
    900      {AcStrategyType::DCT64X32, 6, 1, 3, entropy_mul64X32},
    901      {AcStrategyType::DCT32X64, 6, 1, 3, entropy_mul64X32},
    902      // {AcStrategyType::DCT64X64, 8, 1, 3, 2.0846542128012948f},
    903  };
    904  /*
    905  These sizes not yet included in merge heuristic:
    906  set(AcStrategyType::DCT32X8, 0.0f, 2.261390410971102f);
    907  set(AcStrategyType::DCT8X32, 0.0f, 2.261390410971102f);
    908  set(AcStrategyType::DCT128X128, 0.0f, 1.0f);
    909  set(AcStrategyType::DCT128X64, 0.0f, 0.73f);
    910  set(AcStrategyType::DCT64X128, 0.0f, 0.73f);
    911  set(AcStrategyType::DCT256X256, 0.0f, 1.0f);
    912  set(AcStrategyType::DCT256X128, 0.0f, 0.73f);
    913  set(AcStrategyType::DCT128X256, 0.0f, 0.73f);
    914  */
    915 
    916  // Priority is a tricky kludge to avoid collisions so that transforms
    917  // don't overlap.
    918  uint8_t priority[64] = {};
    919  bool enable_32x32 = cparams.decoding_speed_tier < 4;
    920  for (auto tx : kTransformsForMerge) {
    921    if (tx.decoding_speed_tier_max_limit < cparams.decoding_speed_tier) {
    922      continue;
    923    }
    924    AcStrategy acs = AcStrategy::FromRawStrategy(tx.type);
    925 
    926    for (size_t cy = 0; cy + acs.covered_blocks_y() - 1 < rect.ysize();
    927         cy += acs.covered_blocks_y()) {
    928      for (size_t cx = 0; cx + acs.covered_blocks_x() - 1 < rect.xsize();
    929           cx += acs.covered_blocks_x()) {
    930        if (cy + 7 < rect.ysize() && cx + 7 < rect.xsize()) {
    931          if (cparams.decoding_speed_tier < 4 &&
    932              tx.type == AcStrategyType::DCT32X64) {
    933            // We handle both DCT8X16 and DCT16X8 at the same time.
    934            if ((cy | cx) % 8 == 0) {
    935              JXL_RETURN_IF_ERROR(FindBestFirstLevelDivisionForSquare(
    936                  8, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
    937                  tx.entropy_mul, entropy_mul64X64, entropy_estimate, block,
    938                  scratch_space, quantized));
    939            }
    940            continue;
    941          } else if (tx.type == AcStrategyType::DCT32X16) {
    942            // We handled both DCT8X16 and DCT16X8 at the same time,
    943            // and that is above. The last column and last row,
    944            // when the last column or last row is odd numbered,
    945            // are still handled by TryMergeAcs.
    946            continue;
    947          }
    948        }
    949        if ((tx.type == AcStrategyType::DCT16X32 && cy % 4 != 0) ||
    950            (tx.type == AcStrategyType::DCT32X16 && cx % 4 != 0)) {
    951          // already covered by FindBest32X32
    952          continue;
    953        }
    954 
    955        if (cy + 3 < rect.ysize() && cx + 3 < rect.xsize()) {
    956          if (tx.type == AcStrategyType::DCT16X32) {
    957            // We handle both DCT8X16 and DCT16X8 at the same time.
    958            if ((cy | cx) % 4 == 0) {
    959              JXL_RETURN_IF_ERROR(FindBestFirstLevelDivisionForSquare(
    960                  4, enable_32x32, bx, by, cx, cy, config, cmap_factors,
    961                  ac_strategy, tx.entropy_mul, entropy_mul32X32,
    962                  entropy_estimate, block, scratch_space, quantized));
    963            }
    964            continue;
    965          } else if (tx.type == AcStrategyType::DCT32X16) {
    966            // We handled both DCT8X16 and DCT16X8 at the same time,
    967            // and that is above. The last column and last row,
    968            // when the last column or last row is odd numbered,
    969            // are still handled by TryMergeAcs.
    970            continue;
    971          }
    972        }
    973        if ((tx.type == AcStrategyType::DCT16X32 && cy % 4 != 0) ||
    974            (tx.type == AcStrategyType::DCT32X16 && cx % 4 != 0)) {
    975          // already covered by FindBest32X32
    976          continue;
    977        }
    978        if (cy + 1 < rect.ysize() && cx + 1 < rect.xsize()) {
    979          if (tx.type == AcStrategyType::DCT8X16) {
    980            // We handle both DCT8X16 and DCT16X8 at the same time.
    981            if ((cy | cx) % 2 == 0) {
    982              JXL_RETURN_IF_ERROR(FindBestFirstLevelDivisionForSquare(
    983                  2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
    984                  tx.entropy_mul, entropy_mul16X16, entropy_estimate, block,
    985                  scratch_space, quantized));
    986            }
    987            continue;
    988          } else if (tx.type == AcStrategyType::DCT16X8) {
    989            // We handled both DCT8X16 and DCT16X8 at the same time,
    990            // and that is above. The last column and last row,
    991            // when the last column or last row is odd numbered,
    992            // are still handled by TryMergeAcs.
    993            continue;
    994          }
    995        }
    996        if ((tx.type == AcStrategyType::DCT8X16 && cy % 2 == 1) ||
    997            (tx.type == AcStrategyType::DCT16X8 && cx % 2 == 1)) {
    998          // already covered by FindBestFirstLevelDivisionForSquare
    999          continue;
   1000        }
   1001        // All other merge sizes are handled here.
   1002        // Some of the DCT16X8s and DCT8X16s will still leak through here
   1003        // when there is an odd number of 8x8 blocks, then the last row
   1004        // and column will get their DCT16X8s and DCT8X16s through the
   1005        // normal integral transform merging process.
   1006        JXL_RETURN_IF_ERROR(
   1007            TryMergeAcs(tx.type, bx, by, cx, cy, config, cmap_factors,
   1008                        ac_strategy, tx.entropy_mul, tx.priority, &priority[0],
   1009                        entropy_estimate, block, scratch_space, quantized));
   1010      }
   1011    }
   1012  }
   1013  if (cparams.speed_tier >= SpeedTier::kHare) {
   1014    return true;
   1015  }
   1016  // Here we still try to do some non-aligned matching, find a few more
   1017  // 16X8, 8X16 and 16X16s between the non-2-aligned blocks.
   1018  for (size_t cy = 0; cy + 1 < rect.ysize(); ++cy) {
   1019    for (size_t cx = 0; cx + 1 < rect.xsize(); ++cx) {
   1020      if ((cy | cx) % 2 != 0) {
   1021        JXL_RETURN_IF_ERROR(FindBestFirstLevelDivisionForSquare(
   1022            2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
   1023            entropy_mul16X8, entropy_mul16X16, entropy_estimate, block,
   1024            scratch_space, quantized));
   1025      }
   1026    }
   1027  }
   1028  // Non-aligned matching for 32X32, 16X32 and 32X16.
   1029  size_t step = cparams.speed_tier >= SpeedTier::kTortoise ? 2 : 1;
   1030  for (size_t cy = 0; cy + 3 < rect.ysize(); cy += step) {
   1031    for (size_t cx = 0; cx + 3 < rect.xsize(); cx += step) {
   1032      if ((cy | cx) % 4 == 0) {
   1033        continue;  // Already tried with loop above (DCT16X32 case).
   1034      }
   1035      JXL_RETURN_IF_ERROR(FindBestFirstLevelDivisionForSquare(
   1036          4, enable_32x32, bx, by, cx, cy, config, cmap_factors, ac_strategy,
   1037          entropy_mul16X32, entropy_mul32X32, entropy_estimate, block,
   1038          scratch_space, quantized));
   1039    }
   1040  }
   1041  return true;
   1042 }
   1043 
   1044 // NOLINTNEXTLINE(google-readability-namespace-comments)
   1045 }  // namespace HWY_NAMESPACE
   1046 }  // namespace jxl
   1047 HWY_AFTER_NAMESPACE();
   1048 
   1049 #if HWY_ONCE
   1050 namespace jxl {
   1051 HWY_EXPORT(ProcessRectACS);
   1052 
   1053 Status AcStrategyHeuristics::Init(const Image3F& src, const Rect& rect_in,
   1054                                  const ImageF& quant_field, const ImageF& mask,
   1055                                  const ImageF& mask1x1,
   1056                                  DequantMatrices* matrices) {
   1057  config.dequant = matrices;
   1058 
   1059  if (cparams.speed_tier >= SpeedTier::kCheetah) {
   1060    JXL_RETURN_IF_ERROR(
   1061        matrices->EnsureComputed(memory_manager, 1));  // DCT8 only
   1062  } else {
   1063    uint32_t acs_mask = 0;
   1064    // All transforms up to 64x64.
   1065    for (size_t i = 0; i < static_cast<size_t>(AcStrategyType::DCT128X128);
   1066         i++) {
   1067      acs_mask |= (1 << i);
   1068    }
   1069    JXL_RETURN_IF_ERROR(matrices->EnsureComputed(memory_manager, acs_mask));
   1070  }
   1071 
   1072  // Image row pointers and strides.
   1073  config.quant_field_row = quant_field.Row(0);
   1074  config.quant_field_stride = quant_field.PixelsPerRow();
   1075  if (mask.xsize() > 0 && mask.ysize() > 0) {
   1076    config.masking_field_row = mask.Row(0);
   1077    config.masking_field_stride = mask.PixelsPerRow();
   1078  }
   1079  config.mask1x1_xsize = mask1x1.xsize();
   1080  if (mask1x1.xsize() > 0 && mask1x1.ysize() > 0) {
   1081    config.masking1x1_field_row = mask1x1.Row(0);
   1082    config.masking1x1_field_stride = mask1x1.PixelsPerRow();
   1083  }
   1084 
   1085  config.src_rows[0] = rect_in.ConstPlaneRow(src, 0, 0);
   1086  config.src_rows[1] = rect_in.ConstPlaneRow(src, 1, 0);
   1087  config.src_rows[2] = rect_in.ConstPlaneRow(src, 2, 0);
   1088  config.src_stride = src.PixelsPerRow();
   1089 
   1090  // Entropy estimate is composed of two factors:
   1091  //  - estimate of the number of bits that will be used by the block
   1092  //  - information loss due to quantization
   1093  // The following constant controls the relative weights of these components.
   1094  config.info_loss_multiplier = 1.2;
   1095  config.zeros_mul = 9.3089059022677905;
   1096  config.cost_delta = 10.833273317067883;
   1097 
   1098  static const float kBias = 0.13731742964354549;
   1099  const float ratio = (cparams.butteraugli_distance + kBias) / (1.0f + kBias);
   1100 
   1101  static const float kPow1 = 0.33677806662454718;
   1102  static const float kPow2 = 0.50990926717963703;
   1103  static const float kPow3 = 0.36702940662370243;
   1104  config.info_loss_multiplier *= std::pow(ratio, kPow1);
   1105  config.zeros_mul *= std::pow(ratio, kPow2);
   1106  config.cost_delta *= std::pow(ratio, kPow3);
   1107  return true;
   1108 }
   1109 
   1110 Status AcStrategyHeuristics::PrepareForThreads(std::size_t num_threads) {
   1111  const size_t dct_scratch_size =
   1112      3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim;
   1113  mem_per_thread = 6 * AcStrategy::kMaxCoeffArea + dct_scratch_size;
   1114  size_t mem_bytes = num_threads * mem_per_thread * sizeof(float);
   1115  JXL_ASSIGN_OR_RETURN(mem, AlignedMemory::Create(memory_manager, mem_bytes));
   1116  qmem_per_thread = AcStrategy::kMaxCoeffArea;
   1117  size_t qmem_bytes = num_threads * qmem_per_thread * sizeof(uint32_t);
   1118  JXL_ASSIGN_OR_RETURN(qmem, AlignedMemory::Create(memory_manager, qmem_bytes));
   1119  return true;
   1120 }
   1121 
   1122 Status AcStrategyHeuristics::ProcessRect(const Rect& rect,
   1123                                         const ColorCorrelationMap& cmap,
   1124                                         AcStrategyImage* ac_strategy,
   1125                                         size_t thread) {
   1126  // In Falcon mode, use DCT8 everywhere and uniform quantization.
   1127  if (cparams.speed_tier >= SpeedTier::kCheetah) {
   1128    ac_strategy->FillDCT8(rect);
   1129    return true;
   1130  }
   1131  return HWY_DYNAMIC_DISPATCH(ProcessRectACS)(
   1132      cparams, config, rect, cmap,
   1133      mem.address<float>() + thread * mem_per_thread,
   1134      qmem.address<uint32_t>() + thread * qmem_per_thread, ac_strategy);
   1135 }
   1136 
   1137 Status AcStrategyHeuristics::Finalize(const FrameDimensions& frame_dim,
   1138                                      const AcStrategyImage& ac_strategy,
   1139                                      AuxOut* aux_out) {
   1140  // Accounting and debug output.
   1141  if (aux_out != nullptr) {
   1142    aux_out->num_small_blocks =
   1143        ac_strategy.CountBlocks(AcStrategyType::IDENTITY) +
   1144        ac_strategy.CountBlocks(AcStrategyType::DCT2X2) +
   1145        ac_strategy.CountBlocks(AcStrategyType::DCT4X4);
   1146    aux_out->num_dct4x8_blocks =
   1147        ac_strategy.CountBlocks(AcStrategyType::DCT4X8) +
   1148        ac_strategy.CountBlocks(AcStrategyType::DCT8X4);
   1149    aux_out->num_afv_blocks = ac_strategy.CountBlocks(AcStrategyType::AFV0) +
   1150                              ac_strategy.CountBlocks(AcStrategyType::AFV1) +
   1151                              ac_strategy.CountBlocks(AcStrategyType::AFV2) +
   1152                              ac_strategy.CountBlocks(AcStrategyType::AFV3);
   1153    aux_out->num_dct8_blocks = ac_strategy.CountBlocks(AcStrategyType::DCT);
   1154    aux_out->num_dct8x16_blocks =
   1155        ac_strategy.CountBlocks(AcStrategyType::DCT8X16) +
   1156        ac_strategy.CountBlocks(AcStrategyType::DCT16X8);
   1157    aux_out->num_dct8x32_blocks =
   1158        ac_strategy.CountBlocks(AcStrategyType::DCT8X32) +
   1159        ac_strategy.CountBlocks(AcStrategyType::DCT32X8);
   1160    aux_out->num_dct16_blocks =
   1161        ac_strategy.CountBlocks(AcStrategyType::DCT16X16);
   1162    aux_out->num_dct16x32_blocks =
   1163        ac_strategy.CountBlocks(AcStrategyType::DCT16X32) +
   1164        ac_strategy.CountBlocks(AcStrategyType::DCT32X16);
   1165    aux_out->num_dct32_blocks =
   1166        ac_strategy.CountBlocks(AcStrategyType::DCT32X32);
   1167    aux_out->num_dct32x64_blocks =
   1168        ac_strategy.CountBlocks(AcStrategyType::DCT32X64) +
   1169        ac_strategy.CountBlocks(AcStrategyType::DCT64X32);
   1170    aux_out->num_dct64_blocks =
   1171        ac_strategy.CountBlocks(AcStrategyType::DCT64X64);
   1172  }
   1173 
   1174  if (JXL_DEBUG_AC_STRATEGY && WantDebugOutput(cparams)) {
   1175    JXL_RETURN_IF_ERROR(DumpAcStrategy(ac_strategy, frame_dim.xsize,
   1176                                       frame_dim.ysize, "ac_strategy", aux_out,
   1177                                       cparams));
   1178  }
   1179  return true;
   1180 }
   1181 
   1182 }  // namespace jxl
   1183 #endif  // HWY_ONCE