dec_patch_dictionary.cc (13175B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_patch_dictionary.h" 7 8 #include <jxl/memory_manager.h> 9 #include <sys/types.h> 10 11 #include <algorithm> 12 #include <cstdint> 13 #include <cstdlib> 14 #include <utility> 15 #include <vector> 16 17 #include "lib/jxl/base/printf_macros.h" 18 #include "lib/jxl/base/status.h" 19 #include "lib/jxl/blending.h" 20 #include "lib/jxl/common.h" // kMaxNumReferenceFrames 21 #include "lib/jxl/dec_ans.h" 22 #include "lib/jxl/image.h" 23 #include "lib/jxl/image_bundle.h" 24 #include "lib/jxl/pack_signed.h" 25 #include "lib/jxl/patch_dictionary_internal.h" 26 27 namespace jxl { 28 29 Status PatchDictionary::Decode(JxlMemoryManager* memory_manager, BitReader* br, 30 size_t xsize, size_t ysize, 31 size_t num_extra_channels, 32 bool* uses_extra_channels) { 33 positions_.clear(); 34 blendings_stride_ = num_extra_channels + 1; 35 std::vector<uint8_t> context_map; 36 ANSCode code; 37 JXL_RETURN_IF_ERROR(DecodeHistograms( 38 memory_manager, br, kNumPatchDictionaryContexts, &code, &context_map)); 39 JXL_ASSIGN_OR_RETURN(ANSSymbolReader decoder, 40 ANSSymbolReader::Create(&code, br)); 41 42 auto read_num = [&](size_t context) { 43 size_t r = decoder.ReadHybridUint(context, br, context_map); 44 return r; 45 }; 46 47 size_t num_ref_patch = read_num(kNumRefPatchContext); 48 // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8 49 // bytes per size_t) 50 const size_t num_pixels = xsize * ysize; 51 const size_t max_ref_patches = 1024 + num_pixels / 4; 52 const size_t max_patches = max_ref_patches * 4; 53 const size_t max_blending_infos = max_patches * 4; 54 if (num_ref_patch > max_ref_patches) { 55 return JXL_FAILURE("Too many patches in dictionary"); 56 } 57 58 size_t total_patches = 0; 59 size_t next_size = 1; 60 61 for (size_t id = 0; id < num_ref_patch; id++) { 62 PatchReferencePosition ref_pos; 63 ref_pos.ref = read_num(kReferenceFrameContext); 64 if (ref_pos.ref >= kMaxNumReferenceFrames || 65 reference_frames_->at(ref_pos.ref).frame->xsize() == 0) { 66 return JXL_FAILURE("Invalid reference frame ID"); 67 } 68 if (!reference_frames_->at(ref_pos.ref).ib_is_in_xyb) { 69 return JXL_FAILURE( 70 "Patches cannot use frames saved post color transforms"); 71 } 72 const ImageBundle& ib = *reference_frames_->at(ref_pos.ref).frame; 73 ref_pos.x0 = read_num(kPatchReferencePositionContext); 74 ref_pos.y0 = read_num(kPatchReferencePositionContext); 75 ref_pos.xsize = read_num(kPatchSizeContext) + 1; 76 ref_pos.ysize = read_num(kPatchSizeContext) + 1; 77 if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) { 78 return JXL_FAILURE("Invalid position specified in reference frame"); 79 } 80 if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) { 81 return JXL_FAILURE("Invalid position specified in reference frame"); 82 } 83 size_t id_count = read_num(kPatchCountContext); 84 if (id_count > max_patches) { 85 return JXL_FAILURE("Too many patches in dictionary"); 86 } 87 id_count++; 88 total_patches += id_count; 89 if (total_patches > max_patches) { 90 return JXL_FAILURE("Too many patches in dictionary"); 91 } 92 if (next_size < total_patches) { 93 next_size *= 2; 94 next_size = std::min<size_t>(next_size, max_patches); 95 } 96 if (next_size * blendings_stride_ > max_blending_infos) { 97 return JXL_FAILURE("Too many patches in dictionary"); 98 } 99 positions_.reserve(next_size); 100 blendings_.reserve(next_size * blendings_stride_); 101 bool choose_alpha = (num_extra_channels > 1); 102 for (size_t i = 0; i < id_count; i++) { 103 PatchPosition pos; 104 pos.ref_pos_idx = ref_positions_.size(); 105 if (i == 0) { 106 pos.x = read_num(kPatchPositionContext); 107 pos.y = read_num(kPatchPositionContext); 108 } else { 109 ssize_t deltax = UnpackSigned(read_num(kPatchOffsetContext)); 110 if (deltax < 0 && static_cast<size_t>(-deltax) > positions_.back().x) { 111 return JXL_FAILURE("Invalid patch: negative x coordinate (%" PRIuS 112 " base x %" PRIdS " delta x)", 113 positions_.back().x, deltax); 114 } 115 pos.x = positions_.back().x + deltax; 116 ssize_t deltay = UnpackSigned(read_num(kPatchOffsetContext)); 117 if (deltay < 0 && static_cast<size_t>(-deltay) > positions_.back().y) { 118 return JXL_FAILURE("Invalid patch: negative y coordinate (%" PRIuS 119 " base y %" PRIdS " delta y)", 120 positions_.back().y, deltay); 121 } 122 pos.y = positions_.back().y + deltay; 123 } 124 if (pos.x + ref_pos.xsize > xsize) { 125 return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS 126 " > %" PRIuS, 127 pos.x, ref_pos.xsize, xsize); 128 } 129 if (pos.y + ref_pos.ysize > ysize) { 130 return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS 131 " > %" PRIuS, 132 pos.y, ref_pos.ysize, ysize); 133 } 134 for (size_t j = 0; j < blendings_stride_; j++) { 135 uint32_t blend_mode = read_num(kPatchBlendModeContext); 136 if (blend_mode >= kNumPatchBlendModes) { 137 return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode); 138 } 139 PatchBlending info; 140 info.mode = static_cast<PatchBlendMode>(blend_mode); 141 if (UsesAlpha(info.mode)) { 142 *uses_extra_channels = true; 143 } 144 if (info.mode != PatchBlendMode::kNone && j > 0) { 145 *uses_extra_channels = true; 146 } 147 if (UsesAlpha(info.mode) && choose_alpha) { 148 info.alpha_channel = read_num(kPatchAlphaChannelContext); 149 if (info.alpha_channel >= num_extra_channels) { 150 return JXL_FAILURE( 151 "Invalid alpha channel for blending: %u out of %u\n", 152 info.alpha_channel, static_cast<uint32_t>(num_extra_channels)); 153 } 154 } else { 155 info.alpha_channel = 0; 156 } 157 if (UsesClamp(info.mode)) { 158 info.clamp = static_cast<bool>(read_num(kPatchClampContext)); 159 } else { 160 info.clamp = false; 161 } 162 blendings_.push_back(info); 163 } 164 positions_.emplace_back(pos); 165 } 166 ref_positions_.emplace_back(ref_pos); 167 } 168 positions_.shrink_to_fit(); 169 170 if (!decoder.CheckANSFinalState()) { 171 return JXL_FAILURE("ANS checksum failure."); 172 } 173 174 ComputePatchTree(); 175 return true; 176 } 177 178 int PatchDictionary::GetReferences() const { 179 int result = 0; 180 for (const auto& ref_pos : ref_positions_) { 181 result |= (1 << static_cast<int>(ref_pos.ref)); 182 } 183 return result; 184 } 185 186 namespace { 187 struct PatchInterval { 188 size_t idx; 189 size_t y0, y1; 190 }; 191 } // namespace 192 193 void PatchDictionary::ComputePatchTree() { 194 patch_tree_.clear(); 195 num_patches_.clear(); 196 sorted_patches_y0_.clear(); 197 sorted_patches_y1_.clear(); 198 if (positions_.empty()) { 199 return; 200 } 201 // Create a y-interval for each patch. 202 std::vector<PatchInterval> intervals(positions_.size()); 203 for (size_t i = 0; i < positions_.size(); ++i) { 204 const auto& pos = positions_[i]; 205 intervals[i].idx = i; 206 intervals[i].y0 = pos.y; 207 intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize; 208 } 209 auto sort_by_y0 = [&intervals](size_t start, size_t end) { 210 std::sort(intervals.data() + start, intervals.data() + end, 211 [](const PatchInterval& i0, const PatchInterval& i1) { 212 return i0.y0 < i1.y0; 213 }); 214 }; 215 auto sort_by_y1 = [&intervals](size_t start, size_t end) { 216 std::sort(intervals.data() + start, intervals.data() + end, 217 [](const PatchInterval& i0, const PatchInterval& i1) { 218 return i0.y1 < i1.y1; 219 }); 220 }; 221 // Count the number of patches for each row. 222 sort_by_y1(0, intervals.size()); 223 num_patches_.resize(intervals.back().y1); 224 for (auto iv : intervals) { 225 for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++; 226 } 227 PatchTreeNode root; 228 root.start = 0; 229 root.num = intervals.size(); 230 patch_tree_.push_back(root); 231 size_t next = 0; 232 while (next < patch_tree_.size()) { 233 auto& node = patch_tree_[next]; 234 size_t start = node.start; 235 size_t end = node.start + node.num; 236 // Choose the y_center for this node to be the median of interval starts. 237 sort_by_y0(start, end); 238 size_t middle_idx = start + node.num / 2; 239 node.y_center = intervals[middle_idx].y0; 240 // Divide the intervals in [start, end) into three groups: 241 // * those completely to the right of y_center: [right_start, end) 242 // * those overlapping y_center: [left_end, right_start) 243 // * those completely to the left of y_center: [start, left_end) 244 size_t right_start = middle_idx; 245 while (right_start < end && intervals[right_start].y0 == node.y_center) { 246 ++right_start; 247 } 248 sort_by_y1(start, right_start); 249 size_t left_end = right_start; 250 while (left_end > start && intervals[left_end - 1].y1 > node.y_center) { 251 --left_end; 252 } 253 // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node. 254 node.num = right_start - left_end; 255 node.start = sorted_patches_y0_.size(); 256 for (ssize_t i = static_cast<ssize_t>(right_start) - 1; 257 i >= static_cast<ssize_t>(left_end); --i) { 258 sorted_patches_y1_.emplace_back(intervals[i].y1, intervals[i].idx); 259 } 260 sort_by_y0(left_end, right_start); 261 for (size_t i = left_end; i < right_start; ++i) { 262 sorted_patches_y0_.emplace_back(intervals[i].y0, intervals[i].idx); 263 } 264 // Create the left and right nodes (if not empty). 265 node.left_child = node.right_child = -1; 266 if (left_end > start) { 267 PatchTreeNode left; 268 left.start = start; 269 left.num = left_end - left.start; 270 patch_tree_[next].left_child = patch_tree_.size(); 271 patch_tree_.push_back(left); 272 } 273 if (right_start < end) { 274 PatchTreeNode right; 275 right.start = right_start; 276 right.num = end - right.start; 277 patch_tree_[next].right_child = patch_tree_.size(); 278 patch_tree_.push_back(right); 279 } 280 ++next; 281 } 282 } 283 284 std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const { 285 std::vector<size_t> result; 286 if (y < num_patches_.size() && num_patches_[y] > 0) { 287 result.reserve(num_patches_[y]); 288 for (ssize_t tree_idx = 0; tree_idx != -1;) { 289 JXL_DASSERT(tree_idx < static_cast<ssize_t>(patch_tree_.size())); 290 const auto& node = patch_tree_[tree_idx]; 291 if (y <= node.y_center) { 292 for (size_t i = 0; i < node.num; ++i) { 293 const auto& p = sorted_patches_y0_[node.start + i]; 294 if (y < p.first) break; 295 result.push_back(p.second); 296 } 297 tree_idx = y < node.y_center ? node.left_child : -1; 298 } else { 299 for (size_t i = 0; i < node.num; ++i) { 300 const auto& p = sorted_patches_y1_[node.start + i]; 301 if (y >= p.first) break; 302 result.push_back(p.second); 303 } 304 tree_idx = node.right_child; 305 } 306 } 307 // Ensure that he relative order of patches that affect the same pixels is 308 // preserved. This is important for patches that have a blend mode 309 // different from kAdd. 310 std::sort(result.begin(), result.end()); 311 } 312 return result; 313 } 314 315 // Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed 316 // to be located at position (x0, y) in the frame. 317 Status PatchDictionary::AddOneRow( 318 float* const* inout, size_t y, size_t x0, size_t xsize, 319 const std::vector<ExtraChannelInfo>& extra_channel_info) const { 320 size_t num_ec = extra_channel_info.size(); 321 JXL_ENSURE(num_ec + 1 <= blendings_stride_); 322 std::vector<const float*> fg_ptrs(3 + num_ec); 323 for (size_t pos_idx : GetPatchesForRow(y)) { 324 const size_t blending_idx = pos_idx * blendings_stride_; 325 const PatchPosition& pos = positions_[pos_idx]; 326 const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx]; 327 size_t by = pos.y; 328 size_t bx = pos.x; 329 size_t patch_xsize = ref_pos.xsize; 330 JXL_ENSURE(y >= by); 331 JXL_ENSURE(y < by + ref_pos.ysize); 332 size_t iy = y - by; 333 size_t ref = ref_pos.ref; 334 if (bx >= x0 + xsize) continue; 335 if (bx + patch_xsize < x0) continue; 336 size_t patch_x0 = std::max(bx, x0); 337 size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize); 338 for (size_t c = 0; c < 3; c++) { 339 fg_ptrs[c] = reference_frames_->at(ref).frame->color()->ConstPlaneRow( 340 c, ref_pos.y0 + iy) + 341 ref_pos.x0 + x0 - bx; 342 } 343 for (size_t i = 0; i < num_ec; i++) { 344 fg_ptrs[3 + i] = 345 reference_frames_->at(ref).frame->extra_channels()[i].ConstRow( 346 ref_pos.y0 + iy) + 347 ref_pos.x0 + x0 - bx; 348 } 349 JXL_RETURN_IF_ERROR(PerformBlending( 350 memory_manager_, inout, fg_ptrs.data(), inout, patch_x0 - x0, 351 patch_x1 - patch_x0, blendings_[blending_idx], 352 blendings_.data() + blending_idx + 1, extra_channel_info)); 353 } 354 return true; 355 } 356 } // namespace jxl