quant_weights_test.cc (10176B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 #include "lib/jxl/quant_weights.h" 6 7 #include <jxl/memory_manager.h> 8 9 #include <algorithm> 10 #include <cmath> 11 #include <cstdlib> 12 #include <hwy/base.h> // HWY_ALIGN_MAX 13 #include <hwy/tests/hwy_gtest.h> 14 #include <iterator> 15 #include <numeric> 16 #include <utility> 17 #include <vector> 18 19 #include "lib/jxl/ac_strategy.h" 20 #include "lib/jxl/base/random.h" 21 #include "lib/jxl/base/status.h" 22 #include "lib/jxl/dct_for_test.h" 23 #include "lib/jxl/dec_transforms_testonly.h" 24 #include "lib/jxl/enc_modular.h" 25 #include "lib/jxl/enc_params.h" 26 #include "lib/jxl/enc_quant_weights.h" 27 #include "lib/jxl/enc_transforms.h" 28 #include "lib/jxl/frame_header.h" 29 #include "lib/jxl/image_metadata.h" 30 #include "lib/jxl/test_memory_manager.h" 31 #include "lib/jxl/testing.h" 32 33 namespace jxl { 34 namespace { 35 36 // This should have been static assert; not compiling though with C++<17. 37 TEST(QuantWeightsTest, Invariant) { 38 size_t sum = 0; 39 ASSERT_EQ(DequantMatrices::required_size_x.size(), 40 DequantMatrices::required_size_y.size()); 41 for (size_t i = 0; i < DequantMatrices::required_size_x.size(); ++i) { 42 sum += DequantMatrices::required_size_x[i] * 43 DequantMatrices::required_size_y[i]; 44 } 45 ASSERT_EQ(DequantMatrices::kSumRequiredXy, sum); 46 } 47 48 template <typename T> 49 void CheckSimilar(T a, T b) { 50 EXPECT_EQ(a, b); 51 } 52 // minimum exponent = -15. 53 template <> 54 void CheckSimilar(float a, float b) { 55 float m = std::max(std::abs(a), std::abs(b)); 56 // 10 bits of precision are used in the format. Relative error should be 57 // below 2^-10. 58 EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b; 59 } 60 61 TEST(QuantWeightsTest, DC) { 62 JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); 63 DequantMatrices mat; 64 float dc_quant[3] = {1e+5, 1e+3, 1e+1}; 65 ASSERT_TRUE(DequantMatricesSetCustomDC(memory_manager, &mat, dc_quant)); 66 for (size_t c = 0; c < 3; c++) { 67 CheckSimilar(mat.InvDCQuant(c), dc_quant[c]); 68 } 69 } 70 71 void RoundtripMatrices(const std::vector<QuantEncoding>& encodings) { 72 ASSERT_TRUE(encodings.size() == kNumQuantTables); 73 DequantMatrices mat; 74 CodecMetadata metadata; 75 FrameHeader frame_header(&metadata); 76 JXL_ASSIGN_OR_QUIT( 77 ModularFrameEncoder encoder, 78 ModularFrameEncoder::Create(jxl::test::MemoryManager(), frame_header, 79 CompressParams{}, false), 80 "Failed to create ModularFrameEncoder."); 81 ASSERT_TRUE(DequantMatricesSetCustom(&mat, encodings, &encoder)); 82 const std::vector<QuantEncoding>& encodings_dec = mat.encodings(); 83 for (size_t i = 0; i < encodings.size(); i++) { 84 const QuantEncoding& e = encodings[i]; 85 const QuantEncoding& d = encodings_dec[i]; 86 // Check values roundtripped correctly. 87 EXPECT_EQ(e.mode, d.mode); 88 EXPECT_EQ(e.predefined, d.predefined); 89 EXPECT_EQ(e.source, d.source); 90 91 EXPECT_EQ(static_cast<uint64_t>(e.dct_params.num_distance_bands), 92 static_cast<uint64_t>(d.dct_params.num_distance_bands)); 93 for (size_t c = 0; c < 3; c++) { 94 for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { 95 CheckSimilar(e.dct_params.distance_bands[c][j], 96 d.dct_params.distance_bands[c][j]); 97 } 98 } 99 100 if (e.mode == QuantEncoding::kQuantModeRAW) { 101 EXPECT_FALSE(!e.qraw.qtable); 102 EXPECT_FALSE(!d.qraw.qtable); 103 EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size()); 104 for (size_t j = 0; j < e.qraw.qtable->size(); j++) { 105 EXPECT_EQ(e.qraw.qtable->at(j), d.qraw.qtable->at(j)); 106 } 107 EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f); 108 } else { 109 // modes different than kQuantModeRAW use one of the other fields used 110 // here, which all happen to be arrays of floats. 111 for (size_t c = 0; c < 3; c++) { 112 for (size_t j = 0; j < 3; j++) { 113 CheckSimilar(e.idweights[c][j], d.idweights[c][j]); 114 } 115 for (size_t j = 0; j < 6; j++) { 116 CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]); 117 } 118 for (size_t j = 0; j < 2; j++) { 119 CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]); 120 } 121 CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]); 122 for (size_t j = 0; j < 9; j++) { 123 CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]); 124 } 125 for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { 126 CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j], 127 d.dct_params_afv_4x4.distance_bands[c][j]); 128 } 129 } 130 } 131 } 132 } 133 134 TEST(QuantWeightsTest, AllDefault) { 135 std::vector<QuantEncoding> encodings(kNumQuantTables, 136 QuantEncoding::Library<0>()); 137 RoundtripMatrices(encodings); 138 } 139 140 void TestSingleQuantMatrix(QuantTable kind) { 141 std::vector<QuantEncoding> encodings(kNumQuantTables, 142 QuantEncoding::Library<0>()); 143 size_t quant_table_idx = static_cast<size_t>(kind); 144 encodings[quant_table_idx] = DequantMatrices::Library()[quant_table_idx]; 145 RoundtripMatrices(encodings); 146 } 147 148 // Ensure we can reasonably represent default quant tables. 149 TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(QuantTable::DCT); } 150 TEST(QuantWeightsTest, IDENTITY) { 151 TestSingleQuantMatrix(QuantTable::IDENTITY); 152 } 153 TEST(QuantWeightsTest, DCT2X2) { TestSingleQuantMatrix(QuantTable::DCT2X2); } 154 TEST(QuantWeightsTest, DCT4X4) { TestSingleQuantMatrix(QuantTable::DCT4X4); } 155 TEST(QuantWeightsTest, DCT16X16) { 156 TestSingleQuantMatrix(QuantTable::DCT16X16); 157 } 158 TEST(QuantWeightsTest, DCT32X32) { 159 TestSingleQuantMatrix(QuantTable::DCT32X32); 160 } 161 TEST(QuantWeightsTest, DCT8X16) { TestSingleQuantMatrix(QuantTable::DCT8X16); } 162 TEST(QuantWeightsTest, DCT8X32) { TestSingleQuantMatrix(QuantTable::DCT8X32); } 163 TEST(QuantWeightsTest, DCT16X32) { 164 TestSingleQuantMatrix(QuantTable::DCT16X32); 165 } 166 TEST(QuantWeightsTest, DCT4X8) { TestSingleQuantMatrix(QuantTable::DCT4X8); } 167 TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(QuantTable::AFV0); } 168 TEST(QuantWeightsTest, RAW) { 169 std::vector<QuantEncoding> encodings(kNumQuantTables, 170 QuantEncoding::Library<0>()); 171 std::vector<int> matrix(3 * 32 * 32); 172 Rng rng(0); 173 for (int& v : matrix) v = rng.UniformI(1, 256); 174 QuantTable quant_table = 175 kAcStrategyToQuantTableMap[static_cast<size_t>(AcStrategyType::DCT32X32)]; 176 encodings[static_cast<size_t>(quant_table)] = 177 QuantEncoding::RAW(std::move(matrix), 2); 178 RoundtripMatrices(encodings); 179 } 180 181 class QuantWeightsTargetTest : public hwy::TestWithParamTarget {}; 182 HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest); 183 184 TEST_P(QuantWeightsTargetTest, DCTUniform) { 185 JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); 186 constexpr float kUniformQuant = 4; 187 float weights[3][2] = {{1.0f / kUniformQuant, 0}, 188 {1.0f / kUniformQuant, 0}, 189 {1.0f / kUniformQuant, 0}}; 190 DctQuantWeightParams dct_params(weights); 191 std::vector<QuantEncoding> encodings(kNumQuantTables, 192 QuantEncoding::DCT(dct_params)); 193 DequantMatrices dequant_matrices; 194 CodecMetadata metadata; 195 FrameHeader frame_header(&metadata); 196 JXL_ASSIGN_OR_QUIT( 197 ModularFrameEncoder encoder, 198 ModularFrameEncoder::Create(jxl::test::MemoryManager(), frame_header, 199 CompressParams{}, false), 200 "Failed to create ModularFrameEncoder."); 201 ASSERT_TRUE(DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder)); 202 ASSERT_TRUE(dequant_matrices.EnsureComputed(memory_manager, ~0u)); 203 204 const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant, 205 1.0f / kUniformQuant}; 206 ASSERT_TRUE( 207 DequantMatricesSetCustomDC(memory_manager, &dequant_matrices, dc_quant)); 208 209 HWY_ALIGN_MAX float scratch_space[16 * 16 * 5]; 210 211 // DCT8 212 { 213 HWY_ALIGN_MAX float pixels[64]; 214 std::iota(std::begin(pixels), std::end(pixels), 0); 215 HWY_ALIGN_MAX float coeffs[64]; 216 const AcStrategyType dct = AcStrategyType::DCT; 217 TransformFromPixels(dct, pixels, 8, coeffs, scratch_space); 218 HWY_ALIGN_MAX double slow_coeffs[64]; 219 for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i]; 220 DCTSlow<8>(slow_coeffs); 221 222 for (size_t i = 0; i < 64; i++) { 223 // DCTSlow doesn't multiply/divide by 1/N, so we do it manually. 224 slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; 225 coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * 226 dequant_matrices.Matrix(dct, 0)[i]; 227 } 228 IDCTSlow<8>(slow_coeffs); 229 TransformToPixels(dct, coeffs, pixels, 8, scratch_space); 230 for (size_t i = 0; i < 64; i++) { 231 EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); 232 } 233 } 234 235 // DCT16 236 { 237 HWY_ALIGN_MAX float pixels[64 * 4]; 238 std::iota(std::begin(pixels), std::end(pixels), 0); 239 HWY_ALIGN_MAX float coeffs[64 * 4]; 240 const AcStrategyType dct = AcStrategyType::DCT16X16; 241 TransformFromPixels(dct, pixels, 16, coeffs, scratch_space); 242 HWY_ALIGN_MAX double slow_coeffs[64 * 4]; 243 for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i]; 244 DCTSlow<16>(slow_coeffs); 245 246 for (size_t i = 0; i < 64 * 4; i++) { 247 slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; 248 coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * 249 dequant_matrices.Matrix(dct, 0)[i]; 250 } 251 252 IDCTSlow<16>(slow_coeffs); 253 TransformToPixels(dct, coeffs, pixels, 16, scratch_space); 254 for (size_t i = 0; i < 64 * 4; i++) { 255 EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); 256 } 257 } 258 259 // Check that all matrices have the same DC quantization, i.e. that they all 260 // have the same scaling. 261 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { 262 AcStrategyType kind = static_cast<AcStrategyType>(i); 263 EXPECT_NEAR(dequant_matrices.Matrix(kind, 0)[0], kUniformQuant, 1e-6); 264 } 265 } 266 267 } // namespace 268 } // namespace jxl