tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_frame.cc (107979B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_frame.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <algorithm>
     11 #include <array>
     12 #include <cmath>
     13 #include <cstddef>
     14 #include <cstdint>
     15 #include <memory>
     16 #include <numeric>
     17 #include <utility>
     18 #include <vector>
     19 
     20 #include "lib/jxl/ac_context.h"
     21 #include "lib/jxl/ac_strategy.h"
     22 #include "lib/jxl/base/bits.h"
     23 #include "lib/jxl/base/common.h"
     24 #include "lib/jxl/base/compiler_specific.h"
     25 #include "lib/jxl/base/data_parallel.h"
     26 #include "lib/jxl/base/override.h"
     27 #include "lib/jxl/base/printf_macros.h"
     28 #include "lib/jxl/base/rect.h"
     29 #include "lib/jxl/base/status.h"
     30 #include "lib/jxl/chroma_from_luma.h"
     31 #include "lib/jxl/coeff_order.h"
     32 #include "lib/jxl/coeff_order_fwd.h"
     33 #include "lib/jxl/color_encoding_internal.h"
     34 #include "lib/jxl/common.h"  // kMaxNumPasses
     35 #include "lib/jxl/dct_util.h"
     36 #include "lib/jxl/dec_external_image.h"
     37 #include "lib/jxl/enc_ac_strategy.h"
     38 #include "lib/jxl/enc_adaptive_quantization.h"
     39 #include "lib/jxl/enc_ans.h"
     40 #include "lib/jxl/enc_aux_out.h"
     41 #include "lib/jxl/enc_bit_writer.h"
     42 #include "lib/jxl/enc_cache.h"
     43 #include "lib/jxl/enc_chroma_from_luma.h"
     44 #include "lib/jxl/enc_coeff_order.h"
     45 #include "lib/jxl/enc_context_map.h"
     46 #include "lib/jxl/enc_entropy_coder.h"
     47 #include "lib/jxl/enc_external_image.h"
     48 #include "lib/jxl/enc_fields.h"
     49 #include "lib/jxl/enc_group.h"
     50 #include "lib/jxl/enc_heuristics.h"
     51 #include "lib/jxl/enc_modular.h"
     52 #include "lib/jxl/enc_noise.h"
     53 #include "lib/jxl/enc_params.h"
     54 #include "lib/jxl/enc_patch_dictionary.h"
     55 #include "lib/jxl/enc_photon_noise.h"
     56 #include "lib/jxl/enc_quant_weights.h"
     57 #include "lib/jxl/enc_splines.h"
     58 #include "lib/jxl/enc_toc.h"
     59 #include "lib/jxl/enc_xyb.h"
     60 #include "lib/jxl/fields.h"
     61 #include "lib/jxl/frame_dimensions.h"
     62 #include "lib/jxl/frame_header.h"
     63 #include "lib/jxl/image.h"
     64 #include "lib/jxl/image_bundle.h"
     65 #include "lib/jxl/image_ops.h"
     66 #include "lib/jxl/jpeg/enc_jpeg_data.h"
     67 #include "lib/jxl/loop_filter.h"
     68 #include "lib/jxl/modular/options.h"
     69 #include "lib/jxl/quant_weights.h"
     70 #include "lib/jxl/quantizer.h"
     71 #include "lib/jxl/splines.h"
     72 #include "lib/jxl/toc.h"
     73 
     74 namespace jxl {
     75 
     76 Status ParamsPostInit(CompressParams* p) {
     77  if (!p->manual_noise.empty() &&
     78      p->manual_noise.size() != NoiseParams::kNumNoisePoints) {
     79    return JXL_FAILURE("Invalid number of noise lut entries");
     80  }
     81  if (!p->manual_xyb_factors.empty() && p->manual_xyb_factors.size() != 3) {
     82    return JXL_FAILURE("Invalid number of XYB quantization factors");
     83  }
     84  if (!p->modular_mode && p->butteraugli_distance == 0.0) {
     85    p->butteraugli_distance = kMinButteraugliDistance;
     86  }
     87  if (p->original_butteraugli_distance == -1.0) {
     88    p->original_butteraugli_distance = p->butteraugli_distance;
     89  }
     90  if (p->resampling <= 0) {
     91    p->resampling = 1;
     92    // For very low bit rates, using 2x2 resampling gives better results on
     93    // most photographic images, with an adjusted butteraugli score chosen to
     94    // give roughly the same amount of bits per pixel.
     95    if (!p->already_downsampled && p->butteraugli_distance >= 20) {
     96      p->resampling = 2;
     97      p->butteraugli_distance = 6 + ((p->butteraugli_distance - 20) * 0.25);
     98    }
     99  }
    100  if (p->ec_resampling <= 0) {
    101    p->ec_resampling = p->resampling;
    102  }
    103  return true;
    104 }
    105 
    106 namespace {
    107 
    108 template <typename T>
    109 uint32_t GetBitDepth(JxlBitDepth bit_depth, const T& metadata,
    110                     JxlPixelFormat format) {
    111  if (bit_depth.type == JXL_BIT_DEPTH_FROM_PIXEL_FORMAT) {
    112    return BitsPerChannel(format.data_type);
    113  } else if (bit_depth.type == JXL_BIT_DEPTH_FROM_CODESTREAM) {
    114    return metadata.bit_depth.bits_per_sample;
    115  } else if (bit_depth.type == JXL_BIT_DEPTH_CUSTOM) {
    116    return bit_depth.bits_per_sample;
    117  } else {
    118    return 0;
    119  }
    120 }
    121 
    122 Status CopyColorChannels(JxlChunkedFrameInputSource input, Rect rect,
    123                         const FrameInfo& frame_info,
    124                         const ImageMetadata& metadata, ThreadPool* pool,
    125                         Image3F* color, ImageF* alpha,
    126                         bool* has_interleaved_alpha) {
    127  JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0};
    128  input.get_color_channels_pixel_format(input.opaque, &format);
    129  *has_interleaved_alpha = format.num_channels == 2 || format.num_channels == 4;
    130  size_t bits_per_sample =
    131      GetBitDepth(frame_info.image_bit_depth, metadata, format);
    132  size_t row_offset;
    133  auto buffer = GetColorBuffer(input, rect.x0(), rect.y0(), rect.xsize(),
    134                               rect.ysize(), &row_offset);
    135  if (!buffer) {
    136    return JXL_FAILURE("no buffer for color channels given");
    137  }
    138  size_t color_channels = frame_info.ib_needs_color_transform
    139                              ? metadata.color_encoding.Channels()
    140                              : 3;
    141  if (format.num_channels < color_channels) {
    142    return JXL_FAILURE("Expected %" PRIuS
    143                       " color channels, received only %u channels",
    144                       color_channels, format.num_channels);
    145  }
    146  const uint8_t* data = reinterpret_cast<const uint8_t*>(buffer.get());
    147  for (size_t c = 0; c < color_channels; ++c) {
    148    JXL_RETURN_IF_ERROR(ConvertFromExternalNoSizeCheck(
    149        data, rect.xsize(), rect.ysize(), row_offset, bits_per_sample, format,
    150        c, pool, &color->Plane(c)));
    151  }
    152  if (color_channels == 1) {
    153    JXL_RETURN_IF_ERROR(CopyImageTo(color->Plane(0), &color->Plane(1)));
    154    JXL_RETURN_IF_ERROR(CopyImageTo(color->Plane(0), &color->Plane(2)));
    155  }
    156  if (alpha) {
    157    if (*has_interleaved_alpha) {
    158      JXL_RETURN_IF_ERROR(ConvertFromExternalNoSizeCheck(
    159          data, rect.xsize(), rect.ysize(), row_offset, bits_per_sample, format,
    160          format.num_channels - 1, pool, alpha));
    161    } else {
    162      // if alpha is not passed, but it is expected, then assume
    163      // it is all-opaque
    164      FillImage(1.0f, alpha);
    165    }
    166  }
    167  return true;
    168 }
    169 
    170 Status CopyExtraChannels(JxlChunkedFrameInputSource input, Rect rect,
    171                         const FrameInfo& frame_info,
    172                         const ImageMetadata& metadata,
    173                         bool has_interleaved_alpha, ThreadPool* pool,
    174                         std::vector<ImageF>* extra_channels) {
    175  for (size_t ec = 0; ec < metadata.num_extra_channels; ec++) {
    176    if (has_interleaved_alpha &&
    177        metadata.extra_channel_info[ec].type == ExtraChannel::kAlpha) {
    178      // Skip this alpha channel, but still request additional alpha channels
    179      // if they exist.
    180      has_interleaved_alpha = false;
    181      continue;
    182    }
    183    JxlPixelFormat ec_format = {1, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0};
    184    input.get_extra_channel_pixel_format(input.opaque, ec, &ec_format);
    185    ec_format.num_channels = 1;
    186    size_t row_offset;
    187    auto buffer =
    188        GetExtraChannelBuffer(input, ec, rect.x0(), rect.y0(), rect.xsize(),
    189                              rect.ysize(), &row_offset);
    190    if (!buffer) {
    191      return JXL_FAILURE("no buffer for extra channel given");
    192    }
    193    size_t bits_per_sample = GetBitDepth(
    194        frame_info.image_bit_depth, metadata.extra_channel_info[ec], ec_format);
    195    if (!ConvertFromExternalNoSizeCheck(
    196            reinterpret_cast<const uint8_t*>(buffer.get()), rect.xsize(),
    197            rect.ysize(), row_offset, bits_per_sample, ec_format, 0, pool,
    198            &(*extra_channels)[ec])) {
    199      return JXL_FAILURE("Failed to set buffer for extra channel");
    200    }
    201  }
    202  return true;
    203 }
    204 
    205 void SetProgressiveMode(const CompressParams& cparams,
    206                        ProgressiveSplitter* progressive_splitter) {
    207  constexpr PassDefinition progressive_passes_dc_vlf_lf_full_ac[] = {
    208      {/*num_coefficients=*/2, /*shift=*/0,
    209       /*suitable_for_downsampling_of_at_least=*/4},
    210      {/*num_coefficients=*/3, /*shift=*/0,
    211       /*suitable_for_downsampling_of_at_least=*/2},
    212      {/*num_coefficients=*/8, /*shift=*/0,
    213       /*suitable_for_downsampling_of_at_least=*/0},
    214  };
    215  constexpr PassDefinition progressive_passes_dc_quant_ac_full_ac[] = {
    216      {/*num_coefficients=*/8, /*shift=*/1,
    217       /*suitable_for_downsampling_of_at_least=*/2},
    218      {/*num_coefficients=*/8, /*shift=*/0,
    219       /*suitable_for_downsampling_of_at_least=*/0},
    220  };
    221  bool progressive_mode = ApplyOverride(cparams.progressive_mode, false);
    222  bool qprogressive_mode = ApplyOverride(cparams.qprogressive_mode, false);
    223  if (cparams.custom_progressive_mode) {
    224    progressive_splitter->SetProgressiveMode(*cparams.custom_progressive_mode);
    225  } else if (qprogressive_mode) {
    226    progressive_splitter->SetProgressiveMode(
    227        ProgressiveMode{progressive_passes_dc_quant_ac_full_ac});
    228  } else if (progressive_mode) {
    229    progressive_splitter->SetProgressiveMode(
    230        ProgressiveMode{progressive_passes_dc_vlf_lf_full_ac});
    231  }
    232 }
    233 
    234 uint64_t FrameFlagsFromParams(const CompressParams& cparams) {
    235  uint64_t flags = 0;
    236 
    237  const float dist = cparams.butteraugli_distance;
    238 
    239  // We don't add noise at low butteraugli distances because the original
    240  // noise is stored within the compressed image and adding noise makes things
    241  // worse.
    242  if (ApplyOverride(cparams.noise, dist >= kMinButteraugliForNoise) ||
    243      cparams.photon_noise_iso > 0 ||
    244      cparams.manual_noise.size() == NoiseParams::kNumNoisePoints) {
    245    flags |= FrameHeader::kNoise;
    246  }
    247 
    248  if (cparams.progressive_dc > 0 && cparams.modular_mode == false) {
    249    flags |= FrameHeader::kUseDcFrame;
    250  }
    251 
    252  return flags;
    253 }
    254 
    255 Status LoopFilterFromParams(const CompressParams& cparams, bool streaming_mode,
    256                            FrameHeader* JXL_RESTRICT frame_header) {
    257  LoopFilter* loop_filter = &frame_header->loop_filter;
    258 
    259  // Gaborish defaults to enabled in Hare or slower.
    260  loop_filter->gab = ApplyOverride(
    261      cparams.gaborish, cparams.speed_tier <= SpeedTier::kHare &&
    262                            frame_header->encoding == FrameEncoding::kVarDCT &&
    263                            cparams.decoding_speed_tier < 4 &&
    264                            cparams.butteraugli_distance > 0.5f &&
    265                            !cparams.disable_perceptual_optimizations);
    266 
    267  if (cparams.epf != -1) {
    268    loop_filter->epf_iters = cparams.epf;
    269  } else if (cparams.disable_perceptual_optimizations) {
    270    loop_filter->epf_iters = 0;
    271    return true;
    272  } else {
    273    if (frame_header->encoding == FrameEncoding::kModular) {
    274      loop_filter->epf_iters = 0;
    275    } else {
    276      constexpr float kThresholds[3] = {0.7, 1.5, 4.0};
    277      loop_filter->epf_iters = 0;
    278      if (cparams.decoding_speed_tier < 3) {
    279        for (size_t i = cparams.decoding_speed_tier == 2 ? 1 : 0; i < 3; i++) {
    280          if (cparams.butteraugli_distance >= kThresholds[i]) {
    281            loop_filter->epf_iters++;
    282          }
    283        }
    284      }
    285    }
    286  }
    287  // Strength of EPF in modular mode.
    288  if (frame_header->encoding == FrameEncoding::kModular &&
    289      !cparams.IsLossless()) {
    290    // TODO(veluca): this formula is nonsense.
    291    loop_filter->epf_sigma_for_modular =
    292        std::max(cparams.butteraugli_distance, 1.0f);
    293  }
    294  if (frame_header->encoding == FrameEncoding::kModular &&
    295      cparams.lossy_palette) {
    296    loop_filter->epf_sigma_for_modular = 1.0f;
    297  }
    298 
    299  return true;
    300 }
    301 
    302 Status MakeFrameHeader(size_t xsize, size_t ysize,
    303                       const CompressParams& cparams,
    304                       const ProgressiveSplitter& progressive_splitter,
    305                       const FrameInfo& frame_info,
    306                       const jpeg::JPEGData* jpeg_data, bool streaming_mode,
    307                       FrameHeader* JXL_RESTRICT frame_header) {
    308  frame_header->nonserialized_is_preview = frame_info.is_preview;
    309  frame_header->is_last = frame_info.is_last;
    310  frame_header->save_before_color_transform =
    311      frame_info.save_before_color_transform;
    312  frame_header->frame_type = frame_info.frame_type;
    313  frame_header->name = frame_info.name;
    314 
    315  JXL_RETURN_IF_ERROR(progressive_splitter.InitPasses(&frame_header->passes));
    316 
    317  if (cparams.modular_mode) {
    318    frame_header->encoding = FrameEncoding::kModular;
    319    if (cparams.modular_group_size_shift == -1) {
    320      frame_header->group_size_shift = 1;
    321      // no point using groups when only one group is full and the others are
    322      // less than half full: multithreading will not really help much, while
    323      // compression does suffer
    324      if (xsize <= 400 && ysize <= 400) {
    325        frame_header->group_size_shift = 2;
    326      }
    327    } else {
    328      frame_header->group_size_shift = cparams.modular_group_size_shift;
    329    }
    330  }
    331 
    332  if (jpeg_data) {
    333    // we are transcoding a JPEG, so we don't get to choose
    334    frame_header->encoding = FrameEncoding::kVarDCT;
    335    frame_header->x_qm_scale = 2;
    336    frame_header->b_qm_scale = 2;
    337    JXL_RETURN_IF_ERROR(SetChromaSubsamplingFromJpegData(
    338        *jpeg_data, &frame_header->chroma_subsampling));
    339    JXL_RETURN_IF_ERROR(SetColorTransformFromJpegData(
    340        *jpeg_data, &frame_header->color_transform));
    341  } else {
    342    frame_header->color_transform = cparams.color_transform;
    343    if (!cparams.modular_mode &&
    344        (frame_header->chroma_subsampling.MaxHShift() != 0 ||
    345         frame_header->chroma_subsampling.MaxVShift() != 0)) {
    346      return JXL_FAILURE(
    347          "Chroma subsampling is not supported in VarDCT mode when not "
    348          "recompressing JPEGs");
    349    }
    350  }
    351  if (frame_header->color_transform != ColorTransform::kYCbCr &&
    352      (frame_header->chroma_subsampling.MaxHShift() != 0 ||
    353       frame_header->chroma_subsampling.MaxVShift() != 0)) {
    354    return JXL_FAILURE(
    355        "Chroma subsampling is not supported when color transform is not "
    356        "YCbCr");
    357  }
    358 
    359  frame_header->flags = FrameFlagsFromParams(cparams);
    360  // Non-photon noise is not supported in the Modular encoder for now.
    361  if (frame_header->encoding != FrameEncoding::kVarDCT &&
    362      cparams.photon_noise_iso == 0 && cparams.manual_noise.empty()) {
    363    frame_header->UpdateFlag(false, FrameHeader::Flags::kNoise);
    364  }
    365 
    366  JXL_RETURN_IF_ERROR(
    367      LoopFilterFromParams(cparams, streaming_mode, frame_header));
    368 
    369  frame_header->dc_level = frame_info.dc_level;
    370  if (frame_header->dc_level > 2) {
    371    // With 3 or more progressive_dc frames, the implementation does not yet
    372    // work, see enc_cache.cc.
    373    return JXL_FAILURE("progressive_dc > 2 is not yet supported");
    374  }
    375  if (cparams.progressive_dc > 0 &&
    376      (cparams.ec_resampling != 1 || cparams.resampling != 1)) {
    377    return JXL_FAILURE("Resampling not supported with DC frames");
    378  }
    379  if (cparams.resampling != 1 && cparams.resampling != 2 &&
    380      cparams.resampling != 4 && cparams.resampling != 8) {
    381    return JXL_FAILURE("Invalid resampling factor");
    382  }
    383  if (cparams.ec_resampling != 1 && cparams.ec_resampling != 2 &&
    384      cparams.ec_resampling != 4 && cparams.ec_resampling != 8) {
    385    return JXL_FAILURE("Invalid ec_resampling factor");
    386  }
    387  // Resized frames.
    388  if (frame_info.frame_type != FrameType::kDCFrame) {
    389    frame_header->frame_origin = frame_info.origin;
    390    size_t ups = 1;
    391    if (cparams.already_downsampled) ups = cparams.resampling;
    392 
    393    // TODO(lode): this is not correct in case of odd original image sizes in
    394    // combination with cparams.already_downsampled. Likely these values should
    395    // be set to respectively frame_header->default_xsize() and
    396    // frame_header->default_ysize() instead, the original (non downsampled)
    397    // intended decoded image dimensions. But it may be more subtle than that
    398    // if combined with crop. This issue causes custom_size_or_origin to be
    399    // incorrectly set to true in case of already_downsampled with odd output
    400    // image size when no cropping is used.
    401    frame_header->frame_size.xsize = xsize * ups;
    402    frame_header->frame_size.ysize = ysize * ups;
    403    if (frame_info.origin.x0 != 0 || frame_info.origin.y0 != 0 ||
    404        frame_header->frame_size.xsize != frame_header->default_xsize() ||
    405        frame_header->frame_size.ysize != frame_header->default_ysize()) {
    406      frame_header->custom_size_or_origin = true;
    407    }
    408  }
    409  // Upsampling.
    410  frame_header->upsampling = cparams.resampling;
    411  const std::vector<ExtraChannelInfo>& extra_channels =
    412      frame_header->nonserialized_metadata->m.extra_channel_info;
    413  frame_header->extra_channel_upsampling.clear();
    414  frame_header->extra_channel_upsampling.resize(extra_channels.size(),
    415                                                cparams.ec_resampling);
    416  frame_header->save_as_reference = frame_info.save_as_reference;
    417 
    418  // Set blending-related information.
    419  if (frame_info.blend || frame_header->custom_size_or_origin) {
    420    // Set blend_channel to the first alpha channel. These values are only
    421    // encoded in case a blend mode involving alpha is used and there are more
    422    // than one extra channels.
    423    size_t index = 0;
    424    if (frame_info.alpha_channel == -1) {
    425      if (extra_channels.size() > 1) {
    426        for (size_t i = 0; i < extra_channels.size(); i++) {
    427          if (extra_channels[i].type == ExtraChannel::kAlpha) {
    428            index = i;
    429            break;
    430          }
    431        }
    432      }
    433    } else {
    434      index = static_cast<size_t>(frame_info.alpha_channel);
    435      JXL_ENSURE(index == 0 || index < extra_channels.size());
    436    }
    437    frame_header->blending_info.alpha_channel = index;
    438    frame_header->blending_info.mode =
    439        frame_info.blend ? frame_info.blendmode : BlendMode::kReplace;
    440    frame_header->blending_info.source = frame_info.source;
    441    frame_header->blending_info.clamp = frame_info.clamp;
    442    const auto& extra_channel_info = frame_info.extra_channel_blending_info;
    443    for (size_t i = 0; i < extra_channels.size(); i++) {
    444      if (i < extra_channel_info.size()) {
    445        frame_header->extra_channel_blending_info[i] = extra_channel_info[i];
    446      } else {
    447        frame_header->extra_channel_blending_info[i].alpha_channel = index;
    448        BlendMode default_blend = frame_info.blendmode;
    449        if (extra_channels[i].type != ExtraChannel::kBlack && i != index) {
    450          // K needs to be blended, spot colors and other stuff gets added
    451          default_blend = BlendMode::kAdd;
    452        }
    453        frame_header->extra_channel_blending_info[i].mode =
    454            frame_info.blend ? default_blend : BlendMode::kReplace;
    455        frame_header->extra_channel_blending_info[i].source = 1;
    456      }
    457    }
    458  }
    459 
    460  frame_header->animation_frame.duration = frame_info.duration;
    461  frame_header->animation_frame.timecode = frame_info.timecode;
    462 
    463  if (jpeg_data) {
    464    frame_header->UpdateFlag(false, FrameHeader::kUseDcFrame);
    465    frame_header->UpdateFlag(true, FrameHeader::kSkipAdaptiveDCSmoothing);
    466  }
    467 
    468  return true;
    469 }
    470 
    471 // Invisible (alpha = 0) pixels tend to be a mess in optimized PNGs.
    472 // Since they have no visual impact whatsoever, we can replace them with
    473 // something that compresses better and reduces artifacts near the edges. This
    474 // does some kind of smooth stuff that seems to work.
    475 // Replace invisible pixels with a weighted average of the pixel to the left,
    476 // the pixel to the topright, and non-invisible neighbours.
    477 // Produces downward-blurry smears, with in the upwards direction only a 1px
    478 // edge duplication but not more. It would probably be better to smear in all
    479 // directions. That requires an alpha-weighed convolution with a large enough
    480 // kernel though, which might be overkill...
    481 void SimplifyInvisible(Image3F* image, const ImageF& alpha, bool lossless) {
    482  for (size_t c = 0; c < 3; ++c) {
    483    for (size_t y = 0; y < image->ysize(); ++y) {
    484      float* JXL_RESTRICT row = image->PlaneRow(c, y);
    485      const float* JXL_RESTRICT prow =
    486          (y > 0 ? image->PlaneRow(c, y - 1) : nullptr);
    487      const float* JXL_RESTRICT nrow =
    488          (y + 1 < image->ysize() ? image->PlaneRow(c, y + 1) : nullptr);
    489      const float* JXL_RESTRICT a = alpha.Row(y);
    490      const float* JXL_RESTRICT pa = (y > 0 ? alpha.Row(y - 1) : nullptr);
    491      const float* JXL_RESTRICT na =
    492          (y + 1 < image->ysize() ? alpha.Row(y + 1) : nullptr);
    493      for (size_t x = 0; x < image->xsize(); ++x) {
    494        if (a[x] == 0) {
    495          if (lossless) {
    496            row[x] = 0;
    497            continue;
    498          }
    499          float d = 0.f;
    500          row[x] = 0;
    501          if (x > 0) {
    502            row[x] += row[x - 1];
    503            d++;
    504            if (a[x - 1] > 0.f) {
    505              row[x] += row[x - 1];
    506              d++;
    507            }
    508          }
    509          if (x + 1 < image->xsize()) {
    510            if (y > 0) {
    511              row[x] += prow[x + 1];
    512              d++;
    513            }
    514            if (a[x + 1] > 0.f) {
    515              row[x] += 2.f * row[x + 1];
    516              d += 2.f;
    517            }
    518            if (y > 0 && pa[x + 1] > 0.f) {
    519              row[x] += 2.f * prow[x + 1];
    520              d += 2.f;
    521            }
    522            if (y + 1 < image->ysize() && na[x + 1] > 0.f) {
    523              row[x] += 2.f * nrow[x + 1];
    524              d += 2.f;
    525            }
    526          }
    527          if (y > 0 && pa[x] > 0.f) {
    528            row[x] += 2.f * prow[x];
    529            d += 2.f;
    530          }
    531          if (y + 1 < image->ysize() && na[x] > 0.f) {
    532            row[x] += 2.f * nrow[x];
    533            d += 2.f;
    534          }
    535          if (d > 1.f) row[x] /= d;
    536        }
    537      }
    538    }
    539  }
    540 }
    541 
    542 struct PixelStatsForChromacityAdjustment {
    543  float dx = 0;
    544  float db = 0;
    545  float exposed_blue = 0;
    546  static float CalcPlane(const ImageF* JXL_RESTRICT plane, const Rect& rect) {
    547    float xmax = 0;
    548    float ymax = 0;
    549    for (size_t ty = 1; ty < rect.ysize(); ++ty) {
    550      for (size_t tx = 1; tx < rect.xsize(); ++tx) {
    551        float cur = rect.Row(plane, ty)[tx];
    552        float prev_row = rect.Row(plane, ty - 1)[tx];
    553        float prev = rect.Row(plane, ty)[tx - 1];
    554        xmax = std::max(xmax, std::abs(cur - prev));
    555        ymax = std::max(ymax, std::abs(cur - prev_row));
    556      }
    557    }
    558    return std::max(xmax, ymax);
    559  }
    560  void CalcExposedBlue(const ImageF* JXL_RESTRICT plane_y,
    561                       const ImageF* JXL_RESTRICT plane_b, const Rect& rect) {
    562    float eb = 0;
    563    float xmax = 0;
    564    float ymax = 0;
    565    for (size_t ty = 1; ty < rect.ysize(); ++ty) {
    566      for (size_t tx = 1; tx < rect.xsize(); ++tx) {
    567        float cur_y = rect.Row(plane_y, ty)[tx];
    568        float cur_b = rect.Row(plane_b, ty)[tx];
    569        float exposed_b = cur_b - cur_y * 1.2;
    570        float diff_b = cur_b - cur_y;
    571        float prev_row = rect.Row(plane_b, ty - 1)[tx];
    572        float prev = rect.Row(plane_b, ty)[tx - 1];
    573        float diff_prev_row = prev_row - rect.Row(plane_y, ty - 1)[tx];
    574        float diff_prev = prev - rect.Row(plane_y, ty)[tx - 1];
    575        xmax = std::max(xmax, std::abs(diff_b - diff_prev));
    576        ymax = std::max(ymax, std::abs(diff_b - diff_prev_row));
    577        if (exposed_b >= 0) {
    578          exposed_b *= fabs(cur_b - prev) + fabs(cur_b - prev_row);
    579          eb = std::max(eb, exposed_b);
    580        }
    581      }
    582    }
    583    exposed_blue = eb;
    584    db = std::max(xmax, ymax);
    585  }
    586  void Calc(const Image3F* JXL_RESTRICT opsin, const Rect& rect) {
    587    dx = CalcPlane(&opsin->Plane(0), rect);
    588    CalcExposedBlue(&opsin->Plane(1), &opsin->Plane(2), rect);
    589  }
    590  int HowMuchIsXChannelPixelized() const {
    591    if (dx >= 0.026) {
    592      return 3;
    593    }
    594    if (dx >= 0.022) {
    595      return 2;
    596    }
    597    if (dx >= 0.015) {
    598      return 1;
    599    }
    600    return 0;
    601  }
    602  int HowMuchIsBChannelPixelized() const {
    603    int add = exposed_blue >= 0.13 ? 1 : 0;
    604    if (db > 0.38) {
    605      return 2 + add;
    606    }
    607    if (db > 0.33) {
    608      return 1 + add;
    609    }
    610    if (db > 0.28) {
    611      return add;
    612    }
    613    return 0;
    614  }
    615 };
    616 
    617 void ComputeChromacityAdjustments(const CompressParams& cparams,
    618                                  const Image3F& opsin, const Rect& rect,
    619                                  FrameHeader* frame_header) {
    620  if (frame_header->encoding != FrameEncoding::kVarDCT ||
    621      cparams.max_error_mode) {
    622    return;
    623  }
    624  // 1) Distance based approach for chromacity adjustment:
    625  float x_qm_scale_steps[3] = {2.5f, 5.5f, 9.5f};
    626  frame_header->x_qm_scale = 3;
    627  for (float x_qm_scale_step : x_qm_scale_steps) {
    628    if (cparams.original_butteraugli_distance > x_qm_scale_step) {
    629      frame_header->x_qm_scale++;
    630    }
    631  }
    632  // 2) Pixel-based approach for chromacity adjustment:
    633  // look at the individual pixels and make a guess how difficult
    634  // the image would be based on the worst case pixel.
    635  PixelStatsForChromacityAdjustment pixel_stats;
    636  if (cparams.speed_tier <= SpeedTier::kSquirrel) {
    637    pixel_stats.Calc(&opsin, rect);
    638  }
    639  // For X take the most severe adjustment.
    640  frame_header->x_qm_scale = std::max<int>(
    641      frame_header->x_qm_scale, 2 + pixel_stats.HowMuchIsXChannelPixelized());
    642  // B only adjusted by pixel-based approach.
    643  frame_header->b_qm_scale = 2 + pixel_stats.HowMuchIsBChannelPixelized();
    644 }
    645 
    646 void ComputeNoiseParams(const CompressParams& cparams, bool streaming_mode,
    647                        bool color_is_jpeg, const Image3F& opsin,
    648                        const FrameDimensions& frame_dim,
    649                        FrameHeader* frame_header, NoiseParams* noise_params) {
    650  if (cparams.photon_noise_iso > 0) {
    651    *noise_params = SimulatePhotonNoise(frame_dim.xsize, frame_dim.ysize,
    652                                        cparams.photon_noise_iso);
    653  } else if (cparams.manual_noise.size() == NoiseParams::kNumNoisePoints) {
    654    for (size_t i = 0; i < NoiseParams::kNumNoisePoints; i++) {
    655      noise_params->lut[i] = cparams.manual_noise[i];
    656    }
    657  } else if (frame_header->encoding == FrameEncoding::kVarDCT &&
    658             frame_header->flags & FrameHeader::kNoise && !color_is_jpeg &&
    659             !streaming_mode) {
    660    // Don't start at zero amplitude since adding noise is expensive -- it
    661    // significantly slows down decoding, and this is unlikely to
    662    // completely go away even with advanced optimizations. After the
    663    // kNoiseModelingRampUpDistanceRange we have reached the full level,
    664    // i.e. noise is no longer represented by the compressed image, so we
    665    // can add full noise by the noise modeling itself.
    666    static const float kNoiseModelingRampUpDistanceRange = 0.6;
    667    static const float kNoiseLevelAtStartOfRampUp = 0.25;
    668    static const float kNoiseRampupStart = 1.0;
    669    // TODO(user) test and properly select quality_coef with smooth
    670    // filter
    671    float quality_coef = 1.0f;
    672    const float rampup = (cparams.butteraugli_distance - kNoiseRampupStart) /
    673                         kNoiseModelingRampUpDistanceRange;
    674    if (rampup < 1.0f) {
    675      quality_coef = kNoiseLevelAtStartOfRampUp +
    676                     (1.0f - kNoiseLevelAtStartOfRampUp) * rampup;
    677    }
    678    if (rampup < 0.0f) {
    679      quality_coef = kNoiseRampupStart;
    680    }
    681    if (!GetNoiseParameter(opsin, noise_params, quality_coef)) {
    682      frame_header->flags &= ~FrameHeader::kNoise;
    683    }
    684  }
    685 }
    686 
    687 Status DownsampleColorChannels(const CompressParams& cparams,
    688                               const FrameHeader& frame_header,
    689                               bool color_is_jpeg, Image3F* opsin) {
    690  if (color_is_jpeg || frame_header.upsampling == 1 ||
    691      cparams.already_downsampled) {
    692    return true;
    693  }
    694  if (frame_header.encoding == FrameEncoding::kVarDCT &&
    695      frame_header.upsampling == 2) {
    696    // TODO(lode): use the regular DownsampleImage, or adapt to the custom
    697    // coefficients, if there is are custom upscaling coefficients in
    698    // CustomTransformData
    699    if (cparams.speed_tier <= SpeedTier::kSquirrel) {
    700      // TODO(lode): DownsampleImage2_Iterative is currently too slow to
    701      // be used for squirrel, make it faster, and / or enable it only for
    702      // kitten.
    703      JXL_RETURN_IF_ERROR(DownsampleImage2_Iterative(opsin));
    704    } else {
    705      JXL_RETURN_IF_ERROR(DownsampleImage2_Sharper(opsin));
    706    }
    707  } else {
    708    JXL_ASSIGN_OR_RETURN(*opsin,
    709                         DownsampleImage(*opsin, frame_header.upsampling));
    710  }
    711  if (frame_header.encoding == FrameEncoding::kVarDCT) {
    712    JXL_RETURN_IF_ERROR(PadImageToBlockMultipleInPlace(opsin));
    713  }
    714  return true;
    715 }
    716 
    717 template <size_t L, typename V, typename R>
    718 void FindIndexOfSumMaximum(const V* array, R* idx, V* sum) {
    719  static_assert(L > 0);
    720  V maxval = 0;
    721  V val = 0;
    722  R maxidx = 0;
    723  for (size_t i = 0; i < L; ++i) {
    724    val += array[i];
    725    if (val > maxval) {
    726      maxval = val;
    727      maxidx = i;
    728    }
    729  }
    730  *idx = maxidx;
    731  *sum = maxval;
    732 }
    733 
    734 Status ComputeJPEGTranscodingData(const jpeg::JPEGData& jpeg_data,
    735                                  const FrameHeader& frame_header,
    736                                  ThreadPool* pool,
    737                                  ModularFrameEncoder* enc_modular,
    738                                  PassesEncoderState* enc_state) {
    739  PassesSharedState& shared = enc_state->shared;
    740  JxlMemoryManager* memory_manager = enc_state->memory_manager();
    741  const FrameDimensions& frame_dim = shared.frame_dim;
    742 
    743  const size_t xsize = frame_dim.xsize_padded;
    744  const size_t ysize = frame_dim.ysize_padded;
    745  const size_t xsize_blocks = frame_dim.xsize_blocks;
    746  const size_t ysize_blocks = frame_dim.ysize_blocks;
    747 
    748  // no-op chroma from luma
    749  JXL_ASSIGN_OR_RETURN(shared.cmap, ColorCorrelationMap::Create(
    750                                        memory_manager, xsize, ysize, false));
    751  shared.ac_strategy.FillDCT8();
    752  FillImage(static_cast<uint8_t>(0), &shared.epf_sharpness);
    753 
    754  enc_state->coeffs.clear();
    755  while (enc_state->coeffs.size() < enc_state->passes.size()) {
    756    JXL_ASSIGN_OR_RETURN(
    757        std::unique_ptr<ACImageT<int32_t>> coeffs,
    758        ACImageT<int32_t>::Make(memory_manager, kGroupDim * kGroupDim,
    759                                frame_dim.num_groups));
    760    enc_state->coeffs.emplace_back(std::move(coeffs));
    761  }
    762 
    763  // convert JPEG quantization table to a Quantizer object
    764  float dcquantization[3];
    765  std::vector<QuantEncoding> qe(kNumQuantTables, QuantEncoding::Library<0>());
    766 
    767  auto jpeg_c_map =
    768      JpegOrder(frame_header.color_transform, jpeg_data.components.size() == 1);
    769 
    770  std::vector<int> qt(192);
    771  for (size_t c = 0; c < 3; c++) {
    772    size_t jpeg_c = jpeg_c_map[c];
    773    const int32_t* quant =
    774        jpeg_data.quant[jpeg_data.components[jpeg_c].quant_idx].values.data();
    775 
    776    dcquantization[c] = 255 * 8.0f / quant[0];
    777    for (size_t y = 0; y < 8; y++) {
    778      for (size_t x = 0; x < 8; x++) {
    779        // JPEG XL transposes the DCT, JPEG doesn't.
    780        qt[c * 64 + 8 * x + y] = quant[8 * y + x];
    781      }
    782    }
    783  }
    784  JXL_RETURN_IF_ERROR(DequantMatricesSetCustomDC(
    785      memory_manager, &shared.matrices, dcquantization));
    786  float dcquantization_r[3] = {1.0f / dcquantization[0],
    787                               1.0f / dcquantization[1],
    788                               1.0f / dcquantization[2]};
    789 
    790  std::vector<int32_t> scaled_qtable(192);
    791  for (size_t c = 0; c < 3; c++) {
    792    for (size_t i = 0; i < 64; i++) {
    793      scaled_qtable[64 * c + i] =
    794          (1 << kCFLFixedPointPrecision) * qt[64 + i] / qt[64 * c + i];
    795    }
    796  }
    797 
    798  qe[static_cast<size_t>(AcStrategyType::DCT)] =
    799      QuantEncoding::RAW(std::move(qt));
    800  JXL_RETURN_IF_ERROR(
    801      DequantMatricesSetCustom(&shared.matrices, qe, enc_modular));
    802 
    803  // Ensure that InvGlobalScale() is 1.
    804  shared.quantizer = Quantizer(shared.matrices, 1, kGlobalScaleDenom);
    805  // Recompute MulDC() and InvMulDC().
    806  shared.quantizer.RecomputeFromGlobalScale();
    807 
    808  // Per-block dequant scaling should be 1.
    809  FillImage(static_cast<int32_t>(shared.quantizer.InvGlobalScale()),
    810            &shared.raw_quant_field);
    811 
    812  auto jpeg_row = [&](size_t c, size_t y) {
    813    return jpeg_data.components[jpeg_c_map[c]].coeffs.data() +
    814           jpeg_data.components[jpeg_c_map[c]].width_in_blocks * kDCTBlockSize *
    815               y;
    816  };
    817 
    818  bool DCzero = (frame_header.color_transform == ColorTransform::kYCbCr);
    819  // Compute chroma-from-luma for AC (doesn't seem to be useful for DC)
    820  if (frame_header.chroma_subsampling.Is444() &&
    821      enc_state->cparams.force_cfl_jpeg_recompression &&
    822      jpeg_data.components.size() == 3) {
    823    for (size_t c : {0, 2}) {
    824      ImageSB* map = (c == 0 ? &shared.cmap.ytox_map : &shared.cmap.ytob_map);
    825      const float kScale = kDefaultColorFactor;
    826      const int kOffset = 127;
    827      const float kBase = c == 0 ? shared.cmap.base().YtoXRatio(0)
    828                                 : shared.cmap.base().YtoBRatio(0);
    829      const float kZeroThresh =
    830          kScale * kZeroBiasDefault[c] *
    831          0.9999f;  // just epsilon less for better rounding
    832 
    833      auto process_row = [&](const uint32_t task,
    834                             const size_t thread) -> Status {
    835        size_t ty = task;
    836        int8_t* JXL_RESTRICT row_out = map->Row(ty);
    837        for (size_t tx = 0; tx < map->xsize(); ++tx) {
    838          const size_t y0 = ty * kColorTileDimInBlocks;
    839          const size_t x0 = tx * kColorTileDimInBlocks;
    840          const size_t y1 = std::min(frame_dim.ysize_blocks,
    841                                     (ty + 1) * kColorTileDimInBlocks);
    842          const size_t x1 = std::min(frame_dim.xsize_blocks,
    843                                     (tx + 1) * kColorTileDimInBlocks);
    844          int32_t d_num_zeros[257] = {0};
    845          // TODO(veluca): this needs SIMD + fixed point adaptation, and/or
    846          // conversion to the new CfL algorithm.
    847          for (size_t y = y0; y < y1; ++y) {
    848            const int16_t* JXL_RESTRICT row_m = jpeg_row(1, y);
    849            const int16_t* JXL_RESTRICT row_s = jpeg_row(c, y);
    850            for (size_t x = x0; x < x1; ++x) {
    851              for (size_t coeffpos = 1; coeffpos < kDCTBlockSize; coeffpos++) {
    852                const float scaled_m = row_m[x * kDCTBlockSize + coeffpos] *
    853                                       scaled_qtable[64 * c + coeffpos] *
    854                                       (1.0f / (1 << kCFLFixedPointPrecision));
    855                const float scaled_s =
    856                    kScale * row_s[x * kDCTBlockSize + coeffpos] +
    857                    (kOffset - kBase * kScale) * scaled_m;
    858                if (std::abs(scaled_m) > 1e-8f) {
    859                  float from;
    860                  float to;
    861                  if (scaled_m > 0) {
    862                    from = (scaled_s - kZeroThresh) / scaled_m;
    863                    to = (scaled_s + kZeroThresh) / scaled_m;
    864                  } else {
    865                    from = (scaled_s + kZeroThresh) / scaled_m;
    866                    to = (scaled_s - kZeroThresh) / scaled_m;
    867                  }
    868                  if (from < 0.0f) {
    869                    from = 0.0f;
    870                  }
    871                  if (to > 255.0f) {
    872                    to = 255.0f;
    873                  }
    874                  // Instead of clamping the both values
    875                  // we just check that range is sane.
    876                  if (from <= to) {
    877                    d_num_zeros[static_cast<int>(std::ceil(from))]++;
    878                    d_num_zeros[static_cast<int>(std::floor(to + 1))]--;
    879                  }
    880                }
    881              }
    882            }
    883          }
    884          int best = 0;
    885          int32_t best_sum = 0;
    886          FindIndexOfSumMaximum<256>(d_num_zeros, &best, &best_sum);
    887          int32_t offset_sum = 0;
    888          for (int i = 0; i < 256; ++i) {
    889            if (i <= kOffset) {
    890              offset_sum += d_num_zeros[i];
    891            }
    892          }
    893          row_out[tx] = 0;
    894          if (best_sum > offset_sum + 1) {
    895            row_out[tx] = best - kOffset;
    896          }
    897        }
    898        return true;
    899      };
    900 
    901      JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, map->ysize(), ThreadPool::NoInit,
    902                                    process_row, "FindCorrelation"));
    903    }
    904  }
    905 
    906  JXL_ASSIGN_OR_RETURN(
    907      Image3F dc, Image3F::Create(memory_manager, xsize_blocks, ysize_blocks));
    908  if (!frame_header.chroma_subsampling.Is444()) {
    909    ZeroFillImage(&dc);
    910    for (auto& coeff : enc_state->coeffs) {
    911      coeff->ZeroFill();
    912    }
    913  }
    914  // JPEG DC is from -1024 to 1023.
    915  std::vector<size_t> dc_counts[3] = {};
    916  dc_counts[0].resize(2048);
    917  dc_counts[1].resize(2048);
    918  dc_counts[2].resize(2048);
    919  size_t total_dc[3] = {};
    920  for (size_t c : {1, 0, 2}) {
    921    if (jpeg_data.components.size() == 1 && c != 1) {
    922      for (auto& coeff : enc_state->coeffs) {
    923        coeff->ZeroFillPlane(c);
    924      }
    925      ZeroFillImage(&dc.Plane(c));
    926      // Ensure no division by 0.
    927      dc_counts[c][1024] = 1;
    928      total_dc[c] = 1;
    929      continue;
    930    }
    931    size_t hshift = frame_header.chroma_subsampling.HShift(c);
    932    size_t vshift = frame_header.chroma_subsampling.VShift(c);
    933    ImageSB& map = (c == 0 ? shared.cmap.ytox_map : shared.cmap.ytob_map);
    934    for (size_t group_index = 0; group_index < frame_dim.num_groups;
    935         group_index++) {
    936      const size_t gx = group_index % frame_dim.xsize_groups;
    937      const size_t gy = group_index / frame_dim.xsize_groups;
    938      int32_t* coeffs[kMaxNumPasses];
    939      for (size_t i = 0; i < enc_state->coeffs.size(); i++) {
    940        coeffs[i] = enc_state->coeffs[i]->PlaneRow(c, group_index, 0).ptr32;
    941      }
    942      int32_t block[64];
    943      for (size_t by = gy * kGroupDimInBlocks;
    944           by < ysize_blocks && by < (gy + 1) * kGroupDimInBlocks; ++by) {
    945        if ((by >> vshift) << vshift != by) continue;
    946        const int16_t* JXL_RESTRICT inputjpeg = jpeg_row(c, by >> vshift);
    947        const int16_t* JXL_RESTRICT inputjpegY = jpeg_row(1, by);
    948        float* JXL_RESTRICT fdc = dc.PlaneRow(c, by >> vshift);
    949        const int8_t* JXL_RESTRICT cm =
    950            map.ConstRow(by / kColorTileDimInBlocks);
    951        for (size_t bx = gx * kGroupDimInBlocks;
    952             bx < xsize_blocks && bx < (gx + 1) * kGroupDimInBlocks; ++bx) {
    953          if ((bx >> hshift) << hshift != bx) continue;
    954          size_t base = (bx >> hshift) * kDCTBlockSize;
    955          int idc;
    956          if (DCzero) {
    957            idc = inputjpeg[base];
    958          } else {
    959            idc = inputjpeg[base] + 1024 / qt[c * 64];
    960          }
    961          dc_counts[c][std::min(static_cast<uint32_t>(idc + 1024),
    962                                static_cast<uint32_t>(2047))]++;
    963          total_dc[c]++;
    964          fdc[bx >> hshift] = idc * dcquantization_r[c];
    965          if (c == 1 || !enc_state->cparams.force_cfl_jpeg_recompression ||
    966              !frame_header.chroma_subsampling.Is444()) {
    967            for (size_t y = 0; y < 8; y++) {
    968              for (size_t x = 0; x < 8; x++) {
    969                block[y * 8 + x] = inputjpeg[base + x * 8 + y];
    970              }
    971            }
    972          } else {
    973            const int32_t scale =
    974                ColorCorrelation::RatioJPEG(cm[bx / kColorTileDimInBlocks]);
    975 
    976            for (size_t y = 0; y < 8; y++) {
    977              for (size_t x = 0; x < 8; x++) {
    978                int Y = inputjpegY[kDCTBlockSize * bx + x * 8 + y];
    979                int QChroma = inputjpeg[kDCTBlockSize * bx + x * 8 + y];
    980                // Fixed-point multiply of CfL scale with quant table ratio
    981                // first, and Y value second.
    982                int coeff_scale = (scale * scaled_qtable[64 * c + y * 8 + x] +
    983                                   (1 << (kCFLFixedPointPrecision - 1))) >>
    984                                  kCFLFixedPointPrecision;
    985                int cfl_factor =
    986                    (Y * coeff_scale + (1 << (kCFLFixedPointPrecision - 1))) >>
    987                    kCFLFixedPointPrecision;
    988                int QCR = QChroma - cfl_factor;
    989                block[y * 8 + x] = QCR;
    990              }
    991            }
    992          }
    993          enc_state->progressive_splitter.SplitACCoefficients(
    994              block, AcStrategy::FromRawStrategy(AcStrategyType::DCT), bx, by,
    995              coeffs);
    996          for (size_t i = 0; i < enc_state->coeffs.size(); i++) {
    997            coeffs[i] += kDCTBlockSize;
    998          }
    999        }
   1000      }
   1001    }
   1002  }
   1003 
   1004  auto& dct = enc_state->shared.block_ctx_map.dc_thresholds;
   1005  auto& num_dc_ctxs = enc_state->shared.block_ctx_map.num_dc_ctxs;
   1006  num_dc_ctxs = 1;
   1007  for (size_t i = 0; i < 3; i++) {
   1008    dct[i].clear();
   1009    int num_thresholds = (CeilLog2Nonzero(total_dc[i]) - 12) / 2;
   1010    // up to 3 buckets per channel:
   1011    // dark/medium/bright, yellow/unsat/blue, green/unsat/red
   1012    num_thresholds = std::min(std::max(num_thresholds, 0), 2);
   1013    size_t cumsum = 0;
   1014    size_t cut = total_dc[i] / (num_thresholds + 1);
   1015    for (int j = 0; j < 2048; j++) {
   1016      cumsum += dc_counts[i][j];
   1017      if (cumsum > cut) {
   1018        dct[i].push_back(j - 1025);
   1019        cut = total_dc[i] * (dct[i].size() + 1) / (num_thresholds + 1);
   1020      }
   1021    }
   1022    num_dc_ctxs *= dct[i].size() + 1;
   1023  }
   1024 
   1025  auto& ctx_map = enc_state->shared.block_ctx_map.ctx_map;
   1026  ctx_map.clear();
   1027  ctx_map.resize(3 * kNumOrders * num_dc_ctxs, 0);
   1028 
   1029  int lbuckets = (dct[1].size() + 1);
   1030  for (size_t i = 0; i < num_dc_ctxs; i++) {
   1031    // up to 9 contexts for luma
   1032    ctx_map[i] = i / lbuckets;
   1033    // up to 3 contexts for chroma
   1034    ctx_map[kNumOrders * num_dc_ctxs + i] =
   1035        ctx_map[2 * kNumOrders * num_dc_ctxs + i] =
   1036            num_dc_ctxs / lbuckets + (i % lbuckets);
   1037  }
   1038  enc_state->shared.block_ctx_map.num_ctxs =
   1039      *std::max_element(ctx_map.begin(), ctx_map.end()) + 1;
   1040 
   1041  // disable DC frame for now
   1042  auto compute_dc_coeffs = [&](const uint32_t group_index,
   1043                               size_t /* thread */) -> Status {
   1044    const Rect r = enc_state->shared.frame_dim.DCGroupRect(group_index);
   1045    JXL_RETURN_IF_ERROR(enc_modular->AddVarDCTDC(frame_header, dc, r,
   1046                                                 group_index,
   1047                                                 /*nl_dc=*/false, enc_state,
   1048                                                 /*jpeg_transcode=*/true));
   1049    JXL_RETURN_IF_ERROR(enc_modular->AddACMetadata(
   1050        r, group_index, /*jpeg_transcode=*/true, enc_state));
   1051    return true;
   1052  };
   1053  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_dc_groups,
   1054                                ThreadPool::NoInit, compute_dc_coeffs,
   1055                                "Compute DC coeffs"));
   1056 
   1057  return true;
   1058 }
   1059 
   1060 Status ComputeVarDCTEncodingData(const FrameHeader& frame_header,
   1061                                 const Image3F* linear,
   1062                                 Image3F* JXL_RESTRICT opsin, const Rect& rect,
   1063                                 const JxlCmsInterface& cms, ThreadPool* pool,
   1064                                 ModularFrameEncoder* enc_modular,
   1065                                 PassesEncoderState* enc_state,
   1066                                 AuxOut* aux_out) {
   1067  JXL_ENSURE((rect.xsize() % kBlockDim) == 0 &&
   1068             (rect.ysize() % kBlockDim) == 0);
   1069  JxlMemoryManager* memory_manager = enc_state->memory_manager();
   1070  // Save pre-Gaborish opsin for AR control field heuristics computation.
   1071  Image3F orig_opsin;
   1072  JXL_ASSIGN_OR_RETURN(
   1073      orig_opsin, Image3F::Create(memory_manager, rect.xsize(), rect.ysize()));
   1074  JXL_RETURN_IF_ERROR(CopyImageTo(rect, *opsin, Rect(orig_opsin), &orig_opsin));
   1075  JXL_RETURN_IF_ERROR(orig_opsin.ShrinkTo(enc_state->shared.frame_dim.xsize,
   1076                                          enc_state->shared.frame_dim.ysize));
   1077 
   1078  JXL_RETURN_IF_ERROR(LossyFrameHeuristics(frame_header, enc_state, enc_modular,
   1079                                           linear, opsin, rect, cms, pool,
   1080                                           aux_out));
   1081 
   1082  JXL_RETURN_IF_ERROR(InitializePassesEncoder(
   1083      frame_header, *opsin, rect, cms, pool, enc_state, enc_modular, aux_out));
   1084 
   1085  JXL_RETURN_IF_ERROR(
   1086      ComputeARHeuristics(frame_header, enc_state, orig_opsin, rect, pool));
   1087 
   1088  JXL_RETURN_IF_ERROR(ComputeACMetadata(pool, enc_state, enc_modular));
   1089 
   1090  return true;
   1091 }
   1092 
   1093 Status ComputeAllCoeffOrders(PassesEncoderState& enc_state,
   1094                             const FrameDimensions& frame_dim) {
   1095  auto used_orders_info = ComputeUsedOrders(
   1096      enc_state.cparams.speed_tier, enc_state.shared.ac_strategy,
   1097      Rect(enc_state.shared.raw_quant_field));
   1098  enc_state.used_orders.resize(enc_state.progressive_splitter.GetNumPasses());
   1099  for (size_t i = 0; i < enc_state.progressive_splitter.GetNumPasses(); i++) {
   1100    JXL_RETURN_IF_ERROR(ComputeCoeffOrder(
   1101        enc_state.cparams.speed_tier, *enc_state.coeffs[i],
   1102        enc_state.shared.ac_strategy, frame_dim, enc_state.used_orders[i],
   1103        enc_state.used_acs, used_orders_info.first, used_orders_info.second,
   1104        &enc_state.shared.coeff_orders[i * enc_state.shared.coeff_order_size]));
   1105  }
   1106  enc_state.used_acs |= used_orders_info.first;
   1107  return true;
   1108 }
   1109 
   1110 // Working area for TokenizeCoefficients (per-group!)
   1111 struct EncCache {
   1112  // Allocates memory when first called.
   1113  Status InitOnce(JxlMemoryManager* memory_manager) {
   1114    if (num_nzeroes.xsize() == 0) {
   1115      JXL_ASSIGN_OR_RETURN(num_nzeroes,
   1116                           Image3I::Create(memory_manager, kGroupDimInBlocks,
   1117                                           kGroupDimInBlocks));
   1118    }
   1119    return true;
   1120  }
   1121  // TokenizeCoefficients
   1122  Image3I num_nzeroes;
   1123 };
   1124 
   1125 Status TokenizeAllCoefficients(const FrameHeader& frame_header,
   1126                               ThreadPool* pool,
   1127                               PassesEncoderState* enc_state) {
   1128  PassesSharedState& shared = enc_state->shared;
   1129  std::vector<EncCache> group_caches;
   1130  JxlMemoryManager* memory_manager = enc_state->memory_manager();
   1131  const auto tokenize_group_init = [&](const size_t num_threads) -> Status {
   1132    group_caches.resize(num_threads);
   1133    return true;
   1134  };
   1135  const auto tokenize_group = [&](const uint32_t group_index,
   1136                                  const size_t thread) -> Status {
   1137    // Tokenize coefficients.
   1138    const Rect rect = shared.frame_dim.BlockGroupRect(group_index);
   1139    for (size_t idx_pass = 0; idx_pass < enc_state->passes.size(); idx_pass++) {
   1140      JXL_ENSURE(enc_state->coeffs[idx_pass]->Type() == ACType::k32);
   1141      const int32_t* JXL_RESTRICT ac_rows[3] = {
   1142          enc_state->coeffs[idx_pass]->PlaneRow(0, group_index, 0).ptr32,
   1143          enc_state->coeffs[idx_pass]->PlaneRow(1, group_index, 0).ptr32,
   1144          enc_state->coeffs[idx_pass]->PlaneRow(2, group_index, 0).ptr32,
   1145      };
   1146      // Ensure group cache is initialized.
   1147      JXL_RETURN_IF_ERROR(group_caches[thread].InitOnce(memory_manager));
   1148      JXL_RETURN_IF_ERROR(TokenizeCoefficients(
   1149          &shared.coeff_orders[idx_pass * shared.coeff_order_size], rect,
   1150          ac_rows, shared.ac_strategy, frame_header.chroma_subsampling,
   1151          &group_caches[thread].num_nzeroes,
   1152          &enc_state->passes[idx_pass].ac_tokens[group_index], shared.quant_dc,
   1153          shared.raw_quant_field, shared.block_ctx_map));
   1154    }
   1155    return true;
   1156  };
   1157  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_groups,
   1158                                tokenize_group_init, tokenize_group,
   1159                                "TokenizeGroup"));
   1160  return true;
   1161 }
   1162 
   1163 Status EncodeGlobalDCInfo(const PassesSharedState& shared, BitWriter* writer,
   1164                          AuxOut* aux_out) {
   1165  // Encode quantizer DC and global scale.
   1166  QuantizerParams params = shared.quantizer.GetParams();
   1167  JXL_RETURN_IF_ERROR(
   1168      WriteQuantizerParams(params, writer, LayerType::Quant, aux_out));
   1169  JXL_RETURN_IF_ERROR(EncodeBlockCtxMap(shared.block_ctx_map, writer, aux_out));
   1170  JXL_RETURN_IF_ERROR(ColorCorrelationEncodeDC(shared.cmap.base(), writer,
   1171                                               LayerType::Dc, aux_out));
   1172  return true;
   1173 }
   1174 
   1175 // In streaming mode, this function only performs the histogram clustering and
   1176 // saves the histogram bitstreams in enc_state, the actual AC global bitstream
   1177 // is written in OutputAcGlobal() function after all the groups are processed.
   1178 Status EncodeGlobalACInfo(PassesEncoderState* enc_state, BitWriter* writer,
   1179                          ModularFrameEncoder* enc_modular, AuxOut* aux_out) {
   1180  PassesSharedState& shared = enc_state->shared;
   1181  JxlMemoryManager* memory_manager = enc_state->memory_manager();
   1182  JXL_RETURN_IF_ERROR(DequantMatricesEncode(memory_manager, shared.matrices,
   1183                                            writer, LayerType::Quant, aux_out,
   1184                                            enc_modular));
   1185  size_t num_histo_bits = CeilLog2Nonzero(shared.frame_dim.num_groups);
   1186  if (!enc_state->streaming_mode && num_histo_bits != 0) {
   1187    JXL_RETURN_IF_ERROR(
   1188        writer->WithMaxBits(num_histo_bits, LayerType::Ac, aux_out, [&] {
   1189          writer->Write(num_histo_bits, shared.num_histograms - 1);
   1190          return true;
   1191        }));
   1192  }
   1193 
   1194  for (size_t i = 0; i < enc_state->progressive_splitter.GetNumPasses(); i++) {
   1195    // Encode coefficient orders.
   1196    if (!enc_state->streaming_mode) {
   1197      size_t order_bits = 0;
   1198      JXL_RETURN_IF_ERROR(U32Coder::CanEncode(
   1199          kOrderEnc, enc_state->used_orders[i], &order_bits));
   1200      JXL_RETURN_IF_ERROR(
   1201          writer->WithMaxBits(order_bits, LayerType::Order, aux_out, [&] {
   1202            return U32Coder::Write(kOrderEnc, enc_state->used_orders[i],
   1203                                   writer);
   1204          }));
   1205      JXL_RETURN_IF_ERROR(
   1206          EncodeCoeffOrders(enc_state->used_orders[i],
   1207                            &shared.coeff_orders[i * shared.coeff_order_size],
   1208                            writer, LayerType::Order, aux_out));
   1209    }
   1210 
   1211    // Encode histograms.
   1212    HistogramParams hist_params(enc_state->cparams.speed_tier,
   1213                                shared.block_ctx_map.NumACContexts());
   1214    if (enc_state->cparams.speed_tier > SpeedTier::kTortoise) {
   1215      hist_params.lz77_method = HistogramParams::LZ77Method::kNone;
   1216    }
   1217    if (enc_state->cparams.decoding_speed_tier >= 1) {
   1218      hist_params.max_histograms = 6;
   1219    }
   1220    size_t num_histogram_groups = shared.num_histograms;
   1221    if (enc_state->streaming_mode) {
   1222      size_t prev_num_histograms =
   1223          enc_state->passes[i].codes.encoding_info.size();
   1224      if (enc_state->initialize_global_state) {
   1225        prev_num_histograms += kNumFixedHistograms;
   1226        hist_params.add_fixed_histograms = true;
   1227      }
   1228      size_t remaining_histograms = kClustersLimit - prev_num_histograms;
   1229      // Heuristic to assign budget of new histograms to DC groups.
   1230      // TODO(szabadka) Tune this together with the DC group ordering.
   1231      size_t max_histograms = remaining_histograms < 20
   1232                                  ? std::min<size_t>(remaining_histograms, 4)
   1233                                  : remaining_histograms / 4;
   1234      hist_params.max_histograms =
   1235          std::min(max_histograms, hist_params.max_histograms);
   1236      num_histogram_groups = 1;
   1237    }
   1238    hist_params.streaming_mode = enc_state->streaming_mode;
   1239    hist_params.initialize_global_state = enc_state->initialize_global_state;
   1240    JXL_ASSIGN_OR_RETURN(
   1241        size_t cost,
   1242        BuildAndEncodeHistograms(
   1243            memory_manager, hist_params,
   1244            num_histogram_groups * shared.block_ctx_map.NumACContexts(),
   1245            enc_state->passes[i].ac_tokens, &enc_state->passes[i].codes,
   1246            &enc_state->passes[i].context_map, writer, LayerType::Ac, aux_out));
   1247    (void)cost;
   1248  }
   1249 
   1250  return true;
   1251 }
   1252 
   1253 Status EncodeGroups(const FrameHeader& frame_header,
   1254                    PassesEncoderState* enc_state,
   1255                    ModularFrameEncoder* enc_modular, ThreadPool* pool,
   1256                    std::vector<std::unique_ptr<BitWriter>>* group_codes,
   1257                    AuxOut* aux_out) {
   1258  const PassesSharedState& shared = enc_state->shared;
   1259  JxlMemoryManager* memory_manager = shared.memory_manager;
   1260  const FrameDimensions& frame_dim = shared.frame_dim;
   1261  const size_t num_groups = frame_dim.num_groups;
   1262  const size_t num_passes = enc_state->progressive_splitter.GetNumPasses();
   1263  const size_t global_ac_index = frame_dim.num_dc_groups + 1;
   1264  const bool is_small_image =
   1265      !enc_state->streaming_mode && num_groups == 1 && num_passes == 1;
   1266  const size_t num_toc_entries =
   1267      is_small_image ? 1
   1268                     : AcGroupIndex(0, 0, num_groups, frame_dim.num_dc_groups) +
   1269                           num_groups * num_passes;
   1270  JXL_ENSURE(group_codes->empty());
   1271  group_codes->reserve(num_toc_entries);
   1272  for (size_t i = 0; i < num_toc_entries; ++i) {
   1273    group_codes->emplace_back(jxl::make_unique<BitWriter>(memory_manager));
   1274  }
   1275 
   1276  const auto get_output = [&](const size_t index) -> BitWriter* {
   1277    return (*group_codes)[is_small_image ? 0 : index].get();
   1278  };
   1279  auto ac_group_code = [&](size_t pass, size_t group) {
   1280    return get_output(AcGroupIndex(pass, group, frame_dim.num_groups,
   1281                                   frame_dim.num_dc_groups));
   1282  };
   1283 
   1284  if (enc_state->initialize_global_state) {
   1285    if (frame_header.flags & FrameHeader::kPatches) {
   1286      JXL_RETURN_IF_ERROR(PatchDictionaryEncoder::Encode(
   1287          shared.image_features.patches, get_output(0), LayerType::Dictionary,
   1288          aux_out));
   1289    }
   1290    if (frame_header.flags & FrameHeader::kSplines) {
   1291      JXL_RETURN_IF_ERROR(EncodeSplines(shared.image_features.splines,
   1292                                        get_output(0), LayerType::Splines,
   1293                                        HistogramParams(), aux_out));
   1294    }
   1295    if (frame_header.flags & FrameHeader::kNoise) {
   1296      JXL_RETURN_IF_ERROR(EncodeNoise(shared.image_features.noise_params,
   1297                                      get_output(0), LayerType::Noise,
   1298                                      aux_out));
   1299    }
   1300 
   1301    JXL_RETURN_IF_ERROR(DequantMatricesEncodeDC(shared.matrices, get_output(0),
   1302                                                LayerType::Quant, aux_out));
   1303    if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1304      JXL_RETURN_IF_ERROR(EncodeGlobalDCInfo(shared, get_output(0), aux_out));
   1305    }
   1306    JXL_RETURN_IF_ERROR(enc_modular->EncodeGlobalInfo(enc_state->streaming_mode,
   1307                                                      get_output(0), aux_out));
   1308    JXL_RETURN_IF_ERROR(enc_modular->EncodeStream(get_output(0), aux_out,
   1309                                                  LayerType::ModularGlobal,
   1310                                                  ModularStreamId::Global()));
   1311  }
   1312 
   1313  std::vector<std::unique_ptr<AuxOut>> aux_outs;
   1314  auto resize_aux_outs = [&aux_outs,
   1315                          aux_out](const size_t num_threads) -> Status {
   1316    if (aux_out == nullptr) {
   1317      aux_outs.resize(num_threads);
   1318    } else {
   1319      while (aux_outs.size() > num_threads) {
   1320        aux_out->Assimilate(*aux_outs.back());
   1321        aux_outs.pop_back();
   1322      }
   1323      while (num_threads > aux_outs.size()) {
   1324        aux_outs.emplace_back(jxl::make_unique<AuxOut>());
   1325      }
   1326    }
   1327    return true;
   1328  };
   1329 
   1330  std::atomic<bool> has_error{false};
   1331  const auto process_dc_group = [&](const uint32_t group_index,
   1332                                    const size_t thread) -> Status {
   1333    AuxOut* my_aux_out = aux_outs[thread].get();
   1334    uint32_t input_index = enc_state->streaming_mode ? 0 : group_index;
   1335    BitWriter* output = get_output(input_index + 1);
   1336    if (frame_header.encoding == FrameEncoding::kVarDCT &&
   1337        !(frame_header.flags & FrameHeader::kUseDcFrame)) {
   1338      JXL_RETURN_IF_ERROR(
   1339          output->WithMaxBits(2, LayerType::Dc, my_aux_out, [&] {
   1340            output->Write(2, enc_modular->extra_dc_precision[group_index]);
   1341            return true;
   1342          }));
   1343      JXL_RETURN_IF_ERROR(
   1344          enc_modular->EncodeStream(output, my_aux_out, LayerType::Dc,
   1345                                    ModularStreamId::VarDCTDC(group_index)));
   1346    }
   1347    JXL_RETURN_IF_ERROR(
   1348        enc_modular->EncodeStream(output, my_aux_out, LayerType::ModularDcGroup,
   1349                                  ModularStreamId::ModularDC(group_index)));
   1350    if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1351      const Rect& rect = enc_state->shared.frame_dim.DCGroupRect(input_index);
   1352      size_t nb_bits = CeilLog2Nonzero(rect.xsize() * rect.ysize());
   1353      if (nb_bits != 0) {
   1354        JXL_RETURN_IF_ERROR(output->WithMaxBits(
   1355            nb_bits, LayerType::ControlFields, my_aux_out, [&] {
   1356              output->Write(nb_bits,
   1357                            enc_modular->ac_metadata_size[group_index] - 1);
   1358              return true;
   1359            }));
   1360      }
   1361      JXL_RETURN_IF_ERROR(enc_modular->EncodeStream(
   1362          output, my_aux_out, LayerType::ControlFields,
   1363          ModularStreamId::ACMetadata(group_index)));
   1364    }
   1365    return true;
   1366  };
   1367  if (enc_state->streaming_mode) {
   1368    JXL_ENSURE(frame_dim.num_dc_groups == 1);
   1369    JXL_RETURN_IF_ERROR(resize_aux_outs(1));
   1370    JXL_RETURN_IF_ERROR(process_dc_group(enc_state->dc_group_index, 0));
   1371  } else {
   1372    JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.num_dc_groups,
   1373                                  resize_aux_outs, process_dc_group,
   1374                                  "EncodeDCGroup"));
   1375  }
   1376  if (has_error) return JXL_FAILURE("EncodeDCGroup failed");
   1377  if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1378    JXL_RETURN_IF_ERROR(EncodeGlobalACInfo(
   1379        enc_state, get_output(global_ac_index), enc_modular, aux_out));
   1380  }
   1381 
   1382  const auto process_group = [&](const uint32_t group_index,
   1383                                 const size_t thread) -> Status {
   1384    AuxOut* my_aux_out = aux_outs[thread].get();
   1385 
   1386    size_t ac_group_id =
   1387        enc_state->streaming_mode
   1388            ? enc_modular->ComputeStreamingAbsoluteAcGroupId(
   1389                  enc_state->dc_group_index, group_index, shared.frame_dim)
   1390            : group_index;
   1391 
   1392    for (size_t i = 0; i < num_passes; i++) {
   1393      JXL_DEBUG_V(2, "Encoding AC group %u [abs %" PRIuS "] pass %" PRIuS,
   1394                  group_index, ac_group_id, i);
   1395      if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1396        JXL_RETURN_IF_ERROR(EncodeGroupTokenizedCoefficients(
   1397            group_index, i, enc_state->histogram_idx[group_index], *enc_state,
   1398            ac_group_code(i, group_index), my_aux_out));
   1399      }
   1400      // Write all modular encoded data (color?, alpha, depth, extra channels)
   1401      JXL_RETURN_IF_ERROR(enc_modular->EncodeStream(
   1402          ac_group_code(i, group_index), my_aux_out, LayerType::ModularAcGroup,
   1403          ModularStreamId::ModularAC(ac_group_id, i)));
   1404      JXL_DEBUG_V(2,
   1405                  "AC group %u [abs %" PRIuS "] pass %" PRIuS
   1406                  " encoded size is %" PRIuS " bits",
   1407                  group_index, ac_group_id, i,
   1408                  ac_group_code(i, group_index)->BitsWritten());
   1409    }
   1410    return true;
   1411  };
   1412  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_groups, resize_aux_outs,
   1413                                process_group, "EncodeGroupCoefficients"));
   1414  // Resizing aux_outs to 0 also Assimilates the array.
   1415  static_cast<void>(resize_aux_outs(0));
   1416 
   1417  for (std::unique_ptr<BitWriter>& bw : *group_codes) {
   1418    JXL_RETURN_IF_ERROR(bw->WithMaxBits(8, LayerType::Ac, aux_out, [&] {
   1419      bw->ZeroPadToByte();  // end of group.
   1420      return true;
   1421    }));
   1422  }
   1423  return true;
   1424 }
   1425 
   1426 Status ComputeEncodingData(
   1427    const CompressParams& cparams, const FrameInfo& frame_info,
   1428    const CodecMetadata* metadata, JxlEncoderChunkedFrameAdapter& frame_data,
   1429    const jpeg::JPEGData* jpeg_data, size_t x0, size_t y0, size_t xsize,
   1430    size_t ysize, const JxlCmsInterface& cms, ThreadPool* pool,
   1431    FrameHeader& mutable_frame_header, ModularFrameEncoder& enc_modular,
   1432    PassesEncoderState& enc_state,
   1433    std::vector<std::unique_ptr<BitWriter>>* group_codes, AuxOut* aux_out) {
   1434  JXL_ENSURE(x0 + xsize <= frame_data.xsize);
   1435  JXL_ENSURE(y0 + ysize <= frame_data.ysize);
   1436  JxlMemoryManager* memory_manager = enc_state.memory_manager();
   1437  const FrameHeader& frame_header = mutable_frame_header;
   1438  PassesSharedState& shared = enc_state.shared;
   1439  shared.metadata = metadata;
   1440  if (enc_state.streaming_mode) {
   1441    shared.frame_dim.Set(
   1442        xsize, ysize, frame_header.group_size_shift,
   1443        /*max_hshift=*/0, /*max_vshift=*/0,
   1444        mutable_frame_header.encoding == FrameEncoding::kModular,
   1445        /*upsampling=*/1);
   1446  } else {
   1447    shared.frame_dim = frame_header.ToFrameDimensions();
   1448  }
   1449 
   1450  shared.image_features.patches.SetShared(&shared.reference_frames);
   1451  const FrameDimensions& frame_dim = shared.frame_dim;
   1452  JXL_ASSIGN_OR_RETURN(
   1453      shared.ac_strategy,
   1454      AcStrategyImage::Create(memory_manager, frame_dim.xsize_blocks,
   1455                              frame_dim.ysize_blocks));
   1456  JXL_ASSIGN_OR_RETURN(shared.raw_quant_field,
   1457                       ImageI::Create(memory_manager, frame_dim.xsize_blocks,
   1458                                      frame_dim.ysize_blocks));
   1459  JXL_ASSIGN_OR_RETURN(shared.epf_sharpness,
   1460                       ImageB::Create(memory_manager, frame_dim.xsize_blocks,
   1461                                      frame_dim.ysize_blocks));
   1462  JXL_ASSIGN_OR_RETURN(
   1463      shared.cmap, ColorCorrelationMap::Create(memory_manager, frame_dim.xsize,
   1464                                               frame_dim.ysize));
   1465  shared.coeff_order_size = kCoeffOrderMaxSize;
   1466  if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1467    shared.coeff_orders.resize(frame_header.passes.num_passes *
   1468                               kCoeffOrderMaxSize);
   1469  }
   1470 
   1471  JXL_ASSIGN_OR_RETURN(shared.quant_dc,
   1472                       ImageB::Create(memory_manager, frame_dim.xsize_blocks,
   1473                                      frame_dim.ysize_blocks));
   1474  JXL_ASSIGN_OR_RETURN(shared.dc_storage,
   1475                       Image3F::Create(memory_manager, frame_dim.xsize_blocks,
   1476                                       frame_dim.ysize_blocks));
   1477  shared.dc = &shared.dc_storage;
   1478 
   1479  const size_t num_extra_channels = metadata->m.num_extra_channels;
   1480  const ExtraChannelInfo* alpha_eci = metadata->m.Find(ExtraChannel::kAlpha);
   1481  const ExtraChannelInfo* black_eci = metadata->m.Find(ExtraChannel::kBlack);
   1482  const size_t alpha_idx = alpha_eci - metadata->m.extra_channel_info.data();
   1483  const size_t black_idx = black_eci - metadata->m.extra_channel_info.data();
   1484  const ColorEncoding c_enc = metadata->m.color_encoding;
   1485 
   1486  // Make the image patch bigger than the currently processed group in streaming
   1487  // mode so that we can take into account border pixels around the group when
   1488  // computing inverse Gaborish and adaptive quantization map.
   1489  int max_border = enc_state.streaming_mode ? kBlockDim : 0;
   1490  Rect frame_rect(0, 0, frame_data.xsize, frame_data.ysize);
   1491  Rect frame_area_rect = Rect(x0, y0, xsize, ysize);
   1492  Rect patch_rect = frame_area_rect.Extend(max_border, frame_rect);
   1493  JXL_ENSURE(patch_rect.IsInside(frame_rect));
   1494 
   1495  // Allocating a large enough image avoids a copy when padding.
   1496  JXL_ASSIGN_OR_RETURN(
   1497      Image3F color,
   1498      Image3F::Create(memory_manager, RoundUpToBlockDim(patch_rect.xsize()),
   1499                      RoundUpToBlockDim(patch_rect.ysize())));
   1500  JXL_RETURN_IF_ERROR(color.ShrinkTo(patch_rect.xsize(), patch_rect.ysize()));
   1501  std::vector<ImageF> extra_channels(num_extra_channels);
   1502  for (auto& extra_channel : extra_channels) {
   1503    JXL_ASSIGN_OR_RETURN(
   1504        extra_channel,
   1505        ImageF::Create(memory_manager, patch_rect.xsize(), patch_rect.ysize()));
   1506  }
   1507  ImageF* alpha = alpha_eci ? &extra_channels[alpha_idx] : nullptr;
   1508  ImageF* black = black_eci ? &extra_channels[black_idx] : nullptr;
   1509  bool has_interleaved_alpha = false;
   1510  JxlChunkedFrameInputSource input = frame_data.GetInputSource();
   1511  if (!jpeg_data) {
   1512    JXL_RETURN_IF_ERROR(CopyColorChannels(input, patch_rect, frame_info,
   1513                                          metadata->m, pool, &color, alpha,
   1514                                          &has_interleaved_alpha));
   1515  }
   1516  JXL_RETURN_IF_ERROR(CopyExtraChannels(input, patch_rect, frame_info,
   1517                                        metadata->m, has_interleaved_alpha,
   1518                                        pool, &extra_channels));
   1519 
   1520  enc_state.cparams = cparams;
   1521 
   1522  Image3F linear_storage;
   1523  Image3F* linear = nullptr;
   1524 
   1525  if (!jpeg_data) {
   1526    if (frame_header.color_transform == ColorTransform::kXYB &&
   1527        frame_info.ib_needs_color_transform) {
   1528      if (frame_header.encoding == FrameEncoding::kVarDCT &&
   1529          cparams.speed_tier <= SpeedTier::kKitten) {
   1530        JXL_ASSIGN_OR_RETURN(linear_storage,
   1531                             Image3F::Create(memory_manager, patch_rect.xsize(),
   1532                                             patch_rect.ysize()));
   1533        linear = &linear_storage;
   1534      }
   1535      JXL_RETURN_IF_ERROR(ToXYB(c_enc, metadata->m.IntensityTarget(), black,
   1536                                pool, &color, cms, linear));
   1537    } else {
   1538      // Nothing to do.
   1539      // RGB or YCbCr: forward YCbCr is not implemented, this is only used when
   1540      // the input is already in YCbCr
   1541      // If encoding a special DC or reference frame: input is already in XYB.
   1542    }
   1543    bool lossless = cparams.IsLossless();
   1544    if (alpha && !alpha_eci->alpha_associated &&
   1545        frame_header.frame_type == FrameType::kRegularFrame &&
   1546        !ApplyOverride(cparams.keep_invisible, cparams.IsLossless()) &&
   1547        cparams.ec_resampling == cparams.resampling &&
   1548        !cparams.disable_perceptual_optimizations) {
   1549      // simplify invisible pixels
   1550      SimplifyInvisible(&color, *alpha, lossless);
   1551      if (linear) {
   1552        SimplifyInvisible(linear, *alpha, lossless);
   1553      }
   1554    }
   1555    JXL_RETURN_IF_ERROR(PadImageToBlockMultipleInPlace(&color));
   1556  }
   1557 
   1558  // Rectangle within color that corresponds to the currently processed group in
   1559  // streaming mode.
   1560  Rect group_rect(x0 - patch_rect.x0(), y0 - patch_rect.y0(),
   1561                  RoundUpToBlockDim(xsize), RoundUpToBlockDim(ysize));
   1562 
   1563  if (enc_state.initialize_global_state && !jpeg_data) {
   1564    ComputeChromacityAdjustments(cparams, color, group_rect,
   1565                                 &mutable_frame_header);
   1566  }
   1567 
   1568  bool has_jpeg_data = (jpeg_data != nullptr);
   1569  ComputeNoiseParams(cparams, enc_state.streaming_mode, has_jpeg_data, color,
   1570                     frame_dim, &mutable_frame_header,
   1571                     &shared.image_features.noise_params);
   1572 
   1573  JXL_RETURN_IF_ERROR(
   1574      DownsampleColorChannels(cparams, frame_header, has_jpeg_data, &color));
   1575 
   1576  if (cparams.ec_resampling != 1 && !cparams.already_downsampled) {
   1577    for (ImageF& ec : extra_channels) {
   1578      JXL_ASSIGN_OR_RETURN(ec, DownsampleImage(ec, cparams.ec_resampling));
   1579    }
   1580  }
   1581 
   1582  if (!enc_state.streaming_mode) {
   1583    group_rect = Rect(color);
   1584  }
   1585 
   1586  if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1587    enc_state.passes.resize(enc_state.progressive_splitter.GetNumPasses());
   1588    for (PassesEncoderState::PassData& pass : enc_state.passes) {
   1589      pass.ac_tokens.resize(shared.frame_dim.num_groups);
   1590    }
   1591    if (jpeg_data) {
   1592      JXL_RETURN_IF_ERROR(ComputeJPEGTranscodingData(
   1593          *jpeg_data, frame_header, pool, &enc_modular, &enc_state));
   1594    } else {
   1595      JXL_RETURN_IF_ERROR(ComputeVarDCTEncodingData(
   1596          frame_header, linear, &color, group_rect, cms, pool, &enc_modular,
   1597          &enc_state, aux_out));
   1598    }
   1599    JXL_RETURN_IF_ERROR(ComputeAllCoeffOrders(enc_state, frame_dim));
   1600    if (!enc_state.streaming_mode) {
   1601      shared.num_histograms = 1;
   1602      enc_state.histogram_idx.resize(frame_dim.num_groups);
   1603    }
   1604    JXL_RETURN_IF_ERROR(
   1605        TokenizeAllCoefficients(frame_header, pool, &enc_state));
   1606  }
   1607 
   1608  if (cparams.modular_mode || !extra_channels.empty()) {
   1609    JXL_RETURN_IF_ERROR(enc_modular.ComputeEncodingData(
   1610        frame_header, metadata->m, &color, extra_channels, group_rect,
   1611        frame_dim, frame_area_rect, &enc_state, cms, pool, aux_out,
   1612        /*do_color=*/cparams.modular_mode));
   1613  }
   1614 
   1615  if (!enc_state.streaming_mode) {
   1616    if (cparams.speed_tier < SpeedTier::kTortoise ||
   1617        !cparams.ModularPartIsLossless() || cparams.responsive ||
   1618        !cparams.custom_fixed_tree.empty()) {
   1619      // Use local trees if doing lossless modular, unless at very slow speeds.
   1620      JXL_RETURN_IF_ERROR(enc_modular.ComputeTree(pool));
   1621      JXL_RETURN_IF_ERROR(enc_modular.ComputeTokens(pool));
   1622    }
   1623    mutable_frame_header.UpdateFlag(shared.image_features.patches.HasAny(),
   1624                                    FrameHeader::kPatches);
   1625    mutable_frame_header.UpdateFlag(shared.image_features.splines.HasAny(),
   1626                                    FrameHeader::kSplines);
   1627  }
   1628 
   1629  JXL_RETURN_IF_ERROR(EncodeGroups(frame_header, &enc_state, &enc_modular, pool,
   1630                                   group_codes, aux_out));
   1631  if (enc_state.streaming_mode) {
   1632    const size_t group_index = enc_state.dc_group_index;
   1633    enc_modular.ClearStreamData(ModularStreamId::VarDCTDC(group_index));
   1634    enc_modular.ClearStreamData(ModularStreamId::ACMetadata(group_index));
   1635    enc_modular.ClearModularStreamData();
   1636  }
   1637  return true;
   1638 }
   1639 
   1640 Status PermuteGroups(const CompressParams& cparams,
   1641                     const FrameDimensions& frame_dim, size_t num_passes,
   1642                     std::vector<coeff_order_t>* permutation,
   1643                     std::vector<std::unique_ptr<BitWriter>>* group_codes) {
   1644  const size_t num_groups = frame_dim.num_groups;
   1645  if (!cparams.centerfirst || (num_passes == 1 && num_groups == 1)) {
   1646    return true;
   1647  }
   1648  // Don't permute global DC/AC or DC.
   1649  permutation->resize(frame_dim.num_dc_groups + 2);
   1650  std::iota(permutation->begin(), permutation->end(), 0);
   1651  std::vector<coeff_order_t> ac_group_order(num_groups);
   1652  std::iota(ac_group_order.begin(), ac_group_order.end(), 0);
   1653  size_t group_dim = frame_dim.group_dim;
   1654 
   1655  // The center of the image is either given by parameters or chosen
   1656  // to be the middle of the image by default if center_x, center_y resp.
   1657  // are not provided.
   1658 
   1659  int64_t imag_cx;
   1660  if (cparams.center_x != static_cast<size_t>(-1)) {
   1661    JXL_RETURN_IF_ERROR(cparams.center_x < frame_dim.xsize);
   1662    imag_cx = cparams.center_x;
   1663  } else {
   1664    imag_cx = frame_dim.xsize / 2;
   1665  }
   1666 
   1667  int64_t imag_cy;
   1668  if (cparams.center_y != static_cast<size_t>(-1)) {
   1669    JXL_RETURN_IF_ERROR(cparams.center_y < frame_dim.ysize);
   1670    imag_cy = cparams.center_y;
   1671  } else {
   1672    imag_cy = frame_dim.ysize / 2;
   1673  }
   1674 
   1675  // The center of the group containing the center of the image.
   1676  int64_t cx = (imag_cx / group_dim) * group_dim + group_dim / 2;
   1677  int64_t cy = (imag_cy / group_dim) * group_dim + group_dim / 2;
   1678  // This identifies in what area of the central group the center of the image
   1679  // lies in.
   1680  double direction = -std::atan2(imag_cy - cy, imag_cx - cx);
   1681  // This identifies the side of the central group the center of the image
   1682  // lies closest to. This can take values 0, 1, 2, 3 corresponding to left,
   1683  // bottom, right, top.
   1684  int64_t side = std::fmod((direction + 5 * kPi / 4), 2 * kPi) * 2 / kPi;
   1685  auto get_distance_from_center = [&](size_t gid) {
   1686    Rect r = frame_dim.GroupRect(gid);
   1687    int64_t gcx = r.x0() + group_dim / 2;
   1688    int64_t gcy = r.y0() + group_dim / 2;
   1689    int64_t dx = gcx - cx;
   1690    int64_t dy = gcy - cy;
   1691    // The angle is determined by taking atan2 and adding an appropriate
   1692    // starting point depending on the side we want to start on.
   1693    double angle = std::remainder(
   1694        std::atan2(dy, dx) + kPi / 4 + side * (kPi / 2), 2 * kPi);
   1695    // Concentric squares in clockwise order.
   1696    return std::make_pair(std::max(std::abs(dx), std::abs(dy)), angle);
   1697  };
   1698  std::sort(ac_group_order.begin(), ac_group_order.end(),
   1699            [&](coeff_order_t a, coeff_order_t b) {
   1700              return get_distance_from_center(a) < get_distance_from_center(b);
   1701            });
   1702  std::vector<coeff_order_t> inv_ac_group_order(ac_group_order.size(), 0);
   1703  for (size_t i = 0; i < ac_group_order.size(); i++) {
   1704    inv_ac_group_order[ac_group_order[i]] = i;
   1705  }
   1706  for (size_t i = 0; i < num_passes; i++) {
   1707    size_t pass_start = permutation->size();
   1708    for (coeff_order_t v : inv_ac_group_order) {
   1709      permutation->push_back(pass_start + v);
   1710    }
   1711  }
   1712  std::vector<std::unique_ptr<BitWriter>> new_group_codes(group_codes->size());
   1713  for (size_t i = 0; i < permutation->size(); i++) {
   1714    new_group_codes[(*permutation)[i]] = std::move((*group_codes)[i]);
   1715  }
   1716  group_codes->swap(new_group_codes);
   1717  return true;
   1718 }
   1719 
   1720 bool CanDoStreamingEncoding(const CompressParams& cparams,
   1721                            const FrameInfo& frame_info,
   1722                            const CodecMetadata& metadata,
   1723                            const JxlEncoderChunkedFrameAdapter& frame_data) {
   1724  if (cparams.buffering == 0) {
   1725    return false;
   1726  }
   1727  if (cparams.buffering == -1) {
   1728    if (cparams.speed_tier < SpeedTier::kTortoise) return false;
   1729    if (cparams.speed_tier < SpeedTier::kSquirrel &&
   1730        cparams.butteraugli_distance > 0.5f) {
   1731      return false;
   1732    }
   1733    if (cparams.speed_tier == SpeedTier::kSquirrel &&
   1734        cparams.butteraugli_distance >= 3.f) {
   1735      return false;
   1736    }
   1737  }
   1738 
   1739  // TODO(veluca): handle different values of `buffering`.
   1740  if (frame_data.xsize <= 2048 && frame_data.ysize <= 2048) {
   1741    return false;
   1742  }
   1743  if (frame_data.IsJPEG()) {
   1744    return false;
   1745  }
   1746  if (cparams.noise == Override::kOn || cparams.patches == Override::kOn) {
   1747    return false;
   1748  }
   1749  if (cparams.progressive_dc != 0 || frame_info.dc_level != 0) {
   1750    return false;
   1751  }
   1752  if (cparams.resampling != 1 || cparams.ec_resampling != 1) {
   1753    return false;
   1754  }
   1755  if (cparams.max_error_mode) {
   1756    return false;
   1757  }
   1758  if (!cparams.ModularPartIsLossless() || cparams.responsive > 0) {
   1759    if (metadata.m.num_extra_channels > 0 || cparams.modular_mode) {
   1760      return false;
   1761    }
   1762  }
   1763  ColorTransform ok_color_transform =
   1764      cparams.modular_mode ? ColorTransform::kNone : ColorTransform::kXYB;
   1765  if (cparams.color_transform != ok_color_transform) {
   1766    return false;
   1767  }
   1768  return true;
   1769 }
   1770 
   1771 Status ComputePermutationForStreaming(size_t xsize, size_t ysize,
   1772                                      size_t group_size, size_t num_passes,
   1773                                      std::vector<coeff_order_t>& permutation,
   1774                                      std::vector<size_t>& dc_group_order) {
   1775  // This is only valid in VarDCT mode, otherwise there can be group shift.
   1776  const size_t dc_group_size = group_size * kBlockDim;
   1777  const size_t group_xsize = DivCeil(xsize, group_size);
   1778  const size_t group_ysize = DivCeil(ysize, group_size);
   1779  const size_t dc_group_xsize = DivCeil(xsize, dc_group_size);
   1780  const size_t dc_group_ysize = DivCeil(ysize, dc_group_size);
   1781  const size_t num_groups = group_xsize * group_ysize;
   1782  const size_t num_dc_groups = dc_group_xsize * dc_group_ysize;
   1783  const size_t num_sections = 2 + num_dc_groups + num_passes * num_groups;
   1784  permutation.resize(num_sections);
   1785  size_t new_ix = 0;
   1786  // DC Global is first
   1787  permutation[0] = new_ix++;
   1788  // TODO(szabadka) Change the dc group order to center-first.
   1789  for (size_t dc_y = 0; dc_y < dc_group_ysize; ++dc_y) {
   1790    for (size_t dc_x = 0; dc_x < dc_group_xsize; ++dc_x) {
   1791      size_t dc_ix = dc_y * dc_group_xsize + dc_x;
   1792      dc_group_order.push_back(dc_ix);
   1793      permutation[1 + dc_ix] = new_ix++;
   1794      size_t ac_y0 = dc_y * kBlockDim;
   1795      size_t ac_x0 = dc_x * kBlockDim;
   1796      size_t ac_y1 = std::min<size_t>(group_ysize, ac_y0 + kBlockDim);
   1797      size_t ac_x1 = std::min<size_t>(group_xsize, ac_x0 + kBlockDim);
   1798      for (size_t pass = 0; pass < num_passes; ++pass) {
   1799        for (size_t ac_y = ac_y0; ac_y < ac_y1; ++ac_y) {
   1800          for (size_t ac_x = ac_x0; ac_x < ac_x1; ++ac_x) {
   1801            size_t group_ix = ac_y * group_xsize + ac_x;
   1802            size_t old_ix =
   1803                AcGroupIndex(pass, group_ix, num_groups, num_dc_groups);
   1804            permutation[old_ix] = new_ix++;
   1805          }
   1806        }
   1807      }
   1808    }
   1809  }
   1810  // AC Global is last
   1811  permutation[1 + num_dc_groups] = new_ix++;
   1812  JXL_ENSURE(new_ix == num_sections);
   1813  return true;
   1814 }
   1815 
   1816 constexpr size_t kGroupSizeOffset[4] = {
   1817    static_cast<size_t>(0),
   1818    static_cast<size_t>(1024),
   1819    static_cast<size_t>(17408),
   1820    static_cast<size_t>(4211712),
   1821 };
   1822 constexpr size_t kTOCBits[4] = {12, 16, 24, 32};
   1823 
   1824 size_t TOCBucket(size_t group_size) {
   1825  size_t bucket = 0;
   1826  while (bucket < 3 && group_size >= kGroupSizeOffset[bucket + 1]) ++bucket;
   1827  return bucket;
   1828 }
   1829 
   1830 size_t TOCSize(const std::vector<size_t>& group_sizes) {
   1831  size_t toc_bits = 0;
   1832  for (size_t group_size : group_sizes) {
   1833    toc_bits += kTOCBits[TOCBucket(group_size)];
   1834  }
   1835  return (toc_bits + 7) / 8;
   1836 }
   1837 
   1838 StatusOr<PaddedBytes> EncodeTOC(JxlMemoryManager* memory_manager,
   1839                                const std::vector<size_t>& group_sizes,
   1840                                AuxOut* aux_out) {
   1841  BitWriter writer{memory_manager};
   1842  JXL_RETURN_IF_ERROR(writer.WithMaxBits(
   1843      32 * group_sizes.size(), LayerType::Toc, aux_out, [&]() -> Status {
   1844        for (size_t group_size : group_sizes) {
   1845          JXL_RETURN_IF_ERROR(U32Coder::Write(kTocDist, group_size, &writer));
   1846        }
   1847        writer.ZeroPadToByte();  // before first group
   1848        return true;
   1849      }));
   1850  return std::move(writer).TakeBytes();
   1851 }
   1852 
   1853 Status ComputeGroupDataOffset(size_t frame_header_size, size_t dc_global_size,
   1854                              size_t num_sections, size_t& min_dc_global_size,
   1855                              size_t& group_offset) {
   1856  size_t max_toc_bits = (num_sections - 1) * 32;
   1857  size_t min_toc_bits = (num_sections - 1) * 12;
   1858  size_t max_padding = (max_toc_bits - min_toc_bits + 7) / 8;
   1859  min_dc_global_size = dc_global_size;
   1860  size_t dc_global_bucket = TOCBucket(min_dc_global_size);
   1861  while (TOCBucket(min_dc_global_size + max_padding) > dc_global_bucket) {
   1862    dc_global_bucket = TOCBucket(min_dc_global_size + max_padding);
   1863    min_dc_global_size = kGroupSizeOffset[dc_global_bucket];
   1864  }
   1865  JXL_ENSURE(TOCBucket(min_dc_global_size) == dc_global_bucket);
   1866  JXL_ENSURE(TOCBucket(min_dc_global_size + max_padding) == dc_global_bucket);
   1867  max_toc_bits += kTOCBits[dc_global_bucket];
   1868  size_t max_toc_size = (max_toc_bits + 7) / 8;
   1869  group_offset = frame_header_size + max_toc_size + min_dc_global_size;
   1870  return true;
   1871 }
   1872 
   1873 size_t ComputeDcGlobalPadding(const std::vector<size_t>& group_sizes,
   1874                              size_t frame_header_size,
   1875                              size_t group_data_offset,
   1876                              size_t min_dc_global_size) {
   1877  std::vector<size_t> new_group_sizes = group_sizes;
   1878  new_group_sizes[0] = min_dc_global_size;
   1879  size_t toc_size = TOCSize(new_group_sizes);
   1880  size_t actual_offset = frame_header_size + toc_size + group_sizes[0];
   1881  return group_data_offset - actual_offset;
   1882 }
   1883 
   1884 Status OutputGroups(std::vector<std::unique_ptr<BitWriter>>&& group_codes,
   1885                    std::vector<size_t>* group_sizes,
   1886                    JxlEncoderOutputProcessorWrapper* output_processor) {
   1887  JXL_ENSURE(group_codes.size() >= 4);
   1888  {
   1889    PaddedBytes dc_group = std::move(*group_codes[1]).TakeBytes();
   1890    group_sizes->push_back(dc_group.size());
   1891    JXL_RETURN_IF_ERROR(AppendData(*output_processor, dc_group));
   1892  }
   1893  for (size_t i = 3; i < group_codes.size(); ++i) {
   1894    PaddedBytes ac_group = std::move(*group_codes[i]).TakeBytes();
   1895    group_sizes->push_back(ac_group.size());
   1896    JXL_RETURN_IF_ERROR(AppendData(*output_processor, ac_group));
   1897  }
   1898  return true;
   1899 }
   1900 
   1901 void RemoveUnusedHistograms(std::vector<uint8_t>& context_map,
   1902                            EntropyEncodingData& codes) {
   1903  std::vector<int> remap(256, -1);
   1904  std::vector<uint8_t> inv_remap;
   1905  for (uint8_t& context : context_map) {
   1906    const uint8_t histo_ix = context;
   1907    if (remap[histo_ix] == -1) {
   1908      remap[histo_ix] = inv_remap.size();
   1909      inv_remap.push_back(histo_ix);
   1910    }
   1911    context = remap[histo_ix];
   1912  }
   1913  EntropyEncodingData new_codes;
   1914  new_codes.use_prefix_code = codes.use_prefix_code;
   1915  new_codes.lz77 = codes.lz77;
   1916  for (uint8_t histo_idx : inv_remap) {
   1917    new_codes.encoding_info.emplace_back(
   1918        std::move(codes.encoding_info[histo_idx]));
   1919    new_codes.uint_config.emplace_back(codes.uint_config[histo_idx]);
   1920    new_codes.encoded_histograms.emplace_back(
   1921        std::move(codes.encoded_histograms[histo_idx]));
   1922  }
   1923  codes = std::move(new_codes);
   1924 }
   1925 
   1926 Status OutputAcGlobal(PassesEncoderState& enc_state,
   1927                      const FrameDimensions& frame_dim,
   1928                      std::vector<size_t>* group_sizes,
   1929                      JxlEncoderOutputProcessorWrapper* output_processor,
   1930                      AuxOut* aux_out) {
   1931  JXL_ENSURE(frame_dim.num_groups > 1);
   1932  JxlMemoryManager* memory_manager = enc_state.memory_manager();
   1933  BitWriter writer{memory_manager};
   1934  {
   1935    size_t num_histo_bits = CeilLog2Nonzero(frame_dim.num_groups);
   1936    JXL_RETURN_IF_ERROR(
   1937        writer.WithMaxBits(num_histo_bits + 1, LayerType::Ac, aux_out, [&] {
   1938          writer.Write(1, 1);  // default dequant matrices
   1939          writer.Write(num_histo_bits, frame_dim.num_dc_groups - 1);
   1940          return true;
   1941        }));
   1942  }
   1943  const PassesSharedState& shared = enc_state.shared;
   1944  for (size_t i = 0; i < enc_state.progressive_splitter.GetNumPasses(); i++) {
   1945    // Encode coefficient orders.
   1946    size_t order_bits = 0;
   1947    JXL_RETURN_IF_ERROR(
   1948        U32Coder::CanEncode(kOrderEnc, enc_state.used_orders[i], &order_bits));
   1949    JXL_RETURN_IF_ERROR(
   1950        writer.WithMaxBits(order_bits, LayerType::Order, aux_out, [&] {
   1951          return U32Coder::Write(kOrderEnc, enc_state.used_orders[i], &writer);
   1952        }));
   1953    JXL_RETURN_IF_ERROR(
   1954        EncodeCoeffOrders(enc_state.used_orders[i],
   1955                          &shared.coeff_orders[i * shared.coeff_order_size],
   1956                          &writer, LayerType::Order, aux_out));
   1957    // Fix up context map and entropy codes to remove any fix histograms that
   1958    // were not selected by clustering.
   1959    RemoveUnusedHistograms(enc_state.passes[i].context_map,
   1960                           enc_state.passes[i].codes);
   1961    JXL_RETURN_IF_ERROR(EncodeHistograms(enc_state.passes[i].context_map,
   1962                                         enc_state.passes[i].codes, &writer,
   1963                                         LayerType::Ac, aux_out));
   1964  }
   1965  JXL_RETURN_IF_ERROR(writer.WithMaxBits(8, LayerType::Ac, aux_out, [&] {
   1966    writer.ZeroPadToByte();  // end of group.
   1967    return true;
   1968  }));
   1969  PaddedBytes ac_global = std::move(writer).TakeBytes();
   1970  group_sizes->push_back(ac_global.size());
   1971  JXL_RETURN_IF_ERROR(AppendData(*output_processor, ac_global));
   1972  return true;
   1973 }
   1974 
   1975 Status EncodeFrameStreaming(JxlMemoryManager* memory_manager,
   1976                            const CompressParams& cparams,
   1977                            const FrameInfo& frame_info,
   1978                            const CodecMetadata* metadata,
   1979                            JxlEncoderChunkedFrameAdapter& frame_data,
   1980                            const JxlCmsInterface& cms, ThreadPool* pool,
   1981                            JxlEncoderOutputProcessorWrapper* output_processor,
   1982                            AuxOut* aux_out) {
   1983  PassesEncoderState enc_state{memory_manager};
   1984  SetProgressiveMode(cparams, &enc_state.progressive_splitter);
   1985  FrameHeader frame_header(metadata);
   1986  std::unique_ptr<jpeg::JPEGData> jpeg_data;
   1987  if (frame_data.IsJPEG()) {
   1988    jpeg_data = frame_data.TakeJPEGData();
   1989    JXL_ENSURE(jpeg_data);
   1990  }
   1991  JXL_RETURN_IF_ERROR(MakeFrameHeader(frame_data.xsize, frame_data.ysize,
   1992                                      cparams, enc_state.progressive_splitter,
   1993                                      frame_info, jpeg_data.get(), true,
   1994                                      &frame_header));
   1995  const size_t num_passes = enc_state.progressive_splitter.GetNumPasses();
   1996  JXL_ASSIGN_OR_RETURN(
   1997      ModularFrameEncoder enc_modular,
   1998      ModularFrameEncoder::Create(memory_manager, frame_header, cparams, true));
   1999  std::vector<coeff_order_t> permutation;
   2000  std::vector<size_t> dc_group_order;
   2001  size_t group_size = frame_header.ToFrameDimensions().group_dim;
   2002  JXL_RETURN_IF_ERROR(ComputePermutationForStreaming(
   2003      frame_data.xsize, frame_data.ysize, group_size, num_passes, permutation,
   2004      dc_group_order));
   2005  enc_state.shared.num_histograms = dc_group_order.size();
   2006  size_t dc_group_size = group_size * kBlockDim;
   2007  size_t dc_group_xsize = DivCeil(frame_data.xsize, dc_group_size);
   2008  size_t min_dc_global_size = 0;
   2009  size_t group_data_offset = 0;
   2010  PaddedBytes frame_header_bytes{memory_manager};
   2011  PaddedBytes dc_global_bytes{memory_manager};
   2012  std::vector<size_t> group_sizes;
   2013  size_t start_pos = output_processor->CurrentPosition();
   2014  for (size_t i = 0; i < dc_group_order.size(); ++i) {
   2015    size_t dc_ix = dc_group_order[i];
   2016    size_t dc_y = dc_ix / dc_group_xsize;
   2017    size_t dc_x = dc_ix % dc_group_xsize;
   2018    size_t y0 = dc_y * dc_group_size;
   2019    size_t x0 = dc_x * dc_group_size;
   2020    size_t ysize = std::min<size_t>(dc_group_size, frame_data.ysize - y0);
   2021    size_t xsize = std::min<size_t>(dc_group_size, frame_data.xsize - x0);
   2022    size_t group_xsize = DivCeil(xsize, group_size);
   2023    size_t group_ysize = DivCeil(ysize, group_size);
   2024    JXL_DEBUG_V(2,
   2025                "Encoding DC group #%" PRIuS " dc_y = %" PRIuS " dc_x = %" PRIuS
   2026                " (x0, y0) = (%" PRIuS ", %" PRIuS ") (xsize, ysize) = (%" PRIuS
   2027                ", %" PRIuS ")",
   2028                dc_ix, dc_y, dc_x, x0, y0, xsize, ysize);
   2029    enc_state.streaming_mode = true;
   2030    enc_state.initialize_global_state = (i == 0);
   2031    enc_state.dc_group_index = dc_ix;
   2032    enc_state.histogram_idx = std::vector<size_t>(group_xsize * group_ysize, i);
   2033    std::vector<std::unique_ptr<BitWriter>> group_codes;
   2034    JXL_RETURN_IF_ERROR(ComputeEncodingData(
   2035        cparams, frame_info, metadata, frame_data, jpeg_data.get(), x0, y0,
   2036        xsize, ysize, cms, pool, frame_header, enc_modular, enc_state,
   2037        &group_codes, aux_out));
   2038    JXL_ENSURE(enc_state.special_frames.empty());
   2039    if (i == 0) {
   2040      BitWriter writer{memory_manager};
   2041      JXL_RETURN_IF_ERROR(WriteFrameHeader(frame_header, &writer, aux_out));
   2042      JXL_RETURN_IF_ERROR(
   2043          writer.WithMaxBits(8, LayerType::Header, aux_out, [&]() -> Status {
   2044            writer.Write(1, 1);  // write permutation
   2045            JXL_RETURN_IF_ERROR(EncodePermutation(
   2046                permutation.data(), /*skip=*/0, permutation.size(), &writer,
   2047                LayerType::Header, aux_out));
   2048            writer.ZeroPadToByte();
   2049            return true;
   2050          }));
   2051      frame_header_bytes = std::move(writer).TakeBytes();
   2052      dc_global_bytes = std::move(*group_codes[0]).TakeBytes();
   2053      JXL_RETURN_IF_ERROR(ComputeGroupDataOffset(
   2054          frame_header_bytes.size(), dc_global_bytes.size(), permutation.size(),
   2055          min_dc_global_size, group_data_offset));
   2056      JXL_DEBUG_V(2, "Frame header size: %" PRIuS, frame_header_bytes.size());
   2057      JXL_DEBUG_V(2, "DC global size: %" PRIuS ", min size for TOC: %" PRIuS,
   2058                  dc_global_bytes.size(), min_dc_global_size);
   2059      JXL_DEBUG_V(2, "Num groups: %" PRIuS " group data offset: %" PRIuS,
   2060                  permutation.size(), group_data_offset);
   2061      group_sizes.push_back(dc_global_bytes.size());
   2062      JXL_RETURN_IF_ERROR(
   2063          output_processor->Seek(start_pos + group_data_offset));
   2064    }
   2065    JXL_RETURN_IF_ERROR(
   2066        OutputGroups(std::move(group_codes), &group_sizes, output_processor));
   2067  }
   2068  if (frame_header.encoding == FrameEncoding::kVarDCT) {
   2069    JXL_RETURN_IF_ERROR(
   2070        OutputAcGlobal(enc_state, frame_header.ToFrameDimensions(),
   2071                       &group_sizes, output_processor, aux_out));
   2072  } else {
   2073    group_sizes.push_back(0);
   2074  }
   2075  JXL_ENSURE(group_sizes.size() == permutation.size());
   2076  size_t end_pos = output_processor->CurrentPosition();
   2077  JXL_RETURN_IF_ERROR(output_processor->Seek(start_pos));
   2078  size_t padding_size =
   2079      ComputeDcGlobalPadding(group_sizes, frame_header_bytes.size(),
   2080                             group_data_offset, min_dc_global_size);
   2081  group_sizes[0] += padding_size;
   2082  JXL_ASSIGN_OR_RETURN(PaddedBytes toc_bytes,
   2083                       EncodeTOC(memory_manager, group_sizes, aux_out));
   2084  std::vector<uint8_t> padding_bytes(padding_size);
   2085  JXL_RETURN_IF_ERROR(AppendData(*output_processor, frame_header_bytes));
   2086  JXL_RETURN_IF_ERROR(AppendData(*output_processor, toc_bytes));
   2087  JXL_RETURN_IF_ERROR(AppendData(*output_processor, dc_global_bytes));
   2088  JXL_RETURN_IF_ERROR(AppendData(*output_processor, padding_bytes));
   2089  JXL_DEBUG_V(2, "TOC size: %" PRIuS " padding bytes after DC global: %" PRIuS,
   2090              toc_bytes.size(), padding_size);
   2091  JXL_ENSURE(output_processor->CurrentPosition() ==
   2092             start_pos + group_data_offset);
   2093  JXL_RETURN_IF_ERROR(output_processor->Seek(end_pos));
   2094  return true;
   2095 }
   2096 
   2097 Status EncodeFrameOneShot(JxlMemoryManager* memory_manager,
   2098                          const CompressParams& cparams,
   2099                          const FrameInfo& frame_info,
   2100                          const CodecMetadata* metadata,
   2101                          JxlEncoderChunkedFrameAdapter& frame_data,
   2102                          const JxlCmsInterface& cms, ThreadPool* pool,
   2103                          JxlEncoderOutputProcessorWrapper* output_processor,
   2104                          AuxOut* aux_out) {
   2105  PassesEncoderState enc_state{memory_manager};
   2106  SetProgressiveMode(cparams, &enc_state.progressive_splitter);
   2107  FrameHeader frame_header(metadata);
   2108  std::unique_ptr<jpeg::JPEGData> jpeg_data;
   2109  if (frame_data.IsJPEG()) {
   2110    jpeg_data = frame_data.TakeJPEGData();
   2111    JXL_ENSURE(jpeg_data);
   2112  }
   2113  JXL_RETURN_IF_ERROR(MakeFrameHeader(frame_data.xsize, frame_data.ysize,
   2114                                      cparams, enc_state.progressive_splitter,
   2115                                      frame_info, jpeg_data.get(), false,
   2116                                      &frame_header));
   2117  const size_t num_passes = enc_state.progressive_splitter.GetNumPasses();
   2118  JXL_ASSIGN_OR_RETURN(ModularFrameEncoder enc_modular,
   2119                       ModularFrameEncoder::Create(memory_manager, frame_header,
   2120                                                   cparams, false));
   2121  std::vector<std::unique_ptr<BitWriter>> group_codes;
   2122  JXL_RETURN_IF_ERROR(ComputeEncodingData(
   2123      cparams, frame_info, metadata, frame_data, jpeg_data.get(), 0, 0,
   2124      frame_data.xsize, frame_data.ysize, cms, pool, frame_header, enc_modular,
   2125      enc_state, &group_codes, aux_out));
   2126 
   2127  BitWriter writer{memory_manager};
   2128  JXL_RETURN_IF_ERROR(writer.AppendByteAligned(enc_state.special_frames));
   2129  JXL_RETURN_IF_ERROR(WriteFrameHeader(frame_header, &writer, aux_out));
   2130 
   2131  std::vector<coeff_order_t> permutation;
   2132  JXL_RETURN_IF_ERROR(PermuteGroups(cparams, enc_state.shared.frame_dim,
   2133                                    num_passes, &permutation, &group_codes));
   2134 
   2135  JXL_RETURN_IF_ERROR(
   2136      WriteGroupOffsets(group_codes, permutation, &writer, aux_out));
   2137 
   2138  JXL_RETURN_IF_ERROR(writer.AppendByteAligned(group_codes));
   2139  PaddedBytes frame_bytes = std::move(writer).TakeBytes();
   2140  JXL_RETURN_IF_ERROR(AppendData(*output_processor, frame_bytes));
   2141 
   2142  return true;
   2143 }
   2144 
   2145 }  // namespace
   2146 
   2147 std::vector<CompressParams> TectonicPlateSettingsLessPalette(
   2148    const CompressParams& cparams_orig) {
   2149  std::vector<CompressParams> all_params;
   2150  CompressParams cparams_attempt = cparams_orig;
   2151  cparams_attempt.speed_tier = SpeedTier::kGlacier;
   2152 
   2153  cparams_attempt.options.max_properties = 4;
   2154  cparams_attempt.options.nb_repeats = 1.0f;
   2155  cparams_attempt.modular_group_size_shift = 0;
   2156  cparams_attempt.channel_colors_percent = 0;
   2157  cparams_attempt.options.predictor = Predictor::Variable;
   2158  cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2159  cparams_attempt.palette_colors = 1024;
   2160  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2161  cparams_attempt.patches = Override::kDefault;
   2162  all_params.push_back(cparams_attempt);
   2163  cparams_attempt.channel_colors_percent = 80.f;
   2164  cparams_attempt.modular_group_size_shift = 1;
   2165  cparams_attempt.palette_colors = 0;
   2166  cparams_attempt.channel_colors_pre_transform_percent = 0;
   2167  all_params.push_back(cparams_attempt);
   2168  cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2169  cparams_attempt.modular_group_size_shift = 2;
   2170  all_params.push_back(cparams_attempt);
   2171  cparams_attempt.modular_group_size_shift = 3;
   2172  cparams_attempt.patches = Override::kOff;
   2173  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2174  all_params.push_back(cparams_attempt);
   2175  cparams_attempt.palette_colors = 1024;
   2176  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2177  all_params.push_back(cparams_attempt);
   2178  cparams_attempt.patches = Override::kDefault;
   2179  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2180  all_params.push_back(cparams_attempt);
   2181  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2182  cparams_attempt.channel_colors_pre_transform_percent = 0;
   2183  all_params.push_back(cparams_attempt);
   2184  cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2185  cparams_attempt.options.nb_repeats = 0.9f;
   2186  cparams_attempt.modular_group_size_shift = 2;
   2187  all_params.push_back(cparams_attempt);
   2188  cparams_attempt.modular_group_size_shift = 3;
   2189  cparams_attempt.palette_colors = 0;
   2190  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2191  all_params.push_back(cparams_attempt);
   2192  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2193  cparams_attempt.channel_colors_pre_transform_percent = 0;
   2194  all_params.push_back(cparams_attempt);
   2195  cparams_attempt.palette_colors = 1024;
   2196  cparams_attempt.options.nb_repeats = 0.95f;
   2197  cparams_attempt.modular_group_size_shift = 1;
   2198  cparams_attempt.channel_colors_percent = 0;
   2199  all_params.push_back(cparams_attempt);
   2200  cparams_attempt.modular_group_size_shift = 2;
   2201  cparams_attempt.palette_colors = 0;
   2202  all_params.push_back(cparams_attempt);
   2203  cparams_attempt.channel_colors_percent = 80.f;
   2204  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2205  all_params.push_back(cparams_attempt);
   2206  cparams_attempt.palette_colors = 1024;
   2207  cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2208  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2209  cparams_attempt.modular_group_size_shift = 3;
   2210  all_params.push_back(cparams_attempt);
   2211  cparams_attempt.palette_colors = 0;
   2212  cparams_attempt.patches = Override::kOff;
   2213  all_params.push_back(cparams_attempt);
   2214  cparams_attempt.patches = Override::kDefault;
   2215  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2216  all_params.push_back(cparams_attempt);
   2217  cparams_attempt.palette_colors = 1024;
   2218  cparams_attempt.patches = Override::kOff;
   2219  all_params.push_back(cparams_attempt);
   2220  cparams_attempt.options.nb_repeats = 0.5f;
   2221  cparams_attempt.patches = Override::kDefault;
   2222  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2223  all_params.push_back(cparams_attempt);
   2224  cparams_attempt.options.predictor = Predictor::Zero;
   2225  cparams_attempt.options.nb_repeats = 0;
   2226  cparams_attempt.channel_colors_percent = 0;
   2227  cparams_attempt.channel_colors_pre_transform_percent = 0;
   2228  cparams_attempt.patches = Override::kOff;
   2229  all_params.push_back(cparams_attempt);
   2230  cparams_attempt.channel_colors_percent = 80.f;
   2231  cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2232  cparams_attempt.options.nb_repeats = 1.0f;
   2233  cparams_attempt.palette_colors = 0;
   2234  all_params.push_back(cparams_attempt);
   2235  cparams_attempt.patches = Override::kDefault;
   2236  cparams_attempt.options.predictor = Predictor::Best;
   2237  all_params.push_back(cparams_attempt);
   2238  cparams_attempt.options.nb_repeats = 0.9f;
   2239  cparams_attempt.patches = Override::kOff;
   2240  all_params.push_back(cparams_attempt);
   2241  cparams_attempt.palette_colors = 1024;
   2242  cparams_attempt.patches = Override::kDefault;
   2243  cparams_attempt.options.predictor = Predictor::Weighted;
   2244  cparams_attempt.options.nb_repeats = 1.0f;
   2245  all_params.push_back(cparams_attempt);
   2246  cparams_attempt.options.nb_repeats = 0.95f;
   2247  cparams_attempt.modular_group_size_shift = 2;
   2248  cparams_attempt.palette_colors = 0;
   2249  cparams_attempt.channel_colors_pre_transform_percent = 0;
   2250  all_params.push_back(cparams_attempt);
   2251  return all_params;
   2252 }
   2253 
   2254 std::vector<CompressParams> TectonicPlateSettingsMorePalette(
   2255    const CompressParams& cparams_orig) {
   2256  std::vector<CompressParams> all_params;
   2257  CompressParams cparams_attempt = cparams_orig;
   2258  cparams_attempt.speed_tier = SpeedTier::kGlacier;
   2259 
   2260  cparams_attempt.options.max_properties = 4;
   2261  cparams_attempt.options.nb_repeats = 1.0f;
   2262  cparams_attempt.modular_group_size_shift = 0;
   2263  cparams_attempt.palette_colors = 70000;
   2264  cparams_attempt.options.predictor = Predictor::Variable;
   2265  cparams_attempt.channel_colors_percent = 80.f;
   2266  cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2267  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2268  cparams_attempt.patches = Override::kDefault;
   2269  all_params.push_back(cparams_attempt);
   2270  cparams_attempt.modular_group_size_shift = 2;
   2271  cparams_attempt.channel_colors_percent = 0;
   2272  cparams_attempt.patches = Override::kOff;
   2273  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2274  all_params.push_back(cparams_attempt);
   2275  cparams_attempt.channel_colors_percent = 80.f;
   2276  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2277  cparams_attempt.modular_group_size_shift = 3;
   2278  all_params.push_back(cparams_attempt);
   2279  cparams_attempt.options.nb_repeats = 0.9f;
   2280  all_params.push_back(cparams_attempt);
   2281  cparams_attempt.patches = Override::kDefault;
   2282  cparams_attempt.options.nb_repeats = 0.95f;
   2283  cparams_attempt.modular_group_size_shift = 0;
   2284  all_params.push_back(cparams_attempt);
   2285  cparams_attempt.modular_group_size_shift = 3;
   2286  all_params.push_back(cparams_attempt);
   2287  cparams_attempt.patches = Override::kOff;
   2288  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2289  all_params.push_back(cparams_attempt);
   2290  cparams_attempt.options.nb_repeats = 0.5f;
   2291  all_params.push_back(cparams_attempt);
   2292  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2293  cparams_attempt.options.predictor = Predictor::Zero;
   2294  cparams_attempt.options.nb_repeats = 0;
   2295  all_params.push_back(cparams_attempt);
   2296  cparams_attempt.patches = Override::kDefault;
   2297  cparams_attempt.channel_colors_pre_transform_percent = 0;
   2298  all_params.push_back(cparams_attempt);
   2299  cparams_attempt.options.nb_repeats = 0.01f;
   2300  cparams_attempt.palette_colors = 0;
   2301  cparams_attempt.patches = Override::kOff;
   2302  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2303  all_params.push_back(cparams_attempt);
   2304  cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2305  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2306  cparams_attempt.palette_colors = 70000;
   2307  all_params.push_back(cparams_attempt);
   2308  cparams_attempt.options.nb_repeats = 1.0f;
   2309  cparams_attempt.modular_group_size_shift = 0;
   2310  cparams_attempt.channel_colors_percent = 0;
   2311  cparams_attempt.channel_colors_pre_transform_percent = 0;
   2312  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2313  all_params.push_back(cparams_attempt);
   2314  cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2315  cparams_attempt.modular_group_size_shift = 1;
   2316  all_params.push_back(cparams_attempt);
   2317  cparams_attempt.modular_group_size_shift = 2;
   2318  all_params.push_back(cparams_attempt);
   2319  cparams_attempt.channel_colors_percent = 80.f;
   2320  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2321  cparams_attempt.modular_group_size_shift = 3;
   2322  all_params.push_back(cparams_attempt);
   2323  cparams_attempt.options.nb_repeats = 0.5f;
   2324  cparams_attempt.modular_group_size_shift = 1;
   2325  cparams_attempt.channel_colors_percent = 0;
   2326  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2327  all_params.push_back(cparams_attempt);
   2328  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2329  cparams_attempt.modular_group_size_shift = 2;
   2330  all_params.push_back(cparams_attempt);
   2331  cparams_attempt.channel_colors_percent = 80.f;
   2332  cparams_attempt.modular_group_size_shift = 3;
   2333  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2334  all_params.push_back(cparams_attempt);
   2335  cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2336  cparams_attempt.options.predictor = Predictor::Select;
   2337  cparams_attempt.options.nb_repeats = 1.0f;
   2338  all_params.push_back(cparams_attempt);
   2339  return all_params;
   2340 }
   2341 
   2342 Status EncodeFrame(JxlMemoryManager* memory_manager,
   2343                   const CompressParams& cparams_orig,
   2344                   const FrameInfo& frame_info, const CodecMetadata* metadata,
   2345                   JxlEncoderChunkedFrameAdapter& frame_data,
   2346                   const JxlCmsInterface& cms, ThreadPool* pool,
   2347                   JxlEncoderOutputProcessorWrapper* output_processor,
   2348                   AuxOut* aux_out) {
   2349  CompressParams cparams = cparams_orig;
   2350  if (cparams.speed_tier == SpeedTier::kTectonicPlate &&
   2351      !cparams.IsLossless()) {
   2352    cparams.speed_tier = SpeedTier::kGlacier;
   2353  }
   2354  // Lightning mode is handled externally, so switch to Thunder mode to handle
   2355  // potentially weird cases.
   2356  if (cparams.speed_tier == SpeedTier::kLightning) {
   2357    cparams.speed_tier = SpeedTier::kThunder;
   2358  }
   2359  if (cparams.speed_tier == SpeedTier::kTectonicPlate) {
   2360    // Test palette performance to inform later trials.
   2361    std::vector<CompressParams> all_params;
   2362    CompressParams cparams_attempt = cparams_orig;
   2363    cparams_attempt.speed_tier = SpeedTier::kGlacier;
   2364 
   2365    cparams_attempt.options.max_properties = 4;
   2366    cparams_attempt.options.nb_repeats = 1.0f;
   2367    cparams_attempt.modular_group_size_shift = 3;
   2368    cparams_attempt.palette_colors = 0;
   2369    cparams_attempt.options.predictor = Predictor::Variable;
   2370    cparams_attempt.channel_colors_percent = 80.f;
   2371    cparams_attempt.channel_colors_pre_transform_percent = 95.f;
   2372    cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kDefault;
   2373    cparams_attempt.patches = Override::kDefault;
   2374    all_params.push_back(cparams_attempt);
   2375    cparams_attempt.options.predictor = Predictor::Zero;
   2376    cparams_attempt.options.nb_repeats = 0.01f;
   2377    cparams_attempt.palette_colors = 70000;
   2378    cparams_attempt.patches = Override::kOff;
   2379    cparams_attempt.options.wp_tree_mode = ModularOptions::TreeMode::kNoWP;
   2380    all_params.push_back(cparams_attempt);
   2381 
   2382    std::vector<size_t> size;
   2383    size.resize(all_params.size());
   2384 
   2385    const auto process_variant = [&](size_t task, size_t) -> Status {
   2386      std::vector<uint8_t> output(64);
   2387      uint8_t* next_out = output.data();
   2388      size_t avail_out = output.size();
   2389      JxlEncoderOutputProcessorWrapper local_output(memory_manager);
   2390      JXL_RETURN_IF_ERROR(local_output.SetAvailOut(&next_out, &avail_out));
   2391      JXL_RETURN_IF_ERROR(EncodeFrame(memory_manager, all_params[task],
   2392                                      frame_info, metadata, frame_data, cms,
   2393                                      nullptr, &local_output, aux_out));
   2394      size[task] = local_output.CurrentPosition();
   2395      return true;
   2396    };
   2397    JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, all_params.size(),
   2398                                  ThreadPool::NoInit, process_variant,
   2399                                  "Compress kTectonicPlate"));
   2400 
   2401    std::vector<CompressParams> all_params_test = all_params;
   2402    std::vector<size_t> size_test = size;
   2403    size_t best_idx_test = 0;
   2404 
   2405    if (size_test[0] <= size_test[1]) {
   2406      all_params = TectonicPlateSettingsLessPalette(cparams_orig);
   2407    } else {
   2408      best_idx_test = 1;
   2409      all_params = TectonicPlateSettingsMorePalette(cparams_orig);
   2410    }
   2411 
   2412    size.clear();
   2413    size.resize(all_params.size());
   2414 
   2415    JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, all_params.size(),
   2416                                  ThreadPool::NoInit, process_variant,
   2417                                  "Compress kTectonicPlate"));
   2418 
   2419    size_t best_idx = 0;
   2420    for (size_t i = 1; i < all_params.size(); i++) {
   2421      if (size[best_idx] > size[i]) {
   2422        best_idx = i;
   2423      }
   2424    }
   2425    if (size[best_idx] < size_test[best_idx_test]) {
   2426      cparams = all_params[best_idx];
   2427    } else {
   2428      cparams = all_params_test[best_idx_test];
   2429    }
   2430  }
   2431 
   2432  JXL_RETURN_IF_ERROR(ParamsPostInit(&cparams));
   2433 
   2434  if (cparams.butteraugli_distance < 0) {
   2435    return JXL_FAILURE("Expected non-negative distance");
   2436  }
   2437 
   2438  if (cparams.progressive_dc < 0) {
   2439    if (cparams.progressive_dc != -1) {
   2440      return JXL_FAILURE("Invalid progressive DC setting value (%d)",
   2441                         cparams.progressive_dc);
   2442    }
   2443    cparams.progressive_dc = 0;
   2444  }
   2445  if (cparams.ec_resampling < cparams.resampling) {
   2446    cparams.ec_resampling = cparams.resampling;
   2447  }
   2448  if (cparams.resampling > 1 || frame_info.is_preview) {
   2449    cparams.progressive_dc = 0;
   2450  }
   2451 
   2452  if (frame_info.dc_level + cparams.progressive_dc > 4) {
   2453    return JXL_FAILURE("Too many levels of progressive DC");
   2454  }
   2455 
   2456  if (cparams.butteraugli_distance != 0 &&
   2457      cparams.butteraugli_distance < kMinButteraugliDistance) {
   2458    return JXL_FAILURE("Butteraugli distance is too low (%f)",
   2459                       cparams.butteraugli_distance);
   2460  }
   2461 
   2462  if (frame_data.IsJPEG()) {
   2463    cparams.gaborish = Override::kOff;
   2464    cparams.epf = 0;
   2465    cparams.modular_mode = false;
   2466  }
   2467 
   2468  if (frame_data.xsize == 0 || frame_data.ysize == 0) {
   2469    return JXL_FAILURE("Empty image");
   2470  }
   2471 
   2472  // Assert that this metadata is correctly set up for the compression params,
   2473  // this should have been done by enc_file.cc
   2474  JXL_ENSURE(metadata->m.xyb_encoded ==
   2475             (cparams.color_transform == ColorTransform::kXYB));
   2476 
   2477  if (frame_data.IsJPEG() && cparams.color_transform == ColorTransform::kXYB) {
   2478    return JXL_FAILURE("Can't add JPEG frame to XYB codestream");
   2479  }
   2480 
   2481  if (CanDoStreamingEncoding(cparams, frame_info, *metadata, frame_data)) {
   2482    return EncodeFrameStreaming(memory_manager, cparams, frame_info, metadata,
   2483                                frame_data, cms, pool, output_processor,
   2484                                aux_out);
   2485  } else {
   2486    return EncodeFrameOneShot(memory_manager, cparams, frame_info, metadata,
   2487                              frame_data, cms, pool, output_processor, aux_out);
   2488  }
   2489 }
   2490 
   2491 Status EncodeFrame(JxlMemoryManager* memory_manager,
   2492                   const CompressParams& cparams_orig,
   2493                   const FrameInfo& frame_info, const CodecMetadata* metadata,
   2494                   ImageBundle& ib, const JxlCmsInterface& cms,
   2495                   ThreadPool* pool, BitWriter* writer, AuxOut* aux_out) {
   2496  JxlEncoderChunkedFrameAdapter frame_data(ib.xsize(), ib.ysize(),
   2497                                           ib.extra_channels().size());
   2498  std::vector<uint8_t> color;
   2499  if (ib.IsJPEG()) {
   2500    frame_data.SetJPEGData(std::move(ib.jpeg_data));
   2501  } else {
   2502    uint32_t num_channels =
   2503        ib.IsGray() && frame_info.ib_needs_color_transform ? 1 : 3;
   2504    size_t stride = ib.xsize() * num_channels * 4;
   2505    color.resize(ib.ysize() * stride);
   2506    JXL_RETURN_IF_ERROR(ConvertToExternal(
   2507        ib, /*bits_per_sample=*/32, /*float_out=*/true, num_channels,
   2508        JXL_NATIVE_ENDIAN, stride, pool, color.data(), color.size(),
   2509        /*out_callback=*/{}, Orientation::kIdentity));
   2510    JxlPixelFormat format{num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0};
   2511    frame_data.SetFromBuffer(0, color.data(), color.size(), format);
   2512  }
   2513  for (size_t ec = 0; ec < ib.extra_channels().size(); ++ec) {
   2514    JxlPixelFormat ec_format{1, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0};
   2515    size_t ec_stride = ib.xsize() * 4;
   2516    std::vector<uint8_t> ec_data(ib.ysize() * ec_stride);
   2517    const ImageF* channel = &ib.extra_channels()[ec];
   2518    JXL_RETURN_IF_ERROR(ConvertChannelsToExternal(
   2519        &channel, 1,
   2520        /*bits_per_sample=*/32,
   2521        /*float_out=*/true, JXL_NATIVE_ENDIAN, ec_stride, pool, ec_data.data(),
   2522        ec_data.size(), /*out_callback=*/{}, Orientation::kIdentity));
   2523    frame_data.SetFromBuffer(1 + ec, ec_data.data(), ec_data.size(), ec_format);
   2524  }
   2525  FrameInfo fi = frame_info;
   2526  fi.origin = ib.origin;
   2527  fi.blend = ib.blend;
   2528  fi.blendmode = ib.blendmode;
   2529  fi.duration = ib.duration;
   2530  fi.timecode = ib.timecode;
   2531  fi.name = ib.name;
   2532  std::vector<uint8_t> output(64);
   2533  uint8_t* next_out = output.data();
   2534  size_t avail_out = output.size();
   2535  JxlEncoderOutputProcessorWrapper output_processor(memory_manager);
   2536  JXL_RETURN_IF_ERROR(output_processor.SetAvailOut(&next_out, &avail_out));
   2537  JXL_RETURN_IF_ERROR(EncodeFrame(memory_manager, cparams_orig, fi, metadata,
   2538                                  frame_data, cms, pool, &output_processor,
   2539                                  aux_out));
   2540  JXL_RETURN_IF_ERROR(output_processor.SetFinalizedPosition());
   2541  JXL_RETURN_IF_ERROR(output_processor.CopyOutput(output, next_out, avail_out));
   2542  JXL_RETURN_IF_ERROR(writer->AppendByteAligned(Bytes(output)));
   2543  return true;
   2544 }
   2545 
   2546 }  // namespace jxl