tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

av1_txfm_test.cc (11649B)


      1 /*
      2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include "test/av1_txfm_test.h"
     13 
     14 #include <stdio.h>
     15 
     16 #include <memory>
     17 #include <new>
     18 
     19 namespace libaom_test {
     20 
     21 const char *tx_type_name[] = {
     22  "DCT_DCT",
     23  "ADST_DCT",
     24  "DCT_ADST",
     25  "ADST_ADST",
     26  "FLIPADST_DCT",
     27  "DCT_FLIPADST",
     28  "FLIPADST_FLIPADST",
     29  "ADST_FLIPADST",
     30  "FLIPADST_ADST",
     31  "IDTX",
     32  "V_DCT",
     33  "H_DCT",
     34  "V_ADST",
     35  "H_ADST",
     36  "V_FLIPADST",
     37  "H_FLIPADST",
     38 };
     39 
     40 int get_txfm1d_size(TX_SIZE tx_size) { return tx_size_wide[tx_size]; }
     41 
     42 void get_txfm1d_type(TX_TYPE txfm2d_type, TYPE_TXFM *type0, TYPE_TXFM *type1) {
     43  switch (txfm2d_type) {
     44    case DCT_DCT:
     45      *type0 = TYPE_DCT;
     46      *type1 = TYPE_DCT;
     47      break;
     48    case ADST_DCT:
     49      *type0 = TYPE_ADST;
     50      *type1 = TYPE_DCT;
     51      break;
     52    case DCT_ADST:
     53      *type0 = TYPE_DCT;
     54      *type1 = TYPE_ADST;
     55      break;
     56    case ADST_ADST:
     57      *type0 = TYPE_ADST;
     58      *type1 = TYPE_ADST;
     59      break;
     60    case FLIPADST_DCT:
     61      *type0 = TYPE_ADST;
     62      *type1 = TYPE_DCT;
     63      break;
     64    case DCT_FLIPADST:
     65      *type0 = TYPE_DCT;
     66      *type1 = TYPE_ADST;
     67      break;
     68    case FLIPADST_FLIPADST:
     69      *type0 = TYPE_ADST;
     70      *type1 = TYPE_ADST;
     71      break;
     72    case ADST_FLIPADST:
     73      *type0 = TYPE_ADST;
     74      *type1 = TYPE_ADST;
     75      break;
     76    case FLIPADST_ADST:
     77      *type0 = TYPE_ADST;
     78      *type1 = TYPE_ADST;
     79      break;
     80    case IDTX:
     81      *type0 = TYPE_IDTX;
     82      *type1 = TYPE_IDTX;
     83      break;
     84    case H_DCT:
     85      *type0 = TYPE_IDTX;
     86      *type1 = TYPE_DCT;
     87      break;
     88    case V_DCT:
     89      *type0 = TYPE_DCT;
     90      *type1 = TYPE_IDTX;
     91      break;
     92    case H_ADST:
     93      *type0 = TYPE_IDTX;
     94      *type1 = TYPE_ADST;
     95      break;
     96    case V_ADST:
     97      *type0 = TYPE_ADST;
     98      *type1 = TYPE_IDTX;
     99      break;
    100    case H_FLIPADST:
    101      *type0 = TYPE_IDTX;
    102      *type1 = TYPE_ADST;
    103      break;
    104    case V_FLIPADST:
    105      *type0 = TYPE_ADST;
    106      *type1 = TYPE_IDTX;
    107      break;
    108    default:
    109      *type0 = TYPE_DCT;
    110      *type1 = TYPE_DCT;
    111      assert(0);
    112      break;
    113  }
    114 }
    115 
    116 double Sqrt2 = pow(2, 0.5);
    117 double invSqrt2 = 1 / pow(2, 0.5);
    118 
    119 static double dct_matrix(double n, double k, int size) {
    120  return cos(PI * (2 * n + 1) * k / (2 * size));
    121 }
    122 
    123 void reference_dct_1d(const double *in, double *out, int size) {
    124  for (int k = 0; k < size; ++k) {
    125    out[k] = 0;
    126    for (int n = 0; n < size; ++n) {
    127      out[k] += in[n] * dct_matrix(n, k, size);
    128    }
    129    if (k == 0) out[k] = out[k] * invSqrt2;
    130  }
    131 }
    132 
    133 void reference_idct_1d(const double *in, double *out, int size) {
    134  for (int k = 0; k < size; ++k) {
    135    out[k] = 0;
    136    for (int n = 0; n < size; ++n) {
    137      if (n == 0)
    138        out[k] += invSqrt2 * in[n] * dct_matrix(k, n, size);
    139      else
    140        out[k] += in[n] * dct_matrix(k, n, size);
    141    }
    142  }
    143 }
    144 
    145 // TODO(any): Copied from the old 'fadst4' (same as the new 'av1_fadst4'
    146 // function). Should be replaced by a proper reference function that takes
    147 // 'double' input & output.
    148 static void fadst4_new(const tran_low_t *input, tran_low_t *output) {
    149  tran_high_t x0, x1, x2, x3;
    150  tran_high_t s0, s1, s2, s3, s4, s5, s6, s7;
    151 
    152  x0 = input[0];
    153  x1 = input[1];
    154  x2 = input[2];
    155  x3 = input[3];
    156 
    157  if (!(x0 | x1 | x2 | x3)) {
    158    output[0] = output[1] = output[2] = output[3] = 0;
    159    return;
    160  }
    161 
    162  s0 = sinpi_1_9 * x0;
    163  s1 = sinpi_4_9 * x0;
    164  s2 = sinpi_2_9 * x1;
    165  s3 = sinpi_1_9 * x1;
    166  s4 = sinpi_3_9 * x2;
    167  s5 = sinpi_4_9 * x3;
    168  s6 = sinpi_2_9 * x3;
    169  s7 = x0 + x1 - x3;
    170 
    171  x0 = s0 + s2 + s5;
    172  x1 = sinpi_3_9 * s7;
    173  x2 = s1 - s3 + s6;
    174  x3 = s4;
    175 
    176  s0 = x0 + x3;
    177  s1 = x1;
    178  s2 = x2 - x3;
    179  s3 = x2 - x0 + x3;
    180 
    181  // 1-D transform scaling factor is sqrt(2).
    182  output[0] = (tran_low_t)fdct_round_shift(s0);
    183  output[1] = (tran_low_t)fdct_round_shift(s1);
    184  output[2] = (tran_low_t)fdct_round_shift(s2);
    185  output[3] = (tran_low_t)fdct_round_shift(s3);
    186 }
    187 
    188 void reference_adst_1d(const double *in, double *out, int size) {
    189  if (size == 4) {  // Special case.
    190    tran_low_t int_input[4];
    191    for (int i = 0; i < 4; ++i) {
    192      int_input[i] = static_cast<tran_low_t>(round(in[i]));
    193    }
    194    tran_low_t int_output[4];
    195    fadst4_new(int_input, int_output);
    196    for (int i = 0; i < 4; ++i) {
    197      out[i] = int_output[i];
    198    }
    199    return;
    200  }
    201 
    202  for (int k = 0; k < size; ++k) {
    203    out[k] = 0;
    204    for (int n = 0; n < size; ++n) {
    205      out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size));
    206    }
    207  }
    208 }
    209 
    210 static void reference_idtx_1d(const double *in, double *out, int size) {
    211  double scale = 0;
    212  if (size == 4)
    213    scale = Sqrt2;
    214  else if (size == 8)
    215    scale = 2;
    216  else if (size == 16)
    217    scale = 2 * Sqrt2;
    218  else if (size == 32)
    219    scale = 4;
    220  else if (size == 64)
    221    scale = 4 * Sqrt2;
    222  for (int k = 0; k < size; ++k) {
    223    out[k] = in[k] * scale;
    224  }
    225 }
    226 
    227 void reference_hybrid_1d(double *in, double *out, int size, int type) {
    228  if (type == TYPE_DCT)
    229    reference_dct_1d(in, out, size);
    230  else if (type == TYPE_ADST)
    231    reference_adst_1d(in, out, size);
    232  else
    233    reference_idtx_1d(in, out, size);
    234 }
    235 
    236 double get_amplification_factor(TX_TYPE tx_type, TX_SIZE tx_size) {
    237  TXFM_2D_FLIP_CFG fwd_txfm_flip_cfg;
    238  av1_get_fwd_txfm_cfg(tx_type, tx_size, &fwd_txfm_flip_cfg);
    239  const int tx_width = tx_size_wide[fwd_txfm_flip_cfg.tx_size];
    240  const int tx_height = tx_size_high[fwd_txfm_flip_cfg.tx_size];
    241  const int8_t *shift = fwd_txfm_flip_cfg.shift;
    242  const int amplify_bit = shift[0] + shift[1] + shift[2];
    243  double amplify_factor =
    244      amplify_bit >= 0 ? (1 << amplify_bit) : (1.0 / (1 << -amplify_bit));
    245 
    246  // For rectangular transforms, we need to multiply by an extra factor.
    247  const int rect_type = get_rect_tx_log_ratio(tx_width, tx_height);
    248  if (abs(rect_type) == 1) {
    249    amplify_factor *= pow(2, 0.5);
    250  }
    251  return amplify_factor;
    252 }
    253 
    254 void reference_hybrid_2d(double *in, double *out, TX_TYPE tx_type,
    255                         TX_SIZE tx_size) {
    256  // Get transform type and size of each dimension.
    257  TYPE_TXFM type0;
    258  TYPE_TXFM type1;
    259  get_txfm1d_type(tx_type, &type0, &type1);
    260  const int tx_width = tx_size_wide[tx_size];
    261  const int tx_height = tx_size_high[tx_size];
    262 
    263  std::unique_ptr<double[]> temp_in(
    264      new (std::nothrow) double[AOMMAX(tx_width, tx_height)]);
    265  std::unique_ptr<double[]> temp_out(
    266      new (std::nothrow) double[AOMMAX(tx_width, tx_height)]);
    267  std::unique_ptr<double[]> out_interm(
    268      new (std::nothrow) double[tx_width * tx_height]);
    269  ASSERT_NE(temp_in, nullptr);
    270  ASSERT_NE(temp_out, nullptr);
    271  ASSERT_NE(out_interm, nullptr);
    272 
    273  // Transform columns.
    274  for (int c = 0; c < tx_width; ++c) {
    275    for (int r = 0; r < tx_height; ++r) {
    276      temp_in[r] = in[r * tx_width + c];
    277    }
    278    reference_hybrid_1d(temp_in.get(), temp_out.get(), tx_height, type0);
    279    for (int r = 0; r < tx_height; ++r) {
    280      out_interm[r * tx_width + c] = temp_out[r];
    281    }
    282  }
    283 
    284  // Transform rows.
    285  for (int r = 0; r < tx_height; ++r) {
    286    reference_hybrid_1d(out_interm.get() + r * tx_width, temp_out.get(),
    287                        tx_width, type1);
    288    for (int c = 0; c < tx_width; ++c) {
    289      out[c * tx_height + r] = temp_out[c];
    290    }
    291  }
    292 
    293  // These transforms use an approximate 2D DCT transform, by only keeping the
    294  // top-left quarter of the coefficients, and repacking them in the first
    295  // quarter indices.
    296  // TODO(urvang): Refactor this code.
    297  if (tx_width == 64 && tx_height == 64) {  // tx_size == TX_64X64
    298    // Zero out top-right 32x32 area.
    299    for (int col = 0; col < 32; ++col) {
    300      memset(out + col * 64 + 32, 0, 32 * sizeof(*out));
    301    }
    302    // Zero out the bottom 64x32 area.
    303    memset(out + 32 * 64, 0, 32 * 64 * sizeof(*out));
    304    // Re-pack non-zero coeffs in the first 32x32 indices.
    305    for (int col = 1; col < 32; ++col) {
    306      memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out));
    307    }
    308  } else if (tx_width == 32 && tx_height == 64) {  // tx_size == TX_32X64
    309    // Zero out right 32x32 area.
    310    for (int col = 0; col < 32; ++col) {
    311      memset(out + col * 64 + 32, 0, 32 * sizeof(*out));
    312    }
    313    // Re-pack non-zero coeffs in the first 32x32 indices.
    314    for (int col = 1; col < 32; ++col) {
    315      memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out));
    316    }
    317  } else if (tx_width == 64 && tx_height == 32) {  // tx_size == TX_64X32
    318    // Zero out the bottom 32x32 area.
    319    memset(out + 32 * 32, 0, 32 * 32 * sizeof(*out));
    320    // Note: no repacking needed here.
    321  } else if (tx_width == 16 && tx_height == 64) {  // tx_size == TX_16X64
    322    // Note: no repacking needed here.
    323    // Zero out right 32x16 area.
    324    for (int col = 0; col < 16; ++col) {
    325      memset(out + col * 64 + 32, 0, 32 * sizeof(*out));
    326    }
    327    // Re-pack non-zero coeffs in the first 32x16 indices.
    328    for (int col = 1; col < 16; ++col) {
    329      memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out));
    330    }
    331  } else if (tx_width == 64 && tx_height == 16) {  // tx_size == TX_64X16
    332    // Zero out the bottom 16x32 area.
    333    memset(out + 16 * 32, 0, 16 * 32 * sizeof(*out));
    334  }
    335 
    336  // Apply appropriate scale.
    337  const double amplify_factor = get_amplification_factor(tx_type, tx_size);
    338  for (int c = 0; c < tx_width; ++c) {
    339    for (int r = 0; r < tx_height; ++r) {
    340      out[c * tx_height + r] *= amplify_factor;
    341    }
    342  }
    343 }
    344 
    345 template <typename Type>
    346 void fliplr(Type *dest, int width, int height, int stride) {
    347  for (int r = 0; r < height; ++r) {
    348    for (int c = 0; c < width / 2; ++c) {
    349      const Type tmp = dest[r * stride + c];
    350      dest[r * stride + c] = dest[r * stride + width - 1 - c];
    351      dest[r * stride + width - 1 - c] = tmp;
    352    }
    353  }
    354 }
    355 
    356 template <typename Type>
    357 void flipud(Type *dest, int width, int height, int stride) {
    358  for (int c = 0; c < width; ++c) {
    359    for (int r = 0; r < height / 2; ++r) {
    360      const Type tmp = dest[r * stride + c];
    361      dest[r * stride + c] = dest[(height - 1 - r) * stride + c];
    362      dest[(height - 1 - r) * stride + c] = tmp;
    363    }
    364  }
    365 }
    366 
    367 template <typename Type>
    368 void fliplrud(Type *dest, int width, int height, int stride) {
    369  for (int r = 0; r < height / 2; ++r) {
    370    for (int c = 0; c < width; ++c) {
    371      const Type tmp = dest[r * stride + c];
    372      dest[r * stride + c] = dest[(height - 1 - r) * stride + width - 1 - c];
    373      dest[(height - 1 - r) * stride + width - 1 - c] = tmp;
    374    }
    375  }
    376 }
    377 
    378 template void fliplr<double>(double *dest, int width, int height, int stride);
    379 template void flipud<double>(double *dest, int width, int height, int stride);
    380 template void fliplrud<double>(double *dest, int width, int height, int stride);
    381 
    382 int bd_arr[BD_NUM] = { 8, 10, 12 };
    383 
    384 int8_t low_range_arr[BD_NUM] = { 18, 32, 32 };
    385 int8_t high_range_arr[BD_NUM] = { 32, 32, 32 };
    386 
    387 void txfm_stage_range_check(const int8_t *stage_range, int stage_num,
    388                            int8_t cos_bit, int low_range, int high_range) {
    389  for (int i = 0; i < stage_num; ++i) {
    390    EXPECT_LE(stage_range[i], low_range);
    391    ASSERT_LE(stage_range[i] + cos_bit, high_range) << "stage = " << i;
    392  }
    393  for (int i = 0; i < stage_num - 1; ++i) {
    394    // make sure there is no overflow while doing half_btf()
    395    ASSERT_LE(stage_range[i + 1] + cos_bit, high_range) << "stage = " << i;
    396  }
    397 }
    398 }  // namespace libaom_test