dec_group.cc (33239B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_group.h" 7 8 #include <algorithm> 9 #include <cstdint> 10 #include <cstring> 11 #include <memory> 12 #include <utility> 13 14 #include "lib/jxl/chroma_from_luma.h" 15 #include "lib/jxl/frame_header.h" 16 17 #undef HWY_TARGET_INCLUDE 18 #define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc" 19 #include <hwy/foreach_target.h> 20 #include <hwy/highway.h> 21 22 #include "lib/jxl/ac_context.h" 23 #include "lib/jxl/ac_strategy.h" 24 #include "lib/jxl/base/bits.h" 25 #include "lib/jxl/base/common.h" 26 #include "lib/jxl/base/printf_macros.h" 27 #include "lib/jxl/base/rect.h" 28 #include "lib/jxl/base/status.h" 29 #include "lib/jxl/coeff_order.h" 30 #include "lib/jxl/common.h" // kMaxNumPasses 31 #include "lib/jxl/dec_cache.h" 32 #include "lib/jxl/dec_transforms-inl.h" 33 #include "lib/jxl/dec_xyb.h" 34 #include "lib/jxl/entropy_coder.h" 35 #include "lib/jxl/quant_weights.h" 36 #include "lib/jxl/quantizer-inl.h" 37 #include "lib/jxl/quantizer.h" 38 39 #ifndef LIB_JXL_DEC_GROUP_CC 40 #define LIB_JXL_DEC_GROUP_CC 41 namespace jxl { 42 43 struct AuxOut; 44 45 // Interface for reading groups for DecodeGroupImpl. 46 class GetBlock { 47 public: 48 virtual void StartRow(size_t by) = 0; 49 virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, 50 size_t size, size_t log2_covered_blocks, 51 ACPtr block[3], ACType ac_type) = 0; 52 virtual ~GetBlock() {} 53 }; 54 55 // Controls whether DecodeGroupImpl renders to pixels or not. 56 enum DrawMode { 57 // Render to pixels. 58 kDraw = 0, 59 // Don't render to pixels. 60 kDontDraw = 1, 61 }; 62 63 } // namespace jxl 64 #endif // LIB_JXL_DEC_GROUP_CC 65 66 HWY_BEFORE_NAMESPACE(); 67 namespace jxl { 68 namespace HWY_NAMESPACE { 69 70 // These templates are not found via ADL. 71 using hwy::HWY_NAMESPACE::AllFalse; 72 using hwy::HWY_NAMESPACE::Gt; 73 using hwy::HWY_NAMESPACE::Le; 74 using hwy::HWY_NAMESPACE::MaskFromVec; 75 using hwy::HWY_NAMESPACE::Or; 76 using hwy::HWY_NAMESPACE::Rebind; 77 using hwy::HWY_NAMESPACE::ShiftRight; 78 79 using D = HWY_FULL(float); 80 using DU = HWY_FULL(uint32_t); 81 using DI = HWY_FULL(int32_t); 82 using DI16 = Rebind<int16_t, DI>; 83 using DI16_FULL = HWY_CAPPED(int16_t, kDCTBlockSize); 84 constexpr D d; 85 constexpr DI di; 86 constexpr DI16 di16; 87 constexpr DI16_FULL di16_full; 88 89 // TODO(veluca): consider SIMDfying. 90 void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) { 91 for (size_t x = 0; x < 8; x++) { 92 for (size_t y = x + 1; y < 8; y++) { 93 std::swap(block[y * 8 + x], block[x * 8 + y]); 94 } 95 } 96 } 97 98 template <ACType ac_type> 99 void DequantLane(Vec<D> scaled_dequant_x, Vec<D> scaled_dequant_y, 100 Vec<D> scaled_dequant_b, 101 const float* JXL_RESTRICT dequant_matrices, size_t size, 102 size_t k, Vec<D> x_cc_mul, Vec<D> b_cc_mul, 103 const float* JXL_RESTRICT biases, ACPtr qblock[3], 104 float* JXL_RESTRICT block) { 105 const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x); 106 const auto y_mul = 107 Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y); 108 const auto b_mul = 109 Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b); 110 111 Vec<DI> quantized_x_int; 112 Vec<DI> quantized_y_int; 113 Vec<DI> quantized_b_int; 114 if (ac_type == ACType::k16) { 115 Rebind<int16_t, DI> di16; 116 quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k)); 117 quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k)); 118 quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k)); 119 } else { 120 quantized_x_int = Load(di, qblock[0].ptr32 + k); 121 quantized_y_int = Load(di, qblock[1].ptr32 + k); 122 quantized_b_int = Load(di, qblock[2].ptr32 + k); 123 } 124 125 const auto dequant_x_cc = 126 Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul); 127 const auto dequant_y = 128 Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul); 129 const auto dequant_b_cc = 130 Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul); 131 132 const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc); 133 const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc); 134 Store(dequant_x, d, block + k); 135 Store(dequant_y, d, block + size + k); 136 Store(dequant_b, d, block + 2 * size + k); 137 } 138 139 template <ACType ac_type> 140 void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, 141 float x_dm_multiplier, float b_dm_multiplier, Vec<D> x_cc_mul, 142 Vec<D> b_cc_mul, AcStrategyType kind, size_t size, 143 const Quantizer& quantizer, size_t covered_blocks, 144 const size_t* sbx, 145 const float* JXL_RESTRICT* JXL_RESTRICT dc_row, 146 size_t dc_stride, const float* JXL_RESTRICT biases, 147 ACPtr qblock[3], float* JXL_RESTRICT block, 148 float* JXL_RESTRICT scratch) { 149 const auto scaled_dequant_s = inv_global_scale / quant; 150 151 const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier); 152 const auto scaled_dequant_y = Set(d, scaled_dequant_s); 153 const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier); 154 155 const float* dequant_matrices = quantizer.DequantMatrix(kind, 0); 156 157 for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) { 158 DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, 159 dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases, 160 qblock, block); 161 } 162 for (size_t c = 0; c < 3; c++) { 163 LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride, 164 block + c * size, scratch); 165 } 166 } 167 168 Status DecodeGroupImpl(const FrameHeader& frame_header, 169 GetBlock* JXL_RESTRICT get_block, 170 GroupDecCache* JXL_RESTRICT group_dec_cache, 171 PassesDecoderState* JXL_RESTRICT dec_state, 172 size_t thread, size_t group_idx, 173 RenderPipelineInput& render_pipeline_input, 174 jpeg::JPEGData* jpeg_data, DrawMode draw) { 175 // TODO(veluca): investigate cache usage in this function. 176 const Rect block_rect = 177 dec_state->shared->frame_dim.BlockGroupRect(group_idx); 178 const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy; 179 180 const size_t xsize_blocks = block_rect.xsize(); 181 const size_t ysize_blocks = block_rect.ysize(); 182 183 const size_t dc_stride = dec_state->shared->dc->PixelsPerRow(); 184 185 const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale(); 186 187 const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling; 188 189 const auto kJpegDctMin = Set(di16_full, -4095); 190 const auto kJpegDctMax = Set(di16_full, 4095); 191 192 size_t idct_stride[3]; 193 for (size_t c = 0; c < 3; c++) { 194 idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow(); 195 } 196 197 HWY_ALIGN int32_t scaled_qtable[64 * 3]; 198 199 ACType ac_type = dec_state->coefficients->Type(); 200 auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16> 201 : DequantBlock<ACType::k32>; 202 // Whether or not coefficients should be stored for future usage, and/or read 203 // from past usage. 204 bool accumulate = !dec_state->coefficients->IsEmpty(); 205 // Offset of the current block in the group. 206 size_t offset = 0; 207 208 std::array<int, 3> jpeg_c_map; 209 bool jpeg_is_gray = false; 210 std::array<int, 3> dcoff = {}; 211 212 // TODO(veluca): all of this should be done only once per image. 213 const ColorCorrelation& color_correlation = dec_state->shared->cmap.base(); 214 if (jpeg_data) { 215 if (!color_correlation.IsJPEGCompatible()) { 216 return JXL_FAILURE("The CfL map is not JPEG-compatible"); 217 } 218 jpeg_is_gray = (jpeg_data->components.size() == 1); 219 JXL_ENSURE(frame_header.color_transform != ColorTransform::kXYB); 220 jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray); 221 const std::vector<QuantEncoding>& qe = 222 dec_state->shared->matrices.encodings(); 223 if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || 224 std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { 225 return JXL_FAILURE( 226 "Quantization table is not a JPEG quantization table."); 227 } 228 JXL_ENSURE(qe[0].qraw.qtable->size() == 3 * 8 * 8); 229 int* qtable = qe[0].qraw.qtable->data(); 230 for (size_t c = 0; c < 3; c++) { 231 if (frame_header.color_transform == ColorTransform::kNone) { 232 dcoff[c] = 1024 / qtable[64 * c]; 233 } 234 for (size_t i = 0; i < 64; i++) { 235 // Transpose the matrix, as it will be used on the transposed block. 236 int n = qtable[64 + i]; 237 int d = qtable[64 * c + i]; 238 if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) { 239 return JXL_FAILURE("Invalid JPEG quantization table"); 240 } 241 scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] = 242 (1 << kCFLFixedPointPrecision) * n / d; 243 } 244 } 245 } 246 247 size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)}; 248 size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)}; 249 Rect r[3]; 250 for (size_t i = 0; i < 3; i++) { 251 r[i] = 252 Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i], 253 block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]); 254 if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(), 255 dec_state->shared->dc->Plane(i).ysize()})) { 256 return JXL_FAILURE("Frame dimensions are too big for the image."); 257 } 258 } 259 260 for (size_t by = 0; by < ysize_blocks; ++by) { 261 get_block->StartRow(by); 262 size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]}; 263 264 const int32_t* JXL_RESTRICT row_quant = 265 block_rect.ConstRow(dec_state->shared->raw_quant_field, by); 266 267 const float* JXL_RESTRICT dc_rows[3] = { 268 r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]), 269 r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]), 270 r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]), 271 }; 272 273 const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks; 274 AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); 275 276 const int8_t* JXL_RESTRICT row_cmap[3] = { 277 dec_state->shared->cmap.ytox_map.ConstRow(ty), 278 nullptr, 279 dec_state->shared->cmap.ytob_map.ConstRow(ty), 280 }; 281 282 float* JXL_RESTRICT idct_row[3]; 283 int16_t* JXL_RESTRICT jpeg_row[3]; 284 for (size_t c = 0; c < 3; c++) { 285 const auto& buffer = render_pipeline_input.GetBuffer(c); 286 idct_row[c] = buffer.second.Row(buffer.first, sby[c] * kBlockDim); 287 if (jpeg_data) { 288 auto& component = jpeg_data->components[jpeg_c_map[c]]; 289 jpeg_row[c] = 290 component.coeffs.data() + 291 (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) * 292 kDCTBlockSize; 293 } 294 } 295 296 size_t bx = 0; 297 for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); 298 tx++) { 299 size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks; 300 auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx])); 301 auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx])); 302 // Increment bx by llf_x because those iterations would otherwise 303 // immediately continue (!IsFirstBlock). Reduces mispredictions. 304 for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) { 305 size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]}; 306 AcStrategy acs = acs_row[bx]; 307 const size_t llf_x = acs.covered_blocks_x(); 308 309 // Can only happen in the second or lower rows of a varblock. 310 if (JXL_UNLIKELY(!acs.IsFirstBlock())) { 311 bx += llf_x; 312 continue; 313 } 314 const size_t log2_covered_blocks = acs.log2_covered_blocks(); 315 316 const size_t covered_blocks = 1 << log2_covered_blocks; 317 const size_t size = covered_blocks * kDCTBlockSize; 318 319 ACPtr qblock[3]; 320 if (accumulate) { 321 for (size_t c = 0; c < 3; c++) { 322 qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset); 323 } 324 } else { 325 // No point in reading from bitstream without accumulating and not 326 // drawing. 327 JXL_ENSURE(draw == kDraw); 328 if (ac_type == ACType::k16) { 329 memset(group_dec_cache->dec_group_qblock16, 0, 330 size * 3 * sizeof(int16_t)); 331 for (size_t c = 0; c < 3; c++) { 332 qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size; 333 } 334 } else { 335 memset(group_dec_cache->dec_group_qblock, 0, 336 size * 3 * sizeof(int32_t)); 337 for (size_t c = 0; c < 3; c++) { 338 qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size; 339 } 340 } 341 } 342 JXL_RETURN_IF_ERROR(get_block->LoadBlock( 343 bx, by, acs, size, log2_covered_blocks, qblock, ac_type)); 344 offset += size; 345 if (draw == kDontDraw) { 346 bx += llf_x; 347 continue; 348 } 349 350 if (JXL_UNLIKELY(jpeg_data)) { 351 if (acs.Strategy() != AcStrategyType::DCT) { 352 return JXL_FAILURE( 353 "Can only decode to JPEG if only DCT-8 is used."); 354 } 355 356 HWY_ALIGN int32_t transposed_dct_y[64]; 357 for (size_t c : {1, 0, 2}) { 358 // Propagate only Y for grayscale. 359 if (jpeg_is_gray && c != 1) { 360 continue; 361 } 362 if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { 363 continue; 364 } 365 int16_t* JXL_RESTRICT jpeg_pos = 366 jpeg_row[c] + sbx[c] * kDCTBlockSize; 367 // JPEG XL is transposed, JPEG is not. 368 auto* transposed_dct = qblock[c].ptr32; 369 Transpose8x8InPlace(transposed_dct); 370 // No CfL - no need to store the y block converted to integers. 371 if (!cs.Is444() || 372 (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) { 373 for (size_t i = 0; i < 64; i += Lanes(d)) { 374 const auto ini = Load(di, transposed_dct + i); 375 const auto ini16 = DemoteTo(di16, ini); 376 StoreU(ini16, di16, jpeg_pos + i); 377 } 378 } else if (c == 1) { 379 // Y channel: save for restoring X/B, but nothing else to do. 380 for (size_t i = 0; i < 64; i += Lanes(d)) { 381 const auto ini = Load(di, transposed_dct + i); 382 Store(ini, di, transposed_dct_y + i); 383 const auto ini16 = DemoteTo(di16, ini); 384 StoreU(ini16, di16, jpeg_pos + i); 385 } 386 } else { 387 // transposed_dct_y contains the y channel block, transposed. 388 const auto scale = 389 Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx])); 390 const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1)); 391 for (int i = 0; i < 64; i += Lanes(d)) { 392 auto in = Load(di, transposed_dct + i); 393 auto in_y = Load(di, transposed_dct_y + i); 394 auto qt = Load(di, scaled_qtable + c * size + i); 395 auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>( 396 Add(Mul(qt, scale), round)); 397 auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>( 398 Add(Mul(in_y, coeff_scale), round)); 399 StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i); 400 } 401 } 402 jpeg_pos[0] = 403 Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047); 404 auto overflow = MaskFromVec(Set(di16_full, 0)); 405 auto underflow = MaskFromVec(Set(di16_full, 0)); 406 for (int i = 0; i < 64; i += Lanes(di16_full)) { 407 auto in = LoadU(di16_full, jpeg_pos + i); 408 overflow = Or(overflow, Gt(in, kJpegDctMax)); 409 underflow = Or(underflow, Lt(in, kJpegDctMin)); 410 } 411 if (!AllFalse(di16_full, Or(overflow, underflow))) { 412 return JXL_FAILURE("JPEG DCT coefficients out of range"); 413 } 414 } 415 } else { 416 HWY_ALIGN float* const block = group_dec_cache->dec_group_block; 417 // Dequantize and add predictions. 418 dequant_block( 419 acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier, 420 dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.Strategy(), 421 size, dec_state->shared->quantizer, 422 acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows, 423 dc_stride, 424 dec_state->output_encoding_info.opsin_params.quant_biases, qblock, 425 block, group_dec_cache->scratch_space); 426 427 for (size_t c : {1, 0, 2}) { 428 if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { 429 continue; 430 } 431 // IDCT 432 float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim; 433 TransformToPixels(acs.Strategy(), block + c * size, idct_pos, 434 idct_stride[c], group_dec_cache->scratch_space); 435 } 436 } 437 bx += llf_x; 438 } 439 } 440 } 441 return true; 442 } 443 444 // NOLINTNEXTLINE(google-readability-namespace-comments) 445 } // namespace HWY_NAMESPACE 446 } // namespace jxl 447 HWY_AFTER_NAMESPACE(); 448 449 #if HWY_ONCE 450 namespace jxl { 451 namespace { 452 // Decode quantized AC coefficients of DCT blocks. 453 // LLF components in the output block will not be modified. 454 template <ACType ac_type, bool uses_lz77> 455 Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks, 456 int32_t* JXL_RESTRICT row_nzeros, 457 const int32_t* JXL_RESTRICT row_nzeros_top, 458 size_t nzeros_stride, size_t c, size_t bx, size_t by, 459 size_t lbx, AcStrategy acs, 460 const coeff_order_t* JXL_RESTRICT coeff_order, 461 BitReader* JXL_RESTRICT br, 462 ANSSymbolReader* JXL_RESTRICT decoder, 463 const std::vector<uint8_t>& context_map, 464 const uint8_t* qdc_row, const int32_t* qf_row, 465 const BlockCtxMap& block_ctx_map, ACPtr block, 466 size_t shift = 0) { 467 // Equal to number of LLF coefficients. 468 const size_t covered_blocks = 1 << log2_covered_blocks; 469 const size_t size = covered_blocks * kDCTBlockSize; 470 int32_t predicted_nzeros = 471 PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32); 472 473 size_t ord = kStrategyOrder[acs.RawStrategy()]; 474 const coeff_order_t* JXL_RESTRICT order = 475 &coeff_order[CoeffOrderOffset(ord, c)]; 476 477 size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c); 478 const int32_t nzero_ctx = 479 block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset; 480 481 size_t nzeros = 482 decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map); 483 if (nzeros > size - covered_blocks) { 484 return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS 485 " 8x8 blocks", 486 nzeros, covered_blocks); 487 } 488 for (size_t y = 0; y < acs.covered_blocks_y(); y++) { 489 for (size_t x = 0; x < acs.covered_blocks_x(); x++) { 490 row_nzeros[bx + x + y * nzeros_stride] = 491 (nzeros + covered_blocks - 1) >> log2_covered_blocks; 492 } 493 } 494 495 const size_t histo_offset = 496 ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx); 497 498 size_t prev = (nzeros > size / 16 ? 0 : 1); 499 for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { 500 const size_t ctx = 501 histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, 502 log2_covered_blocks, prev); 503 const size_t u_coeff = 504 decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map); 505 // Hand-rolled version of UnpackSigned, shifting before the conversion to 506 // signed integer to avoid undefined behavior of shifting negative numbers. 507 const size_t magnitude = u_coeff >> 1; 508 const size_t neg_sign = (~u_coeff) & 1; 509 const intptr_t coeff = 510 static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift); 511 if (ac_type == ACType::k16) { 512 block.ptr16[order[k]] += coeff; 513 } else { 514 block.ptr32[order[k]] += coeff; 515 } 516 prev = static_cast<size_t>(u_coeff != 0); 517 nzeros -= prev; 518 } 519 if (JXL_UNLIKELY(nzeros != 0)) { 520 return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS 521 ", should be 0. Block (%" PRIuS ", %" PRIuS 522 "), channel %" PRIuS, 523 nzeros, bx, by, c); 524 } 525 526 return true; 527 } 528 529 // Structs used by DecodeGroupImpl to get a quantized block. 530 // GetBlockFromBitstream uses ANS decoding (and thus keeps track of row 531 // pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient 532 // image provided by the encoder. 533 534 struct GetBlockFromBitstream : public GetBlock { 535 void StartRow(size_t by) override { 536 qf_row = rect.ConstRow(*qf, by); 537 for (size_t c = 0; c < 3; c++) { 538 size_t sby = by >> vshift[c]; 539 quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0(); 540 for (size_t i = 0; i < num_passes; i++) { 541 row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby); 542 row_nzeros_top[i][c] = 543 sby == 0 544 ? nullptr 545 : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1); 546 } 547 } 548 } 549 550 Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, 551 size_t log2_covered_blocks, ACPtr block[3], 552 ACType ac_type) override { 553 ; 554 for (size_t c : {1, 0, 2}) { 555 size_t sbx = bx >> hshift[c]; 556 size_t sby = by >> vshift[c]; 557 if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) { 558 continue; 559 } 560 561 for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) { 562 auto decode_ac_varblock = 563 decoders[pass].UsesLZ77() 564 ? (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 1> 565 : DecodeACVarBlock<ACType::k32, 1>) 566 : (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 0> 567 : DecodeACVarBlock<ACType::k32, 0>); 568 JXL_RETURN_IF_ERROR(decode_ac_varblock( 569 ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c], 570 row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs, 571 &coeff_orders[pass * coeff_order_size], readers[pass], 572 &decoders[pass], context_map[pass], quant_dc_row, qf_row, 573 *block_ctx_map, block[c], shift_for_pass[pass])); 574 } 575 } 576 return true; 577 } 578 579 Status Init(const FrameHeader& frame_header, 580 BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes, 581 size_t group_idx, size_t histo_selector_bits, const Rect& rect, 582 GroupDecCache* JXL_RESTRICT group_dec_cache, 583 PassesDecoderState* dec_state, size_t first_pass) { 584 for (size_t i = 0; i < 3; i++) { 585 hshift[i] = frame_header.chroma_subsampling.HShift(i); 586 vshift[i] = frame_header.chroma_subsampling.VShift(i); 587 } 588 this->coeff_order_size = dec_state->shared->coeff_order_size; 589 this->coeff_orders = 590 dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size; 591 this->context_map = dec_state->context_map.data() + first_pass; 592 this->readers = readers; 593 this->num_passes = num_passes; 594 this->shift_for_pass = frame_header.passes.shift + first_pass; 595 this->group_dec_cache = group_dec_cache; 596 this->rect = rect; 597 block_ctx_map = &dec_state->shared->block_ctx_map; 598 qf = &dec_state->shared->raw_quant_field; 599 quant_dc = &dec_state->shared->quant_dc; 600 601 for (size_t pass = 0; pass < num_passes; pass++) { 602 // Select which histogram set to use among those of the current pass. 603 size_t cur_histogram = 0; 604 if (histo_selector_bits != 0) { 605 cur_histogram = readers[pass]->ReadBits(histo_selector_bits); 606 } 607 if (cur_histogram >= dec_state->shared->num_histograms) { 608 return JXL_FAILURE("Invalid histogram selector"); 609 } 610 ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts(); 611 612 JXL_ASSIGN_OR_RETURN( 613 decoders[pass], 614 ANSSymbolReader::Create(&dec_state->code[pass + first_pass], 615 readers[pass])); 616 } 617 nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow(); 618 for (size_t i = 0; i < num_passes; i++) { 619 JXL_ENSURE( 620 nzeros_stride == 621 static_cast<size_t>(group_dec_cache->num_nzeroes[i].PixelsPerRow())); 622 } 623 return true; 624 } 625 626 const uint32_t* shift_for_pass = nullptr; // not owned 627 const coeff_order_t* JXL_RESTRICT coeff_orders; 628 size_t coeff_order_size; 629 const std::vector<uint8_t>* JXL_RESTRICT context_map; 630 ANSSymbolReader decoders[kMaxNumPasses]; 631 BitReader* JXL_RESTRICT* JXL_RESTRICT readers; 632 size_t num_passes; 633 size_t ctx_offset[kMaxNumPasses]; 634 size_t nzeros_stride; 635 int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3]; 636 const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3]; 637 GroupDecCache* JXL_RESTRICT group_dec_cache; 638 const BlockCtxMap* block_ctx_map; 639 const ImageI* qf; 640 const ImageB* quant_dc; 641 const int32_t* qf_row; 642 const uint8_t* quant_dc_row; 643 Rect rect; 644 size_t hshift[3], vshift[3]; 645 }; 646 647 struct GetBlockFromEncoder : public GetBlock { 648 void StartRow(size_t by) override {} 649 650 Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, 651 size_t log2_covered_blocks, ACPtr block[3], 652 ACType ac_type) override { 653 JXL_ENSURE(ac_type == ACType::k32); 654 for (size_t c = 0; c < 3; c++) { 655 // for each pass 656 for (size_t i = 0; i < quantized_ac->size(); i++) { 657 for (size_t k = 0; k < size; k++) { 658 // TODO(veluca): SIMD. 659 block[c].ptr32[k] += 660 rows[i][c][offset + k] * (1 << shift_for_pass[i]); 661 } 662 } 663 } 664 offset += size; 665 return true; 666 } 667 668 static StatusOr<GetBlockFromEncoder> Create( 669 const std::vector<std::unique_ptr<ACImage>>& ac, size_t group_idx, 670 const uint32_t* shift_for_pass) { 671 GetBlockFromEncoder result(ac, group_idx, shift_for_pass); 672 // TODO(veluca): not supported with chroma subsampling. 673 for (size_t i = 0; i < ac.size(); i++) { 674 JXL_ENSURE(ac[i]->Type() == ACType::k32); 675 for (size_t c = 0; c < 3; c++) { 676 result.rows[i][c] = ac[i]->PlaneRow(c, group_idx, 0).ptr32; 677 } 678 } 679 return result; 680 } 681 682 const std::vector<std::unique_ptr<ACImage>>* JXL_RESTRICT quantized_ac; 683 size_t offset = 0; 684 const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3]; 685 const uint32_t* shift_for_pass = nullptr; // not owned 686 687 private: 688 GetBlockFromEncoder(const std::vector<std::unique_ptr<ACImage>>& ac, 689 size_t group_idx, const uint32_t* shift_for_pass) 690 : quantized_ac(&ac), shift_for_pass(shift_for_pass) {} 691 }; 692 693 HWY_EXPORT(DecodeGroupImpl); 694 695 } // namespace 696 697 Status DecodeGroup(const FrameHeader& frame_header, 698 BitReader* JXL_RESTRICT* JXL_RESTRICT readers, 699 size_t num_passes, size_t group_idx, 700 PassesDecoderState* JXL_RESTRICT dec_state, 701 GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, 702 RenderPipelineInput& render_pipeline_input, 703 jpeg::JPEGData* JXL_RESTRICT jpeg_data, size_t first_pass, 704 bool force_draw, bool dc_only, bool* should_run_pipeline) { 705 JxlMemoryManager* memory_manager = dec_state->memory_manager(); 706 DrawMode draw = 707 (num_passes + first_pass == frame_header.passes.num_passes) || force_draw 708 ? kDraw 709 : kDontDraw; 710 711 if (should_run_pipeline) { 712 *should_run_pipeline = draw != kDontDraw; 713 } 714 715 if (draw == kDraw && num_passes == 0 && first_pass == 0) { 716 JXL_RETURN_IF_ERROR(group_dec_cache->InitDCBufferOnce(memory_manager)); 717 const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling; 718 for (size_t c : {0, 1, 2}) { 719 size_t hs = cs.HShift(c); 720 size_t vs = cs.VShift(c); 721 // We reuse filter_input_storage here as it is not currently in use. 722 const Rect src_rect_precs = 723 dec_state->shared->frame_dim.BlockGroupRect(group_idx); 724 const Rect src_rect = 725 Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs, 726 src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs); 727 const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(), 728 src_rect.ysize()); 729 JXL_RETURN_IF_ERROR( 730 CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2, 731 copy_rect, &group_dec_cache->dc_buffer)); 732 // Mirrorpad. Interleaving left and right padding ensures that padding 733 // works out correctly even for images with DC size of 1. 734 for (size_t y = 0; y < src_rect.ysize() + 4; y++) { 735 size_t xend = kRenderPipelineXOffset + 736 (dec_state->shared->dc->Plane(c).xsize() >> hs) - 737 src_rect.x0(); 738 for (size_t ix = 0; ix < 2; ix++) { 739 if (src_rect.x0() == 0) { 740 group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] = 741 group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix]; 742 } 743 if (src_rect.x0() + src_rect.xsize() + 2 >= 744 (dec_state->shared->dc->xsize() >> hs)) { 745 group_dec_cache->dc_buffer.Row(y)[xend + ix] = 746 group_dec_cache->dc_buffer.Row(y)[xend - ix - 1]; 747 } 748 } 749 } 750 const auto& buffer = render_pipeline_input.GetBuffer(c); 751 Rect dst_rect = buffer.second; 752 ImageF* upsampling_dst = buffer.first; 753 JXL_ENSURE(dst_rect.IsInside(*upsampling_dst)); 754 755 RenderPipelineStage::RowInfo input_rows(1, std::vector<float*>(5)); 756 RenderPipelineStage::RowInfo output_rows(1, std::vector<float*>(8)); 757 for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize(); 758 y++) { 759 for (ssize_t iy = 0; iy < 5; iy++) { 760 input_rows[0][iy] = group_dec_cache->dc_buffer.Row( 761 Mirror(static_cast<ssize_t>(y) + iy - 2, 762 dec_state->shared->dc->Plane(c).ysize() >> vs) + 763 2 - src_rect.y0()); 764 } 765 for (size_t iy = 0; iy < 8; iy++) { 766 output_rows[0][iy] = 767 dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) - 768 kRenderPipelineXOffset; 769 } 770 // Arguments set to 0/nullptr are not used. 771 JXL_RETURN_IF_ERROR(dec_state->upsampler8x->ProcessRow( 772 input_rows, output_rows, 773 /*xextra=*/0, src_rect.xsize(), 0, 0, thread)); 774 } 775 } 776 return true; 777 } 778 779 size_t histo_selector_bits = 0; 780 if (dc_only) { 781 JXL_ENSURE(num_passes == 0); 782 } else { 783 JXL_ENSURE(dec_state->shared->num_histograms > 0); 784 histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms); 785 } 786 787 auto get_block = jxl::make_unique<GetBlockFromBitstream>(); 788 JXL_RETURN_IF_ERROR(get_block->Init( 789 frame_header, readers, num_passes, group_idx, histo_selector_bits, 790 dec_state->shared->frame_dim.BlockGroupRect(group_idx), group_dec_cache, 791 dec_state, first_pass)); 792 793 JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( 794 frame_header, get_block.get(), group_dec_cache, dec_state, thread, 795 group_idx, render_pipeline_input, jpeg_data, draw)); 796 797 for (size_t pass = 0; pass < num_passes; pass++) { 798 if (!get_block->decoders[pass].CheckANSFinalState()) { 799 return JXL_FAILURE("ANS checksum failure."); 800 } 801 } 802 return true; 803 } 804 805 Status DecodeGroupForRoundtrip(const FrameHeader& frame_header, 806 const std::vector<std::unique_ptr<ACImage>>& ac, 807 size_t group_idx, 808 PassesDecoderState* JXL_RESTRICT dec_state, 809 GroupDecCache* JXL_RESTRICT group_dec_cache, 810 size_t thread, 811 RenderPipelineInput& render_pipeline_input, 812 jpeg::JPEGData* JXL_RESTRICT jpeg_data, 813 AuxOut* aux_out) { 814 JxlMemoryManager* memory_manager = dec_state->memory_manager(); 815 JXL_ASSIGN_OR_RETURN( 816 GetBlockFromEncoder get_block, 817 GetBlockFromEncoder::Create(ac, group_idx, frame_header.passes.shift)); 818 JXL_RETURN_IF_ERROR(group_dec_cache->InitOnce( 819 memory_manager, 820 /*num_passes=*/0, 821 /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1)); 822 823 return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( 824 frame_header, &get_block, group_dec_cache, dec_state, thread, group_idx, 825 render_pipeline_input, jpeg_data, kDraw); 826 } 827 828 } // namespace jxl 829 #endif // HWY_ONCE