quant_weights.cc (53271B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 #include "lib/jxl/quant_weights.h" 6 7 #include <jxl/memory_manager.h> 8 9 #include <cmath> 10 #include <cstdio> 11 #include <cstdlib> 12 13 #include "lib/jxl/base/compiler_specific.h" 14 #include "lib/jxl/base/status.h" 15 #include "lib/jxl/dct_scales.h" 16 #include "lib/jxl/dec_modular.h" 17 #include "lib/jxl/fields.h" 18 #include "lib/jxl/memory_manager_internal.h" 19 20 #undef HWY_TARGET_INCLUDE 21 #define HWY_TARGET_INCLUDE "lib/jxl/quant_weights.cc" 22 #include <hwy/foreach_target.h> 23 #include <hwy/highway.h> 24 25 #include "lib/jxl/base/fast_math-inl.h" 26 27 HWY_BEFORE_NAMESPACE(); 28 namespace jxl { 29 namespace HWY_NAMESPACE { 30 31 // These templates are not found via ADL. 32 using hwy::HWY_NAMESPACE::Lt; 33 using hwy::HWY_NAMESPACE::MulAdd; 34 using hwy::HWY_NAMESPACE::Sqrt; 35 36 // kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y) 37 // coefficient in component c. Higher weights correspond to finer quantization 38 // intervals and more bits spent in encoding. 39 40 static constexpr const float kAlmostZero = 1e-8f; 41 42 void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights, 43 float* weights) { 44 for (size_t c = 0; c < 3; c++) { 45 size_t start = c * 64; 46 weights[start] = 0xBAD; 47 weights[start + 1] = weights[start + 8] = dct2weights[c][0]; 48 weights[start + 9] = dct2weights[c][1]; 49 for (size_t y = 0; y < 2; y++) { 50 for (size_t x = 0; x < 2; x++) { 51 weights[start + y * 8 + x + 2] = dct2weights[c][2]; 52 weights[start + (y + 2) * 8 + x] = dct2weights[c][2]; 53 } 54 } 55 for (size_t y = 0; y < 2; y++) { 56 for (size_t x = 0; x < 2; x++) { 57 weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3]; 58 } 59 } 60 for (size_t y = 0; y < 4; y++) { 61 for (size_t x = 0; x < 4; x++) { 62 weights[start + y * 8 + x + 4] = dct2weights[c][4]; 63 weights[start + (y + 4) * 8 + x] = dct2weights[c][4]; 64 } 65 } 66 for (size_t y = 0; y < 4; y++) { 67 for (size_t x = 0; x < 4; x++) { 68 weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5]; 69 } 70 } 71 } 72 } 73 74 void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights, 75 float* weights) { 76 for (size_t c = 0; c < 3; c++) { 77 for (int i = 0; i < 64; i++) { 78 weights[64 * c + i] = idweights[c][0]; 79 } 80 weights[64 * c + 1] = idweights[c][1]; 81 weights[64 * c + 8] = idweights[c][1]; 82 weights[64 * c + 9] = idweights[c][2]; 83 } 84 } 85 86 StatusOr<float> Interpolate(float pos, float max, const float* array, 87 size_t len) { 88 float scaled_pos = pos * (len - 1) / max; 89 size_t idx = scaled_pos; 90 JXL_ENSURE(idx + 1 < len); 91 float a = array[idx]; 92 float b = array[idx + 1]; 93 return a * FastPowf(b / a, scaled_pos - idx); 94 } 95 96 float Mult(float v) { 97 if (v > 0.0f) return 1.0f + v; 98 return 1.0f / (1.0f - v); 99 } 100 101 using DF4 = HWY_CAPPED(float, 4); 102 103 hwy::HWY_NAMESPACE::Vec<DF4> InterpolateVec( 104 hwy::HWY_NAMESPACE::Vec<DF4> scaled_pos, const float* array) { 105 HWY_CAPPED(int32_t, 4) di; 106 107 auto idx = ConvertTo(di, scaled_pos); 108 109 auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx)); 110 111 // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but 112 // it's probably slower. 113 auto a = GatherIndex(DF4(), array, idx); 114 auto b = GatherIndex(DF4(), array + 1, idx); 115 116 return Mul(a, FastPowf(DF4(), Div(b, a), frac)); 117 } 118 119 // Computes quant weights for a COLS*ROWS-sized transform, using num_bands 120 // eccentricity bands and num_ebands eccentricity bands. If print_mode is 1, 121 // prints the resulting matrix; if print_mode is 2, prints the matrix in a 122 // format suitable for a 3d plot with gnuplot. 123 Status GetQuantWeights( 124 size_t ROWS, size_t COLS, 125 const DctQuantWeightParams::DistanceBandsArray& distance_bands, 126 size_t num_bands, float* out) { 127 for (size_t c = 0; c < 3; c++) { 128 float bands[DctQuantWeightParams::kMaxDistanceBands] = { 129 distance_bands[c][0]}; 130 if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); 131 for (size_t i = 1; i < num_bands; i++) { 132 bands[i] = bands[i - 1] * Mult(distance_bands[c][i]); 133 if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); 134 } 135 float scale = (num_bands - 1) / (kSqrt2 + 1e-6f); 136 float rcpcol = scale / (COLS - 1); 137 float rcprow = scale / (ROWS - 1); 138 JXL_ENSURE(COLS >= Lanes(DF4())); 139 HWY_ALIGN float l0123[4] = {0, 1, 2, 3}; 140 for (uint32_t y = 0; y < ROWS; y++) { 141 float dy = y * rcprow; 142 float dy2 = dy * dy; 143 for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) { 144 auto dx = 145 Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol)); 146 auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2))); 147 auto weight = num_bands == 1 ? Set(DF4(), bands[0]) 148 : InterpolateVec(scaled_distance, bands); 149 StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x); 150 } 151 } 152 } 153 return true; 154 } 155 156 // TODO(veluca): SIMD-fy. With 256x256, this is actually slow. 157 Status ComputeQuantTable(const QuantEncoding& encoding, 158 float* JXL_RESTRICT table, 159 float* JXL_RESTRICT inv_table, size_t table_num, 160 QuantTable kind, size_t* pos) { 161 constexpr size_t N = kBlockDim; 162 size_t quant_table_idx = static_cast<size_t>(kind); 163 size_t wrows = 8 * DequantMatrices::required_size_x[quant_table_idx]; 164 size_t wcols = 8 * DequantMatrices::required_size_y[quant_table_idx]; 165 size_t num = wrows * wcols; 166 167 std::vector<float> weights(3 * num); 168 169 switch (encoding.mode) { 170 case QuantEncoding::kQuantModeLibrary: { 171 // Library and copy quant encoding should get replaced by the actual 172 // parameters by the caller. 173 JXL_ENSURE(false); 174 break; 175 } 176 case QuantEncoding::kQuantModeID: { 177 JXL_ENSURE(num == kDCTBlockSize); 178 GetQuantWeightsIdentity(encoding.idweights, weights.data()); 179 break; 180 } 181 case QuantEncoding::kQuantModeDCT2: { 182 JXL_ENSURE(num == kDCTBlockSize); 183 GetQuantWeightsDCT2(encoding.dct2weights, weights.data()); 184 break; 185 } 186 case QuantEncoding::kQuantModeDCT4: { 187 JXL_ENSURE(num == kDCTBlockSize); 188 float weights4x4[3 * 4 * 4]; 189 // Always use 4x4 GetQuantWeights for DCT4 quantization tables. 190 JXL_RETURN_IF_ERROR( 191 GetQuantWeights(4, 4, encoding.dct_params.distance_bands, 192 encoding.dct_params.num_distance_bands, weights4x4)); 193 for (size_t c = 0; c < 3; c++) { 194 for (size_t y = 0; y < kBlockDim; y++) { 195 for (size_t x = 0; x < kBlockDim; x++) { 196 weights[c * num + y * kBlockDim + x] = 197 weights4x4[c * 16 + (y / 2) * 4 + (x / 2)]; 198 } 199 } 200 weights[c * num + 1] /= encoding.dct4multipliers[c][0]; 201 weights[c * num + N] /= encoding.dct4multipliers[c][0]; 202 weights[c * num + N + 1] /= encoding.dct4multipliers[c][1]; 203 } 204 break; 205 } 206 case QuantEncoding::kQuantModeDCT4X8: { 207 JXL_ENSURE(num == kDCTBlockSize); 208 float weights4x8[3 * 4 * 8]; 209 // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables. 210 JXL_RETURN_IF_ERROR( 211 GetQuantWeights(4, 8, encoding.dct_params.distance_bands, 212 encoding.dct_params.num_distance_bands, weights4x8)); 213 for (size_t c = 0; c < 3; c++) { 214 for (size_t y = 0; y < kBlockDim; y++) { 215 for (size_t x = 0; x < kBlockDim; x++) { 216 weights[c * num + y * kBlockDim + x] = 217 weights4x8[c * 32 + (y / 2) * 8 + x]; 218 } 219 } 220 weights[c * num + N] /= encoding.dct4x8multipliers[c]; 221 } 222 break; 223 } 224 case QuantEncoding::kQuantModeDCT: { 225 JXL_RETURN_IF_ERROR(GetQuantWeights( 226 wrows, wcols, encoding.dct_params.distance_bands, 227 encoding.dct_params.num_distance_bands, weights.data())); 228 break; 229 } 230 case QuantEncoding::kQuantModeRAW: { 231 if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) { 232 return JXL_FAILURE("Invalid table encoding"); 233 } 234 int* qtable = encoding.qraw.qtable->data(); 235 for (size_t i = 0; i < 3 * num; i++) { 236 weights[i] = 1.f / (encoding.qraw.qtable_den * qtable[i]); 237 } 238 break; 239 } 240 case QuantEncoding::kQuantModeAFV: { 241 constexpr float kFreqs[] = { 242 0xBAD, 243 0xBAD, 244 0.8517778890324296, 245 5.37778436506804, 246 0xBAD, 247 0xBAD, 248 4.734747904497923, 249 5.449245381693219, 250 1.6598270267479331, 251 4, 252 7.275749096817861, 253 10.423227632456525, 254 2.662932286148962, 255 7.630657783650829, 256 8.962388608184032, 257 12.97166202570235, 258 }; 259 260 float weights4x8[3 * 4 * 8]; 261 JXL_RETURN_IF_ERROR(( 262 GetQuantWeights(4, 8, encoding.dct_params.distance_bands, 263 encoding.dct_params.num_distance_bands, weights4x8))); 264 float weights4x4[3 * 4 * 4]; 265 JXL_RETURN_IF_ERROR((GetQuantWeights( 266 4, 4, encoding.dct_params_afv_4x4.distance_bands, 267 encoding.dct_params_afv_4x4.num_distance_bands, weights4x4))); 268 269 constexpr float lo = 0.8517778890324296; 270 constexpr float hi = 12.97166202570235f - lo + 1e-6f; 271 for (size_t c = 0; c < 3; c++) { 272 float bands[4]; 273 bands[0] = encoding.afv_weights[c][5]; 274 if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); 275 for (size_t i = 1; i < 4; i++) { 276 bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]); 277 if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); 278 } 279 size_t start = c * 64; 280 auto set_weight = [&start, &weights](size_t x, size_t y, float val) { 281 weights[start + y * 8 + x] = val; 282 }; 283 weights[start] = 1; // Not used, but causes MSAN error otherwise. 284 // Weights for (0, 1) and (1, 0). 285 set_weight(0, 1, encoding.afv_weights[c][0]); 286 set_weight(1, 0, encoding.afv_weights[c][1]); 287 // AFV special weights for 3-pixel corner. 288 set_weight(0, 2, encoding.afv_weights[c][2]); 289 set_weight(2, 0, encoding.afv_weights[c][3]); 290 set_weight(2, 2, encoding.afv_weights[c][4]); 291 292 // All other AFV weights. 293 for (size_t y = 0; y < 4; y++) { 294 for (size_t x = 0; x < 4; x++) { 295 if (x < 2 && y < 2) continue; 296 JXL_ASSIGN_OR_RETURN( 297 float val, Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4)); 298 set_weight(2 * x, 2 * y, val); 299 } 300 } 301 302 // Put 4x8 weights in odd rows, except (1, 0). 303 for (size_t y = 0; y < kBlockDim / 2; y++) { 304 for (size_t x = 0; x < kBlockDim; x++) { 305 if (x == 0 && y == 0) continue; 306 weights[c * num + (2 * y + 1) * kBlockDim + x] = 307 weights4x8[c * 32 + y * 8 + x]; 308 } 309 } 310 // Put 4x4 weights in even rows / odd columns, except (0, 1). 311 for (size_t y = 0; y < kBlockDim / 2; y++) { 312 for (size_t x = 0; x < kBlockDim / 2; x++) { 313 if (x == 0 && y == 0) continue; 314 weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] = 315 weights4x4[c * 16 + y * 4 + x]; 316 } 317 } 318 } 319 break; 320 } 321 } 322 size_t prev_pos = *pos; 323 HWY_CAPPED(float, 64) d; 324 for (size_t i = 0; i < num * 3; i += Lanes(d)) { 325 auto inv_val = LoadU(d, weights.data() + i); 326 if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) || 327 !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) { 328 return JXL_FAILURE("Invalid quantization table"); 329 } 330 auto val = Div(Set(d, 1.0f), inv_val); 331 StoreU(val, d, table + *pos + i); 332 StoreU(inv_val, d, inv_table + *pos + i); 333 } 334 (*pos) += 3 * num; 335 336 // Ensure that the lowest frequencies have a 0 inverse table. 337 // This does not affect en/decoding, but allows AC strategy selection to be 338 // slightly simpler. 339 size_t xs = DequantMatrices::required_size_x[quant_table_idx]; 340 size_t ys = DequantMatrices::required_size_y[quant_table_idx]; 341 CoefficientLayout(&ys, &xs); 342 for (size_t c = 0; c < 3; c++) { 343 for (size_t y = 0; y < ys; y++) { 344 for (size_t x = 0; x < xs; x++) { 345 inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs + 346 x] = 0; 347 } 348 } 349 } 350 return true; 351 } 352 353 // NOLINTNEXTLINE(google-readability-namespace-comments) 354 } // namespace HWY_NAMESPACE 355 } // namespace jxl 356 HWY_AFTER_NAMESPACE(); 357 358 #if HWY_ONCE 359 360 namespace jxl { 361 namespace { 362 363 HWY_EXPORT(ComputeQuantTable); 364 365 constexpr const float kAlmostZero = 1e-8f; 366 367 Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) { 368 params->num_distance_bands = 369 br->ReadFixedBits<DctQuantWeightParams::kLog2MaxDistanceBands>() + 1; 370 for (size_t c = 0; c < 3; c++) { 371 for (size_t i = 0; i < params->num_distance_bands; i++) { 372 JXL_RETURN_IF_ERROR(F16Coder::Read(br, ¶ms->distance_bands[c][i])); 373 } 374 if (params->distance_bands[c][0] < kAlmostZero) { 375 return JXL_FAILURE("Distance band seed is too small"); 376 } 377 params->distance_bands[c][0] *= 64.0f; 378 } 379 return true; 380 } 381 382 Status Decode(JxlMemoryManager* memory_manager, BitReader* br, 383 QuantEncoding* encoding, size_t required_size_x, 384 size_t required_size_y, size_t idx, 385 ModularFrameDecoder* modular_frame_decoder) { 386 size_t required_size = required_size_x * required_size_y; 387 required_size_x *= kBlockDim; 388 required_size_y *= kBlockDim; 389 int mode = br->ReadFixedBits<kLog2NumQuantModes>(); 390 switch (mode) { 391 case QuantEncoding::kQuantModeLibrary: { 392 encoding->predefined = br->ReadFixedBits<kCeilLog2NumPredefinedTables>(); 393 if (encoding->predefined >= kNumPredefinedTables) { 394 return JXL_FAILURE("Invalid predefined table"); 395 } 396 break; 397 } 398 case QuantEncoding::kQuantModeID: { 399 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 400 for (size_t c = 0; c < 3; c++) { 401 for (size_t i = 0; i < 3; i++) { 402 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i])); 403 if (std::abs(encoding->idweights[c][i]) < kAlmostZero) { 404 return JXL_FAILURE("ID Quantizer is too small"); 405 } 406 encoding->idweights[c][i] *= 64; 407 } 408 } 409 break; 410 } 411 case QuantEncoding::kQuantModeDCT2: { 412 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 413 for (size_t c = 0; c < 3; c++) { 414 for (size_t i = 0; i < 6; i++) { 415 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i])); 416 if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) { 417 return JXL_FAILURE("Quantizer is too small"); 418 } 419 encoding->dct2weights[c][i] *= 64; 420 } 421 } 422 break; 423 } 424 case QuantEncoding::kQuantModeDCT4X8: { 425 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 426 for (size_t c = 0; c < 3; c++) { 427 JXL_RETURN_IF_ERROR( 428 F16Coder::Read(br, &encoding->dct4x8multipliers[c])); 429 if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) { 430 return JXL_FAILURE("DCT4X8 multiplier is too small"); 431 } 432 } 433 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); 434 break; 435 } 436 case QuantEncoding::kQuantModeDCT4: { 437 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 438 for (size_t c = 0; c < 3; c++) { 439 for (size_t i = 0; i < 2; i++) { 440 JXL_RETURN_IF_ERROR( 441 F16Coder::Read(br, &encoding->dct4multipliers[c][i])); 442 if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) { 443 return JXL_FAILURE("DCT4 multiplier is too small"); 444 } 445 } 446 } 447 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); 448 break; 449 } 450 case QuantEncoding::kQuantModeAFV: { 451 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 452 for (size_t c = 0; c < 3; c++) { 453 for (size_t i = 0; i < 9; i++) { 454 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i])); 455 } 456 for (size_t i = 0; i < 6; i++) { 457 encoding->afv_weights[c][i] *= 64; 458 } 459 } 460 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); 461 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4)); 462 break; 463 } 464 case QuantEncoding::kQuantModeDCT: { 465 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); 466 break; 467 } 468 case QuantEncoding::kQuantModeRAW: { 469 // Set mode early, to avoid mem-leak. 470 encoding->mode = QuantEncoding::kQuantModeRAW; 471 JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable( 472 memory_manager, required_size_x, required_size_y, br, encoding, idx, 473 modular_frame_decoder)); 474 break; 475 } 476 default: 477 return JXL_FAILURE("Invalid quantization table encoding"); 478 } 479 encoding->mode = static_cast<QuantEncoding::Mode>(mode); 480 return true; 481 } 482 483 } // namespace 484 485 #if JXL_CXX_LANG < JXL_CXX_17 486 constexpr const std::array<int, 17> DequantMatrices::required_size_x; 487 constexpr const std::array<int, 17> DequantMatrices::required_size_y; 488 constexpr const size_t DequantMatrices::kSumRequiredXy; 489 #endif 490 491 Status DequantMatrices::Decode(JxlMemoryManager* memory_manager, BitReader* br, 492 ModularFrameDecoder* modular_frame_decoder) { 493 size_t all_default = br->ReadBits(1); 494 size_t num_tables = all_default ? 0 : static_cast<size_t>(kNumQuantTables); 495 encodings_.clear(); 496 encodings_.resize(kNumQuantTables, QuantEncoding::Library<0>()); 497 for (size_t i = 0; i < num_tables; i++) { 498 JXL_RETURN_IF_ERROR(jxl::Decode(memory_manager, br, &encodings_[i], 499 required_size_x[i % kNumQuantTables], 500 required_size_y[i % kNumQuantTables], i, 501 modular_frame_decoder)); 502 } 503 computed_mask_ = 0; 504 return true; 505 } 506 507 Status DequantMatrices::DecodeDC(BitReader* br) { 508 bool all_default = static_cast<bool>(br->ReadBits(1)); 509 if (!br->AllReadsWithinBounds()) return JXL_FAILURE("EOS during DecodeDC"); 510 if (!all_default) { 511 for (size_t c = 0; c < 3; c++) { 512 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &dc_quant_[c])); 513 dc_quant_[c] *= 1.0f / 128.0f; 514 // Negative values and nearly zero are invalid values. 515 if (dc_quant_[c] < kAlmostZero) { 516 return JXL_FAILURE("Invalid dc_quant: coefficient is too small."); 517 } 518 inv_dc_quant_[c] = 1.0f / dc_quant_[c]; 519 } 520 } 521 return true; 522 } 523 524 constexpr float V(float v) { return static_cast<float>(v); } 525 526 namespace { 527 struct DequantMatricesLibraryDef { 528 // DCT8 529 static constexpr QuantEncodingInternal DCT() { 530 return QuantEncodingInternal::DCT(DctQuantWeightParams({{{{ 531 V(3150.0), 532 V(0.0), 533 V(-0.4), 534 V(-0.4), 535 V(-0.4), 536 V(-2.0), 537 }}, 538 {{ 539 V(560.0), 540 V(0.0), 541 V(-0.3), 542 V(-0.3), 543 V(-0.3), 544 V(-0.3), 545 }}, 546 {{ 547 V(512.0), 548 V(-2.0), 549 V(-1.0), 550 V(0.0), 551 V(-1.0), 552 V(-2.0), 553 }}}}, 554 6)); 555 } 556 557 // Identity 558 static constexpr QuantEncodingInternal IDENTITY() { 559 return QuantEncodingInternal::Identity({{{{ 560 V(280.0), 561 V(3160.0), 562 V(3160.0), 563 }}, 564 {{ 565 V(60.0), 566 V(864.0), 567 V(864.0), 568 }}, 569 {{ 570 V(18.0), 571 V(200.0), 572 V(200.0), 573 }}}}); 574 } 575 576 // DCT2 577 static constexpr QuantEncodingInternal DCT2X2() { 578 return QuantEncodingInternal::DCT2({{{{ 579 V(3840.0), 580 V(2560.0), 581 V(1280.0), 582 V(640.0), 583 V(480.0), 584 V(300.0), 585 }}, 586 {{ 587 V(960.0), 588 V(640.0), 589 V(320.0), 590 V(180.0), 591 V(140.0), 592 V(120.0), 593 }}, 594 {{ 595 V(640.0), 596 V(320.0), 597 V(128.0), 598 V(64.0), 599 V(32.0), 600 V(16.0), 601 }}}}); 602 } 603 604 // DCT4 (quant_kind 3) 605 static constexpr QuantEncodingInternal DCT4X4() { 606 return QuantEncodingInternal::DCT4(DctQuantWeightParams({{{{ 607 V(2200.0), 608 V(0.0), 609 V(0.0), 610 V(0.0), 611 }}, 612 {{ 613 V(392.0), 614 V(0.0), 615 V(0.0), 616 V(0.0), 617 }}, 618 {{ 619 V(112.0), 620 V(-0.25), 621 V(-0.25), 622 V(-0.5), 623 }}}}, 624 4), 625 /* kMul */ 626 {{{{ 627 V(1.0), 628 V(1.0), 629 }}, 630 {{ 631 V(1.0), 632 V(1.0), 633 }}, 634 {{ 635 V(1.0), 636 V(1.0), 637 }}}}); 638 } 639 640 // DCT16 641 static constexpr QuantEncodingInternal DCT16X16() { 642 return QuantEncodingInternal::DCT( 643 DctQuantWeightParams({{{{ 644 V(8996.8725711814115328), 645 V(-1.3000777393353804), 646 V(-0.49424529824571225), 647 V(-0.439093774457103443), 648 V(-0.6350101832695744), 649 V(-0.90177264050827612), 650 V(-1.6162099239887414), 651 }}, 652 {{ 653 V(3191.48366296844234752), 654 V(-0.67424582104194355), 655 V(-0.80745813428471001), 656 V(-0.44925837484843441), 657 V(-0.35865440981033403), 658 V(-0.31322389111877305), 659 V(-0.37615025315725483), 660 }}, 661 {{ 662 V(1157.50408145487200256), 663 V(-2.0531423165804414), 664 V(-1.4), 665 V(-0.50687130033378396), 666 V(-0.42708730624733904), 667 V(-1.4856834539296244), 668 V(-4.9209142884401604), 669 }}}}, 670 7)); 671 } 672 673 // DCT32 674 static constexpr QuantEncodingInternal DCT32X32() { 675 return QuantEncodingInternal::DCT( 676 DctQuantWeightParams({{{{ 677 V(15718.40830982518931456), 678 V(-1.025), 679 V(-0.98), 680 V(-0.9012), 681 V(-0.4), 682 V(-0.48819395464), 683 V(-0.421064), 684 V(-0.27), 685 }}, 686 {{ 687 V(7305.7636810695983104), 688 V(-0.8041958212306401), 689 V(-0.7633036457487539), 690 V(-0.55660379990111464), 691 V(-0.49785304658857626), 692 V(-0.43699592683512467), 693 V(-0.40180866526242109), 694 V(-0.27321683125358037), 695 }}, 696 {{ 697 V(3803.53173721215041536), 698 V(-3.060733579805728), 699 V(-2.0413270132490346), 700 V(-2.0235650159727417), 701 V(-0.5495389509954993), 702 V(-0.4), 703 V(-0.4), 704 V(-0.3), 705 }}}}, 706 8)); 707 } 708 709 // DCT16X8 710 static constexpr QuantEncodingInternal DCT8X16() { 711 return QuantEncodingInternal::DCT( 712 DctQuantWeightParams({{{{ 713 V(7240.7734393502), 714 V(-0.7), 715 V(-0.7), 716 V(-0.2), 717 V(-0.2), 718 V(-0.2), 719 V(-0.5), 720 }}, 721 {{ 722 V(1448.15468787004), 723 V(-0.5), 724 V(-0.5), 725 V(-0.5), 726 V(-0.2), 727 V(-0.2), 728 V(-0.2), 729 }}, 730 {{ 731 V(506.854140754517), 732 V(-1.4), 733 V(-0.2), 734 V(-0.5), 735 V(-0.5), 736 V(-1.5), 737 V(-3.6), 738 }}}}, 739 7)); 740 } 741 742 // DCT32X8 743 static constexpr QuantEncodingInternal DCT8X32() { 744 return QuantEncodingInternal::DCT( 745 DctQuantWeightParams({{{{ 746 V(16283.2494710648897), 747 V(-1.7812845336559429), 748 V(-1.6309059012653515), 749 V(-1.0382179034313539), 750 V(-0.85), 751 V(-0.7), 752 V(-0.9), 753 V(-1.2360638576849587), 754 }}, 755 {{ 756 V(5089.15750884921511936), 757 V(-0.320049391452786891), 758 V(-0.35362849922161446), 759 V(-0.30340000000000003), 760 V(-0.61), 761 V(-0.5), 762 V(-0.5), 763 V(-0.6), 764 }}, 765 {{ 766 V(3397.77603275308720128), 767 V(-0.321327362693153371), 768 V(-0.34507619223117997), 769 V(-0.70340000000000003), 770 V(-0.9), 771 V(-1.0), 772 V(-1.0), 773 V(-1.1754605576265209), 774 }}}}, 775 8)); 776 } 777 778 // DCT32X16 779 static constexpr QuantEncodingInternal DCT16X32() { 780 return QuantEncodingInternal::DCT( 781 DctQuantWeightParams({{{{ 782 V(13844.97076442300573), 783 V(-0.97113799999999995), 784 V(-0.658), 785 V(-0.42026), 786 V(-0.22712), 787 V(-0.2206), 788 V(-0.226), 789 V(-0.6), 790 }}, 791 {{ 792 V(4798.964084220744293), 793 V(-0.61125308982767057), 794 V(-0.83770786552491361), 795 V(-0.79014862079498627), 796 V(-0.2692727459704829), 797 V(-0.38272769465388551), 798 V(-0.22924222653091453), 799 V(-0.20719098826199578), 800 }}, 801 {{ 802 V(1807.236946760964614), 803 V(-1.2), 804 V(-1.2), 805 V(-0.7), 806 V(-0.7), 807 V(-0.7), 808 V(-0.4), 809 V(-0.5), 810 }}}}, 811 8)); 812 } 813 814 // DCT4X8 and 8x4 815 static constexpr QuantEncodingInternal DCT4X8() { 816 return QuantEncodingInternal::DCT4X8( 817 DctQuantWeightParams({{ 818 {{ 819 V(2198.050556016380522), 820 V(-0.96269623020744692), 821 V(-0.76194253026666783), 822 V(-0.6551140670773547), 823 }}, 824 {{ 825 V(764.3655248643528689), 826 V(-0.92630200888366945), 827 V(-0.9675229603596517), 828 V(-0.27845290869168118), 829 }}, 830 {{ 831 V(527.107573587542228), 832 V(-1.4594385811273854), 833 V(-1.450082094097871593), 834 V(-1.5843722511996204), 835 }}, 836 }}, 837 4), 838 /* kMuls */ 839 {{ 840 V(1.0), 841 V(1.0), 842 V(1.0), 843 }}); 844 } 845 // AFV 846 static QuantEncodingInternal AFV0() { 847 return QuantEncodingInternal::AFV(DCT4X8().dct_params, DCT4X4().dct_params, 848 {{{{ 849 // 4x4/4x8 DC tendency. 850 V(3072.0), 851 V(3072.0), 852 // AFV corner. 853 V(256.0), 854 V(256.0), 855 V(256.0), 856 // AFV high freqs. 857 V(414.0), 858 V(0.0), 859 V(0.0), 860 V(0.0), 861 }}, 862 {{ 863 // 4x4/4x8 DC tendency. 864 V(1024.0), 865 V(1024.0), 866 // AFV corner. 867 V(50), 868 V(50), 869 V(50), 870 // AFV high freqs. 871 V(58.0), 872 V(0.0), 873 V(0.0), 874 V(0.0), 875 }}, 876 {{ 877 // 4x4/4x8 DC tendency. 878 V(384.0), 879 V(384.0), 880 // AFV corner. 881 V(12.0), 882 V(12.0), 883 V(12.0), 884 // AFV high freqs. 885 V(22.0), 886 V(-0.25), 887 V(-0.25), 888 V(-0.25), 889 }}}}); 890 } 891 892 // DCT64 893 static QuantEncodingInternal DCT64X64() { 894 return QuantEncodingInternal::DCT( 895 DctQuantWeightParams({{{{ 896 V(0.9 * 26629.073922049845), 897 V(-1.025), 898 V(-0.78), 899 V(-0.65012), 900 V(-0.19041574084286472), 901 V(-0.20819395464), 902 V(-0.421064), 903 V(-0.32733845535848671), 904 }}, 905 {{ 906 V(0.9 * 9311.3238710010046), 907 V(-0.3041958212306401), 908 V(-0.3633036457487539), 909 V(-0.35660379990111464), 910 V(-0.3443074455424403), 911 V(-0.33699592683512467), 912 V(-0.30180866526242109), 913 V(-0.27321683125358037), 914 }}, 915 {{ 916 V(0.9 * 4992.2486445538634), 917 V(-1.2), 918 V(-1.2), 919 V(-0.8), 920 V(-0.7), 921 V(-0.7), 922 V(-0.4), 923 V(-0.5), 924 }}}}, 925 8)); 926 } 927 928 // DCT64X32 929 static QuantEncodingInternal DCT32X64() { 930 return QuantEncodingInternal::DCT( 931 DctQuantWeightParams({{{{ 932 V(0.65 * 23629.073922049845), 933 V(-1.025), 934 V(-0.78), 935 V(-0.65012), 936 V(-0.19041574084286472), 937 V(-0.20819395464), 938 V(-0.421064), 939 V(-0.32733845535848671), 940 }}, 941 {{ 942 V(0.65 * 8611.3238710010046), 943 V(-0.3041958212306401), 944 V(-0.3633036457487539), 945 V(-0.35660379990111464), 946 V(-0.3443074455424403), 947 V(-0.33699592683512467), 948 V(-0.30180866526242109), 949 V(-0.27321683125358037), 950 }}, 951 {{ 952 V(0.65 * 4492.2486445538634), 953 V(-1.2), 954 V(-1.2), 955 V(-0.8), 956 V(-0.7), 957 V(-0.7), 958 V(-0.4), 959 V(-0.5), 960 }}}}, 961 8)); 962 } 963 // DCT128X128 964 static QuantEncodingInternal DCT128X128() { 965 return QuantEncodingInternal::DCT( 966 DctQuantWeightParams({{{{ 967 V(1.8 * 26629.073922049845), 968 V(-1.025), 969 V(-0.78), 970 V(-0.65012), 971 V(-0.19041574084286472), 972 V(-0.20819395464), 973 V(-0.421064), 974 V(-0.32733845535848671), 975 }}, 976 {{ 977 V(1.8 * 9311.3238710010046), 978 V(-0.3041958212306401), 979 V(-0.3633036457487539), 980 V(-0.35660379990111464), 981 V(-0.3443074455424403), 982 V(-0.33699592683512467), 983 V(-0.30180866526242109), 984 V(-0.27321683125358037), 985 }}, 986 {{ 987 V(1.8 * 4992.2486445538634), 988 V(-1.2), 989 V(-1.2), 990 V(-0.8), 991 V(-0.7), 992 V(-0.7), 993 V(-0.4), 994 V(-0.5), 995 }}}}, 996 8)); 997 } 998 999 // DCT128X64 1000 static QuantEncodingInternal DCT64X128() { 1001 return QuantEncodingInternal::DCT( 1002 DctQuantWeightParams({{{{ 1003 V(1.3 * 23629.073922049845), 1004 V(-1.025), 1005 V(-0.78), 1006 V(-0.65012), 1007 V(-0.19041574084286472), 1008 V(-0.20819395464), 1009 V(-0.421064), 1010 V(-0.32733845535848671), 1011 }}, 1012 {{ 1013 V(1.3 * 8611.3238710010046), 1014 V(-0.3041958212306401), 1015 V(-0.3633036457487539), 1016 V(-0.35660379990111464), 1017 V(-0.3443074455424403), 1018 V(-0.33699592683512467), 1019 V(-0.30180866526242109), 1020 V(-0.27321683125358037), 1021 }}, 1022 {{ 1023 V(1.3 * 4492.2486445538634), 1024 V(-1.2), 1025 V(-1.2), 1026 V(-0.8), 1027 V(-0.7), 1028 V(-0.7), 1029 V(-0.4), 1030 V(-0.5), 1031 }}}}, 1032 8)); 1033 } 1034 // DCT256X256 1035 static QuantEncodingInternal DCT256X256() { 1036 return QuantEncodingInternal::DCT( 1037 DctQuantWeightParams({{{{ 1038 V(3.6 * 26629.073922049845), 1039 V(-1.025), 1040 V(-0.78), 1041 V(-0.65012), 1042 V(-0.19041574084286472), 1043 V(-0.20819395464), 1044 V(-0.421064), 1045 V(-0.32733845535848671), 1046 }}, 1047 {{ 1048 V(3.6 * 9311.3238710010046), 1049 V(-0.3041958212306401), 1050 V(-0.3633036457487539), 1051 V(-0.35660379990111464), 1052 V(-0.3443074455424403), 1053 V(-0.33699592683512467), 1054 V(-0.30180866526242109), 1055 V(-0.27321683125358037), 1056 }}, 1057 {{ 1058 V(3.6 * 4992.2486445538634), 1059 V(-1.2), 1060 V(-1.2), 1061 V(-0.8), 1062 V(-0.7), 1063 V(-0.7), 1064 V(-0.4), 1065 V(-0.5), 1066 }}}}, 1067 8)); 1068 } 1069 1070 // DCT256X128 1071 static QuantEncodingInternal DCT128X256() { 1072 return QuantEncodingInternal::DCT( 1073 DctQuantWeightParams({{{{ 1074 V(2.6 * 23629.073922049845), 1075 V(-1.025), 1076 V(-0.78), 1077 V(-0.65012), 1078 V(-0.19041574084286472), 1079 V(-0.20819395464), 1080 V(-0.421064), 1081 V(-0.32733845535848671), 1082 }}, 1083 {{ 1084 V(2.6 * 8611.3238710010046), 1085 V(-0.3041958212306401), 1086 V(-0.3633036457487539), 1087 V(-0.35660379990111464), 1088 V(-0.3443074455424403), 1089 V(-0.33699592683512467), 1090 V(-0.30180866526242109), 1091 V(-0.27321683125358037), 1092 }}, 1093 {{ 1094 V(2.6 * 4492.2486445538634), 1095 V(-1.2), 1096 V(-1.2), 1097 V(-0.8), 1098 V(-0.7), 1099 V(-0.7), 1100 V(-0.4), 1101 V(-0.5), 1102 }}}}, 1103 8)); 1104 } 1105 }; 1106 } // namespace 1107 1108 DequantMatrices::DequantLibraryInternal DequantMatrices::LibraryInit() { 1109 static_assert(kNumQuantTables == 17, 1110 "Update this function when adding new quantization kinds."); 1111 static_assert(kNumPredefinedTables == 1, 1112 "Update this function when adding new quantization matrices to " 1113 "the library."); 1114 1115 // The library and the indices need to be kept in sync manually. 1116 static_assert(0 == static_cast<uint8_t>(QuantTable::DCT), 1117 "Update the DequantLibrary array below."); 1118 static_assert(1 == static_cast<uint8_t>(QuantTable::IDENTITY), 1119 "Update the DequantLibrary array below."); 1120 static_assert(2 == static_cast<uint8_t>(QuantTable::DCT2X2), 1121 "Update the DequantLibrary array below."); 1122 static_assert(3 == static_cast<uint8_t>(QuantTable::DCT4X4), 1123 "Update the DequantLibrary array below."); 1124 static_assert(4 == static_cast<uint8_t>(QuantTable::DCT16X16), 1125 "Update the DequantLibrary array below."); 1126 static_assert(5 == static_cast<uint8_t>(QuantTable::DCT32X32), 1127 "Update the DequantLibrary array below."); 1128 static_assert(6 == static_cast<uint8_t>(QuantTable::DCT8X16), 1129 "Update the DequantLibrary array below."); 1130 static_assert(7 == static_cast<uint8_t>(QuantTable::DCT8X32), 1131 "Update the DequantLibrary array below."); 1132 static_assert(8 == static_cast<uint8_t>(QuantTable::DCT16X32), 1133 "Update the DequantLibrary array below."); 1134 static_assert(9 == static_cast<uint8_t>(QuantTable::DCT4X8), 1135 "Update the DequantLibrary array below."); 1136 static_assert(10 == static_cast<uint8_t>(QuantTable::AFV0), 1137 "Update the DequantLibrary array below."); 1138 static_assert(11 == static_cast<uint8_t>(QuantTable::DCT64X64), 1139 "Update the DequantLibrary array below."); 1140 static_assert(12 == static_cast<uint8_t>(QuantTable::DCT32X64), 1141 "Update the DequantLibrary array below."); 1142 static_assert(13 == static_cast<uint8_t>(QuantTable::DCT128X128), 1143 "Update the DequantLibrary array below."); 1144 static_assert(14 == static_cast<uint8_t>(QuantTable::DCT64X128), 1145 "Update the DequantLibrary array below."); 1146 static_assert(15 == static_cast<uint8_t>(QuantTable::DCT256X256), 1147 "Update the DequantLibrary array below."); 1148 static_assert(16 == static_cast<uint8_t>(QuantTable::DCT128X256), 1149 "Update the DequantLibrary array below."); 1150 return DequantMatrices::DequantLibraryInternal{{ 1151 DequantMatricesLibraryDef::DCT(), 1152 DequantMatricesLibraryDef::IDENTITY(), 1153 DequantMatricesLibraryDef::DCT2X2(), 1154 DequantMatricesLibraryDef::DCT4X4(), 1155 DequantMatricesLibraryDef::DCT16X16(), 1156 DequantMatricesLibraryDef::DCT32X32(), 1157 DequantMatricesLibraryDef::DCT8X16(), 1158 DequantMatricesLibraryDef::DCT8X32(), 1159 DequantMatricesLibraryDef::DCT16X32(), 1160 DequantMatricesLibraryDef::DCT4X8(), 1161 DequantMatricesLibraryDef::AFV0(), 1162 DequantMatricesLibraryDef::DCT64X64(), 1163 DequantMatricesLibraryDef::DCT32X64(), 1164 // Same default for large transforms (128+) as for 64x* transforms. 1165 DequantMatricesLibraryDef::DCT128X128(), 1166 DequantMatricesLibraryDef::DCT64X128(), 1167 DequantMatricesLibraryDef::DCT256X256(), 1168 DequantMatricesLibraryDef::DCT128X256(), 1169 }}; 1170 } 1171 1172 const QuantEncoding* DequantMatrices::Library() { 1173 static const DequantMatrices::DequantLibraryInternal kDequantLibrary = 1174 DequantMatrices::LibraryInit(); 1175 // Downcast the result to a const QuantEncoding* from QuantEncodingInternal* 1176 // since the subclass (QuantEncoding) doesn't add any new members and users 1177 // will need to upcast to QuantEncodingInternal to access the members of that 1178 // class. This allows to have kDequantLibrary as a constexpr value while still 1179 // allowing to create QuantEncoding::RAW() instances that use std::vector in 1180 // C++11. 1181 return reinterpret_cast<const QuantEncoding*>(kDequantLibrary.data()); 1182 } 1183 1184 DequantMatrices::DequantMatrices() { 1185 encodings_.resize(kNumQuantTables, QuantEncoding::Library<0>()); 1186 size_t pos = 0; 1187 size_t offsets[kNumQuantTables * 3]; 1188 for (size_t i = 0; i < static_cast<size_t>(kNumQuantTables); i++) { 1189 size_t num = required_size_x[i] * required_size_y[i] * kDCTBlockSize; 1190 for (size_t c = 0; c < 3; c++) { 1191 offsets[3 * i + c] = pos + c * num; 1192 } 1193 pos += 3 * num; 1194 } 1195 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { 1196 for (size_t c = 0; c < 3; c++) { 1197 table_offsets_[i * 3 + c] = 1198 offsets[static_cast<size_t>(kAcStrategyToQuantTableMap[i]) * 3 + c]; 1199 } 1200 } 1201 } 1202 1203 Status DequantMatrices::EnsureComputed(JxlMemoryManager* memory_manager, 1204 uint32_t acs_mask) { 1205 const QuantEncoding* library = Library(); 1206 1207 if (!table_storage_) { 1208 size_t table_storage_bytes = 2 * kTotalTableSize * sizeof(float); 1209 JXL_ASSIGN_OR_RETURN( 1210 table_storage_, 1211 AlignedMemory::Create(memory_manager, table_storage_bytes)); 1212 table_ = table_storage_.address<float>(); 1213 inv_table_ = table_ + kTotalTableSize; 1214 } 1215 1216 size_t offsets[kNumQuantTables * 3 + 1]; 1217 size_t pos = 0; 1218 for (size_t i = 0; i < kNumQuantTables; i++) { 1219 size_t num = required_size_x[i] * required_size_y[i] * kDCTBlockSize; 1220 for (size_t c = 0; c < 3; c++) { 1221 offsets[3 * i + c] = pos + c * num; 1222 } 1223 pos += 3 * num; 1224 } 1225 offsets[kNumQuantTables * 3] = pos; 1226 JXL_ENSURE(pos == kTotalTableSize); 1227 1228 uint32_t kind_mask = 0; 1229 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { 1230 if (acs_mask & (1u << i)) { 1231 kind_mask |= 1u << static_cast<uint32_t>(kAcStrategyToQuantTableMap[i]); 1232 } 1233 } 1234 uint32_t computed_kind_mask = 0; 1235 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { 1236 if (computed_mask_ & (1u << i)) { 1237 computed_kind_mask |= 1238 1u << static_cast<uint32_t>(kAcStrategyToQuantTableMap[i]); 1239 } 1240 } 1241 for (size_t table = 0; table < kNumQuantTables; table++) { 1242 if ((1 << table) & computed_kind_mask) continue; 1243 if ((1 << table) & ~kind_mask) continue; 1244 size_t pos = offsets[table * 3]; 1245 float* mutable_table = table_storage_.address<float>(); 1246 if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { 1247 JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( 1248 library[table], mutable_table, mutable_table + kTotalTableSize, table, 1249 QuantTable(table), &pos)); 1250 } else { 1251 JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( 1252 encodings_[table], mutable_table, mutable_table + kTotalTableSize, 1253 table, QuantTable(table), &pos)); 1254 } 1255 JXL_ENSURE(pos == offsets[table * 3 + 3]); 1256 } 1257 computed_mask_ |= acs_mask; 1258 1259 return true; 1260 } 1261 1262 } // namespace jxl 1263 #endif