quant_weights.h (15590B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #ifndef LIB_JXL_QUANT_WEIGHTS_H_ 7 #define LIB_JXL_QUANT_WEIGHTS_H_ 8 9 #include <jxl/memory_manager.h> 10 11 #include <array> 12 #include <cstdint> 13 #include <cstring> 14 #include <vector> 15 16 #include "lib/jxl/ac_strategy.h" 17 #include "lib/jxl/base/common.h" 18 #include "lib/jxl/base/compiler_specific.h" 19 #include "lib/jxl/base/status.h" 20 #include "lib/jxl/dec_bit_reader.h" 21 #include "lib/jxl/frame_dimensions.h" 22 #include "lib/jxl/memory_manager_internal.h" 23 24 namespace jxl { 25 26 static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea; 27 static constexpr size_t kNumPredefinedTables = 1; 28 static constexpr size_t kCeilLog2NumPredefinedTables = 0; 29 static constexpr size_t kLog2NumQuantModes = 3; 30 31 struct DctQuantWeightParams { 32 static constexpr size_t kLog2MaxDistanceBands = 4; 33 static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands); 34 typedef std::array<std::array<float, kMaxDistanceBands>, 3> 35 DistanceBandsArray; 36 37 size_t num_distance_bands = 0; 38 DistanceBandsArray distance_bands = {}; 39 40 constexpr DctQuantWeightParams() : num_distance_bands(0) {} 41 42 constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands, 43 size_t num_dist_bands) 44 : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {} 45 46 template <size_t num_dist_bands> 47 explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) { 48 num_distance_bands = num_dist_bands; 49 for (size_t c = 0; c < 3; c++) { 50 memcpy(distance_bands[c].data(), dist_bands[c], 51 sizeof(float) * num_dist_bands); 52 } 53 } 54 }; 55 56 // NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) 57 struct QuantEncodingInternal { 58 enum Mode { 59 kQuantModeLibrary, 60 kQuantModeID, 61 kQuantModeDCT2, 62 kQuantModeDCT4, 63 kQuantModeDCT4X8, 64 kQuantModeAFV, 65 kQuantModeDCT, 66 kQuantModeRAW, 67 }; 68 69 template <Mode mode> 70 struct Tag {}; 71 72 typedef std::array<std::array<float, 3>, 3> IdWeights; 73 typedef std::array<std::array<float, 6>, 3> DCT2Weights; 74 typedef std::array<std::array<float, 2>, 3> DCT4Multipliers; 75 typedef std::array<std::array<float, 9>, 3> AFVWeights; 76 typedef std::array<float, 3> DCT4x8Multipliers; 77 78 template <size_t A> 79 static constexpr QuantEncodingInternal Library() { 80 static_assert(A < kNumPredefinedTables); 81 return QuantEncodingInternal(Tag<kQuantModeLibrary>(), A); 82 } 83 constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */, 84 uint8_t predefined) 85 : mode(kQuantModeLibrary), predefined(predefined) {} 86 87 // Identity 88 // xybweights is an array of {xweights, yweights, bweights}. 89 static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) { 90 return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights); 91 } 92 constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */, 93 const IdWeights& xybweights) 94 : mode(kQuantModeID), idweights(xybweights) {} 95 96 // DCT2 97 static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) { 98 return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights); 99 } 100 constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */, 101 const DCT2Weights& xybweights) 102 : mode(kQuantModeDCT2), dct2weights(xybweights) {} 103 104 // DCT4 105 static constexpr QuantEncodingInternal DCT4( 106 const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) { 107 return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul); 108 } 109 constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */, 110 const DctQuantWeightParams& params, 111 const DCT4Multipliers& xybmul) 112 : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {} 113 114 // DCT4x8 115 static constexpr QuantEncodingInternal DCT4X8( 116 const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) { 117 return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul); 118 } 119 constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */, 120 const DctQuantWeightParams& params, 121 const DCT4x8Multipliers& xybmul) 122 : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {} 123 124 // DCT 125 static constexpr QuantEncodingInternal DCT( 126 const DctQuantWeightParams& params) { 127 return QuantEncodingInternal(Tag<kQuantModeDCT>(), params); 128 } 129 constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */, 130 const DctQuantWeightParams& params) 131 : mode(kQuantModeDCT), dct_params(params) {} 132 133 // AFV 134 static constexpr QuantEncodingInternal AFV( 135 const DctQuantWeightParams& params4x8, 136 const DctQuantWeightParams& params4x4, const AFVWeights& weights) { 137 return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4, 138 weights); 139 } 140 constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */, 141 const DctQuantWeightParams& params4x8, 142 const DctQuantWeightParams& params4x4, 143 const AFVWeights& weights) 144 : mode(kQuantModeAFV), 145 dct_params(params4x8), 146 afv_weights(weights), 147 dct_params_afv_4x4(params4x4) {} 148 149 // This constructor is not constexpr so it can't be used in any of the 150 // constexpr cases above. 151 explicit QuantEncodingInternal(Mode mode) : mode(mode) {} 152 153 Mode mode; 154 155 // Weights for DCT4+ tables. 156 DctQuantWeightParams dct_params; 157 158 union { 159 // Weights for identity. 160 IdWeights idweights; 161 162 // Weights for DCT2. 163 DCT2Weights dct2weights; 164 165 // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV. 166 DCT4Multipliers dct4multipliers; 167 168 // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1, 169 // 0); {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) + 170 // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated 171 // as in GetQuantWeights for DC and are used for other coefficients. 172 AFVWeights afv_weights = {}; 173 174 // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4. 175 DCT4x8Multipliers dct4x8multipliers; 176 177 // Only used in kQuantModeRAW mode. 178 struct { 179 // explicit quantization table (like in JPEG) 180 std::vector<int>* qtable = nullptr; 181 float qtable_den = 1.f / (8 * 255); 182 } qraw; 183 }; 184 185 // Weights for 4x4 sub-block in AFV. 186 DctQuantWeightParams dct_params_afv_4x4; 187 188 union { 189 // Which predefined table to use. Only used if mode is kQuantModeLibrary. 190 uint8_t predefined = 0; 191 192 // Which other quant table to copy; must copy from a table that comes before 193 // the current one. Only used if mode is kQuantModeCopy. 194 uint8_t source; 195 }; 196 }; 197 198 class QuantEncoding final : public QuantEncodingInternal { 199 public: 200 QuantEncoding(const QuantEncoding& other) 201 : QuantEncodingInternal( 202 static_cast<const QuantEncodingInternal&>(other)) { 203 if (mode == kQuantModeRAW && qraw.qtable) { 204 // Need to make a copy of the passed *qtable. 205 qraw.qtable = new std::vector<int>(*other.qraw.qtable); 206 } 207 } 208 QuantEncoding(QuantEncoding&& other) noexcept 209 : QuantEncodingInternal( 210 static_cast<const QuantEncodingInternal&>(other)) { 211 // Steal the qtable from the other object if any. 212 if (mode == kQuantModeRAW) { 213 other.qraw.qtable = nullptr; 214 } 215 } 216 QuantEncoding& operator=(const QuantEncoding& other) { 217 if (mode == kQuantModeRAW && qraw.qtable) { 218 delete qraw.qtable; 219 } 220 *static_cast<QuantEncodingInternal*>(this) = 221 QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other)); 222 if (mode == kQuantModeRAW && qraw.qtable) { 223 // Need to make a copy of the passed *qtable. 224 qraw.qtable = new std::vector<int>(*other.qraw.qtable); 225 } 226 return *this; 227 } 228 229 ~QuantEncoding() { 230 if (mode == kQuantModeRAW && qraw.qtable) { 231 delete qraw.qtable; 232 } 233 } 234 235 // Wrappers of the QuantEncodingInternal:: static functions that return a 236 // QuantEncoding instead. This is using the explicit and private cast from 237 // QuantEncodingInternal to QuantEncoding, which would be inlined anyway. 238 // In general, you should use this wrappers. The only reason to directly 239 // create a QuantEncodingInternal instance is if you need a constexpr version 240 // of this class. Note that RAW() is not supported in that case since it uses 241 // a std::vector. 242 template <size_t A> 243 static QuantEncoding Library() { 244 return QuantEncoding(QuantEncodingInternal::Library<A>()); 245 } 246 static QuantEncoding Identity(const IdWeights& xybweights) { 247 return QuantEncoding(QuantEncodingInternal::Identity(xybweights)); 248 } 249 static QuantEncoding DCT2(const DCT2Weights& xybweights) { 250 return QuantEncoding(QuantEncodingInternal::DCT2(xybweights)); 251 } 252 static QuantEncoding DCT4(const DctQuantWeightParams& params, 253 const DCT4Multipliers& xybmul) { 254 return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul)); 255 } 256 static QuantEncoding DCT4X8(const DctQuantWeightParams& params, 257 const DCT4x8Multipliers& xybmul) { 258 return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul)); 259 } 260 static QuantEncoding DCT(const DctQuantWeightParams& params) { 261 return QuantEncoding(QuantEncodingInternal::DCT(params)); 262 } 263 static QuantEncoding AFV(const DctQuantWeightParams& params4x8, 264 const DctQuantWeightParams& params4x4, 265 const AFVWeights& weights) { 266 return QuantEncoding( 267 QuantEncodingInternal::AFV(params4x8, params4x4, weights)); 268 } 269 270 // RAW, note that this one is not a constexpr one. 271 static QuantEncoding RAW(std::vector<int>&& qtable, int shift = 0) { 272 QuantEncoding encoding(kQuantModeRAW); 273 encoding.qraw.qtable = new std::vector<int>(); 274 *encoding.qraw.qtable = qtable; 275 encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255)); 276 return encoding; 277 } 278 279 private: 280 explicit QuantEncoding(const QuantEncodingInternal& other) 281 : QuantEncodingInternal(other) {} 282 283 explicit QuantEncoding(QuantEncodingInternal::Mode mode_arg) 284 : QuantEncodingInternal(mode_arg) {} 285 }; 286 287 // A constexpr QuantEncodingInternal instance is often downcasted to the 288 // QuantEncoding subclass even if the instance wasn't an instance of the 289 // subclass. This is safe because user will upcast to QuantEncodingInternal to 290 // access any of its members. 291 static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal), 292 "Don't add any members to QuantEncoding"); 293 294 // Let's try to keep these 2**N for possible future simplicity. 295 const float kInvDCQuant[3] = { 296 4096.0f, 297 512.0f, 298 256.0f, 299 }; 300 301 const float kDCQuant[3] = { 302 1.0f / kInvDCQuant[0], 303 1.0f / kInvDCQuant[1], 304 1.0f / kInvDCQuant[2], 305 }; 306 307 class ModularFrameEncoder; 308 class ModularFrameDecoder; 309 310 enum class QuantTable : size_t { 311 DCT = 0, 312 IDENTITY, 313 DCT2X2, 314 DCT4X4, 315 DCT16X16, 316 DCT32X32, 317 // DCT16X8 318 DCT8X16, 319 // DCT32X8 320 DCT8X32, 321 // DCT32X16 322 DCT16X32, 323 DCT4X8, 324 // DCT8X4 325 AFV0, 326 // AFV1 327 // AFV2 328 // AFV3 329 DCT64X64, 330 // DCT64X32, 331 DCT32X64, 332 DCT128X128, 333 // DCT128X64, 334 DCT64X128, 335 DCT256X256, 336 // DCT256X128, 337 DCT128X256 338 }; 339 340 static constexpr uint8_t kNumQuantTables = 341 static_cast<uint8_t>(QuantTable::DCT128X256) + 1; 342 343 static const std::array<QuantTable, AcStrategy::kNumValidStrategies> 344 kAcStrategyToQuantTableMap = { 345 QuantTable::DCT, QuantTable::IDENTITY, QuantTable::DCT2X2, 346 QuantTable::DCT4X4, QuantTable::DCT16X16, QuantTable::DCT32X32, 347 QuantTable::DCT8X16, QuantTable::DCT8X16, QuantTable::DCT8X32, 348 QuantTable::DCT8X32, QuantTable::DCT16X32, QuantTable::DCT16X32, 349 QuantTable::DCT4X8, QuantTable::DCT4X8, QuantTable::AFV0, 350 QuantTable::AFV0, QuantTable::AFV0, QuantTable::AFV0, 351 QuantTable::DCT64X64, QuantTable::DCT32X64, QuantTable::DCT32X64, 352 QuantTable::DCT128X128, QuantTable::DCT64X128, QuantTable::DCT64X128, 353 QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256, 354 }; 355 356 class DequantMatrices { 357 public: 358 DequantMatrices(); 359 360 static const QuantEncoding* Library(); 361 362 typedef std::array<QuantEncodingInternal, 363 kNumPredefinedTables * kNumQuantTables> 364 DequantLibraryInternal; 365 // Return the array of library kNumPredefinedTables QuantEncoding entries as 366 // a constexpr array. Use Library() to obtain a pointer to the copy in the 367 // .cc file. 368 static DequantLibraryInternal LibraryInit(); 369 370 // Returns aligned memory. 371 JXL_INLINE const float* Matrix(AcStrategyType quant_kind, size_t c) const { 372 JXL_DASSERT((1 << static_cast<uint32_t>(quant_kind)) & computed_mask_); 373 return &table_[table_offsets_[static_cast<size_t>(quant_kind) * 3 + c]]; 374 } 375 376 JXL_INLINE const float* InvMatrix(AcStrategyType quant_kind, size_t c) const { 377 size_t quant_table_idx = static_cast<uint32_t>(quant_kind); 378 JXL_DASSERT((1 << quant_table_idx) & computed_mask_); 379 return &inv_table_[table_offsets_[quant_table_idx * 3 + c]]; 380 } 381 382 // DC quants are used in modular mode for XYB multipliers. 383 JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; } 384 JXL_INLINE const float* DCQuants() const { return dc_quant_; } 385 386 JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; } 387 388 // For encoder. 389 void SetEncodings(const std::vector<QuantEncoding>& encodings) { 390 encodings_ = encodings; 391 computed_mask_ = 0; 392 } 393 394 // For encoder. 395 void SetDCQuant(const float dc[3]) { 396 for (size_t c = 0; c < 3; c++) { 397 dc_quant_[c] = 1.0f / dc[c]; 398 inv_dc_quant_[c] = dc[c]; 399 } 400 } 401 402 Status Decode(JxlMemoryManager* memory_manager, BitReader* br, 403 ModularFrameDecoder* modular_frame_decoder = nullptr); 404 Status DecodeDC(BitReader* br); 405 406 const std::vector<QuantEncoding>& encodings() const { return encodings_; } 407 408 static constexpr auto required_size_x = 409 to_array<int>({1, 1, 1, 1, 2, 4, 1, 1, 2, 1, 1, 8, 4, 16, 8, 32, 16}); 410 static_assert(kNumQuantTables == required_size_x.size(), 411 "Update this array when adding or removing quant tables."); 412 413 static constexpr auto required_size_y = 414 to_array<int>({1, 1, 1, 1, 2, 4, 2, 4, 4, 1, 1, 8, 8, 16, 16, 32, 32}); 415 static_assert(kNumQuantTables == required_size_y.size(), 416 "Update this array when adding or removing quant tables."); 417 418 // MUST be equal `sum(dot(required_size_x, required_size_y))`. 419 static constexpr size_t kSumRequiredXy = 2056; 420 421 Status EnsureComputed(JxlMemoryManager* memory_manager, uint32_t acs_mask); 422 423 private: 424 static constexpr size_t kTotalTableSize = kSumRequiredXy * kDCTBlockSize * 3; 425 426 uint32_t computed_mask_ = 0; 427 // kTotalTableSize entries followed by kTotalTableSize for inv_table 428 AlignedMemory table_storage_; 429 const float* table_; 430 const float* inv_table_; 431 float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]}; 432 float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]}; 433 size_t table_offsets_[AcStrategy::kNumValidStrategies * 3]; 434 std::vector<QuantEncoding> encodings_; 435 }; 436 437 } // namespace jxl 438 439 #endif // LIB_JXL_QUANT_WEIGHTS_H_