tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

dec_group.cc (33239B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_group.h"
      7 
      8 #include <algorithm>
      9 #include <cstdint>
     10 #include <cstring>
     11 #include <memory>
     12 #include <utility>
     13 
     14 #include "lib/jxl/chroma_from_luma.h"
     15 #include "lib/jxl/frame_header.h"
     16 
     17 #undef HWY_TARGET_INCLUDE
     18 #define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc"
     19 #include <hwy/foreach_target.h>
     20 #include <hwy/highway.h>
     21 
     22 #include "lib/jxl/ac_context.h"
     23 #include "lib/jxl/ac_strategy.h"
     24 #include "lib/jxl/base/bits.h"
     25 #include "lib/jxl/base/common.h"
     26 #include "lib/jxl/base/printf_macros.h"
     27 #include "lib/jxl/base/rect.h"
     28 #include "lib/jxl/base/status.h"
     29 #include "lib/jxl/coeff_order.h"
     30 #include "lib/jxl/common.h"  // kMaxNumPasses
     31 #include "lib/jxl/dec_cache.h"
     32 #include "lib/jxl/dec_transforms-inl.h"
     33 #include "lib/jxl/dec_xyb.h"
     34 #include "lib/jxl/entropy_coder.h"
     35 #include "lib/jxl/quant_weights.h"
     36 #include "lib/jxl/quantizer-inl.h"
     37 #include "lib/jxl/quantizer.h"
     38 
     39 #ifndef LIB_JXL_DEC_GROUP_CC
     40 #define LIB_JXL_DEC_GROUP_CC
     41 namespace jxl {
     42 
     43 struct AuxOut;
     44 
     45 // Interface for reading groups for DecodeGroupImpl.
     46 class GetBlock {
     47 public:
     48  virtual void StartRow(size_t by) = 0;
     49  virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs,
     50                           size_t size, size_t log2_covered_blocks,
     51                           ACPtr block[3], ACType ac_type) = 0;
     52  virtual ~GetBlock() {}
     53 };
     54 
     55 // Controls whether DecodeGroupImpl renders to pixels or not.
     56 enum DrawMode {
     57  // Render to pixels.
     58  kDraw = 0,
     59  // Don't render to pixels.
     60  kDontDraw = 1,
     61 };
     62 
     63 }  // namespace jxl
     64 #endif  // LIB_JXL_DEC_GROUP_CC
     65 
     66 HWY_BEFORE_NAMESPACE();
     67 namespace jxl {
     68 namespace HWY_NAMESPACE {
     69 
     70 // These templates are not found via ADL.
     71 using hwy::HWY_NAMESPACE::AllFalse;
     72 using hwy::HWY_NAMESPACE::Gt;
     73 using hwy::HWY_NAMESPACE::Le;
     74 using hwy::HWY_NAMESPACE::MaskFromVec;
     75 using hwy::HWY_NAMESPACE::Or;
     76 using hwy::HWY_NAMESPACE::Rebind;
     77 using hwy::HWY_NAMESPACE::ShiftRight;
     78 
     79 using D = HWY_FULL(float);
     80 using DU = HWY_FULL(uint32_t);
     81 using DI = HWY_FULL(int32_t);
     82 using DI16 = Rebind<int16_t, DI>;
     83 using DI16_FULL = HWY_CAPPED(int16_t, kDCTBlockSize);
     84 constexpr D d;
     85 constexpr DI di;
     86 constexpr DI16 di16;
     87 constexpr DI16_FULL di16_full;
     88 
     89 // TODO(veluca): consider SIMDfying.
     90 void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
     91  for (size_t x = 0; x < 8; x++) {
     92    for (size_t y = x + 1; y < 8; y++) {
     93      std::swap(block[y * 8 + x], block[x * 8 + y]);
     94    }
     95  }
     96 }
     97 
     98 template <ACType ac_type>
     99 void DequantLane(Vec<D> scaled_dequant_x, Vec<D> scaled_dequant_y,
    100                 Vec<D> scaled_dequant_b,
    101                 const float* JXL_RESTRICT dequant_matrices, size_t size,
    102                 size_t k, Vec<D> x_cc_mul, Vec<D> b_cc_mul,
    103                 const float* JXL_RESTRICT biases, ACPtr qblock[3],
    104                 float* JXL_RESTRICT block) {
    105  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
    106  const auto y_mul =
    107      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
    108  const auto b_mul =
    109      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
    110 
    111  Vec<DI> quantized_x_int;
    112  Vec<DI> quantized_y_int;
    113  Vec<DI> quantized_b_int;
    114  if (ac_type == ACType::k16) {
    115    Rebind<int16_t, DI> di16;
    116    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
    117    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
    118    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
    119  } else {
    120    quantized_x_int = Load(di, qblock[0].ptr32 + k);
    121    quantized_y_int = Load(di, qblock[1].ptr32 + k);
    122    quantized_b_int = Load(di, qblock[2].ptr32 + k);
    123  }
    124 
    125  const auto dequant_x_cc =
    126      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
    127  const auto dequant_y =
    128      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
    129  const auto dequant_b_cc =
    130      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
    131 
    132  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
    133  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
    134  Store(dequant_x, d, block + k);
    135  Store(dequant_y, d, block + size + k);
    136  Store(dequant_b, d, block + 2 * size + k);
    137 }
    138 
    139 template <ACType ac_type>
    140 void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant,
    141                  float x_dm_multiplier, float b_dm_multiplier, Vec<D> x_cc_mul,
    142                  Vec<D> b_cc_mul, AcStrategyType kind, size_t size,
    143                  const Quantizer& quantizer, size_t covered_blocks,
    144                  const size_t* sbx,
    145                  const float* JXL_RESTRICT* JXL_RESTRICT dc_row,
    146                  size_t dc_stride, const float* JXL_RESTRICT biases,
    147                  ACPtr qblock[3], float* JXL_RESTRICT block,
    148                  float* JXL_RESTRICT scratch) {
    149  const auto scaled_dequant_s = inv_global_scale / quant;
    150 
    151  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
    152  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
    153  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
    154 
    155  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
    156 
    157  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
    158    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
    159                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
    160                         qblock, block);
    161  }
    162  for (size_t c = 0; c < 3; c++) {
    163    LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
    164                            block + c * size, scratch);
    165  }
    166 }
    167 
    168 Status DecodeGroupImpl(const FrameHeader& frame_header,
    169                       GetBlock* JXL_RESTRICT get_block,
    170                       GroupDecCache* JXL_RESTRICT group_dec_cache,
    171                       PassesDecoderState* JXL_RESTRICT dec_state,
    172                       size_t thread, size_t group_idx,
    173                       RenderPipelineInput& render_pipeline_input,
    174                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
    175  // TODO(veluca): investigate cache usage in this function.
    176  const Rect block_rect =
    177      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
    178  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
    179 
    180  const size_t xsize_blocks = block_rect.xsize();
    181  const size_t ysize_blocks = block_rect.ysize();
    182 
    183  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
    184 
    185  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
    186 
    187  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
    188 
    189  const auto kJpegDctMin = Set(di16_full, -4095);
    190  const auto kJpegDctMax = Set(di16_full, 4095);
    191 
    192  size_t idct_stride[3];
    193  for (size_t c = 0; c < 3; c++) {
    194    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
    195  }
    196 
    197  HWY_ALIGN int32_t scaled_qtable[64 * 3];
    198 
    199  ACType ac_type = dec_state->coefficients->Type();
    200  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
    201                                              : DequantBlock<ACType::k32>;
    202  // Whether or not coefficients should be stored for future usage, and/or read
    203  // from past usage.
    204  bool accumulate = !dec_state->coefficients->IsEmpty();
    205  // Offset of the current block in the group.
    206  size_t offset = 0;
    207 
    208  std::array<int, 3> jpeg_c_map;
    209  bool jpeg_is_gray = false;
    210  std::array<int, 3> dcoff = {};
    211 
    212  // TODO(veluca): all of this should be done only once per image.
    213  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
    214  if (jpeg_data) {
    215    if (!color_correlation.IsJPEGCompatible()) {
    216      return JXL_FAILURE("The CfL map is not JPEG-compatible");
    217    }
    218    jpeg_is_gray = (jpeg_data->components.size() == 1);
    219    JXL_ENSURE(frame_header.color_transform != ColorTransform::kXYB);
    220    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
    221    const std::vector<QuantEncoding>& qe =
    222        dec_state->shared->matrices.encodings();
    223    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
    224        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
    225      return JXL_FAILURE(
    226          "Quantization table is not a JPEG quantization table.");
    227    }
    228    JXL_ENSURE(qe[0].qraw.qtable->size() == 3 * 8 * 8);
    229    int* qtable = qe[0].qraw.qtable->data();
    230    for (size_t c = 0; c < 3; c++) {
    231      if (frame_header.color_transform == ColorTransform::kNone) {
    232        dcoff[c] = 1024 / qtable[64 * c];
    233      }
    234      for (size_t i = 0; i < 64; i++) {
    235        // Transpose the matrix, as it will be used on the transposed block.
    236        int n = qtable[64 + i];
    237        int d = qtable[64 * c + i];
    238        if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) {
    239          return JXL_FAILURE("Invalid JPEG quantization table");
    240        }
    241        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
    242            (1 << kCFLFixedPointPrecision) * n / d;
    243      }
    244    }
    245  }
    246 
    247  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
    248  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
    249  Rect r[3];
    250  for (size_t i = 0; i < 3; i++) {
    251    r[i] =
    252        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
    253             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
    254    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
    255                        dec_state->shared->dc->Plane(i).ysize()})) {
    256      return JXL_FAILURE("Frame dimensions are too big for the image.");
    257    }
    258  }
    259 
    260  for (size_t by = 0; by < ysize_blocks; ++by) {
    261    get_block->StartRow(by);
    262    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
    263 
    264    const int32_t* JXL_RESTRICT row_quant =
    265        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
    266 
    267    const float* JXL_RESTRICT dc_rows[3] = {
    268        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
    269        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
    270        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
    271    };
    272 
    273    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
    274    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
    275 
    276    const int8_t* JXL_RESTRICT row_cmap[3] = {
    277        dec_state->shared->cmap.ytox_map.ConstRow(ty),
    278        nullptr,
    279        dec_state->shared->cmap.ytob_map.ConstRow(ty),
    280    };
    281 
    282    float* JXL_RESTRICT idct_row[3];
    283    int16_t* JXL_RESTRICT jpeg_row[3];
    284    for (size_t c = 0; c < 3; c++) {
    285      const auto& buffer = render_pipeline_input.GetBuffer(c);
    286      idct_row[c] = buffer.second.Row(buffer.first, sby[c] * kBlockDim);
    287      if (jpeg_data) {
    288        auto& component = jpeg_data->components[jpeg_c_map[c]];
    289        jpeg_row[c] =
    290            component.coeffs.data() +
    291            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
    292                kDCTBlockSize;
    293      }
    294    }
    295 
    296    size_t bx = 0;
    297    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
    298         tx++) {
    299      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
    300      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
    301      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
    302      // Increment bx by llf_x because those iterations would otherwise
    303      // immediately continue (!IsFirstBlock). Reduces mispredictions.
    304      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
    305        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
    306        AcStrategy acs = acs_row[bx];
    307        const size_t llf_x = acs.covered_blocks_x();
    308 
    309        // Can only happen in the second or lower rows of a varblock.
    310        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
    311          bx += llf_x;
    312          continue;
    313        }
    314        const size_t log2_covered_blocks = acs.log2_covered_blocks();
    315 
    316        const size_t covered_blocks = 1 << log2_covered_blocks;
    317        const size_t size = covered_blocks * kDCTBlockSize;
    318 
    319        ACPtr qblock[3];
    320        if (accumulate) {
    321          for (size_t c = 0; c < 3; c++) {
    322            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
    323          }
    324        } else {
    325          // No point in reading from bitstream without accumulating and not
    326          // drawing.
    327          JXL_ENSURE(draw == kDraw);
    328          if (ac_type == ACType::k16) {
    329            memset(group_dec_cache->dec_group_qblock16, 0,
    330                   size * 3 * sizeof(int16_t));
    331            for (size_t c = 0; c < 3; c++) {
    332              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
    333            }
    334          } else {
    335            memset(group_dec_cache->dec_group_qblock, 0,
    336                   size * 3 * sizeof(int32_t));
    337            for (size_t c = 0; c < 3; c++) {
    338              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
    339            }
    340          }
    341        }
    342        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
    343            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
    344        offset += size;
    345        if (draw == kDontDraw) {
    346          bx += llf_x;
    347          continue;
    348        }
    349 
    350        if (JXL_UNLIKELY(jpeg_data)) {
    351          if (acs.Strategy() != AcStrategyType::DCT) {
    352            return JXL_FAILURE(
    353                "Can only decode to JPEG if only DCT-8 is used.");
    354          }
    355 
    356          HWY_ALIGN int32_t transposed_dct_y[64];
    357          for (size_t c : {1, 0, 2}) {
    358            // Propagate only Y for grayscale.
    359            if (jpeg_is_gray && c != 1) {
    360              continue;
    361            }
    362            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
    363              continue;
    364            }
    365            int16_t* JXL_RESTRICT jpeg_pos =
    366                jpeg_row[c] + sbx[c] * kDCTBlockSize;
    367            // JPEG XL is transposed, JPEG is not.
    368            auto* transposed_dct = qblock[c].ptr32;
    369            Transpose8x8InPlace(transposed_dct);
    370            // No CfL - no need to store the y block converted to integers.
    371            if (!cs.Is444() ||
    372                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
    373              for (size_t i = 0; i < 64; i += Lanes(d)) {
    374                const auto ini = Load(di, transposed_dct + i);
    375                const auto ini16 = DemoteTo(di16, ini);
    376                StoreU(ini16, di16, jpeg_pos + i);
    377              }
    378            } else if (c == 1) {
    379              // Y channel: save for restoring X/B, but nothing else to do.
    380              for (size_t i = 0; i < 64; i += Lanes(d)) {
    381                const auto ini = Load(di, transposed_dct + i);
    382                Store(ini, di, transposed_dct_y + i);
    383                const auto ini16 = DemoteTo(di16, ini);
    384                StoreU(ini16, di16, jpeg_pos + i);
    385              }
    386            } else {
    387              // transposed_dct_y contains the y channel block, transposed.
    388              const auto scale =
    389                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
    390              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
    391              for (int i = 0; i < 64; i += Lanes(d)) {
    392                auto in = Load(di, transposed_dct + i);
    393                auto in_y = Load(di, transposed_dct_y + i);
    394                auto qt = Load(di, scaled_qtable + c * size + i);
    395                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
    396                    Add(Mul(qt, scale), round));
    397                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
    398                    Add(Mul(in_y, coeff_scale), round));
    399                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
    400              }
    401            }
    402            jpeg_pos[0] =
    403                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
    404            auto overflow = MaskFromVec(Set(di16_full, 0));
    405            auto underflow = MaskFromVec(Set(di16_full, 0));
    406            for (int i = 0; i < 64; i += Lanes(di16_full)) {
    407              auto in = LoadU(di16_full, jpeg_pos + i);
    408              overflow = Or(overflow, Gt(in, kJpegDctMax));
    409              underflow = Or(underflow, Lt(in, kJpegDctMin));
    410            }
    411            if (!AllFalse(di16_full, Or(overflow, underflow))) {
    412              return JXL_FAILURE("JPEG DCT coefficients out of range");
    413            }
    414          }
    415        } else {
    416          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
    417          // Dequantize and add predictions.
    418          dequant_block(
    419              acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
    420              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.Strategy(),
    421              size, dec_state->shared->quantizer,
    422              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
    423              dc_stride,
    424              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
    425              block, group_dec_cache->scratch_space);
    426 
    427          for (size_t c : {1, 0, 2}) {
    428            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
    429              continue;
    430            }
    431            // IDCT
    432            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
    433            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
    434                              idct_stride[c], group_dec_cache->scratch_space);
    435          }
    436        }
    437        bx += llf_x;
    438      }
    439    }
    440  }
    441  return true;
    442 }
    443 
    444 // NOLINTNEXTLINE(google-readability-namespace-comments)
    445 }  // namespace HWY_NAMESPACE
    446 }  // namespace jxl
    447 HWY_AFTER_NAMESPACE();
    448 
    449 #if HWY_ONCE
    450 namespace jxl {
    451 namespace {
    452 // Decode quantized AC coefficients of DCT blocks.
    453 // LLF components in the output block will not be modified.
    454 template <ACType ac_type, bool uses_lz77>
    455 Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks,
    456                        int32_t* JXL_RESTRICT row_nzeros,
    457                        const int32_t* JXL_RESTRICT row_nzeros_top,
    458                        size_t nzeros_stride, size_t c, size_t bx, size_t by,
    459                        size_t lbx, AcStrategy acs,
    460                        const coeff_order_t* JXL_RESTRICT coeff_order,
    461                        BitReader* JXL_RESTRICT br,
    462                        ANSSymbolReader* JXL_RESTRICT decoder,
    463                        const std::vector<uint8_t>& context_map,
    464                        const uint8_t* qdc_row, const int32_t* qf_row,
    465                        const BlockCtxMap& block_ctx_map, ACPtr block,
    466                        size_t shift = 0) {
    467  // Equal to number of LLF coefficients.
    468  const size_t covered_blocks = 1 << log2_covered_blocks;
    469  const size_t size = covered_blocks * kDCTBlockSize;
    470  int32_t predicted_nzeros =
    471      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
    472 
    473  size_t ord = kStrategyOrder[acs.RawStrategy()];
    474  const coeff_order_t* JXL_RESTRICT order =
    475      &coeff_order[CoeffOrderOffset(ord, c)];
    476 
    477  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
    478  const int32_t nzero_ctx =
    479      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
    480 
    481  size_t nzeros =
    482      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
    483  if (nzeros > size - covered_blocks) {
    484    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
    485                       " 8x8 blocks",
    486                       nzeros, covered_blocks);
    487  }
    488  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
    489    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
    490      row_nzeros[bx + x + y * nzeros_stride] =
    491          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
    492    }
    493  }
    494 
    495  const size_t histo_offset =
    496      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
    497 
    498  size_t prev = (nzeros > size / 16 ? 0 : 1);
    499  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
    500    const size_t ctx =
    501        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
    502                                          log2_covered_blocks, prev);
    503    const size_t u_coeff =
    504        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
    505    // Hand-rolled version of UnpackSigned, shifting before the conversion to
    506    // signed integer to avoid undefined behavior of shifting negative numbers.
    507    const size_t magnitude = u_coeff >> 1;
    508    const size_t neg_sign = (~u_coeff) & 1;
    509    const intptr_t coeff =
    510        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
    511    if (ac_type == ACType::k16) {
    512      block.ptr16[order[k]] += coeff;
    513    } else {
    514      block.ptr32[order[k]] += coeff;
    515    }
    516    prev = static_cast<size_t>(u_coeff != 0);
    517    nzeros -= prev;
    518  }
    519  if (JXL_UNLIKELY(nzeros != 0)) {
    520    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
    521                       ", should be 0. Block (%" PRIuS ", %" PRIuS
    522                       "), channel %" PRIuS,
    523                       nzeros, bx, by, c);
    524  }
    525 
    526  return true;
    527 }
    528 
    529 // Structs used by DecodeGroupImpl to get a quantized block.
    530 // GetBlockFromBitstream uses ANS decoding (and thus keeps track of row
    531 // pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient
    532 // image provided by the encoder.
    533 
    534 struct GetBlockFromBitstream : public GetBlock {
    535  void StartRow(size_t by) override {
    536    qf_row = rect.ConstRow(*qf, by);
    537    for (size_t c = 0; c < 3; c++) {
    538      size_t sby = by >> vshift[c];
    539      quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0();
    540      for (size_t i = 0; i < num_passes; i++) {
    541        row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby);
    542        row_nzeros_top[i][c] =
    543            sby == 0
    544                ? nullptr
    545                : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1);
    546      }
    547    }
    548  }
    549 
    550  Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size,
    551                   size_t log2_covered_blocks, ACPtr block[3],
    552                   ACType ac_type) override {
    553    ;
    554    for (size_t c : {1, 0, 2}) {
    555      size_t sbx = bx >> hshift[c];
    556      size_t sby = by >> vshift[c];
    557      if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) {
    558        continue;
    559      }
    560 
    561      for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) {
    562        auto decode_ac_varblock =
    563            decoders[pass].UsesLZ77()
    564                ? (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 1>
    565                                          : DecodeACVarBlock<ACType::k32, 1>)
    566                : (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 0>
    567                                          : DecodeACVarBlock<ACType::k32, 0>);
    568        JXL_RETURN_IF_ERROR(decode_ac_varblock(
    569            ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c],
    570            row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs,
    571            &coeff_orders[pass * coeff_order_size], readers[pass],
    572            &decoders[pass], context_map[pass], quant_dc_row, qf_row,
    573            *block_ctx_map, block[c], shift_for_pass[pass]));
    574      }
    575    }
    576    return true;
    577  }
    578 
    579  Status Init(const FrameHeader& frame_header,
    580              BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes,
    581              size_t group_idx, size_t histo_selector_bits, const Rect& rect,
    582              GroupDecCache* JXL_RESTRICT group_dec_cache,
    583              PassesDecoderState* dec_state, size_t first_pass) {
    584    for (size_t i = 0; i < 3; i++) {
    585      hshift[i] = frame_header.chroma_subsampling.HShift(i);
    586      vshift[i] = frame_header.chroma_subsampling.VShift(i);
    587    }
    588    this->coeff_order_size = dec_state->shared->coeff_order_size;
    589    this->coeff_orders =
    590        dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size;
    591    this->context_map = dec_state->context_map.data() + first_pass;
    592    this->readers = readers;
    593    this->num_passes = num_passes;
    594    this->shift_for_pass = frame_header.passes.shift + first_pass;
    595    this->group_dec_cache = group_dec_cache;
    596    this->rect = rect;
    597    block_ctx_map = &dec_state->shared->block_ctx_map;
    598    qf = &dec_state->shared->raw_quant_field;
    599    quant_dc = &dec_state->shared->quant_dc;
    600 
    601    for (size_t pass = 0; pass < num_passes; pass++) {
    602      // Select which histogram set to use among those of the current pass.
    603      size_t cur_histogram = 0;
    604      if (histo_selector_bits != 0) {
    605        cur_histogram = readers[pass]->ReadBits(histo_selector_bits);
    606      }
    607      if (cur_histogram >= dec_state->shared->num_histograms) {
    608        return JXL_FAILURE("Invalid histogram selector");
    609      }
    610      ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts();
    611 
    612      JXL_ASSIGN_OR_RETURN(
    613          decoders[pass],
    614          ANSSymbolReader::Create(&dec_state->code[pass + first_pass],
    615                                  readers[pass]));
    616    }
    617    nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow();
    618    for (size_t i = 0; i < num_passes; i++) {
    619      JXL_ENSURE(
    620          nzeros_stride ==
    621          static_cast<size_t>(group_dec_cache->num_nzeroes[i].PixelsPerRow()));
    622    }
    623    return true;
    624  }
    625 
    626  const uint32_t* shift_for_pass = nullptr;  // not owned
    627  const coeff_order_t* JXL_RESTRICT coeff_orders;
    628  size_t coeff_order_size;
    629  const std::vector<uint8_t>* JXL_RESTRICT context_map;
    630  ANSSymbolReader decoders[kMaxNumPasses];
    631  BitReader* JXL_RESTRICT* JXL_RESTRICT readers;
    632  size_t num_passes;
    633  size_t ctx_offset[kMaxNumPasses];
    634  size_t nzeros_stride;
    635  int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3];
    636  const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3];
    637  GroupDecCache* JXL_RESTRICT group_dec_cache;
    638  const BlockCtxMap* block_ctx_map;
    639  const ImageI* qf;
    640  const ImageB* quant_dc;
    641  const int32_t* qf_row;
    642  const uint8_t* quant_dc_row;
    643  Rect rect;
    644  size_t hshift[3], vshift[3];
    645 };
    646 
    647 struct GetBlockFromEncoder : public GetBlock {
    648  void StartRow(size_t by) override {}
    649 
    650  Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size,
    651                   size_t log2_covered_blocks, ACPtr block[3],
    652                   ACType ac_type) override {
    653    JXL_ENSURE(ac_type == ACType::k32);
    654    for (size_t c = 0; c < 3; c++) {
    655      // for each pass
    656      for (size_t i = 0; i < quantized_ac->size(); i++) {
    657        for (size_t k = 0; k < size; k++) {
    658          // TODO(veluca): SIMD.
    659          block[c].ptr32[k] +=
    660              rows[i][c][offset + k] * (1 << shift_for_pass[i]);
    661        }
    662      }
    663    }
    664    offset += size;
    665    return true;
    666  }
    667 
    668  static StatusOr<GetBlockFromEncoder> Create(
    669      const std::vector<std::unique_ptr<ACImage>>& ac, size_t group_idx,
    670      const uint32_t* shift_for_pass) {
    671    GetBlockFromEncoder result(ac, group_idx, shift_for_pass);
    672    // TODO(veluca): not supported with chroma subsampling.
    673    for (size_t i = 0; i < ac.size(); i++) {
    674      JXL_ENSURE(ac[i]->Type() == ACType::k32);
    675      for (size_t c = 0; c < 3; c++) {
    676        result.rows[i][c] = ac[i]->PlaneRow(c, group_idx, 0).ptr32;
    677      }
    678    }
    679    return result;
    680  }
    681 
    682  const std::vector<std::unique_ptr<ACImage>>* JXL_RESTRICT quantized_ac;
    683  size_t offset = 0;
    684  const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3];
    685  const uint32_t* shift_for_pass = nullptr;  // not owned
    686 
    687 private:
    688  GetBlockFromEncoder(const std::vector<std::unique_ptr<ACImage>>& ac,
    689                      size_t group_idx, const uint32_t* shift_for_pass)
    690      : quantized_ac(&ac), shift_for_pass(shift_for_pass) {}
    691 };
    692 
    693 HWY_EXPORT(DecodeGroupImpl);
    694 
    695 }  // namespace
    696 
    697 Status DecodeGroup(const FrameHeader& frame_header,
    698                   BitReader* JXL_RESTRICT* JXL_RESTRICT readers,
    699                   size_t num_passes, size_t group_idx,
    700                   PassesDecoderState* JXL_RESTRICT dec_state,
    701                   GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread,
    702                   RenderPipelineInput& render_pipeline_input,
    703                   jpeg::JPEGData* JXL_RESTRICT jpeg_data, size_t first_pass,
    704                   bool force_draw, bool dc_only, bool* should_run_pipeline) {
    705  JxlMemoryManager* memory_manager = dec_state->memory_manager();
    706  DrawMode draw =
    707      (num_passes + first_pass == frame_header.passes.num_passes) || force_draw
    708          ? kDraw
    709          : kDontDraw;
    710 
    711  if (should_run_pipeline) {
    712    *should_run_pipeline = draw != kDontDraw;
    713  }
    714 
    715  if (draw == kDraw && num_passes == 0 && first_pass == 0) {
    716    JXL_RETURN_IF_ERROR(group_dec_cache->InitDCBufferOnce(memory_manager));
    717    const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
    718    for (size_t c : {0, 1, 2}) {
    719      size_t hs = cs.HShift(c);
    720      size_t vs = cs.VShift(c);
    721      // We reuse filter_input_storage here as it is not currently in use.
    722      const Rect src_rect_precs =
    723          dec_state->shared->frame_dim.BlockGroupRect(group_idx);
    724      const Rect src_rect =
    725          Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs,
    726               src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs);
    727      const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(),
    728                           src_rect.ysize());
    729      JXL_RETURN_IF_ERROR(
    730          CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2,
    731                                 copy_rect, &group_dec_cache->dc_buffer));
    732      // Mirrorpad. Interleaving left and right padding ensures that padding
    733      // works out correctly even for images with DC size of 1.
    734      for (size_t y = 0; y < src_rect.ysize() + 4; y++) {
    735        size_t xend = kRenderPipelineXOffset +
    736                      (dec_state->shared->dc->Plane(c).xsize() >> hs) -
    737                      src_rect.x0();
    738        for (size_t ix = 0; ix < 2; ix++) {
    739          if (src_rect.x0() == 0) {
    740            group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] =
    741                group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix];
    742          }
    743          if (src_rect.x0() + src_rect.xsize() + 2 >=
    744              (dec_state->shared->dc->xsize() >> hs)) {
    745            group_dec_cache->dc_buffer.Row(y)[xend + ix] =
    746                group_dec_cache->dc_buffer.Row(y)[xend - ix - 1];
    747          }
    748        }
    749      }
    750      const auto& buffer = render_pipeline_input.GetBuffer(c);
    751      Rect dst_rect = buffer.second;
    752      ImageF* upsampling_dst = buffer.first;
    753      JXL_ENSURE(dst_rect.IsInside(*upsampling_dst));
    754 
    755      RenderPipelineStage::RowInfo input_rows(1, std::vector<float*>(5));
    756      RenderPipelineStage::RowInfo output_rows(1, std::vector<float*>(8));
    757      for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize();
    758           y++) {
    759        for (ssize_t iy = 0; iy < 5; iy++) {
    760          input_rows[0][iy] = group_dec_cache->dc_buffer.Row(
    761              Mirror(static_cast<ssize_t>(y) + iy - 2,
    762                     dec_state->shared->dc->Plane(c).ysize() >> vs) +
    763              2 - src_rect.y0());
    764        }
    765        for (size_t iy = 0; iy < 8; iy++) {
    766          output_rows[0][iy] =
    767              dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) -
    768              kRenderPipelineXOffset;
    769        }
    770        // Arguments set to 0/nullptr are not used.
    771        JXL_RETURN_IF_ERROR(dec_state->upsampler8x->ProcessRow(
    772            input_rows, output_rows,
    773            /*xextra=*/0, src_rect.xsize(), 0, 0, thread));
    774      }
    775    }
    776    return true;
    777  }
    778 
    779  size_t histo_selector_bits = 0;
    780  if (dc_only) {
    781    JXL_ENSURE(num_passes == 0);
    782  } else {
    783    JXL_ENSURE(dec_state->shared->num_histograms > 0);
    784    histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms);
    785  }
    786 
    787  auto get_block = jxl::make_unique<GetBlockFromBitstream>();
    788  JXL_RETURN_IF_ERROR(get_block->Init(
    789      frame_header, readers, num_passes, group_idx, histo_selector_bits,
    790      dec_state->shared->frame_dim.BlockGroupRect(group_idx), group_dec_cache,
    791      dec_state, first_pass));
    792 
    793  JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(
    794      frame_header, get_block.get(), group_dec_cache, dec_state, thread,
    795      group_idx, render_pipeline_input, jpeg_data, draw));
    796 
    797  for (size_t pass = 0; pass < num_passes; pass++) {
    798    if (!get_block->decoders[pass].CheckANSFinalState()) {
    799      return JXL_FAILURE("ANS checksum failure.");
    800    }
    801  }
    802  return true;
    803 }
    804 
    805 Status DecodeGroupForRoundtrip(const FrameHeader& frame_header,
    806                               const std::vector<std::unique_ptr<ACImage>>& ac,
    807                               size_t group_idx,
    808                               PassesDecoderState* JXL_RESTRICT dec_state,
    809                               GroupDecCache* JXL_RESTRICT group_dec_cache,
    810                               size_t thread,
    811                               RenderPipelineInput& render_pipeline_input,
    812                               jpeg::JPEGData* JXL_RESTRICT jpeg_data,
    813                               AuxOut* aux_out) {
    814  JxlMemoryManager* memory_manager = dec_state->memory_manager();
    815  JXL_ASSIGN_OR_RETURN(
    816      GetBlockFromEncoder get_block,
    817      GetBlockFromEncoder::Create(ac, group_idx, frame_header.passes.shift));
    818  JXL_RETURN_IF_ERROR(group_dec_cache->InitOnce(
    819      memory_manager,
    820      /*num_passes=*/0,
    821      /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1));
    822 
    823  return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(
    824      frame_header, &get_block, group_dec_cache, dec_state, thread, group_idx,
    825      render_pipeline_input, jpeg_data, kDraw);
    826 }
    827 
    828 }  // namespace jxl
    829 #endif  // HWY_ONCE