tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

tx_search.c (154649B)


      1 /*
      2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include "av1/common/cfl.h"
     13 #include "av1/common/reconintra.h"
     14 #include "av1/encoder/block.h"
     15 #include "av1/encoder/hybrid_fwd_txfm.h"
     16 #include "av1/common/idct.h"
     17 #include "av1/encoder/model_rd.h"
     18 #include "av1/encoder/random.h"
     19 #include "av1/encoder/rdopt_utils.h"
     20 #include "av1/encoder/sorting_network.h"
     21 #include "av1/encoder/tx_prune_model_weights.h"
     22 #include "av1/encoder/tx_search.h"
     23 #include "av1/encoder/txb_rdopt.h"
     24 
     25 #define PROB_THRESH_OFFSET_TX_TYPE 100
     26 
     27 struct rdcost_block_args {
     28  const AV1_COMP *cpi;
     29  MACROBLOCK *x;
     30  ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
     31  ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
     32  RD_STATS rd_stats;
     33  int64_t current_rd;
     34  int64_t best_rd;
     35  int exit_early;
     36  int incomplete_exit;
     37  FAST_TX_SEARCH_MODE ftxs_mode;
     38  int skip_trellis;
     39 };
     40 
     41 typedef struct {
     42  int64_t rd;
     43  int txb_entropy_ctx;
     44  TX_TYPE tx_type;
     45 } TxCandidateInfo;
     46 
     47 // origin_threshold * 128 / 100
     48 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
     49  {
     50      64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
     51      68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
     52  },
     53  {
     54      88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
     55      68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
     56  },
     57  {
     58      90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
     59      74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
     60  },
     61 };
     62 
     63 // lookup table for predict_skip_txfm
     64 // int max_tx_size = max_txsize_rect_lookup[bsize];
     65 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
     66 //   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
     67 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
     68  TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
     69  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
     70  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
     71  TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
     72 };
     73 
     74 // look-up table for sqrt of number of pixels in a transform block
     75 // rounded up to the nearest integer.
     76 static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4,  8,  16, 32, 32, 6,  6,
     77                                                     12, 12, 23, 23, 32, 32, 8,
     78                                                     8,  16, 16, 23, 23 };
     79 
     80 static inline uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
     81  const int rows = block_size_high[bsize];
     82  const int cols = block_size_wide[bsize];
     83  const int16_t *diff = x->plane[0].src_diff;
     84  const uint32_t hash =
     85      av1_get_crc32c_value(&x->txfm_search_info.mb_rd_record->crc_calculator,
     86                           (uint8_t *)diff, 2 * rows * cols);
     87  return (hash << 5) + bsize;
     88 }
     89 
     90 static inline int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
     91                                      const int64_t ref_best_rd,
     92                                      const uint32_t hash) {
     93  int32_t match_index = -1;
     94  if (ref_best_rd != INT64_MAX) {
     95    for (int i = 0; i < mb_rd_record->num; ++i) {
     96      const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
     97      // If there is a match in the mb_rd_record, fetch the RD decision and
     98      // terminate early.
     99      if (mb_rd_record->mb_rd_info[index].hash_value == hash) {
    100        match_index = index;
    101        break;
    102      }
    103    }
    104  }
    105  return match_index;
    106 }
    107 
    108 static inline void fetch_mb_rd_info(int n4, const MB_RD_INFO *const mb_rd_info,
    109                                    RD_STATS *const rd_stats,
    110                                    MACROBLOCK *const x) {
    111  MACROBLOCKD *const xd = &x->e_mbd;
    112  MB_MODE_INFO *const mbmi = xd->mi[0];
    113  mbmi->tx_size = mb_rd_info->tx_size;
    114  memcpy(x->txfm_search_info.blk_skip, mb_rd_info->blk_skip,
    115         sizeof(mb_rd_info->blk_skip[0]) * n4);
    116  av1_copy(mbmi->inter_tx_size, mb_rd_info->inter_tx_size);
    117  av1_copy_array(xd->tx_type_map, mb_rd_info->tx_type_map, n4);
    118  *rd_stats = mb_rd_info->rd_stats;
    119 }
    120 
    121 int64_t av1_pixel_diff_dist(const MACROBLOCK *x, int plane, int blk_row,
    122                            int blk_col, const BLOCK_SIZE plane_bsize,
    123                            const BLOCK_SIZE tx_bsize,
    124                            unsigned int *block_mse_q8) {
    125  int visible_rows, visible_cols;
    126  const MACROBLOCKD *xd = &x->e_mbd;
    127  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
    128                     NULL, &visible_cols, &visible_rows);
    129  const int diff_stride = block_size_wide[plane_bsize];
    130  const int16_t *diff = x->plane[plane].src_diff;
    131 
    132  diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
    133  uint64_t sse =
    134      aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
    135  if (block_mse_q8 != NULL) {
    136    if (visible_cols > 0 && visible_rows > 0)
    137      *block_mse_q8 =
    138          (unsigned int)((256 * sse) / (visible_cols * visible_rows));
    139    else
    140      *block_mse_q8 = UINT_MAX;
    141  }
    142  return sse;
    143 }
    144 
    145 // Computes the residual block's SSE and mean on all visible 4x4s in the
    146 // transform block
    147 static inline int64_t pixel_diff_stats(
    148    MACROBLOCK *x, int plane, int blk_row, int blk_col,
    149    const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
    150    unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {
    151  int visible_rows, visible_cols;
    152  const MACROBLOCKD *xd = &x->e_mbd;
    153  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
    154                     NULL, &visible_cols, &visible_rows);
    155  const int diff_stride = block_size_wide[plane_bsize];
    156  const int16_t *diff = x->plane[plane].src_diff;
    157 
    158  diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
    159  uint64_t sse = 0;
    160  int sum = 0;
    161  sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum);
    162  if (visible_cols > 0 && visible_rows > 0) {
    163    double norm_factor = 1.0 / (visible_cols * visible_rows);
    164    int sign_sum = sum > 0 ? 1 : -1;
    165    // Conversion to transform domain
    166    *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7;
    167    *per_px_mean = sign_sum * (*per_px_mean);
    168    *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse));
    169    *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum));
    170  } else {
    171    *block_mse_q8 = UINT_MAX;
    172  }
    173  return sse;
    174 }
    175 
    176 // Uses simple features on top of DCT coefficients to quickly predict
    177 // whether optimal RD decision is to skip encoding the residual.
    178 // The sse value is stored in dist.
    179 static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
    180                             int reduced_tx_set) {
    181  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
    182  const int bw = block_size_wide[bsize];
    183  const int bh = block_size_high[bsize];
    184  const MACROBLOCKD *xd = &x->e_mbd;
    185  const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
    186 
    187  *dist = av1_pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
    188 
    189  const int64_t mse = *dist / bw / bh;
    190  // Normalized quantizer takes the transform upscaling factor (8 for tx size
    191  // smaller than 32) into account.
    192  const int16_t normalized_dc_q = dc_q >> 3;
    193  const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
    194  // For faster early skip decision, use dist to compare against threshold so
    195  // that quality risk is less for the skip=1 decision. Otherwise, use mse
    196  // since the fwd_txfm coeff checks will take care of quality
    197  // TODO(any): Use dist to return 0 when skip_txfm_level is 1
    198  int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse;
    199  // Predict not to skip when error is larger than threshold.
    200  if (pred_err > mse_thresh) return 0;
    201  // Return as skip otherwise for aggressive early skip
    202  else if (txfm_params->skip_txfm_level >= 2)
    203    return 1;
    204 
    205  const int max_tx_size = max_predict_sf_tx_size[bsize];
    206  const int tx_h = tx_size_high[max_tx_size];
    207  const int tx_w = tx_size_wide[max_tx_size];
    208  DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
    209  TxfmParam param;
    210  param.tx_type = DCT_DCT;
    211  param.tx_size = max_tx_size;
    212  param.bd = xd->bd;
    213  param.is_hbd = is_cur_buf_hbd(xd);
    214  param.lossless = 0;
    215  param.tx_set_type = av1_get_ext_tx_set_type(
    216      param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
    217  const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
    218  const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
    219  const int16_t *src_diff = x->plane[0].src_diff;
    220  const int n_coeff = tx_w * tx_h;
    221  const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
    222  const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
    223  const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
    224  for (int row = 0; row < bh; row += tx_h) {
    225    for (int col = 0; col < bw; col += tx_w) {
    226      av1_fwd_txfm(src_diff + col, coefs, bw, &param);
    227      // Operating on TX domain, not pixels; we want the QTX quantizers
    228      const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
    229      if (dc_coef >= dc_thresh) return 0;
    230      for (int i = 1; i < n_coeff; ++i) {
    231        const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
    232        if (ac_coef >= ac_thresh) return 0;
    233      }
    234    }
    235    src_diff += tx_h * bw;
    236  }
    237  return 1;
    238 }
    239 
    240 // Used to set proper context for early termination with skip = 1.
    241 static inline void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
    242                                 BLOCK_SIZE bsize, int64_t dist) {
    243  MACROBLOCKD *const xd = &x->e_mbd;
    244  MB_MODE_INFO *const mbmi = xd->mi[0];
    245  const int n4 = bsize_to_num_blk(bsize);
    246  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
    247  memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
    248  memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
    249  mbmi->tx_size = tx_size;
    250  for (int i = 0; i < n4; ++i)
    251    set_blk_skip(x->txfm_search_info.blk_skip, 0, i, 1);
    252  rd_stats->skip_txfm = 1;
    253  if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
    254  rd_stats->dist = rd_stats->sse = (dist << 4);
    255  // Though decision is to make the block as skip based on luma stats,
    256  // it is possible that block becomes non skip after chroma rd. In addition
    257  // intermediate non skip costs calculated by caller function will be
    258  // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
    259  // accounted). Hence intermediate rate is populated to code the luma tx blks
    260  // as skip, the caller function based on final rd decision (i.e., skip vs
    261  // non-skip) sets the final rate accordingly. Here the rate populated
    262  // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
    263  // size possible) in the current block. Eg: For 128*128 block, rate would be
    264  // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
    265  // block as 'all zeros'
    266  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
    267  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
    268  av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
    269  ENTROPY_CONTEXT *ta = ctxa;
    270  ENTROPY_CONTEXT *tl = ctxl;
    271  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
    272  TXB_CTX txb_ctx;
    273  get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
    274  const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
    275                                .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
    276  rd_stats->rate = zero_blk_rate *
    277                   (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
    278                   (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
    279 }
    280 
    281 static inline void save_mb_rd_info(int n4, uint32_t hash,
    282                                   const MACROBLOCK *const x,
    283                                   const RD_STATS *const rd_stats,
    284                                   MB_RD_RECORD *mb_rd_record) {
    285  int index;
    286  if (mb_rd_record->num < RD_RECORD_BUFFER_LEN) {
    287    index =
    288        (mb_rd_record->index_start + mb_rd_record->num) % RD_RECORD_BUFFER_LEN;
    289    ++mb_rd_record->num;
    290  } else {
    291    index = mb_rd_record->index_start;
    292    mb_rd_record->index_start =
    293        (mb_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
    294  }
    295  MB_RD_INFO *const mb_rd_info = &mb_rd_record->mb_rd_info[index];
    296  const MACROBLOCKD *const xd = &x->e_mbd;
    297  const MB_MODE_INFO *const mbmi = xd->mi[0];
    298  mb_rd_info->hash_value = hash;
    299  mb_rd_info->tx_size = mbmi->tx_size;
    300  memcpy(mb_rd_info->blk_skip, x->txfm_search_info.blk_skip,
    301         sizeof(mb_rd_info->blk_skip[0]) * n4);
    302  av1_copy(mb_rd_info->inter_tx_size, mbmi->inter_tx_size);
    303  av1_copy_array(mb_rd_info->tx_type_map, xd->tx_type_map, n4);
    304  mb_rd_info->rd_stats = *rd_stats;
    305 }
    306 
    307 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
    308                                 const SPEED_FEATURES *sf,
    309                                 int tx_size_search_method) {
    310  if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
    311 
    312  if (sf->tx_sf.tx_size_search_lgr_block) {
    313    if (mi_width > mi_size_wide[BLOCK_64X64] ||
    314        mi_height > mi_size_high[BLOCK_64X64])
    315      return MAX_VARTX_DEPTH;
    316  }
    317 
    318  if (is_inter) {
    319    return (mi_height != mi_width)
    320               ? sf->tx_sf.inter_tx_size_search_init_depth_rect
    321               : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
    322  } else {
    323    return (mi_height != mi_width)
    324               ? sf->tx_sf.intra_tx_size_search_init_depth_rect
    325               : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
    326  }
    327 }
    328 
    329 static inline void select_tx_block(
    330    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
    331    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
    332    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
    333    RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
    334    int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode);
    335 
    336 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
    337 // 0: Do not collect any RD stats
    338 // 1: Collect RD stats for transform units
    339 // 2: Collect RD stats for partition units
    340 #if CONFIG_COLLECT_RD_STATS
    341 
    342 static inline void get_energy_distribution_fine(
    343    const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
    344    const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
    345    double *verdist) {
    346  const int bw = block_size_wide[bsize];
    347  const int bh = block_size_high[bsize];
    348  unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
    349 
    350  if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
    351    // Special cases: calculate 'esq' values manually, as we don't have 'vf'
    352    // functions for the 16 (very small) sub-blocks of this block.
    353    const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
    354    const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
    355    assert(bw <= 32);
    356    assert(bh <= 32);
    357    assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
    358    if (cpi->common.seq_params->use_highbitdepth) {
    359      const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
    360      const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
    361      for (int i = 0; i < bh; ++i)
    362        for (int j = 0; j < bw; ++j) {
    363          const int index = (j >> w_shift) + ((i >> h_shift) << 2);
    364          esq[index] +=
    365              (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
    366              (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
    367        }
    368    } else {
    369      for (int i = 0; i < bh; ++i)
    370        for (int j = 0; j < bw; ++j) {
    371          const int index = (j >> w_shift) + ((i >> h_shift) << 2);
    372          esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
    373                        (src[j + i * src_stride] - dst[j + i * dst_stride]);
    374        }
    375    }
    376  } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
    377    const int f_index =
    378        (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
    379    assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
    380    const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
    381    assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
    382    assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
    383    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
    384    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
    385                                 dst_stride, &esq[1]);
    386    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
    387                                 dst_stride, &esq[2]);
    388    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
    389                                 dst_stride, &esq[3]);
    390    src += bh / 4 * src_stride;
    391    dst += bh / 4 * dst_stride;
    392 
    393    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
    394    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
    395                                 dst_stride, &esq[5]);
    396    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
    397                                 dst_stride, &esq[6]);
    398    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
    399                                 dst_stride, &esq[7]);
    400    src += bh / 4 * src_stride;
    401    dst += bh / 4 * dst_stride;
    402 
    403    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
    404    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
    405                                 dst_stride, &esq[9]);
    406    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
    407                                 dst_stride, &esq[10]);
    408    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
    409                                 dst_stride, &esq[11]);
    410    src += bh / 4 * src_stride;
    411    dst += bh / 4 * dst_stride;
    412 
    413    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
    414    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
    415                                 dst_stride, &esq[13]);
    416    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
    417                                 dst_stride, &esq[14]);
    418    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
    419                                 dst_stride, &esq[15]);
    420  }
    421 
    422  double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
    423                 esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
    424                 esq[12] + esq[13] + esq[14] + esq[15];
    425  if (total > 0) {
    426    const double e_recip = 1.0 / total;
    427    hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
    428    hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
    429    hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
    430    if (need_4th) {
    431      hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
    432    }
    433    verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
    434    verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
    435    verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
    436    if (need_4th) {
    437      verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
    438    }
    439  } else {
    440    hordist[0] = verdist[0] = 0.25;
    441    hordist[1] = verdist[1] = 0.25;
    442    hordist[2] = verdist[2] = 0.25;
    443    if (need_4th) {
    444      hordist[3] = verdist[3] = 0.25;
    445    }
    446  }
    447 }
    448 
    449 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
    450  double sum = 0.0;
    451  for (int j = 0; j < h; ++j) {
    452    for (int i = 0; i < w; ++i) {
    453      const int err = diff[j * stride + i];
    454      sum += err * err;
    455    }
    456  }
    457  assert(w > 0 && h > 0);
    458  return sum / (w * h);
    459 }
    460 
    461 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
    462  double sum = 0.0;
    463  for (int j = 0; j < h; ++j) {
    464    for (int i = 0; i < w; ++i) {
    465      sum += abs(diff[j * stride + i]);
    466    }
    467  }
    468  assert(w > 0 && h > 0);
    469  return sum / (w * h);
    470 }
    471 
    472 static inline void get_2x2_normalized_sses_and_sads(
    473    const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
    474    int src_stride, const uint8_t *const dst, int dst_stride,
    475    const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
    476    double *const sad_norm_arr) {
    477  const BLOCK_SIZE tx_bsize_half =
    478      get_partition_subsize(tx_bsize, PARTITION_SPLIT);
    479  if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
    480    const int half_width = block_size_wide[tx_bsize] / 2;
    481    const int half_height = block_size_high[tx_bsize] / 2;
    482    for (int row = 0; row < 2; ++row) {
    483      for (int col = 0; col < 2; ++col) {
    484        const int16_t *const this_src_diff =
    485            src_diff + row * half_height * diff_stride + col * half_width;
    486        if (sse_norm_arr) {
    487          sse_norm_arr[row * 2 + col] =
    488              get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
    489        }
    490        if (sad_norm_arr) {
    491          sad_norm_arr[row * 2 + col] =
    492              get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
    493        }
    494      }
    495    }
    496  } else {  // use function pointers to calculate stats
    497    const int half_width = block_size_wide[tx_bsize_half];
    498    const int half_height = block_size_high[tx_bsize_half];
    499    const int num_samples_half = half_width * half_height;
    500    for (int row = 0; row < 2; ++row) {
    501      for (int col = 0; col < 2; ++col) {
    502        const uint8_t *const this_src =
    503            src + row * half_height * src_stride + col * half_width;
    504        const uint8_t *const this_dst =
    505            dst + row * half_height * dst_stride + col * half_width;
    506 
    507        if (sse_norm_arr) {
    508          unsigned int this_sse;
    509          cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
    510                                             dst_stride, &this_sse);
    511          sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
    512        }
    513 
    514        if (sad_norm_arr) {
    515          const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf(
    516              this_src, src_stride, this_dst, dst_stride);
    517          sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
    518        }
    519      }
    520    }
    521  }
    522 }
    523 
    524 #if CONFIG_COLLECT_RD_STATS == 1
    525 static double get_mean(const int16_t *diff, int stride, int w, int h) {
    526  double sum = 0.0;
    527  for (int j = 0; j < h; ++j) {
    528    for (int i = 0; i < w; ++i) {
    529      sum += diff[j * stride + i];
    530    }
    531  }
    532  assert(w > 0 && h > 0);
    533  return sum / (w * h);
    534 }
    535 static inline void PrintTransformUnitStats(
    536    const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
    537    int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
    538    TX_TYPE tx_type, int64_t rd) {
    539  if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
    540 
    541  // Generate small sample to restrict output size.
    542  static unsigned int seed = 21743;
    543  if (lcg_rand16(&seed) % 256 > 0) return;
    544 
    545  const char output_file[] = "tu_stats.txt";
    546  FILE *fout = fopen(output_file, "a");
    547  if (!fout) return;
    548 
    549  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
    550  const MACROBLOCKD *const xd = &x->e_mbd;
    551  const int plane = 0;
    552  struct macroblock_plane *const p = &x->plane[plane];
    553  const struct macroblockd_plane *const pd = &xd->plane[plane];
    554  const int txw = tx_size_wide[tx_size];
    555  const int txh = tx_size_high[tx_size];
    556  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
    557  const int q_step = p->dequant_QTX[1] >> dequant_shift;
    558  const int num_samples = txw * txh;
    559 
    560  const double rate_norm = (double)rd_stats->rate / num_samples;
    561  const double dist_norm = (double)rd_stats->dist / num_samples;
    562 
    563  fprintf(fout, "%g %g", rate_norm, dist_norm);
    564 
    565  const int src_stride = p->src.stride;
    566  const uint8_t *const src =
    567      &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
    568  const int dst_stride = pd->dst.stride;
    569  const uint8_t *const dst =
    570      &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
    571  unsigned int sse;
    572  cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
    573  const double sse_norm = (double)sse / num_samples;
    574 
    575  const unsigned int sad =
    576      cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
    577  const double sad_norm = (double)sad / num_samples;
    578 
    579  fprintf(fout, " %g %g", sse_norm, sad_norm);
    580 
    581  const int diff_stride = block_size_wide[plane_bsize];
    582  const int16_t *const src_diff =
    583      &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
    584 
    585  double sse_norm_arr[4], sad_norm_arr[4];
    586  get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
    587                                   dst_stride, src_diff, diff_stride,
    588                                   sse_norm_arr, sad_norm_arr);
    589  for (int i = 0; i < 4; ++i) {
    590    fprintf(fout, " %g", sse_norm_arr[i]);
    591  }
    592  for (int i = 0; i < 4; ++i) {
    593    fprintf(fout, " %g", sad_norm_arr[i]);
    594  }
    595 
    596  const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
    597  const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
    598 
    599  fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
    600          tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
    601 
    602  int model_rate;
    603  int64_t model_dist;
    604  model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
    605                                   &model_rate, &model_dist);
    606  const double model_rate_norm = (double)model_rate / num_samples;
    607  const double model_dist_norm = (double)model_dist / num_samples;
    608  fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
    609 
    610  const double mean = get_mean(src_diff, diff_stride, txw, txh);
    611  float hor_corr, vert_corr;
    612  av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
    613                                  &vert_corr);
    614  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
    615 
    616  double hdist[4] = { 0 }, vdist[4] = { 0 };
    617  get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
    618                               1, hdist, vdist);
    619  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
    620          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
    621 
    622  fprintf(fout, " %d %" PRId64, x->rdmult, rd);
    623 
    624  fprintf(fout, "\n");
    625  fclose(fout);
    626 }
    627 #endif  // CONFIG_COLLECT_RD_STATS == 1
    628 
    629 #if CONFIG_COLLECT_RD_STATS >= 2
    630 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
    631  const AV1_COMMON *cm = &cpi->common;
    632  const int num_planes = av1_num_planes(cm);
    633  const MACROBLOCKD *xd = &x->e_mbd;
    634  const MB_MODE_INFO *mbmi = xd->mi[0];
    635  int64_t total_sse = 0;
    636  for (int plane = 0; plane < num_planes; ++plane) {
    637    const struct macroblock_plane *const p = &x->plane[plane];
    638    const struct macroblockd_plane *const pd = &xd->plane[plane];
    639    const BLOCK_SIZE bs =
    640        get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
    641    unsigned int sse;
    642 
    643    if (plane) continue;
    644 
    645    cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf,
    646                            pd->dst.stride, &sse);
    647    total_sse += sse;
    648  }
    649  total_sse <<= 4;
    650  return total_sse;
    651 }
    652 
    653 static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
    654                             int64_t sse, int *est_residue_cost,
    655                             int64_t *est_dist) {
    656  const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
    657  if (md->ready) {
    658    if (sse < md->dist_mean) {
    659      *est_residue_cost = 0;
    660      *est_dist = sse;
    661    } else {
    662      *est_dist = (int64_t)round(md->dist_mean);
    663      const double est_ld = md->a * sse + md->b;
    664      // Clamp estimated rate cost by INT_MAX / 2.
    665      // TODO(angiebird@google.com): find better solution than clamping.
    666      if (fabs(est_ld) < 1e-2) {
    667        *est_residue_cost = INT_MAX / 2;
    668      } else {
    669        double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
    670        if (est_residue_cost_dbl < 0) {
    671          *est_residue_cost = 0;
    672        } else {
    673          *est_residue_cost =
    674              (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
    675        }
    676      }
    677      if (*est_residue_cost <= 0) {
    678        *est_residue_cost = 0;
    679        *est_dist = sse;
    680      }
    681    }
    682    return 1;
    683  }
    684  return 0;
    685 }
    686 
    687 static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
    688                                   const uint8_t *dst8, int dst_stride, int w,
    689                                   int h) {
    690  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    691  const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
    692  double sum = 0.0;
    693  for (int j = 0; j < h; ++j) {
    694    for (int i = 0; i < w; ++i) {
    695      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
    696      sum += diff;
    697    }
    698  }
    699  assert(w > 0 && h > 0);
    700  return sum / (w * h);
    701 }
    702 
    703 static double get_diff_mean(const uint8_t *src, int src_stride,
    704                            const uint8_t *dst, int dst_stride, int w, int h) {
    705  double sum = 0.0;
    706  for (int j = 0; j < h; ++j) {
    707    for (int i = 0; i < w; ++i) {
    708      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
    709      sum += diff;
    710    }
    711  }
    712  assert(w > 0 && h > 0);
    713  return sum / (w * h);
    714 }
    715 
    716 static inline void PrintPredictionUnitStats(const AV1_COMP *const cpi,
    717                                            const TileDataEnc *tile_data,
    718                                            MACROBLOCK *x,
    719                                            const RD_STATS *const rd_stats,
    720                                            BLOCK_SIZE plane_bsize) {
    721  if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
    722 
    723  if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
    724      (tile_data == NULL ||
    725       !tile_data->inter_mode_rd_models[plane_bsize].ready))
    726    return;
    727  (void)tile_data;
    728  // Generate small sample to restrict output size.
    729  static unsigned int seed = 95014;
    730 
    731  if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
    732      1)
    733    return;
    734 
    735  const char output_file[] = "pu_stats.txt";
    736  FILE *fout = fopen(output_file, "a");
    737  if (!fout) return;
    738 
    739  MACROBLOCKD *const xd = &x->e_mbd;
    740  const int plane = 0;
    741  struct macroblock_plane *const p = &x->plane[plane];
    742  struct macroblockd_plane *pd = &xd->plane[plane];
    743  const int diff_stride = block_size_wide[plane_bsize];
    744  int bw, bh;
    745  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
    746                     &bh);
    747  const int num_samples = bw * bh;
    748  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
    749  const int q_step = p->dequant_QTX[1] >> dequant_shift;
    750  const int shift = (xd->bd - 8);
    751 
    752  const double rate_norm = (double)rd_stats->rate / num_samples;
    753  const double dist_norm = (double)rd_stats->dist / num_samples;
    754  const double rdcost_norm =
    755      (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
    756 
    757  fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
    758 
    759  const int src_stride = p->src.stride;
    760  const uint8_t *const src = p->src.buf;
    761  const int dst_stride = pd->dst.stride;
    762  const uint8_t *const dst = pd->dst.buf;
    763  const int16_t *const src_diff = p->src_diff;
    764 
    765  int64_t sse = calculate_sse(xd, p, pd, bw, bh);
    766  const double sse_norm = (double)sse / num_samples;
    767 
    768  const unsigned int sad =
    769      cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
    770  const double sad_norm =
    771      (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
    772 
    773  fprintf(fout, " %g %g", sse_norm, sad_norm);
    774 
    775  double sse_norm_arr[4], sad_norm_arr[4];
    776  get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
    777                                   dst_stride, src_diff, diff_stride,
    778                                   sse_norm_arr, sad_norm_arr);
    779  if (shift) {
    780    for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
    781    for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
    782  }
    783  for (int i = 0; i < 4; ++i) {
    784    fprintf(fout, " %g", sse_norm_arr[i]);
    785  }
    786  for (int i = 0; i < 4; ++i) {
    787    fprintf(fout, " %g", sad_norm_arr[i]);
    788  }
    789 
    790  fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
    791 
    792  int model_rate;
    793  int64_t model_dist;
    794  model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
    795                                   &model_rate, &model_dist);
    796  const double model_rdcost_norm =
    797      (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
    798  const double model_rate_norm = (double)model_rate / num_samples;
    799  const double model_dist_norm = (double)model_dist / num_samples;
    800  fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
    801          model_rdcost_norm);
    802 
    803  double mean;
    804  if (is_cur_buf_hbd(xd)) {
    805    mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
    806                                pd->dst.stride, bw, bh);
    807  } else {
    808    mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
    809                         bw, bh);
    810  }
    811  mean /= (1 << shift);
    812  float hor_corr, vert_corr;
    813  av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
    814                                  &vert_corr);
    815  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
    816 
    817  double hdist[4] = { 0 }, vdist[4] = { 0 };
    818  get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
    819                               dst_stride, 1, hdist, vdist);
    820  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
    821          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
    822 
    823  if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
    824    assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
    825    const int64_t overall_sse = get_sse(cpi, x);
    826    int est_residue_cost = 0;
    827    int64_t est_dist = 0;
    828    get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
    829                      &est_dist);
    830    const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
    831    const double est_dist_norm = (double)est_dist / num_samples;
    832    const double est_rdcost_norm =
    833        (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
    834    fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
    835            est_rdcost_norm);
    836  }
    837 
    838  fprintf(fout, "\n");
    839  fclose(fout);
    840 }
    841 #endif  // CONFIG_COLLECT_RD_STATS >= 2
    842 #endif  // CONFIG_COLLECT_RD_STATS
    843 
    844 static inline void inverse_transform_block_facade(MACROBLOCK *const x,
    845                                                  int plane, int block,
    846                                                  int blk_row, int blk_col,
    847                                                  int eob, int reduced_tx_set) {
    848  if (!eob) return;
    849  struct macroblock_plane *const p = &x->plane[plane];
    850  MACROBLOCKD *const xd = &x->e_mbd;
    851  tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
    852  const PLANE_TYPE plane_type = get_plane_type(plane);
    853  const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
    854  const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
    855                                          tx_size, reduced_tx_set);
    856 
    857  struct macroblockd_plane *const pd = &xd->plane[plane];
    858  const int dst_stride = pd->dst.stride;
    859  uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
    860  av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
    861                              dst_stride, eob, reduced_tx_set);
    862 }
    863 
    864 static inline void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
    865                               int block, int blk_row, int blk_col,
    866                               BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
    867                               const TXB_CTX *const txb_ctx, int skip_trellis,
    868                               TX_TYPE best_tx_type, int do_quant,
    869                               int *rate_cost, uint16_t best_eob) {
    870  const AV1_COMMON *cm = &cpi->common;
    871  MACROBLOCKD *xd = &x->e_mbd;
    872  MB_MODE_INFO *mbmi = xd->mi[0];
    873  const int is_inter = is_inter_block(mbmi);
    874  if (!is_inter && best_eob &&
    875      (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
    876       blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
    877    // if the quantized coefficients are stored in the dqcoeff buffer, we don't
    878    // need to do transform and quantization again.
    879    if (do_quant) {
    880      TxfmParam txfm_param_intra;
    881      QUANT_PARAM quant_param_intra;
    882      av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
    883      av1_setup_quant(tx_size, !skip_trellis,
    884                      skip_trellis
    885                          ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
    886                                                    : AV1_XFORM_QUANT_FP)
    887                          : AV1_XFORM_QUANT_FP,
    888                      cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra);
    889      av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
    890                        &quant_param_intra);
    891      av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
    892                      &txfm_param_intra, &quant_param_intra);
    893      if (quant_param_intra.use_optimize_b) {
    894        av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
    895                       rate_cost);
    896      }
    897    }
    898 
    899    inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
    900                                   x->plane[plane].eobs[block],
    901                                   cm->features.reduced_tx_set_used);
    902 
    903    // This may happen because of hash collision. The eob stored in the hash
    904    // table is non-zero, but the real eob is zero. We need to make sure tx_type
    905    // is DCT_DCT in this case.
    906    if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
    907        best_tx_type != DCT_DCT) {
    908      update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
    909    }
    910  }
    911 }
    912 
    913 static unsigned pixel_dist_visible_only(
    914    const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
    915    const int src_stride, const uint8_t *dst, const int dst_stride,
    916    const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
    917    int visible_cols) {
    918  unsigned sse;
    919 
    920  if (txb_rows == visible_rows && txb_cols == visible_cols) {
    921    cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
    922    return sse;
    923  }
    924 
    925 #if CONFIG_AV1_HIGHBITDEPTH
    926  const MACROBLOCKD *xd = &x->e_mbd;
    927  if (is_cur_buf_hbd(xd)) {
    928    uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
    929                                             visible_cols, visible_rows);
    930    return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
    931  }
    932 #else
    933  (void)x;
    934 #endif
    935  sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
    936                         visible_rows);
    937  return sse;
    938 }
    939 
    940 // Compute the pixel domain distortion from src and dst on all visible 4x4s in
    941 // the
    942 // transform block.
    943 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
    944                           int plane, const uint8_t *src, const int src_stride,
    945                           const uint8_t *dst, const int dst_stride,
    946                           int blk_row, int blk_col,
    947                           const BLOCK_SIZE plane_bsize,
    948                           const BLOCK_SIZE tx_bsize) {
    949  int txb_rows, txb_cols, visible_rows, visible_cols;
    950  const MACROBLOCKD *xd = &x->e_mbd;
    951 
    952  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
    953                     &txb_cols, &txb_rows, &visible_cols, &visible_rows);
    954  assert(visible_rows > 0);
    955  assert(visible_cols > 0);
    956 
    957  unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
    958                                         dst_stride, tx_bsize, txb_rows,
    959                                         txb_cols, visible_rows, visible_cols);
    960 
    961  return sse;
    962 }
    963 
    964 static inline int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
    965                                           int plane, BLOCK_SIZE plane_bsize,
    966                                           int block, int blk_row, int blk_col,
    967                                           TX_SIZE tx_size) {
    968  MACROBLOCKD *const xd = &x->e_mbd;
    969  const struct macroblock_plane *const p = &x->plane[plane];
    970  const uint16_t eob = p->eobs[block];
    971  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
    972  const int bsw = block_size_wide[tx_bsize];
    973  const int bsh = block_size_high[tx_bsize];
    974  const int src_stride = x->plane[plane].src.stride;
    975  const int dst_stride = xd->plane[plane].dst.stride;
    976  // Scale the transform block index to pixel unit.
    977  const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
    978  const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
    979  const uint8_t *src = &x->plane[plane].src.buf[src_idx];
    980  const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
    981  const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
    982 
    983  assert(cpi != NULL);
    984  assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
    985 
    986  uint8_t *recon;
    987  DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
    988 
    989 #if CONFIG_AV1_HIGHBITDEPTH
    990  if (is_cur_buf_hbd(xd)) {
    991    recon = CONVERT_TO_BYTEPTR(recon16);
    992    aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride,
    993                             CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh);
    994  } else {
    995    recon = (uint8_t *)recon16;
    996    aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
    997  }
    998 #else
    999  recon = (uint8_t *)recon16;
   1000  aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
   1001 #endif
   1002 
   1003  const PLANE_TYPE plane_type = get_plane_type(plane);
   1004  TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
   1005                                    cpi->common.features.reduced_tx_set_used);
   1006  av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
   1007                              MAX_TX_SIZE, eob,
   1008                              cpi->common.features.reduced_tx_set_used);
   1009 
   1010  return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
   1011                         blk_row, blk_col, plane_bsize, tx_bsize);
   1012 }
   1013 
   1014 // pruning thresholds for prune_txk_type and prune_txk_type_separ
   1015 static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
   1016 static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
   1017 
   1018 // R-D costs are sorted in ascending order.
   1019 static inline void sort_rd(int64_t rds[], int txk[], int len) {
   1020  int i, j, k;
   1021 
   1022  for (i = 1; i <= len - 1; ++i) {
   1023    for (j = 0; j < i; ++j) {
   1024      if (rds[j] > rds[i]) {
   1025        int64_t temprd;
   1026        int tempi;
   1027 
   1028        temprd = rds[i];
   1029        tempi = txk[i];
   1030 
   1031        for (k = i; k > j; k--) {
   1032          rds[k] = rds[k - 1];
   1033          txk[k] = txk[k - 1];
   1034        }
   1035 
   1036        rds[j] = temprd;
   1037        txk[j] = tempi;
   1038        break;
   1039      }
   1040    }
   1041  }
   1042 }
   1043 
   1044 static inline int64_t av1_block_error_qm(
   1045    const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size,
   1046    const qm_val_t *qmatrix, const int16_t *scan, int64_t *ssz, int bd) {
   1047  int i;
   1048  int64_t error = 0, sqcoeff = 0;
   1049  int shift = 2 * (bd - 8);
   1050  int rounding = (1 << shift) >> 1;
   1051 
   1052  for (i = 0; i < block_size; i++) {
   1053    int64_t weight = qmatrix[scan[i]];
   1054    int64_t dd = coeff[i] - dqcoeff[i];
   1055    dd *= weight;
   1056    int64_t cc = coeff[i];
   1057    cc *= weight;
   1058    // The ranges of coeff and dqcoeff are
   1059    //  bd8 : 18 bits (including sign)
   1060    //  bd10: 20 bits (including sign)
   1061    //  bd12: 22 bits (including sign)
   1062    // As AOM_QM_BITS is 5, the intermediate quantities in the calculation
   1063    // below should fit in 54 bits, thus no overflow should happen.
   1064    error += (dd * dd + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
   1065    sqcoeff += (cc * cc + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
   1066  }
   1067 
   1068  error = (error + rounding) >> shift;
   1069  sqcoeff = (sqcoeff + rounding) >> shift;
   1070 
   1071  *ssz = sqcoeff;
   1072  return error;
   1073 }
   1074 
   1075 static inline void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
   1076                                        TX_SIZE tx_size,
   1077                                        const qm_val_t *qmatrix,
   1078                                        const int16_t *scan, int64_t *out_dist,
   1079                                        int64_t *out_sse) {
   1080  const struct macroblock_plane *const p = &x->plane[plane];
   1081  // Transform domain distortion computation is more efficient as it does
   1082  // not involve an inverse transform, but it is less accurate.
   1083  const int buffer_length = av1_get_max_eob(tx_size);
   1084  int64_t this_sse;
   1085  // TX-domain results need to shift down to Q2/D10 to match pixel
   1086  // domain distortion values which are in Q2^2
   1087  int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
   1088  const int block_offset = BLOCK_OFFSET(block);
   1089  tran_low_t *const coeff = p->coeff + block_offset;
   1090  tran_low_t *const dqcoeff = p->dqcoeff + block_offset;
   1091 #if CONFIG_AV1_HIGHBITDEPTH
   1092  MACROBLOCKD *const xd = &x->e_mbd;
   1093  if (is_cur_buf_hbd(xd)) {
   1094    if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
   1095      *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length,
   1096                                         &this_sse, xd->bd);
   1097    } else {
   1098      *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
   1099                                     scan, &this_sse, xd->bd);
   1100    }
   1101  } else {
   1102 #endif
   1103    if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
   1104      *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
   1105    } else {
   1106      *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
   1107                                     scan, &this_sse, 8);
   1108    }
   1109 #if CONFIG_AV1_HIGHBITDEPTH
   1110  }
   1111 #endif
   1112 
   1113  *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
   1114  *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
   1115 }
   1116 
   1117 static uint16_t prune_txk_type_separ(
   1118    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, TX_SIZE tx_size,
   1119    int blk_row, int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
   1120    int16_t allowed_tx_mask, int prune_factor, const TXB_CTX *const txb_ctx,
   1121    int reduced_tx_set_used, int64_t ref_best_rd, int num_sel) {
   1122  const AV1_COMMON *cm = &cpi->common;
   1123  MACROBLOCKD *xd = &x->e_mbd;
   1124 
   1125  int idx;
   1126 
   1127  int64_t rds_v[4];
   1128  int64_t rds_h[4];
   1129  int idx_v[4] = { 0, 1, 2, 3 };
   1130  int idx_h[4] = { 0, 1, 2, 3 };
   1131  int skip_v[4] = { 0 };
   1132  int skip_h[4] = { 0 };
   1133  const int idx_map[16] = {
   1134    DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
   1135    ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
   1136    FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
   1137    H_DCT,        H_ADST,        H_FLIPADST,        IDTX
   1138  };
   1139 
   1140  const int sel_pattern_v[16] = {
   1141    0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
   1142  };
   1143  const int sel_pattern_h[16] = {
   1144    0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
   1145  };
   1146 
   1147  QUANT_PARAM quant_param;
   1148  TxfmParam txfm_param;
   1149  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
   1150  av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
   1151                  &quant_param);
   1152  int tx_type;
   1153  // to ensure we can try ones even outside of ext_tx_set of current block
   1154  // this function should only be called for size < 16
   1155  assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
   1156  txfm_param.tx_set_type = EXT_TX_SET_ALL16;
   1157 
   1158  int rate_cost = 0;
   1159  int64_t dist = 0, sse = 0;
   1160  // evaluate horizontal with vertical DCT
   1161  for (idx = 0; idx < 4; ++idx) {
   1162    tx_type = idx_map[idx];
   1163    txfm_param.tx_type = tx_type;
   1164 
   1165    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
   1166                      &quant_param);
   1167 
   1168    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
   1169                    &quant_param);
   1170 
   1171    const SCAN_ORDER *const scan_order =
   1172        get_scan(txfm_param.tx_size, txfm_param.tx_type);
   1173    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
   1174                         scan_order->scan, &dist, &sse);
   1175 
   1176    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
   1177                                              txb_ctx, reduced_tx_set_used, 0);
   1178 
   1179    rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
   1180 
   1181    if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
   1182      skip_h[idx] = 1;
   1183    }
   1184  }
   1185  sort_rd(rds_h, idx_h, 4);
   1186  for (idx = 1; idx < 4; idx++) {
   1187    if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
   1188  }
   1189 
   1190  if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
   1191 
   1192  // evaluate vertical with the best horizontal chosen
   1193  rds_v[0] = rds_h[0];
   1194  int start_v = 1, end_v = 4;
   1195  const int *idx_map_v = idx_map + idx_h[0];
   1196 
   1197  for (idx = start_v; idx < end_v; ++idx) {
   1198    tx_type = idx_map_v[idx_v[idx] * 4];
   1199    txfm_param.tx_type = tx_type;
   1200 
   1201    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
   1202                      &quant_param);
   1203 
   1204    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
   1205                    &quant_param);
   1206 
   1207    const SCAN_ORDER *const scan_order =
   1208        get_scan(txfm_param.tx_size, txfm_param.tx_type);
   1209    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
   1210                         scan_order->scan, &dist, &sse);
   1211 
   1212    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
   1213                                              txb_ctx, reduced_tx_set_used, 0);
   1214 
   1215    rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
   1216 
   1217    if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
   1218      skip_v[idx] = 1;
   1219    }
   1220  }
   1221  sort_rd(rds_v, idx_v, 4);
   1222  for (idx = 1; idx < 4; idx++) {
   1223    if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
   1224  }
   1225 
   1226  // combine rd_h and rd_v to prune tx candidates
   1227  int i_v, i_h;
   1228  int64_t rds[16];
   1229  int num_cand = 0, last = TX_TYPES - 1;
   1230 
   1231  for (int i = 0; i < 16; i++) {
   1232    i_v = sel_pattern_v[i];
   1233    i_h = sel_pattern_h[i];
   1234    tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
   1235    if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
   1236        skip_v[idx_v[i_v]]) {
   1237      txk_map[last] = tx_type;
   1238      last--;
   1239    } else {
   1240      txk_map[num_cand] = tx_type;
   1241      rds[num_cand] = rds_v[i_v] + rds_h[i_h];
   1242      if (rds[num_cand] == 0) rds[num_cand] = 1;
   1243      num_cand++;
   1244    }
   1245  }
   1246  sort_rd(rds, txk_map, num_cand);
   1247 
   1248  uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
   1249  num_sel = AOMMIN(num_sel, num_cand);
   1250 
   1251  for (int i = 1; i < num_sel; i++) {
   1252    int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
   1253    if (factor < (int64_t)prune_factor)
   1254      prune &= ~(1 << txk_map[i]);
   1255    else
   1256      break;
   1257  }
   1258  return prune;
   1259 }
   1260 
   1261 static uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
   1262                               int block, TX_SIZE tx_size, int blk_row,
   1263                               int blk_col, BLOCK_SIZE plane_bsize,
   1264                               int *txk_map, uint16_t allowed_tx_mask,
   1265                               int prune_factor, const TXB_CTX *const txb_ctx,
   1266                               int reduced_tx_set_used) {
   1267  const AV1_COMMON *cm = &cpi->common;
   1268  MACROBLOCKD *xd = &x->e_mbd;
   1269  int tx_type;
   1270 
   1271  int64_t rds[TX_TYPES];
   1272 
   1273  int num_cand = 0;
   1274  int last = TX_TYPES - 1;
   1275 
   1276  TxfmParam txfm_param;
   1277  QUANT_PARAM quant_param;
   1278  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
   1279  av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
   1280                  &quant_param);
   1281 
   1282  for (int idx = 0; idx < TX_TYPES; idx++) {
   1283    tx_type = idx;
   1284    int rate_cost = 0;
   1285    int64_t dist = 0, sse = 0;
   1286    if (!(allowed_tx_mask & (1 << tx_type))) {
   1287      txk_map[last] = tx_type;
   1288      last--;
   1289      continue;
   1290    }
   1291    txfm_param.tx_type = tx_type;
   1292 
   1293    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
   1294                      &quant_param);
   1295 
   1296    // do txfm and quantization
   1297    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
   1298                    &quant_param);
   1299    // estimate rate cost
   1300    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
   1301                                              txb_ctx, reduced_tx_set_used, 0);
   1302    // tx domain dist
   1303    const SCAN_ORDER *const scan_order =
   1304        get_scan(txfm_param.tx_size, txfm_param.tx_type);
   1305    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
   1306                         scan_order->scan, &dist, &sse);
   1307 
   1308    txk_map[num_cand] = tx_type;
   1309    rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
   1310    if (rds[num_cand] == 0) rds[num_cand] = 1;
   1311    num_cand++;
   1312  }
   1313 
   1314  if (num_cand == 0) return (uint16_t)0xFFFF;
   1315 
   1316  sort_rd(rds, txk_map, num_cand);
   1317  uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
   1318 
   1319  // 0 < prune_factor <= 1000 controls aggressiveness
   1320  int64_t factor = 0;
   1321  for (int idx = 1; idx < num_cand; idx++) {
   1322    factor = 1000 * (rds[idx] - rds[0]) / rds[0];
   1323    if (factor < (int64_t)prune_factor)
   1324      prune &= ~(1 << txk_map[idx]);
   1325    else
   1326      break;
   1327  }
   1328  return prune;
   1329 }
   1330 
   1331 // These thresholds were calibrated to provide a certain number of TX types
   1332 // pruned by the model on average, i.e. selecting a threshold with index i
   1333 // will lead to pruning i+1 TX types on average
   1334 static const float *prune_2D_adaptive_thresholds[] = {
   1335  // TX_4X4
   1336  (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
   1337             0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
   1338             0.09778f, 0.11780f },
   1339  // TX_8X8
   1340  (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
   1341             0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
   1342             0.10803f, 0.14124f },
   1343  // TX_16X16
   1344  (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
   1345             0.06897f, 0.07629f, 0.08875f, 0.11169f },
   1346  // TX_32X32
   1347  NULL,
   1348  // TX_64X64
   1349  NULL,
   1350  // TX_4X8
   1351  (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
   1352             0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
   1353             0.10168f, 0.12585f },
   1354  // TX_8X4
   1355  (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
   1356             0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
   1357             0.10583f, 0.13123f },
   1358  // TX_8X16
   1359  (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
   1360             0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
   1361             0.10730f, 0.14221f },
   1362  // TX_16X8
   1363  (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
   1364             0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
   1365             0.10339f, 0.13464f },
   1366  // TX_16X32
   1367  NULL,
   1368  // TX_32X16
   1369  NULL,
   1370  // TX_32X64
   1371  NULL,
   1372  // TX_64X32
   1373  NULL,
   1374  // TX_4X16
   1375  (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
   1376             0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
   1377             0.10242f, 0.12878f },
   1378  // TX_16X4
   1379  (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
   1380             0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
   1381             0.10217f, 0.12610f },
   1382  // TX_8X32
   1383  NULL,
   1384  // TX_32X8
   1385  NULL,
   1386  // TX_16X64
   1387  NULL,
   1388  // TX_64X16
   1389  NULL,
   1390 };
   1391 
   1392 static inline float get_adaptive_thresholds(
   1393    TX_SIZE tx_size, TxSetType tx_set_type,
   1394    TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {
   1395  const int prune_aggr_table[5][2] = {
   1396    { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 }
   1397  };
   1398  int pruning_aggressiveness = 0;
   1399  if (tx_set_type == EXT_TX_SET_ALL16)
   1400    pruning_aggressiveness =
   1401        prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0];
   1402  else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
   1403    pruning_aggressiveness =
   1404        prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1];
   1405 
   1406  return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
   1407 }
   1408 
   1409 static inline void get_energy_distribution_finer(const int16_t *diff,
   1410                                                 int stride, int bw, int bh,
   1411                                                 float *hordist,
   1412                                                 float *verdist) {
   1413  // First compute downscaled block energy values (esq); downscale factors
   1414  // are defined by w_shift and h_shift.
   1415  unsigned int esq[256];
   1416  const int w_shift = bw <= 8 ? 0 : 1;
   1417  const int h_shift = bh <= 8 ? 0 : 1;
   1418  const int esq_w = bw >> w_shift;
   1419  const int esq_h = bh >> h_shift;
   1420  const int esq_sz = esq_w * esq_h;
   1421  int i, j;
   1422  memset(esq, 0, esq_sz * sizeof(esq[0]));
   1423  if (w_shift) {
   1424    for (i = 0; i < bh; i++) {
   1425      unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
   1426      const int16_t *cur_diff_row = diff + i * stride;
   1427      for (j = 0; j < bw; j += 2) {
   1428        cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
   1429                                cur_diff_row[j + 1] * cur_diff_row[j + 1]);
   1430      }
   1431    }
   1432  } else {
   1433    for (i = 0; i < bh; i++) {
   1434      unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
   1435      const int16_t *cur_diff_row = diff + i * stride;
   1436      for (j = 0; j < bw; j++) {
   1437        cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
   1438      }
   1439    }
   1440  }
   1441 
   1442  uint64_t total = 0;
   1443  for (i = 0; i < esq_sz; i++) total += esq[i];
   1444 
   1445  // Output hordist and verdist arrays are normalized 1D projections of esq
   1446  if (total == 0) {
   1447    float hor_val = 1.0f / esq_w;
   1448    for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
   1449    float ver_val = 1.0f / esq_h;
   1450    for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
   1451    return;
   1452  }
   1453 
   1454  const float e_recip = 1.0f / (float)total;
   1455  memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
   1456  memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
   1457  const unsigned int *cur_esq_row;
   1458  for (i = 0; i < esq_h - 1; i++) {
   1459    cur_esq_row = esq + i * esq_w;
   1460    for (j = 0; j < esq_w - 1; j++) {
   1461      hordist[j] += (float)cur_esq_row[j];
   1462      verdist[i] += (float)cur_esq_row[j];
   1463    }
   1464    verdist[i] += (float)cur_esq_row[j];
   1465  }
   1466  cur_esq_row = esq + i * esq_w;
   1467  for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
   1468 
   1469  for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
   1470  for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
   1471 }
   1472 
   1473 static inline bool check_bit_mask(uint16_t mask, int val) {
   1474  return mask & (1 << val);
   1475 }
   1476 
   1477 static inline void set_bit_mask(uint16_t *mask, int val) {
   1478  *mask |= (1 << val);
   1479 }
   1480 
   1481 static inline void unset_bit_mask(uint16_t *mask, int val) {
   1482  *mask &= ~(1 << val);
   1483 }
   1484 
   1485 static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
   1486                        int blk_row, int blk_col, TxSetType tx_set_type,
   1487                        TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
   1488                        uint16_t *allowed_tx_mask) {
   1489  // This table is used because the search order is different from the enum
   1490  // order.
   1491  static const int tx_type_table_2D[16] = {
   1492    DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
   1493    ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
   1494    FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
   1495    H_DCT,        H_ADST,        H_FLIPADST,        IDTX
   1496  };
   1497  if (tx_set_type != EXT_TX_SET_ALL16 &&
   1498      tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
   1499    return;
   1500 #if CONFIG_NN_V2
   1501  NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
   1502  NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
   1503 #else
   1504  const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
   1505  const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
   1506 #endif
   1507  if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
   1508 
   1509  float hfeatures[16], vfeatures[16];
   1510  float hscores[4], vscores[4];
   1511  float scores_2D_raw[16];
   1512  const int bw = tx_size_wide[tx_size];
   1513  const int bh = tx_size_high[tx_size];
   1514  const int hfeatures_num = bw <= 8 ? bw : bw / 2;
   1515  const int vfeatures_num = bh <= 8 ? bh : bh / 2;
   1516  assert(hfeatures_num <= 16);
   1517  assert(vfeatures_num <= 16);
   1518 
   1519  const struct macroblock_plane *const p = &x->plane[0];
   1520  const int diff_stride = block_size_wide[bsize];
   1521  const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
   1522  get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
   1523                                vfeatures);
   1524 
   1525  av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
   1526                                  &hfeatures[hfeatures_num - 1],
   1527                                  &vfeatures[vfeatures_num - 1]);
   1528 
   1529 #if CONFIG_NN_V2
   1530  av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
   1531  av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
   1532 #else
   1533  av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
   1534  av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
   1535 #endif
   1536 
   1537  for (int i = 0; i < 4; i++) {
   1538    float *cur_scores_2D = scores_2D_raw + i * 4;
   1539    cur_scores_2D[0] = vscores[i] * hscores[0];
   1540    cur_scores_2D[1] = vscores[i] * hscores[1];
   1541    cur_scores_2D[2] = vscores[i] * hscores[2];
   1542    cur_scores_2D[3] = vscores[i] * hscores[3];
   1543  }
   1544 
   1545  assert(TX_TYPES == 16);
   1546  // This version of the function only works when there are at most 16 classes.
   1547  // So we will need to change the optimization or use av1_nn_softmax instead if
   1548  // this ever gets changed.
   1549  av1_nn_fast_softmax_16(scores_2D_raw, scores_2D_raw);
   1550 
   1551  const float score_thresh =
   1552      get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
   1553 
   1554  // Always keep the TX type with the highest score, prune all others with
   1555  // score below score_thresh.
   1556  int max_score_i = 0;
   1557  float max_score = 0.0f;
   1558  uint16_t allow_bitmask = 0;
   1559  float sum_score = 0.0;
   1560  // Calculate sum of allowed tx type score and Populate allow bit mask based
   1561  // on score_thresh and allowed_tx_mask
   1562  int allow_count = 0;
   1563  int tx_type_allowed[16] = { TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
   1564                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
   1565                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
   1566                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
   1567                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
   1568                              TX_TYPE_INVALID };
   1569  float scores_2D[16] = {
   1570    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
   1571  };
   1572  for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
   1573    const int allow_tx_type =
   1574        check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]);
   1575    if (!allow_tx_type) {
   1576      continue;
   1577    }
   1578    if (scores_2D_raw[tx_idx] > max_score) {
   1579      max_score = scores_2D_raw[tx_idx];
   1580      max_score_i = tx_idx;
   1581    }
   1582    if (scores_2D_raw[tx_idx] >= score_thresh) {
   1583      // Set allow mask based on score_thresh
   1584      set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
   1585 
   1586      // Accumulate score of allowed tx type
   1587      sum_score += scores_2D_raw[tx_idx];
   1588 
   1589      scores_2D[allow_count] = scores_2D_raw[tx_idx];
   1590      tx_type_allowed[allow_count] = tx_type_table_2D[tx_idx];
   1591      allow_count += 1;
   1592    }
   1593  }
   1594  if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) {
   1595    // If even the tx_type with max score is pruned, this means that no other
   1596    // tx_type is feasible. When this happens, we force enable max_score_i and
   1597    // end the search.
   1598    set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]);
   1599    memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
   1600    *allowed_tx_mask = allow_bitmask;
   1601    return;
   1602  }
   1603 
   1604  // Sort tx type probability of all types
   1605  if (allow_count <= 8) {
   1606    av1_sort_fi32_8(scores_2D, tx_type_allowed);
   1607  } else {
   1608    av1_sort_fi32_16(scores_2D, tx_type_allowed);
   1609  }
   1610 
   1611  // Enable more pruning based on tx type probability and number of allowed tx
   1612  // types
   1613  if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) {
   1614    float temp_score = 0.0;
   1615    float score_ratio = 0.0;
   1616    int tx_idx, tx_count = 0;
   1617    const float inv_sum_score = 100 / sum_score;
   1618    // Get allowed tx types based on sorted probability score and tx count
   1619    for (tx_idx = 0; tx_idx < allow_count; tx_idx++) {
   1620      // Skip the tx type which has more than 30% of cumulative
   1621      // probability and allowed tx type count is more than 2
   1622      if (score_ratio > 30.0 && tx_count >= 2) break;
   1623 
   1624      assert(check_bit_mask(allow_bitmask, tx_type_allowed[tx_idx]));
   1625      // Calculate cumulative probability
   1626      temp_score += scores_2D[tx_idx];
   1627 
   1628      // Calculate percentage of cumulative probability of allowed tx type
   1629      score_ratio = temp_score * inv_sum_score;
   1630      tx_count++;
   1631    }
   1632    // Set remaining tx types as pruned
   1633    for (; tx_idx < allow_count; tx_idx++)
   1634      unset_bit_mask(&allow_bitmask, tx_type_allowed[tx_idx]);
   1635  }
   1636 
   1637  memcpy(txk_map, tx_type_allowed, sizeof(tx_type_table_2D));
   1638  *allowed_tx_mask = allow_bitmask;
   1639 }
   1640 
   1641 static float get_dev(float mean, double x2_sum, int num) {
   1642  const float e_x2 = (float)(x2_sum / num);
   1643  const float diff = e_x2 - mean * mean;
   1644  const float dev = (diff > 0) ? sqrtf(diff) : 0;
   1645  return dev;
   1646 }
   1647 
   1648 // Writes the features required by the ML model to predict tx split based on
   1649 // mean and standard deviation values of the block and sub-blocks.
   1650 // Returns the number of elements written to the output array which is at most
   1651 // 12 currently. Hence 'features' buffer should be able to accommodate at least
   1652 // 12 elements.
   1653 static inline int get_mean_dev_features(const int16_t *data, int stride, int bw,
   1654                                        int bh, float *features) {
   1655  const int16_t *const data_ptr = &data[0];
   1656  const int subh = (bh >= bw) ? (bh >> 1) : bh;
   1657  const int subw = (bw >= bh) ? (bw >> 1) : bw;
   1658  const int num = bw * bh;
   1659  const int sub_num = subw * subh;
   1660  int feature_idx = 2;
   1661  int total_x_sum = 0;
   1662  int64_t total_x2_sum = 0;
   1663  int num_sub_blks = 0;
   1664  double mean2_sum = 0.0f;
   1665  float dev_sum = 0.0f;
   1666 
   1667  for (int row = 0; row < bh; row += subh) {
   1668    for (int col = 0; col < bw; col += subw) {
   1669      int x_sum;
   1670      int64_t x2_sum;
   1671      // TODO(any): Write a SIMD version. Clear registers.
   1672      aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
   1673                          &x_sum, &x2_sum);
   1674      total_x_sum += x_sum;
   1675      total_x2_sum += x2_sum;
   1676 
   1677      const float mean = (float)x_sum / sub_num;
   1678      const float dev = get_dev(mean, (double)x2_sum, sub_num);
   1679      features[feature_idx++] = mean;
   1680      features[feature_idx++] = dev;
   1681      mean2_sum += (double)(mean * mean);
   1682      dev_sum += dev;
   1683      num_sub_blks++;
   1684    }
   1685  }
   1686 
   1687  const float lvl0_mean = (float)total_x_sum / num;
   1688  features[0] = lvl0_mean;
   1689  features[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
   1690 
   1691  // Deviation of means.
   1692  features[feature_idx++] = get_dev(lvl0_mean, mean2_sum, num_sub_blks);
   1693  // Mean of deviations.
   1694  features[feature_idx++] = dev_sum / num_sub_blks;
   1695 
   1696  return feature_idx;
   1697 }
   1698 
   1699 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
   1700                               int blk_col, TX_SIZE tx_size) {
   1701  const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
   1702  if (!nn_config) return -1;
   1703 
   1704  const int diff_stride = block_size_wide[bsize];
   1705  const int16_t *diff =
   1706      x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
   1707  const int bw = tx_size_wide[tx_size];
   1708  const int bh = tx_size_high[tx_size];
   1709 
   1710  float features[64] = { 0.0f };
   1711  get_mean_dev_features(diff, diff_stride, bw, bh, features);
   1712 
   1713  float score = 0.0f;
   1714  av1_nn_predict(features, nn_config, 1, &score);
   1715 
   1716  int int_score = (int)(score * 10000);
   1717  return clamp(int_score, -80000, 80000);
   1718 }
   1719 
   1720 static inline uint16_t get_tx_mask(
   1721    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, int blk_row,
   1722    int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
   1723    const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
   1724    int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
   1725  const AV1_COMMON *cm = &cpi->common;
   1726  MACROBLOCKD *xd = &x->e_mbd;
   1727  MB_MODE_INFO *mbmi = xd->mi[0];
   1728  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
   1729  const int is_inter = is_inter_block(mbmi);
   1730  const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
   1731  // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
   1732  // TX_TYPES, only that specific tx type is allowed.
   1733  TX_TYPE txk_allowed = TX_TYPES;
   1734 
   1735  const FRAME_UPDATE_TYPE update_type =
   1736      get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index);
   1737  int use_actual_frame_probs = 1;
   1738  const int *tx_type_probs;
   1739 #if CONFIG_FPMT_TEST
   1740  use_actual_frame_probs =
   1741      (cpi->ppi->fpmt_unit_test_cfg == PARALLEL_SIMULATION_ENCODE) ? 0 : 1;
   1742  if (!use_actual_frame_probs) {
   1743    tx_type_probs =
   1744        (int *)cpi->ppi->temp_frame_probs.tx_type_probs[update_type][tx_size];
   1745  }
   1746 #endif
   1747  if (use_actual_frame_probs) {
   1748    tx_type_probs = cpi->ppi->frame_probs.tx_type_probs[update_type][tx_size];
   1749  }
   1750 
   1751  if ((!is_inter && txfm_params->use_default_intra_tx_type) ||
   1752      (is_inter && txfm_params->default_inter_tx_type_prob_thresh == 0)) {
   1753    txk_allowed =
   1754        get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools);
   1755  } else if (is_inter &&
   1756             txfm_params->default_inter_tx_type_prob_thresh != INT_MAX) {
   1757    if (tx_type_probs[DEFAULT_INTER_TX_TYPE] >
   1758        txfm_params->default_inter_tx_type_prob_thresh) {
   1759      txk_allowed = DEFAULT_INTER_TX_TYPE;
   1760    } else {
   1761      int force_tx_type = 0;
   1762      int max_prob = 0;
   1763      const int tx_type_prob_threshold =
   1764          txfm_params->default_inter_tx_type_prob_thresh +
   1765          PROB_THRESH_OFFSET_TX_TYPE;
   1766      for (int i = 1; i < TX_TYPES; i++) {  // find maximum probability.
   1767        if (tx_type_probs[i] > max_prob) {
   1768          max_prob = tx_type_probs[i];
   1769          force_tx_type = i;
   1770        }
   1771      }
   1772      if (max_prob > tx_type_prob_threshold)  // force tx type with max prob.
   1773        txk_allowed = force_tx_type;
   1774      else if (x->rd_model == LOW_TXFM_RD) {
   1775        if (plane == 0) txk_allowed = DCT_DCT;
   1776      }
   1777    }
   1778  } else if (x->rd_model == LOW_TXFM_RD) {
   1779    if (plane == 0) txk_allowed = DCT_DCT;
   1780  }
   1781 
   1782  const TxSetType tx_set_type = av1_get_ext_tx_set_type(
   1783      tx_size, is_inter, cm->features.reduced_tx_set_used);
   1784 
   1785  TX_TYPE uv_tx_type = DCT_DCT;
   1786  if (plane) {
   1787    // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
   1788    uv_tx_type = txk_allowed =
   1789        av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
   1790                        cm->features.reduced_tx_set_used);
   1791  }
   1792  PREDICTION_MODE intra_dir =
   1793      mbmi->filter_intra_mode_info.use_filter_intra
   1794          ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
   1795          : mbmi->mode;
   1796  uint16_t ext_tx_used_flag =
   1797      cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset != 0 &&
   1798              tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
   1799          ? av1_reduced_intra_tx_used_flag[intra_dir]
   1800          : av1_ext_tx_used_flag[tx_set_type];
   1801 
   1802  if (cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset == 2)
   1803    ext_tx_used_flag &= av1_derived_intra_tx_used_flag[intra_dir];
   1804 
   1805  if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
   1806      ext_tx_used_flag == 0x0001 ||
   1807      (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) ||
   1808      (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) {
   1809    txk_allowed = DCT_DCT;
   1810  }
   1811 
   1812  if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0)
   1813    ext_tx_used_flag &= DCT_ADST_TX_MASK;
   1814 
   1815  uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
   1816  if (txk_allowed < TX_TYPES) {
   1817    allowed_tx_mask = 1 << txk_allowed;
   1818    allowed_tx_mask &= ext_tx_used_flag;
   1819  } else if (fast_tx_search) {
   1820    allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
   1821    allowed_tx_mask &= ext_tx_used_flag;
   1822  } else {
   1823    assert(plane == 0);
   1824    allowed_tx_mask = ext_tx_used_flag;
   1825    int num_allowed = 0;
   1826    int i;
   1827 
   1828    if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
   1829      static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
   1830                                            { 10, 17, 17, 10, 17, 17, 17 } };
   1831      const int thresh =
   1832          thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
   1833                    [update_type];
   1834      uint16_t prune = 0;
   1835      int max_prob = -1;
   1836      int max_idx = 0;
   1837      for (i = 0; i < TX_TYPES; i++) {
   1838        if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
   1839          max_prob = tx_type_probs[i];
   1840          max_idx = i;
   1841        }
   1842        if (tx_type_probs[i] < thresh) prune |= (1 << i);
   1843      }
   1844      if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
   1845      allowed_tx_mask &= (~prune);
   1846    }
   1847    for (i = 0; i < TX_TYPES; i++) {
   1848      if (allowed_tx_mask & (1 << i)) num_allowed++;
   1849    }
   1850    assert(num_allowed > 0);
   1851 
   1852    if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
   1853      int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
   1854      int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
   1855      if (num_allowed <= 7) {
   1856        const uint16_t prune =
   1857            prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
   1858                           plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
   1859                           cm->features.reduced_tx_set_used);
   1860        allowed_tx_mask &= (~prune);
   1861      } else {
   1862        const int num_sel = (num_allowed * mf + 50) / 100;
   1863        const uint16_t prune = prune_txk_type_separ(
   1864            cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
   1865            txk_map, allowed_tx_mask, pf, txb_ctx,
   1866            cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
   1867 
   1868        allowed_tx_mask &= (~prune);
   1869      }
   1870    } else {
   1871      assert(num_allowed > 0);
   1872      int allowed_tx_count =
   1873          (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5;
   1874      // !fast_tx_search && txk_end != txk_start && plane == 0
   1875      if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter &&
   1876          num_allowed > allowed_tx_count) {
   1877        prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
   1878                    txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask);
   1879      }
   1880    }
   1881  }
   1882 
   1883  // Need to have at least one transform type allowed.
   1884  if (allowed_tx_mask == 0) {
   1885    txk_allowed = (plane ? uv_tx_type : DCT_DCT);
   1886    allowed_tx_mask = (1 << txk_allowed);
   1887  }
   1888 
   1889  assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
   1890  *allowed_txk_types = txk_allowed;
   1891  return allowed_tx_mask;
   1892 }
   1893 
   1894 #if CONFIG_RD_DEBUG
   1895 static inline void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
   1896                                         int txb_coeff_cost) {
   1897  rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
   1898 }
   1899 #endif
   1900 
   1901 static inline int cost_coeffs(MACROBLOCK *x, int plane, int block,
   1902                              TX_SIZE tx_size, const TX_TYPE tx_type,
   1903                              const TXB_CTX *const txb_ctx,
   1904                              int reduced_tx_set_used) {
   1905 #if TXCOEFF_COST_TIMER
   1906  struct aom_usec_timer timer;
   1907  aom_usec_timer_start(&timer);
   1908 #endif
   1909  const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
   1910                                       txb_ctx, reduced_tx_set_used);
   1911 #if TXCOEFF_COST_TIMER
   1912  AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
   1913  aom_usec_timer_mark(&timer);
   1914  const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
   1915  tmp_cm->txcoeff_cost_timer += elapsed_time;
   1916  ++tmp_cm->txcoeff_cost_count;
   1917 #endif
   1918  return cost;
   1919 }
   1920 
   1921 static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
   1922                                          QUANT_PARAM *quant_param, int plane,
   1923                                          int block, TX_SIZE tx_size,
   1924                                          int quant_b_adapt, int qstep,
   1925                                          unsigned int coeff_opt_satd_threshold,
   1926                                          int skip_trellis, int dc_only_blk) {
   1927  if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX))
   1928    return skip_trellis;
   1929 
   1930  const struct macroblock_plane *const p = &x->plane[plane];
   1931  const int block_offset = BLOCK_OFFSET(block);
   1932  tran_low_t *const coeff_ptr = p->coeff + block_offset;
   1933  const int n_coeffs = av1_get_max_eob(tx_size);
   1934  const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size));
   1935  int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs);
   1936  satd = RIGHT_SIGNED_SHIFT(satd, shift);
   1937  satd >>= (x->e_mbd.bd - 8);
   1938 
   1939  const int skip_block_trellis =
   1940      ((uint64_t)satd >
   1941       (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]);
   1942 
   1943  av1_setup_quant(
   1944      tx_size, !skip_block_trellis,
   1945      skip_block_trellis
   1946          ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP)
   1947          : AV1_XFORM_QUANT_FP,
   1948      quant_b_adapt, quant_param);
   1949 
   1950  return skip_block_trellis;
   1951 }
   1952 
   1953 // Predict DC only blocks if the residual variance is below a qstep based
   1954 // threshold.For such blocks, transform type search is bypassed.
   1955 static inline void predict_dc_only_block(
   1956    MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
   1957    int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
   1958    int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
   1959    int *dc_only_blk) {
   1960  MACROBLOCKD *xd = &x->e_mbd;
   1961  MB_MODE_INFO *mbmi = xd->mi[0];
   1962  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
   1963  const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
   1964  uint64_t block_var = UINT64_MAX;
   1965  const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3;
   1966  *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize,
   1967                                txsize_to_bsize[tx_size], block_mse_q8,
   1968                                per_px_mean, &block_var);
   1969  assert((*block_mse_q8) != UINT_MAX);
   1970  uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep);
   1971  if (is_cur_buf_hbd(xd))
   1972    block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2);
   1973 
   1974  if (block_var >= var_threshold) return;
   1975  const unsigned int predict_dc_level = x->txfm_search_params.predict_dc_level;
   1976  assert(predict_dc_level != 0);
   1977 
   1978  // Prediction of skip block if residual mean and variance are less
   1979  // than qstep based threshold
   1980  if ((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) {
   1981    // If the normalized mean of residual block is less than the dc qstep and
   1982    // the  normalized block variance is less than ac qstep, then the block is
   1983    // assumed to be a skip block and its rdcost is updated accordingly.
   1984    best_rd_stats->skip_txfm = 1;
   1985 
   1986    x->plane[plane].eobs[block] = 0;
   1987 
   1988    if (is_cur_buf_hbd(xd))
   1989      *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2);
   1990 
   1991    best_rd_stats->dist = (*block_sse) << 4;
   1992    best_rd_stats->sse = best_rd_stats->dist;
   1993 
   1994    ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
   1995    ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
   1996    av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl);
   1997    ENTROPY_CONTEXT *ta = ctxa;
   1998    ENTROPY_CONTEXT *tl = ctxl;
   1999    const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
   2000    TXB_CTX txb_ctx_tmp;
   2001    const PLANE_TYPE plane_type = get_plane_type(plane);
   2002    get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp);
   2003    const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type]
   2004                                  .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1];
   2005    best_rd_stats->rate = zero_blk_rate;
   2006 
   2007    best_rd_stats->rdcost =
   2008        RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse);
   2009 
   2010    x->plane[plane].txb_entropy_ctx[block] = 0;
   2011  } else if (predict_dc_level > 1) {
   2012    // Predict DC only blocks based on residual variance.
   2013    // For chroma plane, this prediction is disabled for intra blocks.
   2014    if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1;
   2015  }
   2016 }
   2017 
   2018 // Search for the best transform type for a given transform block.
   2019 // This function can be used for both inter and intra, both luma and chroma.
   2020 static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
   2021                           int block, int blk_row, int blk_col,
   2022                           BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
   2023                           const TXB_CTX *const txb_ctx,
   2024                           FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis,
   2025                           int64_t ref_best_rd, RD_STATS *best_rd_stats) {
   2026  const AV1_COMMON *cm = &cpi->common;
   2027  MACROBLOCKD *xd = &x->e_mbd;
   2028  MB_MODE_INFO *mbmi = xd->mi[0];
   2029  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
   2030  int64_t best_rd = INT64_MAX;
   2031  uint16_t best_eob = 0;
   2032  TX_TYPE best_tx_type = DCT_DCT;
   2033  int rate_cost = 0;
   2034  struct macroblock_plane *const p = &x->plane[plane];
   2035  tran_low_t *orig_dqcoeff = p->dqcoeff;
   2036  tran_low_t *best_dqcoeff = x->dqcoeff_buf;
   2037  const int tx_type_map_idx =
   2038      plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
   2039  av1_invalid_rd_stats(best_rd_stats);
   2040 
   2041  skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
   2042                                   DRY_RUN_NORMAL);
   2043 
   2044  uint8_t best_txb_ctx = 0;
   2045  // txk_allowed = TX_TYPES: >1 tx types are allowed
   2046  // txk_allowed < TX_TYPES: only that specific tx type is allowed.
   2047  TX_TYPE txk_allowed = TX_TYPES;
   2048  int txk_map[TX_TYPES] = {
   2049    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
   2050  };
   2051  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
   2052  const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
   2053 
   2054  const uint8_t txw = tx_size_wide[tx_size];
   2055  const uint8_t txh = tx_size_high[tx_size];
   2056  int64_t block_sse;
   2057  unsigned int block_mse_q8;
   2058  int dc_only_blk = 0;
   2059  const bool predict_dc_block =
   2060      txfm_params->predict_dc_level >= 1 && txw != 64 && txh != 64;
   2061  int64_t per_px_mean = INT64_MAX;
   2062  if (predict_dc_block) {
   2063    predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row,
   2064                          blk_col, best_rd_stats, &block_sse, &block_mse_q8,
   2065                          &per_px_mean, &dc_only_blk);
   2066    if (best_rd_stats->skip_txfm == 1) {
   2067      const TX_TYPE tx_type = DCT_DCT;
   2068      if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
   2069      return;
   2070    }
   2071  } else {
   2072    block_sse = av1_pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
   2073                                    txsize_to_bsize[tx_size], &block_mse_q8);
   2074    assert(block_mse_q8 != UINT_MAX);
   2075  }
   2076 
   2077  // Bit mask to indicate which transform types are allowed in the RD search.
   2078  uint16_t tx_mask;
   2079 
   2080  // Use DCT_DCT transform for DC only block.
   2081  if (dc_only_blk || cpi->sf.rt_sf.dct_only_palette_nonrd == 1)
   2082    tx_mask = 1 << DCT_DCT;
   2083  else
   2084    tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
   2085                          tx_size, txb_ctx, ftxs_mode, ref_best_rd,
   2086                          &txk_allowed, txk_map);
   2087  const uint16_t allowed_tx_mask = tx_mask;
   2088 
   2089  if (is_cur_buf_hbd(xd)) {
   2090    block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
   2091    block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
   2092  }
   2093  block_sse *= 16;
   2094  // Use mse / qstep^2 based threshold logic to take decision of R-D
   2095  // optimization of coeffs. For smaller residuals, coeff optimization
   2096  // would be helpful. For larger residuals, R-D optimization may not be
   2097  // effective.
   2098  // TODO(any): Experiment with variance and mean based thresholds
   2099  const int perform_block_coeff_opt =
   2100      ((uint64_t)block_mse_q8 <=
   2101       (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep);
   2102  skip_trellis |= !perform_block_coeff_opt;
   2103 
   2104  // Flag to indicate if distortion should be calculated in transform domain or
   2105  // not during iterating through transform type candidates.
   2106  // Transform domain distortion is accurate for higher residuals.
   2107  // TODO(any): Experiment with variance and mean based thresholds
   2108  int use_transform_domain_distortion =
   2109      (txfm_params->use_transform_domain_distortion > 0) &&
   2110      (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) &&
   2111      // Any 64-pt transforms only preserves half the coefficients.
   2112      // Therefore transform domain distortion is not valid for these
   2113      // transform sizes.
   2114      (txsize_sqr_up_map[tx_size] != TX_64X64) &&
   2115      // Use pixel domain distortion for DC only blocks
   2116      !dc_only_blk;
   2117  // Flag to indicate if an extra calculation of distortion in the pixel domain
   2118  // should be performed at the end, after the best transform type has been
   2119  // decided.
   2120  int calc_pixel_domain_distortion_final =
   2121      txfm_params->use_transform_domain_distortion == 1 &&
   2122      use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
   2123  if (calc_pixel_domain_distortion_final &&
   2124      (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
   2125    calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
   2126 
   2127  const uint16_t *eobs_ptr = x->plane[plane].eobs;
   2128 
   2129  TxfmParam txfm_param;
   2130  QUANT_PARAM quant_param;
   2131  int skip_trellis_based_on_satd[TX_TYPES] = { 0 };
   2132  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
   2133  av1_setup_quant(tx_size, !skip_trellis,
   2134                  skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
   2135                                                         : AV1_XFORM_QUANT_FP)
   2136                               : AV1_XFORM_QUANT_FP,
   2137                  cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
   2138 
   2139  // Iterate through all transform type candidates.
   2140  for (int idx = 0; idx < TX_TYPES; ++idx) {
   2141    const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
   2142    if (tx_type == TX_TYPE_INVALID || !check_bit_mask(allowed_tx_mask, tx_type))
   2143      continue;
   2144    txfm_param.tx_type = tx_type;
   2145    if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
   2146      av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
   2147                        &quant_param);
   2148    }
   2149    if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
   2150    RD_STATS this_rd_stats;
   2151    av1_invalid_rd_stats(&this_rd_stats);
   2152 
   2153    if (!dc_only_blk)
   2154      av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
   2155    else
   2156      av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
   2157 
   2158    skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd(
   2159        x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt,
   2160        qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk);
   2161 
   2162    av1_quant(x, plane, block, &txfm_param, &quant_param);
   2163 
   2164    // Calculate rate cost of quantized coefficients.
   2165    if (quant_param.use_optimize_b) {
   2166      // TODO(aomedia:3209): update Trellis quantization to take into account
   2167      // quantization matrices.
   2168      av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
   2169                     &rate_cost);
   2170    } else {
   2171      rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
   2172                              cm->features.reduced_tx_set_used);
   2173    }
   2174 
   2175    // If rd cost based on coeff rate alone is already more than best_rd,
   2176    // terminate early.
   2177    if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
   2178 
   2179    // Calculate distortion.
   2180    if (eobs_ptr[block] == 0) {
   2181      // When eob is 0, pixel domain distortion is more efficient and accurate.
   2182      this_rd_stats.dist = this_rd_stats.sse = block_sse;
   2183    } else if (dc_only_blk) {
   2184      this_rd_stats.sse = block_sse;
   2185      this_rd_stats.dist = dist_block_px_domain(
   2186          cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
   2187    } else if (use_transform_domain_distortion) {
   2188      const SCAN_ORDER *const scan_order =
   2189          get_scan(txfm_param.tx_size, txfm_param.tx_type);
   2190      dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
   2191                           scan_order->scan, &this_rd_stats.dist,
   2192                           &this_rd_stats.sse);
   2193    } else {
   2194      int64_t sse_diff = INT64_MAX;
   2195      // high_energy threshold assumes that every pixel within a txfm block
   2196      // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
   2197      // for 8 bit.
   2198      const int64_t high_energy_thresh =
   2199          ((int64_t)128 * 128 * tx_size_2d[tx_size]);
   2200      const int is_high_energy = (block_sse >= high_energy_thresh);
   2201      if (tx_size == TX_64X64 || is_high_energy) {
   2202        // Because 3 out 4 quadrants of transform coefficients are forced to
   2203        // zero, the inverse transform has a tendency to overflow. sse_diff
   2204        // is effectively the energy of those 3 quadrants, here we use it
   2205        // to decide if we should do pixel domain distortion. If the energy
   2206        // is mostly in first quadrant, then it is unlikely that we have
   2207        // overflow issue in inverse transform.
   2208        const SCAN_ORDER *const scan_order =
   2209            get_scan(txfm_param.tx_size, txfm_param.tx_type);
   2210        dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
   2211                             scan_order->scan, &this_rd_stats.dist,
   2212                             &this_rd_stats.sse);
   2213        sse_diff = block_sse - this_rd_stats.sse;
   2214      }
   2215      if (tx_size != TX_64X64 || !is_high_energy ||
   2216          (sse_diff * 2) < this_rd_stats.sse) {
   2217        const int64_t tx_domain_dist = this_rd_stats.dist;
   2218        this_rd_stats.dist = dist_block_px_domain(
   2219            cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
   2220        // For high energy blocks, occasionally, the pixel domain distortion
   2221        // can be artificially low due to clamping at reconstruction stage
   2222        // even when inverse transform output is hugely different from the
   2223        // actual residue.
   2224        if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
   2225          this_rd_stats.dist = tx_domain_dist;
   2226      } else {
   2227        assert(sse_diff < INT64_MAX);
   2228        this_rd_stats.dist += sse_diff;
   2229      }
   2230      this_rd_stats.sse = block_sse;
   2231    }
   2232 
   2233    this_rd_stats.rate = rate_cost;
   2234 
   2235    const int64_t rd =
   2236        RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
   2237 
   2238    if (rd < best_rd) {
   2239      best_rd = rd;
   2240      *best_rd_stats = this_rd_stats;
   2241      best_tx_type = tx_type;
   2242      best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
   2243      best_eob = x->plane[plane].eobs[block];
   2244      // Swap dqcoeff buffers
   2245      tran_low_t *const tmp_dqcoeff = best_dqcoeff;
   2246      best_dqcoeff = p->dqcoeff;
   2247      p->dqcoeff = tmp_dqcoeff;
   2248    }
   2249 
   2250 #if CONFIG_COLLECT_RD_STATS == 1
   2251    if (plane == 0) {
   2252      PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
   2253                              plane_bsize, tx_size, tx_type, rd);
   2254    }
   2255 #endif  // CONFIG_COLLECT_RD_STATS == 1
   2256 
   2257 #if COLLECT_TX_SIZE_DATA
   2258    // Generate small sample to restrict output size.
   2259    static unsigned int seed = 21743;
   2260    if (lcg_rand16(&seed) % 200 == 0) {
   2261      FILE *fp = NULL;
   2262 
   2263      if (within_border) {
   2264        fp = fopen(av1_tx_size_data_output_file, "a");
   2265      }
   2266 
   2267      if (fp) {
   2268        // Transform info and RD
   2269        const int txb_w = tx_size_wide[tx_size];
   2270        const int txb_h = tx_size_high[tx_size];
   2271 
   2272        // Residue signal.
   2273        const int diff_stride = block_size_wide[plane_bsize];
   2274        struct macroblock_plane *const p = &x->plane[plane];
   2275        const int16_t *src_diff =
   2276            &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
   2277 
   2278        for (int r = 0; r < txb_h; ++r) {
   2279          for (int c = 0; c < txb_w; ++c) {
   2280            fprintf(fp, "%d,", src_diff[c]);
   2281          }
   2282          src_diff += diff_stride;
   2283        }
   2284 
   2285        fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
   2286        fprintf(fp, "\n");
   2287        fclose(fp);
   2288      }
   2289    }
   2290 #endif  // COLLECT_TX_SIZE_DATA
   2291 
   2292    // If the current best RD cost is much worse than the reference RD cost,
   2293    // terminate early.
   2294    if (cpi->sf.tx_sf.adaptive_txb_search_level) {
   2295      if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
   2296          ref_best_rd) {
   2297        break;
   2298      }
   2299    }
   2300 
   2301    // Terminate transform type search if the block has been quantized to
   2302    // all zero.
   2303    if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
   2304  }
   2305 
   2306  assert(best_rd != INT64_MAX);
   2307 
   2308  best_rd_stats->skip_txfm = best_eob == 0;
   2309  if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
   2310  x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
   2311  x->plane[plane].eobs[block] = best_eob;
   2312  skip_trellis = skip_trellis_based_on_satd[best_tx_type];
   2313 
   2314  // Point dqcoeff to the quantized coefficients corresponding to the best
   2315  // transform type, then we can skip transform and quantization, e.g. in the
   2316  // final pixel domain distortion calculation and recon_intra().
   2317  p->dqcoeff = best_dqcoeff;
   2318 
   2319  if (calc_pixel_domain_distortion_final && best_eob) {
   2320    best_rd_stats->dist = dist_block_px_domain(
   2321        cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
   2322    best_rd_stats->sse = block_sse;
   2323  }
   2324 
   2325  // Intra mode needs decoded pixels such that the next transform block
   2326  // can use them for prediction.
   2327  recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
   2328              txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
   2329  p->dqcoeff = orig_dqcoeff;
   2330 }
   2331 
   2332 // Pick transform type for a luma transform block of tx_size. Note this function
   2333 // is used only for inter-predicted blocks.
   2334 static inline void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
   2335                              TX_SIZE tx_size, int blk_row, int blk_col,
   2336                              int block, int plane_bsize, TXB_CTX *txb_ctx,
   2337                              RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode,
   2338                              int64_t ref_rdcost) {
   2339  assert(is_inter_block(x->e_mbd.mi[0]));
   2340  RD_STATS this_rd_stats;
   2341  const int skip_trellis = 0;
   2342  search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
   2343                 txb_ctx, ftxs_mode, skip_trellis, ref_rdcost, &this_rd_stats);
   2344 
   2345  av1_merge_rd_stats(rd_stats, &this_rd_stats);
   2346 }
   2347 
   2348 static inline void try_tx_block_no_split(
   2349    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
   2350    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
   2351    const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
   2352    int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
   2353    FAST_TX_SEARCH_MODE ftxs_mode, TxCandidateInfo *no_split) {
   2354  MACROBLOCKD *const xd = &x->e_mbd;
   2355  MB_MODE_INFO *const mbmi = xd->mi[0];
   2356  struct macroblock_plane *const p = &x->plane[0];
   2357  const int bw = mi_size_wide[plane_bsize];
   2358  const ENTROPY_CONTEXT *const pta = ta + blk_col;
   2359  const ENTROPY_CONTEXT *const ptl = tl + blk_row;
   2360  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
   2361  TXB_CTX txb_ctx;
   2362  get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
   2363  const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
   2364                                .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
   2365  rd_stats->zero_rate = zero_blk_rate;
   2366  const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
   2367  mbmi->inter_tx_size[index] = tx_size;
   2368  tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
   2369             rd_stats, ftxs_mode, ref_best_rd);
   2370  assert(rd_stats->rate < INT_MAX);
   2371 
   2372  const int pick_skip_txfm =
   2373      !xd->lossless[mbmi->segment_id] &&
   2374      (rd_stats->skip_txfm == 1 ||
   2375       RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
   2376           RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
   2377  if (pick_skip_txfm) {
   2378 #if CONFIG_RD_DEBUG
   2379    update_txb_coeff_cost(rd_stats, 0, zero_blk_rate - rd_stats->rate);
   2380 #endif  // CONFIG_RD_DEBUG
   2381    rd_stats->rate = zero_blk_rate;
   2382    rd_stats->dist = rd_stats->sse;
   2383    p->eobs[block] = 0;
   2384    update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
   2385  }
   2386  rd_stats->skip_txfm = pick_skip_txfm;
   2387  set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
   2388               pick_skip_txfm);
   2389 
   2390  if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
   2391    rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0];
   2392 
   2393  no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
   2394  no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
   2395  no_split->tx_type =
   2396      xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
   2397 }
   2398 
   2399 static inline void try_tx_block_split(
   2400    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
   2401    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
   2402    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
   2403    int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
   2404    FAST_TX_SEARCH_MODE ftxs_mode, RD_STATS *split_rd_stats) {
   2405  assert(tx_size < TX_SIZES_ALL);
   2406  MACROBLOCKD *const xd = &x->e_mbd;
   2407  const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
   2408  const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
   2409  const int txb_width = tx_size_wide_unit[tx_size];
   2410  const int txb_height = tx_size_high_unit[tx_size];
   2411  // Transform size after splitting current block.
   2412  const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
   2413  const int sub_txb_width = tx_size_wide_unit[sub_txs];
   2414  const int sub_txb_height = tx_size_high_unit[sub_txs];
   2415  const int sub_step = sub_txb_width * sub_txb_height;
   2416  const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
   2417  assert(nblks > 0);
   2418  av1_init_rd_stats(split_rd_stats);
   2419  split_rd_stats->rate =
   2420      x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1];
   2421 
   2422  for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
   2423    const int offsetr = blk_row + r;
   2424    if (offsetr >= max_blocks_high) break;
   2425    for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
   2426      assert(blk_idx < 4);
   2427      const int offsetc = blk_col + c;
   2428      if (offsetc >= max_blocks_wide) continue;
   2429 
   2430      RD_STATS this_rd_stats;
   2431      int this_cost_valid = 1;
   2432      select_tx_block(cpi, x, offsetr, offsetc, block, sub_txs, depth + 1,
   2433                      plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
   2434                      no_split_rd / nblks, ref_best_rd - split_rd_stats->rdcost,
   2435                      &this_cost_valid, ftxs_mode);
   2436      if (!this_cost_valid) {
   2437        split_rd_stats->rdcost = INT64_MAX;
   2438        return;
   2439      }
   2440      av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
   2441      split_rd_stats->rdcost =
   2442          RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
   2443      if (split_rd_stats->rdcost > ref_best_rd) {
   2444        split_rd_stats->rdcost = INT64_MAX;
   2445        return;
   2446      }
   2447      block += sub_step;
   2448    }
   2449  }
   2450 }
   2451 
   2452 static float get_var(float mean, double x2_sum, int num) {
   2453  const float e_x2 = (float)(x2_sum / num);
   2454  const float diff = e_x2 - mean * mean;
   2455  return diff;
   2456 }
   2457 
   2458 static inline void get_blk_var_dev(const int16_t *data, int stride, int bw,
   2459                                   int bh, float *dev_of_mean,
   2460                                   float *var_of_vars) {
   2461  const int16_t *const data_ptr = &data[0];
   2462  const int subh = (bh >= bw) ? (bh >> 1) : bh;
   2463  const int subw = (bw >= bh) ? (bw >> 1) : bw;
   2464  const int num = bw * bh;
   2465  const int sub_num = subw * subh;
   2466  int total_x_sum = 0;
   2467  int64_t total_x2_sum = 0;
   2468  int blk_idx = 0;
   2469  float var_sum = 0.0f;
   2470  float mean_sum = 0.0f;
   2471  double var2_sum = 0.0f;
   2472  double mean2_sum = 0.0f;
   2473 
   2474  for (int row = 0; row < bh; row += subh) {
   2475    for (int col = 0; col < bw; col += subw) {
   2476      int x_sum;
   2477      int64_t x2_sum;
   2478      aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
   2479                          &x_sum, &x2_sum);
   2480      total_x_sum += x_sum;
   2481      total_x2_sum += x2_sum;
   2482 
   2483      const float mean = (float)x_sum / sub_num;
   2484      const float var = get_var(mean, (double)x2_sum, sub_num);
   2485      mean_sum += mean;
   2486      mean2_sum += (double)(mean * mean);
   2487      var_sum += var;
   2488      var2_sum += var * var;
   2489      blk_idx++;
   2490    }
   2491  }
   2492 
   2493  const float lvl0_mean = (float)total_x_sum / num;
   2494  const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num);
   2495  mean_sum += lvl0_mean;
   2496  mean2_sum += (double)(lvl0_mean * lvl0_mean);
   2497  var_sum += block_var;
   2498  var2_sum += block_var * block_var;
   2499  const float av_mean = mean_sum / 5;
   2500 
   2501  if (blk_idx > 1) {
   2502    // Deviation of means.
   2503    *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1));
   2504    // Variance of variances.
   2505    const float mean_var = var_sum / (blk_idx + 1);
   2506    *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1));
   2507  }
   2508 }
   2509 
   2510 static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
   2511                                    int blk_row, int blk_col, TX_SIZE tx_size,
   2512                                    int *try_no_split, int *try_split,
   2513                                    int pruning_level) {
   2514  const int diff_stride = block_size_wide[bsize];
   2515  const int16_t *diff =
   2516      x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
   2517  const int bw = tx_size_wide[tx_size];
   2518  const int bh = tx_size_high[tx_size];
   2519  float dev_of_means = 0.0f;
   2520  float var_of_vars = 0.0f;
   2521 
   2522  // This function calculates the deviation of means, and the variance of pixel
   2523  // variances of the block as well as it's sub-blocks.
   2524  get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars);
   2525  const int dc_q = x->plane[0].dequant_QTX[0] >> 3;
   2526  const int ac_q = x->plane[0].dequant_QTX[1] >> 3;
   2527  const int no_split_thresh_scales[4] = { 0, 24, 8, 8 };
   2528  const int no_split_thresh_scale = no_split_thresh_scales[pruning_level];
   2529  const int split_thresh_scales[4] = { 0, 24, 10, 8 };
   2530  const int split_thresh_scale = split_thresh_scales[pruning_level];
   2531 
   2532  if ((dev_of_means <= dc_q) &&
   2533      (split_thresh_scale * var_of_vars <= ac_q * ac_q)) {
   2534    *try_split = 0;
   2535  }
   2536  if ((dev_of_means > no_split_thresh_scale * dc_q) &&
   2537      (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) {
   2538    *try_no_split = 0;
   2539  }
   2540 }
   2541 
   2542 // Search for the best transform partition(recursive)/type for a given
   2543 // inter-predicted luma block. The obtained transform selection will be saved
   2544 // in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
   2545 static inline void select_tx_block(
   2546    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
   2547    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
   2548    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
   2549    RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
   2550    int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode) {
   2551  assert(tx_size < TX_SIZES_ALL);
   2552  av1_init_rd_stats(rd_stats);
   2553  if (ref_best_rd < 0) {
   2554    *is_cost_valid = 0;
   2555    return;
   2556  }
   2557 
   2558  MACROBLOCKD *const xd = &x->e_mbd;
   2559  assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
   2560         blk_col < max_block_wide(xd, plane_bsize, 0));
   2561  MB_MODE_INFO *const mbmi = xd->mi[0];
   2562  const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
   2563                                         mbmi->bsize, tx_size);
   2564  struct macroblock_plane *const p = &x->plane[0];
   2565 
   2566  int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
   2567                      txsize_sqr_up_map[tx_size] != TX_64X64) &&
   2568                     (cpi->oxcf.txfm_cfg.enable_rect_tx ||
   2569                      tx_size_wide[tx_size] == tx_size_high[tx_size]);
   2570  int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
   2571  TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
   2572 
   2573  // Prune tx_split and no-split based on sub-block properties.
   2574  if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 &&
   2575      cpi->sf.tx_sf.prune_tx_size_level > 0) {
   2576    prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size,
   2577                            &try_no_split, &try_split,
   2578                            cpi->sf.tx_sf.prune_tx_size_level);
   2579  }
   2580 
   2581  if (cpi->sf.rt_sf.skip_tx_no_split_var_based_partition) {
   2582    if (x->try_merge_partition && try_split && p->eobs[block]) try_no_split = 0;
   2583  }
   2584 
   2585  // Try using current block as a single transform block without split.
   2586  if (try_no_split) {
   2587    try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
   2588                          plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
   2589                          ftxs_mode, &no_split);
   2590 
   2591    // Speed features for early termination.
   2592    const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
   2593    if (search_level) {
   2594      if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
   2595        *is_cost_valid = 0;
   2596        return;
   2597      }
   2598      if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
   2599        try_split = 0;
   2600      }
   2601    }
   2602    if (cpi->sf.tx_sf.txb_split_cap) {
   2603      if (p->eobs[block] == 0) try_split = 0;
   2604    }
   2605  }
   2606 
   2607  // ML based speed feature to skip searching for split transform blocks.
   2608  if (x->e_mbd.bd == 8 && try_split &&
   2609      !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
   2610    const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
   2611    if (threshold >= 0) {
   2612      const int split_score =
   2613          ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
   2614      if (split_score < -threshold) try_split = 0;
   2615    }
   2616  }
   2617 
   2618  RD_STATS split_rd_stats;
   2619  split_rd_stats.rdcost = INT64_MAX;
   2620  // Try splitting current block into smaller transform blocks.
   2621  if (try_split) {
   2622    try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
   2623                       plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
   2624                       AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
   2625                       &split_rd_stats);
   2626  }
   2627 
   2628  if (no_split.rd < split_rd_stats.rdcost) {
   2629    ENTROPY_CONTEXT *pta = ta + blk_col;
   2630    ENTROPY_CONTEXT *ptl = tl + blk_row;
   2631    p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
   2632    av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
   2633    txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
   2634                          tx_size);
   2635    for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
   2636      for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
   2637        const int index =
   2638            av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
   2639        mbmi->inter_tx_size[index] = tx_size;
   2640      }
   2641    }
   2642    mbmi->tx_size = tx_size;
   2643    update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
   2644    const int bw = mi_size_wide[plane_bsize];
   2645    set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
   2646                 rd_stats->skip_txfm);
   2647  } else {
   2648    *rd_stats = split_rd_stats;
   2649    if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
   2650  }
   2651 }
   2652 
   2653 static inline void choose_largest_tx_size(const AV1_COMP *const cpi,
   2654                                          MACROBLOCK *x, RD_STATS *rd_stats,
   2655                                          int64_t ref_best_rd, BLOCK_SIZE bs) {
   2656  MACROBLOCKD *const xd = &x->e_mbd;
   2657  MB_MODE_INFO *const mbmi = xd->mi[0];
   2658  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
   2659  mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
   2660 
   2661  // If tx64 is not enabled, we need to go down to the next available size
   2662  if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
   2663    static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
   2664      TX_4X4,    // 4x4 transform
   2665      TX_8X8,    // 8x8 transform
   2666      TX_16X16,  // 16x16 transform
   2667      TX_32X32,  // 32x32 transform
   2668      TX_32X32,  // 64x64 transform
   2669      TX_4X8,    // 4x8 transform
   2670      TX_8X4,    // 8x4 transform
   2671      TX_8X16,   // 8x16 transform
   2672      TX_16X8,   // 16x8 transform
   2673      TX_16X32,  // 16x32 transform
   2674      TX_32X16,  // 32x16 transform
   2675      TX_32X32,  // 32x64 transform
   2676      TX_32X32,  // 64x32 transform
   2677      TX_4X16,   // 4x16 transform
   2678      TX_16X4,   // 16x4 transform
   2679      TX_8X32,   // 8x32 transform
   2680      TX_32X8,   // 32x8 transform
   2681      TX_16X32,  // 16x64 transform
   2682      TX_32X16,  // 64x16 transform
   2683    };
   2684    mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
   2685  } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
   2686             !cpi->oxcf.txfm_cfg.enable_rect_tx) {
   2687    static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
   2688      TX_4X4,    // 4x4 transform
   2689      TX_8X8,    // 8x8 transform
   2690      TX_16X16,  // 16x16 transform
   2691      TX_32X32,  // 32x32 transform
   2692      TX_64X64,  // 64x64 transform
   2693      TX_4X4,    // 4x8 transform
   2694      TX_4X4,    // 8x4 transform
   2695      TX_8X8,    // 8x16 transform
   2696      TX_8X8,    // 16x8 transform
   2697      TX_16X16,  // 16x32 transform
   2698      TX_16X16,  // 32x16 transform
   2699      TX_32X32,  // 32x64 transform
   2700      TX_32X32,  // 64x32 transform
   2701      TX_4X4,    // 4x16 transform
   2702      TX_4X4,    // 16x4 transform
   2703      TX_8X8,    // 8x32 transform
   2704      TX_8X8,    // 32x8 transform
   2705      TX_16X16,  // 16x64 transform
   2706      TX_16X16,  // 64x16 transform
   2707    };
   2708    mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
   2709  } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
   2710             !cpi->oxcf.txfm_cfg.enable_rect_tx) {
   2711    static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
   2712      TX_4X4,    // 4x4 transform
   2713      TX_8X8,    // 8x8 transform
   2714      TX_16X16,  // 16x16 transform
   2715      TX_32X32,  // 32x32 transform
   2716      TX_32X32,  // 64x64 transform
   2717      TX_4X4,    // 4x8 transform
   2718      TX_4X4,    // 8x4 transform
   2719      TX_8X8,    // 8x16 transform
   2720      TX_8X8,    // 16x8 transform
   2721      TX_16X16,  // 16x32 transform
   2722      TX_16X16,  // 32x16 transform
   2723      TX_32X32,  // 32x64 transform
   2724      TX_32X32,  // 64x32 transform
   2725      TX_4X4,    // 4x16 transform
   2726      TX_4X4,    // 16x4 transform
   2727      TX_8X8,    // 8x32 transform
   2728      TX_8X8,    // 32x8 transform
   2729      TX_16X16,  // 16x64 transform
   2730      TX_16X16,  // 64x16 transform
   2731    };
   2732 
   2733    mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
   2734  }
   2735 
   2736  const int skip_ctx = av1_get_skip_txfm_context(xd);
   2737  const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
   2738  const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
   2739  // Skip RDcost is used only for Inter blocks
   2740  const int64_t skip_txfm_rd =
   2741      is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
   2742  const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0);
   2743  const int skip_trellis = 0;
   2744  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
   2745                       AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
   2746                       mbmi->tx_size, FTXS_NONE, skip_trellis);
   2747 }
   2748 
   2749 static inline void choose_smallest_tx_size(const AV1_COMP *const cpi,
   2750                                           MACROBLOCK *x, RD_STATS *rd_stats,
   2751                                           int64_t ref_best_rd, BLOCK_SIZE bs) {
   2752  MACROBLOCKD *const xd = &x->e_mbd;
   2753  MB_MODE_INFO *const mbmi = xd->mi[0];
   2754 
   2755  mbmi->tx_size = TX_4X4;
   2756  // TODO(any) : Pass this_rd based on skip/non-skip cost
   2757  const int skip_trellis = 0;
   2758  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
   2759                       FTXS_NONE, skip_trellis);
   2760 }
   2761 
   2762 #if !CONFIG_REALTIME_ONLY
   2763 static void ml_predict_intra_tx_depth_prune(MACROBLOCK *x, int blk_row,
   2764                                            int blk_col, BLOCK_SIZE bsize,
   2765                                            TX_SIZE tx_size) {
   2766  const MACROBLOCKD *const xd = &x->e_mbd;
   2767  const MB_MODE_INFO *const mbmi = xd->mi[0];
   2768 
   2769  // Disable the pruning logic using NN model for the following cases:
   2770  // 1) Lossless coding as only 4x4 transform is evaluated in this case
   2771  // 2) When transform and current block sizes do not match as the features are
   2772  // obtained over the current block
   2773  // 3) When operating bit-depth is not 8-bit as the input features are not
   2774  // scaled according to bit-depth.
   2775  if (xd->lossless[mbmi->segment_id] || txsize_to_bsize[tx_size] != bsize ||
   2776      xd->bd != 8)
   2777    return;
   2778 
   2779  // Currently NN model based pruning is supported only when largest transform
   2780  // size is 8x8
   2781  if (tx_size != TX_8X8) return;
   2782 
   2783  // Neural network model is a sequential neural net and was trained using SGD
   2784  // optimizer. The model can be further improved in terms of speed/quality by
   2785  // considering the following experiments:
   2786  // 1) Generate ML model by training with balanced data for different learning
   2787  // rates and optimizers.
   2788  // 2) Experiment with ML model by adding features related to the statistics of
   2789  // top and left pixels to capture the accuracy of reconstructed neighbouring
   2790  // pixels for 4x4 blocks numbered 1, 2, 3 in 8x8 block, source variance of 4x4
   2791  // sub-blocks, etc.
   2792  // 3) Generate ML models for transform blocks other than 8x8.
   2793  const NN_CONFIG *const nn_config = &av1_intra_tx_split_nnconfig_8x8;
   2794  const float *const intra_tx_prune_thresh = av1_intra_tx_prune_nn_thresh_8x8;
   2795 
   2796  float features[NUM_INTRA_TX_SPLIT_FEATURES] = { 0.0f };
   2797  const int diff_stride = block_size_wide[bsize];
   2798 
   2799  const int16_t *diff = x->plane[0].src_diff + MI_SIZE * blk_row * diff_stride +
   2800                        MI_SIZE * blk_col;
   2801  const int bw = tx_size_wide[tx_size];
   2802  const int bh = tx_size_high[tx_size];
   2803 
   2804  int feature_idx = get_mean_dev_features(diff, diff_stride, bw, bh, features);
   2805 
   2806  features[feature_idx++] = log1pf((float)x->source_variance);
   2807 
   2808  const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
   2809  const float log_dc_q_square = log1pf((float)(dc_q * dc_q) / 256.0f);
   2810  features[feature_idx++] = log_dc_q_square;
   2811  assert(feature_idx == NUM_INTRA_TX_SPLIT_FEATURES);
   2812  for (int i = 0; i < NUM_INTRA_TX_SPLIT_FEATURES; i++) {
   2813    features[i] = (features[i] - av1_intra_tx_split_8x8_mean[i]) /
   2814                  av1_intra_tx_split_8x8_std[i];
   2815  }
   2816 
   2817  float score;
   2818  av1_nn_predict(features, nn_config, 1, &score);
   2819 
   2820  TxfmSearchParams *const txfm_params = &x->txfm_search_params;
   2821  if (score <= intra_tx_prune_thresh[0])
   2822    txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_SPLIT;
   2823  else if (score > intra_tx_prune_thresh[1])
   2824    txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_LARGEST;
   2825 }
   2826 #endif  // !CONFIG_REALTIME_ONLY
   2827 
   2828 /*!\brief Transform type search for luma macroblock with fixed transform size.
   2829 *
   2830 * \ingroup transform_search
   2831 * Search for the best transform type and return the transform coefficients RD
   2832 * cost of current luma macroblock with the given uniform transform size.
   2833 *
   2834 * \param[in]    x              Pointer to structure holding the data for the
   2835                                current encoding macroblock
   2836 * \param[in]    cpi            Top-level encoder structure
   2837 * \param[in]    rd_stats       Pointer to struct to keep track of the RD stats
   2838 * \param[in]    ref_best_rd    Best RD cost seen for this block so far
   2839 * \param[in]    bs             Size of the current macroblock
   2840 * \param[in]    tx_size        The given transform size
   2841 * \param[in]    ftxs_mode      Transform search mode specifying desired speed
   2842                                and quality tradeoff
   2843 * \param[in]    skip_trellis   Binary flag indicating if trellis optimization
   2844                                should be skipped
   2845 * \return       An int64_t value that is the best RD cost found.
   2846 */
   2847 static int64_t uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
   2848                                RD_STATS *rd_stats, int64_t ref_best_rd,
   2849                                BLOCK_SIZE bs, TX_SIZE tx_size,
   2850                                FAST_TX_SEARCH_MODE ftxs_mode,
   2851                                int skip_trellis) {
   2852  assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
   2853  MACROBLOCKD *const xd = &x->e_mbd;
   2854  MB_MODE_INFO *const mbmi = xd->mi[0];
   2855  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
   2856  const ModeCosts *mode_costs = &x->mode_costs;
   2857  const int is_inter = is_inter_block(mbmi);
   2858  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
   2859                        block_signals_txsize(mbmi->bsize);
   2860  int tx_size_rate = 0;
   2861  if (tx_select) {
   2862    const int ctx = txfm_partition_context(
   2863        xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
   2864    tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0]
   2865                            : tx_size_cost(x, bs, tx_size);
   2866  }
   2867  const int skip_ctx = av1_get_skip_txfm_context(xd);
   2868  const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
   2869  const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
   2870  const int64_t skip_txfm_rd =
   2871      is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
   2872  const int64_t no_this_rd =
   2873      RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
   2874 
   2875  mbmi->tx_size = tx_size;
   2876  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
   2877                       AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
   2878                       tx_size, ftxs_mode, skip_trellis);
   2879  if (rd_stats->rate == INT_MAX) return INT64_MAX;
   2880 
   2881  int64_t rd;
   2882  // rdstats->rate should include all the rate except skip/non-skip cost as the
   2883  // same is accounted in the caller functions after rd evaluation of all
   2884  // planes. However the decisions should be done after considering the
   2885  // skip/non-skip header cost
   2886  if (rd_stats->skip_txfm && is_inter) {
   2887    rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
   2888  } else {
   2889    // Intra blocks are always signalled as non-skip
   2890    rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
   2891                rd_stats->dist);
   2892    rd_stats->rate += tx_size_rate;
   2893  }
   2894  // Check if forcing the block to skip transform leads to smaller RD cost.
   2895  if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
   2896    int64_t temp_skip_txfm_rd =
   2897        RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
   2898    if (temp_skip_txfm_rd <= rd) {
   2899      rd = temp_skip_txfm_rd;
   2900      rd_stats->rate = 0;
   2901      rd_stats->dist = rd_stats->sse;
   2902      rd_stats->skip_txfm = 1;
   2903    }
   2904  }
   2905 
   2906  return rd;
   2907 }
   2908 
   2909 // Search for the best uniform transform size and type for current coding block.
   2910 static inline void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
   2911                                               MACROBLOCK *x,
   2912                                               RD_STATS *rd_stats,
   2913                                               int64_t ref_best_rd,
   2914                                               BLOCK_SIZE bs) {
   2915  av1_invalid_rd_stats(rd_stats);
   2916 
   2917  MACROBLOCKD *const xd = &x->e_mbd;
   2918  MB_MODE_INFO *const mbmi = xd->mi[0];
   2919  TxfmSearchParams *const txfm_params = &x->txfm_search_params;
   2920  const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
   2921  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT;
   2922  int start_tx;
   2923  // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
   2924  // how many times of splitting is allowed during the RD search.
   2925  int init_depth;
   2926 
   2927  if (tx_select) {
   2928    start_tx = max_rect_tx_size;
   2929    init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
   2930                                       is_inter_block(mbmi), &cpi->sf,
   2931                                       txfm_params->tx_size_search_method);
   2932    if (init_depth == MAX_TX_DEPTH && !cpi->oxcf.txfm_cfg.enable_tx64 &&
   2933        txsize_sqr_up_map[start_tx] == TX_64X64) {
   2934      start_tx = sub_tx_size_map[start_tx];
   2935    }
   2936  } else {
   2937    const TX_SIZE chosen_tx_size =
   2938        tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
   2939    start_tx = chosen_tx_size;
   2940    init_depth = MAX_TX_DEPTH;
   2941  }
   2942 
   2943  const int skip_trellis = 0;
   2944  uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
   2945  uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
   2946  TX_SIZE best_tx_size = max_rect_tx_size;
   2947  int64_t best_rd = INT64_MAX;
   2948  const int num_blks = bsize_to_num_blk(bs);
   2949  x->rd_model = FULL_TXFM_RD;
   2950  int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
   2951  TxfmSearchInfo *txfm_info = &x->txfm_search_info;
   2952  for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
   2953       depth++, tx_size = sub_tx_size_map[tx_size]) {
   2954    if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
   2955         txsize_sqr_up_map[tx_size] == TX_64X64) ||
   2956        (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
   2957         tx_size_wide[tx_size] != tx_size_high[tx_size])) {
   2958      continue;
   2959    }
   2960 
   2961 #if !CONFIG_REALTIME_ONLY
   2962    if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_SPLIT) break;
   2963 
   2964    // Set the flag to enable the evaluation of NN classifier to prune transform
   2965    // depths. As the features are based on intra residual information of
   2966    // largest transform, the evaluation of NN model is enabled only for this
   2967    // case.
   2968    txfm_params->enable_nn_prune_intra_tx_depths =
   2969        (cpi->sf.tx_sf.prune_intra_tx_depths_using_nn && tx_size == start_tx);
   2970 #endif
   2971 
   2972    RD_STATS this_rd_stats;
   2973    // When the speed feature use_rd_based_breakout_for_intra_tx_search is
   2974    // enabled, use the known minimum best_rd for early termination.
   2975    const int64_t rd_thresh =
   2976        cpi->sf.tx_sf.use_rd_based_breakout_for_intra_tx_search
   2977            ? AOMMIN(ref_best_rd, best_rd)
   2978            : ref_best_rd;
   2979    rd[depth] = uniform_txfm_yrd(cpi, x, &this_rd_stats, rd_thresh, bs, tx_size,
   2980                                 FTXS_NONE, skip_trellis);
   2981    if (rd[depth] < best_rd) {
   2982      av1_copy_array(best_blk_skip, txfm_info->blk_skip, num_blks);
   2983      av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
   2984      best_tx_size = tx_size;
   2985      best_rd = rd[depth];
   2986      *rd_stats = this_rd_stats;
   2987    }
   2988    if (tx_size == TX_4X4) break;
   2989    // If we are searching three depths, prune the smallest size depending
   2990    // on rd results for the first two depths for low contrast blocks.
   2991    if (depth > init_depth && depth != MAX_TX_DEPTH &&
   2992        x->source_variance < 256) {
   2993      if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
   2994    }
   2995  }
   2996 
   2997  if (rd_stats->rate != INT_MAX) {
   2998    mbmi->tx_size = best_tx_size;
   2999    av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
   3000    av1_copy_array(txfm_info->blk_skip, best_blk_skip, num_blks);
   3001  }
   3002 
   3003 #if !CONFIG_REALTIME_ONLY
   3004  // Reset the flags to avoid any unintentional evaluation of NN model and
   3005  // consumption of prune depths.
   3006  txfm_params->enable_nn_prune_intra_tx_depths = false;
   3007  txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_NONE;
   3008 #endif
   3009 }
   3010 
   3011 // Search for the best transform type for the given transform block in the
   3012 // given plane/channel, and calculate the corresponding RD cost.
   3013 static inline void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
   3014                                 BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
   3015                                 void *arg) {
   3016  struct rdcost_block_args *args = arg;
   3017  if (args->exit_early) {
   3018    args->incomplete_exit = 1;
   3019    return;
   3020  }
   3021 
   3022  MACROBLOCK *const x = args->x;
   3023  MACROBLOCKD *const xd = &x->e_mbd;
   3024  const int is_inter = is_inter_block(xd->mi[0]);
   3025  const AV1_COMP *cpi = args->cpi;
   3026  ENTROPY_CONTEXT *a = args->t_above + blk_col;
   3027  ENTROPY_CONTEXT *l = args->t_left + blk_row;
   3028  const AV1_COMMON *cm = &cpi->common;
   3029  RD_STATS this_rd_stats;
   3030  av1_init_rd_stats(&this_rd_stats);
   3031 
   3032  if (!is_inter) {
   3033    av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
   3034    av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
   3035 #if !CONFIG_REALTIME_ONLY
   3036    const TxfmSearchParams *const txfm_params = &x->txfm_search_params;
   3037    if (txfm_params->enable_nn_prune_intra_tx_depths) {
   3038      ml_predict_intra_tx_depth_prune(x, blk_row, blk_col, plane_bsize,
   3039                                      tx_size);
   3040      if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_LARGEST) {
   3041        av1_invalid_rd_stats(&args->rd_stats);
   3042        args->exit_early = 1;
   3043        return;
   3044      }
   3045    }
   3046 #endif
   3047  }
   3048 
   3049  TXB_CTX txb_ctx;
   3050  get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
   3051  search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
   3052                 &txb_ctx, args->ftxs_mode, args->skip_trellis,
   3053                 args->best_rd - args->current_rd, &this_rd_stats);
   3054 
   3055 #if !CONFIG_REALTIME_ONLY
   3056  if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
   3057    assert(!is_inter || plane_bsize < BLOCK_8X8);
   3058    cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
   3059  }
   3060 #endif
   3061 
   3062 #if CONFIG_RD_DEBUG
   3063  update_txb_coeff_cost(&this_rd_stats, plane, this_rd_stats.rate);
   3064 #endif  // CONFIG_RD_DEBUG
   3065  av1_set_txb_context(x, plane, block, tx_size, a, l);
   3066 
   3067  const int blk_idx =
   3068      blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
   3069 
   3070  TxfmSearchInfo *txfm_info = &x->txfm_search_info;
   3071  if (plane == 0)
   3072    set_blk_skip(txfm_info->blk_skip, plane, blk_idx,
   3073                 x->plane[plane].eobs[block] == 0);
   3074  else
   3075    set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0);
   3076 
   3077  int64_t rd;
   3078  if (is_inter) {
   3079    const int64_t no_skip_txfm_rd =
   3080        RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
   3081    const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
   3082    rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
   3083    this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block];
   3084  } else {
   3085    // Signal non-skip_txfm for Intra blocks
   3086    rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
   3087    this_rd_stats.skip_txfm = 0;
   3088  }
   3089 
   3090  av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
   3091 
   3092  args->current_rd += rd;
   3093  if (args->current_rd > args->best_rd) args->exit_early = 1;
   3094 }
   3095 
   3096 int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
   3097                              RD_STATS *rd_stats, int64_t ref_best_rd,
   3098                              BLOCK_SIZE bs, TX_SIZE tx_size) {
   3099  MACROBLOCKD *const xd = &x->e_mbd;
   3100  MB_MODE_INFO *const mbmi = xd->mi[0];
   3101  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
   3102  const ModeCosts *mode_costs = &x->mode_costs;
   3103  const int is_inter = is_inter_block(mbmi);
   3104  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
   3105                        block_signals_txsize(mbmi->bsize);
   3106  int tx_size_rate = 0;
   3107  if (tx_select) {
   3108    const int ctx = txfm_partition_context(
   3109        xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
   3110    tx_size_rate = mode_costs->txfm_partition_cost[ctx][0];
   3111  }
   3112  const int skip_ctx = av1_get_skip_txfm_context(xd);
   3113  const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
   3114  const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
   3115  const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0);
   3116  const int64_t no_this_rd =
   3117      RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
   3118  mbmi->tx_size = tx_size;
   3119 
   3120  const uint8_t txw_unit = tx_size_wide_unit[tx_size];
   3121  const uint8_t txh_unit = tx_size_high_unit[tx_size];
   3122  const int step = txw_unit * txh_unit;
   3123  const int max_blocks_wide = max_block_wide(xd, bs, 0);
   3124  const int max_blocks_high = max_block_high(xd, bs, 0);
   3125 
   3126  struct rdcost_block_args args;
   3127  av1_zero(args);
   3128  args.x = x;
   3129  args.cpi = cpi;
   3130  args.best_rd = ref_best_rd;
   3131  args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd);
   3132  av1_init_rd_stats(&args.rd_stats);
   3133  av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left);
   3134  int i = 0;
   3135  for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit;
   3136       blk_row += txh_unit) {
   3137    for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) {
   3138      RD_STATS this_rd_stats;
   3139      av1_init_rd_stats(&this_rd_stats);
   3140 
   3141      if (args.exit_early) {
   3142        args.incomplete_exit = 1;
   3143        break;
   3144      }
   3145 
   3146      ENTROPY_CONTEXT *a = args.t_above + blk_col;
   3147      ENTROPY_CONTEXT *l = args.t_left + blk_row;
   3148      TXB_CTX txb_ctx;
   3149      get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx);
   3150 
   3151      TxfmParam txfm_param;
   3152      QUANT_PARAM quant_param;
   3153      av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param);
   3154      av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param);
   3155 
   3156      av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param);
   3157      av1_quant(x, 0, i, &txfm_param, &quant_param);
   3158 
   3159      this_rd_stats.rate =
   3160          cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0);
   3161 
   3162      const SCAN_ORDER *const scan_order =
   3163          get_scan(txfm_param.tx_size, txfm_param.tx_type);
   3164      dist_block_tx_domain(x, 0, i, tx_size, quant_param.qmatrix,
   3165                           scan_order->scan, &this_rd_stats.dist,
   3166                           &this_rd_stats.sse);
   3167 
   3168      const int64_t no_skip_txfm_rd =
   3169          RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
   3170      const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
   3171 
   3172      this_rd_stats.skip_txfm &= !x->plane[0].eobs[i];
   3173 
   3174      av1_merge_rd_stats(&args.rd_stats, &this_rd_stats);
   3175      args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd);
   3176 
   3177      if (args.current_rd > ref_best_rd) {
   3178        args.exit_early = 1;
   3179        break;
   3180      }
   3181 
   3182      av1_set_txb_context(x, 0, i, tx_size, a, l);
   3183      i += step;
   3184    }
   3185  }
   3186 
   3187  if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats);
   3188 
   3189  *rd_stats = args.rd_stats;
   3190  if (rd_stats->rate == INT_MAX) return INT64_MAX;
   3191 
   3192  int64_t rd;
   3193  // rdstats->rate should include all the rate except skip/non-skip cost as the
   3194  // same is accounted in the caller functions after rd evaluation of all
   3195  // planes. However the decisions should be done after considering the
   3196  // skip/non-skip header cost
   3197  if (rd_stats->skip_txfm && is_inter) {
   3198    rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
   3199  } else {
   3200    // Intra blocks are always signalled as non-skip
   3201    rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
   3202                rd_stats->dist);
   3203    rd_stats->rate += tx_size_rate;
   3204  }
   3205  // Check if forcing the block to skip transform leads to smaller RD cost.
   3206  if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
   3207    int64_t temp_skip_txfm_rd =
   3208        RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
   3209    if (temp_skip_txfm_rd <= rd) {
   3210      rd = temp_skip_txfm_rd;
   3211      rd_stats->rate = 0;
   3212      rd_stats->dist = rd_stats->sse;
   3213      rd_stats->skip_txfm = 1;
   3214    }
   3215  }
   3216 
   3217  return rd;
   3218 }
   3219 
   3220 // Search for the best transform type for a luma inter-predicted block, given
   3221 // the transform block partitions.
   3222 // This function is used only when some speed features are enabled.
   3223 static inline void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
   3224                                int blk_col, int block, TX_SIZE tx_size,
   3225                                BLOCK_SIZE plane_bsize, int depth,
   3226                                ENTROPY_CONTEXT *above_ctx,
   3227                                ENTROPY_CONTEXT *left_ctx,
   3228                                TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
   3229                                int64_t ref_best_rd, RD_STATS *rd_stats,
   3230                                FAST_TX_SEARCH_MODE ftxs_mode) {
   3231  assert(tx_size < TX_SIZES_ALL);
   3232  MACROBLOCKD *const xd = &x->e_mbd;
   3233  MB_MODE_INFO *const mbmi = xd->mi[0];
   3234  assert(is_inter_block(mbmi));
   3235  const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
   3236  const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
   3237 
   3238  if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
   3239 
   3240  const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
   3241      plane_bsize, blk_row, blk_col)];
   3242  const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
   3243                                         mbmi->bsize, tx_size);
   3244 
   3245  av1_init_rd_stats(rd_stats);
   3246  if (tx_size == plane_tx_size) {
   3247    ENTROPY_CONTEXT *ta = above_ctx + blk_col;
   3248    ENTROPY_CONTEXT *tl = left_ctx + blk_row;
   3249    const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
   3250    TXB_CTX txb_ctx;
   3251    get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
   3252 
   3253    const int zero_blk_rate =
   3254        x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)]
   3255            .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
   3256    rd_stats->zero_rate = zero_blk_rate;
   3257    tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
   3258               rd_stats, ftxs_mode, ref_best_rd);
   3259    const int mi_width = mi_size_wide[plane_bsize];
   3260    TxfmSearchInfo *txfm_info = &x->txfm_search_info;
   3261    if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
   3262            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
   3263        rd_stats->skip_txfm == 1) {
   3264      rd_stats->rate = zero_blk_rate;
   3265      rd_stats->dist = rd_stats->sse;
   3266      rd_stats->skip_txfm = 1;
   3267      set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 1);
   3268      x->plane[0].eobs[block] = 0;
   3269      x->plane[0].txb_entropy_ctx[block] = 0;
   3270      update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
   3271    } else {
   3272      rd_stats->skip_txfm = 0;
   3273      set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 0);
   3274    }
   3275    if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
   3276      rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0];
   3277    av1_set_txb_context(x, 0, block, tx_size, ta, tl);
   3278    txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
   3279                          tx_size);
   3280  } else {
   3281    const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
   3282    const int txb_width = tx_size_wide_unit[sub_txs];
   3283    const int txb_height = tx_size_high_unit[sub_txs];
   3284    const int step = txb_height * txb_width;
   3285    const int row_end =
   3286        AOMMIN(tx_size_high_unit[tx_size], max_blocks_high - blk_row);
   3287    const int col_end =
   3288        AOMMIN(tx_size_wide_unit[tx_size], max_blocks_wide - blk_col);
   3289    RD_STATS pn_rd_stats;
   3290    int64_t this_rd = 0;
   3291    assert(txb_width > 0 && txb_height > 0);
   3292 
   3293    for (int row = 0; row < row_end; row += txb_height) {
   3294      const int offsetr = blk_row + row;
   3295      for (int col = 0; col < col_end; col += txb_width) {
   3296        const int offsetc = blk_col + col;
   3297 
   3298        av1_init_rd_stats(&pn_rd_stats);
   3299        tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
   3300                     depth + 1, above_ctx, left_ctx, tx_above, tx_left,
   3301                     ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
   3302        if (pn_rd_stats.rate == INT_MAX) {
   3303          av1_invalid_rd_stats(rd_stats);
   3304          return;
   3305        }
   3306        av1_merge_rd_stats(rd_stats, &pn_rd_stats);
   3307        this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
   3308        block += step;
   3309      }
   3310    }
   3311 
   3312    if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
   3313      rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1];
   3314  }
   3315 }
   3316 
   3317 // search for tx type with tx sizes already decided for a inter-predicted luma
   3318 // partition block. It's used only when some speed features are enabled.
   3319 // Return value 0: early termination triggered, no valid rd cost available;
   3320 //              1: rd cost values are valid.
   3321 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
   3322                           RD_STATS *rd_stats, BLOCK_SIZE bsize,
   3323                           int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
   3324  if (ref_best_rd < 0) {
   3325    av1_invalid_rd_stats(rd_stats);
   3326    return 0;
   3327  }
   3328 
   3329  av1_init_rd_stats(rd_stats);
   3330 
   3331  MACROBLOCKD *const xd = &x->e_mbd;
   3332  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
   3333  const struct macroblockd_plane *const pd = &xd->plane[0];
   3334  const int mi_width = mi_size_wide[bsize];
   3335  const int mi_height = mi_size_high[bsize];
   3336  const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
   3337  const int bh = tx_size_high_unit[max_tx_size];
   3338  const int bw = tx_size_wide_unit[max_tx_size];
   3339  const int step = bw * bh;
   3340  const int init_depth = get_search_init_depth(
   3341      mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
   3342  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
   3343  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
   3344  TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
   3345  TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
   3346  av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
   3347  memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
   3348  memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
   3349 
   3350  int64_t this_rd = 0;
   3351  for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
   3352    for (int idx = 0; idx < mi_width; idx += bw) {
   3353      RD_STATS pn_rd_stats;
   3354      av1_init_rd_stats(&pn_rd_stats);
   3355      tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
   3356                   ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
   3357                   &pn_rd_stats, ftxs_mode);
   3358      if (pn_rd_stats.rate == INT_MAX) {
   3359        av1_invalid_rd_stats(rd_stats);
   3360        return 0;
   3361      }
   3362      av1_merge_rd_stats(rd_stats, &pn_rd_stats);
   3363      this_rd +=
   3364          AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
   3365                 RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
   3366      block += step;
   3367    }
   3368  }
   3369 
   3370  const int skip_ctx = av1_get_skip_txfm_context(xd);
   3371  const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
   3372  const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
   3373  const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
   3374  this_rd =
   3375      RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist);
   3376  if (skip_txfm_rd < this_rd) {
   3377    this_rd = skip_txfm_rd;
   3378    rd_stats->rate = 0;
   3379    rd_stats->dist = rd_stats->sse;
   3380    rd_stats->skip_txfm = 1;
   3381  }
   3382 
   3383  const int is_cost_valid = this_rd > ref_best_rd;
   3384  if (!is_cost_valid) {
   3385    // reset cost value
   3386    av1_invalid_rd_stats(rd_stats);
   3387  }
   3388  return is_cost_valid;
   3389 }
   3390 
   3391 // Search for the best transform size and type for current inter-predicted
   3392 // luma block with recursive transform block partitioning. The obtained
   3393 // transform selection will be saved in xd->mi[0], the corresponding RD stats
   3394 // will be saved in rd_stats. The returned value is the corresponding RD cost.
   3395 static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
   3396                                       RD_STATS *rd_stats, BLOCK_SIZE bsize,
   3397                                       int64_t ref_best_rd) {
   3398  MACROBLOCKD *const xd = &x->e_mbd;
   3399  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
   3400  assert(is_inter_block(xd->mi[0]));
   3401  assert(bsize < BLOCK_SIZES_ALL);
   3402  const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD;
   3403  int64_t rd_thresh = ref_best_rd;
   3404  if (rd_thresh == 0) {
   3405    av1_invalid_rd_stats(rd_stats);
   3406    return INT64_MAX;
   3407  }
   3408  if (fast_tx_search && rd_thresh < INT64_MAX) {
   3409    if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
   3410  }
   3411  assert(rd_thresh > 0);
   3412  const FAST_TX_SEARCH_MODE ftxs_mode =
   3413      fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
   3414  const struct macroblockd_plane *const pd = &xd->plane[0];
   3415  assert(bsize < BLOCK_SIZES_ALL);
   3416  const int mi_width = mi_size_wide[bsize];
   3417  const int mi_height = mi_size_high[bsize];
   3418  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
   3419  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
   3420  TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
   3421  TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
   3422  av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
   3423  memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
   3424  memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
   3425  const int init_depth = get_search_init_depth(
   3426      mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
   3427  const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
   3428  const int bh = tx_size_high_unit[max_tx_size];
   3429  const int bw = tx_size_wide_unit[max_tx_size];
   3430  const int step = bw * bh;
   3431  const int skip_ctx = av1_get_skip_txfm_context(xd);
   3432  const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0];
   3433  const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1];
   3434  int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0);
   3435  int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0);
   3436  int block = 0;
   3437 
   3438  av1_init_rd_stats(rd_stats);
   3439  for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
   3440    for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
   3441      const int64_t best_rd_sofar =
   3442          (rd_thresh == INT64_MAX)
   3443              ? INT64_MAX
   3444              : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd)));
   3445      int is_cost_valid = 1;
   3446      RD_STATS pn_rd_stats;
   3447      // Search for the best transform block size and type for the sub-block.
   3448      select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
   3449                      ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
   3450                      best_rd_sofar, &is_cost_valid, ftxs_mode);
   3451      if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
   3452        av1_invalid_rd_stats(rd_stats);
   3453        return INT64_MAX;
   3454      }
   3455      av1_merge_rd_stats(rd_stats, &pn_rd_stats);
   3456      skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
   3457      no_skip_txfm_rd =
   3458          RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
   3459      block += step;
   3460    }
   3461  }
   3462 
   3463  if (rd_stats->rate == INT_MAX) return INT64_MAX;
   3464 
   3465  rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd);
   3466 
   3467  // If fast_tx_search is true, only DCT and 1D DCT were tested in
   3468  // select_inter_block_yrd() above. Do a better search for tx type with
   3469  // tx sizes already decided.
   3470  if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
   3471    if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
   3472      return INT64_MAX;
   3473  }
   3474 
   3475  int64_t final_rd;
   3476  if (rd_stats->skip_txfm) {
   3477    final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
   3478  } else {
   3479    final_rd =
   3480        RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
   3481    if (!xd->lossless[xd->mi[0]->segment_id]) {
   3482      final_rd =
   3483          AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse));
   3484    }
   3485  }
   3486 
   3487  return final_rd;
   3488 }
   3489 
   3490 // Return 1 to terminate transform search early. The decision is made based on
   3491 // the comparison with the reference RD cost and the model-estimated RD cost.
   3492 static inline int model_based_tx_search_prune(const AV1_COMP *cpi,
   3493                                              MACROBLOCK *x, BLOCK_SIZE bsize,
   3494                                              int64_t ref_best_rd) {
   3495  const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
   3496  assert(level >= 0 && level <= 2);
   3497  int model_rate;
   3498  int64_t model_dist;
   3499  uint8_t model_skip;
   3500  MACROBLOCKD *const xd = &x->e_mbd;
   3501  model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
   3502      cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
   3503      NULL, NULL, NULL);
   3504  if (model_skip) return 0;
   3505  const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
   3506  // TODO(debargha, urvang): Improve the model and make the check below
   3507  // tighter.
   3508  static const int prune_factor_by8[] = { 3, 5 };
   3509  const int factor = prune_factor_by8[level - 1];
   3510  return ((model_rd * factor) >> 3) > ref_best_rd;
   3511 }
   3512 
   3513 void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
   3514                                         RD_STATS *rd_stats, BLOCK_SIZE bsize,
   3515                                         int64_t ref_best_rd) {
   3516  MACROBLOCKD *const xd = &x->e_mbd;
   3517  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
   3518  assert(is_inter_block(xd->mi[0]));
   3519 
   3520  av1_invalid_rd_stats(rd_stats);
   3521 
   3522  // If modeled RD cost is a lot worse than the best so far, terminate early.
   3523  if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
   3524      ref_best_rd != INT64_MAX) {
   3525    if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
   3526  }
   3527 
   3528  // Hashing based speed feature. If the hash of the prediction residue block is
   3529  // found in the hash table, use previous search results and terminate early.
   3530  uint32_t hash = 0;
   3531  MB_RD_RECORD *mb_rd_record = NULL;
   3532  const int mi_row = x->e_mbd.mi_row;
   3533  const int mi_col = x->e_mbd.mi_col;
   3534  const int within_border =
   3535      mi_row >= xd->tile.mi_row_start &&
   3536      (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
   3537      mi_col >= xd->tile.mi_col_start &&
   3538      (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
   3539  const int is_mb_rd_hash_enabled =
   3540      (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
   3541  const int n4 = bsize_to_num_blk(bsize);
   3542  if (is_mb_rd_hash_enabled) {
   3543    hash = get_block_residue_hash(x, bsize);
   3544    mb_rd_record = x->txfm_search_info.mb_rd_record;
   3545    const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
   3546    if (match_index != -1) {
   3547      MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
   3548      fetch_mb_rd_info(n4, mb_rd_info, rd_stats, x);
   3549      return;
   3550    }
   3551  }
   3552 
   3553  // If we predict that skip is the optimal RD decision - set the respective
   3554  // context and terminate early.
   3555  int64_t dist;
   3556  if (txfm_params->skip_txfm_level &&
   3557      predict_skip_txfm(x, bsize, &dist,
   3558                        cpi->common.features.reduced_tx_set_used)) {
   3559    set_skip_txfm(x, rd_stats, bsize, dist);
   3560    // Save the RD search results into mb_rd_record.
   3561    if (is_mb_rd_hash_enabled)
   3562      save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
   3563    return;
   3564  }
   3565 #if CONFIG_SPEED_STATS
   3566  ++x->txfm_search_info.tx_search_count;
   3567 #endif  // CONFIG_SPEED_STATS
   3568 
   3569  const int64_t rd =
   3570      select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd);
   3571 
   3572  if (rd == INT64_MAX) {
   3573    // We should always find at least one candidate unless ref_best_rd is less
   3574    // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
   3575    // might have failed to find something better)
   3576    assert(ref_best_rd != INT64_MAX);
   3577    av1_invalid_rd_stats(rd_stats);
   3578    return;
   3579  }
   3580 
   3581  // Save the RD search results into mb_rd_record.
   3582  if (is_mb_rd_hash_enabled) {
   3583    assert(mb_rd_record != NULL);
   3584    save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
   3585  }
   3586 }
   3587 
   3588 void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
   3589                                       RD_STATS *rd_stats, BLOCK_SIZE bs,
   3590                                       int64_t ref_best_rd) {
   3591  MACROBLOCKD *const xd = &x->e_mbd;
   3592  MB_MODE_INFO *const mbmi = xd->mi[0];
   3593  const TxfmSearchParams *tx_params = &x->txfm_search_params;
   3594  assert(bs == mbmi->bsize);
   3595  const int is_inter = is_inter_block(mbmi);
   3596  const int mi_row = xd->mi_row;
   3597  const int mi_col = xd->mi_col;
   3598 
   3599  av1_init_rd_stats(rd_stats);
   3600 
   3601  // Hashing based speed feature for inter blocks. If the hash of the residue
   3602  // block is found in the table, use previously saved search results and
   3603  // terminate early.
   3604  uint32_t hash = 0;
   3605  MB_RD_RECORD *mb_rd_record = NULL;
   3606  const int num_blks = bsize_to_num_blk(bs);
   3607  if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
   3608    const int within_border =
   3609        mi_row >= xd->tile.mi_row_start &&
   3610        (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
   3611        mi_col >= xd->tile.mi_col_start &&
   3612        (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
   3613    if (within_border) {
   3614      hash = get_block_residue_hash(x, bs);
   3615      mb_rd_record = x->txfm_search_info.mb_rd_record;
   3616      const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
   3617      if (match_index != -1) {
   3618        MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
   3619        fetch_mb_rd_info(num_blks, mb_rd_info, rd_stats, x);
   3620        return;
   3621      }
   3622    }
   3623  }
   3624 
   3625  // If we predict that skip is the optimal RD decision - set the respective
   3626  // context and terminate early.
   3627  int64_t dist;
   3628  if (tx_params->skip_txfm_level && is_inter &&
   3629      !xd->lossless[mbmi->segment_id] &&
   3630      predict_skip_txfm(x, bs, &dist,
   3631                        cpi->common.features.reduced_tx_set_used)) {
   3632    // Populate rdstats as per skip decision
   3633    set_skip_txfm(x, rd_stats, bs, dist);
   3634    // Save the RD search results into mb_rd_record.
   3635    if (mb_rd_record) {
   3636      save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
   3637    }
   3638    return;
   3639  }
   3640 
   3641  if (xd->lossless[mbmi->segment_id]) {
   3642    // Lossless mode can only pick the smallest (4x4) transform size.
   3643    choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
   3644  } else if (tx_params->tx_size_search_method == USE_LARGESTALL) {
   3645    choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
   3646  } else {
   3647    choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
   3648  }
   3649 
   3650  // Save the RD search results into mb_rd_record for possible reuse in future.
   3651  if (mb_rd_record) {
   3652    save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
   3653  }
   3654 }
   3655 
   3656 int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
   3657                  BLOCK_SIZE bsize, int64_t ref_best_rd) {
   3658  av1_init_rd_stats(rd_stats);
   3659  if (ref_best_rd < 0) return 0;
   3660  if (!x->e_mbd.is_chroma_ref) return 1;
   3661 
   3662  MACROBLOCKD *const xd = &x->e_mbd;
   3663  MB_MODE_INFO *const mbmi = xd->mi[0];
   3664  struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
   3665  const int is_inter = is_inter_block(mbmi);
   3666  int64_t this_rd = 0, skip_txfm_rd = 0;
   3667  const BLOCK_SIZE plane_bsize =
   3668      get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
   3669 
   3670  if (is_inter) {
   3671    for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
   3672      av1_subtract_plane(x, plane_bsize, plane);
   3673  }
   3674 
   3675  const int skip_trellis = 0;
   3676  const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
   3677  int is_cost_valid = 1;
   3678  for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
   3679    RD_STATS this_rd_stats;
   3680    int64_t chroma_ref_best_rd = ref_best_rd;
   3681    // For inter blocks, refined ref_best_rd is used for early exit
   3682    // For intra blocks, even though current rd crosses ref_best_rd, early
   3683    // exit is not recommended as current rd is used for gating subsequent
   3684    // modes as well (say, for angular modes)
   3685    // TODO(any): Extend the early exit mechanism for intra modes as well
   3686    if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
   3687        chroma_ref_best_rd != INT64_MAX)
   3688      chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
   3689    av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
   3690                         plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
   3691    if (this_rd_stats.rate == INT_MAX) {
   3692      is_cost_valid = 0;
   3693      break;
   3694    }
   3695    av1_merge_rd_stats(rd_stats, &this_rd_stats);
   3696    this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
   3697    skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
   3698    if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
   3699      is_cost_valid = 0;
   3700      break;
   3701    }
   3702  }
   3703 
   3704  if (!is_cost_valid) {
   3705    // reset cost value
   3706    av1_invalid_rd_stats(rd_stats);
   3707  }
   3708 
   3709  return is_cost_valid;
   3710 }
   3711 
   3712 void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
   3713                          RD_STATS *rd_stats, int64_t ref_best_rd,
   3714                          int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
   3715                          TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
   3716                          int skip_trellis) {
   3717  assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
   3718 
   3719  if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
   3720      txsize_sqr_up_map[tx_size] == TX_64X64) {
   3721    av1_invalid_rd_stats(rd_stats);
   3722    return;
   3723  }
   3724 
   3725  if (current_rd > ref_best_rd) {
   3726    av1_invalid_rd_stats(rd_stats);
   3727    return;
   3728  }
   3729 
   3730  MACROBLOCKD *const xd = &x->e_mbd;
   3731  const struct macroblockd_plane *const pd = &xd->plane[plane];
   3732  struct rdcost_block_args args;
   3733  av1_zero(args);
   3734  args.x = x;
   3735  args.cpi = cpi;
   3736  args.best_rd = ref_best_rd;
   3737  args.current_rd = current_rd;
   3738  args.ftxs_mode = ftxs_mode;
   3739  args.skip_trellis = skip_trellis;
   3740  av1_init_rd_stats(&args.rd_stats);
   3741 
   3742  av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
   3743  av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
   3744                                         &args);
   3745 
   3746  MB_MODE_INFO *const mbmi = xd->mi[0];
   3747  const int is_inter = is_inter_block(mbmi);
   3748  const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
   3749 
   3750  if (invalid_rd) {
   3751    av1_invalid_rd_stats(rd_stats);
   3752  } else {
   3753    *rd_stats = args.rd_stats;
   3754  }
   3755 }
   3756 
   3757 int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
   3758                    RD_STATS *rd_stats, RD_STATS *rd_stats_y,
   3759                    RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
   3760  MACROBLOCKD *const xd = &x->e_mbd;
   3761  TxfmSearchParams *txfm_params = &x->txfm_search_params;
   3762  const int skip_ctx = av1_get_skip_txfm_context(xd);
   3763  const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0],
   3764                                  x->mode_costs.skip_txfm_cost[skip_ctx][1] };
   3765  const int64_t min_header_rate =
   3766      mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]);
   3767  // Account for minimum skip and non_skip rd.
   3768  // Eventually either one of them will be added to mode_rate
   3769  const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
   3770  if (min_header_rd_possible > ref_best_rd) {
   3771    av1_invalid_rd_stats(rd_stats_y);
   3772    return 0;
   3773  }
   3774 
   3775  const AV1_COMMON *cm = &cpi->common;
   3776  MB_MODE_INFO *const mbmi = xd->mi[0];
   3777  const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
   3778  const int64_t rd_thresh =
   3779      ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
   3780  av1_init_rd_stats(rd_stats);
   3781  av1_init_rd_stats(rd_stats_y);
   3782  rd_stats->rate = mode_rate;
   3783 
   3784  // cost and distortion
   3785  av1_subtract_plane(x, bsize, 0);
   3786  if (txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
   3787      !xd->lossless[mbmi->segment_id]) {
   3788    av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
   3789 #if CONFIG_COLLECT_RD_STATS == 2
   3790    PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
   3791 #endif  // CONFIG_COLLECT_RD_STATS == 2
   3792  } else {
   3793    av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
   3794    memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
   3795    for (int i = 0; i < xd->height * xd->width; ++i)
   3796      set_blk_skip(x->txfm_search_info.blk_skip, 0, i, rd_stats_y->skip_txfm);
   3797  }
   3798 
   3799  if (rd_stats_y->rate == INT_MAX) return 0;
   3800 
   3801  av1_merge_rd_stats(rd_stats, rd_stats_y);
   3802 
   3803  const int64_t non_skip_txfm_rdcosty =
   3804      RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist);
   3805  const int64_t skip_txfm_rdcosty =
   3806      RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse);
   3807  const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty);
   3808  if (min_rdcosty > ref_best_rd) return 0;
   3809 
   3810  av1_init_rd_stats(rd_stats_uv);
   3811  const int num_planes = av1_num_planes(cm);
   3812  if (num_planes > 1) {
   3813    int64_t ref_best_chroma_rd = ref_best_rd;
   3814    // Calculate best rd cost possible for chroma
   3815    if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
   3816        (ref_best_chroma_rd != INT64_MAX)) {
   3817      ref_best_chroma_rd = (ref_best_chroma_rd -
   3818                            AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty));
   3819    }
   3820    const int is_cost_valid_uv =
   3821        av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
   3822    if (!is_cost_valid_uv) return 0;
   3823    av1_merge_rd_stats(rd_stats, rd_stats_uv);
   3824  }
   3825 
   3826  int choose_skip_txfm = rd_stats->skip_txfm;
   3827  if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) {
   3828    const int64_t rdcost_no_skip_txfm = RDCOST(
   3829        x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0],
   3830        rd_stats->dist);
   3831    const int64_t rdcost_skip_txfm =
   3832        RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse);
   3833    if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1;
   3834  }
   3835  if (choose_skip_txfm) {
   3836    rd_stats_y->rate = 0;
   3837    rd_stats_uv->rate = 0;
   3838    rd_stats->rate = mode_rate + skip_txfm_cost[1];
   3839    rd_stats->dist = rd_stats->sse;
   3840    rd_stats_y->dist = rd_stats_y->sse;
   3841    rd_stats_uv->dist = rd_stats_uv->sse;
   3842    mbmi->skip_txfm = 1;
   3843    if (rd_stats->skip_txfm) {
   3844      const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
   3845      if (tmprd > ref_best_rd) return 0;
   3846    }
   3847  } else {
   3848    rd_stats->rate += skip_txfm_cost[0];
   3849    mbmi->skip_txfm = 0;
   3850  }
   3851 
   3852  return 1;
   3853 }