dec_modular.cc (33601B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_modular.h" 7 8 #include <jxl/memory_manager.h> 9 10 #include <cstdint> 11 #include <vector> 12 13 #include "lib/jxl/frame_header.h" 14 15 #undef HWY_TARGET_INCLUDE 16 #define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc" 17 #include <hwy/foreach_target.h> 18 #include <hwy/highway.h> 19 20 #include "lib/jxl/base/compiler_specific.h" 21 #include "lib/jxl/base/printf_macros.h" 22 #include "lib/jxl/base/rect.h" 23 #include "lib/jxl/base/status.h" 24 #include "lib/jxl/compressed_dc.h" 25 #include "lib/jxl/epf.h" 26 #include "lib/jxl/modular/encoding/encoding.h" 27 #include "lib/jxl/modular/modular_image.h" 28 #include "lib/jxl/modular/transform/transform.h" 29 30 HWY_BEFORE_NAMESPACE(); 31 namespace jxl { 32 namespace HWY_NAMESPACE { 33 34 // These templates are not found via ADL. 35 using hwy::HWY_NAMESPACE::Add; 36 using hwy::HWY_NAMESPACE::Mul; 37 using hwy::HWY_NAMESPACE::Rebind; 38 39 void MultiplySum(const size_t xsize, 40 const pixel_type* const JXL_RESTRICT row_in, 41 const pixel_type* const JXL_RESTRICT row_in_Y, 42 const float factor, float* const JXL_RESTRICT row_out) { 43 const HWY_FULL(float) df; 44 const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float 45 const auto factor_v = Set(df, factor); 46 for (size_t x = 0; x < xsize; x += Lanes(di)) { 47 const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x)); 48 const auto out = Mul(ConvertTo(df, in), factor_v); 49 Store(out, df, row_out + x); 50 } 51 } 52 53 void RgbFromSingle(const size_t xsize, 54 const pixel_type* const JXL_RESTRICT row_in, 55 const float factor, float* out_r, float* out_g, 56 float* out_b) { 57 const HWY_FULL(float) df; 58 const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float 59 60 const auto factor_v = Set(df, factor); 61 for (size_t x = 0; x < xsize; x += Lanes(di)) { 62 const auto in = Load(di, row_in + x); 63 const auto out = Mul(ConvertTo(df, in), factor_v); 64 Store(out, df, out_r + x); 65 Store(out, df, out_g + x); 66 Store(out, df, out_b + x); 67 } 68 } 69 70 void SingleFromSingle(const size_t xsize, 71 const pixel_type* const JXL_RESTRICT row_in, 72 const float factor, float* row_out) { 73 const HWY_FULL(float) df; 74 const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float 75 76 const auto factor_v = Set(df, factor); 77 for (size_t x = 0; x < xsize; x += Lanes(di)) { 78 const auto in = Load(di, row_in + x); 79 const auto out = Mul(ConvertTo(df, in), factor_v); 80 Store(out, df, row_out + x); 81 } 82 } 83 // NOLINTNEXTLINE(google-readability-namespace-comments) 84 } // namespace HWY_NAMESPACE 85 } // namespace jxl 86 HWY_AFTER_NAMESPACE(); 87 88 #if HWY_ONCE 89 namespace jxl { 90 HWY_EXPORT(MultiplySum); // Local function 91 HWY_EXPORT(RgbFromSingle); // Local function 92 HWY_EXPORT(SingleFromSingle); // Local function 93 94 // Slow conversion using double precision multiplication, only 95 // needed when the bit depth is too high for single precision 96 void SingleFromSingleAccurate(const size_t xsize, 97 const pixel_type* const JXL_RESTRICT row_in, 98 const double factor, float* row_out) { 99 for (size_t x = 0; x < xsize; x++) { 100 row_out[x] = row_in[x] * factor; 101 } 102 } 103 104 // convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int 105 // back to binary32 float 106 Status int_to_float(const pixel_type* const JXL_RESTRICT row_in, 107 float* const JXL_RESTRICT row_out, const size_t xsize, 108 const int bits, const int exp_bits) { 109 static_assert(sizeof(pixel_type) == sizeof(float)); 110 if (bits == 32) { 111 JXL_ENSURE(exp_bits == 8); 112 memcpy(row_out, row_in, xsize * sizeof(float)); 113 return true; 114 } 115 int exp_bias = (1 << (exp_bits - 1)) - 1; 116 int sign_shift = bits - 1; 117 int mant_bits = bits - exp_bits - 1; 118 int mant_shift = 23 - mant_bits; 119 for (size_t x = 0; x < xsize; ++x) { 120 uint32_t f; 121 memcpy(&f, &row_in[x], 4); 122 int signbit = (f >> sign_shift); 123 f &= (1 << sign_shift) - 1; 124 if (f == 0) { 125 row_out[x] = (signbit ? -0.f : 0.f); 126 continue; 127 } 128 int exp = (f >> mant_bits); 129 int mantissa = (f & ((1 << mant_bits) - 1)); 130 mantissa <<= mant_shift; 131 // Try to normalize only if there is space for maneuver. 132 if (exp == 0 && exp_bits < 8) { 133 // subnormal number 134 while ((mantissa & 0x800000) == 0) { 135 mantissa <<= 1; 136 exp--; 137 } 138 exp++; 139 // remove leading 1 because it is implicit now 140 mantissa &= 0x7fffff; 141 } 142 exp -= exp_bias; 143 // broke up the arbitrary float into its parts, now reassemble into 144 // binary32 145 exp += 127; 146 JXL_ENSURE(exp >= 0); 147 f = (signbit ? 0x80000000 : 0); 148 f |= (exp << 23); 149 f |= mantissa; 150 memcpy(&row_out[x], &f, 4); 151 } 152 return true; 153 } 154 155 #if JXL_DEBUG_V_LEVEL >= 1 156 std::string ModularStreamId::DebugString() const { 157 std::ostringstream os; 158 os << (kind == GlobalData ? "ModularGlobal" 159 : kind == VarDCTDC ? "VarDCTDC" 160 : kind == ModularDC ? "ModularDC" 161 : kind == ACMetadata ? "ACMeta" 162 : kind == QuantTable ? "QuantTable" 163 : kind == ModularAC ? "ModularAC" 164 : ""); 165 if (kind == VarDCTDC || kind == ModularDC || kind == ACMetadata || 166 kind == ModularAC) { 167 os << " group " << group_id; 168 } 169 if (kind == ModularAC) { 170 os << " pass " << pass_id; 171 } 172 if (kind == QuantTable) { 173 os << " " << quant_table_id; 174 } 175 return os.str(); 176 } 177 #endif 178 179 Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, 180 const FrameHeader& frame_header, 181 bool allow_truncated_group) { 182 JxlMemoryManager* memory_manager = this->memory_manager(); 183 bool decode_color = frame_header.encoding == FrameEncoding::kModular; 184 const auto& metadata = frame_header.nonserialized_metadata->m; 185 bool is_gray = metadata.color_encoding.IsGray(); 186 size_t nb_chans = 3; 187 if (is_gray && frame_header.color_transform == ColorTransform::kNone) { 188 nb_chans = 1; 189 } 190 do_color = decode_color; 191 size_t nb_extra = metadata.extra_channel_info.size(); 192 bool has_tree = static_cast<bool>(reader->ReadBits(1)); 193 if (!allow_truncated_group || 194 reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) { 195 if (has_tree) { 196 size_t tree_size_limit = 197 std::min(static_cast<size_t>(1 << 22), 198 1024 + frame_dim.xsize * frame_dim.ysize * 199 (nb_chans + nb_extra) / 16); 200 JXL_RETURN_IF_ERROR( 201 DecodeTree(memory_manager, reader, &tree, tree_size_limit)); 202 JXL_RETURN_IF_ERROR(DecodeHistograms( 203 memory_manager, reader, (tree.size() + 1) / 2, &code, &context_map)); 204 } 205 } 206 if (!do_color) nb_chans = 0; 207 208 bool fp = metadata.bit_depth.floating_point_sample; 209 210 // bits_per_sample is just metadata for XYB images. 211 if (metadata.bit_depth.bits_per_sample >= 32 && do_color && 212 frame_header.color_transform != ColorTransform::kXYB) { 213 if (metadata.bit_depth.bits_per_sample == 32 && fp == false) { 214 return JXL_FAILURE("uint32_t not supported in dec_modular"); 215 } else if (metadata.bit_depth.bits_per_sample > 32) { 216 return JXL_FAILURE("bits_per_sample > 32 not supported"); 217 } 218 } 219 220 JXL_ASSIGN_OR_RETURN( 221 Image gi, 222 Image::Create(memory_manager, frame_dim.xsize, frame_dim.ysize, 223 metadata.bit_depth.bits_per_sample, nb_chans + nb_extra)); 224 225 all_same_shift = true; 226 if (frame_header.color_transform == ColorTransform::kYCbCr) { 227 for (size_t c = 0; c < nb_chans; c++) { 228 gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c); 229 gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c); 230 size_t xsize_shifted = 231 DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift); 232 size_t ysize_shifted = 233 DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift); 234 JXL_RETURN_IF_ERROR(gi.channel[c].shrink(xsize_shifted, ysize_shifted)); 235 if (gi.channel[c].hshift != gi.channel[0].hshift || 236 gi.channel[c].vshift != gi.channel[0].vshift) 237 all_same_shift = false; 238 } 239 } 240 241 for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) { 242 size_t ecups = frame_header.extra_channel_upsampling[ec]; 243 JXL_RETURN_IF_ERROR( 244 gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups), 245 DivCeil(frame_dim.ysize_upsampled, ecups))); 246 gi.channel[c].hshift = gi.channel[c].vshift = 247 CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling); 248 if (gi.channel[c].hshift != gi.channel[0].hshift || 249 gi.channel[c].vshift != gi.channel[0].vshift) 250 all_same_shift = false; 251 } 252 253 JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s", 254 gi.DebugString().c_str()); 255 ModularOptions options; 256 options.max_chan_size = frame_dim.group_dim; 257 options.group_dim = frame_dim.group_dim; 258 Status dec_status = ModularGenericDecompress( 259 reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim), 260 &options, 261 /*undo_transforms=*/false, &tree, &code, &context_map, 262 allow_truncated_group); 263 if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); 264 if (dec_status.IsFatalError()) { 265 return JXL_FAILURE("Failed to decode global modular info"); 266 } 267 268 // TODO(eustas): are we sure this can be done after partial decode? 269 have_something = false; 270 for (size_t c = 0; c < gi.channel.size(); c++) { 271 Channel& gic = gi.channel[c]; 272 if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim && 273 gic.h <= frame_dim.group_dim) 274 have_something = true; 275 } 276 // move global transforms to groups if possible 277 if (!have_something && all_same_shift) { 278 if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) { 279 global_transform = gi.transform; 280 gi.transform.clear(); 281 // TODO(jon): also move no-delta-palette out (trickier though) 282 } 283 } 284 full_image = std::move(gi); 285 JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s", 286 full_image.DebugString().c_str()); 287 return dec_status; 288 } 289 290 void ModularFrameDecoder::MaybeDropFullImage() { 291 if (full_image.transform.empty() && !have_something && all_same_shift) { 292 use_full_image = false; 293 JXL_DEBUG_V(6, "Dropping full image"); 294 for (auto& ch : full_image.channel) { 295 // keep metadata on channels around, but dealloc their planes 296 ch.plane = Plane<pixel_type>(); 297 } 298 } 299 } 300 301 Status ModularFrameDecoder::DecodeGroup( 302 const FrameHeader& frame_header, const Rect& rect, BitReader* reader, 303 int minShift, int maxShift, const ModularStreamId& stream, bool zerofill, 304 PassesDecoderState* dec_state, RenderPipelineInput* render_pipeline_input, 305 bool allow_truncated, bool* should_run_pipeline) { 306 JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s", 307 stream.DebugString().c_str(), Description(rect).c_str(), minShift, 308 maxShift, zerofill ? "using zerofill" : ""); 309 JXL_ENSURE(stream.kind == ModularStreamId::Kind::ModularDC || 310 stream.kind == ModularStreamId::Kind::ModularAC); 311 const size_t xsize = rect.xsize(); 312 const size_t ysize = rect.ysize(); 313 JXL_ASSIGN_OR_RETURN(Image gi, Image::Create(memory_manager_, xsize, ysize, 314 full_image.bitdepth, 0)); 315 // start at the first bigger-than-groupsize non-metachannel 316 size_t c = full_image.nb_meta_channels; 317 for (; c < full_image.channel.size(); c++) { 318 Channel& fc = full_image.channel[c]; 319 if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break; 320 } 321 size_t beginc = c; 322 for (; c < full_image.channel.size(); c++) { 323 Channel& fc = full_image.channel[c]; 324 int shift = std::min(fc.hshift, fc.vshift); 325 if (shift > maxShift) continue; 326 if (shift < minShift) continue; 327 Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, 328 rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); 329 if (r.xsize() == 0 || r.ysize() == 0) continue; 330 if (zerofill && use_full_image) { 331 for (size_t y = 0; y < r.ysize(); ++y) { 332 pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y); 333 memset(row_out, 0, r.xsize() * sizeof(*row_out)); 334 } 335 } else { 336 JXL_ASSIGN_OR_RETURN( 337 Channel gc, Channel::Create(memory_manager_, r.xsize(), r.ysize())); 338 if (zerofill) ZeroFillImage(&gc.plane); 339 gc.hshift = fc.hshift; 340 gc.vshift = fc.vshift; 341 gi.channel.emplace_back(std::move(gc)); 342 } 343 } 344 if (zerofill && use_full_image) return true; 345 // Return early if there's nothing to decode. Otherwise there might be 346 // problems later (in ModularImageToDecodedRect). 347 if (gi.channel.empty()) { 348 if (dec_state && should_run_pipeline) { 349 const auto* metadata = frame_header.nonserialized_metadata; 350 if (do_color || metadata->m.num_extra_channels > 0) { 351 // Signal to FrameDecoder that we do not have some of the required input 352 // for the render pipeline. 353 *should_run_pipeline = false; 354 } 355 } 356 JXL_DEBUG_V(6, "Nothing to decode, returning early."); 357 return true; 358 } 359 ModularOptions options; 360 if (!zerofill) { 361 auto status = ModularGenericDecompress( 362 reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options, 363 /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated); 364 if (!allow_truncated) JXL_RETURN_IF_ERROR(status); 365 if (status.IsFatalError()) return status; 366 } 367 // Undo global transforms that have been pushed to the group level 368 if (!use_full_image) { 369 JXL_ENSURE(render_pipeline_input); 370 for (const auto& t : global_transform) { 371 JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header)); 372 } 373 JXL_RETURN_IF_ERROR(ModularImageToDecodedRect( 374 frame_header, gi, dec_state, nullptr, *render_pipeline_input, 375 Rect(0, 0, gi.w, gi.h))); 376 return true; 377 } 378 int gic = 0; 379 for (c = beginc; c < full_image.channel.size(); c++) { 380 Channel& fc = full_image.channel[c]; 381 int shift = std::min(fc.hshift, fc.vshift); 382 if (shift > maxShift) continue; 383 if (shift < minShift) continue; 384 Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, 385 rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); 386 if (r.xsize() == 0 || r.ysize() == 0) continue; 387 JXL_ENSURE(use_full_image); 388 JXL_RETURN_IF_ERROR( 389 CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()), 390 /*from=*/gi.channel[gic].plane, 391 /*rect_to=*/r, /*to=*/&fc.plane)); 392 gic++; 393 } 394 return true; 395 } 396 397 Status ModularFrameDecoder::DecodeVarDCTDC(const FrameHeader& frame_header, 398 size_t group_id, BitReader* reader, 399 PassesDecoderState* dec_state) { 400 JxlMemoryManager* memory_manager = dec_state->memory_manager(); 401 const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id); 402 JXL_DEBUG_V(6, "Decoding VarDCT DC with rect %s", Description(r).c_str()); 403 // TODO(eustas): investigate if we could reduce the impact of 404 // EvalRationalPolynomial; generally speaking, the limit is 405 // 2**(128/(3*magic)), where 128 comes from IEEE 754 exponent, 406 // 3 comes from XybToRgb that cubes the values, and "magic" is 407 // the sum of all other contributions. 2**18 is known to lead 408 // to NaN on input found by fuzzing (see commit message). 409 JXL_ASSIGN_OR_RETURN(Image image, 410 Image::Create(memory_manager, r.xsize(), r.ysize(), 411 full_image.bitdepth, 3)); 412 size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim); 413 reader->Refill(); 414 size_t extra_precision = reader->ReadFixedBits<2>(); 415 float mul = 1.0f / (1 << extra_precision); 416 ModularOptions options; 417 for (size_t c = 0; c < 3; c++) { 418 Channel& ch = image.channel[c < 2 ? c ^ 1 : c]; 419 ch.w >>= frame_header.chroma_subsampling.HShift(c); 420 ch.h >>= frame_header.chroma_subsampling.VShift(c); 421 JXL_RETURN_IF_ERROR(ch.shrink()); 422 } 423 if (!ModularGenericDecompress( 424 reader, image, /*header=*/nullptr, stream_id, &options, 425 /*undo_transforms=*/true, &tree, &code, &context_map)) { 426 return JXL_FAILURE("Failed to decode VarDCT DC group (DC group id %d)", 427 static_cast<int>(group_id)); 428 } 429 DequantDC(r, &dec_state->shared_storage.dc_storage, 430 &dec_state->shared_storage.quant_dc, image, 431 dec_state->shared->quantizer.MulDC(), mul, 432 dec_state->shared->cmap.base().DCFactors(), 433 frame_header.chroma_subsampling, dec_state->shared->block_ctx_map); 434 return true; 435 } 436 437 Status ModularFrameDecoder::DecodeAcMetadata(const FrameHeader& frame_header, 438 size_t group_id, BitReader* reader, 439 PassesDecoderState* dec_state) { 440 JxlMemoryManager* memory_manager = dec_state->memory_manager(); 441 const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id); 442 JXL_DEBUG_V(6, "Decoding AcMetadata with rect %s", Description(r).c_str()); 443 size_t upper_bound = r.xsize() * r.ysize(); 444 reader->Refill(); 445 size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1; 446 size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim); 447 // YToX, YToB, ACS + QF, EPF 448 JXL_ASSIGN_OR_RETURN(Image image, 449 Image::Create(memory_manager, r.xsize(), r.ysize(), 450 full_image.bitdepth, 4)); 451 static_assert(kColorTileDimInBlocks == 8, "Color tile size changed"); 452 Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3); 453 JXL_ASSIGN_OR_RETURN( 454 image.channel[0], 455 Channel::Create(memory_manager, cr.xsize(), cr.ysize(), 3, 3)); 456 JXL_ASSIGN_OR_RETURN( 457 image.channel[1], 458 Channel::Create(memory_manager, cr.xsize(), cr.ysize(), 3, 3)); 459 JXL_ASSIGN_OR_RETURN(image.channel[2], 460 Channel::Create(memory_manager, count, 2, 0, 0)); 461 ModularOptions options; 462 if (!ModularGenericDecompress( 463 reader, image, /*header=*/nullptr, stream_id, &options, 464 /*undo_transforms=*/true, &tree, &code, &context_map)) { 465 return JXL_FAILURE("Failed to decode AC metadata"); 466 } 467 JXL_RETURN_IF_ERROR( 468 ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, 469 cr, &dec_state->shared_storage.cmap.ytox_map)); 470 JXL_RETURN_IF_ERROR( 471 ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, 472 cr, &dec_state->shared_storage.cmap.ytob_map)); 473 size_t num = 0; 474 bool is444 = frame_header.chroma_subsampling.Is444(); 475 auto& ac_strategy = dec_state->shared_storage.ac_strategy; 476 size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize()); 477 size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize()); 478 uint32_t local_used_acs = 0; 479 for (size_t iy = 0; iy < r.ysize(); iy++) { 480 size_t y = r.y0() + iy; 481 int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy); 482 uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy); 483 int32_t* row_in_1 = image.channel[2].plane.Row(0); 484 int32_t* row_in_2 = image.channel[2].plane.Row(1); 485 int32_t* row_in_3 = image.channel[3].plane.Row(iy); 486 for (size_t ix = 0; ix < r.xsize(); ix++) { 487 size_t x = r.x0() + ix; 488 int sharpness = row_in_3[ix]; 489 if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) { 490 return JXL_FAILURE("Corrupted sharpness field"); 491 } 492 row_epf[ix] = sharpness; 493 if (ac_strategy.IsValid(x, y)) { 494 continue; 495 } 496 497 if (num >= count) return JXL_FAILURE("Corrupted stream"); 498 499 if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) { 500 return JXL_FAILURE("Invalid AC strategy"); 501 } 502 local_used_acs |= 1u << row_in_1[num]; 503 AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]); 504 if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) && 505 !is444) { 506 return JXL_FAILURE( 507 "AC strategy not compatible with chroma subsampling"); 508 } 509 // Ensure that blocks do not overflow *AC* groups. 510 size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks; 511 size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks; 512 size_t next_x_dct_block = x + acs.covered_blocks_x(); 513 size_t next_y_dct_block = y + acs.covered_blocks_y(); 514 if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) { 515 return JXL_FAILURE("Invalid AC strategy, x overflow"); 516 } 517 if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) { 518 return JXL_FAILURE("Invalid AC strategy, y overflow"); 519 } 520 JXL_RETURN_IF_ERROR( 521 ac_strategy.SetNoBoundsCheck(x, y, AcStrategyType(row_in_1[num]))); 522 row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1, 523 row_in_2[num])); 524 num++; 525 } 526 } 527 dec_state->used_acs |= local_used_acs; 528 if (frame_header.loop_filter.epf_iters > 0) { 529 JXL_RETURN_IF_ERROR(ComputeSigma(frame_header.loop_filter, r, dec_state)); 530 } 531 return true; 532 } 533 534 Status ModularFrameDecoder::ModularImageToDecodedRect( 535 const FrameHeader& frame_header, Image& gi, PassesDecoderState* dec_state, 536 jxl::ThreadPool* pool, RenderPipelineInput& render_pipeline_input, 537 Rect modular_rect) const { 538 const auto* metadata = frame_header.nonserialized_metadata; 539 JXL_ENSURE(gi.transform.empty()); 540 541 auto get_row = [&](size_t c, size_t y) { 542 const auto& buffer = render_pipeline_input.GetBuffer(c); 543 return buffer.second.Row(buffer.first, y); 544 }; 545 546 size_t c = 0; 547 if (do_color) { 548 const bool rgb_from_gray = 549 metadata->m.color_encoding.IsGray() && 550 frame_header.color_transform == ColorTransform::kNone; 551 const bool fp = metadata->m.bit_depth.floating_point_sample && 552 frame_header.color_transform != ColorTransform::kXYB; 553 for (; c < 3; c++) { 554 double factor = full_image.bitdepth < 32 555 ? 1.0 / ((1u << full_image.bitdepth) - 1) 556 : 0; 557 size_t c_in = c; 558 if (frame_header.color_transform == ColorTransform::kXYB) { 559 factor = dec_state->shared->matrices.DCQuants()[c]; 560 // XYB is encoded as YX(B-Y) 561 if (c < 2) c_in = 1 - c; 562 } else if (rgb_from_gray) { 563 c_in = 0; 564 } 565 JXL_ENSURE(c_in < gi.channel.size()); 566 Channel& ch_in = gi.channel[c_in]; 567 // TODO(eustas): could we detect it on earlier stage? 568 if (ch_in.w == 0 || ch_in.h == 0) { 569 return JXL_FAILURE("Empty image"); 570 } 571 JXL_ENSURE(ch_in.hshift <= 3 && ch_in.vshift <= 3); 572 Rect r = render_pipeline_input.GetBuffer(c).second; 573 Rect mr(modular_rect.x0() >> ch_in.hshift, 574 modular_rect.y0() >> ch_in.vshift, 575 DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), 576 DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); 577 mr = mr.Crop(ch_in.plane); 578 size_t xsize_shifted = r.xsize(); 579 size_t ysize_shifted = r.ysize(); 580 if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { 581 return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS 582 "x%" PRIuS 583 " modular channel into " 584 "a %" PRIuS "x%" PRIuS " rect", 585 mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); 586 } 587 if (frame_header.color_transform == ColorTransform::kXYB && c == 2) { 588 JXL_ENSURE(!fp); 589 const auto process_row = [&](const uint32_t task, 590 size_t /* thread */) -> Status { 591 const size_t y = task; 592 const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y); 593 const pixel_type* const JXL_RESTRICT row_in_Y = 594 mr.Row(&gi.channel[0].plane, y); 595 float* const JXL_RESTRICT row_out = get_row(c, y); 596 HWY_DYNAMIC_DISPATCH(MultiplySum) 597 (xsize_shifted, row_in, row_in_Y, factor, row_out); 598 return true; 599 }; 600 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted, 601 ThreadPool::NoInit, process_row, 602 "ModularIntToFloat")); 603 } else if (fp) { 604 int bits = metadata->m.bit_depth.bits_per_sample; 605 int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample; 606 const auto process_row = [&](const uint32_t task, 607 size_t /* thread */) -> Status { 608 const size_t y = task; 609 const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y); 610 if (rgb_from_gray) { 611 for (size_t cc = 0; cc < 3; cc++) { 612 float* const JXL_RESTRICT row_out = get_row(cc, y); 613 JXL_RETURN_IF_ERROR( 614 int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits)); 615 } 616 } else { 617 float* const JXL_RESTRICT row_out = get_row(c, y); 618 JXL_RETURN_IF_ERROR( 619 int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits)); 620 } 621 return true; 622 }; 623 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted, 624 ThreadPool::NoInit, process_row, 625 "ModularIntToFloat_losslessfloat")); 626 } else { 627 const auto process_row = [&](const uint32_t task, 628 size_t /* thread */) -> Status { 629 const size_t y = task; 630 const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y); 631 if (rgb_from_gray) { 632 if (full_image.bitdepth < 23) { 633 HWY_DYNAMIC_DISPATCH(RgbFromSingle) 634 (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y), 635 get_row(2, y)); 636 } else { 637 SingleFromSingleAccurate(xsize_shifted, row_in, factor, 638 get_row(0, y)); 639 SingleFromSingleAccurate(xsize_shifted, row_in, factor, 640 get_row(1, y)); 641 SingleFromSingleAccurate(xsize_shifted, row_in, factor, 642 get_row(2, y)); 643 } 644 } else { 645 float* const JXL_RESTRICT row_out = get_row(c, y); 646 if (full_image.bitdepth < 23) { 647 HWY_DYNAMIC_DISPATCH(SingleFromSingle) 648 (xsize_shifted, row_in, factor, row_out); 649 } else { 650 SingleFromSingleAccurate(xsize_shifted, row_in, factor, row_out); 651 } 652 } 653 return true; 654 }; 655 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted, 656 ThreadPool::NoInit, process_row, 657 "ModularIntToFloat")); 658 } 659 if (rgb_from_gray) { 660 break; 661 } 662 } 663 if (rgb_from_gray) { 664 c = 1; 665 } 666 } 667 size_t num_extra_channels = metadata->m.num_extra_channels; 668 for (size_t ec = 0; ec < num_extra_channels; ec++, c++) { 669 const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec]; 670 int bits = eci.bit_depth.bits_per_sample; 671 int exp_bits = eci.bit_depth.exponent_bits_per_sample; 672 bool fp = eci.bit_depth.floating_point_sample; 673 JXL_ENSURE(fp || bits < 32); 674 const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1)); 675 JXL_ENSURE(c < gi.channel.size()); 676 Channel& ch_in = gi.channel[c]; 677 const auto& buffer = render_pipeline_input.GetBuffer(3 + ec); 678 Rect r = buffer.second; 679 Rect mr(modular_rect.x0() >> ch_in.hshift, 680 modular_rect.y0() >> ch_in.vshift, 681 DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), 682 DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); 683 mr = mr.Crop(ch_in.plane); 684 if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { 685 return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS 686 "x%" PRIuS 687 " modular channel into " 688 "a %" PRIuS "x%" PRIuS " rect", 689 mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); 690 } 691 for (size_t y = 0; y < r.ysize(); ++y) { 692 float* const JXL_RESTRICT row_out = r.Row(buffer.first, y); 693 const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y); 694 if (fp) { 695 JXL_RETURN_IF_ERROR( 696 int_to_float(row_in, row_out, r.xsize(), bits, exp_bits)); 697 } else { 698 if (full_image.bitdepth < 23) { 699 HWY_DYNAMIC_DISPATCH(SingleFromSingle) 700 (r.xsize(), row_in, factor, row_out); 701 } else { 702 SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out); 703 } 704 } 705 } 706 } 707 return true; 708 } 709 710 Status ModularFrameDecoder::FinalizeDecoding(const FrameHeader& frame_header, 711 PassesDecoderState* dec_state, 712 jxl::ThreadPool* pool, 713 bool inplace) { 714 if (!use_full_image) return true; 715 JxlMemoryManager* memory_manager = dec_state->memory_manager(); 716 Image gi{memory_manager}; 717 if (inplace) { 718 gi = std::move(full_image); 719 } else { 720 JXL_ASSIGN_OR_RETURN(gi, Image::Clone(full_image)); 721 } 722 size_t xsize = gi.w; 723 size_t ysize = gi.h; 724 725 JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s", 726 gi.DebugString().c_str()); 727 728 // Don't use threads if total image size is smaller than a group 729 if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr; 730 731 // Undo the global transforms 732 gi.undo_transforms(global_header.wp_header, pool); 733 JXL_ENSURE(global_transform.empty()); 734 if (gi.error) return JXL_FAILURE("Undoing transforms failed"); 735 736 for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) { 737 dec_state->render_pipeline->ClearDone(i); 738 } 739 740 const auto init = [&](size_t num_threads) -> Status { 741 bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT || 742 (frame_header.flags & FrameHeader::kNoise)); 743 JXL_RETURN_IF_ERROR(dec_state->render_pipeline->PrepareForThreads( 744 num_threads, use_group_ids)); 745 return true; 746 }; 747 const auto process_group = [&](const uint32_t group, 748 size_t thread_id) -> Status { 749 RenderPipelineInput input = 750 dec_state->render_pipeline->GetInputBuffers(group, thread_id); 751 JXL_RETURN_IF_ERROR(ModularImageToDecodedRect( 752 frame_header, gi, dec_state, nullptr, input, 753 dec_state->shared->frame_dim.GroupRect(group))); 754 JXL_RETURN_IF_ERROR(input.Done()); 755 return true; 756 }; 757 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, 758 dec_state->shared->frame_dim.num_groups, init, 759 process_group, "ModularToRect")); 760 return true; 761 } 762 763 static constexpr const float kAlmostZero = 1e-8f; 764 765 Status ModularFrameDecoder::DecodeQuantTable( 766 JxlMemoryManager* memory_manager, size_t required_size_x, 767 size_t required_size_y, BitReader* br, QuantEncoding* encoding, size_t idx, 768 ModularFrameDecoder* modular_frame_decoder) { 769 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den)); 770 if (encoding->qraw.qtable_den < kAlmostZero) { 771 // qtable[] values are already checked for <= 0 so the denominator may not 772 // be negative. 773 return JXL_FAILURE("Invalid qtable_den: value too small"); 774 } 775 JXL_ASSIGN_OR_RETURN( 776 Image image, 777 Image::Create(memory_manager, required_size_x, required_size_y, 8, 3)); 778 ModularOptions options; 779 if (modular_frame_decoder) { 780 JXL_ASSIGN_OR_RETURN(ModularStreamId qt, ModularStreamId::QuantTable(idx)); 781 JXL_RETURN_IF_ERROR(ModularGenericDecompress( 782 br, image, /*header=*/nullptr, qt.ID(modular_frame_decoder->frame_dim), 783 &options, /*undo_transforms=*/true, &modular_frame_decoder->tree, 784 &modular_frame_decoder->code, &modular_frame_decoder->context_map)); 785 } else { 786 JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr, 787 0, &options, 788 /*undo_transforms=*/true)); 789 } 790 if (!encoding->qraw.qtable) { 791 encoding->qraw.qtable = 792 new std::vector<int>(required_size_x * required_size_y * 3); 793 } else { 794 JXL_ENSURE(encoding->qraw.qtable->size() == 795 required_size_x * required_size_y * 3); 796 } 797 int* qtable = encoding->qraw.qtable->data(); 798 for (size_t c = 0; c < 3; c++) { 799 for (size_t y = 0; y < required_size_y; y++) { 800 int32_t* JXL_RESTRICT row = image.channel[c].Row(y); 801 for (size_t x = 0; x < required_size_x; x++) { 802 qtable[c * required_size_x * required_size_y + y * required_size_x + 803 x] = row[x]; 804 if (row[x] <= 0) { 805 return JXL_FAILURE("Invalid raw quantization table"); 806 } 807 } 808 } 809 } 810 return true; 811 } 812 813 } // namespace jxl 814 #endif // HWY_ONCE