tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

av1_fwd_txfm2d_test.cc (23845B)


      1 /*
      2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <math.h>
     13 #include <stdio.h>
     14 #include <stdlib.h>
     15 #include <tuple>
     16 #include <vector>
     17 
     18 #include "config/aom_config.h"
     19 #include "config/av1_rtcd.h"
     20 
     21 #include "test/acm_random.h"
     22 #include "test/util.h"
     23 #include "test/av1_txfm_test.h"
     24 #include "av1/common/av1_txfm.h"
     25 #include "av1/encoder/hybrid_fwd_txfm.h"
     26 
     27 using libaom_test::ACMRandom;
     28 using libaom_test::bd;
     29 using libaom_test::compute_avg_abs_error;
     30 using libaom_test::input_base;
     31 using libaom_test::tx_type_name;
     32 using libaom_test::TYPE_TXFM;
     33 
     34 using std::vector;
     35 
     36 namespace {
     37 // tx_type_, tx_size_, max_error_, max_avg_error_
     38 using AV1FwdTxfm2dParam = std::tuple<TX_TYPE, TX_SIZE, double, double>;
     39 
     40 class AV1FwdTxfm2d : public ::testing::TestWithParam<AV1FwdTxfm2dParam> {
     41 public:
     42  void SetUp() override {
     43    tx_type_ = GET_PARAM(0);
     44    tx_size_ = GET_PARAM(1);
     45    max_error_ = GET_PARAM(2);
     46    max_avg_error_ = GET_PARAM(3);
     47    count_ = 500;
     48    TXFM_2D_FLIP_CFG fwd_txfm_flip_cfg;
     49    av1_get_fwd_txfm_cfg(tx_type_, tx_size_, &fwd_txfm_flip_cfg);
     50    amplify_factor_ = libaom_test::get_amplification_factor(tx_type_, tx_size_);
     51    tx_width_ = tx_size_wide[fwd_txfm_flip_cfg.tx_size];
     52    tx_height_ = tx_size_high[fwd_txfm_flip_cfg.tx_size];
     53    ud_flip_ = fwd_txfm_flip_cfg.ud_flip;
     54    lr_flip_ = fwd_txfm_flip_cfg.lr_flip;
     55 
     56    fwd_txfm_ = libaom_test::fwd_txfm_func_ls[tx_size_];
     57    txfm2d_size_ = tx_width_ * tx_height_;
     58    input_ = reinterpret_cast<int16_t *>(
     59        aom_memalign(16, sizeof(input_[0]) * txfm2d_size_));
     60    ASSERT_NE(input_, nullptr);
     61    output_ = reinterpret_cast<int32_t *>(
     62        aom_memalign(16, sizeof(output_[0]) * txfm2d_size_));
     63    ASSERT_NE(output_, nullptr);
     64    ref_input_ = reinterpret_cast<double *>(
     65        aom_memalign(16, sizeof(ref_input_[0]) * txfm2d_size_));
     66    ASSERT_NE(ref_input_, nullptr);
     67    ref_output_ = reinterpret_cast<double *>(
     68        aom_memalign(16, sizeof(ref_output_[0]) * txfm2d_size_));
     69    ASSERT_NE(ref_output_, nullptr);
     70  }
     71 
     72  void RunFwdAccuracyCheck() {
     73    ACMRandom rnd(ACMRandom::DeterministicSeed());
     74    double avg_abs_error = 0;
     75    for (int ci = 0; ci < count_; ci++) {
     76      for (int ni = 0; ni < txfm2d_size_; ++ni) {
     77        input_[ni] = rnd.Rand16() % input_base;
     78        ref_input_[ni] = static_cast<double>(input_[ni]);
     79        output_[ni] = 0;
     80        ref_output_[ni] = 0;
     81      }
     82 
     83      fwd_txfm_(input_, output_, tx_width_, tx_type_, bd);
     84 
     85      if (lr_flip_ && ud_flip_) {
     86        libaom_test::fliplrud(ref_input_, tx_width_, tx_height_, tx_width_);
     87      } else if (lr_flip_) {
     88        libaom_test::fliplr(ref_input_, tx_width_, tx_height_, tx_width_);
     89      } else if (ud_flip_) {
     90        libaom_test::flipud(ref_input_, tx_width_, tx_height_, tx_width_);
     91      }
     92 
     93      libaom_test::reference_hybrid_2d(ref_input_, ref_output_, tx_type_,
     94                                       tx_size_);
     95 
     96      double actual_max_error = 0;
     97      for (int ni = 0; ni < txfm2d_size_; ++ni) {
     98        ref_output_[ni] = round(ref_output_[ni]);
     99        const double this_error =
    100            fabs(output_[ni] - ref_output_[ni]) / amplify_factor_;
    101        actual_max_error = AOMMAX(actual_max_error, this_error);
    102      }
    103      EXPECT_GE(max_error_, actual_max_error)
    104          << "tx_w: " << tx_width_ << " tx_h: " << tx_height_
    105          << ", tx_type = " << (int)tx_type_;
    106      if (actual_max_error > max_error_) {  // exit early.
    107        break;
    108      }
    109 
    110      avg_abs_error += compute_avg_abs_error<int32_t, double>(
    111          output_, ref_output_, txfm2d_size_);
    112    }
    113 
    114    avg_abs_error /= amplify_factor_;
    115    avg_abs_error /= count_;
    116    EXPECT_GE(max_avg_error_, avg_abs_error)
    117        << "tx_size = " << tx_size_ << ", tx_type = " << tx_type_;
    118  }
    119 
    120  void TearDown() override {
    121    aom_free(input_);
    122    aom_free(output_);
    123    aom_free(ref_input_);
    124    aom_free(ref_output_);
    125  }
    126 
    127 private:
    128  double max_error_;
    129  double max_avg_error_;
    130  int count_;
    131  double amplify_factor_;
    132  TX_TYPE tx_type_;
    133  TX_SIZE tx_size_;
    134  int tx_width_;
    135  int tx_height_;
    136  int txfm2d_size_;
    137  FwdTxfm2dFunc fwd_txfm_;
    138  int16_t *input_;
    139  int32_t *output_;
    140  double *ref_input_;
    141  double *ref_output_;
    142  int ud_flip_;  // flip upside down
    143  int lr_flip_;  // flip left to right
    144 };
    145 
    146 static constexpr double avg_error_ls[TX_SIZES_ALL] = {
    147  0.5,   // 4x4 transform
    148  0.5,   // 8x8 transform
    149  1.2,   // 16x16 transform
    150  6.1,   // 32x32 transform
    151  3.4,   // 64x64 transform
    152  0.57,  // 4x8 transform
    153  0.68,  // 8x4 transform
    154  0.92,  // 8x16 transform
    155  1.1,   // 16x8 transform
    156  4.1,   // 16x32 transform
    157  6,     // 32x16 transform
    158  3.5,   // 32x64 transform
    159  5.7,   // 64x32 transform
    160  0.6,   // 4x16 transform
    161  0.9,   // 16x4 transform
    162  1.2,   // 8x32 transform
    163  1.7,   // 32x8 transform
    164  2.0,   // 16x64 transform
    165  4.7,   // 64x16 transform
    166 };
    167 
    168 static constexpr double max_error_ls[TX_SIZES_ALL] = {
    169  3,    // 4x4 transform
    170  5,    // 8x8 transform
    171  11,   // 16x16 transform
    172  70,   // 32x32 transform
    173  64,   // 64x64 transform
    174  3.9,  // 4x8 transform
    175  4.3,  // 8x4 transform
    176  12,   // 8x16 transform
    177  12,   // 16x8 transform
    178  32,   // 16x32 transform
    179  46,   // 32x16 transform
    180  136,  // 32x64 transform
    181  136,  // 64x32 transform
    182  5,    // 4x16 transform
    183  6,    // 16x4 transform
    184  21,   // 8x32 transform
    185  13,   // 32x8 transform
    186  30,   // 16x64 transform
    187  36,   // 64x16 transform
    188 };
    189 
    190 vector<AV1FwdTxfm2dParam> GetTxfm2dParamList() {
    191  vector<AV1FwdTxfm2dParam> param_list;
    192  for (int s = 0; s < TX_SIZES; ++s) {
    193    const double max_error = max_error_ls[s];
    194    const double avg_error = avg_error_ls[s];
    195    for (int t = 0; t < TX_TYPES; ++t) {
    196      const TX_TYPE tx_type = static_cast<TX_TYPE>(t);
    197      const TX_SIZE tx_size = static_cast<TX_SIZE>(s);
    198      if (libaom_test::IsTxSizeTypeValid(tx_size, tx_type)) {
    199        param_list.push_back(
    200            AV1FwdTxfm2dParam(tx_type, tx_size, max_error, avg_error));
    201      }
    202    }
    203  }
    204  return param_list;
    205 }
    206 
    207 INSTANTIATE_TEST_SUITE_P(C, AV1FwdTxfm2d,
    208                         ::testing::ValuesIn(GetTxfm2dParamList()));
    209 
    210 TEST_P(AV1FwdTxfm2d, RunFwdAccuracyCheck) { RunFwdAccuracyCheck(); }
    211 
    212 TEST(AV1FwdTxfm2d, CfgTest) {
    213  for (int bd_idx = 0; bd_idx < BD_NUM; ++bd_idx) {
    214    int bd = libaom_test::bd_arr[bd_idx];
    215    int8_t low_range = libaom_test::low_range_arr[bd_idx];
    216    int8_t high_range = libaom_test::high_range_arr[bd_idx];
    217    for (int tx_size = 0; tx_size < TX_SIZES_ALL; ++tx_size) {
    218      for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
    219        if (libaom_test::IsTxSizeTypeValid(static_cast<TX_SIZE>(tx_size),
    220                                           static_cast<TX_TYPE>(tx_type)) ==
    221            false) {
    222          continue;
    223        }
    224        TXFM_2D_FLIP_CFG cfg;
    225        av1_get_fwd_txfm_cfg(static_cast<TX_TYPE>(tx_type),
    226                             static_cast<TX_SIZE>(tx_size), &cfg);
    227        int8_t stage_range_col[MAX_TXFM_STAGE_NUM];
    228        int8_t stage_range_row[MAX_TXFM_STAGE_NUM];
    229        av1_gen_fwd_stage_range(stage_range_col, stage_range_row, &cfg, bd);
    230        libaom_test::txfm_stage_range_check(stage_range_col, cfg.stage_num_col,
    231                                            cfg.cos_bit_col, low_range,
    232                                            high_range);
    233        libaom_test::txfm_stage_range_check(stage_range_row, cfg.stage_num_row,
    234                                            cfg.cos_bit_row, low_range,
    235                                            high_range);
    236      }
    237    }
    238  }
    239 }
    240 
    241 using lowbd_fwd_txfm_func = void (*)(const int16_t *src_diff, tran_low_t *coeff,
    242                                     int diff_stride, TxfmParam *txfm_param);
    243 
    244 void AV1FwdTxfm2dMatchTest(TX_SIZE tx_size, lowbd_fwd_txfm_func target_func) {
    245  const int bd = 8;
    246  TxfmParam param;
    247  memset(&param, 0, sizeof(param));
    248  const int rows = tx_size_high[tx_size];
    249  const int cols = tx_size_wide[tx_size];
    250  for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
    251    if (libaom_test::IsTxSizeTypeValid(
    252            tx_size, static_cast<TX_TYPE>(tx_type)) == false) {
    253      continue;
    254    }
    255 
    256    FwdTxfm2dFunc ref_func = libaom_test::fwd_txfm_func_ls[tx_size];
    257    if (ref_func != nullptr) {
    258      DECLARE_ALIGNED(32, int16_t, input[64 * 64]) = { 0 };
    259      DECLARE_ALIGNED(32, int32_t, output[64 * 64]);
    260      DECLARE_ALIGNED(32, int32_t, ref_output[64 * 64]);
    261      int input_stride = 64;
    262      ACMRandom rnd(ACMRandom::DeterministicSeed());
    263      for (int cnt = 0; cnt < 500; ++cnt) {
    264        if (cnt == 0) {
    265          for (int c = 0; c < cols; ++c) {
    266            for (int r = 0; r < rows; ++r) {
    267              input[r * input_stride + c] = (1 << bd) - 1;
    268            }
    269          }
    270        } else {
    271          for (int r = 0; r < rows; ++r) {
    272            for (int c = 0; c < cols; ++c) {
    273              input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
    274            }
    275          }
    276        }
    277        param.tx_type = (TX_TYPE)tx_type;
    278        param.tx_size = (TX_SIZE)tx_size;
    279        param.tx_set_type = EXT_TX_SET_ALL16;
    280        param.bd = bd;
    281        ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
    282        target_func(input, output, input_stride, &param);
    283        const int check_cols = AOMMIN(32, cols);
    284        const int check_rows = AOMMIN(32, rows * cols / check_cols);
    285        for (int r = 0; r < check_rows; ++r) {
    286          for (int c = 0; c < check_cols; ++c) {
    287            ASSERT_EQ(ref_output[r * check_cols + c],
    288                      output[r * check_cols + c])
    289                << "[" << r << "," << c << "] cnt:" << cnt
    290                << " tx_size: " << cols << "x" << rows
    291                << " tx_type: " << tx_type_name[tx_type];
    292          }
    293        }
    294      }
    295    }
    296  }
    297 }
    298 
    299 void AV1FwdTxfm2dSpeedTest(TX_SIZE tx_size, lowbd_fwd_txfm_func target_func) {
    300  TxfmParam param;
    301  memset(&param, 0, sizeof(param));
    302  const int rows = tx_size_high[tx_size];
    303  const int cols = tx_size_wide[tx_size];
    304  const int num_loops = 1000000 / (rows * cols);
    305 
    306  const int bd = 8;
    307  for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
    308    if (libaom_test::IsTxSizeTypeValid(
    309            tx_size, static_cast<TX_TYPE>(tx_type)) == false) {
    310      continue;
    311    }
    312 
    313    FwdTxfm2dFunc ref_func = libaom_test::fwd_txfm_func_ls[tx_size];
    314    if (ref_func != nullptr) {
    315      DECLARE_ALIGNED(32, int16_t, input[64 * 64]) = { 0 };
    316      DECLARE_ALIGNED(32, int32_t, output[64 * 64]);
    317      DECLARE_ALIGNED(32, int32_t, ref_output[64 * 64]);
    318      int input_stride = 64;
    319      ACMRandom rnd(ACMRandom::DeterministicSeed());
    320 
    321      for (int r = 0; r < rows; ++r) {
    322        for (int c = 0; c < cols; ++c) {
    323          input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
    324        }
    325      }
    326 
    327      param.tx_type = (TX_TYPE)tx_type;
    328      param.tx_size = (TX_SIZE)tx_size;
    329      param.tx_set_type = EXT_TX_SET_ALL16;
    330      param.bd = bd;
    331 
    332      aom_usec_timer ref_timer, test_timer;
    333 
    334      aom_usec_timer_start(&ref_timer);
    335      for (int i = 0; i < num_loops; ++i) {
    336        ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
    337      }
    338      aom_usec_timer_mark(&ref_timer);
    339      const int elapsed_time_c =
    340          static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
    341 
    342      aom_usec_timer_start(&test_timer);
    343      for (int i = 0; i < num_loops; ++i) {
    344        target_func(input, output, input_stride, &param);
    345      }
    346      aom_usec_timer_mark(&test_timer);
    347      const int elapsed_time_simd =
    348          static_cast<int>(aom_usec_timer_elapsed(&test_timer));
    349 
    350      printf(
    351          "txfm_size[%2dx%-2d] \t txfm_type[%d] \t c_time=%d \t"
    352          "simd_time=%d \t gain=%d \n",
    353          rows, cols, tx_type, elapsed_time_c, elapsed_time_simd,
    354          (elapsed_time_c / elapsed_time_simd));
    355    }
    356  }
    357 }
    358 
    359 using LbdFwdTxfm2dParam = std::tuple<TX_SIZE, lowbd_fwd_txfm_func>;
    360 
    361 class AV1FwdTxfm2dTest : public ::testing::TestWithParam<LbdFwdTxfm2dParam> {};
    362 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AV1FwdTxfm2dTest);
    363 
    364 TEST_P(AV1FwdTxfm2dTest, match) {
    365  AV1FwdTxfm2dMatchTest(GET_PARAM(0), GET_PARAM(1));
    366 }
    367 TEST_P(AV1FwdTxfm2dTest, DISABLED_Speed) {
    368  AV1FwdTxfm2dSpeedTest(GET_PARAM(0), GET_PARAM(1));
    369 }
    370 TEST(AV1FwdTxfm2dTest, DCTScaleTest) {
    371  BitDepthInfo bd_info;
    372  bd_info.bit_depth = 8;
    373  bd_info.use_highbitdepth_buf = 0;
    374  DECLARE_ALIGNED(32, int16_t, src_diff[1024]);
    375  DECLARE_ALIGNED(32, tran_low_t, coeff[1024]);
    376 
    377  const TX_SIZE tx_size_list[4] = { TX_4X4, TX_8X8, TX_16X16, TX_32X32 };
    378  const int stride_list[4] = { 4, 8, 16, 32 };
    379  const int ref_scale_list[4] = { 64, 64, 64, 16 };
    380 
    381  for (int i = 0; i < 4; i++) {
    382    TX_SIZE tx_size = tx_size_list[i];
    383    int stride = stride_list[i];
    384    int array_size = stride * stride;
    385 
    386    for (int j = 0; j < array_size; j++) {
    387      src_diff[j] = 8;
    388      coeff[j] = 0;
    389    }
    390 
    391    av1_quick_txfm(/*use_hadamard=*/0, tx_size, bd_info, src_diff, stride,
    392                   coeff);
    393 
    394    double input_sse = 0;
    395    double output_sse = 0;
    396    for (int j = 0; j < array_size; j++) {
    397      input_sse += pow(src_diff[j], 2);
    398      output_sse += pow(coeff[j], 2);
    399    }
    400 
    401    double scale = output_sse / input_sse;
    402 
    403    EXPECT_NEAR(scale, ref_scale_list[i], 5);
    404  }
    405 }
    406 TEST(AV1FwdTxfm2dTest, HadamardScaleTest) {
    407  BitDepthInfo bd_info;
    408  bd_info.bit_depth = 8;
    409  bd_info.use_highbitdepth_buf = 0;
    410  DECLARE_ALIGNED(32, int16_t, src_diff[1024]);
    411  DECLARE_ALIGNED(32, tran_low_t, coeff[1024]);
    412 
    413  const TX_SIZE tx_size_list[4] = { TX_4X4, TX_8X8, TX_16X16, TX_32X32 };
    414  const int stride_list[4] = { 4, 8, 16, 32 };
    415  const int ref_scale_list[4] = { 1, 64, 64, 16 };
    416 
    417  for (int i = 0; i < 4; i++) {
    418    TX_SIZE tx_size = tx_size_list[i];
    419    int stride = stride_list[i];
    420    int array_size = stride * stride;
    421 
    422    for (int j = 0; j < array_size; j++) {
    423      src_diff[j] = 8;
    424      coeff[j] = 0;
    425    }
    426 
    427    av1_quick_txfm(/*use_hadamard=*/1, tx_size, bd_info, src_diff, stride,
    428                   coeff);
    429 
    430    double input_sse = 0;
    431    double output_sse = 0;
    432    for (int j = 0; j < array_size; j++) {
    433      input_sse += pow(src_diff[j], 2);
    434      output_sse += pow(coeff[j], 2);
    435    }
    436 
    437    double scale = output_sse / input_sse;
    438 
    439    EXPECT_NEAR(scale, ref_scale_list[i], 5);
    440  }
    441 }
    442 using ::testing::Combine;
    443 using ::testing::Values;
    444 using ::testing::ValuesIn;
    445 
    446 #if AOM_ARCH_X86 && HAVE_SSE2
    447 static constexpr TX_SIZE fwd_txfm_for_sse2[] = {
    448  TX_4X4,
    449  TX_8X8,
    450  TX_16X16,
    451  TX_32X32,
    452  // TX_64X64,
    453  TX_4X8,
    454  TX_8X4,
    455  TX_8X16,
    456  TX_16X8,
    457  TX_16X32,
    458  TX_32X16,
    459  // TX_32X64,
    460  // TX_64X32,
    461  TX_4X16,
    462  TX_16X4,
    463  TX_8X32,
    464  TX_32X8,
    465  TX_16X64,
    466  TX_64X16,
    467 };
    468 
    469 INSTANTIATE_TEST_SUITE_P(SSE2, AV1FwdTxfm2dTest,
    470                         Combine(ValuesIn(fwd_txfm_for_sse2),
    471                                 Values(av1_lowbd_fwd_txfm_sse2)));
    472 #endif  // AOM_ARCH_X86 && HAVE_SSE2
    473 
    474 #if HAVE_SSE4_1
    475 static constexpr TX_SIZE fwd_txfm_for_sse41[] = {
    476  TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
    477  TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32, TX_4X16,
    478  TX_16X4, TX_8X32, TX_32X8,  TX_16X64, TX_64X16
    479 };
    480 
    481 INSTANTIATE_TEST_SUITE_P(SSE4_1, AV1FwdTxfm2dTest,
    482                         Combine(ValuesIn(fwd_txfm_for_sse41),
    483                                 Values(av1_lowbd_fwd_txfm_sse4_1)));
    484 #endif  // HAVE_SSE4_1
    485 
    486 #if HAVE_AVX2
    487 static constexpr TX_SIZE fwd_txfm_for_avx2[] = {
    488  TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
    489  TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32, TX_4X16,
    490  TX_16X4, TX_8X32, TX_32X8,  TX_16X64, TX_64X16,
    491 };
    492 
    493 INSTANTIATE_TEST_SUITE_P(AVX2, AV1FwdTxfm2dTest,
    494                         Combine(ValuesIn(fwd_txfm_for_avx2),
    495                                 Values(av1_lowbd_fwd_txfm_avx2)));
    496 #endif  // HAVE_AVX2
    497 
    498 #if CONFIG_HIGHWAY && HAVE_AVX512
    499 static constexpr TX_SIZE fwd_txfm_for_avx512[] = {
    500  TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
    501  TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32, TX_4X16,
    502  TX_16X4, TX_8X32, TX_32X8,  TX_16X64, TX_64X16,
    503 };
    504 
    505 INSTANTIATE_TEST_SUITE_P(AVX512, AV1FwdTxfm2dTest,
    506                         Combine(ValuesIn(fwd_txfm_for_avx512),
    507                                 Values(av1_lowbd_fwd_txfm_avx512)));
    508 #endif  // HAVE_AVX512
    509 
    510 #if HAVE_NEON
    511 
    512 static constexpr TX_SIZE fwd_txfm_for_neon[] = {
    513  TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
    514  TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32, TX_4X16,
    515  TX_16X4, TX_8X32, TX_32X8,  TX_16X64, TX_64X16
    516 };
    517 
    518 INSTANTIATE_TEST_SUITE_P(NEON, AV1FwdTxfm2dTest,
    519                         Combine(ValuesIn(fwd_txfm_for_neon),
    520                                 Values(av1_lowbd_fwd_txfm_neon)));
    521 
    522 #endif  // HAVE_NEON
    523 
    524 using Highbd_fwd_txfm_func = void (*)(const int16_t *src_diff,
    525                                      tran_low_t *coeff, int diff_stride,
    526                                      TxfmParam *txfm_param);
    527 
    528 void AV1HighbdFwdTxfm2dMatchTest(TX_SIZE tx_size,
    529                                 Highbd_fwd_txfm_func target_func) {
    530  const int bd_ar[2] = { 10, 12 };
    531  TxfmParam param;
    532  memset(&param, 0, sizeof(param));
    533  const int rows = tx_size_high[tx_size];
    534  const int cols = tx_size_wide[tx_size];
    535  for (int i = 0; i < 2; ++i) {
    536    const int bd = bd_ar[i];
    537    for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
    538      if (libaom_test::IsTxSizeTypeValid(
    539              tx_size, static_cast<TX_TYPE>(tx_type)) == false) {
    540        continue;
    541      }
    542 
    543      FwdTxfm2dFunc ref_func = libaom_test::fwd_txfm_func_ls[tx_size];
    544      if (ref_func != nullptr) {
    545        DECLARE_ALIGNED(32, int16_t, input[64 * 64]) = { 0 };
    546        DECLARE_ALIGNED(32, int32_t, output[64 * 64]);
    547        DECLARE_ALIGNED(32, int32_t, ref_output[64 * 64]);
    548        int input_stride = 64;
    549        ACMRandom rnd(ACMRandom::DeterministicSeed());
    550        for (int cnt = 0; cnt < 500; ++cnt) {
    551          if (cnt == 0) {
    552            for (int r = 0; r < rows; ++r) {
    553              for (int c = 0; c < cols; ++c) {
    554                input[r * input_stride + c] = (1 << bd) - 1;
    555              }
    556            }
    557          } else {
    558            for (int r = 0; r < rows; ++r) {
    559              for (int c = 0; c < cols; ++c) {
    560                input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
    561              }
    562            }
    563          }
    564          param.tx_type = (TX_TYPE)tx_type;
    565          param.tx_size = (TX_SIZE)tx_size;
    566          param.tx_set_type = EXT_TX_SET_ALL16;
    567          param.bd = bd;
    568 
    569          ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
    570          target_func(input, output, input_stride, &param);
    571          const int check_cols = AOMMIN(32, cols);
    572          const int check_rows = AOMMIN(32, rows * cols / check_cols);
    573          for (int r = 0; r < check_rows; ++r) {
    574            for (int c = 0; c < check_cols; ++c) {
    575              ASSERT_EQ(ref_output[c * check_rows + r],
    576                        output[c * check_rows + r])
    577                  << "[" << r << "," << c << "] cnt:" << cnt
    578                  << " tx_size: " << cols << "x" << rows
    579                  << " tx_type: " << tx_type;
    580            }
    581          }
    582        }
    583      }
    584    }
    585  }
    586 }
    587 
    588 void AV1HighbdFwdTxfm2dSpeedTest(TX_SIZE tx_size,
    589                                 Highbd_fwd_txfm_func target_func) {
    590  const int bd_ar[2] = { 10, 12 };
    591  TxfmParam param;
    592  memset(&param, 0, sizeof(param));
    593  const int rows = tx_size_high[tx_size];
    594  const int cols = tx_size_wide[tx_size];
    595  const int num_loops = 1000000 / (rows * cols);
    596 
    597  for (int i = 0; i < 2; ++i) {
    598    const int bd = bd_ar[i];
    599    for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
    600      if (libaom_test::IsTxSizeTypeValid(
    601              tx_size, static_cast<TX_TYPE>(tx_type)) == false) {
    602        continue;
    603      }
    604 
    605      FwdTxfm2dFunc ref_func = libaom_test::fwd_txfm_func_ls[tx_size];
    606      if (ref_func != nullptr) {
    607        DECLARE_ALIGNED(32, int16_t, input[64 * 64]) = { 0 };
    608        DECLARE_ALIGNED(32, int32_t, output[64 * 64]);
    609        DECLARE_ALIGNED(32, int32_t, ref_output[64 * 64]);
    610        int input_stride = 64;
    611        ACMRandom rnd(ACMRandom::DeterministicSeed());
    612 
    613        for (int r = 0; r < rows; ++r) {
    614          for (int c = 0; c < cols; ++c) {
    615            input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
    616          }
    617        }
    618 
    619        param.tx_type = (TX_TYPE)tx_type;
    620        param.tx_size = (TX_SIZE)tx_size;
    621        param.tx_set_type = EXT_TX_SET_ALL16;
    622        param.bd = bd;
    623 
    624        aom_usec_timer ref_timer, test_timer;
    625 
    626        aom_usec_timer_start(&ref_timer);
    627        for (int j = 0; j < num_loops; ++j) {
    628          ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
    629        }
    630        aom_usec_timer_mark(&ref_timer);
    631        const int elapsed_time_c =
    632            static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
    633 
    634        aom_usec_timer_start(&test_timer);
    635        for (int j = 0; j < num_loops; ++j) {
    636          target_func(input, output, input_stride, &param);
    637        }
    638        aom_usec_timer_mark(&test_timer);
    639        const int elapsed_time_simd =
    640            static_cast<int>(aom_usec_timer_elapsed(&test_timer));
    641 
    642        printf(
    643            "txfm_size[%2dx%-2d] \t txfm_type[%d] \t c_time=%d \t"
    644            "simd_time=%d \t gain=%d \n",
    645            cols, rows, tx_type, elapsed_time_c, elapsed_time_simd,
    646            (elapsed_time_c / elapsed_time_simd));
    647      }
    648    }
    649  }
    650 }
    651 
    652 using HighbdFwdTxfm2dParam = std::tuple<TX_SIZE, Highbd_fwd_txfm_func>;
    653 
    654 class AV1HighbdFwdTxfm2dTest
    655    : public ::testing::TestWithParam<HighbdFwdTxfm2dParam> {};
    656 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AV1HighbdFwdTxfm2dTest);
    657 
    658 TEST_P(AV1HighbdFwdTxfm2dTest, match) {
    659  AV1HighbdFwdTxfm2dMatchTest(GET_PARAM(0), GET_PARAM(1));
    660 }
    661 
    662 TEST_P(AV1HighbdFwdTxfm2dTest, DISABLED_Speed) {
    663  AV1HighbdFwdTxfm2dSpeedTest(GET_PARAM(0), GET_PARAM(1));
    664 }
    665 
    666 using ::testing::Combine;
    667 using ::testing::Values;
    668 using ::testing::ValuesIn;
    669 
    670 #if HAVE_SSE4_1
    671 static constexpr TX_SIZE Highbd_fwd_txfm_for_sse4_1[] = {
    672  TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
    673  TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32,
    674 #if !CONFIG_REALTIME_ONLY
    675  TX_4X16, TX_16X4, TX_8X32,  TX_32X8,  TX_16X64, TX_64X16,
    676 #endif  // !CONFIG_REALTIME_ONLY
    677 };
    678 
    679 INSTANTIATE_TEST_SUITE_P(SSE4_1, AV1HighbdFwdTxfm2dTest,
    680                         Combine(ValuesIn(Highbd_fwd_txfm_for_sse4_1),
    681                                 Values(av1_highbd_fwd_txfm)));
    682 #endif  // HAVE_SSE4_1
    683 #if HAVE_AVX2
    684 static constexpr TX_SIZE Highbd_fwd_txfm_for_avx2[] = { TX_8X8,   TX_16X16,
    685                                                        TX_32X32, TX_64X64,
    686                                                        TX_8X16,  TX_16X8 };
    687 
    688 INSTANTIATE_TEST_SUITE_P(AVX2, AV1HighbdFwdTxfm2dTest,
    689                         Combine(ValuesIn(Highbd_fwd_txfm_for_avx2),
    690                                 Values(av1_highbd_fwd_txfm)));
    691 #endif  // HAVE_AVX2
    692 
    693 #if HAVE_NEON
    694 static constexpr TX_SIZE Highbd_fwd_txfm_for_neon[] = {
    695  TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
    696  TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32,
    697 #if !CONFIG_REALTIME_ONLY
    698  TX_4X16, TX_16X4, TX_8X32,  TX_32X8,  TX_16X64, TX_64X16
    699 #endif  // !CONFIG_REALTIME_ONLY
    700 };
    701 
    702 INSTANTIATE_TEST_SUITE_P(NEON, AV1HighbdFwdTxfm2dTest,
    703                         Combine(ValuesIn(Highbd_fwd_txfm_for_neon),
    704                                 Values(av1_highbd_fwd_txfm)));
    705 #endif  // HAVE_NEON
    706 
    707 }  // namespace