tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_chroma_from_luma.cc (16260B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_chroma_from_luma.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <algorithm>
     11 #include <cfloat>
     12 #include <cmath>
     13 #include <cstdlib>
     14 #include <hwy/base.h>  // HWY_ALIGN_MAX
     15 
     16 #undef HWY_TARGET_INCLUDE
     17 #define HWY_TARGET_INCLUDE "lib/jxl/enc_chroma_from_luma.cc"
     18 #include <hwy/foreach_target.h>
     19 #include <hwy/highway.h>
     20 
     21 #include "lib/jxl/base/common.h"
     22 #include "lib/jxl/base/rect.h"
     23 #include "lib/jxl/base/status.h"
     24 #include "lib/jxl/cms/opsin_params.h"
     25 #include "lib/jxl/dec_transforms-inl.h"
     26 #include "lib/jxl/enc_aux_out.h"
     27 #include "lib/jxl/enc_params.h"
     28 #include "lib/jxl/enc_transforms-inl.h"
     29 #include "lib/jxl/quantizer.h"
     30 #include "lib/jxl/simd_util.h"
     31 HWY_BEFORE_NAMESPACE();
     32 namespace jxl {
     33 namespace HWY_NAMESPACE {
     34 
     35 // These templates are not found via ADL.
     36 using hwy::HWY_NAMESPACE::Abs;
     37 using hwy::HWY_NAMESPACE::Ge;
     38 using hwy::HWY_NAMESPACE::GetLane;
     39 using hwy::HWY_NAMESPACE::IfThenElse;
     40 using hwy::HWY_NAMESPACE::Lt;
     41 
     42 static HWY_FULL(float) df;
     43 
     44 struct CFLFunction {
     45  static constexpr float kCoeff = 1.f / 3;
     46  static constexpr float kThres = 100.0f;
     47  static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor;
     48  CFLFunction(const float* values_m, const float* values_s, size_t num,
     49              float base, float distance_mul)
     50      : values_m(values_m),
     51        values_s(values_s),
     52        num(num),
     53        base(base),
     54        distance_mul(distance_mul) {
     55    JXL_DASSERT(num % Lanes(df) == 0);
     56  }
     57 
     58  // Returns f'(x), where f is 1/3 * sum ((|color residual| + 1)^2-1) +
     59  // distance_mul * x^2 * num.
     60  float Compute(float x, float eps, float* fpeps, float* fmeps) const {
     61    float first_derivative = 2 * distance_mul * num * x;
     62    float first_derivative_peps = 2 * distance_mul * num * (x + eps);
     63    float first_derivative_meps = 2 * distance_mul * num * (x - eps);
     64 
     65    const auto inv_color_factor = Set(df, kInvColorFactor);
     66    const auto thres = Set(df, kThres);
     67    const auto coeffx2 = Set(df, kCoeff * 2.0f);
     68    const auto one = Set(df, 1.0f);
     69    const auto zero = Set(df, 0.0f);
     70    const auto base_v = Set(df, base);
     71    const auto x_v = Set(df, x);
     72    const auto xpe_v = Set(df, x + eps);
     73    const auto xme_v = Set(df, x - eps);
     74    auto fd_v = Zero(df);
     75    auto fdpe_v = Zero(df);
     76    auto fdme_v = Zero(df);
     77 
     78    for (size_t i = 0; i < num; i += Lanes(df)) {
     79      // color residual = ax + b
     80      const auto a = Mul(inv_color_factor, Load(df, values_m + i));
     81      const auto b =
     82          Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i));
     83      const auto v = MulAdd(a, x_v, b);
     84      const auto vpe = MulAdd(a, xpe_v, b);
     85      const auto vme = MulAdd(a, xme_v, b);
     86      const auto av = Abs(v);
     87      const auto avpe = Abs(vpe);
     88      const auto avme = Abs(vme);
     89      const auto acoeffx2 = Mul(coeffx2, a);
     90      auto d = Mul(acoeffx2, Add(av, one));
     91      auto dpe = Mul(acoeffx2, Add(avpe, one));
     92      auto dme = Mul(acoeffx2, Add(avme, one));
     93      d = IfThenElse(Lt(v, zero), Sub(zero, d), d);
     94      dpe = IfThenElse(Lt(vpe, zero), Sub(zero, dpe), dpe);
     95      dme = IfThenElse(Lt(vme, zero), Sub(zero, dme), dme);
     96      const auto above = Ge(av, thres);
     97      // TODO(eustas): use IfThenElseZero
     98      fd_v = Add(fd_v, IfThenElse(above, zero, d));
     99      fdpe_v = Add(fdpe_v, IfThenElse(above, zero, dpe));
    100      fdme_v = Add(fdme_v, IfThenElse(above, zero, dme));
    101    }
    102 
    103    *fpeps = first_derivative_peps + GetLane(SumOfLanes(df, fdpe_v));
    104    *fmeps = first_derivative_meps + GetLane(SumOfLanes(df, fdme_v));
    105    return first_derivative + GetLane(SumOfLanes(df, fd_v));
    106  }
    107 
    108  const float* JXL_RESTRICT values_m;
    109  const float* JXL_RESTRICT values_s;
    110  size_t num;
    111  float base;
    112  float distance_mul;
    113 };
    114 
    115 // Chroma-from-luma search, values_m will have luma -- and values_s chroma.
    116 int32_t FindBestMultiplier(const float* values_m, const float* values_s,
    117                           size_t num, float base, float distance_mul,
    118                           bool fast) {
    119  if (num == 0) {
    120    return 0;
    121  }
    122  float x;
    123  if (fast) {
    124    static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor;
    125    auto ca = Zero(df);
    126    auto cb = Zero(df);
    127    const auto inv_color_factor = Set(df, kInvColorFactor);
    128    const auto base_v = Set(df, base);
    129    for (size_t i = 0; i < num; i += Lanes(df)) {
    130      // color residual = ax + b
    131      const auto a = Mul(inv_color_factor, Load(df, values_m + i));
    132      const auto b =
    133          Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i));
    134      ca = MulAdd(a, a, ca);
    135      cb = MulAdd(a, b, cb);
    136    }
    137    // + distance_mul * x^2 * num
    138    x = -GetLane(SumOfLanes(df, cb)) /
    139        (GetLane(SumOfLanes(df, ca)) + num * distance_mul * 0.5f);
    140  } else {
    141    constexpr float eps = 100;
    142    constexpr float kClamp = 20.0f;
    143    CFLFunction fn(values_m, values_s, num, base, distance_mul);
    144    x = 0;
    145    // Up to 20 Newton iterations, with approximate derivatives.
    146    // Derivatives are approximate due to the high amount of noise in the exact
    147    // derivatives.
    148    for (size_t i = 0; i < 20; i++) {
    149      float dfpeps;
    150      float dfmeps;
    151      float df = fn.Compute(x, eps, &dfpeps, &dfmeps);
    152      float ddf = (dfpeps - dfmeps) / (2 * eps);
    153      float kExperimentalInsignificantStabilizer = 0.85;
    154      float step = df / (ddf + kExperimentalInsignificantStabilizer);
    155      x -= std::min(kClamp, std::max(-kClamp, step));
    156      if (std::abs(step) < 3e-3) break;
    157    }
    158  }
    159  // CFL seems to be tricky for larger transforms for HF components
    160  // close to zero. This heuristic brings the solutions closer to zero
    161  // and reduces red-green oscillations. A better approach would
    162  // look into variance of the multiplier within separate (e.g. 8x8)
    163  // areas and only apply this heuristic where there is a high variance.
    164  // This would give about 1 % more compression density.
    165  float towards_zero = 2.6;
    166  if (x >= towards_zero) {
    167    x -= towards_zero;
    168  } else if (x <= -towards_zero) {
    169    x += towards_zero;
    170  } else {
    171    x = 0;
    172  }
    173  return std::max(-128.0f, std::min(127.0f, roundf(x)));
    174 }
    175 
    176 Status InitDCStorage(JxlMemoryManager* memory_manager, size_t num_blocks,
    177                     ImageF* dc_values) {
    178  // First row: Y channel
    179  // Second row: X channel
    180  // Third row: Y channel
    181  // Fourth row: B channel
    182  JXL_ASSIGN_OR_RETURN(
    183      *dc_values,
    184      ImageF::Create(memory_manager, RoundUpTo(num_blocks, Lanes(df)), 4));
    185 
    186  JXL_ENSURE(dc_values->xsize() != 0);
    187  // Zero-fill the last lanes
    188  for (size_t y = 0; y < 4; y++) {
    189    for (size_t x = dc_values->xsize() - Lanes(df); x < dc_values->xsize();
    190         x++) {
    191      dc_values->Row(y)[x] = 0;
    192    }
    193  }
    194  return true;
    195 }
    196 
    197 Status ComputeTile(const Image3F& opsin, const Rect& opsin_rect,
    198                   const DequantMatrices& dequant,
    199                   const AcStrategyImage* ac_strategy,
    200                   const ImageI* raw_quant_field, const Quantizer* quantizer,
    201                   const Rect& rect, bool fast, bool use_dct8, ImageSB* map_x,
    202                   ImageSB* map_b, ImageF* dc_values, float* mem) {
    203  static_assert(kEncTileDimInBlocks == kColorTileDimInBlocks,
    204                "Invalid color tile dim");
    205  size_t xsize_blocks = opsin_rect.xsize() / kBlockDim;
    206  constexpr float kDistanceMultiplierAC = 1e-9f;
    207  const size_t dct_scratch_size =
    208      3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim;
    209 
    210  const size_t y0 = rect.y0();
    211  const size_t x0 = rect.x0();
    212  const size_t x1 = rect.x0() + rect.xsize();
    213  const size_t y1 = rect.y0() + rect.ysize();
    214 
    215  int ty = y0 / kColorTileDimInBlocks;
    216  int tx = x0 / kColorTileDimInBlocks;
    217 
    218  int8_t* JXL_RESTRICT row_out_x = map_x->Row(ty);
    219  int8_t* JXL_RESTRICT row_out_b = map_b->Row(ty);
    220 
    221  float* JXL_RESTRICT dc_values_yx = dc_values->Row(0);
    222  float* JXL_RESTRICT dc_values_x = dc_values->Row(1);
    223  float* JXL_RESTRICT dc_values_yb = dc_values->Row(2);
    224  float* JXL_RESTRICT dc_values_b = dc_values->Row(3);
    225 
    226  // All are aligned.
    227  float* HWY_RESTRICT block_y = mem;
    228  float* HWY_RESTRICT block_x = block_y + AcStrategy::kMaxCoeffArea;
    229  float* HWY_RESTRICT block_b = block_x + AcStrategy::kMaxCoeffArea;
    230  float* HWY_RESTRICT coeffs_yx = block_b + AcStrategy::kMaxCoeffArea;
    231  float* HWY_RESTRICT coeffs_x = coeffs_yx + kColorTileDim * kColorTileDim;
    232  float* HWY_RESTRICT coeffs_yb = coeffs_x + kColorTileDim * kColorTileDim;
    233  float* HWY_RESTRICT coeffs_b = coeffs_yb + kColorTileDim * kColorTileDim;
    234  float* HWY_RESTRICT scratch_space = coeffs_b + kColorTileDim * kColorTileDim;
    235  float* scratch_space_end =
    236      scratch_space + 2 * AcStrategy::kMaxCoeffArea + dct_scratch_size;
    237  JXL_ENSURE(scratch_space_end == block_y + CfLHeuristics::ItemsPerThread());
    238  (void)scratch_space_end;
    239 
    240  // Small (~256 bytes each)
    241  HWY_ALIGN_MAX float
    242      dc_y[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {};
    243  HWY_ALIGN_MAX float
    244      dc_x[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {};
    245  HWY_ALIGN_MAX float
    246      dc_b[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {};
    247  size_t num_ac = 0;
    248 
    249  for (size_t y = y0; y < y1; ++y) {
    250    const float* JXL_RESTRICT row_y =
    251        opsin_rect.ConstPlaneRow(opsin, 1, y * kBlockDim);
    252    const float* JXL_RESTRICT row_x =
    253        opsin_rect.ConstPlaneRow(opsin, 0, y * kBlockDim);
    254    const float* JXL_RESTRICT row_b =
    255        opsin_rect.ConstPlaneRow(opsin, 2, y * kBlockDim);
    256    size_t stride = opsin.PixelsPerRow();
    257 
    258    for (size_t x = x0; x < x1; x++) {
    259      AcStrategy acs = use_dct8
    260                           ? AcStrategy::FromRawStrategy(AcStrategyType::DCT)
    261                           : ac_strategy->ConstRow(y)[x];
    262      if (!acs.IsFirstBlock()) continue;
    263      size_t xs = acs.covered_blocks_x();
    264      TransformFromPixels(acs.Strategy(), row_y + x * kBlockDim, stride,
    265                          block_y, scratch_space);
    266      DCFromLowestFrequencies(acs.Strategy(), block_y, dc_y, xs);
    267      TransformFromPixels(acs.Strategy(), row_x + x * kBlockDim, stride,
    268                          block_x, scratch_space);
    269      DCFromLowestFrequencies(acs.Strategy(), block_x, dc_x, xs);
    270      TransformFromPixels(acs.Strategy(), row_b + x * kBlockDim, stride,
    271                          block_b, scratch_space);
    272      DCFromLowestFrequencies(acs.Strategy(), block_b, dc_b, xs);
    273      const float* const JXL_RESTRICT qm_x =
    274          dequant.InvMatrix(acs.Strategy(), 0);
    275      const float* const JXL_RESTRICT qm_b =
    276          dequant.InvMatrix(acs.Strategy(), 2);
    277      float q_dc_x = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(0);
    278      float q_dc_b = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(2);
    279 
    280      // Copy DCs in dc_values.
    281      for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    282        for (size_t ix = 0; ix < xs; ix++) {
    283          dc_values_yx[(iy + y) * xsize_blocks + ix + x] =
    284              dc_y[iy * xs + ix] * q_dc_x;
    285          dc_values_x[(iy + y) * xsize_blocks + ix + x] =
    286              dc_x[iy * xs + ix] * q_dc_x;
    287          dc_values_yb[(iy + y) * xsize_blocks + ix + x] =
    288              dc_y[iy * xs + ix] * q_dc_b;
    289          dc_values_b[(iy + y) * xsize_blocks + ix + x] =
    290              dc_b[iy * xs + ix] * q_dc_b;
    291        }
    292      }
    293 
    294      // Do not use this block for computing AC CfL.
    295      if (acs.covered_blocks_x() + x0 > x1 ||
    296          acs.covered_blocks_y() + y0 > y1) {
    297        continue;
    298      }
    299 
    300      // Copy AC coefficients in the local block. The order in which
    301      // coefficients get stored does not matter.
    302      size_t cx = acs.covered_blocks_x();
    303      size_t cy = acs.covered_blocks_y();
    304      CoefficientLayout(&cy, &cx);
    305      // Zero out LFs. This introduces terms in the optimization loop that
    306      // don't affect the result, as they are all 0, but allow for simpler
    307      // SIMDfication.
    308      for (size_t iy = 0; iy < cy; iy++) {
    309        for (size_t ix = 0; ix < cx; ix++) {
    310          block_y[cx * kBlockDim * iy + ix] = 0;
    311          block_x[cx * kBlockDim * iy + ix] = 0;
    312          block_b[cx * kBlockDim * iy + ix] = 0;
    313        }
    314      }
    315      // Unclear why this is like it is. (This works slightly better
    316      // than the previous approach which was also a hack.)
    317      const float qq =
    318          (raw_quant_field == nullptr) ? 1.0f : raw_quant_field->Row(y)[x];
    319      // Experimentally values 128-130 seem best -- I don't know why we
    320      // need this multiplier.
    321      const float kStrangeMultiplier = 128;
    322      float q = use_dct8 ? 1 : quantizer->Scale() * kStrangeMultiplier * qq;
    323      const auto qv = Set(df, q);
    324      for (size_t i = 0; i < cx * cy * 64; i += Lanes(df)) {
    325        const auto b_y = Load(df, block_y + i);
    326        const auto b_x = Load(df, block_x + i);
    327        const auto b_b = Load(df, block_b + i);
    328        const auto qqm_x = Mul(qv, Load(df, qm_x + i));
    329        const auto qqm_b = Mul(qv, Load(df, qm_b + i));
    330        Store(Mul(b_y, qqm_x), df, coeffs_yx + num_ac);
    331        Store(Mul(b_x, qqm_x), df, coeffs_x + num_ac);
    332        Store(Mul(b_y, qqm_b), df, coeffs_yb + num_ac);
    333        Store(Mul(b_b, qqm_b), df, coeffs_b + num_ac);
    334        num_ac += Lanes(df);
    335      }
    336    }
    337  }
    338  JXL_ENSURE(num_ac % Lanes(df) == 0);
    339  row_out_x[tx] = FindBestMultiplier(coeffs_yx, coeffs_x, num_ac, 0.0f,
    340                                     kDistanceMultiplierAC, fast);
    341  row_out_b[tx] =
    342      FindBestMultiplier(coeffs_yb, coeffs_b, num_ac, jxl::cms::kYToBRatio,
    343                         kDistanceMultiplierAC, fast);
    344  return true;
    345 }
    346 
    347 // NOLINTNEXTLINE(google-readability-namespace-comments)
    348 }  // namespace HWY_NAMESPACE
    349 }  // namespace jxl
    350 HWY_AFTER_NAMESPACE();
    351 
    352 #if HWY_ONCE
    353 namespace jxl {
    354 
    355 HWY_EXPORT(InitDCStorage);
    356 HWY_EXPORT(ComputeTile);
    357 
    358 Status CfLHeuristics::Init(const Rect& rect) {
    359  size_t xsize_blocks = rect.xsize() / kBlockDim;
    360  size_t ysize_blocks = rect.ysize() / kBlockDim;
    361  return HWY_DYNAMIC_DISPATCH(InitDCStorage)(
    362      memory_manager, xsize_blocks * ysize_blocks, &dc_values);
    363 }
    364 
    365 Status CfLHeuristics::ComputeTile(const Rect& r, const Image3F& opsin,
    366                                  const Rect& opsin_rect,
    367                                  const DequantMatrices& dequant,
    368                                  const AcStrategyImage* ac_strategy,
    369                                  const ImageI* raw_quant_field,
    370                                  const Quantizer* quantizer, bool fast,
    371                                  size_t thread, ColorCorrelationMap* cmap) {
    372  bool use_dct8 = ac_strategy == nullptr;
    373  return HWY_DYNAMIC_DISPATCH(ComputeTile)(
    374      opsin, opsin_rect, dequant, ac_strategy, raw_quant_field, quantizer, r,
    375      fast, use_dct8, &cmap->ytox_map, &cmap->ytob_map, &dc_values,
    376      mem.address<float>() + thread * ItemsPerThread());
    377 }
    378 
    379 Status ColorCorrelationEncodeDC(const ColorCorrelation& color_correlation,
    380                                BitWriter* writer, LayerType layer,
    381                                AuxOut* aux_out) {
    382  float color_factor = color_correlation.GetColorFactor();
    383  float base_correlation_x = color_correlation.GetBaseCorrelationX();
    384  float base_correlation_b = color_correlation.GetBaseCorrelationB();
    385  int32_t ytox_dc = color_correlation.GetYToXDC();
    386  int32_t ytob_dc = color_correlation.GetYToBDC();
    387 
    388  return writer->WithMaxBits(
    389      1 + 2 * kBitsPerByte + 12 + 32, layer, aux_out, [&]() -> Status {
    390        if (ytox_dc == 0 && ytob_dc == 0 &&
    391            color_factor == kDefaultColorFactor && base_correlation_x == 0.0f &&
    392            base_correlation_b == jxl::cms::kYToBRatio) {
    393          writer->Write(1, 1);
    394          return true;
    395        }
    396        writer->Write(1, 0);
    397        JXL_RETURN_IF_ERROR(
    398            U32Coder::Write(kColorFactorDist, color_factor, writer));
    399        JXL_RETURN_IF_ERROR(F16Coder::Write(base_correlation_x, writer));
    400        JXL_RETURN_IF_ERROR(F16Coder::Write(base_correlation_b, writer));
    401        writer->Write(kBitsPerByte,
    402                      ytox_dc - std::numeric_limits<int8_t>::min());
    403        writer->Write(kBitsPerByte,
    404                      ytob_dc - std::numeric_limits<int8_t>::min());
    405        return true;
    406      });
    407 }
    408 
    409 }  // namespace jxl
    410 #endif  // HWY_ONCE