tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

quant_weights.h (15590B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #ifndef LIB_JXL_QUANT_WEIGHTS_H_
      7 #define LIB_JXL_QUANT_WEIGHTS_H_
      8 
      9 #include <jxl/memory_manager.h>
     10 
     11 #include <array>
     12 #include <cstdint>
     13 #include <cstring>
     14 #include <vector>
     15 
     16 #include "lib/jxl/ac_strategy.h"
     17 #include "lib/jxl/base/common.h"
     18 #include "lib/jxl/base/compiler_specific.h"
     19 #include "lib/jxl/base/status.h"
     20 #include "lib/jxl/dec_bit_reader.h"
     21 #include "lib/jxl/frame_dimensions.h"
     22 #include "lib/jxl/memory_manager_internal.h"
     23 
     24 namespace jxl {
     25 
     26 static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea;
     27 static constexpr size_t kNumPredefinedTables = 1;
     28 static constexpr size_t kCeilLog2NumPredefinedTables = 0;
     29 static constexpr size_t kLog2NumQuantModes = 3;
     30 
     31 struct DctQuantWeightParams {
     32  static constexpr size_t kLog2MaxDistanceBands = 4;
     33  static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands);
     34  typedef std::array<std::array<float, kMaxDistanceBands>, 3>
     35      DistanceBandsArray;
     36 
     37  size_t num_distance_bands = 0;
     38  DistanceBandsArray distance_bands = {};
     39 
     40  constexpr DctQuantWeightParams() : num_distance_bands(0) {}
     41 
     42  constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands,
     43                                 size_t num_dist_bands)
     44      : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {}
     45 
     46  template <size_t num_dist_bands>
     47  explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) {
     48    num_distance_bands = num_dist_bands;
     49    for (size_t c = 0; c < 3; c++) {
     50      memcpy(distance_bands[c].data(), dist_bands[c],
     51             sizeof(float) * num_dist_bands);
     52    }
     53  }
     54 };
     55 
     56 // NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding)
     57 struct QuantEncodingInternal {
     58  enum Mode {
     59    kQuantModeLibrary,
     60    kQuantModeID,
     61    kQuantModeDCT2,
     62    kQuantModeDCT4,
     63    kQuantModeDCT4X8,
     64    kQuantModeAFV,
     65    kQuantModeDCT,
     66    kQuantModeRAW,
     67  };
     68 
     69  template <Mode mode>
     70  struct Tag {};
     71 
     72  typedef std::array<std::array<float, 3>, 3> IdWeights;
     73  typedef std::array<std::array<float, 6>, 3> DCT2Weights;
     74  typedef std::array<std::array<float, 2>, 3> DCT4Multipliers;
     75  typedef std::array<std::array<float, 9>, 3> AFVWeights;
     76  typedef std::array<float, 3> DCT4x8Multipliers;
     77 
     78  template <size_t A>
     79  static constexpr QuantEncodingInternal Library() {
     80    static_assert(A < kNumPredefinedTables);
     81    return QuantEncodingInternal(Tag<kQuantModeLibrary>(), A);
     82  }
     83  constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */,
     84                                  uint8_t predefined)
     85      : mode(kQuantModeLibrary), predefined(predefined) {}
     86 
     87  // Identity
     88  // xybweights is an array of {xweights, yweights, bweights}.
     89  static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) {
     90    return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights);
     91  }
     92  constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */,
     93                                  const IdWeights& xybweights)
     94      : mode(kQuantModeID), idweights(xybweights) {}
     95 
     96  // DCT2
     97  static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) {
     98    return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights);
     99  }
    100  constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */,
    101                                  const DCT2Weights& xybweights)
    102      : mode(kQuantModeDCT2), dct2weights(xybweights) {}
    103 
    104  // DCT4
    105  static constexpr QuantEncodingInternal DCT4(
    106      const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) {
    107    return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul);
    108  }
    109  constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */,
    110                                  const DctQuantWeightParams& params,
    111                                  const DCT4Multipliers& xybmul)
    112      : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {}
    113 
    114  // DCT4x8
    115  static constexpr QuantEncodingInternal DCT4X8(
    116      const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) {
    117    return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul);
    118  }
    119  constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */,
    120                                  const DctQuantWeightParams& params,
    121                                  const DCT4x8Multipliers& xybmul)
    122      : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {}
    123 
    124  // DCT
    125  static constexpr QuantEncodingInternal DCT(
    126      const DctQuantWeightParams& params) {
    127    return QuantEncodingInternal(Tag<kQuantModeDCT>(), params);
    128  }
    129  constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */,
    130                                  const DctQuantWeightParams& params)
    131      : mode(kQuantModeDCT), dct_params(params) {}
    132 
    133  // AFV
    134  static constexpr QuantEncodingInternal AFV(
    135      const DctQuantWeightParams& params4x8,
    136      const DctQuantWeightParams& params4x4, const AFVWeights& weights) {
    137    return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4,
    138                                 weights);
    139  }
    140  constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */,
    141                                  const DctQuantWeightParams& params4x8,
    142                                  const DctQuantWeightParams& params4x4,
    143                                  const AFVWeights& weights)
    144      : mode(kQuantModeAFV),
    145        dct_params(params4x8),
    146        afv_weights(weights),
    147        dct_params_afv_4x4(params4x4) {}
    148 
    149  // This constructor is not constexpr so it can't be used in any of the
    150  // constexpr cases above.
    151  explicit QuantEncodingInternal(Mode mode) : mode(mode) {}
    152 
    153  Mode mode;
    154 
    155  // Weights for DCT4+ tables.
    156  DctQuantWeightParams dct_params;
    157 
    158  union {
    159    // Weights for identity.
    160    IdWeights idweights;
    161 
    162    // Weights for DCT2.
    163    DCT2Weights dct2weights;
    164 
    165    // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV.
    166    DCT4Multipliers dct4multipliers;
    167 
    168    // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1,
    169    // 0);  {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) +
    170    // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated
    171    // as in GetQuantWeights for DC and are used for other coefficients.
    172    AFVWeights afv_weights = {};
    173 
    174    // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4.
    175    DCT4x8Multipliers dct4x8multipliers;
    176 
    177    // Only used in kQuantModeRAW mode.
    178    struct {
    179      // explicit quantization table (like in JPEG)
    180      std::vector<int>* qtable = nullptr;
    181      float qtable_den = 1.f / (8 * 255);
    182    } qraw;
    183  };
    184 
    185  // Weights for 4x4 sub-block in AFV.
    186  DctQuantWeightParams dct_params_afv_4x4;
    187 
    188  union {
    189    // Which predefined table to use. Only used if mode is kQuantModeLibrary.
    190    uint8_t predefined = 0;
    191 
    192    // Which other quant table to copy; must copy from a table that comes before
    193    // the current one. Only used if mode is kQuantModeCopy.
    194    uint8_t source;
    195  };
    196 };
    197 
    198 class QuantEncoding final : public QuantEncodingInternal {
    199 public:
    200  QuantEncoding(const QuantEncoding& other)
    201      : QuantEncodingInternal(
    202            static_cast<const QuantEncodingInternal&>(other)) {
    203    if (mode == kQuantModeRAW && qraw.qtable) {
    204      // Need to make a copy of the passed *qtable.
    205      qraw.qtable = new std::vector<int>(*other.qraw.qtable);
    206    }
    207  }
    208  QuantEncoding(QuantEncoding&& other) noexcept
    209      : QuantEncodingInternal(
    210            static_cast<const QuantEncodingInternal&>(other)) {
    211    // Steal the qtable from the other object if any.
    212    if (mode == kQuantModeRAW) {
    213      other.qraw.qtable = nullptr;
    214    }
    215  }
    216  QuantEncoding& operator=(const QuantEncoding& other) {
    217    if (mode == kQuantModeRAW && qraw.qtable) {
    218      delete qraw.qtable;
    219    }
    220    *static_cast<QuantEncodingInternal*>(this) =
    221        QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other));
    222    if (mode == kQuantModeRAW && qraw.qtable) {
    223      // Need to make a copy of the passed *qtable.
    224      qraw.qtable = new std::vector<int>(*other.qraw.qtable);
    225    }
    226    return *this;
    227  }
    228 
    229  ~QuantEncoding() {
    230    if (mode == kQuantModeRAW && qraw.qtable) {
    231      delete qraw.qtable;
    232    }
    233  }
    234 
    235  // Wrappers of the QuantEncodingInternal:: static functions that return a
    236  // QuantEncoding instead. This is using the explicit and private cast from
    237  // QuantEncodingInternal to QuantEncoding, which would be inlined anyway.
    238  // In general, you should use this wrappers. The only reason to directly
    239  // create a QuantEncodingInternal instance is if you need a constexpr version
    240  // of this class. Note that RAW() is not supported in that case since it uses
    241  // a std::vector.
    242  template <size_t A>
    243  static QuantEncoding Library() {
    244    return QuantEncoding(QuantEncodingInternal::Library<A>());
    245  }
    246  static QuantEncoding Identity(const IdWeights& xybweights) {
    247    return QuantEncoding(QuantEncodingInternal::Identity(xybweights));
    248  }
    249  static QuantEncoding DCT2(const DCT2Weights& xybweights) {
    250    return QuantEncoding(QuantEncodingInternal::DCT2(xybweights));
    251  }
    252  static QuantEncoding DCT4(const DctQuantWeightParams& params,
    253                            const DCT4Multipliers& xybmul) {
    254    return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul));
    255  }
    256  static QuantEncoding DCT4X8(const DctQuantWeightParams& params,
    257                              const DCT4x8Multipliers& xybmul) {
    258    return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul));
    259  }
    260  static QuantEncoding DCT(const DctQuantWeightParams& params) {
    261    return QuantEncoding(QuantEncodingInternal::DCT(params));
    262  }
    263  static QuantEncoding AFV(const DctQuantWeightParams& params4x8,
    264                           const DctQuantWeightParams& params4x4,
    265                           const AFVWeights& weights) {
    266    return QuantEncoding(
    267        QuantEncodingInternal::AFV(params4x8, params4x4, weights));
    268  }
    269 
    270  // RAW, note that this one is not a constexpr one.
    271  static QuantEncoding RAW(std::vector<int>&& qtable, int shift = 0) {
    272    QuantEncoding encoding(kQuantModeRAW);
    273    encoding.qraw.qtable = new std::vector<int>();
    274    *encoding.qraw.qtable = qtable;
    275    encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255));
    276    return encoding;
    277  }
    278 
    279 private:
    280  explicit QuantEncoding(const QuantEncodingInternal& other)
    281      : QuantEncodingInternal(other) {}
    282 
    283  explicit QuantEncoding(QuantEncodingInternal::Mode mode_arg)
    284      : QuantEncodingInternal(mode_arg) {}
    285 };
    286 
    287 // A constexpr QuantEncodingInternal instance is often downcasted to the
    288 // QuantEncoding subclass even if the instance wasn't an instance of the
    289 // subclass. This is safe because user will upcast to QuantEncodingInternal to
    290 // access any of its members.
    291 static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal),
    292              "Don't add any members to QuantEncoding");
    293 
    294 // Let's try to keep these 2**N for possible future simplicity.
    295 const float kInvDCQuant[3] = {
    296    4096.0f,
    297    512.0f,
    298    256.0f,
    299 };
    300 
    301 const float kDCQuant[3] = {
    302    1.0f / kInvDCQuant[0],
    303    1.0f / kInvDCQuant[1],
    304    1.0f / kInvDCQuant[2],
    305 };
    306 
    307 class ModularFrameEncoder;
    308 class ModularFrameDecoder;
    309 
    310 enum class QuantTable : size_t {
    311  DCT = 0,
    312  IDENTITY,
    313  DCT2X2,
    314  DCT4X4,
    315  DCT16X16,
    316  DCT32X32,
    317  // DCT16X8
    318  DCT8X16,
    319  // DCT32X8
    320  DCT8X32,
    321  // DCT32X16
    322  DCT16X32,
    323  DCT4X8,
    324  // DCT8X4
    325  AFV0,
    326  // AFV1
    327  // AFV2
    328  // AFV3
    329  DCT64X64,
    330  // DCT64X32,
    331  DCT32X64,
    332  DCT128X128,
    333  // DCT128X64,
    334  DCT64X128,
    335  DCT256X256,
    336  // DCT256X128,
    337  DCT128X256
    338 };
    339 
    340 static constexpr uint8_t kNumQuantTables =
    341    static_cast<uint8_t>(QuantTable::DCT128X256) + 1;
    342 
    343 static const std::array<QuantTable, AcStrategy::kNumValidStrategies>
    344    kAcStrategyToQuantTableMap = {
    345        QuantTable::DCT,        QuantTable::IDENTITY,   QuantTable::DCT2X2,
    346        QuantTable::DCT4X4,     QuantTable::DCT16X16,   QuantTable::DCT32X32,
    347        QuantTable::DCT8X16,    QuantTable::DCT8X16,    QuantTable::DCT8X32,
    348        QuantTable::DCT8X32,    QuantTable::DCT16X32,   QuantTable::DCT16X32,
    349        QuantTable::DCT4X8,     QuantTable::DCT4X8,     QuantTable::AFV0,
    350        QuantTable::AFV0,       QuantTable::AFV0,       QuantTable::AFV0,
    351        QuantTable::DCT64X64,   QuantTable::DCT32X64,   QuantTable::DCT32X64,
    352        QuantTable::DCT128X128, QuantTable::DCT64X128,  QuantTable::DCT64X128,
    353        QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256,
    354 };
    355 
    356 class DequantMatrices {
    357 public:
    358  DequantMatrices();
    359 
    360  static const QuantEncoding* Library();
    361 
    362  typedef std::array<QuantEncodingInternal,
    363                     kNumPredefinedTables * kNumQuantTables>
    364      DequantLibraryInternal;
    365  // Return the array of library kNumPredefinedTables QuantEncoding entries as
    366  // a constexpr array. Use Library() to obtain a pointer to the copy in the
    367  // .cc file.
    368  static DequantLibraryInternal LibraryInit();
    369 
    370  // Returns aligned memory.
    371  JXL_INLINE const float* Matrix(AcStrategyType quant_kind, size_t c) const {
    372    JXL_DASSERT((1 << static_cast<uint32_t>(quant_kind)) & computed_mask_);
    373    return &table_[table_offsets_[static_cast<size_t>(quant_kind) * 3 + c]];
    374  }
    375 
    376  JXL_INLINE const float* InvMatrix(AcStrategyType quant_kind, size_t c) const {
    377    size_t quant_table_idx = static_cast<uint32_t>(quant_kind);
    378    JXL_DASSERT((1 << quant_table_idx) & computed_mask_);
    379    return &inv_table_[table_offsets_[quant_table_idx * 3 + c]];
    380  }
    381 
    382  // DC quants are used in modular mode for XYB multipliers.
    383  JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; }
    384  JXL_INLINE const float* DCQuants() const { return dc_quant_; }
    385 
    386  JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; }
    387 
    388  // For encoder.
    389  void SetEncodings(const std::vector<QuantEncoding>& encodings) {
    390    encodings_ = encodings;
    391    computed_mask_ = 0;
    392  }
    393 
    394  // For encoder.
    395  void SetDCQuant(const float dc[3]) {
    396    for (size_t c = 0; c < 3; c++) {
    397      dc_quant_[c] = 1.0f / dc[c];
    398      inv_dc_quant_[c] = dc[c];
    399    }
    400  }
    401 
    402  Status Decode(JxlMemoryManager* memory_manager, BitReader* br,
    403                ModularFrameDecoder* modular_frame_decoder = nullptr);
    404  Status DecodeDC(BitReader* br);
    405 
    406  const std::vector<QuantEncoding>& encodings() const { return encodings_; }
    407 
    408  static constexpr auto required_size_x =
    409      to_array<int>({1, 1, 1, 1, 2, 4, 1, 1, 2, 1, 1, 8, 4, 16, 8, 32, 16});
    410  static_assert(kNumQuantTables == required_size_x.size(),
    411                "Update this array when adding or removing quant tables.");
    412 
    413  static constexpr auto required_size_y =
    414      to_array<int>({1, 1, 1, 1, 2, 4, 2, 4, 4, 1, 1, 8, 8, 16, 16, 32, 32});
    415  static_assert(kNumQuantTables == required_size_y.size(),
    416                "Update this array when adding or removing quant tables.");
    417 
    418  // MUST be equal `sum(dot(required_size_x, required_size_y))`.
    419  static constexpr size_t kSumRequiredXy = 2056;
    420 
    421  Status EnsureComputed(JxlMemoryManager* memory_manager, uint32_t acs_mask);
    422 
    423 private:
    424  static constexpr size_t kTotalTableSize = kSumRequiredXy * kDCTBlockSize * 3;
    425 
    426  uint32_t computed_mask_ = 0;
    427  // kTotalTableSize entries followed by kTotalTableSize for inv_table
    428  AlignedMemory table_storage_;
    429  const float* table_;
    430  const float* inv_table_;
    431  float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]};
    432  float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]};
    433  size_t table_offsets_[AcStrategy::kNumValidStrategies * 3];
    434  std::vector<QuantEncoding> encodings_;
    435 };
    436 
    437 }  // namespace jxl
    438 
    439 #endif  // LIB_JXL_QUANT_WEIGHTS_H_