tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

dec_modular.cc (33601B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_modular.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <cstdint>
     11 #include <vector>
     12 
     13 #include "lib/jxl/frame_header.h"
     14 
     15 #undef HWY_TARGET_INCLUDE
     16 #define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc"
     17 #include <hwy/foreach_target.h>
     18 #include <hwy/highway.h>
     19 
     20 #include "lib/jxl/base/compiler_specific.h"
     21 #include "lib/jxl/base/printf_macros.h"
     22 #include "lib/jxl/base/rect.h"
     23 #include "lib/jxl/base/status.h"
     24 #include "lib/jxl/compressed_dc.h"
     25 #include "lib/jxl/epf.h"
     26 #include "lib/jxl/modular/encoding/encoding.h"
     27 #include "lib/jxl/modular/modular_image.h"
     28 #include "lib/jxl/modular/transform/transform.h"
     29 
     30 HWY_BEFORE_NAMESPACE();
     31 namespace jxl {
     32 namespace HWY_NAMESPACE {
     33 
     34 // These templates are not found via ADL.
     35 using hwy::HWY_NAMESPACE::Add;
     36 using hwy::HWY_NAMESPACE::Mul;
     37 using hwy::HWY_NAMESPACE::Rebind;
     38 
     39 void MultiplySum(const size_t xsize,
     40                 const pixel_type* const JXL_RESTRICT row_in,
     41                 const pixel_type* const JXL_RESTRICT row_in_Y,
     42                 const float factor, float* const JXL_RESTRICT row_out) {
     43  const HWY_FULL(float) df;
     44  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
     45  const auto factor_v = Set(df, factor);
     46  for (size_t x = 0; x < xsize; x += Lanes(di)) {
     47    const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x));
     48    const auto out = Mul(ConvertTo(df, in), factor_v);
     49    Store(out, df, row_out + x);
     50  }
     51 }
     52 
     53 void RgbFromSingle(const size_t xsize,
     54                   const pixel_type* const JXL_RESTRICT row_in,
     55                   const float factor, float* out_r, float* out_g,
     56                   float* out_b) {
     57  const HWY_FULL(float) df;
     58  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
     59 
     60  const auto factor_v = Set(df, factor);
     61  for (size_t x = 0; x < xsize; x += Lanes(di)) {
     62    const auto in = Load(di, row_in + x);
     63    const auto out = Mul(ConvertTo(df, in), factor_v);
     64    Store(out, df, out_r + x);
     65    Store(out, df, out_g + x);
     66    Store(out, df, out_b + x);
     67  }
     68 }
     69 
     70 void SingleFromSingle(const size_t xsize,
     71                      const pixel_type* const JXL_RESTRICT row_in,
     72                      const float factor, float* row_out) {
     73  const HWY_FULL(float) df;
     74  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
     75 
     76  const auto factor_v = Set(df, factor);
     77  for (size_t x = 0; x < xsize; x += Lanes(di)) {
     78    const auto in = Load(di, row_in + x);
     79    const auto out = Mul(ConvertTo(df, in), factor_v);
     80    Store(out, df, row_out + x);
     81  }
     82 }
     83 // NOLINTNEXTLINE(google-readability-namespace-comments)
     84 }  // namespace HWY_NAMESPACE
     85 }  // namespace jxl
     86 HWY_AFTER_NAMESPACE();
     87 
     88 #if HWY_ONCE
     89 namespace jxl {
     90 HWY_EXPORT(MultiplySum);       // Local function
     91 HWY_EXPORT(RgbFromSingle);     // Local function
     92 HWY_EXPORT(SingleFromSingle);  // Local function
     93 
     94 // Slow conversion using double precision multiplication, only
     95 // needed when the bit depth is too high for single precision
     96 void SingleFromSingleAccurate(const size_t xsize,
     97                              const pixel_type* const JXL_RESTRICT row_in,
     98                              const double factor, float* row_out) {
     99  for (size_t x = 0; x < xsize; x++) {
    100    row_out[x] = row_in[x] * factor;
    101  }
    102 }
    103 
    104 // convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int
    105 // back to binary32 float
    106 Status int_to_float(const pixel_type* const JXL_RESTRICT row_in,
    107                    float* const JXL_RESTRICT row_out, const size_t xsize,
    108                    const int bits, const int exp_bits) {
    109  static_assert(sizeof(pixel_type) == sizeof(float));
    110  if (bits == 32) {
    111    JXL_ENSURE(exp_bits == 8);
    112    memcpy(row_out, row_in, xsize * sizeof(float));
    113    return true;
    114  }
    115  int exp_bias = (1 << (exp_bits - 1)) - 1;
    116  int sign_shift = bits - 1;
    117  int mant_bits = bits - exp_bits - 1;
    118  int mant_shift = 23 - mant_bits;
    119  for (size_t x = 0; x < xsize; ++x) {
    120    uint32_t f;
    121    memcpy(&f, &row_in[x], 4);
    122    int signbit = (f >> sign_shift);
    123    f &= (1 << sign_shift) - 1;
    124    if (f == 0) {
    125      row_out[x] = (signbit ? -0.f : 0.f);
    126      continue;
    127    }
    128    int exp = (f >> mant_bits);
    129    int mantissa = (f & ((1 << mant_bits) - 1));
    130    mantissa <<= mant_shift;
    131    // Try to normalize only if there is space for maneuver.
    132    if (exp == 0 && exp_bits < 8) {
    133      // subnormal number
    134      while ((mantissa & 0x800000) == 0) {
    135        mantissa <<= 1;
    136        exp--;
    137      }
    138      exp++;
    139      // remove leading 1 because it is implicit now
    140      mantissa &= 0x7fffff;
    141    }
    142    exp -= exp_bias;
    143    // broke up the arbitrary float into its parts, now reassemble into
    144    // binary32
    145    exp += 127;
    146    JXL_ENSURE(exp >= 0);
    147    f = (signbit ? 0x80000000 : 0);
    148    f |= (exp << 23);
    149    f |= mantissa;
    150    memcpy(&row_out[x], &f, 4);
    151  }
    152  return true;
    153 }
    154 
    155 #if JXL_DEBUG_V_LEVEL >= 1
    156 std::string ModularStreamId::DebugString() const {
    157  std::ostringstream os;
    158  os << (kind == GlobalData   ? "ModularGlobal"
    159         : kind == VarDCTDC   ? "VarDCTDC"
    160         : kind == ModularDC  ? "ModularDC"
    161         : kind == ACMetadata ? "ACMeta"
    162         : kind == QuantTable ? "QuantTable"
    163         : kind == ModularAC  ? "ModularAC"
    164                              : "");
    165  if (kind == VarDCTDC || kind == ModularDC || kind == ACMetadata ||
    166      kind == ModularAC) {
    167    os << " group " << group_id;
    168  }
    169  if (kind == ModularAC) {
    170    os << " pass " << pass_id;
    171  }
    172  if (kind == QuantTable) {
    173    os << " " << quant_table_id;
    174  }
    175  return os.str();
    176 }
    177 #endif
    178 
    179 Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
    180                                             const FrameHeader& frame_header,
    181                                             bool allow_truncated_group) {
    182  JxlMemoryManager* memory_manager = this->memory_manager();
    183  bool decode_color = frame_header.encoding == FrameEncoding::kModular;
    184  const auto& metadata = frame_header.nonserialized_metadata->m;
    185  bool is_gray = metadata.color_encoding.IsGray();
    186  size_t nb_chans = 3;
    187  if (is_gray && frame_header.color_transform == ColorTransform::kNone) {
    188    nb_chans = 1;
    189  }
    190  do_color = decode_color;
    191  size_t nb_extra = metadata.extra_channel_info.size();
    192  bool has_tree = static_cast<bool>(reader->ReadBits(1));
    193  if (!allow_truncated_group ||
    194      reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) {
    195    if (has_tree) {
    196      size_t tree_size_limit =
    197          std::min(static_cast<size_t>(1 << 22),
    198                   1024 + frame_dim.xsize * frame_dim.ysize *
    199                              (nb_chans + nb_extra) / 16);
    200      JXL_RETURN_IF_ERROR(
    201          DecodeTree(memory_manager, reader, &tree, tree_size_limit));
    202      JXL_RETURN_IF_ERROR(DecodeHistograms(
    203          memory_manager, reader, (tree.size() + 1) / 2, &code, &context_map));
    204    }
    205  }
    206  if (!do_color) nb_chans = 0;
    207 
    208  bool fp = metadata.bit_depth.floating_point_sample;
    209 
    210  // bits_per_sample is just metadata for XYB images.
    211  if (metadata.bit_depth.bits_per_sample >= 32 && do_color &&
    212      frame_header.color_transform != ColorTransform::kXYB) {
    213    if (metadata.bit_depth.bits_per_sample == 32 && fp == false) {
    214      return JXL_FAILURE("uint32_t not supported in dec_modular");
    215    } else if (metadata.bit_depth.bits_per_sample > 32) {
    216      return JXL_FAILURE("bits_per_sample > 32 not supported");
    217    }
    218  }
    219 
    220  JXL_ASSIGN_OR_RETURN(
    221      Image gi,
    222      Image::Create(memory_manager, frame_dim.xsize, frame_dim.ysize,
    223                    metadata.bit_depth.bits_per_sample, nb_chans + nb_extra));
    224 
    225  all_same_shift = true;
    226  if (frame_header.color_transform == ColorTransform::kYCbCr) {
    227    for (size_t c = 0; c < nb_chans; c++) {
    228      gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c);
    229      gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c);
    230      size_t xsize_shifted =
    231          DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift);
    232      size_t ysize_shifted =
    233          DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift);
    234      JXL_RETURN_IF_ERROR(gi.channel[c].shrink(xsize_shifted, ysize_shifted));
    235      if (gi.channel[c].hshift != gi.channel[0].hshift ||
    236          gi.channel[c].vshift != gi.channel[0].vshift)
    237        all_same_shift = false;
    238    }
    239  }
    240 
    241  for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) {
    242    size_t ecups = frame_header.extra_channel_upsampling[ec];
    243    JXL_RETURN_IF_ERROR(
    244        gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups),
    245                             DivCeil(frame_dim.ysize_upsampled, ecups)));
    246    gi.channel[c].hshift = gi.channel[c].vshift =
    247        CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling);
    248    if (gi.channel[c].hshift != gi.channel[0].hshift ||
    249        gi.channel[c].vshift != gi.channel[0].vshift)
    250      all_same_shift = false;
    251  }
    252 
    253  JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s",
    254              gi.DebugString().c_str());
    255  ModularOptions options;
    256  options.max_chan_size = frame_dim.group_dim;
    257  options.group_dim = frame_dim.group_dim;
    258  Status dec_status = ModularGenericDecompress(
    259      reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim),
    260      &options,
    261      /*undo_transforms=*/false, &tree, &code, &context_map,
    262      allow_truncated_group);
    263  if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
    264  if (dec_status.IsFatalError()) {
    265    return JXL_FAILURE("Failed to decode global modular info");
    266  }
    267 
    268  // TODO(eustas): are we sure this can be done after partial decode?
    269  have_something = false;
    270  for (size_t c = 0; c < gi.channel.size(); c++) {
    271    Channel& gic = gi.channel[c];
    272    if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim &&
    273        gic.h <= frame_dim.group_dim)
    274      have_something = true;
    275  }
    276  // move global transforms to groups if possible
    277  if (!have_something && all_same_shift) {
    278    if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) {
    279      global_transform = gi.transform;
    280      gi.transform.clear();
    281      // TODO(jon): also move no-delta-palette out (trickier though)
    282    }
    283  }
    284  full_image = std::move(gi);
    285  JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s",
    286              full_image.DebugString().c_str());
    287  return dec_status;
    288 }
    289 
    290 void ModularFrameDecoder::MaybeDropFullImage() {
    291  if (full_image.transform.empty() && !have_something && all_same_shift) {
    292    use_full_image = false;
    293    JXL_DEBUG_V(6, "Dropping full image");
    294    for (auto& ch : full_image.channel) {
    295      // keep metadata on channels around, but dealloc their planes
    296      ch.plane = Plane<pixel_type>();
    297    }
    298  }
    299 }
    300 
    301 Status ModularFrameDecoder::DecodeGroup(
    302    const FrameHeader& frame_header, const Rect& rect, BitReader* reader,
    303    int minShift, int maxShift, const ModularStreamId& stream, bool zerofill,
    304    PassesDecoderState* dec_state, RenderPipelineInput* render_pipeline_input,
    305    bool allow_truncated, bool* should_run_pipeline) {
    306  JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s",
    307              stream.DebugString().c_str(), Description(rect).c_str(), minShift,
    308              maxShift, zerofill ? "using zerofill" : "");
    309  JXL_ENSURE(stream.kind == ModularStreamId::Kind::ModularDC ||
    310             stream.kind == ModularStreamId::Kind::ModularAC);
    311  const size_t xsize = rect.xsize();
    312  const size_t ysize = rect.ysize();
    313  JXL_ASSIGN_OR_RETURN(Image gi, Image::Create(memory_manager_, xsize, ysize,
    314                                               full_image.bitdepth, 0));
    315  // start at the first bigger-than-groupsize non-metachannel
    316  size_t c = full_image.nb_meta_channels;
    317  for (; c < full_image.channel.size(); c++) {
    318    Channel& fc = full_image.channel[c];
    319    if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break;
    320  }
    321  size_t beginc = c;
    322  for (; c < full_image.channel.size(); c++) {
    323    Channel& fc = full_image.channel[c];
    324    int shift = std::min(fc.hshift, fc.vshift);
    325    if (shift > maxShift) continue;
    326    if (shift < minShift) continue;
    327    Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
    328           rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
    329    if (r.xsize() == 0 || r.ysize() == 0) continue;
    330    if (zerofill && use_full_image) {
    331      for (size_t y = 0; y < r.ysize(); ++y) {
    332        pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y);
    333        memset(row_out, 0, r.xsize() * sizeof(*row_out));
    334      }
    335    } else {
    336      JXL_ASSIGN_OR_RETURN(
    337          Channel gc, Channel::Create(memory_manager_, r.xsize(), r.ysize()));
    338      if (zerofill) ZeroFillImage(&gc.plane);
    339      gc.hshift = fc.hshift;
    340      gc.vshift = fc.vshift;
    341      gi.channel.emplace_back(std::move(gc));
    342    }
    343  }
    344  if (zerofill && use_full_image) return true;
    345  // Return early if there's nothing to decode. Otherwise there might be
    346  // problems later (in ModularImageToDecodedRect).
    347  if (gi.channel.empty()) {
    348    if (dec_state && should_run_pipeline) {
    349      const auto* metadata = frame_header.nonserialized_metadata;
    350      if (do_color || metadata->m.num_extra_channels > 0) {
    351        // Signal to FrameDecoder that we do not have some of the required input
    352        // for the render pipeline.
    353        *should_run_pipeline = false;
    354      }
    355    }
    356    JXL_DEBUG_V(6, "Nothing to decode, returning early.");
    357    return true;
    358  }
    359  ModularOptions options;
    360  if (!zerofill) {
    361    auto status = ModularGenericDecompress(
    362        reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options,
    363        /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated);
    364    if (!allow_truncated) JXL_RETURN_IF_ERROR(status);
    365    if (status.IsFatalError()) return status;
    366  }
    367  // Undo global transforms that have been pushed to the group level
    368  if (!use_full_image) {
    369    JXL_ENSURE(render_pipeline_input);
    370    for (const auto& t : global_transform) {
    371      JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header));
    372    }
    373    JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(
    374        frame_header, gi, dec_state, nullptr, *render_pipeline_input,
    375        Rect(0, 0, gi.w, gi.h)));
    376    return true;
    377  }
    378  int gic = 0;
    379  for (c = beginc; c < full_image.channel.size(); c++) {
    380    Channel& fc = full_image.channel[c];
    381    int shift = std::min(fc.hshift, fc.vshift);
    382    if (shift > maxShift) continue;
    383    if (shift < minShift) continue;
    384    Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
    385           rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
    386    if (r.xsize() == 0 || r.ysize() == 0) continue;
    387    JXL_ENSURE(use_full_image);
    388    JXL_RETURN_IF_ERROR(
    389        CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()),
    390                    /*from=*/gi.channel[gic].plane,
    391                    /*rect_to=*/r, /*to=*/&fc.plane));
    392    gic++;
    393  }
    394  return true;
    395 }
    396 
    397 Status ModularFrameDecoder::DecodeVarDCTDC(const FrameHeader& frame_header,
    398                                           size_t group_id, BitReader* reader,
    399                                           PassesDecoderState* dec_state) {
    400  JxlMemoryManager* memory_manager = dec_state->memory_manager();
    401  const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id);
    402  JXL_DEBUG_V(6, "Decoding VarDCT DC with rect %s", Description(r).c_str());
    403  // TODO(eustas): investigate if we could reduce the impact of
    404  //               EvalRationalPolynomial; generally speaking, the limit is
    405  //               2**(128/(3*magic)), where 128 comes from IEEE 754 exponent,
    406  //               3 comes from XybToRgb that cubes the values, and "magic" is
    407  //               the sum of all other contributions. 2**18 is known to lead
    408  //               to NaN on input found by fuzzing (see commit message).
    409  JXL_ASSIGN_OR_RETURN(Image image,
    410                       Image::Create(memory_manager, r.xsize(), r.ysize(),
    411                                     full_image.bitdepth, 3));
    412  size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim);
    413  reader->Refill();
    414  size_t extra_precision = reader->ReadFixedBits<2>();
    415  float mul = 1.0f / (1 << extra_precision);
    416  ModularOptions options;
    417  for (size_t c = 0; c < 3; c++) {
    418    Channel& ch = image.channel[c < 2 ? c ^ 1 : c];
    419    ch.w >>= frame_header.chroma_subsampling.HShift(c);
    420    ch.h >>= frame_header.chroma_subsampling.VShift(c);
    421    JXL_RETURN_IF_ERROR(ch.shrink());
    422  }
    423  if (!ModularGenericDecompress(
    424          reader, image, /*header=*/nullptr, stream_id, &options,
    425          /*undo_transforms=*/true, &tree, &code, &context_map)) {
    426    return JXL_FAILURE("Failed to decode VarDCT DC group (DC group id %d)",
    427                       static_cast<int>(group_id));
    428  }
    429  DequantDC(r, &dec_state->shared_storage.dc_storage,
    430            &dec_state->shared_storage.quant_dc, image,
    431            dec_state->shared->quantizer.MulDC(), mul,
    432            dec_state->shared->cmap.base().DCFactors(),
    433            frame_header.chroma_subsampling, dec_state->shared->block_ctx_map);
    434  return true;
    435 }
    436 
    437 Status ModularFrameDecoder::DecodeAcMetadata(const FrameHeader& frame_header,
    438                                             size_t group_id, BitReader* reader,
    439                                             PassesDecoderState* dec_state) {
    440  JxlMemoryManager* memory_manager = dec_state->memory_manager();
    441  const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id);
    442  JXL_DEBUG_V(6, "Decoding AcMetadata with rect %s", Description(r).c_str());
    443  size_t upper_bound = r.xsize() * r.ysize();
    444  reader->Refill();
    445  size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1;
    446  size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim);
    447  // YToX, YToB, ACS + QF, EPF
    448  JXL_ASSIGN_OR_RETURN(Image image,
    449                       Image::Create(memory_manager, r.xsize(), r.ysize(),
    450                                     full_image.bitdepth, 4));
    451  static_assert(kColorTileDimInBlocks == 8, "Color tile size changed");
    452  Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3);
    453  JXL_ASSIGN_OR_RETURN(
    454      image.channel[0],
    455      Channel::Create(memory_manager, cr.xsize(), cr.ysize(), 3, 3));
    456  JXL_ASSIGN_OR_RETURN(
    457      image.channel[1],
    458      Channel::Create(memory_manager, cr.xsize(), cr.ysize(), 3, 3));
    459  JXL_ASSIGN_OR_RETURN(image.channel[2],
    460                       Channel::Create(memory_manager, count, 2, 0, 0));
    461  ModularOptions options;
    462  if (!ModularGenericDecompress(
    463          reader, image, /*header=*/nullptr, stream_id, &options,
    464          /*undo_transforms=*/true, &tree, &code, &context_map)) {
    465    return JXL_FAILURE("Failed to decode AC metadata");
    466  }
    467  JXL_RETURN_IF_ERROR(
    468      ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane,
    469                           cr, &dec_state->shared_storage.cmap.ytox_map));
    470  JXL_RETURN_IF_ERROR(
    471      ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane,
    472                           cr, &dec_state->shared_storage.cmap.ytob_map));
    473  size_t num = 0;
    474  bool is444 = frame_header.chroma_subsampling.Is444();
    475  auto& ac_strategy = dec_state->shared_storage.ac_strategy;
    476  size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize());
    477  size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize());
    478  uint32_t local_used_acs = 0;
    479  for (size_t iy = 0; iy < r.ysize(); iy++) {
    480    size_t y = r.y0() + iy;
    481    int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy);
    482    uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy);
    483    int32_t* row_in_1 = image.channel[2].plane.Row(0);
    484    int32_t* row_in_2 = image.channel[2].plane.Row(1);
    485    int32_t* row_in_3 = image.channel[3].plane.Row(iy);
    486    for (size_t ix = 0; ix < r.xsize(); ix++) {
    487      size_t x = r.x0() + ix;
    488      int sharpness = row_in_3[ix];
    489      if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) {
    490        return JXL_FAILURE("Corrupted sharpness field");
    491      }
    492      row_epf[ix] = sharpness;
    493      if (ac_strategy.IsValid(x, y)) {
    494        continue;
    495      }
    496 
    497      if (num >= count) return JXL_FAILURE("Corrupted stream");
    498 
    499      if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) {
    500        return JXL_FAILURE("Invalid AC strategy");
    501      }
    502      local_used_acs |= 1u << row_in_1[num];
    503      AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]);
    504      if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) &&
    505          !is444) {
    506        return JXL_FAILURE(
    507            "AC strategy not compatible with chroma subsampling");
    508      }
    509      // Ensure that blocks do not overflow *AC* groups.
    510      size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
    511      size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
    512      size_t next_x_dct_block = x + acs.covered_blocks_x();
    513      size_t next_y_dct_block = y + acs.covered_blocks_y();
    514      if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) {
    515        return JXL_FAILURE("Invalid AC strategy, x overflow");
    516      }
    517      if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) {
    518        return JXL_FAILURE("Invalid AC strategy, y overflow");
    519      }
    520      JXL_RETURN_IF_ERROR(
    521          ac_strategy.SetNoBoundsCheck(x, y, AcStrategyType(row_in_1[num])));
    522      row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1,
    523                                                     row_in_2[num]));
    524      num++;
    525    }
    526  }
    527  dec_state->used_acs |= local_used_acs;
    528  if (frame_header.loop_filter.epf_iters > 0) {
    529    JXL_RETURN_IF_ERROR(ComputeSigma(frame_header.loop_filter, r, dec_state));
    530  }
    531  return true;
    532 }
    533 
    534 Status ModularFrameDecoder::ModularImageToDecodedRect(
    535    const FrameHeader& frame_header, Image& gi, PassesDecoderState* dec_state,
    536    jxl::ThreadPool* pool, RenderPipelineInput& render_pipeline_input,
    537    Rect modular_rect) const {
    538  const auto* metadata = frame_header.nonserialized_metadata;
    539  JXL_ENSURE(gi.transform.empty());
    540 
    541  auto get_row = [&](size_t c, size_t y) {
    542    const auto& buffer = render_pipeline_input.GetBuffer(c);
    543    return buffer.second.Row(buffer.first, y);
    544  };
    545 
    546  size_t c = 0;
    547  if (do_color) {
    548    const bool rgb_from_gray =
    549        metadata->m.color_encoding.IsGray() &&
    550        frame_header.color_transform == ColorTransform::kNone;
    551    const bool fp = metadata->m.bit_depth.floating_point_sample &&
    552                    frame_header.color_transform != ColorTransform::kXYB;
    553    for (; c < 3; c++) {
    554      double factor = full_image.bitdepth < 32
    555                          ? 1.0 / ((1u << full_image.bitdepth) - 1)
    556                          : 0;
    557      size_t c_in = c;
    558      if (frame_header.color_transform == ColorTransform::kXYB) {
    559        factor = dec_state->shared->matrices.DCQuants()[c];
    560        // XYB is encoded as YX(B-Y)
    561        if (c < 2) c_in = 1 - c;
    562      } else if (rgb_from_gray) {
    563        c_in = 0;
    564      }
    565      JXL_ENSURE(c_in < gi.channel.size());
    566      Channel& ch_in = gi.channel[c_in];
    567      // TODO(eustas): could we detect it on earlier stage?
    568      if (ch_in.w == 0 || ch_in.h == 0) {
    569        return JXL_FAILURE("Empty image");
    570      }
    571      JXL_ENSURE(ch_in.hshift <= 3 && ch_in.vshift <= 3);
    572      Rect r = render_pipeline_input.GetBuffer(c).second;
    573      Rect mr(modular_rect.x0() >> ch_in.hshift,
    574              modular_rect.y0() >> ch_in.vshift,
    575              DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
    576              DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
    577      mr = mr.Crop(ch_in.plane);
    578      size_t xsize_shifted = r.xsize();
    579      size_t ysize_shifted = r.ysize();
    580      if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
    581        return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
    582                           "x%" PRIuS
    583                           " modular channel into "
    584                           "a %" PRIuS "x%" PRIuS " rect",
    585                           mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
    586      }
    587      if (frame_header.color_transform == ColorTransform::kXYB && c == 2) {
    588        JXL_ENSURE(!fp);
    589        const auto process_row = [&](const uint32_t task,
    590                                     size_t /* thread */) -> Status {
    591          const size_t y = task;
    592          const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
    593          const pixel_type* const JXL_RESTRICT row_in_Y =
    594              mr.Row(&gi.channel[0].plane, y);
    595          float* const JXL_RESTRICT row_out = get_row(c, y);
    596          HWY_DYNAMIC_DISPATCH(MultiplySum)
    597          (xsize_shifted, row_in, row_in_Y, factor, row_out);
    598          return true;
    599        };
    600        JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted,
    601                                      ThreadPool::NoInit, process_row,
    602                                      "ModularIntToFloat"));
    603      } else if (fp) {
    604        int bits = metadata->m.bit_depth.bits_per_sample;
    605        int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample;
    606        const auto process_row = [&](const uint32_t task,
    607                                     size_t /* thread */) -> Status {
    608          const size_t y = task;
    609          const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
    610          if (rgb_from_gray) {
    611            for (size_t cc = 0; cc < 3; cc++) {
    612              float* const JXL_RESTRICT row_out = get_row(cc, y);
    613              JXL_RETURN_IF_ERROR(
    614                  int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits));
    615            }
    616          } else {
    617            float* const JXL_RESTRICT row_out = get_row(c, y);
    618            JXL_RETURN_IF_ERROR(
    619                int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits));
    620          }
    621          return true;
    622        };
    623        JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted,
    624                                      ThreadPool::NoInit, process_row,
    625                                      "ModularIntToFloat_losslessfloat"));
    626      } else {
    627        const auto process_row = [&](const uint32_t task,
    628                                     size_t /* thread */) -> Status {
    629          const size_t y = task;
    630          const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
    631          if (rgb_from_gray) {
    632            if (full_image.bitdepth < 23) {
    633              HWY_DYNAMIC_DISPATCH(RgbFromSingle)
    634              (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y),
    635               get_row(2, y));
    636            } else {
    637              SingleFromSingleAccurate(xsize_shifted, row_in, factor,
    638                                       get_row(0, y));
    639              SingleFromSingleAccurate(xsize_shifted, row_in, factor,
    640                                       get_row(1, y));
    641              SingleFromSingleAccurate(xsize_shifted, row_in, factor,
    642                                       get_row(2, y));
    643            }
    644          } else {
    645            float* const JXL_RESTRICT row_out = get_row(c, y);
    646            if (full_image.bitdepth < 23) {
    647              HWY_DYNAMIC_DISPATCH(SingleFromSingle)
    648              (xsize_shifted, row_in, factor, row_out);
    649            } else {
    650              SingleFromSingleAccurate(xsize_shifted, row_in, factor, row_out);
    651            }
    652          }
    653          return true;
    654        };
    655        JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted,
    656                                      ThreadPool::NoInit, process_row,
    657                                      "ModularIntToFloat"));
    658      }
    659      if (rgb_from_gray) {
    660        break;
    661      }
    662    }
    663    if (rgb_from_gray) {
    664      c = 1;
    665    }
    666  }
    667  size_t num_extra_channels = metadata->m.num_extra_channels;
    668  for (size_t ec = 0; ec < num_extra_channels; ec++, c++) {
    669    const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec];
    670    int bits = eci.bit_depth.bits_per_sample;
    671    int exp_bits = eci.bit_depth.exponent_bits_per_sample;
    672    bool fp = eci.bit_depth.floating_point_sample;
    673    JXL_ENSURE(fp || bits < 32);
    674    const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1));
    675    JXL_ENSURE(c < gi.channel.size());
    676    Channel& ch_in = gi.channel[c];
    677    const auto& buffer = render_pipeline_input.GetBuffer(3 + ec);
    678    Rect r = buffer.second;
    679    Rect mr(modular_rect.x0() >> ch_in.hshift,
    680            modular_rect.y0() >> ch_in.vshift,
    681            DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
    682            DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
    683    mr = mr.Crop(ch_in.plane);
    684    if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
    685      return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
    686                         "x%" PRIuS
    687                         " modular channel into "
    688                         "a %" PRIuS "x%" PRIuS " rect",
    689                         mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
    690    }
    691    for (size_t y = 0; y < r.ysize(); ++y) {
    692      float* const JXL_RESTRICT row_out = r.Row(buffer.first, y);
    693      const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
    694      if (fp) {
    695        JXL_RETURN_IF_ERROR(
    696            int_to_float(row_in, row_out, r.xsize(), bits, exp_bits));
    697      } else {
    698        if (full_image.bitdepth < 23) {
    699          HWY_DYNAMIC_DISPATCH(SingleFromSingle)
    700          (r.xsize(), row_in, factor, row_out);
    701        } else {
    702          SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out);
    703        }
    704      }
    705    }
    706  }
    707  return true;
    708 }
    709 
    710 Status ModularFrameDecoder::FinalizeDecoding(const FrameHeader& frame_header,
    711                                             PassesDecoderState* dec_state,
    712                                             jxl::ThreadPool* pool,
    713                                             bool inplace) {
    714  if (!use_full_image) return true;
    715  JxlMemoryManager* memory_manager = dec_state->memory_manager();
    716  Image gi{memory_manager};
    717  if (inplace) {
    718    gi = std::move(full_image);
    719  } else {
    720    JXL_ASSIGN_OR_RETURN(gi, Image::Clone(full_image));
    721  }
    722  size_t xsize = gi.w;
    723  size_t ysize = gi.h;
    724 
    725  JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s",
    726              gi.DebugString().c_str());
    727 
    728  // Don't use threads if total image size is smaller than a group
    729  if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr;
    730 
    731  // Undo the global transforms
    732  gi.undo_transforms(global_header.wp_header, pool);
    733  JXL_ENSURE(global_transform.empty());
    734  if (gi.error) return JXL_FAILURE("Undoing transforms failed");
    735 
    736  for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) {
    737    dec_state->render_pipeline->ClearDone(i);
    738  }
    739 
    740  const auto init = [&](size_t num_threads) -> Status {
    741    bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT ||
    742                          (frame_header.flags & FrameHeader::kNoise));
    743    JXL_RETURN_IF_ERROR(dec_state->render_pipeline->PrepareForThreads(
    744        num_threads, use_group_ids));
    745    return true;
    746  };
    747  const auto process_group = [&](const uint32_t group,
    748                                 size_t thread_id) -> Status {
    749    RenderPipelineInput input =
    750        dec_state->render_pipeline->GetInputBuffers(group, thread_id);
    751    JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(
    752        frame_header, gi, dec_state, nullptr, input,
    753        dec_state->shared->frame_dim.GroupRect(group)));
    754    JXL_RETURN_IF_ERROR(input.Done());
    755    return true;
    756  };
    757  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0,
    758                                dec_state->shared->frame_dim.num_groups, init,
    759                                process_group, "ModularToRect"));
    760  return true;
    761 }
    762 
    763 static constexpr const float kAlmostZero = 1e-8f;
    764 
    765 Status ModularFrameDecoder::DecodeQuantTable(
    766    JxlMemoryManager* memory_manager, size_t required_size_x,
    767    size_t required_size_y, BitReader* br, QuantEncoding* encoding, size_t idx,
    768    ModularFrameDecoder* modular_frame_decoder) {
    769  JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den));
    770  if (encoding->qraw.qtable_den < kAlmostZero) {
    771    // qtable[] values are already checked for <= 0 so the denominator may not
    772    // be negative.
    773    return JXL_FAILURE("Invalid qtable_den: value too small");
    774  }
    775  JXL_ASSIGN_OR_RETURN(
    776      Image image,
    777      Image::Create(memory_manager, required_size_x, required_size_y, 8, 3));
    778  ModularOptions options;
    779  if (modular_frame_decoder) {
    780    JXL_ASSIGN_OR_RETURN(ModularStreamId qt, ModularStreamId::QuantTable(idx));
    781    JXL_RETURN_IF_ERROR(ModularGenericDecompress(
    782        br, image, /*header=*/nullptr, qt.ID(modular_frame_decoder->frame_dim),
    783        &options, /*undo_transforms=*/true, &modular_frame_decoder->tree,
    784        &modular_frame_decoder->code, &modular_frame_decoder->context_map));
    785  } else {
    786    JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr,
    787                                                 0, &options,
    788                                                 /*undo_transforms=*/true));
    789  }
    790  if (!encoding->qraw.qtable) {
    791    encoding->qraw.qtable =
    792        new std::vector<int>(required_size_x * required_size_y * 3);
    793  } else {
    794    JXL_ENSURE(encoding->qraw.qtable->size() ==
    795               required_size_x * required_size_y * 3);
    796  }
    797  int* qtable = encoding->qraw.qtable->data();
    798  for (size_t c = 0; c < 3; c++) {
    799    for (size_t y = 0; y < required_size_y; y++) {
    800      int32_t* JXL_RESTRICT row = image.channel[c].Row(y);
    801      for (size_t x = 0; x < required_size_x; x++) {
    802        qtable[c * required_size_x * required_size_y + y * required_size_x +
    803               x] = row[x];
    804        if (row[x] <= 0) {
    805          return JXL_FAILURE("Invalid raw quantization table");
    806        }
    807      }
    808    }
    809  }
    810  return true;
    811 }
    812 
    813 }  // namespace jxl
    814 #endif  // HWY_ONCE