tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

quant_weights_test.cc (10176B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 #include "lib/jxl/quant_weights.h"
      6 
      7 #include <jxl/memory_manager.h>
      8 
      9 #include <algorithm>
     10 #include <cmath>
     11 #include <cstdlib>
     12 #include <hwy/base.h>  // HWY_ALIGN_MAX
     13 #include <hwy/tests/hwy_gtest.h>
     14 #include <iterator>
     15 #include <numeric>
     16 #include <utility>
     17 #include <vector>
     18 
     19 #include "lib/jxl/ac_strategy.h"
     20 #include "lib/jxl/base/random.h"
     21 #include "lib/jxl/base/status.h"
     22 #include "lib/jxl/dct_for_test.h"
     23 #include "lib/jxl/dec_transforms_testonly.h"
     24 #include "lib/jxl/enc_modular.h"
     25 #include "lib/jxl/enc_params.h"
     26 #include "lib/jxl/enc_quant_weights.h"
     27 #include "lib/jxl/enc_transforms.h"
     28 #include "lib/jxl/frame_header.h"
     29 #include "lib/jxl/image_metadata.h"
     30 #include "lib/jxl/test_memory_manager.h"
     31 #include "lib/jxl/testing.h"
     32 
     33 namespace jxl {
     34 namespace {
     35 
     36 // This should have been static assert; not compiling though with C++<17.
     37 TEST(QuantWeightsTest, Invariant) {
     38  size_t sum = 0;
     39  ASSERT_EQ(DequantMatrices::required_size_x.size(),
     40            DequantMatrices::required_size_y.size());
     41  for (size_t i = 0; i < DequantMatrices::required_size_x.size(); ++i) {
     42    sum += DequantMatrices::required_size_x[i] *
     43           DequantMatrices::required_size_y[i];
     44  }
     45  ASSERT_EQ(DequantMatrices::kSumRequiredXy, sum);
     46 }
     47 
     48 template <typename T>
     49 void CheckSimilar(T a, T b) {
     50  EXPECT_EQ(a, b);
     51 }
     52 // minimum exponent = -15.
     53 template <>
     54 void CheckSimilar(float a, float b) {
     55  float m = std::max(std::abs(a), std::abs(b));
     56  // 10 bits of precision are used in the format. Relative error should be
     57  // below 2^-10.
     58  EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b;
     59 }
     60 
     61 TEST(QuantWeightsTest, DC) {
     62  JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
     63  DequantMatrices mat;
     64  float dc_quant[3] = {1e+5, 1e+3, 1e+1};
     65  ASSERT_TRUE(DequantMatricesSetCustomDC(memory_manager, &mat, dc_quant));
     66  for (size_t c = 0; c < 3; c++) {
     67    CheckSimilar(mat.InvDCQuant(c), dc_quant[c]);
     68  }
     69 }
     70 
     71 void RoundtripMatrices(const std::vector<QuantEncoding>& encodings) {
     72  ASSERT_TRUE(encodings.size() == kNumQuantTables);
     73  DequantMatrices mat;
     74  CodecMetadata metadata;
     75  FrameHeader frame_header(&metadata);
     76  JXL_ASSIGN_OR_QUIT(
     77      ModularFrameEncoder encoder,
     78      ModularFrameEncoder::Create(jxl::test::MemoryManager(), frame_header,
     79                                  CompressParams{}, false),
     80      "Failed to create ModularFrameEncoder.");
     81  ASSERT_TRUE(DequantMatricesSetCustom(&mat, encodings, &encoder));
     82  const std::vector<QuantEncoding>& encodings_dec = mat.encodings();
     83  for (size_t i = 0; i < encodings.size(); i++) {
     84    const QuantEncoding& e = encodings[i];
     85    const QuantEncoding& d = encodings_dec[i];
     86    // Check values roundtripped correctly.
     87    EXPECT_EQ(e.mode, d.mode);
     88    EXPECT_EQ(e.predefined, d.predefined);
     89    EXPECT_EQ(e.source, d.source);
     90 
     91    EXPECT_EQ(static_cast<uint64_t>(e.dct_params.num_distance_bands),
     92              static_cast<uint64_t>(d.dct_params.num_distance_bands));
     93    for (size_t c = 0; c < 3; c++) {
     94      for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) {
     95        CheckSimilar(e.dct_params.distance_bands[c][j],
     96                     d.dct_params.distance_bands[c][j]);
     97      }
     98    }
     99 
    100    if (e.mode == QuantEncoding::kQuantModeRAW) {
    101      EXPECT_FALSE(!e.qraw.qtable);
    102      EXPECT_FALSE(!d.qraw.qtable);
    103      EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size());
    104      for (size_t j = 0; j < e.qraw.qtable->size(); j++) {
    105        EXPECT_EQ(e.qraw.qtable->at(j), d.qraw.qtable->at(j));
    106      }
    107      EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f);
    108    } else {
    109      // modes different than kQuantModeRAW use one of the other fields used
    110      // here, which all happen to be arrays of floats.
    111      for (size_t c = 0; c < 3; c++) {
    112        for (size_t j = 0; j < 3; j++) {
    113          CheckSimilar(e.idweights[c][j], d.idweights[c][j]);
    114        }
    115        for (size_t j = 0; j < 6; j++) {
    116          CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]);
    117        }
    118        for (size_t j = 0; j < 2; j++) {
    119          CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]);
    120        }
    121        CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]);
    122        for (size_t j = 0; j < 9; j++) {
    123          CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]);
    124        }
    125        for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) {
    126          CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j],
    127                       d.dct_params_afv_4x4.distance_bands[c][j]);
    128        }
    129      }
    130    }
    131  }
    132 }
    133 
    134 TEST(QuantWeightsTest, AllDefault) {
    135  std::vector<QuantEncoding> encodings(kNumQuantTables,
    136                                       QuantEncoding::Library<0>());
    137  RoundtripMatrices(encodings);
    138 }
    139 
    140 void TestSingleQuantMatrix(QuantTable kind) {
    141  std::vector<QuantEncoding> encodings(kNumQuantTables,
    142                                       QuantEncoding::Library<0>());
    143  size_t quant_table_idx = static_cast<size_t>(kind);
    144  encodings[quant_table_idx] = DequantMatrices::Library()[quant_table_idx];
    145  RoundtripMatrices(encodings);
    146 }
    147 
    148 // Ensure we can reasonably represent default quant tables.
    149 TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(QuantTable::DCT); }
    150 TEST(QuantWeightsTest, IDENTITY) {
    151  TestSingleQuantMatrix(QuantTable::IDENTITY);
    152 }
    153 TEST(QuantWeightsTest, DCT2X2) { TestSingleQuantMatrix(QuantTable::DCT2X2); }
    154 TEST(QuantWeightsTest, DCT4X4) { TestSingleQuantMatrix(QuantTable::DCT4X4); }
    155 TEST(QuantWeightsTest, DCT16X16) {
    156  TestSingleQuantMatrix(QuantTable::DCT16X16);
    157 }
    158 TEST(QuantWeightsTest, DCT32X32) {
    159  TestSingleQuantMatrix(QuantTable::DCT32X32);
    160 }
    161 TEST(QuantWeightsTest, DCT8X16) { TestSingleQuantMatrix(QuantTable::DCT8X16); }
    162 TEST(QuantWeightsTest, DCT8X32) { TestSingleQuantMatrix(QuantTable::DCT8X32); }
    163 TEST(QuantWeightsTest, DCT16X32) {
    164  TestSingleQuantMatrix(QuantTable::DCT16X32);
    165 }
    166 TEST(QuantWeightsTest, DCT4X8) { TestSingleQuantMatrix(QuantTable::DCT4X8); }
    167 TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(QuantTable::AFV0); }
    168 TEST(QuantWeightsTest, RAW) {
    169  std::vector<QuantEncoding> encodings(kNumQuantTables,
    170                                       QuantEncoding::Library<0>());
    171  std::vector<int> matrix(3 * 32 * 32);
    172  Rng rng(0);
    173  for (int& v : matrix) v = rng.UniformI(1, 256);
    174  QuantTable quant_table =
    175      kAcStrategyToQuantTableMap[static_cast<size_t>(AcStrategyType::DCT32X32)];
    176  encodings[static_cast<size_t>(quant_table)] =
    177      QuantEncoding::RAW(std::move(matrix), 2);
    178  RoundtripMatrices(encodings);
    179 }
    180 
    181 class QuantWeightsTargetTest : public hwy::TestWithParamTarget {};
    182 HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest);
    183 
    184 TEST_P(QuantWeightsTargetTest, DCTUniform) {
    185  JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
    186  constexpr float kUniformQuant = 4;
    187  float weights[3][2] = {{1.0f / kUniformQuant, 0},
    188                         {1.0f / kUniformQuant, 0},
    189                         {1.0f / kUniformQuant, 0}};
    190  DctQuantWeightParams dct_params(weights);
    191  std::vector<QuantEncoding> encodings(kNumQuantTables,
    192                                       QuantEncoding::DCT(dct_params));
    193  DequantMatrices dequant_matrices;
    194  CodecMetadata metadata;
    195  FrameHeader frame_header(&metadata);
    196  JXL_ASSIGN_OR_QUIT(
    197      ModularFrameEncoder encoder,
    198      ModularFrameEncoder::Create(jxl::test::MemoryManager(), frame_header,
    199                                  CompressParams{}, false),
    200      "Failed to create ModularFrameEncoder.");
    201  ASSERT_TRUE(DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder));
    202  ASSERT_TRUE(dequant_matrices.EnsureComputed(memory_manager, ~0u));
    203 
    204  const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant,
    205                             1.0f / kUniformQuant};
    206  ASSERT_TRUE(
    207      DequantMatricesSetCustomDC(memory_manager, &dequant_matrices, dc_quant));
    208 
    209  HWY_ALIGN_MAX float scratch_space[16 * 16 * 5];
    210 
    211  // DCT8
    212  {
    213    HWY_ALIGN_MAX float pixels[64];
    214    std::iota(std::begin(pixels), std::end(pixels), 0);
    215    HWY_ALIGN_MAX float coeffs[64];
    216    const AcStrategyType dct = AcStrategyType::DCT;
    217    TransformFromPixels(dct, pixels, 8, coeffs, scratch_space);
    218    HWY_ALIGN_MAX double slow_coeffs[64];
    219    for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i];
    220    DCTSlow<8>(slow_coeffs);
    221 
    222    for (size_t i = 0; i < 64; i++) {
    223      // DCTSlow doesn't multiply/divide by 1/N, so we do it manually.
    224      slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant;
    225      coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) *
    226                  dequant_matrices.Matrix(dct, 0)[i];
    227    }
    228    IDCTSlow<8>(slow_coeffs);
    229    TransformToPixels(dct, coeffs, pixels, 8, scratch_space);
    230    for (size_t i = 0; i < 64; i++) {
    231      EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4);
    232    }
    233  }
    234 
    235  // DCT16
    236  {
    237    HWY_ALIGN_MAX float pixels[64 * 4];
    238    std::iota(std::begin(pixels), std::end(pixels), 0);
    239    HWY_ALIGN_MAX float coeffs[64 * 4];
    240    const AcStrategyType dct = AcStrategyType::DCT16X16;
    241    TransformFromPixels(dct, pixels, 16, coeffs, scratch_space);
    242    HWY_ALIGN_MAX double slow_coeffs[64 * 4];
    243    for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i];
    244    DCTSlow<16>(slow_coeffs);
    245 
    246    for (size_t i = 0; i < 64 * 4; i++) {
    247      slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant;
    248      coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) *
    249                  dequant_matrices.Matrix(dct, 0)[i];
    250    }
    251 
    252    IDCTSlow<16>(slow_coeffs);
    253    TransformToPixels(dct, coeffs, pixels, 16, scratch_space);
    254    for (size_t i = 0; i < 64 * 4; i++) {
    255      EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4);
    256    }
    257  }
    258 
    259  // Check that all matrices have the same DC quantization, i.e. that they all
    260  // have the same scaling.
    261  for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
    262    AcStrategyType kind = static_cast<AcStrategyType>(i);
    263    EXPECT_NEAR(dequant_matrices.Matrix(kind, 0)[0], kUniformQuant, 1e-6);
    264  }
    265 }
    266 
    267 }  // namespace
    268 }  // namespace jxl