av1_txfm_test.cc (11649B)
1 /* 2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include "test/av1_txfm_test.h" 13 14 #include <stdio.h> 15 16 #include <memory> 17 #include <new> 18 19 namespace libaom_test { 20 21 const char *tx_type_name[] = { 22 "DCT_DCT", 23 "ADST_DCT", 24 "DCT_ADST", 25 "ADST_ADST", 26 "FLIPADST_DCT", 27 "DCT_FLIPADST", 28 "FLIPADST_FLIPADST", 29 "ADST_FLIPADST", 30 "FLIPADST_ADST", 31 "IDTX", 32 "V_DCT", 33 "H_DCT", 34 "V_ADST", 35 "H_ADST", 36 "V_FLIPADST", 37 "H_FLIPADST", 38 }; 39 40 int get_txfm1d_size(TX_SIZE tx_size) { return tx_size_wide[tx_size]; } 41 42 void get_txfm1d_type(TX_TYPE txfm2d_type, TYPE_TXFM *type0, TYPE_TXFM *type1) { 43 switch (txfm2d_type) { 44 case DCT_DCT: 45 *type0 = TYPE_DCT; 46 *type1 = TYPE_DCT; 47 break; 48 case ADST_DCT: 49 *type0 = TYPE_ADST; 50 *type1 = TYPE_DCT; 51 break; 52 case DCT_ADST: 53 *type0 = TYPE_DCT; 54 *type1 = TYPE_ADST; 55 break; 56 case ADST_ADST: 57 *type0 = TYPE_ADST; 58 *type1 = TYPE_ADST; 59 break; 60 case FLIPADST_DCT: 61 *type0 = TYPE_ADST; 62 *type1 = TYPE_DCT; 63 break; 64 case DCT_FLIPADST: 65 *type0 = TYPE_DCT; 66 *type1 = TYPE_ADST; 67 break; 68 case FLIPADST_FLIPADST: 69 *type0 = TYPE_ADST; 70 *type1 = TYPE_ADST; 71 break; 72 case ADST_FLIPADST: 73 *type0 = TYPE_ADST; 74 *type1 = TYPE_ADST; 75 break; 76 case FLIPADST_ADST: 77 *type0 = TYPE_ADST; 78 *type1 = TYPE_ADST; 79 break; 80 case IDTX: 81 *type0 = TYPE_IDTX; 82 *type1 = TYPE_IDTX; 83 break; 84 case H_DCT: 85 *type0 = TYPE_IDTX; 86 *type1 = TYPE_DCT; 87 break; 88 case V_DCT: 89 *type0 = TYPE_DCT; 90 *type1 = TYPE_IDTX; 91 break; 92 case H_ADST: 93 *type0 = TYPE_IDTX; 94 *type1 = TYPE_ADST; 95 break; 96 case V_ADST: 97 *type0 = TYPE_ADST; 98 *type1 = TYPE_IDTX; 99 break; 100 case H_FLIPADST: 101 *type0 = TYPE_IDTX; 102 *type1 = TYPE_ADST; 103 break; 104 case V_FLIPADST: 105 *type0 = TYPE_ADST; 106 *type1 = TYPE_IDTX; 107 break; 108 default: 109 *type0 = TYPE_DCT; 110 *type1 = TYPE_DCT; 111 assert(0); 112 break; 113 } 114 } 115 116 double Sqrt2 = pow(2, 0.5); 117 double invSqrt2 = 1 / pow(2, 0.5); 118 119 static double dct_matrix(double n, double k, int size) { 120 return cos(PI * (2 * n + 1) * k / (2 * size)); 121 } 122 123 void reference_dct_1d(const double *in, double *out, int size) { 124 for (int k = 0; k < size; ++k) { 125 out[k] = 0; 126 for (int n = 0; n < size; ++n) { 127 out[k] += in[n] * dct_matrix(n, k, size); 128 } 129 if (k == 0) out[k] = out[k] * invSqrt2; 130 } 131 } 132 133 void reference_idct_1d(const double *in, double *out, int size) { 134 for (int k = 0; k < size; ++k) { 135 out[k] = 0; 136 for (int n = 0; n < size; ++n) { 137 if (n == 0) 138 out[k] += invSqrt2 * in[n] * dct_matrix(k, n, size); 139 else 140 out[k] += in[n] * dct_matrix(k, n, size); 141 } 142 } 143 } 144 145 // TODO(any): Copied from the old 'fadst4' (same as the new 'av1_fadst4' 146 // function). Should be replaced by a proper reference function that takes 147 // 'double' input & output. 148 static void fadst4_new(const tran_low_t *input, tran_low_t *output) { 149 tran_high_t x0, x1, x2, x3; 150 tran_high_t s0, s1, s2, s3, s4, s5, s6, s7; 151 152 x0 = input[0]; 153 x1 = input[1]; 154 x2 = input[2]; 155 x3 = input[3]; 156 157 if (!(x0 | x1 | x2 | x3)) { 158 output[0] = output[1] = output[2] = output[3] = 0; 159 return; 160 } 161 162 s0 = sinpi_1_9 * x0; 163 s1 = sinpi_4_9 * x0; 164 s2 = sinpi_2_9 * x1; 165 s3 = sinpi_1_9 * x1; 166 s4 = sinpi_3_9 * x2; 167 s5 = sinpi_4_9 * x3; 168 s6 = sinpi_2_9 * x3; 169 s7 = x0 + x1 - x3; 170 171 x0 = s0 + s2 + s5; 172 x1 = sinpi_3_9 * s7; 173 x2 = s1 - s3 + s6; 174 x3 = s4; 175 176 s0 = x0 + x3; 177 s1 = x1; 178 s2 = x2 - x3; 179 s3 = x2 - x0 + x3; 180 181 // 1-D transform scaling factor is sqrt(2). 182 output[0] = (tran_low_t)fdct_round_shift(s0); 183 output[1] = (tran_low_t)fdct_round_shift(s1); 184 output[2] = (tran_low_t)fdct_round_shift(s2); 185 output[3] = (tran_low_t)fdct_round_shift(s3); 186 } 187 188 void reference_adst_1d(const double *in, double *out, int size) { 189 if (size == 4) { // Special case. 190 tran_low_t int_input[4]; 191 for (int i = 0; i < 4; ++i) { 192 int_input[i] = static_cast<tran_low_t>(round(in[i])); 193 } 194 tran_low_t int_output[4]; 195 fadst4_new(int_input, int_output); 196 for (int i = 0; i < 4; ++i) { 197 out[i] = int_output[i]; 198 } 199 return; 200 } 201 202 for (int k = 0; k < size; ++k) { 203 out[k] = 0; 204 for (int n = 0; n < size; ++n) { 205 out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size)); 206 } 207 } 208 } 209 210 static void reference_idtx_1d(const double *in, double *out, int size) { 211 double scale = 0; 212 if (size == 4) 213 scale = Sqrt2; 214 else if (size == 8) 215 scale = 2; 216 else if (size == 16) 217 scale = 2 * Sqrt2; 218 else if (size == 32) 219 scale = 4; 220 else if (size == 64) 221 scale = 4 * Sqrt2; 222 for (int k = 0; k < size; ++k) { 223 out[k] = in[k] * scale; 224 } 225 } 226 227 void reference_hybrid_1d(double *in, double *out, int size, int type) { 228 if (type == TYPE_DCT) 229 reference_dct_1d(in, out, size); 230 else if (type == TYPE_ADST) 231 reference_adst_1d(in, out, size); 232 else 233 reference_idtx_1d(in, out, size); 234 } 235 236 double get_amplification_factor(TX_TYPE tx_type, TX_SIZE tx_size) { 237 TXFM_2D_FLIP_CFG fwd_txfm_flip_cfg; 238 av1_get_fwd_txfm_cfg(tx_type, tx_size, &fwd_txfm_flip_cfg); 239 const int tx_width = tx_size_wide[fwd_txfm_flip_cfg.tx_size]; 240 const int tx_height = tx_size_high[fwd_txfm_flip_cfg.tx_size]; 241 const int8_t *shift = fwd_txfm_flip_cfg.shift; 242 const int amplify_bit = shift[0] + shift[1] + shift[2]; 243 double amplify_factor = 244 amplify_bit >= 0 ? (1 << amplify_bit) : (1.0 / (1 << -amplify_bit)); 245 246 // For rectangular transforms, we need to multiply by an extra factor. 247 const int rect_type = get_rect_tx_log_ratio(tx_width, tx_height); 248 if (abs(rect_type) == 1) { 249 amplify_factor *= pow(2, 0.5); 250 } 251 return amplify_factor; 252 } 253 254 void reference_hybrid_2d(double *in, double *out, TX_TYPE tx_type, 255 TX_SIZE tx_size) { 256 // Get transform type and size of each dimension. 257 TYPE_TXFM type0; 258 TYPE_TXFM type1; 259 get_txfm1d_type(tx_type, &type0, &type1); 260 const int tx_width = tx_size_wide[tx_size]; 261 const int tx_height = tx_size_high[tx_size]; 262 263 std::unique_ptr<double[]> temp_in( 264 new (std::nothrow) double[AOMMAX(tx_width, tx_height)]); 265 std::unique_ptr<double[]> temp_out( 266 new (std::nothrow) double[AOMMAX(tx_width, tx_height)]); 267 std::unique_ptr<double[]> out_interm( 268 new (std::nothrow) double[tx_width * tx_height]); 269 ASSERT_NE(temp_in, nullptr); 270 ASSERT_NE(temp_out, nullptr); 271 ASSERT_NE(out_interm, nullptr); 272 273 // Transform columns. 274 for (int c = 0; c < tx_width; ++c) { 275 for (int r = 0; r < tx_height; ++r) { 276 temp_in[r] = in[r * tx_width + c]; 277 } 278 reference_hybrid_1d(temp_in.get(), temp_out.get(), tx_height, type0); 279 for (int r = 0; r < tx_height; ++r) { 280 out_interm[r * tx_width + c] = temp_out[r]; 281 } 282 } 283 284 // Transform rows. 285 for (int r = 0; r < tx_height; ++r) { 286 reference_hybrid_1d(out_interm.get() + r * tx_width, temp_out.get(), 287 tx_width, type1); 288 for (int c = 0; c < tx_width; ++c) { 289 out[c * tx_height + r] = temp_out[c]; 290 } 291 } 292 293 // These transforms use an approximate 2D DCT transform, by only keeping the 294 // top-left quarter of the coefficients, and repacking them in the first 295 // quarter indices. 296 // TODO(urvang): Refactor this code. 297 if (tx_width == 64 && tx_height == 64) { // tx_size == TX_64X64 298 // Zero out top-right 32x32 area. 299 for (int col = 0; col < 32; ++col) { 300 memset(out + col * 64 + 32, 0, 32 * sizeof(*out)); 301 } 302 // Zero out the bottom 64x32 area. 303 memset(out + 32 * 64, 0, 32 * 64 * sizeof(*out)); 304 // Re-pack non-zero coeffs in the first 32x32 indices. 305 for (int col = 1; col < 32; ++col) { 306 memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out)); 307 } 308 } else if (tx_width == 32 && tx_height == 64) { // tx_size == TX_32X64 309 // Zero out right 32x32 area. 310 for (int col = 0; col < 32; ++col) { 311 memset(out + col * 64 + 32, 0, 32 * sizeof(*out)); 312 } 313 // Re-pack non-zero coeffs in the first 32x32 indices. 314 for (int col = 1; col < 32; ++col) { 315 memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out)); 316 } 317 } else if (tx_width == 64 && tx_height == 32) { // tx_size == TX_64X32 318 // Zero out the bottom 32x32 area. 319 memset(out + 32 * 32, 0, 32 * 32 * sizeof(*out)); 320 // Note: no repacking needed here. 321 } else if (tx_width == 16 && tx_height == 64) { // tx_size == TX_16X64 322 // Note: no repacking needed here. 323 // Zero out right 32x16 area. 324 for (int col = 0; col < 16; ++col) { 325 memset(out + col * 64 + 32, 0, 32 * sizeof(*out)); 326 } 327 // Re-pack non-zero coeffs in the first 32x16 indices. 328 for (int col = 1; col < 16; ++col) { 329 memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out)); 330 } 331 } else if (tx_width == 64 && tx_height == 16) { // tx_size == TX_64X16 332 // Zero out the bottom 16x32 area. 333 memset(out + 16 * 32, 0, 16 * 32 * sizeof(*out)); 334 } 335 336 // Apply appropriate scale. 337 const double amplify_factor = get_amplification_factor(tx_type, tx_size); 338 for (int c = 0; c < tx_width; ++c) { 339 for (int r = 0; r < tx_height; ++r) { 340 out[c * tx_height + r] *= amplify_factor; 341 } 342 } 343 } 344 345 template <typename Type> 346 void fliplr(Type *dest, int width, int height, int stride) { 347 for (int r = 0; r < height; ++r) { 348 for (int c = 0; c < width / 2; ++c) { 349 const Type tmp = dest[r * stride + c]; 350 dest[r * stride + c] = dest[r * stride + width - 1 - c]; 351 dest[r * stride + width - 1 - c] = tmp; 352 } 353 } 354 } 355 356 template <typename Type> 357 void flipud(Type *dest, int width, int height, int stride) { 358 for (int c = 0; c < width; ++c) { 359 for (int r = 0; r < height / 2; ++r) { 360 const Type tmp = dest[r * stride + c]; 361 dest[r * stride + c] = dest[(height - 1 - r) * stride + c]; 362 dest[(height - 1 - r) * stride + c] = tmp; 363 } 364 } 365 } 366 367 template <typename Type> 368 void fliplrud(Type *dest, int width, int height, int stride) { 369 for (int r = 0; r < height / 2; ++r) { 370 for (int c = 0; c < width; ++c) { 371 const Type tmp = dest[r * stride + c]; 372 dest[r * stride + c] = dest[(height - 1 - r) * stride + width - 1 - c]; 373 dest[(height - 1 - r) * stride + width - 1 - c] = tmp; 374 } 375 } 376 } 377 378 template void fliplr<double>(double *dest, int width, int height, int stride); 379 template void flipud<double>(double *dest, int width, int height, int stride); 380 template void fliplrud<double>(double *dest, int width, int height, int stride); 381 382 int bd_arr[BD_NUM] = { 8, 10, 12 }; 383 384 int8_t low_range_arr[BD_NUM] = { 18, 32, 32 }; 385 int8_t high_range_arr[BD_NUM] = { 32, 32, 32 }; 386 387 void txfm_stage_range_check(const int8_t *stage_range, int stage_num, 388 int8_t cos_bit, int low_range, int high_range) { 389 for (int i = 0; i < stage_num; ++i) { 390 EXPECT_LE(stage_range[i], low_range); 391 ASSERT_LE(stage_range[i] + cos_bit, high_range) << "stage = " << i; 392 } 393 for (int i = 0; i < stage_num - 1; ++i) { 394 // make sure there is no overflow while doing half_btf() 395 ASSERT_LE(stage_range[i + 1] + cos_bit, high_range) << "stage = " << i; 396 } 397 } 398 } // namespace libaom_test