tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

compressed_dc.cc (11319B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/compressed_dc.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <algorithm>
     11 #include <cstdint>
     12 #include <cstdlib>
     13 #include <cstring>
     14 #include <vector>
     15 
     16 #undef HWY_TARGET_INCLUDE
     17 #define HWY_TARGET_INCLUDE "lib/jxl/compressed_dc.cc"
     18 #include <hwy/foreach_target.h>
     19 #include <hwy/highway.h>
     20 
     21 #include "lib/jxl/base/compiler_specific.h"
     22 #include "lib/jxl/base/data_parallel.h"
     23 #include "lib/jxl/base/rect.h"
     24 #include "lib/jxl/base/status.h"
     25 #include "lib/jxl/image.h"
     26 HWY_BEFORE_NAMESPACE();
     27 namespace jxl {
     28 namespace HWY_NAMESPACE {
     29 
     30 using D = HWY_FULL(float);
     31 using DScalar = HWY_CAPPED(float, 1);
     32 
     33 // These templates are not found via ADL.
     34 using hwy::HWY_NAMESPACE::Abs;
     35 using hwy::HWY_NAMESPACE::Add;
     36 using hwy::HWY_NAMESPACE::Div;
     37 using hwy::HWY_NAMESPACE::Max;
     38 using hwy::HWY_NAMESPACE::Mul;
     39 using hwy::HWY_NAMESPACE::MulAdd;
     40 using hwy::HWY_NAMESPACE::Rebind;
     41 using hwy::HWY_NAMESPACE::Sub;
     42 using hwy::HWY_NAMESPACE::Vec;
     43 using hwy::HWY_NAMESPACE::ZeroIfNegative;
     44 
     45 // TODO(veluca): optimize constants.
     46 const float w1 = 0.20345139757231578f;
     47 const float w2 = 0.0334829185968739f;
     48 const float w0 = 1.0f - 4.0f * (w1 + w2);
     49 
     50 template <class V>
     51 V MaxWorkaround(V a, V b) {
     52 #if (HWY_TARGET == HWY_AVX3) && HWY_COMPILER_CLANG <= 800
     53  // Prevents "Do not know how to split the result of this operator" error
     54  return IfThenElse(a > b, a, b);
     55 #else
     56  return Max(a, b);
     57 #endif
     58 }
     59 
     60 template <typename D>
     61 JXL_INLINE void ComputePixelChannel(const D d, const float dc_factor,
     62                                    const float* JXL_RESTRICT row_top,
     63                                    const float* JXL_RESTRICT row,
     64                                    const float* JXL_RESTRICT row_bottom,
     65                                    Vec<D>* JXL_RESTRICT mc,
     66                                    Vec<D>* JXL_RESTRICT sm,
     67                                    Vec<D>* JXL_RESTRICT gap, size_t x) {
     68  const auto tl = LoadU(d, row_top + x - 1);
     69  const auto tc = Load(d, row_top + x);
     70  const auto tr = LoadU(d, row_top + x + 1);
     71 
     72  const auto ml = LoadU(d, row + x - 1);
     73  *mc = Load(d, row + x);
     74  const auto mr = LoadU(d, row + x + 1);
     75 
     76  const auto bl = LoadU(d, row_bottom + x - 1);
     77  const auto bc = Load(d, row_bottom + x);
     78  const auto br = LoadU(d, row_bottom + x + 1);
     79 
     80  const auto w_center = Set(d, w0);
     81  const auto w_side = Set(d, w1);
     82  const auto w_corner = Set(d, w2);
     83 
     84  const auto corner = Add(Add(tl, tr), Add(bl, br));
     85  const auto side = Add(Add(ml, mr), Add(tc, bc));
     86  *sm = MulAdd(corner, w_corner, MulAdd(side, w_side, Mul(*mc, w_center)));
     87 
     88  const auto dc_quant = Set(d, dc_factor);
     89  *gap = MaxWorkaround(*gap, Abs(Div(Sub(*mc, *sm), dc_quant)));
     90 }
     91 
     92 template <typename D>
     93 JXL_INLINE void ComputePixel(
     94    const float* JXL_RESTRICT dc_factors,
     95    const float* JXL_RESTRICT* JXL_RESTRICT rows_top,
     96    const float* JXL_RESTRICT* JXL_RESTRICT rows,
     97    const float* JXL_RESTRICT* JXL_RESTRICT rows_bottom,
     98    float* JXL_RESTRICT* JXL_RESTRICT out_rows, size_t x) {
     99  const D d;
    100  auto mc_x = Undefined(d);
    101  auto mc_y = Undefined(d);
    102  auto mc_b = Undefined(d);
    103  auto sm_x = Undefined(d);
    104  auto sm_y = Undefined(d);
    105  auto sm_b = Undefined(d);
    106  auto gap = Set(d, 0.5f);
    107  ComputePixelChannel(d, dc_factors[0], rows_top[0], rows[0], rows_bottom[0],
    108                      &mc_x, &sm_x, &gap, x);
    109  ComputePixelChannel(d, dc_factors[1], rows_top[1], rows[1], rows_bottom[1],
    110                      &mc_y, &sm_y, &gap, x);
    111  ComputePixelChannel(d, dc_factors[2], rows_top[2], rows[2], rows_bottom[2],
    112                      &mc_b, &sm_b, &gap, x);
    113  auto factor = MulAdd(Set(d, -4.0f), gap, Set(d, 3.0f));
    114  factor = ZeroIfNegative(factor);
    115 
    116  auto out = MulAdd(Sub(sm_x, mc_x), factor, mc_x);
    117  Store(out, d, out_rows[0] + x);
    118  out = MulAdd(Sub(sm_y, mc_y), factor, mc_y);
    119  Store(out, d, out_rows[1] + x);
    120  out = MulAdd(Sub(sm_b, mc_b), factor, mc_b);
    121  Store(out, d, out_rows[2] + x);
    122 }
    123 
    124 Status AdaptiveDCSmoothing(JxlMemoryManager* memory_manager,
    125                           const float* dc_factors, Image3F* dc,
    126                           ThreadPool* pool) {
    127  const size_t xsize = dc->xsize();
    128  const size_t ysize = dc->ysize();
    129  if (ysize <= 2 || xsize <= 2) return true;
    130 
    131  // TODO(veluca): use tile-based processing?
    132  // TODO(veluca): decide if changes to the y channel should be propagated to
    133  // the x and b channels through color correlation.
    134  JXL_ENSURE(w1 + w2 < 0.25f);
    135 
    136  JXL_ASSIGN_OR_RETURN(Image3F smoothed,
    137                       Image3F::Create(memory_manager, xsize, ysize));
    138  // Fill in borders that the loop below will not. First and last are unused.
    139  for (size_t c = 0; c < 3; c++) {
    140    for (size_t y : {static_cast<size_t>(0), ysize - 1}) {
    141      memcpy(smoothed.PlaneRow(c, y), dc->PlaneRow(c, y),
    142             xsize * sizeof(float));
    143    }
    144  }
    145  auto process_row = [&](const uint32_t y, size_t /*thread*/) -> Status {
    146    const float* JXL_RESTRICT rows_top[3]{
    147        dc->ConstPlaneRow(0, y - 1),
    148        dc->ConstPlaneRow(1, y - 1),
    149        dc->ConstPlaneRow(2, y - 1),
    150    };
    151    const float* JXL_RESTRICT rows[3] = {
    152        dc->ConstPlaneRow(0, y),
    153        dc->ConstPlaneRow(1, y),
    154        dc->ConstPlaneRow(2, y),
    155    };
    156    const float* JXL_RESTRICT rows_bottom[3] = {
    157        dc->ConstPlaneRow(0, y + 1),
    158        dc->ConstPlaneRow(1, y + 1),
    159        dc->ConstPlaneRow(2, y + 1),
    160    };
    161    float* JXL_RESTRICT rows_out[3] = {
    162        smoothed.PlaneRow(0, y),
    163        smoothed.PlaneRow(1, y),
    164        smoothed.PlaneRow(2, y),
    165    };
    166    for (size_t x : {static_cast<size_t>(0), xsize - 1}) {
    167      for (size_t c = 0; c < 3; c++) {
    168        rows_out[c][x] = rows[c][x];
    169      }
    170    }
    171 
    172    size_t x = 1;
    173    // First pixels
    174    const size_t N = Lanes(D());
    175    for (; x < std::min(N, xsize - 1); x++) {
    176      ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out,
    177                            x);
    178    }
    179    // Full vectors.
    180    for (; x + N <= xsize - 1; x += N) {
    181      ComputePixel<D>(dc_factors, rows_top, rows, rows_bottom, rows_out, x);
    182    }
    183    // Last pixels.
    184    for (; x < xsize - 1; x++) {
    185      ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out,
    186                            x);
    187    }
    188    return true;
    189  };
    190  JXL_RETURN_IF_ERROR(RunOnPool(pool, 1, ysize - 1, ThreadPool::NoInit,
    191                                process_row, "DCSmoothingRow"));
    192  dc->Swap(smoothed);
    193  return true;
    194 }
    195 
    196 // DC dequantization.
    197 void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in,
    198               const float* dc_factors, float mul, const float* cfl_factors,
    199               const YCbCrChromaSubsampling& chroma_subsampling,
    200               const BlockCtxMap& bctx) {
    201  const HWY_FULL(float) df;
    202  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
    203  if (chroma_subsampling.Is444()) {
    204    const auto fac_x = Set(df, dc_factors[0] * mul);
    205    const auto fac_y = Set(df, dc_factors[1] * mul);
    206    const auto fac_b = Set(df, dc_factors[2] * mul);
    207    const auto cfl_fac_x = Set(df, cfl_factors[0]);
    208    const auto cfl_fac_b = Set(df, cfl_factors[2]);
    209    for (size_t y = 0; y < r.ysize(); y++) {
    210      float* dec_row_x = r.PlaneRow(dc, 0, y);
    211      float* dec_row_y = r.PlaneRow(dc, 1, y);
    212      float* dec_row_b = r.PlaneRow(dc, 2, y);
    213      const int32_t* quant_row_x = in.channel[1].plane.Row(y);
    214      const int32_t* quant_row_y = in.channel[0].plane.Row(y);
    215      const int32_t* quant_row_b = in.channel[2].plane.Row(y);
    216      for (size_t x = 0; x < r.xsize(); x += Lanes(di)) {
    217        const auto in_q_x = Load(di, quant_row_x + x);
    218        const auto in_q_y = Load(di, quant_row_y + x);
    219        const auto in_q_b = Load(di, quant_row_b + x);
    220        const auto in_x = Mul(ConvertTo(df, in_q_x), fac_x);
    221        const auto in_y = Mul(ConvertTo(df, in_q_y), fac_y);
    222        const auto in_b = Mul(ConvertTo(df, in_q_b), fac_b);
    223        Store(in_y, df, dec_row_y + x);
    224        Store(MulAdd(in_y, cfl_fac_x, in_x), df, dec_row_x + x);
    225        Store(MulAdd(in_y, cfl_fac_b, in_b), df, dec_row_b + x);
    226      }
    227    }
    228  } else {
    229    for (size_t c : {1, 0, 2}) {
    230      Rect rect(r.x0() >> chroma_subsampling.HShift(c),
    231                r.y0() >> chroma_subsampling.VShift(c),
    232                r.xsize() >> chroma_subsampling.HShift(c),
    233                r.ysize() >> chroma_subsampling.VShift(c));
    234      const auto fac = Set(df, dc_factors[c] * mul);
    235      const Channel& ch = in.channel[c < 2 ? c ^ 1 : c];
    236      for (size_t y = 0; y < rect.ysize(); y++) {
    237        const int32_t* quant_row = ch.plane.Row(y);
    238        float* row = rect.PlaneRow(dc, c, y);
    239        for (size_t x = 0; x < rect.xsize(); x += Lanes(di)) {
    240          const auto in_q = Load(di, quant_row + x);
    241          const auto in = Mul(ConvertTo(df, in_q), fac);
    242          Store(in, df, row + x);
    243        }
    244      }
    245    }
    246  }
    247  if (bctx.num_dc_ctxs <= 1) {
    248    for (size_t y = 0; y < r.ysize(); y++) {
    249      uint8_t* qdc_row = r.Row(quant_dc, y);
    250      memset(qdc_row, 0, sizeof(*qdc_row) * r.xsize());
    251    }
    252  } else {
    253    for (size_t y = 0; y < r.ysize(); y++) {
    254      uint8_t* qdc_row_val = r.Row(quant_dc, y);
    255      const int32_t* quant_row_x =
    256          in.channel[1].plane.Row(y >> chroma_subsampling.VShift(0));
    257      const int32_t* quant_row_y =
    258          in.channel[0].plane.Row(y >> chroma_subsampling.VShift(1));
    259      const int32_t* quant_row_b =
    260          in.channel[2].plane.Row(y >> chroma_subsampling.VShift(2));
    261      for (size_t x = 0; x < r.xsize(); x++) {
    262        int bucket_x = 0;
    263        int bucket_y = 0;
    264        int bucket_b = 0;
    265        for (int t : bctx.dc_thresholds[0]) {
    266          if (quant_row_x[x >> chroma_subsampling.HShift(0)] > t) bucket_x++;
    267        }
    268        for (int t : bctx.dc_thresholds[1]) {
    269          if (quant_row_y[x >> chroma_subsampling.HShift(1)] > t) bucket_y++;
    270        }
    271        for (int t : bctx.dc_thresholds[2]) {
    272          if (quant_row_b[x >> chroma_subsampling.HShift(2)] > t) bucket_b++;
    273        }
    274        int bucket = bucket_x;
    275        bucket *= bctx.dc_thresholds[2].size() + 1;
    276        bucket += bucket_b;
    277        bucket *= bctx.dc_thresholds[1].size() + 1;
    278        bucket += bucket_y;
    279        qdc_row_val[x] = bucket;
    280      }
    281    }
    282  }
    283 }
    284 
    285 // NOLINTNEXTLINE(google-readability-namespace-comments)
    286 }  // namespace HWY_NAMESPACE
    287 }  // namespace jxl
    288 HWY_AFTER_NAMESPACE();
    289 
    290 #if HWY_ONCE
    291 namespace jxl {
    292 
    293 HWY_EXPORT(DequantDC);
    294 HWY_EXPORT(AdaptiveDCSmoothing);
    295 Status AdaptiveDCSmoothing(JxlMemoryManager* memory_manager,
    296                           const float* dc_factors, Image3F* dc,
    297                           ThreadPool* pool) {
    298  return HWY_DYNAMIC_DISPATCH(AdaptiveDCSmoothing)(memory_manager, dc_factors,
    299                                                   dc, pool);
    300 }
    301 
    302 void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in,
    303               const float* dc_factors, float mul, const float* cfl_factors,
    304               const YCbCrChromaSubsampling& chroma_subsampling,
    305               const BlockCtxMap& bctx) {
    306  HWY_DYNAMIC_DISPATCH(DequantDC)
    307  (r, dc, quant_dc, in, dc_factors, mul, cfl_factors, chroma_subsampling, bctx);
    308 }
    309 
    310 }  // namespace jxl
    311 #endif  // HWY_ONCE