tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

idct.cc (26909B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jpegli/idct.h"
      7 
      8 #include <cmath>
      9 
     10 #include "lib/jpegli/decode_internal.h"
     11 #include "lib/jxl/base/compiler_specific.h"
     12 #include "lib/jxl/base/status.h"
     13 
     14 #undef HWY_TARGET_INCLUDE
     15 #define HWY_TARGET_INCLUDE "lib/jpegli/idct.cc"
     16 #include <hwy/foreach_target.h>
     17 #include <hwy/highway.h>
     18 
     19 #include "lib/jpegli/transpose-inl.h"
     20 
     21 HWY_BEFORE_NAMESPACE();
     22 namespace jpegli {
     23 namespace HWY_NAMESPACE {
     24 
     25 // These templates are not found via ADL.
     26 using hwy::HWY_NAMESPACE::Abs;
     27 using hwy::HWY_NAMESPACE::Add;
     28 using hwy::HWY_NAMESPACE::Gt;
     29 using hwy::HWY_NAMESPACE::IfThenElseZero;
     30 using hwy::HWY_NAMESPACE::Mul;
     31 using hwy::HWY_NAMESPACE::MulAdd;
     32 using hwy::HWY_NAMESPACE::NegMulAdd;
     33 using hwy::HWY_NAMESPACE::Rebind;
     34 using hwy::HWY_NAMESPACE::Sub;
     35 using hwy::HWY_NAMESPACE::Vec;
     36 using hwy::HWY_NAMESPACE::Xor;
     37 
     38 using D = HWY_FULL(float);
     39 using DI = HWY_FULL(int32_t);
     40 constexpr D d;
     41 constexpr DI di;
     42 
     43 using D8 = HWY_CAPPED(float, 8);
     44 constexpr D8 d8;
     45 
     46 void DequantBlock(const int16_t* JXL_RESTRICT qblock,
     47                  const float* JXL_RESTRICT dequant,
     48                  const float* JXL_RESTRICT biases, float* JXL_RESTRICT block) {
     49  for (size_t k = 0; k < 64; k += Lanes(d)) {
     50    const auto mul = Load(d, dequant + k);
     51    const auto bias = Load(d, biases + k);
     52    const Rebind<int16_t, DI> di16;
     53    const Vec<DI> quant_i = PromoteTo(di, Load(di16, qblock + k));
     54    const Rebind<float, DI> df;
     55    const auto quant = ConvertTo(df, quant_i);
     56    const auto abs_quant = Abs(quant);
     57    const auto not_0 = Gt(abs_quant, Zero(df));
     58    const auto sign_quant = Xor(quant, abs_quant);
     59    const auto biased_quant = Sub(quant, Xor(bias, sign_quant));
     60    const auto dequant = IfThenElseZero(not_0, Mul(biased_quant, mul));
     61    Store(dequant, d, block + k);
     62  }
     63 }
     64 
     65 template <size_t N>
     66 void ForwardEvenOdd(const float* JXL_RESTRICT a_in, size_t a_in_stride,
     67                    float* JXL_RESTRICT a_out) {
     68  for (size_t i = 0; i < N / 2; i++) {
     69    auto in1 = LoadU(d8, a_in + 2 * i * a_in_stride);
     70    Store(in1, d8, a_out + i * 8);
     71  }
     72  for (size_t i = N / 2; i < N; i++) {
     73    auto in1 = LoadU(d8, a_in + (2 * (i - N / 2) + 1) * a_in_stride);
     74    Store(in1, d8, a_out + i * 8);
     75  }
     76 }
     77 
     78 template <size_t N>
     79 void BTranspose(float* JXL_RESTRICT coeff) {
     80  for (size_t i = N - 1; i > 0; i--) {
     81    auto in1 = Load(d8, coeff + i * 8);
     82    auto in2 = Load(d8, coeff + (i - 1) * 8);
     83    Store(Add(in1, in2), d8, coeff + i * 8);
     84  }
     85  constexpr float kSqrt2 = 1.41421356237f;
     86  auto sqrt2 = Set(d8, kSqrt2);
     87  auto in1 = Load(d8, coeff);
     88  Store(Mul(in1, sqrt2), d8, coeff);
     89 }
     90 
     91 // Constants for DCT implementation. Generated by the following snippet:
     92 // for i in range(N // 2):
     93 //    print(1.0 / (2 * math.cos((i + 0.5) * math.pi / N)), end=", ")
     94 template <size_t N>
     95 struct WcMultipliers;
     96 
     97 template <>
     98 struct WcMultipliers<4> {
     99  static constexpr float kMultipliers[] = {
    100      0.541196100146197,
    101      1.3065629648763764,
    102  };
    103 };
    104 
    105 template <>
    106 struct WcMultipliers<8> {
    107  static constexpr float kMultipliers[] = {
    108      0.5097955791041592,
    109      0.6013448869350453,
    110      0.8999762231364156,
    111      2.5629154477415055,
    112  };
    113 };
    114 
    115 #if JXL_CXX_LANG < JXL_CXX_17
    116 constexpr float WcMultipliers<4>::kMultipliers[];
    117 constexpr float WcMultipliers<8>::kMultipliers[];
    118 #endif
    119 
    120 template <size_t N>
    121 void MultiplyAndAdd(const float* JXL_RESTRICT coeff, float* JXL_RESTRICT out,
    122                    size_t out_stride) {
    123  for (size_t i = 0; i < N / 2; i++) {
    124    auto mul = Set(d8, WcMultipliers<N>::kMultipliers[i]);
    125    auto in1 = Load(d8, coeff + i * 8);
    126    auto in2 = Load(d8, coeff + (N / 2 + i) * 8);
    127    auto out1 = MulAdd(mul, in2, in1);
    128    auto out2 = NegMulAdd(mul, in2, in1);
    129    StoreU(out1, d8, out + i * out_stride);
    130    StoreU(out2, d8, out + (N - i - 1) * out_stride);
    131  }
    132 }
    133 
    134 template <size_t N>
    135 struct IDCT1DImpl;
    136 
    137 template <>
    138 struct IDCT1DImpl<1> {
    139  JXL_INLINE void operator()(const float* from, size_t from_stride, float* to,
    140                             size_t to_stride) {
    141    StoreU(LoadU(d8, from), d8, to);
    142  }
    143 };
    144 
    145 template <>
    146 struct IDCT1DImpl<2> {
    147  JXL_INLINE void operator()(const float* from, size_t from_stride, float* to,
    148                             size_t to_stride) {
    149    JXL_DASSERT(from_stride >= 8);
    150    JXL_DASSERT(to_stride >= 8);
    151    auto in1 = LoadU(d8, from);
    152    auto in2 = LoadU(d8, from + from_stride);
    153    StoreU(Add(in1, in2), d8, to);
    154    StoreU(Sub(in1, in2), d8, to + to_stride);
    155  }
    156 };
    157 
    158 template <size_t N>
    159 struct IDCT1DImpl {
    160  void operator()(const float* from, size_t from_stride, float* to,
    161                  size_t to_stride) {
    162    JXL_DASSERT(from_stride >= 8);
    163    JXL_DASSERT(to_stride >= 8);
    164    HWY_ALIGN float tmp[64];
    165    ForwardEvenOdd<N>(from, from_stride, tmp);
    166    IDCT1DImpl<N / 2>()(tmp, 8, tmp, 8);
    167    BTranspose<N / 2>(tmp + N * 4);
    168    IDCT1DImpl<N / 2>()(tmp + N * 4, 8, tmp + N * 4, 8);
    169    MultiplyAndAdd<N>(tmp, to, to_stride);
    170  }
    171 };
    172 
    173 template <size_t N>
    174 void IDCT1D(float* JXL_RESTRICT from, float* JXL_RESTRICT output,
    175            size_t output_stride) {
    176  for (size_t i = 0; i < 8; i += Lanes(d8)) {
    177    IDCT1DImpl<N>()(from + i, 8, output + i, output_stride);
    178  }
    179 }
    180 
    181 void ComputeScaledIDCT(float* JXL_RESTRICT block0, float* JXL_RESTRICT block1,
    182                       float* JXL_RESTRICT output, size_t output_stride) {
    183  Transpose8x8Block(block0, block1);
    184  IDCT1D<8>(block1, block0, 8);
    185  Transpose8x8Block(block0, block1);
    186  IDCT1D<8>(block1, output, output_stride);
    187 }
    188 
    189 void InverseTransformBlock8x8(const int16_t* JXL_RESTRICT qblock,
    190                              const float* JXL_RESTRICT dequant,
    191                              const float* JXL_RESTRICT biases,
    192                              float* JXL_RESTRICT scratch_space,
    193                              float* JXL_RESTRICT output, size_t output_stride,
    194                              size_t dctsize) {
    195  float* JXL_RESTRICT block0 = scratch_space;
    196  float* JXL_RESTRICT block1 = scratch_space + DCTSIZE2;
    197  DequantBlock(qblock, dequant, biases, block0);
    198  ComputeScaledIDCT(block0, block1, output, output_stride);
    199 }
    200 
    201 // Computes the N-point IDCT of in[], and stores the result in out[]. The in[]
    202 // array is at most 8 values long, values in[8:N-1] are assumed to be 0.
    203 void Compute1dIDCT(const float* in, float* out, size_t N) {
    204  switch (N) {
    205    case 3: {
    206      static constexpr float kC3[3] = {
    207          1.414213562373,
    208          1.224744871392,
    209          0.707106781187,
    210      };
    211      float even0 = in[0] + kC3[2] * in[2];
    212      float even1 = in[0] - kC3[0] * in[2];
    213      float odd0 = kC3[1] * in[1];
    214      out[0] = even0 + odd0;
    215      out[2] = even0 - odd0;
    216      out[1] = even1;
    217      break;
    218    }
    219    case 5: {
    220      static constexpr float kC5[5] = {
    221          1.414213562373, 1.344997023928, 1.144122805635,
    222          0.831253875555, 0.437016024449,
    223      };
    224      float even0 = in[0] + kC5[2] * in[2] + kC5[4] * in[4];
    225      float even1 = in[0] - kC5[4] * in[2] - kC5[2] * in[4];
    226      float even2 = in[0] - kC5[0] * in[2] + kC5[0] * in[4];
    227      float odd0 = kC5[1] * in[1] + kC5[3] * in[3];
    228      float odd1 = kC5[3] * in[1] - kC5[1] * in[3];
    229      out[0] = even0 + odd0;
    230      out[4] = even0 - odd0;
    231      out[1] = even1 + odd1;
    232      out[3] = even1 - odd1;
    233      out[2] = even2;
    234      break;
    235    }
    236    case 6: {
    237      static constexpr float kC6[6] = {
    238          1.414213562373, 1.366025403784, 1.224744871392,
    239          1.000000000000, 0.707106781187, 0.366025403784,
    240      };
    241      float even0 = in[0] + kC6[2] * in[2] + kC6[4] * in[4];
    242      float even1 = in[0] - kC6[0] * in[4];
    243      float even2 = in[0] - kC6[2] * in[2] + kC6[4] * in[4];
    244      float odd0 = kC6[1] * in[1] + kC6[3] * in[3] + kC6[5] * in[5];
    245      float odd1 = kC6[3] * in[1] - kC6[3] * in[3] - kC6[3] * in[5];
    246      float odd2 = kC6[5] * in[1] - kC6[3] * in[3] + kC6[1] * in[5];
    247      out[0] = even0 + odd0;
    248      out[5] = even0 - odd0;
    249      out[1] = even1 + odd1;
    250      out[4] = even1 - odd1;
    251      out[2] = even2 + odd2;
    252      out[3] = even2 - odd2;
    253      break;
    254    }
    255    case 7: {
    256      static constexpr float kC7[7] = {
    257          1.414213562373, 1.378756275744, 1.274162392264, 1.105676685997,
    258          0.881747733790, 0.613604268353, 0.314692122713,
    259      };
    260      float even0 = in[0] + kC7[2] * in[2] + kC7[4] * in[4] + kC7[6] * in[6];
    261      float even1 = in[0] + kC7[6] * in[2] - kC7[2] * in[4] - kC7[4] * in[6];
    262      float even2 = in[0] - kC7[4] * in[2] - kC7[6] * in[4] + kC7[2] * in[6];
    263      float even3 = in[0] - kC7[0] * in[2] + kC7[0] * in[4] - kC7[0] * in[6];
    264      float odd0 = kC7[1] * in[1] + kC7[3] * in[3] + kC7[5] * in[5];
    265      float odd1 = kC7[3] * in[1] - kC7[5] * in[3] - kC7[1] * in[5];
    266      float odd2 = kC7[5] * in[1] - kC7[1] * in[3] + kC7[3] * in[5];
    267      out[0] = even0 + odd0;
    268      out[6] = even0 - odd0;
    269      out[1] = even1 + odd1;
    270      out[5] = even1 - odd1;
    271      out[2] = even2 + odd2;
    272      out[4] = even2 - odd2;
    273      out[3] = even3;
    274      break;
    275    }
    276    case 9: {
    277      static constexpr float kC9[9] = {
    278          1.414213562373, 1.392728480640, 1.328926048777,
    279          1.224744871392, 1.083350440839, 0.909038955344,
    280          0.707106781187, 0.483689525296, 0.245575607938,
    281      };
    282      float even0 = in[0] + kC9[2] * in[2] + kC9[4] * in[4] + kC9[6] * in[6];
    283      float even1 = in[0] + kC9[6] * in[2] - kC9[6] * in[4] - kC9[0] * in[6];
    284      float even2 = in[0] - kC9[8] * in[2] - kC9[2] * in[4] + kC9[6] * in[6];
    285      float even3 = in[0] - kC9[4] * in[2] + kC9[8] * in[4] + kC9[6] * in[6];
    286      float even4 = in[0] - kC9[0] * in[2] + kC9[0] * in[4] - kC9[0] * in[6];
    287      float odd0 =
    288          kC9[1] * in[1] + kC9[3] * in[3] + kC9[5] * in[5] + kC9[7] * in[7];
    289      float odd1 = kC9[3] * in[1] - kC9[3] * in[5] - kC9[3] * in[7];
    290      float odd2 =
    291          kC9[5] * in[1] - kC9[3] * in[3] - kC9[7] * in[5] + kC9[1] * in[7];
    292      float odd3 =
    293          kC9[7] * in[1] - kC9[3] * in[3] + kC9[1] * in[5] - kC9[5] * in[7];
    294      out[0] = even0 + odd0;
    295      out[8] = even0 - odd0;
    296      out[1] = even1 + odd1;
    297      out[7] = even1 - odd1;
    298      out[2] = even2 + odd2;
    299      out[6] = even2 - odd2;
    300      out[3] = even3 + odd3;
    301      out[5] = even3 - odd3;
    302      out[4] = even4;
    303      break;
    304    }
    305    case 10: {
    306      static constexpr float kC10[10] = {
    307          1.414213562373, 1.396802246667, 1.344997023928, 1.260073510670,
    308          1.144122805635, 1.000000000000, 0.831253875555, 0.642039521920,
    309          0.437016024449, 0.221231742082,
    310      };
    311      float even0 = in[0] + kC10[2] * in[2] + kC10[4] * in[4] + kC10[6] * in[6];
    312      float even1 = in[0] + kC10[6] * in[2] - kC10[8] * in[4] - kC10[2] * in[6];
    313      float even2 = in[0] - kC10[0] * in[4];
    314      float even3 = in[0] - kC10[6] * in[2] - kC10[8] * in[4] + kC10[2] * in[6];
    315      float even4 = in[0] - kC10[2] * in[2] + kC10[4] * in[4] - kC10[6] * in[6];
    316      float odd0 =
    317          kC10[1] * in[1] + kC10[3] * in[3] + kC10[5] * in[5] + kC10[7] * in[7];
    318      float odd1 =
    319          kC10[3] * in[1] + kC10[9] * in[3] - kC10[5] * in[5] - kC10[1] * in[7];
    320      float odd2 =
    321          kC10[5] * in[1] - kC10[5] * in[3] - kC10[5] * in[5] + kC10[5] * in[7];
    322      float odd3 =
    323          kC10[7] * in[1] - kC10[1] * in[3] + kC10[5] * in[5] + kC10[9] * in[7];
    324      float odd4 =
    325          kC10[9] * in[1] - kC10[7] * in[3] + kC10[5] * in[5] - kC10[3] * in[7];
    326      out[0] = even0 + odd0;
    327      out[9] = even0 - odd0;
    328      out[1] = even1 + odd1;
    329      out[8] = even1 - odd1;
    330      out[2] = even2 + odd2;
    331      out[7] = even2 - odd2;
    332      out[3] = even3 + odd3;
    333      out[6] = even3 - odd3;
    334      out[4] = even4 + odd4;
    335      out[5] = even4 - odd4;
    336      break;
    337    }
    338    case 11: {
    339      static constexpr float kC11[11] = {
    340          1.414213562373, 1.399818907436, 1.356927976287, 1.286413904599,
    341          1.189712155524, 1.068791297809, 0.926112931411, 0.764581576418,
    342          0.587485545401, 0.398430002847, 0.201263574413,
    343      };
    344      float even0 = in[0] + kC11[2] * in[2] + kC11[4] * in[4] + kC11[6] * in[6];
    345      float even1 =
    346          in[0] + kC11[6] * in[2] - kC11[10] * in[4] - kC11[4] * in[6];
    347      float even2 =
    348          in[0] + kC11[10] * in[2] - kC11[2] * in[4] - kC11[8] * in[6];
    349      float even3 = in[0] - kC11[8] * in[2] - kC11[6] * in[4] + kC11[2] * in[6];
    350      float even4 =
    351          in[0] - kC11[4] * in[2] + kC11[8] * in[4] + kC11[10] * in[6];
    352      float even5 = in[0] - kC11[0] * in[2] + kC11[0] * in[4] - kC11[0] * in[6];
    353      float odd0 =
    354          kC11[1] * in[1] + kC11[3] * in[3] + kC11[5] * in[5] + kC11[7] * in[7];
    355      float odd1 =
    356          kC11[3] * in[1] + kC11[9] * in[3] - kC11[7] * in[5] - kC11[1] * in[7];
    357      float odd2 =
    358          kC11[5] * in[1] - kC11[7] * in[3] - kC11[3] * in[5] + kC11[9] * in[7];
    359      float odd3 =
    360          kC11[7] * in[1] - kC11[1] * in[3] + kC11[9] * in[5] + kC11[5] * in[7];
    361      float odd4 =
    362          kC11[9] * in[1] - kC11[5] * in[3] + kC11[1] * in[5] - kC11[3] * in[7];
    363      out[0] = even0 + odd0;
    364      out[10] = even0 - odd0;
    365      out[1] = even1 + odd1;
    366      out[9] = even1 - odd1;
    367      out[2] = even2 + odd2;
    368      out[8] = even2 - odd2;
    369      out[3] = even3 + odd3;
    370      out[7] = even3 - odd3;
    371      out[4] = even4 + odd4;
    372      out[6] = even4 - odd4;
    373      out[5] = even5;
    374      break;
    375    }
    376    case 12: {
    377      static constexpr float kC12[12] = {
    378          1.414213562373, 1.402114769300, 1.366025403784, 1.306562964876,
    379          1.224744871392, 1.121971053594, 1.000000000000, 0.860918669154,
    380          0.707106781187, 0.541196100146, 0.366025403784, 0.184591911283,
    381      };
    382      float even0 = in[0] + kC12[2] * in[2] + kC12[4] * in[4] + kC12[6] * in[6];
    383      float even1 = in[0] + kC12[6] * in[2] - kC12[6] * in[6];
    384      float even2 =
    385          in[0] + kC12[10] * in[2] - kC12[4] * in[4] - kC12[6] * in[6];
    386      float even3 =
    387          in[0] - kC12[10] * in[2] - kC12[4] * in[4] + kC12[6] * in[6];
    388      float even4 = in[0] - kC12[6] * in[2] + kC12[6] * in[6];
    389      float even5 = in[0] - kC12[2] * in[2] + kC12[4] * in[4] - kC12[6] * in[6];
    390      float odd0 =
    391          kC12[1] * in[1] + kC12[3] * in[3] + kC12[5] * in[5] + kC12[7] * in[7];
    392      float odd1 =
    393          kC12[3] * in[1] + kC12[9] * in[3] - kC12[9] * in[5] - kC12[3] * in[7];
    394      float odd2 = kC12[5] * in[1] - kC12[9] * in[3] - kC12[1] * in[5] -
    395                   kC12[11] * in[7];
    396      float odd3 = kC12[7] * in[1] - kC12[3] * in[3] - kC12[11] * in[5] +
    397                   kC12[1] * in[7];
    398      float odd4 =
    399          kC12[9] * in[1] - kC12[3] * in[3] + kC12[3] * in[5] - kC12[9] * in[7];
    400      float odd5 = kC12[11] * in[1] - kC12[9] * in[3] + kC12[7] * in[5] -
    401                   kC12[5] * in[7];
    402      out[0] = even0 + odd0;
    403      out[11] = even0 - odd0;
    404      out[1] = even1 + odd1;
    405      out[10] = even1 - odd1;
    406      out[2] = even2 + odd2;
    407      out[9] = even2 - odd2;
    408      out[3] = even3 + odd3;
    409      out[8] = even3 - odd3;
    410      out[4] = even4 + odd4;
    411      out[7] = even4 - odd4;
    412      out[5] = even5 + odd5;
    413      out[6] = even5 - odd5;
    414      break;
    415    }
    416    case 13: {
    417      static constexpr float kC13[13] = {
    418          1.414213562373, 1.403902353238, 1.373119086479, 1.322312651445,
    419          1.252223920364, 1.163874944761, 1.058554051646, 0.937797056801,
    420          0.803364869133, 0.657217812653, 0.501487040539, 0.338443458124,
    421          0.170464607981,
    422      };
    423      float even0 = in[0] + kC13[2] * in[2] + kC13[4] * in[4] + kC13[6] * in[6];
    424      float even1 =
    425          in[0] + kC13[6] * in[2] + kC13[12] * in[4] - kC13[8] * in[6];
    426      float even2 =
    427          in[0] + kC13[10] * in[2] - kC13[6] * in[4] - kC13[4] * in[6];
    428      float even3 =
    429          in[0] - kC13[12] * in[2] - kC13[2] * in[4] + kC13[10] * in[6];
    430      float even4 =
    431          in[0] - kC13[8] * in[2] - kC13[10] * in[4] + kC13[2] * in[6];
    432      float even5 =
    433          in[0] - kC13[4] * in[2] + kC13[8] * in[4] - kC13[12] * in[6];
    434      float even6 = in[0] - kC13[0] * in[2] + kC13[0] * in[4] - kC13[0] * in[6];
    435      float odd0 =
    436          kC13[1] * in[1] + kC13[3] * in[3] + kC13[5] * in[5] + kC13[7] * in[7];
    437      float odd1 = kC13[3] * in[1] + kC13[9] * in[3] - kC13[11] * in[5] -
    438                   kC13[5] * in[7];
    439      float odd2 = kC13[5] * in[1] - kC13[11] * in[3] - kC13[1] * in[5] -
    440                   kC13[9] * in[7];
    441      float odd3 =
    442          kC13[7] * in[1] - kC13[5] * in[3] - kC13[9] * in[5] + kC13[3] * in[7];
    443      float odd4 = kC13[9] * in[1] - kC13[1] * in[3] + kC13[7] * in[5] +
    444                   kC13[11] * in[7];
    445      float odd5 = kC13[11] * in[1] - kC13[7] * in[3] + kC13[3] * in[5] -
    446                   kC13[1] * in[7];
    447      out[0] = even0 + odd0;
    448      out[12] = even0 - odd0;
    449      out[1] = even1 + odd1;
    450      out[11] = even1 - odd1;
    451      out[2] = even2 + odd2;
    452      out[10] = even2 - odd2;
    453      out[3] = even3 + odd3;
    454      out[9] = even3 - odd3;
    455      out[4] = even4 + odd4;
    456      out[8] = even4 - odd4;
    457      out[5] = even5 + odd5;
    458      out[7] = even5 - odd5;
    459      out[6] = even6;
    460      break;
    461    }
    462    case 14: {
    463      static constexpr float kC14[14] = {
    464          1.414213562373, 1.405321284327, 1.378756275744, 1.334852607020,
    465          1.274162392264, 1.197448846138, 1.105676685997, 1.000000000000,
    466          0.881747733790, 0.752406978226, 0.613604268353, 0.467085128785,
    467          0.314692122713, 0.158341680609,
    468      };
    469      float even0 = in[0] + kC14[2] * in[2] + kC14[4] * in[4] + kC14[6] * in[6];
    470      float even1 =
    471          in[0] + kC14[6] * in[2] + kC14[12] * in[4] - kC14[10] * in[6];
    472      float even2 =
    473          in[0] + kC14[10] * in[2] - kC14[8] * in[4] - kC14[2] * in[6];
    474      float even3 = in[0] - kC14[0] * in[4];
    475      float even4 =
    476          in[0] - kC14[10] * in[2] - kC14[8] * in[4] + kC14[2] * in[6];
    477      float even5 =
    478          in[0] - kC14[6] * in[2] + kC14[12] * in[4] + kC14[10] * in[6];
    479      float even6 = in[0] - kC14[2] * in[2] + kC14[4] * in[4] - kC14[6] * in[6];
    480      float odd0 =
    481          kC14[1] * in[1] + kC14[3] * in[3] + kC14[5] * in[5] + kC14[7] * in[7];
    482      float odd1 = kC14[3] * in[1] + kC14[9] * in[3] - kC14[13] * in[5] -
    483                   kC14[7] * in[7];
    484      float odd2 = kC14[5] * in[1] - kC14[13] * in[3] - kC14[3] * in[5] -
    485                   kC14[7] * in[7];
    486      float odd3 =
    487          kC14[7] * in[1] - kC14[7] * in[3] - kC14[7] * in[5] + kC14[7] * in[7];
    488      float odd4 = kC14[9] * in[1] - kC14[1] * in[3] + kC14[11] * in[5] +
    489                   kC14[7] * in[7];
    490      float odd5 = kC14[11] * in[1] - kC14[5] * in[3] + kC14[1] * in[5] -
    491                   kC14[7] * in[7];
    492      float odd6 = kC14[13] * in[1] - kC14[11] * in[3] + kC14[9] * in[5] -
    493                   kC14[7] * in[7];
    494      out[0] = even0 + odd0;
    495      out[13] = even0 - odd0;
    496      out[1] = even1 + odd1;
    497      out[12] = even1 - odd1;
    498      out[2] = even2 + odd2;
    499      out[11] = even2 - odd2;
    500      out[3] = even3 + odd3;
    501      out[10] = even3 - odd3;
    502      out[4] = even4 + odd4;
    503      out[9] = even4 - odd4;
    504      out[5] = even5 + odd5;
    505      out[8] = even5 - odd5;
    506      out[6] = even6 + odd6;
    507      out[7] = even6 - odd6;
    508      break;
    509    }
    510    case 15: {
    511      static constexpr float kC15[15] = {
    512          1.414213562373, 1.406466352507, 1.383309602960, 1.344997023928,
    513          1.291948376043, 1.224744871392, 1.144122805635, 1.050965490998,
    514          0.946293578512, 0.831253875555, 0.707106781187, 0.575212476952,
    515          0.437016024449, 0.294031532930, 0.147825570407,
    516      };
    517      float even0 = in[0] + kC15[2] * in[2] + kC15[4] * in[4] + kC15[6] * in[6];
    518      float even1 =
    519          in[0] + kC15[6] * in[2] + kC15[12] * in[4] - kC15[12] * in[6];
    520      float even2 =
    521          in[0] + kC15[10] * in[2] - kC15[10] * in[4] - kC15[0] * in[6];
    522      float even3 =
    523          in[0] + kC15[14] * in[2] - kC15[2] * in[4] - kC15[12] * in[6];
    524      float even4 =
    525          in[0] - kC15[12] * in[2] - kC15[6] * in[4] + kC15[6] * in[6];
    526      float even5 =
    527          in[0] - kC15[8] * in[2] - kC15[14] * in[4] + kC15[6] * in[6];
    528      float even6 =
    529          in[0] - kC15[4] * in[2] + kC15[8] * in[4] - kC15[12] * in[6];
    530      float even7 = in[0] - kC15[0] * in[2] + kC15[0] * in[4] - kC15[0] * in[6];
    531      float odd0 =
    532          kC15[1] * in[1] + kC15[3] * in[3] + kC15[5] * in[5] + kC15[7] * in[7];
    533      float odd1 = kC15[3] * in[1] + kC15[9] * in[3] - kC15[9] * in[7];
    534      float odd2 = kC15[5] * in[1] - kC15[5] * in[5] - kC15[5] * in[7];
    535      float odd3 = kC15[7] * in[1] - kC15[9] * in[3] - kC15[5] * in[5] +
    536                   kC15[11] * in[7];
    537      float odd4 = kC15[9] * in[1] - kC15[3] * in[3] + kC15[3] * in[7];
    538      float odd5 = kC15[11] * in[1] - kC15[3] * in[3] + kC15[5] * in[5] -
    539                   kC15[13] * in[7];
    540      float odd6 = kC15[13] * in[1] - kC15[9] * in[3] + kC15[5] * in[5] -
    541                   kC15[1] * in[7];
    542      out[0] = even0 + odd0;
    543      out[14] = even0 - odd0;
    544      out[1] = even1 + odd1;
    545      out[13] = even1 - odd1;
    546      out[2] = even2 + odd2;
    547      out[12] = even2 - odd2;
    548      out[3] = even3 + odd3;
    549      out[11] = even3 - odd3;
    550      out[4] = even4 + odd4;
    551      out[10] = even4 - odd4;
    552      out[5] = even5 + odd5;
    553      out[9] = even5 - odd5;
    554      out[6] = even6 + odd6;
    555      out[8] = even6 - odd6;
    556      out[7] = even7;
    557      break;
    558    }
    559    case 16: {
    560      static constexpr float kC16[16] = {
    561          1.414213562373, 1.407403737526, 1.387039845322, 1.353318001174,
    562          1.306562964876, 1.247225012987, 1.175875602419, 1.093201867002,
    563          1.000000000000, 0.897167586343, 0.785694958387, 0.666655658478,
    564          0.541196100146, 0.410524527522, 0.275899379283, 0.138617169199,
    565      };
    566      float even0 = in[0] + kC16[2] * in[2] + kC16[4] * in[4] + kC16[6] * in[6];
    567      float even1 =
    568          in[0] + kC16[6] * in[2] + kC16[12] * in[4] - kC16[14] * in[6];
    569      float even2 =
    570          in[0] + kC16[10] * in[2] - kC16[12] * in[4] - kC16[2] * in[6];
    571      float even3 =
    572          in[0] + kC16[14] * in[2] - kC16[4] * in[4] - kC16[10] * in[6];
    573      float even4 =
    574          in[0] - kC16[14] * in[2] - kC16[4] * in[4] + kC16[10] * in[6];
    575      float even5 =
    576          in[0] - kC16[10] * in[2] - kC16[12] * in[4] + kC16[2] * in[6];
    577      float even6 =
    578          in[0] - kC16[6] * in[2] + kC16[12] * in[4] + kC16[14] * in[6];
    579      float even7 = in[0] - kC16[2] * in[2] + kC16[4] * in[4] - kC16[6] * in[6];
    580      float odd0 = (kC16[1] * in[1] + kC16[3] * in[3] + kC16[5] * in[5] +
    581                    kC16[7] * in[7]);
    582      float odd1 = (kC16[3] * in[1] + kC16[9] * in[3] + kC16[15] * in[5] -
    583                    kC16[11] * in[7]);
    584      float odd2 = (kC16[5] * in[1] + kC16[15] * in[3] - kC16[7] * in[5] -
    585                    kC16[3] * in[7]);
    586      float odd3 = (kC16[7] * in[1] - kC16[11] * in[3] - kC16[3] * in[5] +
    587                    kC16[15] * in[7]);
    588      float odd4 = (kC16[9] * in[1] - kC16[5] * in[3] - kC16[13] * in[5] +
    589                    kC16[1] * in[7]);
    590      float odd5 = (kC16[11] * in[1] - kC16[1] * in[3] + kC16[9] * in[5] +
    591                    kC16[13] * in[7]);
    592      float odd6 = (kC16[13] * in[1] - kC16[7] * in[3] + kC16[1] * in[5] -
    593                    kC16[5] * in[7]);
    594      float odd7 = (kC16[15] * in[1] - kC16[13] * in[3] + kC16[11] * in[5] -
    595                    kC16[9] * in[7]);
    596      out[0] = even0 + odd0;
    597      out[15] = even0 - odd0;
    598      out[1] = even1 + odd1;
    599      out[14] = even1 - odd1;
    600      out[2] = even2 + odd2;
    601      out[13] = even2 - odd2;
    602      out[3] = even3 + odd3;
    603      out[12] = even3 - odd3;
    604      out[4] = even4 + odd4;
    605      out[11] = even4 - odd4;
    606      out[5] = even5 + odd5;
    607      out[10] = even5 - odd5;
    608      out[6] = even6 + odd6;
    609      out[9] = even6 - odd6;
    610      out[7] = even7 + odd7;
    611      out[8] = even7 - odd7;
    612      break;
    613    }
    614    default:
    615      JXL_DEBUG_ABORT("Unreachable");
    616      break;
    617  }
    618 }
    619 
    620 void InverseTransformBlockGeneric(const int16_t* JXL_RESTRICT qblock,
    621                                  const float* JXL_RESTRICT dequant,
    622                                  const float* JXL_RESTRICT biases,
    623                                  float* JXL_RESTRICT scratch_space,
    624                                  float* JXL_RESTRICT output,
    625                                  size_t output_stride, size_t dctsize) {
    626  float* JXL_RESTRICT block0 = scratch_space;
    627  float* JXL_RESTRICT block1 = scratch_space + DCTSIZE2;
    628  DequantBlock(qblock, dequant, biases, block0);
    629  if (dctsize == 1) {
    630    *output = *block0;
    631  } else if (dctsize == 2 || dctsize == 4) {
    632    float* JXL_RESTRICT block2 = scratch_space + 2 * DCTSIZE2;
    633    ComputeScaledIDCT(block0, block1, block2, 8);
    634    if (dctsize == 4) {
    635      for (size_t iy = 0; iy < 4; ++iy) {
    636        for (size_t ix = 0; ix < 4; ++ix) {
    637          float* block = &block2[16 * iy + 2 * ix];
    638          output[iy * output_stride + ix] =
    639              0.25f * (block[0] + block[1] + block[8] + block[9]);
    640        }
    641      }
    642    } else {
    643      for (size_t iy = 0; iy < 2; ++iy) {
    644        for (size_t ix = 0; ix < 2; ++ix) {
    645          float* block = &block2[32 * iy + 4 * ix];
    646          output[iy * output_stride + ix] =
    647              0.0625f *
    648              (block[0] + block[1] + block[2] + block[3] + block[8] + block[9] +
    649               block[10] + block[11] + block[16] + block[17] + block[18] +
    650               block[19] + block[24] + block[25] + block[26] + block[27]);
    651        }
    652      }
    653    }
    654  } else {
    655    float dctin[DCTSIZE];
    656    float dctout[DCTSIZE * 2];
    657    size_t insize = std::min<size_t>(dctsize, DCTSIZE);
    658    for (size_t ix = 0; ix < insize; ++ix) {
    659      for (size_t iy = 0; iy < insize; ++iy) {
    660        dctin[iy] = block0[iy * DCTSIZE + ix];
    661      }
    662      Compute1dIDCT(dctin, dctout, dctsize);
    663      for (size_t iy = 0; iy < dctsize; ++iy) {
    664        block1[iy * dctsize + ix] = dctout[iy];
    665      }
    666    }
    667    for (size_t iy = 0; iy < dctsize; ++iy) {
    668      Compute1dIDCT(block1 + iy * dctsize, output + iy * output_stride,
    669                    dctsize);
    670    }
    671  }
    672 }
    673 
    674 // NOLINTNEXTLINE(google-readability-namespace-comments)
    675 }  // namespace HWY_NAMESPACE
    676 }  // namespace jpegli
    677 HWY_AFTER_NAMESPACE();
    678 
    679 #if HWY_ONCE
    680 namespace jpegli {
    681 
    682 HWY_EXPORT(InverseTransformBlock8x8);
    683 HWY_EXPORT(InverseTransformBlockGeneric);
    684 
    685 jxl::Status ChooseInverseTransform(j_decompress_ptr cinfo) {
    686  jpeg_decomp_master* m = cinfo->master;
    687  for (int c = 0; c < cinfo->num_components; ++c) {
    688    int dct_size = m->scaled_dct_size[c];
    689    if (dct_size < 1 || dct_size > 16) {
    690      return JXL_FAILURE("Compute1dIDCT does not support N=%d", dct_size);
    691    }
    692    if (dct_size == DCTSIZE) {
    693      m->inverse_transform[c] = HWY_DYNAMIC_DISPATCH(InverseTransformBlock8x8);
    694    } else {
    695      m->inverse_transform[c] =
    696          HWY_DYNAMIC_DISPATCH(InverseTransformBlockGeneric);
    697    }
    698  }
    699  return true;
    700 }
    701 
    702 }  // namespace jpegli
    703 #endif  // HWY_ONCE