tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_patch_dictionary.cc (32854B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_patch_dictionary.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 #include <jxl/types.h>
     10 #include <sys/types.h>
     11 
     12 #include <algorithm>
     13 #include <atomic>
     14 #include <cstdint>
     15 #include <cstdlib>
     16 #include <utility>
     17 #include <vector>
     18 
     19 #include "lib/jxl/base/common.h"
     20 #include "lib/jxl/base/compiler_specific.h"
     21 #include "lib/jxl/base/override.h"
     22 #include "lib/jxl/base/printf_macros.h"
     23 #include "lib/jxl/base/random.h"
     24 #include "lib/jxl/base/rect.h"
     25 #include "lib/jxl/base/status.h"
     26 #include "lib/jxl/dec_cache.h"
     27 #include "lib/jxl/dec_frame.h"
     28 #include "lib/jxl/enc_ans.h"
     29 #include "lib/jxl/enc_aux_out.h"
     30 #include "lib/jxl/enc_cache.h"
     31 #include "lib/jxl/enc_debug_image.h"
     32 #include "lib/jxl/enc_dot_dictionary.h"
     33 #include "lib/jxl/enc_frame.h"
     34 #include "lib/jxl/frame_header.h"
     35 #include "lib/jxl/image.h"
     36 #include "lib/jxl/image_bundle.h"
     37 #include "lib/jxl/image_ops.h"
     38 #include "lib/jxl/pack_signed.h"
     39 #include "lib/jxl/patch_dictionary_internal.h"
     40 
     41 namespace jxl {
     42 
     43 static constexpr size_t kPatchFrameReferenceId = 3;
     44 
     45 // static
     46 Status PatchDictionaryEncoder::Encode(const PatchDictionary& pdic,
     47                                      BitWriter* writer, LayerType layer,
     48                                      AuxOut* aux_out) {
     49  JXL_ENSURE(pdic.HasAny());
     50  JxlMemoryManager* memory_manager = writer->memory_manager();
     51  std::vector<std::vector<Token>> tokens(1);
     52 
     53  auto add_num = [&](int context, size_t num) {
     54    tokens[0].emplace_back(context, num);
     55  };
     56  size_t num_ref_patch = 0;
     57  for (size_t i = 0; i < pdic.positions_.size();) {
     58    size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx;
     59    while (i < pdic.positions_.size() &&
     60           pdic.positions_[i].ref_pos_idx == ref_pos_idx) {
     61      i++;
     62    }
     63    num_ref_patch++;
     64  }
     65  add_num(kNumRefPatchContext, num_ref_patch);
     66  size_t blend_pos = 0;
     67  size_t blending_stride = pdic.blendings_stride_;
     68  // blending_stride == num_ec + 1; num_ec > 1 =>
     69  bool choose_alpha = (blending_stride > 1 + 1);
     70  for (size_t i = 0; i < pdic.positions_.size();) {
     71    size_t i_start = i;
     72    size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx;
     73    const auto& ref_pos = pdic.ref_positions_[ref_pos_idx];
     74    while (i < pdic.positions_.size() &&
     75           pdic.positions_[i].ref_pos_idx == ref_pos_idx) {
     76      i++;
     77    }
     78    size_t num = i - i_start;
     79    JXL_ENSURE(num > 0);
     80    add_num(kReferenceFrameContext, ref_pos.ref);
     81    add_num(kPatchReferencePositionContext, ref_pos.x0);
     82    add_num(kPatchReferencePositionContext, ref_pos.y0);
     83    add_num(kPatchSizeContext, ref_pos.xsize - 1);
     84    add_num(kPatchSizeContext, ref_pos.ysize - 1);
     85    add_num(kPatchCountContext, num - 1);
     86    for (size_t j = i_start; j < i; j++) {
     87      const PatchPosition& pos = pdic.positions_[j];
     88      if (j == i_start) {
     89        add_num(kPatchPositionContext, pos.x);
     90        add_num(kPatchPositionContext, pos.y);
     91      } else {
     92        add_num(kPatchOffsetContext,
     93                PackSigned(pos.x - pdic.positions_[j - 1].x));
     94        add_num(kPatchOffsetContext,
     95                PackSigned(pos.y - pdic.positions_[j - 1].y));
     96      }
     97      for (size_t j = 0; j < blending_stride; ++j, ++blend_pos) {
     98        const PatchBlending& info = pdic.blendings_[blend_pos];
     99        add_num(kPatchBlendModeContext, static_cast<uint32_t>(info.mode));
    100        if (UsesAlpha(info.mode) && choose_alpha) {
    101          add_num(kPatchAlphaChannelContext, info.alpha_channel);
    102        }
    103        if (UsesClamp(info.mode)) {
    104          add_num(kPatchClampContext, TO_JXL_BOOL(info.clamp));
    105        }
    106      }
    107    }
    108  }
    109 
    110  EntropyEncodingData codes;
    111  std::vector<uint8_t> context_map;
    112  JXL_ASSIGN_OR_RETURN(
    113      size_t cost,
    114      BuildAndEncodeHistograms(memory_manager, HistogramParams(),
    115                               kNumPatchDictionaryContexts, tokens, &codes,
    116                               &context_map, writer, layer, aux_out));
    117  (void)cost;
    118  JXL_RETURN_IF_ERROR(
    119      WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out));
    120  return true;
    121 }
    122 
    123 // static
    124 Status PatchDictionaryEncoder::SubtractFrom(const PatchDictionary& pdic,
    125                                            Image3F* opsin) {
    126  // TODO(veluca): this can likely be optimized knowing it runs on full images.
    127  for (size_t y = 0; y < opsin->ysize(); y++) {
    128    float* JXL_RESTRICT rows[3] = {
    129        opsin->PlaneRow(0, y),
    130        opsin->PlaneRow(1, y),
    131        opsin->PlaneRow(2, y),
    132    };
    133    size_t blending_stride = pdic.blendings_stride_;
    134    for (size_t pos_idx : pdic.GetPatchesForRow(y)) {
    135      const size_t blending_idx = pos_idx * blending_stride;
    136      const PatchPosition& pos = pdic.positions_[pos_idx];
    137      const PatchReferencePosition& ref_pos =
    138          pdic.ref_positions_[pos.ref_pos_idx];
    139      const PatchBlendMode mode = pdic.blendings_[blending_idx].mode;
    140      size_t by = pos.y;
    141      size_t bx = pos.x;
    142      size_t xsize = ref_pos.xsize;
    143      JXL_ENSURE(y >= by);
    144      JXL_ENSURE(y < by + ref_pos.ysize);
    145      size_t iy = y - by;
    146      size_t ref = ref_pos.ref;
    147      const float* JXL_RESTRICT ref_rows[3] = {
    148          pdic.reference_frames_->at(ref).frame->color()->ConstPlaneRow(
    149              0, ref_pos.y0 + iy) +
    150              ref_pos.x0,
    151          pdic.reference_frames_->at(ref).frame->color()->ConstPlaneRow(
    152              1, ref_pos.y0 + iy) +
    153              ref_pos.x0,
    154          pdic.reference_frames_->at(ref).frame->color()->ConstPlaneRow(
    155              2, ref_pos.y0 + iy) +
    156              ref_pos.x0,
    157      };
    158      for (size_t ix = 0; ix < xsize; ix++) {
    159        for (size_t c = 0; c < 3; c++) {
    160          if (mode == PatchBlendMode::kAdd) {
    161            rows[c][bx + ix] -= ref_rows[c][ix];
    162          } else if (mode == PatchBlendMode::kReplace) {
    163            rows[c][bx + ix] = 0;
    164          } else if (mode == PatchBlendMode::kNone) {
    165            // Nothing to do.
    166          } else {
    167            return JXL_UNREACHABLE("blending mode %u not yet implemented",
    168                                   static_cast<uint32_t>(mode));
    169          }
    170        }
    171      }
    172    }
    173  }
    174  return true;
    175 }
    176 
    177 namespace {
    178 
    179 struct PatchColorspaceInfo {
    180  float kChannelDequant[3];
    181  float kChannelWeights[3];
    182 
    183  explicit PatchColorspaceInfo(bool is_xyb) {
    184    if (is_xyb) {
    185      kChannelDequant[0] = 0.01615;
    186      kChannelDequant[1] = 0.08875;
    187      kChannelDequant[2] = 0.1922;
    188      kChannelWeights[0] = 30.0;
    189      kChannelWeights[1] = 3.0;
    190      kChannelWeights[2] = 1.0;
    191    } else {
    192      kChannelDequant[0] = 20.0f / 255;
    193      kChannelDequant[1] = 22.0f / 255;
    194      kChannelDequant[2] = 20.0f / 255;
    195      kChannelWeights[0] = 0.017 * 255;
    196      kChannelWeights[1] = 0.02 * 255;
    197      kChannelWeights[2] = 0.017 * 255;
    198    }
    199  }
    200 
    201  float ScaleForQuantization(float val, size_t c) {
    202    return val / kChannelDequant[c];
    203  }
    204 
    205  int Quantize(float val, size_t c) {
    206    return truncf(ScaleForQuantization(val, c));
    207  }
    208 
    209  bool is_similar_v(const float v1[3], const float v2[3], float threshold) {
    210    float distance = 0;
    211    for (size_t c = 0; c < 3; c++) {
    212      distance += std::fabs(v1[c] - v2[c]) * kChannelWeights[c];
    213    }
    214    return distance <= threshold;
    215  }
    216 };
    217 
    218 StatusOr<std::vector<PatchInfo>> FindTextLikePatches(
    219    const CompressParams& cparams, const Image3F& opsin,
    220    const PassesEncoderState* JXL_RESTRICT state, ThreadPool* pool,
    221    AuxOut* aux_out, bool is_xyb) {
    222  std::vector<PatchInfo> info;
    223  if (state->cparams.patches == Override::kOff) return info;
    224  const auto& frame_dim = state->shared.frame_dim;
    225  JxlMemoryManager* memory_manager = opsin.memory_manager();
    226 
    227  PatchColorspaceInfo pci(is_xyb);
    228  float kSimilarThreshold = 0.8f;
    229 
    230  auto is_similar_impl = [&pci](std::pair<uint32_t, uint32_t> p1,
    231                                std::pair<uint32_t, uint32_t> p2,
    232                                const float* JXL_RESTRICT rows[3],
    233                                size_t stride, float threshold) {
    234    float v1[3];
    235    float v2[3];
    236    for (size_t c = 0; c < 3; c++) {
    237      v1[c] = rows[c][p1.second * stride + p1.first];
    238      v2[c] = rows[c][p2.second * stride + p2.first];
    239    }
    240    return pci.is_similar_v(v1, v2, threshold);
    241  };
    242 
    243  std::atomic<bool> has_screenshot_areas{false};
    244  const size_t opsin_stride = opsin.PixelsPerRow();
    245  const float* JXL_RESTRICT opsin_rows[3] = {opsin.ConstPlaneRow(0, 0),
    246                                             opsin.ConstPlaneRow(1, 0),
    247                                             opsin.ConstPlaneRow(2, 0)};
    248 
    249  auto is_same = [&opsin_rows, opsin_stride](std::pair<uint32_t, uint32_t> p1,
    250                                             std::pair<uint32_t, uint32_t> p2) {
    251    for (auto& opsin_row : opsin_rows) {
    252      float v1 = opsin_row[p1.second * opsin_stride + p1.first];
    253      float v2 = opsin_row[p2.second * opsin_stride + p2.first];
    254      if (std::fabs(v1 - v2) > 1e-4) {
    255        return false;
    256      }
    257    }
    258    return true;
    259  };
    260 
    261  auto is_similar = [&](std::pair<uint32_t, uint32_t> p1,
    262                        std::pair<uint32_t, uint32_t> p2) {
    263    return is_similar_impl(p1, p2, opsin_rows, opsin_stride, kSimilarThreshold);
    264  };
    265 
    266  constexpr int64_t kPatchSide = 4;
    267  constexpr int64_t kExtraSide = 4;
    268 
    269  // Look for kPatchSide size squares, naturally aligned, that all have the same
    270  // pixel values.
    271  JXL_ASSIGN_OR_RETURN(
    272      ImageB is_screenshot_like,
    273      ImageB::Create(memory_manager, DivCeil(frame_dim.xsize, kPatchSide),
    274                     DivCeil(frame_dim.ysize, kPatchSide)));
    275  ZeroFillImage(&is_screenshot_like);
    276  uint8_t* JXL_RESTRICT screenshot_row = is_screenshot_like.Row(0);
    277  const size_t screenshot_stride = is_screenshot_like.PixelsPerRow();
    278  const auto process_row = [&](const uint32_t y,
    279                               size_t /* thread */) -> Status {
    280    for (uint64_t x = 0; x < frame_dim.xsize / kPatchSide; x++) {
    281      bool all_same = true;
    282      for (size_t iy = 0; iy < static_cast<size_t>(kPatchSide); iy++) {
    283        for (size_t ix = 0; ix < static_cast<size_t>(kPatchSide); ix++) {
    284          size_t cx = x * kPatchSide + ix;
    285          size_t cy = y * kPatchSide + iy;
    286          if (!is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) {
    287            all_same = false;
    288            break;
    289          }
    290        }
    291      }
    292      if (!all_same) continue;
    293      size_t num = 0;
    294      size_t num_same = 0;
    295      for (int64_t iy = -kExtraSide; iy < kExtraSide + kPatchSide; iy++) {
    296        for (int64_t ix = -kExtraSide; ix < kExtraSide + kPatchSide; ix++) {
    297          int64_t cx = x * kPatchSide + ix;
    298          int64_t cy = y * kPatchSide + iy;
    299          if (cx < 0 || static_cast<uint64_t>(cx) >= frame_dim.xsize ||  //
    300              cy < 0 || static_cast<uint64_t>(cy) >= frame_dim.ysize) {
    301            continue;
    302          }
    303          num++;
    304          if (is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) num_same++;
    305        }
    306      }
    307      // Too few equal pixels nearby.
    308      if (num_same * 8 < num * 7) continue;
    309      screenshot_row[y * screenshot_stride + x] = 1;
    310      has_screenshot_areas = true;
    311    }
    312    return true;
    313  };
    314  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.ysize / kPatchSide,
    315                                ThreadPool::NoInit, process_row,
    316                                "IsScreenshotLike"));
    317 
    318  // TODO(veluca): also parallelize the rest of this function.
    319  if (WantDebugOutput(cparams)) {
    320    JXL_RETURN_IF_ERROR(
    321        DumpPlaneNormalized(cparams, "screenshot_like", is_screenshot_like));
    322  }
    323 
    324  constexpr int kSearchRadius = 1;
    325 
    326  if (!ApplyOverride(state->cparams.patches, has_screenshot_areas)) {
    327    return info;
    328  }
    329 
    330  // Search for "similar enough" pixels near the screenshot-like areas.
    331  JXL_ASSIGN_OR_RETURN(
    332      ImageB is_background,
    333      ImageB::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
    334  ZeroFillImage(&is_background);
    335  JXL_ASSIGN_OR_RETURN(
    336      Image3F background,
    337      Image3F::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
    338  ZeroFillImage(&background);
    339  constexpr size_t kDistanceLimit = 50;
    340  float* JXL_RESTRICT background_rows[3] = {
    341      background.PlaneRow(0, 0),
    342      background.PlaneRow(1, 0),
    343      background.PlaneRow(2, 0),
    344  };
    345  const size_t background_stride = background.PixelsPerRow();
    346  uint8_t* JXL_RESTRICT is_background_row = is_background.Row(0);
    347  const size_t is_background_stride = is_background.PixelsPerRow();
    348  std::vector<
    349      std::pair<std::pair<uint32_t, uint32_t>, std::pair<uint32_t, uint32_t>>>
    350      queue;
    351  size_t queue_front = 0;
    352  for (size_t y = 0; y < frame_dim.ysize; y++) {
    353    for (size_t x = 0; x < frame_dim.xsize; x++) {
    354      if (!screenshot_row[screenshot_stride * (y / kPatchSide) +
    355                          (x / kPatchSide)])
    356        continue;
    357      queue.push_back({{x, y}, {x, y}});
    358    }
    359  }
    360  while (queue.size() != queue_front) {
    361    std::pair<uint32_t, uint32_t> cur = queue[queue_front].first;
    362    std::pair<uint32_t, uint32_t> src = queue[queue_front].second;
    363    queue_front++;
    364    if (is_background_row[cur.second * is_background_stride + cur.first])
    365      continue;
    366    is_background_row[cur.second * is_background_stride + cur.first] = 1;
    367    for (size_t c = 0; c < 3; c++) {
    368      background_rows[c][cur.second * background_stride + cur.first] =
    369          opsin_rows[c][src.second * opsin_stride + src.first];
    370    }
    371    for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
    372      for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
    373        if (dx == 0 && dy == 0) continue;
    374        int next_first = cur.first + dx;
    375        int next_second = cur.second + dy;
    376        if (next_first < 0 || next_second < 0 ||
    377            static_cast<uint32_t>(next_first) >= frame_dim.xsize ||
    378            static_cast<uint32_t>(next_second) >= frame_dim.ysize) {
    379          continue;
    380        }
    381        if (static_cast<uint32_t>(
    382                std::abs(next_first - static_cast<int>(src.first)) +
    383                std::abs(next_second - static_cast<int>(src.second))) >
    384            kDistanceLimit) {
    385          continue;
    386        }
    387        std::pair<uint32_t, uint32_t> next{next_first, next_second};
    388        if (is_similar(src, next)) {
    389          if (!screenshot_row[next.second / kPatchSide * screenshot_stride +
    390                              next.first / kPatchSide] ||
    391              is_same(src, next)) {
    392            if (!is_background_row[next.second * is_background_stride +
    393                                   next.first])
    394              queue.emplace_back(next, src);
    395          }
    396        }
    397      }
    398    }
    399  }
    400  queue.clear();
    401 
    402  ImageF ccs;
    403  Rng rng(0);
    404  bool paint_ccs = false;
    405  if (WantDebugOutput(cparams)) {
    406    JXL_RETURN_IF_ERROR(
    407        DumpPlaneNormalized(cparams, "is_background", is_background));
    408    if (is_xyb) {
    409      JXL_RETURN_IF_ERROR(DumpXybImage(cparams, "background", background));
    410    } else {
    411      JXL_RETURN_IF_ERROR(DumpImage(cparams, "background", background));
    412    }
    413    JXL_ASSIGN_OR_RETURN(
    414        ccs, ImageF::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
    415    ZeroFillImage(&ccs);
    416    paint_ccs = true;
    417  }
    418 
    419  constexpr float kVerySimilarThreshold = 0.03f;
    420  constexpr float kHasSimilarThreshold = 0.03f;
    421 
    422  const float* JXL_RESTRICT const_background_rows[3] = {
    423      background_rows[0], background_rows[1], background_rows[2]};
    424  auto is_similar_b = [&](std::pair<int, int> p1, std::pair<int, int> p2) {
    425    return is_similar_impl(p1, p2, const_background_rows, background_stride,
    426                           kVerySimilarThreshold);
    427  };
    428 
    429  constexpr int kMinPeak = 2;
    430  constexpr int kHasSimilarRadius = 2;
    431 
    432  // Find small CC outside the "similar enough" areas, compute bounding boxes,
    433  // and run heuristics to exclude some patches.
    434  JXL_ASSIGN_OR_RETURN(
    435      ImageB visited,
    436      ImageB::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
    437  ZeroFillImage(&visited);
    438  uint8_t* JXL_RESTRICT visited_row = visited.Row(0);
    439  const size_t visited_stride = visited.PixelsPerRow();
    440  std::vector<std::pair<uint32_t, uint32_t>> cc;
    441  std::vector<std::pair<uint32_t, uint32_t>> stack;
    442  for (size_t y = 0; y < frame_dim.ysize; y++) {
    443    for (size_t x = 0; x < frame_dim.xsize; x++) {
    444      if (is_background_row[y * is_background_stride + x]) continue;
    445      cc.clear();
    446      stack.clear();
    447      stack.emplace_back(x, y);
    448      size_t min_x = x;
    449      size_t max_x = x;
    450      size_t min_y = y;
    451      size_t max_y = y;
    452      std::pair<uint32_t, uint32_t> reference;
    453      bool found_border = false;
    454      bool all_similar = true;
    455      while (!stack.empty()) {
    456        std::pair<uint32_t, uint32_t> cur = stack.back();
    457        stack.pop_back();
    458        if (visited_row[cur.second * visited_stride + cur.first]) continue;
    459        visited_row[cur.second * visited_stride + cur.first] = 1;
    460        if (cur.first < min_x) min_x = cur.first;
    461        if (cur.first > max_x) max_x = cur.first;
    462        if (cur.second < min_y) min_y = cur.second;
    463        if (cur.second > max_y) max_y = cur.second;
    464        if (paint_ccs) {
    465          cc.push_back(cur);
    466        }
    467        for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
    468          for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
    469            if (dx == 0 && dy == 0) continue;
    470            int next_first = static_cast<int32_t>(cur.first) + dx;
    471            int next_second = static_cast<int32_t>(cur.second) + dy;
    472            if (next_first < 0 || next_second < 0 ||
    473                static_cast<uint32_t>(next_first) >= frame_dim.xsize ||
    474                static_cast<uint32_t>(next_second) >= frame_dim.ysize) {
    475              continue;
    476            }
    477            std::pair<uint32_t, uint32_t> next{next_first, next_second};
    478            if (!is_background_row[next.second * is_background_stride +
    479                                   next.first]) {
    480              stack.push_back(next);
    481            } else {
    482              if (!found_border) {
    483                reference = next;
    484                found_border = true;
    485              } else {
    486                if (!is_similar_b(next, reference)) all_similar = false;
    487              }
    488            }
    489          }
    490        }
    491      }
    492      if (!found_border || !all_similar || max_x - min_x >= kMaxPatchSize ||
    493          max_y - min_y >= kMaxPatchSize) {
    494        continue;
    495      }
    496      size_t bpos = background_stride * reference.second + reference.first;
    497      float ref[3] = {background_rows[0][bpos], background_rows[1][bpos],
    498                      background_rows[2][bpos]};
    499      bool has_similar = false;
    500      for (size_t iy = std::max<int>(
    501               static_cast<int32_t>(min_y) - kHasSimilarRadius, 0);
    502           iy < std::min(max_y + kHasSimilarRadius + 1, frame_dim.ysize);
    503           iy++) {
    504        for (size_t ix = std::max<int>(
    505                 static_cast<int32_t>(min_x) - kHasSimilarRadius, 0);
    506             ix < std::min(max_x + kHasSimilarRadius + 1, frame_dim.xsize);
    507             ix++) {
    508          size_t opos = opsin_stride * iy + ix;
    509          float px[3] = {opsin_rows[0][opos], opsin_rows[1][opos],
    510                         opsin_rows[2][opos]};
    511          if (pci.is_similar_v(ref, px, kHasSimilarThreshold)) {
    512            has_similar = true;
    513          }
    514        }
    515      }
    516      if (!has_similar) continue;
    517      info.emplace_back();
    518      info.back().second.emplace_back(min_x, min_y);
    519      QuantizedPatch& patch = info.back().first;
    520      patch.xsize = max_x - min_x + 1;
    521      patch.ysize = max_y - min_y + 1;
    522      int max_value = 0;
    523      for (size_t c : {1, 0, 2}) {
    524        for (size_t iy = min_y; iy <= max_y; iy++) {
    525          for (size_t ix = min_x; ix <= max_x; ix++) {
    526            size_t offset = (iy - min_y) * patch.xsize + ix - min_x;
    527            patch.fpixels[c][offset] =
    528                opsin_rows[c][iy * opsin_stride + ix] - ref[c];
    529            int val = pci.Quantize(patch.fpixels[c][offset], c);
    530            patch.pixels[c][offset] = val;
    531            if (std::abs(val) > max_value) max_value = std::abs(val);
    532          }
    533        }
    534      }
    535      if (max_value < kMinPeak) {
    536        info.pop_back();
    537        continue;
    538      }
    539      if (paint_ccs) {
    540        float cc_color = rng.UniformF(0.5, 1.0);
    541        for (std::pair<uint32_t, uint32_t> p : cc) {
    542          ccs.Row(p.second)[p.first] = cc_color;
    543        }
    544      }
    545    }
    546  }
    547 
    548  if (paint_ccs) {
    549    JXL_ENSURE(WantDebugOutput(cparams));
    550    JXL_RETURN_IF_ERROR(DumpPlaneNormalized(cparams, "ccs", ccs));
    551  }
    552  if (info.empty()) {
    553    return info;
    554  }
    555 
    556  // Remove duplicates.
    557  constexpr size_t kMinPatchOccurrences = 2;
    558  std::sort(info.begin(), info.end());
    559  size_t unique = 0;
    560  for (size_t i = 1; i < info.size(); i++) {
    561    if (info[i].first == info[unique].first) {
    562      info[unique].second.insert(info[unique].second.end(),
    563                                 info[i].second.begin(), info[i].second.end());
    564    } else {
    565      if (info[unique].second.size() >= kMinPatchOccurrences) {
    566        unique++;
    567      }
    568      info[unique] = info[i];
    569    }
    570  }
    571  if (info[unique].second.size() >= kMinPatchOccurrences) {
    572    unique++;
    573  }
    574  info.resize(unique);
    575 
    576  size_t max_patch_size = 0;
    577 
    578  for (const auto& patch : info) {
    579    size_t pixels = patch.first.xsize * patch.first.ysize;
    580    if (pixels > max_patch_size) max_patch_size = pixels;
    581  }
    582 
    583  // don't use patches if all patches are smaller than this
    584  constexpr size_t kMinMaxPatchSize = 20;
    585  if (max_patch_size < kMinMaxPatchSize) {
    586    info.clear();
    587  }
    588 
    589  return info;
    590 }
    591 
    592 }  // namespace
    593 
    594 Status FindBestPatchDictionary(const Image3F& opsin,
    595                               PassesEncoderState* JXL_RESTRICT state,
    596                               const JxlCmsInterface& cms, ThreadPool* pool,
    597                               AuxOut* aux_out, bool is_xyb) {
    598  JXL_ASSIGN_OR_RETURN(
    599      std::vector<PatchInfo> info,
    600      FindTextLikePatches(state->cparams, opsin, state, pool, aux_out, is_xyb));
    601  JxlMemoryManager* memory_manager = opsin.memory_manager();
    602 
    603  // TODO(veluca): this doesn't work if both dots and patches are enabled.
    604  // For now, since dots and patches are not likely to occur in the same kind of
    605  // images, disable dots if some patches were found.
    606  if (info.empty() &&
    607      ApplyOverride(
    608          state->cparams.dots,
    609          state->cparams.speed_tier <= SpeedTier::kSquirrel &&
    610              state->cparams.butteraugli_distance >= kMinButteraugliForDots &&
    611              !state->cparams.disable_perceptual_optimizations)) {
    612    Rect rect(0, 0, state->shared.frame_dim.xsize,
    613              state->shared.frame_dim.ysize);
    614    JXL_ASSIGN_OR_RETURN(info,
    615                         FindDotDictionary(state->cparams, opsin, rect,
    616                                           state->shared.cmap.base(), pool));
    617  }
    618 
    619  if (info.empty()) return true;
    620 
    621  std::sort(
    622      info.begin(), info.end(), [&](const PatchInfo& a, const PatchInfo& b) {
    623        return a.first.xsize * a.first.ysize > b.first.xsize * b.first.ysize;
    624      });
    625 
    626  size_t max_x_size = 0;
    627  size_t max_y_size = 0;
    628  size_t total_pixels = 0;
    629 
    630  for (const auto& patch : info) {
    631    size_t pixels = patch.first.xsize * patch.first.ysize;
    632    if (max_x_size < patch.first.xsize) max_x_size = patch.first.xsize;
    633    if (max_y_size < patch.first.ysize) max_y_size = patch.first.ysize;
    634    total_pixels += pixels;
    635  }
    636 
    637  // Bin-packing & conversion of patches.
    638  constexpr float kBinPackingSlackness = 1.05f;
    639  size_t ref_xsize = std::max<float>(max_x_size, std::sqrt(total_pixels));
    640  size_t ref_ysize = std::max<float>(max_y_size, std::sqrt(total_pixels));
    641  std::vector<std::pair<size_t, size_t>> ref_positions(info.size());
    642  // TODO(veluca): allow partial overlaps of patches that have the same pixels.
    643  size_t max_y = 0;
    644  do {
    645    max_y = 0;
    646    // Increase packed image size.
    647    ref_xsize = ref_xsize * kBinPackingSlackness + 1;
    648    ref_ysize = ref_ysize * kBinPackingSlackness + 1;
    649 
    650    JXL_ASSIGN_OR_RETURN(ImageB occupied,
    651                         ImageB::Create(memory_manager, ref_xsize, ref_ysize));
    652    ZeroFillImage(&occupied);
    653    uint8_t* JXL_RESTRICT occupied_rows = occupied.Row(0);
    654    size_t occupied_stride = occupied.PixelsPerRow();
    655 
    656    bool success = true;
    657    // For every patch...
    658    for (size_t patch = 0; patch < info.size(); patch++) {
    659      size_t x0 = 0;
    660      size_t y0 = 0;
    661      size_t xsize = info[patch].first.xsize;
    662      size_t ysize = info[patch].first.ysize;
    663      bool found = false;
    664      // For every possible start position ...
    665      for (; y0 + ysize <= ref_ysize; y0++) {
    666        x0 = 0;
    667        for (; x0 + xsize <= ref_xsize; x0++) {
    668          bool has_occupied_pixel = false;
    669          size_t x = x0;
    670          // Check if it is possible to place the patch in this position in the
    671          // reference frame.
    672          for (size_t y = y0; y < y0 + ysize; y++) {
    673            x = x0;
    674            for (; x < x0 + xsize; x++) {
    675              if (occupied_rows[y * occupied_stride + x]) {
    676                has_occupied_pixel = true;
    677                break;
    678              }
    679            }
    680          }  // end of positioning check
    681          if (!has_occupied_pixel) {
    682            found = true;
    683            break;
    684          }
    685          x0 = x;  // Jump to next pixel after the occupied one.
    686        }
    687        if (found) break;
    688      }  // end of start position checking
    689 
    690      // We didn't find a possible position: repeat from the beginning with a
    691      // larger reference frame size.
    692      if (!found) {
    693        success = false;
    694        break;
    695      }
    696 
    697      // We found a position: mark the corresponding positions in the reference
    698      // image as used.
    699      ref_positions[patch] = {x0, y0};
    700      for (size_t y = y0; y < y0 + ysize; y++) {
    701        for (size_t x = x0; x < x0 + xsize; x++) {
    702          occupied_rows[y * occupied_stride + x] = JXL_TRUE;
    703        }
    704      }
    705      max_y = std::max(max_y, y0 + ysize);
    706    }
    707 
    708    if (success) break;
    709  } while (true);
    710 
    711  JXL_ENSURE(ref_ysize >= max_y);
    712 
    713  ref_ysize = max_y;
    714 
    715  JXL_ASSIGN_OR_RETURN(Image3F reference_frame,
    716                       Image3F::Create(memory_manager, ref_xsize, ref_ysize));
    717  // TODO(veluca): figure out a better way to fill the image.
    718  ZeroFillImage(&reference_frame);
    719  std::vector<PatchPosition> positions;
    720  std::vector<PatchReferencePosition> pref_positions;
    721  std::vector<PatchBlending> blendings;
    722  float* JXL_RESTRICT ref_rows[3] = {
    723      reference_frame.PlaneRow(0, 0),
    724      reference_frame.PlaneRow(1, 0),
    725      reference_frame.PlaneRow(2, 0),
    726  };
    727  size_t ref_stride = reference_frame.PixelsPerRow();
    728  size_t num_ec = state->shared.metadata->m.num_extra_channels;
    729 
    730  for (size_t i = 0; i < info.size(); i++) {
    731    PatchReferencePosition ref_pos;
    732    ref_pos.xsize = info[i].first.xsize;
    733    ref_pos.ysize = info[i].first.ysize;
    734    ref_pos.x0 = ref_positions[i].first;
    735    ref_pos.y0 = ref_positions[i].second;
    736    ref_pos.ref = kPatchFrameReferenceId;
    737    for (size_t y = 0; y < ref_pos.ysize; y++) {
    738      for (size_t x = 0; x < ref_pos.xsize; x++) {
    739        for (size_t c = 0; c < 3; c++) {
    740          ref_rows[c][(y + ref_pos.y0) * ref_stride + x + ref_pos.x0] =
    741              info[i].first.fpixels[c][y * ref_pos.xsize + x];
    742        }
    743      }
    744    }
    745    for (const auto& pos : info[i].second) {
    746      JXL_DEBUG_V(4, "Patch %" PRIuS "x%" PRIuS " at position %u,%u",
    747                  ref_pos.xsize, ref_pos.ysize, pos.first, pos.second);
    748      positions.emplace_back(
    749          PatchPosition{pos.first, pos.second, pref_positions.size()});
    750      // Add blending for color channels, ignore other channels.
    751      blendings.push_back({PatchBlendMode::kAdd, 0, false});
    752      for (size_t j = 0; j < num_ec; ++j) {
    753        blendings.push_back({PatchBlendMode::kNone, 0, false});
    754      }
    755    }
    756    pref_positions.emplace_back(ref_pos);
    757  }
    758 
    759  CompressParams cparams = state->cparams;
    760  // Recursive application of patches could create very weird issues.
    761  cparams.patches = Override::kOff;
    762 
    763  if (WantDebugOutput(cparams)) {
    764    if (is_xyb) {
    765      JXL_RETURN_IF_ERROR(
    766          DumpXybImage(cparams, "patch_reference", reference_frame));
    767    } else {
    768      JXL_RETURN_IF_ERROR(
    769          DumpImage(cparams, "patch_reference", reference_frame));
    770    }
    771  }
    772 
    773  JXL_RETURN_IF_ERROR(RoundtripPatchFrame(&reference_frame, state,
    774                                          kPatchFrameReferenceId, cparams, cms,
    775                                          pool, aux_out, /*subtract=*/true));
    776 
    777  // TODO(veluca): this assumes that applying patches is commutative, which is
    778  // not true for all blending modes. This code only produces kAdd patches, so
    779  // this works out.
    780  PatchDictionaryEncoder::SetPositions(
    781      &state->shared.image_features.patches, std::move(positions),
    782      std::move(pref_positions), std::move(blendings), num_ec + 1);
    783  return true;
    784 }
    785 
    786 Status RoundtripPatchFrame(Image3F* reference_frame,
    787                           PassesEncoderState* JXL_RESTRICT state, int idx,
    788                           CompressParams& cparams, const JxlCmsInterface& cms,
    789                           ThreadPool* pool, AuxOut* aux_out, bool subtract) {
    790  JxlMemoryManager* memory_manager = state->memory_manager();
    791  FrameInfo patch_frame_info;
    792  cparams.resampling = 1;
    793  cparams.ec_resampling = 1;
    794  cparams.dots = Override::kOff;
    795  cparams.noise = Override::kOff;
    796  cparams.modular_mode = true;
    797  cparams.responsive = 0;
    798  cparams.progressive_dc = 0;
    799  cparams.progressive_mode = Override::kOff;
    800  cparams.qprogressive_mode = Override::kOff;
    801  // Use gradient predictor and not Predictor::Best.
    802  cparams.options.predictor = Predictor::Gradient;
    803  patch_frame_info.save_as_reference = idx;  // always saved.
    804  patch_frame_info.frame_type = FrameType::kReferenceOnly;
    805  patch_frame_info.save_before_color_transform = true;
    806  ImageBundle ib(memory_manager, &state->shared.metadata->m);
    807  // TODO(veluca): metadata.color_encoding is a lie: ib is in XYB, but there is
    808  // no simple way to express that yet.
    809  patch_frame_info.ib_needs_color_transform = false;
    810  JXL_RETURN_IF_ERROR(ib.SetFromImage(
    811      std::move(*reference_frame), state->shared.metadata->m.color_encoding));
    812  if (!ib.metadata()->extra_channel_info.empty()) {
    813    // Add placeholder extra channels to the patch image: patch encoding does
    814    // not yet support extra channels, but the codec expects that the amount of
    815    // extra channels in frames matches that in the metadata of the codestream.
    816    std::vector<ImageF> extra_channels;
    817    extra_channels.reserve(ib.metadata()->extra_channel_info.size());
    818    for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) {
    819      JXL_ASSIGN_OR_RETURN(
    820          ImageF ch, ImageF::Create(memory_manager, ib.xsize(), ib.ysize()));
    821      extra_channels.emplace_back(std::move(ch));
    822      // Must initialize the image with data to not affect blending with
    823      // uninitialized memory.
    824      // TODO(lode): patches must copy and use the real extra channels instead.
    825      ZeroFillImage(&extra_channels.back());
    826    }
    827    JXL_RETURN_IF_ERROR(ib.SetExtraChannels(std::move(extra_channels)));
    828  }
    829  auto special_frame = jxl::make_unique<BitWriter>(memory_manager);
    830  AuxOut patch_aux_out;
    831  JXL_RETURN_IF_ERROR(EncodeFrame(
    832      memory_manager, cparams, patch_frame_info, state->shared.metadata, ib,
    833      cms, pool, special_frame.get(), aux_out ? &patch_aux_out : nullptr));
    834  if (aux_out) {
    835    for (const auto& l : patch_aux_out.layers) {
    836      aux_out->layer(LayerType::Dictionary).Assimilate(l);
    837    }
    838  }
    839  const Span<const uint8_t> encoded = special_frame->GetSpan();
    840  state->special_frames.emplace_back(std::move(special_frame));
    841  if (subtract) {
    842    ImageBundle decoded(memory_manager, &state->shared.metadata->m);
    843    PassesDecoderState dec_state(memory_manager);
    844    JXL_RETURN_IF_ERROR(dec_state.output_encoding_info.SetFromMetadata(
    845        *state->shared.metadata));
    846    const uint8_t* frame_start = encoded.data();
    847    size_t encoded_size = encoded.size();
    848    JXL_RETURN_IF_ERROR(DecodeFrame(&dec_state, pool, frame_start, encoded_size,
    849                                    /*frame_header=*/nullptr, &decoded,
    850                                    *state->shared.metadata));
    851    frame_start += decoded.decoded_bytes();
    852    encoded_size -= decoded.decoded_bytes();
    853    size_t ref_xsize =
    854        dec_state.shared_storage.reference_frames[idx].frame->color()->xsize();
    855    // if the frame itself uses patches, we need to decode another frame
    856    if (!ref_xsize) {
    857      JXL_RETURN_IF_ERROR(DecodeFrame(
    858          &dec_state, pool, frame_start, encoded_size,
    859          /*frame_header=*/nullptr, &decoded, *state->shared.metadata));
    860    }
    861    JXL_ENSURE(encoded_size == 0);
    862    state->shared.reference_frames[idx] =
    863        std::move(dec_state.shared_storage.reference_frames[idx]);
    864  } else {
    865    *state->shared.reference_frames[idx].frame = std::move(ib);
    866  }
    867  return true;
    868 }
    869 
    870 }  // namespace jxl