tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

dec_patch_dictionary.cc (13175B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_patch_dictionary.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 #include <sys/types.h>
     10 
     11 #include <algorithm>
     12 #include <cstdint>
     13 #include <cstdlib>
     14 #include <utility>
     15 #include <vector>
     16 
     17 #include "lib/jxl/base/printf_macros.h"
     18 #include "lib/jxl/base/status.h"
     19 #include "lib/jxl/blending.h"
     20 #include "lib/jxl/common.h"  // kMaxNumReferenceFrames
     21 #include "lib/jxl/dec_ans.h"
     22 #include "lib/jxl/image.h"
     23 #include "lib/jxl/image_bundle.h"
     24 #include "lib/jxl/pack_signed.h"
     25 #include "lib/jxl/patch_dictionary_internal.h"
     26 
     27 namespace jxl {
     28 
     29 Status PatchDictionary::Decode(JxlMemoryManager* memory_manager, BitReader* br,
     30                               size_t xsize, size_t ysize,
     31                               size_t num_extra_channels,
     32                               bool* uses_extra_channels) {
     33  positions_.clear();
     34  blendings_stride_ = num_extra_channels + 1;
     35  std::vector<uint8_t> context_map;
     36  ANSCode code;
     37  JXL_RETURN_IF_ERROR(DecodeHistograms(
     38      memory_manager, br, kNumPatchDictionaryContexts, &code, &context_map));
     39  JXL_ASSIGN_OR_RETURN(ANSSymbolReader decoder,
     40                       ANSSymbolReader::Create(&code, br));
     41 
     42  auto read_num = [&](size_t context) {
     43    size_t r = decoder.ReadHybridUint(context, br, context_map);
     44    return r;
     45  };
     46 
     47  size_t num_ref_patch = read_num(kNumRefPatchContext);
     48  // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8
     49  // bytes per size_t)
     50  const size_t num_pixels = xsize * ysize;
     51  const size_t max_ref_patches = 1024 + num_pixels / 4;
     52  const size_t max_patches = max_ref_patches * 4;
     53  const size_t max_blending_infos = max_patches * 4;
     54  if (num_ref_patch > max_ref_patches) {
     55    return JXL_FAILURE("Too many patches in dictionary");
     56  }
     57 
     58  size_t total_patches = 0;
     59  size_t next_size = 1;
     60 
     61  for (size_t id = 0; id < num_ref_patch; id++) {
     62    PatchReferencePosition ref_pos;
     63    ref_pos.ref = read_num(kReferenceFrameContext);
     64    if (ref_pos.ref >= kMaxNumReferenceFrames ||
     65        reference_frames_->at(ref_pos.ref).frame->xsize() == 0) {
     66      return JXL_FAILURE("Invalid reference frame ID");
     67    }
     68    if (!reference_frames_->at(ref_pos.ref).ib_is_in_xyb) {
     69      return JXL_FAILURE(
     70          "Patches cannot use frames saved post color transforms");
     71    }
     72    const ImageBundle& ib = *reference_frames_->at(ref_pos.ref).frame;
     73    ref_pos.x0 = read_num(kPatchReferencePositionContext);
     74    ref_pos.y0 = read_num(kPatchReferencePositionContext);
     75    ref_pos.xsize = read_num(kPatchSizeContext) + 1;
     76    ref_pos.ysize = read_num(kPatchSizeContext) + 1;
     77    if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) {
     78      return JXL_FAILURE("Invalid position specified in reference frame");
     79    }
     80    if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) {
     81      return JXL_FAILURE("Invalid position specified in reference frame");
     82    }
     83    size_t id_count = read_num(kPatchCountContext);
     84    if (id_count > max_patches) {
     85      return JXL_FAILURE("Too many patches in dictionary");
     86    }
     87    id_count++;
     88    total_patches += id_count;
     89    if (total_patches > max_patches) {
     90      return JXL_FAILURE("Too many patches in dictionary");
     91    }
     92    if (next_size < total_patches) {
     93      next_size *= 2;
     94      next_size = std::min<size_t>(next_size, max_patches);
     95    }
     96    if (next_size * blendings_stride_ > max_blending_infos) {
     97      return JXL_FAILURE("Too many patches in dictionary");
     98    }
     99    positions_.reserve(next_size);
    100    blendings_.reserve(next_size * blendings_stride_);
    101    bool choose_alpha = (num_extra_channels > 1);
    102    for (size_t i = 0; i < id_count; i++) {
    103      PatchPosition pos;
    104      pos.ref_pos_idx = ref_positions_.size();
    105      if (i == 0) {
    106        pos.x = read_num(kPatchPositionContext);
    107        pos.y = read_num(kPatchPositionContext);
    108      } else {
    109        ssize_t deltax = UnpackSigned(read_num(kPatchOffsetContext));
    110        if (deltax < 0 && static_cast<size_t>(-deltax) > positions_.back().x) {
    111          return JXL_FAILURE("Invalid patch: negative x coordinate (%" PRIuS
    112                             " base x %" PRIdS " delta x)",
    113                             positions_.back().x, deltax);
    114        }
    115        pos.x = positions_.back().x + deltax;
    116        ssize_t deltay = UnpackSigned(read_num(kPatchOffsetContext));
    117        if (deltay < 0 && static_cast<size_t>(-deltay) > positions_.back().y) {
    118          return JXL_FAILURE("Invalid patch: negative y coordinate (%" PRIuS
    119                             " base y %" PRIdS " delta y)",
    120                             positions_.back().y, deltay);
    121        }
    122        pos.y = positions_.back().y + deltay;
    123      }
    124      if (pos.x + ref_pos.xsize > xsize) {
    125        return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS
    126                           " > %" PRIuS,
    127                           pos.x, ref_pos.xsize, xsize);
    128      }
    129      if (pos.y + ref_pos.ysize > ysize) {
    130        return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS
    131                           " > %" PRIuS,
    132                           pos.y, ref_pos.ysize, ysize);
    133      }
    134      for (size_t j = 0; j < blendings_stride_; j++) {
    135        uint32_t blend_mode = read_num(kPatchBlendModeContext);
    136        if (blend_mode >= kNumPatchBlendModes) {
    137          return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode);
    138        }
    139        PatchBlending info;
    140        info.mode = static_cast<PatchBlendMode>(blend_mode);
    141        if (UsesAlpha(info.mode)) {
    142          *uses_extra_channels = true;
    143        }
    144        if (info.mode != PatchBlendMode::kNone && j > 0) {
    145          *uses_extra_channels = true;
    146        }
    147        if (UsesAlpha(info.mode) && choose_alpha) {
    148          info.alpha_channel = read_num(kPatchAlphaChannelContext);
    149          if (info.alpha_channel >= num_extra_channels) {
    150            return JXL_FAILURE(
    151                "Invalid alpha channel for blending: %u out of %u\n",
    152                info.alpha_channel, static_cast<uint32_t>(num_extra_channels));
    153          }
    154        } else {
    155          info.alpha_channel = 0;
    156        }
    157        if (UsesClamp(info.mode)) {
    158          info.clamp = static_cast<bool>(read_num(kPatchClampContext));
    159        } else {
    160          info.clamp = false;
    161        }
    162        blendings_.push_back(info);
    163      }
    164      positions_.emplace_back(pos);
    165    }
    166    ref_positions_.emplace_back(ref_pos);
    167  }
    168  positions_.shrink_to_fit();
    169 
    170  if (!decoder.CheckANSFinalState()) {
    171    return JXL_FAILURE("ANS checksum failure.");
    172  }
    173 
    174  ComputePatchTree();
    175  return true;
    176 }
    177 
    178 int PatchDictionary::GetReferences() const {
    179  int result = 0;
    180  for (const auto& ref_pos : ref_positions_) {
    181    result |= (1 << static_cast<int>(ref_pos.ref));
    182  }
    183  return result;
    184 }
    185 
    186 namespace {
    187 struct PatchInterval {
    188  size_t idx;
    189  size_t y0, y1;
    190 };
    191 }  // namespace
    192 
    193 void PatchDictionary::ComputePatchTree() {
    194  patch_tree_.clear();
    195  num_patches_.clear();
    196  sorted_patches_y0_.clear();
    197  sorted_patches_y1_.clear();
    198  if (positions_.empty()) {
    199    return;
    200  }
    201  // Create a y-interval for each patch.
    202  std::vector<PatchInterval> intervals(positions_.size());
    203  for (size_t i = 0; i < positions_.size(); ++i) {
    204    const auto& pos = positions_[i];
    205    intervals[i].idx = i;
    206    intervals[i].y0 = pos.y;
    207    intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize;
    208  }
    209  auto sort_by_y0 = [&intervals](size_t start, size_t end) {
    210    std::sort(intervals.data() + start, intervals.data() + end,
    211              [](const PatchInterval& i0, const PatchInterval& i1) {
    212                return i0.y0 < i1.y0;
    213              });
    214  };
    215  auto sort_by_y1 = [&intervals](size_t start, size_t end) {
    216    std::sort(intervals.data() + start, intervals.data() + end,
    217              [](const PatchInterval& i0, const PatchInterval& i1) {
    218                return i0.y1 < i1.y1;
    219              });
    220  };
    221  // Count the number of patches for each row.
    222  sort_by_y1(0, intervals.size());
    223  num_patches_.resize(intervals.back().y1);
    224  for (auto iv : intervals) {
    225    for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++;
    226  }
    227  PatchTreeNode root;
    228  root.start = 0;
    229  root.num = intervals.size();
    230  patch_tree_.push_back(root);
    231  size_t next = 0;
    232  while (next < patch_tree_.size()) {
    233    auto& node = patch_tree_[next];
    234    size_t start = node.start;
    235    size_t end = node.start + node.num;
    236    // Choose the y_center for this node to be the median of interval starts.
    237    sort_by_y0(start, end);
    238    size_t middle_idx = start + node.num / 2;
    239    node.y_center = intervals[middle_idx].y0;
    240    // Divide the intervals in [start, end) into three groups:
    241    //   * those completely to the right of y_center: [right_start, end)
    242    //   * those overlapping y_center: [left_end, right_start)
    243    //   * those completely to the left of y_center: [start, left_end)
    244    size_t right_start = middle_idx;
    245    while (right_start < end && intervals[right_start].y0 == node.y_center) {
    246      ++right_start;
    247    }
    248    sort_by_y1(start, right_start);
    249    size_t left_end = right_start;
    250    while (left_end > start && intervals[left_end - 1].y1 > node.y_center) {
    251      --left_end;
    252    }
    253    // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node.
    254    node.num = right_start - left_end;
    255    node.start = sorted_patches_y0_.size();
    256    for (ssize_t i = static_cast<ssize_t>(right_start) - 1;
    257         i >= static_cast<ssize_t>(left_end); --i) {
    258      sorted_patches_y1_.emplace_back(intervals[i].y1, intervals[i].idx);
    259    }
    260    sort_by_y0(left_end, right_start);
    261    for (size_t i = left_end; i < right_start; ++i) {
    262      sorted_patches_y0_.emplace_back(intervals[i].y0, intervals[i].idx);
    263    }
    264    // Create the left and right nodes (if not empty).
    265    node.left_child = node.right_child = -1;
    266    if (left_end > start) {
    267      PatchTreeNode left;
    268      left.start = start;
    269      left.num = left_end - left.start;
    270      patch_tree_[next].left_child = patch_tree_.size();
    271      patch_tree_.push_back(left);
    272    }
    273    if (right_start < end) {
    274      PatchTreeNode right;
    275      right.start = right_start;
    276      right.num = end - right.start;
    277      patch_tree_[next].right_child = patch_tree_.size();
    278      patch_tree_.push_back(right);
    279    }
    280    ++next;
    281  }
    282 }
    283 
    284 std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const {
    285  std::vector<size_t> result;
    286  if (y < num_patches_.size() && num_patches_[y] > 0) {
    287    result.reserve(num_patches_[y]);
    288    for (ssize_t tree_idx = 0; tree_idx != -1;) {
    289      JXL_DASSERT(tree_idx < static_cast<ssize_t>(patch_tree_.size()));
    290      const auto& node = patch_tree_[tree_idx];
    291      if (y <= node.y_center) {
    292        for (size_t i = 0; i < node.num; ++i) {
    293          const auto& p = sorted_patches_y0_[node.start + i];
    294          if (y < p.first) break;
    295          result.push_back(p.second);
    296        }
    297        tree_idx = y < node.y_center ? node.left_child : -1;
    298      } else {
    299        for (size_t i = 0; i < node.num; ++i) {
    300          const auto& p = sorted_patches_y1_[node.start + i];
    301          if (y >= p.first) break;
    302          result.push_back(p.second);
    303        }
    304        tree_idx = node.right_child;
    305      }
    306    }
    307    // Ensure that he relative order of patches that affect the same pixels is
    308    // preserved. This is important for patches that have a blend mode
    309    // different from kAdd.
    310    std::sort(result.begin(), result.end());
    311  }
    312  return result;
    313 }
    314 
    315 // Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed
    316 // to be located at position (x0, y) in the frame.
    317 Status PatchDictionary::AddOneRow(
    318    float* const* inout, size_t y, size_t x0, size_t xsize,
    319    const std::vector<ExtraChannelInfo>& extra_channel_info) const {
    320  size_t num_ec = extra_channel_info.size();
    321  JXL_ENSURE(num_ec + 1 <= blendings_stride_);
    322  std::vector<const float*> fg_ptrs(3 + num_ec);
    323  for (size_t pos_idx : GetPatchesForRow(y)) {
    324    const size_t blending_idx = pos_idx * blendings_stride_;
    325    const PatchPosition& pos = positions_[pos_idx];
    326    const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx];
    327    size_t by = pos.y;
    328    size_t bx = pos.x;
    329    size_t patch_xsize = ref_pos.xsize;
    330    JXL_ENSURE(y >= by);
    331    JXL_ENSURE(y < by + ref_pos.ysize);
    332    size_t iy = y - by;
    333    size_t ref = ref_pos.ref;
    334    if (bx >= x0 + xsize) continue;
    335    if (bx + patch_xsize < x0) continue;
    336    size_t patch_x0 = std::max(bx, x0);
    337    size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize);
    338    for (size_t c = 0; c < 3; c++) {
    339      fg_ptrs[c] = reference_frames_->at(ref).frame->color()->ConstPlaneRow(
    340                       c, ref_pos.y0 + iy) +
    341                   ref_pos.x0 + x0 - bx;
    342    }
    343    for (size_t i = 0; i < num_ec; i++) {
    344      fg_ptrs[3 + i] =
    345          reference_frames_->at(ref).frame->extra_channels()[i].ConstRow(
    346              ref_pos.y0 + iy) +
    347          ref_pos.x0 + x0 - bx;
    348    }
    349    JXL_RETURN_IF_ERROR(PerformBlending(
    350        memory_manager_, inout, fg_ptrs.data(), inout, patch_x0 - x0,
    351        patch_x1 - patch_x0, blendings_[blending_idx],
    352        blendings_.data() + blending_idx + 1, extra_channel_info));
    353  }
    354  return true;
    355 }
    356 }  // namespace jxl