tx_search.c (154649B)
1 /* 2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include "av1/common/cfl.h" 13 #include "av1/common/reconintra.h" 14 #include "av1/encoder/block.h" 15 #include "av1/encoder/hybrid_fwd_txfm.h" 16 #include "av1/common/idct.h" 17 #include "av1/encoder/model_rd.h" 18 #include "av1/encoder/random.h" 19 #include "av1/encoder/rdopt_utils.h" 20 #include "av1/encoder/sorting_network.h" 21 #include "av1/encoder/tx_prune_model_weights.h" 22 #include "av1/encoder/tx_search.h" 23 #include "av1/encoder/txb_rdopt.h" 24 25 #define PROB_THRESH_OFFSET_TX_TYPE 100 26 27 struct rdcost_block_args { 28 const AV1_COMP *cpi; 29 MACROBLOCK *x; 30 ENTROPY_CONTEXT t_above[MAX_MIB_SIZE]; 31 ENTROPY_CONTEXT t_left[MAX_MIB_SIZE]; 32 RD_STATS rd_stats; 33 int64_t current_rd; 34 int64_t best_rd; 35 int exit_early; 36 int incomplete_exit; 37 FAST_TX_SEARCH_MODE ftxs_mode; 38 int skip_trellis; 39 }; 40 41 typedef struct { 42 int64_t rd; 43 int txb_entropy_ctx; 44 TX_TYPE tx_type; 45 } TxCandidateInfo; 46 47 // origin_threshold * 128 / 100 48 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = { 49 { 50 64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68, 51 68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68, 52 }, 53 { 54 88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68, 55 68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68, 56 }, 57 { 58 90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74, 59 74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74, 60 }, 61 }; 62 63 // lookup table for predict_skip_txfm 64 // int max_tx_size = max_txsize_rect_lookup[bsize]; 65 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16) 66 // max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16); 67 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = { 68 TX_4X4, TX_4X8, TX_8X4, TX_8X8, TX_8X16, TX_16X8, 69 TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, 70 TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16, TX_16X4, 71 TX_8X8, TX_8X8, TX_16X16, TX_16X16, 72 }; 73 74 // look-up table for sqrt of number of pixels in a transform block 75 // rounded up to the nearest integer. 76 static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4, 8, 16, 32, 32, 6, 6, 77 12, 12, 23, 23, 32, 32, 8, 78 8, 16, 16, 23, 23 }; 79 80 static inline uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) { 81 const int rows = block_size_high[bsize]; 82 const int cols = block_size_wide[bsize]; 83 const int16_t *diff = x->plane[0].src_diff; 84 const uint32_t hash = 85 av1_get_crc32c_value(&x->txfm_search_info.mb_rd_record->crc_calculator, 86 (uint8_t *)diff, 2 * rows * cols); 87 return (hash << 5) + bsize; 88 } 89 90 static inline int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record, 91 const int64_t ref_best_rd, 92 const uint32_t hash) { 93 int32_t match_index = -1; 94 if (ref_best_rd != INT64_MAX) { 95 for (int i = 0; i < mb_rd_record->num; ++i) { 96 const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN; 97 // If there is a match in the mb_rd_record, fetch the RD decision and 98 // terminate early. 99 if (mb_rd_record->mb_rd_info[index].hash_value == hash) { 100 match_index = index; 101 break; 102 } 103 } 104 } 105 return match_index; 106 } 107 108 static inline void fetch_mb_rd_info(int n4, const MB_RD_INFO *const mb_rd_info, 109 RD_STATS *const rd_stats, 110 MACROBLOCK *const x) { 111 MACROBLOCKD *const xd = &x->e_mbd; 112 MB_MODE_INFO *const mbmi = xd->mi[0]; 113 mbmi->tx_size = mb_rd_info->tx_size; 114 memcpy(x->txfm_search_info.blk_skip, mb_rd_info->blk_skip, 115 sizeof(mb_rd_info->blk_skip[0]) * n4); 116 av1_copy(mbmi->inter_tx_size, mb_rd_info->inter_tx_size); 117 av1_copy_array(xd->tx_type_map, mb_rd_info->tx_type_map, n4); 118 *rd_stats = mb_rd_info->rd_stats; 119 } 120 121 int64_t av1_pixel_diff_dist(const MACROBLOCK *x, int plane, int blk_row, 122 int blk_col, const BLOCK_SIZE plane_bsize, 123 const BLOCK_SIZE tx_bsize, 124 unsigned int *block_mse_q8) { 125 int visible_rows, visible_cols; 126 const MACROBLOCKD *xd = &x->e_mbd; 127 get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL, 128 NULL, &visible_cols, &visible_rows); 129 const int diff_stride = block_size_wide[plane_bsize]; 130 const int16_t *diff = x->plane[plane].src_diff; 131 132 diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2); 133 uint64_t sse = 134 aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows); 135 if (block_mse_q8 != NULL) { 136 if (visible_cols > 0 && visible_rows > 0) 137 *block_mse_q8 = 138 (unsigned int)((256 * sse) / (visible_cols * visible_rows)); 139 else 140 *block_mse_q8 = UINT_MAX; 141 } 142 return sse; 143 } 144 145 // Computes the residual block's SSE and mean on all visible 4x4s in the 146 // transform block 147 static inline int64_t pixel_diff_stats( 148 MACROBLOCK *x, int plane, int blk_row, int blk_col, 149 const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize, 150 unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) { 151 int visible_rows, visible_cols; 152 const MACROBLOCKD *xd = &x->e_mbd; 153 get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL, 154 NULL, &visible_cols, &visible_rows); 155 const int diff_stride = block_size_wide[plane_bsize]; 156 const int16_t *diff = x->plane[plane].src_diff; 157 158 diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2); 159 uint64_t sse = 0; 160 int sum = 0; 161 sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum); 162 if (visible_cols > 0 && visible_rows > 0) { 163 double norm_factor = 1.0 / (visible_cols * visible_rows); 164 int sign_sum = sum > 0 ? 1 : -1; 165 // Conversion to transform domain 166 *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7; 167 *per_px_mean = sign_sum * (*per_px_mean); 168 *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse)); 169 *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum)); 170 } else { 171 *block_mse_q8 = UINT_MAX; 172 } 173 return sse; 174 } 175 176 // Uses simple features on top of DCT coefficients to quickly predict 177 // whether optimal RD decision is to skip encoding the residual. 178 // The sse value is stored in dist. 179 static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist, 180 int reduced_tx_set) { 181 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 182 const int bw = block_size_wide[bsize]; 183 const int bh = block_size_high[bsize]; 184 const MACROBLOCKD *xd = &x->e_mbd; 185 const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd); 186 187 *dist = av1_pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL); 188 189 const int64_t mse = *dist / bw / bh; 190 // Normalized quantizer takes the transform upscaling factor (8 for tx size 191 // smaller than 32) into account. 192 const int16_t normalized_dc_q = dc_q >> 3; 193 const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8; 194 // For faster early skip decision, use dist to compare against threshold so 195 // that quality risk is less for the skip=1 decision. Otherwise, use mse 196 // since the fwd_txfm coeff checks will take care of quality 197 // TODO(any): Use dist to return 0 when skip_txfm_level is 1 198 int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse; 199 // Predict not to skip when error is larger than threshold. 200 if (pred_err > mse_thresh) return 0; 201 // Return as skip otherwise for aggressive early skip 202 else if (txfm_params->skip_txfm_level >= 2) 203 return 1; 204 205 const int max_tx_size = max_predict_sf_tx_size[bsize]; 206 const int tx_h = tx_size_high[max_tx_size]; 207 const int tx_w = tx_size_wide[max_tx_size]; 208 DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]); 209 TxfmParam param; 210 param.tx_type = DCT_DCT; 211 param.tx_size = max_tx_size; 212 param.bd = xd->bd; 213 param.is_hbd = is_cur_buf_hbd(xd); 214 param.lossless = 0; 215 param.tx_set_type = av1_get_ext_tx_set_type( 216 param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set); 217 const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2); 218 const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize]; 219 const int16_t *src_diff = x->plane[0].src_diff; 220 const int n_coeff = tx_w * tx_h; 221 const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd); 222 const uint32_t dc_thresh = max_qcoef_thresh * dc_q; 223 const uint32_t ac_thresh = max_qcoef_thresh * ac_q; 224 for (int row = 0; row < bh; row += tx_h) { 225 for (int col = 0; col < bw; col += tx_w) { 226 av1_fwd_txfm(src_diff + col, coefs, bw, ¶m); 227 // Operating on TX domain, not pixels; we want the QTX quantizers 228 const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7); 229 if (dc_coef >= dc_thresh) return 0; 230 for (int i = 1; i < n_coeff; ++i) { 231 const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7); 232 if (ac_coef >= ac_thresh) return 0; 233 } 234 } 235 src_diff += tx_h * bw; 236 } 237 return 1; 238 } 239 240 // Used to set proper context for early termination with skip = 1. 241 static inline void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats, 242 BLOCK_SIZE bsize, int64_t dist) { 243 MACROBLOCKD *const xd = &x->e_mbd; 244 MB_MODE_INFO *const mbmi = xd->mi[0]; 245 const int n4 = bsize_to_num_blk(bsize); 246 const TX_SIZE tx_size = max_txsize_rect_lookup[bsize]; 247 memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4); 248 memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size)); 249 mbmi->tx_size = tx_size; 250 for (int i = 0; i < n4; ++i) 251 set_blk_skip(x->txfm_search_info.blk_skip, 0, i, 1); 252 rd_stats->skip_txfm = 1; 253 if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2); 254 rd_stats->dist = rd_stats->sse = (dist << 4); 255 // Though decision is to make the block as skip based on luma stats, 256 // it is possible that block becomes non skip after chroma rd. In addition 257 // intermediate non skip costs calculated by caller function will be 258 // incorrect, if rate is set as zero (i.e., if zero_blk_rate is not 259 // accounted). Hence intermediate rate is populated to code the luma tx blks 260 // as skip, the caller function based on final rd decision (i.e., skip vs 261 // non-skip) sets the final rate accordingly. Here the rate populated 262 // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx 263 // size possible) in the current block. Eg: For 128*128 block, rate would be 264 // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx 265 // block as 'all zeros' 266 ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE]; 267 ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE]; 268 av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl); 269 ENTROPY_CONTEXT *ta = ctxa; 270 ENTROPY_CONTEXT *tl = ctxl; 271 const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size); 272 TXB_CTX txb_ctx; 273 get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx); 274 const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y] 275 .txb_skip_cost[txb_ctx.txb_skip_ctx][1]; 276 rd_stats->rate = zero_blk_rate * 277 (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) * 278 (block_size_high[bsize] >> tx_size_high_log2[tx_size]); 279 } 280 281 static inline void save_mb_rd_info(int n4, uint32_t hash, 282 const MACROBLOCK *const x, 283 const RD_STATS *const rd_stats, 284 MB_RD_RECORD *mb_rd_record) { 285 int index; 286 if (mb_rd_record->num < RD_RECORD_BUFFER_LEN) { 287 index = 288 (mb_rd_record->index_start + mb_rd_record->num) % RD_RECORD_BUFFER_LEN; 289 ++mb_rd_record->num; 290 } else { 291 index = mb_rd_record->index_start; 292 mb_rd_record->index_start = 293 (mb_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN; 294 } 295 MB_RD_INFO *const mb_rd_info = &mb_rd_record->mb_rd_info[index]; 296 const MACROBLOCKD *const xd = &x->e_mbd; 297 const MB_MODE_INFO *const mbmi = xd->mi[0]; 298 mb_rd_info->hash_value = hash; 299 mb_rd_info->tx_size = mbmi->tx_size; 300 memcpy(mb_rd_info->blk_skip, x->txfm_search_info.blk_skip, 301 sizeof(mb_rd_info->blk_skip[0]) * n4); 302 av1_copy(mb_rd_info->inter_tx_size, mbmi->inter_tx_size); 303 av1_copy_array(mb_rd_info->tx_type_map, xd->tx_type_map, n4); 304 mb_rd_info->rd_stats = *rd_stats; 305 } 306 307 static int get_search_init_depth(int mi_width, int mi_height, int is_inter, 308 const SPEED_FEATURES *sf, 309 int tx_size_search_method) { 310 if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH; 311 312 if (sf->tx_sf.tx_size_search_lgr_block) { 313 if (mi_width > mi_size_wide[BLOCK_64X64] || 314 mi_height > mi_size_high[BLOCK_64X64]) 315 return MAX_VARTX_DEPTH; 316 } 317 318 if (is_inter) { 319 return (mi_height != mi_width) 320 ? sf->tx_sf.inter_tx_size_search_init_depth_rect 321 : sf->tx_sf.inter_tx_size_search_init_depth_sqr; 322 } else { 323 return (mi_height != mi_width) 324 ? sf->tx_sf.intra_tx_size_search_init_depth_rect 325 : sf->tx_sf.intra_tx_size_search_init_depth_sqr; 326 } 327 } 328 329 static inline void select_tx_block( 330 const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block, 331 TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta, 332 ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, 333 RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd, 334 int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode); 335 336 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values 337 // 0: Do not collect any RD stats 338 // 1: Collect RD stats for transform units 339 // 2: Collect RD stats for partition units 340 #if CONFIG_COLLECT_RD_STATS 341 342 static inline void get_energy_distribution_fine( 343 const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride, 344 const uint8_t *dst, int dst_stride, int need_4th, double *hordist, 345 double *verdist) { 346 const int bw = block_size_wide[bsize]; 347 const int bh = block_size_high[bsize]; 348 unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; 349 350 if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) { 351 // Special cases: calculate 'esq' values manually, as we don't have 'vf' 352 // functions for the 16 (very small) sub-blocks of this block. 353 const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3; 354 const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3; 355 assert(bw <= 32); 356 assert(bh <= 32); 357 assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15); 358 if (cpi->common.seq_params->use_highbitdepth) { 359 const uint16_t *src16 = CONVERT_TO_SHORTPTR(src); 360 const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst); 361 for (int i = 0; i < bh; ++i) 362 for (int j = 0; j < bw; ++j) { 363 const int index = (j >> w_shift) + ((i >> h_shift) << 2); 364 esq[index] += 365 (src16[j + i * src_stride] - dst16[j + i * dst_stride]) * 366 (src16[j + i * src_stride] - dst16[j + i * dst_stride]); 367 } 368 } else { 369 for (int i = 0; i < bh; ++i) 370 for (int j = 0; j < bw; ++j) { 371 const int index = (j >> w_shift) + ((i >> h_shift) << 2); 372 esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) * 373 (src[j + i * src_stride] - dst[j + i * dst_stride]); 374 } 375 } 376 } else { // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks. 377 const int f_index = 378 (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16; 379 assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL); 380 const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index; 381 assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]); 382 assert(block_size_high[bsize] == 4 * block_size_high[subsize]); 383 cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]); 384 cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, 385 dst_stride, &esq[1]); 386 cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, 387 dst_stride, &esq[2]); 388 cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, 389 dst_stride, &esq[3]); 390 src += bh / 4 * src_stride; 391 dst += bh / 4 * dst_stride; 392 393 cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]); 394 cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, 395 dst_stride, &esq[5]); 396 cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, 397 dst_stride, &esq[6]); 398 cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, 399 dst_stride, &esq[7]); 400 src += bh / 4 * src_stride; 401 dst += bh / 4 * dst_stride; 402 403 cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]); 404 cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, 405 dst_stride, &esq[9]); 406 cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, 407 dst_stride, &esq[10]); 408 cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, 409 dst_stride, &esq[11]); 410 src += bh / 4 * src_stride; 411 dst += bh / 4 * dst_stride; 412 413 cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]); 414 cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, 415 dst_stride, &esq[13]); 416 cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, 417 dst_stride, &esq[14]); 418 cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, 419 dst_stride, &esq[15]); 420 } 421 422 double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] + 423 esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] + 424 esq[12] + esq[13] + esq[14] + esq[15]; 425 if (total > 0) { 426 const double e_recip = 1.0 / total; 427 hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip; 428 hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip; 429 hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip; 430 if (need_4th) { 431 hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip; 432 } 433 verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip; 434 verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip; 435 verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip; 436 if (need_4th) { 437 verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip; 438 } 439 } else { 440 hordist[0] = verdist[0] = 0.25; 441 hordist[1] = verdist[1] = 0.25; 442 hordist[2] = verdist[2] = 0.25; 443 if (need_4th) { 444 hordist[3] = verdist[3] = 0.25; 445 } 446 } 447 } 448 449 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) { 450 double sum = 0.0; 451 for (int j = 0; j < h; ++j) { 452 for (int i = 0; i < w; ++i) { 453 const int err = diff[j * stride + i]; 454 sum += err * err; 455 } 456 } 457 assert(w > 0 && h > 0); 458 return sum / (w * h); 459 } 460 461 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) { 462 double sum = 0.0; 463 for (int j = 0; j < h; ++j) { 464 for (int i = 0; i < w; ++i) { 465 sum += abs(diff[j * stride + i]); 466 } 467 } 468 assert(w > 0 && h > 0); 469 return sum / (w * h); 470 } 471 472 static inline void get_2x2_normalized_sses_and_sads( 473 const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src, 474 int src_stride, const uint8_t *const dst, int dst_stride, 475 const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr, 476 double *const sad_norm_arr) { 477 const BLOCK_SIZE tx_bsize_half = 478 get_partition_subsize(tx_bsize, PARTITION_SPLIT); 479 if (tx_bsize_half == BLOCK_INVALID) { // manually calculate stats 480 const int half_width = block_size_wide[tx_bsize] / 2; 481 const int half_height = block_size_high[tx_bsize] / 2; 482 for (int row = 0; row < 2; ++row) { 483 for (int col = 0; col < 2; ++col) { 484 const int16_t *const this_src_diff = 485 src_diff + row * half_height * diff_stride + col * half_width; 486 if (sse_norm_arr) { 487 sse_norm_arr[row * 2 + col] = 488 get_sse_norm(this_src_diff, diff_stride, half_width, half_height); 489 } 490 if (sad_norm_arr) { 491 sad_norm_arr[row * 2 + col] = 492 get_sad_norm(this_src_diff, diff_stride, half_width, half_height); 493 } 494 } 495 } 496 } else { // use function pointers to calculate stats 497 const int half_width = block_size_wide[tx_bsize_half]; 498 const int half_height = block_size_high[tx_bsize_half]; 499 const int num_samples_half = half_width * half_height; 500 for (int row = 0; row < 2; ++row) { 501 for (int col = 0; col < 2; ++col) { 502 const uint8_t *const this_src = 503 src + row * half_height * src_stride + col * half_width; 504 const uint8_t *const this_dst = 505 dst + row * half_height * dst_stride + col * half_width; 506 507 if (sse_norm_arr) { 508 unsigned int this_sse; 509 cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst, 510 dst_stride, &this_sse); 511 sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half; 512 } 513 514 if (sad_norm_arr) { 515 const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf( 516 this_src, src_stride, this_dst, dst_stride); 517 sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half; 518 } 519 } 520 } 521 } 522 } 523 524 #if CONFIG_COLLECT_RD_STATS == 1 525 static double get_mean(const int16_t *diff, int stride, int w, int h) { 526 double sum = 0.0; 527 for (int j = 0; j < h; ++j) { 528 for (int i = 0; i < w; ++i) { 529 sum += diff[j * stride + i]; 530 } 531 } 532 assert(w > 0 && h > 0); 533 return sum / (w * h); 534 } 535 static inline void PrintTransformUnitStats( 536 const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats, 537 int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size, 538 TX_TYPE tx_type, int64_t rd) { 539 if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return; 540 541 // Generate small sample to restrict output size. 542 static unsigned int seed = 21743; 543 if (lcg_rand16(&seed) % 256 > 0) return; 544 545 const char output_file[] = "tu_stats.txt"; 546 FILE *fout = fopen(output_file, "a"); 547 if (!fout) return; 548 549 const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size]; 550 const MACROBLOCKD *const xd = &x->e_mbd; 551 const int plane = 0; 552 struct macroblock_plane *const p = &x->plane[plane]; 553 const struct macroblockd_plane *const pd = &xd->plane[plane]; 554 const int txw = tx_size_wide[tx_size]; 555 const int txh = tx_size_high[tx_size]; 556 const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3; 557 const int q_step = p->dequant_QTX[1] >> dequant_shift; 558 const int num_samples = txw * txh; 559 560 const double rate_norm = (double)rd_stats->rate / num_samples; 561 const double dist_norm = (double)rd_stats->dist / num_samples; 562 563 fprintf(fout, "%g %g", rate_norm, dist_norm); 564 565 const int src_stride = p->src.stride; 566 const uint8_t *const src = 567 &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2]; 568 const int dst_stride = pd->dst.stride; 569 const uint8_t *const dst = 570 &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2]; 571 unsigned int sse; 572 cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse); 573 const double sse_norm = (double)sse / num_samples; 574 575 const unsigned int sad = 576 cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride); 577 const double sad_norm = (double)sad / num_samples; 578 579 fprintf(fout, " %g %g", sse_norm, sad_norm); 580 581 const int diff_stride = block_size_wide[plane_bsize]; 582 const int16_t *const src_diff = 583 &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2]; 584 585 double sse_norm_arr[4], sad_norm_arr[4]; 586 get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst, 587 dst_stride, src_diff, diff_stride, 588 sse_norm_arr, sad_norm_arr); 589 for (int i = 0; i < 4; ++i) { 590 fprintf(fout, " %g", sse_norm_arr[i]); 591 } 592 for (int i = 0; i < 4; ++i) { 593 fprintf(fout, " %g", sad_norm_arr[i]); 594 } 595 596 const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type]; 597 const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type]; 598 599 fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size], 600 tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col); 601 602 int model_rate; 603 int64_t model_dist; 604 model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples, 605 &model_rate, &model_dist); 606 const double model_rate_norm = (double)model_rate / num_samples; 607 const double model_dist_norm = (double)model_dist / num_samples; 608 fprintf(fout, " %g %g", model_rate_norm, model_dist_norm); 609 610 const double mean = get_mean(src_diff, diff_stride, txw, txh); 611 float hor_corr, vert_corr; 612 av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr, 613 &vert_corr); 614 fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr); 615 616 double hdist[4] = { 0 }, vdist[4] = { 0 }; 617 get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride, 618 1, hdist, vdist); 619 fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2], 620 hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]); 621 622 fprintf(fout, " %d %" PRId64, x->rdmult, rd); 623 624 fprintf(fout, "\n"); 625 fclose(fout); 626 } 627 #endif // CONFIG_COLLECT_RD_STATS == 1 628 629 #if CONFIG_COLLECT_RD_STATS >= 2 630 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) { 631 const AV1_COMMON *cm = &cpi->common; 632 const int num_planes = av1_num_planes(cm); 633 const MACROBLOCKD *xd = &x->e_mbd; 634 const MB_MODE_INFO *mbmi = xd->mi[0]; 635 int64_t total_sse = 0; 636 for (int plane = 0; plane < num_planes; ++plane) { 637 const struct macroblock_plane *const p = &x->plane[plane]; 638 const struct macroblockd_plane *const pd = &xd->plane[plane]; 639 const BLOCK_SIZE bs = 640 get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y); 641 unsigned int sse; 642 643 if (plane) continue; 644 645 cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, 646 pd->dst.stride, &sse); 647 total_sse += sse; 648 } 649 total_sse <<= 4; 650 return total_sse; 651 } 652 653 static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize, 654 int64_t sse, int *est_residue_cost, 655 int64_t *est_dist) { 656 const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize]; 657 if (md->ready) { 658 if (sse < md->dist_mean) { 659 *est_residue_cost = 0; 660 *est_dist = sse; 661 } else { 662 *est_dist = (int64_t)round(md->dist_mean); 663 const double est_ld = md->a * sse + md->b; 664 // Clamp estimated rate cost by INT_MAX / 2. 665 // TODO(angiebird@google.com): find better solution than clamping. 666 if (fabs(est_ld) < 1e-2) { 667 *est_residue_cost = INT_MAX / 2; 668 } else { 669 double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld); 670 if (est_residue_cost_dbl < 0) { 671 *est_residue_cost = 0; 672 } else { 673 *est_residue_cost = 674 (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2); 675 } 676 } 677 if (*est_residue_cost <= 0) { 678 *est_residue_cost = 0; 679 *est_dist = sse; 680 } 681 } 682 return 1; 683 } 684 return 0; 685 } 686 687 static double get_highbd_diff_mean(const uint8_t *src8, int src_stride, 688 const uint8_t *dst8, int dst_stride, int w, 689 int h) { 690 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 691 const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8); 692 double sum = 0.0; 693 for (int j = 0; j < h; ++j) { 694 for (int i = 0; i < w; ++i) { 695 const int diff = src[j * src_stride + i] - dst[j * dst_stride + i]; 696 sum += diff; 697 } 698 } 699 assert(w > 0 && h > 0); 700 return sum / (w * h); 701 } 702 703 static double get_diff_mean(const uint8_t *src, int src_stride, 704 const uint8_t *dst, int dst_stride, int w, int h) { 705 double sum = 0.0; 706 for (int j = 0; j < h; ++j) { 707 for (int i = 0; i < w; ++i) { 708 const int diff = src[j * src_stride + i] - dst[j * dst_stride + i]; 709 sum += diff; 710 } 711 } 712 assert(w > 0 && h > 0); 713 return sum / (w * h); 714 } 715 716 static inline void PrintPredictionUnitStats(const AV1_COMP *const cpi, 717 const TileDataEnc *tile_data, 718 MACROBLOCK *x, 719 const RD_STATS *const rd_stats, 720 BLOCK_SIZE plane_bsize) { 721 if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return; 722 723 if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 && 724 (tile_data == NULL || 725 !tile_data->inter_mode_rd_models[plane_bsize].ready)) 726 return; 727 (void)tile_data; 728 // Generate small sample to restrict output size. 729 static unsigned int seed = 95014; 730 731 if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) != 732 1) 733 return; 734 735 const char output_file[] = "pu_stats.txt"; 736 FILE *fout = fopen(output_file, "a"); 737 if (!fout) return; 738 739 MACROBLOCKD *const xd = &x->e_mbd; 740 const int plane = 0; 741 struct macroblock_plane *const p = &x->plane[plane]; 742 struct macroblockd_plane *pd = &xd->plane[plane]; 743 const int diff_stride = block_size_wide[plane_bsize]; 744 int bw, bh; 745 get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw, 746 &bh); 747 const int num_samples = bw * bh; 748 const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3; 749 const int q_step = p->dequant_QTX[1] >> dequant_shift; 750 const int shift = (xd->bd - 8); 751 752 const double rate_norm = (double)rd_stats->rate / num_samples; 753 const double dist_norm = (double)rd_stats->dist / num_samples; 754 const double rdcost_norm = 755 (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples; 756 757 fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm); 758 759 const int src_stride = p->src.stride; 760 const uint8_t *const src = p->src.buf; 761 const int dst_stride = pd->dst.stride; 762 const uint8_t *const dst = pd->dst.buf; 763 const int16_t *const src_diff = p->src_diff; 764 765 int64_t sse = calculate_sse(xd, p, pd, bw, bh); 766 const double sse_norm = (double)sse / num_samples; 767 768 const unsigned int sad = 769 cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride); 770 const double sad_norm = 771 (double)sad / (1 << num_pels_log2_lookup[plane_bsize]); 772 773 fprintf(fout, " %g %g", sse_norm, sad_norm); 774 775 double sse_norm_arr[4], sad_norm_arr[4]; 776 get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst, 777 dst_stride, src_diff, diff_stride, 778 sse_norm_arr, sad_norm_arr); 779 if (shift) { 780 for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift)); 781 for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift); 782 } 783 for (int i = 0; i < 4; ++i) { 784 fprintf(fout, " %g", sse_norm_arr[i]); 785 } 786 for (int i = 0; i < 4; ++i) { 787 fprintf(fout, " %g", sad_norm_arr[i]); 788 } 789 790 fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh); 791 792 int model_rate; 793 int64_t model_dist; 794 model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples, 795 &model_rate, &model_dist); 796 const double model_rdcost_norm = 797 (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples; 798 const double model_rate_norm = (double)model_rate / num_samples; 799 const double model_dist_norm = (double)model_dist / num_samples; 800 fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm, 801 model_rdcost_norm); 802 803 double mean; 804 if (is_cur_buf_hbd(xd)) { 805 mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, 806 pd->dst.stride, bw, bh); 807 } else { 808 mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, 809 bw, bh); 810 } 811 mean /= (1 << shift); 812 float hor_corr, vert_corr; 813 av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr, 814 &vert_corr); 815 fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr); 816 817 double hdist[4] = { 0 }, vdist[4] = { 0 }; 818 get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst, 819 dst_stride, 1, hdist, vdist); 820 fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2], 821 hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]); 822 823 if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) { 824 assert(tile_data->inter_mode_rd_models[plane_bsize].ready); 825 const int64_t overall_sse = get_sse(cpi, x); 826 int est_residue_cost = 0; 827 int64_t est_dist = 0; 828 get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost, 829 &est_dist); 830 const double est_residue_cost_norm = (double)est_residue_cost / num_samples; 831 const double est_dist_norm = (double)est_dist / num_samples; 832 const double est_rdcost_norm = 833 (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples; 834 fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm, 835 est_rdcost_norm); 836 } 837 838 fprintf(fout, "\n"); 839 fclose(fout); 840 } 841 #endif // CONFIG_COLLECT_RD_STATS >= 2 842 #endif // CONFIG_COLLECT_RD_STATS 843 844 static inline void inverse_transform_block_facade(MACROBLOCK *const x, 845 int plane, int block, 846 int blk_row, int blk_col, 847 int eob, int reduced_tx_set) { 848 if (!eob) return; 849 struct macroblock_plane *const p = &x->plane[plane]; 850 MACROBLOCKD *const xd = &x->e_mbd; 851 tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block); 852 const PLANE_TYPE plane_type = get_plane_type(plane); 853 const TX_SIZE tx_size = av1_get_tx_size(plane, xd); 854 const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, 855 tx_size, reduced_tx_set); 856 857 struct macroblockd_plane *const pd = &xd->plane[plane]; 858 const int dst_stride = pd->dst.stride; 859 uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2]; 860 av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst, 861 dst_stride, eob, reduced_tx_set); 862 } 863 864 static inline void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane, 865 int block, int blk_row, int blk_col, 866 BLOCK_SIZE plane_bsize, TX_SIZE tx_size, 867 const TXB_CTX *const txb_ctx, int skip_trellis, 868 TX_TYPE best_tx_type, int do_quant, 869 int *rate_cost, uint16_t best_eob) { 870 const AV1_COMMON *cm = &cpi->common; 871 MACROBLOCKD *xd = &x->e_mbd; 872 MB_MODE_INFO *mbmi = xd->mi[0]; 873 const int is_inter = is_inter_block(mbmi); 874 if (!is_inter && best_eob && 875 (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] || 876 blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) { 877 // if the quantized coefficients are stored in the dqcoeff buffer, we don't 878 // need to do transform and quantization again. 879 if (do_quant) { 880 TxfmParam txfm_param_intra; 881 QUANT_PARAM quant_param_intra; 882 av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra); 883 av1_setup_quant(tx_size, !skip_trellis, 884 skip_trellis 885 ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B 886 : AV1_XFORM_QUANT_FP) 887 : AV1_XFORM_QUANT_FP, 888 cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra); 889 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type, 890 &quant_param_intra); 891 av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, 892 &txfm_param_intra, &quant_param_intra); 893 if (quant_param_intra.use_optimize_b) { 894 av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx, 895 rate_cost); 896 } 897 } 898 899 inverse_transform_block_facade(x, plane, block, blk_row, blk_col, 900 x->plane[plane].eobs[block], 901 cm->features.reduced_tx_set_used); 902 903 // This may happen because of hash collision. The eob stored in the hash 904 // table is non-zero, but the real eob is zero. We need to make sure tx_type 905 // is DCT_DCT in this case. 906 if (plane == 0 && x->plane[plane].eobs[block] == 0 && 907 best_tx_type != DCT_DCT) { 908 update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT); 909 } 910 } 911 } 912 913 static unsigned pixel_dist_visible_only( 914 const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src, 915 const int src_stride, const uint8_t *dst, const int dst_stride, 916 const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows, 917 int visible_cols) { 918 unsigned sse; 919 920 if (txb_rows == visible_rows && txb_cols == visible_cols) { 921 cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse); 922 return sse; 923 } 924 925 #if CONFIG_AV1_HIGHBITDEPTH 926 const MACROBLOCKD *xd = &x->e_mbd; 927 if (is_cur_buf_hbd(xd)) { 928 uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride, 929 visible_cols, visible_rows); 930 return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2); 931 } 932 #else 933 (void)x; 934 #endif 935 sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols, 936 visible_rows); 937 return sse; 938 } 939 940 // Compute the pixel domain distortion from src and dst on all visible 4x4s in 941 // the 942 // transform block. 943 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x, 944 int plane, const uint8_t *src, const int src_stride, 945 const uint8_t *dst, const int dst_stride, 946 int blk_row, int blk_col, 947 const BLOCK_SIZE plane_bsize, 948 const BLOCK_SIZE tx_bsize) { 949 int txb_rows, txb_cols, visible_rows, visible_cols; 950 const MACROBLOCKD *xd = &x->e_mbd; 951 952 get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, 953 &txb_cols, &txb_rows, &visible_cols, &visible_rows); 954 assert(visible_rows > 0); 955 assert(visible_cols > 0); 956 957 unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst, 958 dst_stride, tx_bsize, txb_rows, 959 txb_cols, visible_rows, visible_cols); 960 961 return sse; 962 } 963 964 static inline int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x, 965 int plane, BLOCK_SIZE plane_bsize, 966 int block, int blk_row, int blk_col, 967 TX_SIZE tx_size) { 968 MACROBLOCKD *const xd = &x->e_mbd; 969 const struct macroblock_plane *const p = &x->plane[plane]; 970 const uint16_t eob = p->eobs[block]; 971 const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size]; 972 const int bsw = block_size_wide[tx_bsize]; 973 const int bsh = block_size_high[tx_bsize]; 974 const int src_stride = x->plane[plane].src.stride; 975 const int dst_stride = xd->plane[plane].dst.stride; 976 // Scale the transform block index to pixel unit. 977 const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2; 978 const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2; 979 const uint8_t *src = &x->plane[plane].src.buf[src_idx]; 980 const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx]; 981 const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block); 982 983 assert(cpi != NULL); 984 assert(tx_size_wide_log2[0] == tx_size_high_log2[0]); 985 986 uint8_t *recon; 987 DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]); 988 989 #if CONFIG_AV1_HIGHBITDEPTH 990 if (is_cur_buf_hbd(xd)) { 991 recon = CONVERT_TO_BYTEPTR(recon16); 992 aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride, 993 CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh); 994 } else { 995 recon = (uint8_t *)recon16; 996 aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh); 997 } 998 #else 999 recon = (uint8_t *)recon16; 1000 aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh); 1001 #endif 1002 1003 const PLANE_TYPE plane_type = get_plane_type(plane); 1004 TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size, 1005 cpi->common.features.reduced_tx_set_used); 1006 av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon, 1007 MAX_TX_SIZE, eob, 1008 cpi->common.features.reduced_tx_set_used); 1009 1010 return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE, 1011 blk_row, blk_col, plane_bsize, tx_bsize); 1012 } 1013 1014 // pruning thresholds for prune_txk_type and prune_txk_type_separ 1015 static const int prune_factors[5] = { 200, 200, 120, 80, 40 }; // scale 1000 1016 static const int mul_factors[5] = { 80, 80, 70, 50, 30 }; // scale 100 1017 1018 // R-D costs are sorted in ascending order. 1019 static inline void sort_rd(int64_t rds[], int txk[], int len) { 1020 int i, j, k; 1021 1022 for (i = 1; i <= len - 1; ++i) { 1023 for (j = 0; j < i; ++j) { 1024 if (rds[j] > rds[i]) { 1025 int64_t temprd; 1026 int tempi; 1027 1028 temprd = rds[i]; 1029 tempi = txk[i]; 1030 1031 for (k = i; k > j; k--) { 1032 rds[k] = rds[k - 1]; 1033 txk[k] = txk[k - 1]; 1034 } 1035 1036 rds[j] = temprd; 1037 txk[j] = tempi; 1038 break; 1039 } 1040 } 1041 } 1042 } 1043 1044 static inline int64_t av1_block_error_qm( 1045 const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, 1046 const qm_val_t *qmatrix, const int16_t *scan, int64_t *ssz, int bd) { 1047 int i; 1048 int64_t error = 0, sqcoeff = 0; 1049 int shift = 2 * (bd - 8); 1050 int rounding = (1 << shift) >> 1; 1051 1052 for (i = 0; i < block_size; i++) { 1053 int64_t weight = qmatrix[scan[i]]; 1054 int64_t dd = coeff[i] - dqcoeff[i]; 1055 dd *= weight; 1056 int64_t cc = coeff[i]; 1057 cc *= weight; 1058 // The ranges of coeff and dqcoeff are 1059 // bd8 : 18 bits (including sign) 1060 // bd10: 20 bits (including sign) 1061 // bd12: 22 bits (including sign) 1062 // As AOM_QM_BITS is 5, the intermediate quantities in the calculation 1063 // below should fit in 54 bits, thus no overflow should happen. 1064 error += (dd * dd + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS); 1065 sqcoeff += (cc * cc + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS); 1066 } 1067 1068 error = (error + rounding) >> shift; 1069 sqcoeff = (sqcoeff + rounding) >> shift; 1070 1071 *ssz = sqcoeff; 1072 return error; 1073 } 1074 1075 static inline void dist_block_tx_domain(MACROBLOCK *x, int plane, int block, 1076 TX_SIZE tx_size, 1077 const qm_val_t *qmatrix, 1078 const int16_t *scan, int64_t *out_dist, 1079 int64_t *out_sse) { 1080 const struct macroblock_plane *const p = &x->plane[plane]; 1081 // Transform domain distortion computation is more efficient as it does 1082 // not involve an inverse transform, but it is less accurate. 1083 const int buffer_length = av1_get_max_eob(tx_size); 1084 int64_t this_sse; 1085 // TX-domain results need to shift down to Q2/D10 to match pixel 1086 // domain distortion values which are in Q2^2 1087 int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2; 1088 const int block_offset = BLOCK_OFFSET(block); 1089 tran_low_t *const coeff = p->coeff + block_offset; 1090 tran_low_t *const dqcoeff = p->dqcoeff + block_offset; 1091 #if CONFIG_AV1_HIGHBITDEPTH 1092 MACROBLOCKD *const xd = &x->e_mbd; 1093 if (is_cur_buf_hbd(xd)) { 1094 if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) { 1095 *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, 1096 &this_sse, xd->bd); 1097 } else { 1098 *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix, 1099 scan, &this_sse, xd->bd); 1100 } 1101 } else { 1102 #endif 1103 if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) { 1104 *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse); 1105 } else { 1106 *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix, 1107 scan, &this_sse, 8); 1108 } 1109 #if CONFIG_AV1_HIGHBITDEPTH 1110 } 1111 #endif 1112 1113 *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift); 1114 *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift); 1115 } 1116 1117 static uint16_t prune_txk_type_separ( 1118 const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, TX_SIZE tx_size, 1119 int blk_row, int blk_col, BLOCK_SIZE plane_bsize, int *txk_map, 1120 int16_t allowed_tx_mask, int prune_factor, const TXB_CTX *const txb_ctx, 1121 int reduced_tx_set_used, int64_t ref_best_rd, int num_sel) { 1122 const AV1_COMMON *cm = &cpi->common; 1123 MACROBLOCKD *xd = &x->e_mbd; 1124 1125 int idx; 1126 1127 int64_t rds_v[4]; 1128 int64_t rds_h[4]; 1129 int idx_v[4] = { 0, 1, 2, 3 }; 1130 int idx_h[4] = { 0, 1, 2, 3 }; 1131 int skip_v[4] = { 0 }; 1132 int skip_h[4] = { 0 }; 1133 const int idx_map[16] = { 1134 DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT, 1135 ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST, 1136 FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST, 1137 H_DCT, H_ADST, H_FLIPADST, IDTX 1138 }; 1139 1140 const int sel_pattern_v[16] = { 1141 0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3 1142 }; 1143 const int sel_pattern_h[16] = { 1144 0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3 1145 }; 1146 1147 QUANT_PARAM quant_param; 1148 TxfmParam txfm_param; 1149 av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param); 1150 av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt, 1151 &quant_param); 1152 int tx_type; 1153 // to ensure we can try ones even outside of ext_tx_set of current block 1154 // this function should only be called for size < 16 1155 assert(txsize_sqr_up_map[tx_size] <= TX_16X16); 1156 txfm_param.tx_set_type = EXT_TX_SET_ALL16; 1157 1158 int rate_cost = 0; 1159 int64_t dist = 0, sse = 0; 1160 // evaluate horizontal with vertical DCT 1161 for (idx = 0; idx < 4; ++idx) { 1162 tx_type = idx_map[idx]; 1163 txfm_param.tx_type = tx_type; 1164 1165 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type, 1166 &quant_param); 1167 1168 av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param, 1169 &quant_param); 1170 1171 const SCAN_ORDER *const scan_order = 1172 get_scan(txfm_param.tx_size, txfm_param.tx_type); 1173 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix, 1174 scan_order->scan, &dist, &sse); 1175 1176 rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type, 1177 txb_ctx, reduced_tx_set_used, 0); 1178 1179 rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist); 1180 1181 if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) { 1182 skip_h[idx] = 1; 1183 } 1184 } 1185 sort_rd(rds_h, idx_h, 4); 1186 for (idx = 1; idx < 4; idx++) { 1187 if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1; 1188 } 1189 1190 if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF; 1191 1192 // evaluate vertical with the best horizontal chosen 1193 rds_v[0] = rds_h[0]; 1194 int start_v = 1, end_v = 4; 1195 const int *idx_map_v = idx_map + idx_h[0]; 1196 1197 for (idx = start_v; idx < end_v; ++idx) { 1198 tx_type = idx_map_v[idx_v[idx] * 4]; 1199 txfm_param.tx_type = tx_type; 1200 1201 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type, 1202 &quant_param); 1203 1204 av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param, 1205 &quant_param); 1206 1207 const SCAN_ORDER *const scan_order = 1208 get_scan(txfm_param.tx_size, txfm_param.tx_type); 1209 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix, 1210 scan_order->scan, &dist, &sse); 1211 1212 rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type, 1213 txb_ctx, reduced_tx_set_used, 0); 1214 1215 rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist); 1216 1217 if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) { 1218 skip_v[idx] = 1; 1219 } 1220 } 1221 sort_rd(rds_v, idx_v, 4); 1222 for (idx = 1; idx < 4; idx++) { 1223 if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1; 1224 } 1225 1226 // combine rd_h and rd_v to prune tx candidates 1227 int i_v, i_h; 1228 int64_t rds[16]; 1229 int num_cand = 0, last = TX_TYPES - 1; 1230 1231 for (int i = 0; i < 16; i++) { 1232 i_v = sel_pattern_v[i]; 1233 i_h = sel_pattern_h[i]; 1234 tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]]; 1235 if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] || 1236 skip_v[idx_v[i_v]]) { 1237 txk_map[last] = tx_type; 1238 last--; 1239 } else { 1240 txk_map[num_cand] = tx_type; 1241 rds[num_cand] = rds_v[i_v] + rds_h[i_h]; 1242 if (rds[num_cand] == 0) rds[num_cand] = 1; 1243 num_cand++; 1244 } 1245 } 1246 sort_rd(rds, txk_map, num_cand); 1247 1248 uint16_t prune = (uint16_t)(~(1 << txk_map[0])); 1249 num_sel = AOMMIN(num_sel, num_cand); 1250 1251 for (int i = 1; i < num_sel; i++) { 1252 int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]); 1253 if (factor < (int64_t)prune_factor) 1254 prune &= ~(1 << txk_map[i]); 1255 else 1256 break; 1257 } 1258 return prune; 1259 } 1260 1261 static uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, 1262 int block, TX_SIZE tx_size, int blk_row, 1263 int blk_col, BLOCK_SIZE plane_bsize, 1264 int *txk_map, uint16_t allowed_tx_mask, 1265 int prune_factor, const TXB_CTX *const txb_ctx, 1266 int reduced_tx_set_used) { 1267 const AV1_COMMON *cm = &cpi->common; 1268 MACROBLOCKD *xd = &x->e_mbd; 1269 int tx_type; 1270 1271 int64_t rds[TX_TYPES]; 1272 1273 int num_cand = 0; 1274 int last = TX_TYPES - 1; 1275 1276 TxfmParam txfm_param; 1277 QUANT_PARAM quant_param; 1278 av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param); 1279 av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt, 1280 &quant_param); 1281 1282 for (int idx = 0; idx < TX_TYPES; idx++) { 1283 tx_type = idx; 1284 int rate_cost = 0; 1285 int64_t dist = 0, sse = 0; 1286 if (!(allowed_tx_mask & (1 << tx_type))) { 1287 txk_map[last] = tx_type; 1288 last--; 1289 continue; 1290 } 1291 txfm_param.tx_type = tx_type; 1292 1293 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type, 1294 &quant_param); 1295 1296 // do txfm and quantization 1297 av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param, 1298 &quant_param); 1299 // estimate rate cost 1300 rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type, 1301 txb_ctx, reduced_tx_set_used, 0); 1302 // tx domain dist 1303 const SCAN_ORDER *const scan_order = 1304 get_scan(txfm_param.tx_size, txfm_param.tx_type); 1305 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix, 1306 scan_order->scan, &dist, &sse); 1307 1308 txk_map[num_cand] = tx_type; 1309 rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist); 1310 if (rds[num_cand] == 0) rds[num_cand] = 1; 1311 num_cand++; 1312 } 1313 1314 if (num_cand == 0) return (uint16_t)0xFFFF; 1315 1316 sort_rd(rds, txk_map, num_cand); 1317 uint16_t prune = (uint16_t)(~(1 << txk_map[0])); 1318 1319 // 0 < prune_factor <= 1000 controls aggressiveness 1320 int64_t factor = 0; 1321 for (int idx = 1; idx < num_cand; idx++) { 1322 factor = 1000 * (rds[idx] - rds[0]) / rds[0]; 1323 if (factor < (int64_t)prune_factor) 1324 prune &= ~(1 << txk_map[idx]); 1325 else 1326 break; 1327 } 1328 return prune; 1329 } 1330 1331 // These thresholds were calibrated to provide a certain number of TX types 1332 // pruned by the model on average, i.e. selecting a threshold with index i 1333 // will lead to pruning i+1 TX types on average 1334 static const float *prune_2D_adaptive_thresholds[] = { 1335 // TX_4X4 1336 (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f, 1337 0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f, 1338 0.09778f, 0.11780f }, 1339 // TX_8X8 1340 (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f, 1341 0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f, 1342 0.10803f, 0.14124f }, 1343 // TX_16X16 1344 (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f, 1345 0.06897f, 0.07629f, 0.08875f, 0.11169f }, 1346 // TX_32X32 1347 NULL, 1348 // TX_64X64 1349 NULL, 1350 // TX_4X8 1351 (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f, 1352 0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f, 1353 0.10168f, 0.12585f }, 1354 // TX_8X4 1355 (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f, 1356 0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f, 1357 0.10583f, 0.13123f }, 1358 // TX_8X16 1359 (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f, 1360 0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f, 1361 0.10730f, 0.14221f }, 1362 // TX_16X8 1363 (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f, 1364 0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f, 1365 0.10339f, 0.13464f }, 1366 // TX_16X32 1367 NULL, 1368 // TX_32X16 1369 NULL, 1370 // TX_32X64 1371 NULL, 1372 // TX_64X32 1373 NULL, 1374 // TX_4X16 1375 (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f, 1376 0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f, 1377 0.10242f, 0.12878f }, 1378 // TX_16X4 1379 (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f, 1380 0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f, 1381 0.10217f, 0.12610f }, 1382 // TX_8X32 1383 NULL, 1384 // TX_32X8 1385 NULL, 1386 // TX_16X64 1387 NULL, 1388 // TX_64X16 1389 NULL, 1390 }; 1391 1392 static inline float get_adaptive_thresholds( 1393 TX_SIZE tx_size, TxSetType tx_set_type, 1394 TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) { 1395 const int prune_aggr_table[5][2] = { 1396 { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 } 1397 }; 1398 int pruning_aggressiveness = 0; 1399 if (tx_set_type == EXT_TX_SET_ALL16) 1400 pruning_aggressiveness = 1401 prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0]; 1402 else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT) 1403 pruning_aggressiveness = 1404 prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1]; 1405 1406 return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness]; 1407 } 1408 1409 static inline void get_energy_distribution_finer(const int16_t *diff, 1410 int stride, int bw, int bh, 1411 float *hordist, 1412 float *verdist) { 1413 // First compute downscaled block energy values (esq); downscale factors 1414 // are defined by w_shift and h_shift. 1415 unsigned int esq[256]; 1416 const int w_shift = bw <= 8 ? 0 : 1; 1417 const int h_shift = bh <= 8 ? 0 : 1; 1418 const int esq_w = bw >> w_shift; 1419 const int esq_h = bh >> h_shift; 1420 const int esq_sz = esq_w * esq_h; 1421 int i, j; 1422 memset(esq, 0, esq_sz * sizeof(esq[0])); 1423 if (w_shift) { 1424 for (i = 0; i < bh; i++) { 1425 unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w; 1426 const int16_t *cur_diff_row = diff + i * stride; 1427 for (j = 0; j < bw; j += 2) { 1428 cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] + 1429 cur_diff_row[j + 1] * cur_diff_row[j + 1]); 1430 } 1431 } 1432 } else { 1433 for (i = 0; i < bh; i++) { 1434 unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w; 1435 const int16_t *cur_diff_row = diff + i * stride; 1436 for (j = 0; j < bw; j++) { 1437 cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j]; 1438 } 1439 } 1440 } 1441 1442 uint64_t total = 0; 1443 for (i = 0; i < esq_sz; i++) total += esq[i]; 1444 1445 // Output hordist and verdist arrays are normalized 1D projections of esq 1446 if (total == 0) { 1447 float hor_val = 1.0f / esq_w; 1448 for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val; 1449 float ver_val = 1.0f / esq_h; 1450 for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val; 1451 return; 1452 } 1453 1454 const float e_recip = 1.0f / (float)total; 1455 memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0])); 1456 memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0])); 1457 const unsigned int *cur_esq_row; 1458 for (i = 0; i < esq_h - 1; i++) { 1459 cur_esq_row = esq + i * esq_w; 1460 for (j = 0; j < esq_w - 1; j++) { 1461 hordist[j] += (float)cur_esq_row[j]; 1462 verdist[i] += (float)cur_esq_row[j]; 1463 } 1464 verdist[i] += (float)cur_esq_row[j]; 1465 } 1466 cur_esq_row = esq + i * esq_w; 1467 for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j]; 1468 1469 for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip; 1470 for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip; 1471 } 1472 1473 static inline bool check_bit_mask(uint16_t mask, int val) { 1474 return mask & (1 << val); 1475 } 1476 1477 static inline void set_bit_mask(uint16_t *mask, int val) { 1478 *mask |= (1 << val); 1479 } 1480 1481 static inline void unset_bit_mask(uint16_t *mask, int val) { 1482 *mask &= ~(1 << val); 1483 } 1484 1485 static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size, 1486 int blk_row, int blk_col, TxSetType tx_set_type, 1487 TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map, 1488 uint16_t *allowed_tx_mask) { 1489 // This table is used because the search order is different from the enum 1490 // order. 1491 static const int tx_type_table_2D[16] = { 1492 DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT, 1493 ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST, 1494 FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST, 1495 H_DCT, H_ADST, H_FLIPADST, IDTX 1496 }; 1497 if (tx_set_type != EXT_TX_SET_ALL16 && 1498 tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT) 1499 return; 1500 #if CONFIG_NN_V2 1501 NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size]; 1502 NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size]; 1503 #else 1504 const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size]; 1505 const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size]; 1506 #endif 1507 if (!nn_config_hor || !nn_config_ver) return; // Model not established yet. 1508 1509 float hfeatures[16], vfeatures[16]; 1510 float hscores[4], vscores[4]; 1511 float scores_2D_raw[16]; 1512 const int bw = tx_size_wide[tx_size]; 1513 const int bh = tx_size_high[tx_size]; 1514 const int hfeatures_num = bw <= 8 ? bw : bw / 2; 1515 const int vfeatures_num = bh <= 8 ? bh : bh / 2; 1516 assert(hfeatures_num <= 16); 1517 assert(vfeatures_num <= 16); 1518 1519 const struct macroblock_plane *const p = &x->plane[0]; 1520 const int diff_stride = block_size_wide[bsize]; 1521 const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col; 1522 get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures, 1523 vfeatures); 1524 1525 av1_get_horver_correlation_full(diff, diff_stride, bw, bh, 1526 &hfeatures[hfeatures_num - 1], 1527 &vfeatures[vfeatures_num - 1]); 1528 1529 #if CONFIG_NN_V2 1530 av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores); 1531 av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores); 1532 #else 1533 av1_nn_predict(hfeatures, nn_config_hor, 1, hscores); 1534 av1_nn_predict(vfeatures, nn_config_ver, 1, vscores); 1535 #endif 1536 1537 for (int i = 0; i < 4; i++) { 1538 float *cur_scores_2D = scores_2D_raw + i * 4; 1539 cur_scores_2D[0] = vscores[i] * hscores[0]; 1540 cur_scores_2D[1] = vscores[i] * hscores[1]; 1541 cur_scores_2D[2] = vscores[i] * hscores[2]; 1542 cur_scores_2D[3] = vscores[i] * hscores[3]; 1543 } 1544 1545 assert(TX_TYPES == 16); 1546 // This version of the function only works when there are at most 16 classes. 1547 // So we will need to change the optimization or use av1_nn_softmax instead if 1548 // this ever gets changed. 1549 av1_nn_fast_softmax_16(scores_2D_raw, scores_2D_raw); 1550 1551 const float score_thresh = 1552 get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode); 1553 1554 // Always keep the TX type with the highest score, prune all others with 1555 // score below score_thresh. 1556 int max_score_i = 0; 1557 float max_score = 0.0f; 1558 uint16_t allow_bitmask = 0; 1559 float sum_score = 0.0; 1560 // Calculate sum of allowed tx type score and Populate allow bit mask based 1561 // on score_thresh and allowed_tx_mask 1562 int allow_count = 0; 1563 int tx_type_allowed[16] = { TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID, 1564 TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID, 1565 TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID, 1566 TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID, 1567 TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID, 1568 TX_TYPE_INVALID }; 1569 float scores_2D[16] = { 1570 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1571 }; 1572 for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) { 1573 const int allow_tx_type = 1574 check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]); 1575 if (!allow_tx_type) { 1576 continue; 1577 } 1578 if (scores_2D_raw[tx_idx] > max_score) { 1579 max_score = scores_2D_raw[tx_idx]; 1580 max_score_i = tx_idx; 1581 } 1582 if (scores_2D_raw[tx_idx] >= score_thresh) { 1583 // Set allow mask based on score_thresh 1584 set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]); 1585 1586 // Accumulate score of allowed tx type 1587 sum_score += scores_2D_raw[tx_idx]; 1588 1589 scores_2D[allow_count] = scores_2D_raw[tx_idx]; 1590 tx_type_allowed[allow_count] = tx_type_table_2D[tx_idx]; 1591 allow_count += 1; 1592 } 1593 } 1594 if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) { 1595 // If even the tx_type with max score is pruned, this means that no other 1596 // tx_type is feasible. When this happens, we force enable max_score_i and 1597 // end the search. 1598 set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]); 1599 memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D)); 1600 *allowed_tx_mask = allow_bitmask; 1601 return; 1602 } 1603 1604 // Sort tx type probability of all types 1605 if (allow_count <= 8) { 1606 av1_sort_fi32_8(scores_2D, tx_type_allowed); 1607 } else { 1608 av1_sort_fi32_16(scores_2D, tx_type_allowed); 1609 } 1610 1611 // Enable more pruning based on tx type probability and number of allowed tx 1612 // types 1613 if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) { 1614 float temp_score = 0.0; 1615 float score_ratio = 0.0; 1616 int tx_idx, tx_count = 0; 1617 const float inv_sum_score = 100 / sum_score; 1618 // Get allowed tx types based on sorted probability score and tx count 1619 for (tx_idx = 0; tx_idx < allow_count; tx_idx++) { 1620 // Skip the tx type which has more than 30% of cumulative 1621 // probability and allowed tx type count is more than 2 1622 if (score_ratio > 30.0 && tx_count >= 2) break; 1623 1624 assert(check_bit_mask(allow_bitmask, tx_type_allowed[tx_idx])); 1625 // Calculate cumulative probability 1626 temp_score += scores_2D[tx_idx]; 1627 1628 // Calculate percentage of cumulative probability of allowed tx type 1629 score_ratio = temp_score * inv_sum_score; 1630 tx_count++; 1631 } 1632 // Set remaining tx types as pruned 1633 for (; tx_idx < allow_count; tx_idx++) 1634 unset_bit_mask(&allow_bitmask, tx_type_allowed[tx_idx]); 1635 } 1636 1637 memcpy(txk_map, tx_type_allowed, sizeof(tx_type_table_2D)); 1638 *allowed_tx_mask = allow_bitmask; 1639 } 1640 1641 static float get_dev(float mean, double x2_sum, int num) { 1642 const float e_x2 = (float)(x2_sum / num); 1643 const float diff = e_x2 - mean * mean; 1644 const float dev = (diff > 0) ? sqrtf(diff) : 0; 1645 return dev; 1646 } 1647 1648 // Writes the features required by the ML model to predict tx split based on 1649 // mean and standard deviation values of the block and sub-blocks. 1650 // Returns the number of elements written to the output array which is at most 1651 // 12 currently. Hence 'features' buffer should be able to accommodate at least 1652 // 12 elements. 1653 static inline int get_mean_dev_features(const int16_t *data, int stride, int bw, 1654 int bh, float *features) { 1655 const int16_t *const data_ptr = &data[0]; 1656 const int subh = (bh >= bw) ? (bh >> 1) : bh; 1657 const int subw = (bw >= bh) ? (bw >> 1) : bw; 1658 const int num = bw * bh; 1659 const int sub_num = subw * subh; 1660 int feature_idx = 2; 1661 int total_x_sum = 0; 1662 int64_t total_x2_sum = 0; 1663 int num_sub_blks = 0; 1664 double mean2_sum = 0.0f; 1665 float dev_sum = 0.0f; 1666 1667 for (int row = 0; row < bh; row += subh) { 1668 for (int col = 0; col < bw; col += subw) { 1669 int x_sum; 1670 int64_t x2_sum; 1671 // TODO(any): Write a SIMD version. Clear registers. 1672 aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh, 1673 &x_sum, &x2_sum); 1674 total_x_sum += x_sum; 1675 total_x2_sum += x2_sum; 1676 1677 const float mean = (float)x_sum / sub_num; 1678 const float dev = get_dev(mean, (double)x2_sum, sub_num); 1679 features[feature_idx++] = mean; 1680 features[feature_idx++] = dev; 1681 mean2_sum += (double)(mean * mean); 1682 dev_sum += dev; 1683 num_sub_blks++; 1684 } 1685 } 1686 1687 const float lvl0_mean = (float)total_x_sum / num; 1688 features[0] = lvl0_mean; 1689 features[1] = get_dev(lvl0_mean, (double)total_x2_sum, num); 1690 1691 // Deviation of means. 1692 features[feature_idx++] = get_dev(lvl0_mean, mean2_sum, num_sub_blks); 1693 // Mean of deviations. 1694 features[feature_idx++] = dev_sum / num_sub_blks; 1695 1696 return feature_idx; 1697 } 1698 1699 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row, 1700 int blk_col, TX_SIZE tx_size) { 1701 const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size]; 1702 if (!nn_config) return -1; 1703 1704 const int diff_stride = block_size_wide[bsize]; 1705 const int16_t *diff = 1706 x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col; 1707 const int bw = tx_size_wide[tx_size]; 1708 const int bh = tx_size_high[tx_size]; 1709 1710 float features[64] = { 0.0f }; 1711 get_mean_dev_features(diff, diff_stride, bw, bh, features); 1712 1713 float score = 0.0f; 1714 av1_nn_predict(features, nn_config, 1, &score); 1715 1716 int int_score = (int)(score * 10000); 1717 return clamp(int_score, -80000, 80000); 1718 } 1719 1720 static inline uint16_t get_tx_mask( 1721 const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, int blk_row, 1722 int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size, 1723 const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode, 1724 int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) { 1725 const AV1_COMMON *cm = &cpi->common; 1726 MACROBLOCKD *xd = &x->e_mbd; 1727 MB_MODE_INFO *mbmi = xd->mi[0]; 1728 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 1729 const int is_inter = is_inter_block(mbmi); 1730 const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY; 1731 // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed < 1732 // TX_TYPES, only that specific tx type is allowed. 1733 TX_TYPE txk_allowed = TX_TYPES; 1734 1735 const FRAME_UPDATE_TYPE update_type = 1736 get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index); 1737 int use_actual_frame_probs = 1; 1738 const int *tx_type_probs; 1739 #if CONFIG_FPMT_TEST 1740 use_actual_frame_probs = 1741 (cpi->ppi->fpmt_unit_test_cfg == PARALLEL_SIMULATION_ENCODE) ? 0 : 1; 1742 if (!use_actual_frame_probs) { 1743 tx_type_probs = 1744 (int *)cpi->ppi->temp_frame_probs.tx_type_probs[update_type][tx_size]; 1745 } 1746 #endif 1747 if (use_actual_frame_probs) { 1748 tx_type_probs = cpi->ppi->frame_probs.tx_type_probs[update_type][tx_size]; 1749 } 1750 1751 if ((!is_inter && txfm_params->use_default_intra_tx_type) || 1752 (is_inter && txfm_params->default_inter_tx_type_prob_thresh == 0)) { 1753 txk_allowed = 1754 get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools); 1755 } else if (is_inter && 1756 txfm_params->default_inter_tx_type_prob_thresh != INT_MAX) { 1757 if (tx_type_probs[DEFAULT_INTER_TX_TYPE] > 1758 txfm_params->default_inter_tx_type_prob_thresh) { 1759 txk_allowed = DEFAULT_INTER_TX_TYPE; 1760 } else { 1761 int force_tx_type = 0; 1762 int max_prob = 0; 1763 const int tx_type_prob_threshold = 1764 txfm_params->default_inter_tx_type_prob_thresh + 1765 PROB_THRESH_OFFSET_TX_TYPE; 1766 for (int i = 1; i < TX_TYPES; i++) { // find maximum probability. 1767 if (tx_type_probs[i] > max_prob) { 1768 max_prob = tx_type_probs[i]; 1769 force_tx_type = i; 1770 } 1771 } 1772 if (max_prob > tx_type_prob_threshold) // force tx type with max prob. 1773 txk_allowed = force_tx_type; 1774 else if (x->rd_model == LOW_TXFM_RD) { 1775 if (plane == 0) txk_allowed = DCT_DCT; 1776 } 1777 } 1778 } else if (x->rd_model == LOW_TXFM_RD) { 1779 if (plane == 0) txk_allowed = DCT_DCT; 1780 } 1781 1782 const TxSetType tx_set_type = av1_get_ext_tx_set_type( 1783 tx_size, is_inter, cm->features.reduced_tx_set_used); 1784 1785 TX_TYPE uv_tx_type = DCT_DCT; 1786 if (plane) { 1787 // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y 1788 uv_tx_type = txk_allowed = 1789 av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size, 1790 cm->features.reduced_tx_set_used); 1791 } 1792 PREDICTION_MODE intra_dir = 1793 mbmi->filter_intra_mode_info.use_filter_intra 1794 ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode] 1795 : mbmi->mode; 1796 uint16_t ext_tx_used_flag = 1797 cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset != 0 && 1798 tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT 1799 ? av1_reduced_intra_tx_used_flag[intra_dir] 1800 : av1_ext_tx_used_flag[tx_set_type]; 1801 1802 if (cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset == 2) 1803 ext_tx_used_flag &= av1_derived_intra_tx_used_flag[intra_dir]; 1804 1805 if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 || 1806 ext_tx_used_flag == 0x0001 || 1807 (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) || 1808 (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) { 1809 txk_allowed = DCT_DCT; 1810 } 1811 1812 if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0) 1813 ext_tx_used_flag &= DCT_ADST_TX_MASK; 1814 1815 uint16_t allowed_tx_mask = 0; // 1: allow; 0: skip. 1816 if (txk_allowed < TX_TYPES) { 1817 allowed_tx_mask = 1 << txk_allowed; 1818 allowed_tx_mask &= ext_tx_used_flag; 1819 } else if (fast_tx_search) { 1820 allowed_tx_mask = 0x0c01; // V_DCT, H_DCT, DCT_DCT 1821 allowed_tx_mask &= ext_tx_used_flag; 1822 } else { 1823 assert(plane == 0); 1824 allowed_tx_mask = ext_tx_used_flag; 1825 int num_allowed = 0; 1826 int i; 1827 1828 if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) { 1829 static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 }, 1830 { 10, 17, 17, 10, 17, 17, 17 } }; 1831 const int thresh = 1832 thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1] 1833 [update_type]; 1834 uint16_t prune = 0; 1835 int max_prob = -1; 1836 int max_idx = 0; 1837 for (i = 0; i < TX_TYPES; i++) { 1838 if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) { 1839 max_prob = tx_type_probs[i]; 1840 max_idx = i; 1841 } 1842 if (tx_type_probs[i] < thresh) prune |= (1 << i); 1843 } 1844 if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx); 1845 allowed_tx_mask &= (~prune); 1846 } 1847 for (i = 0; i < TX_TYPES; i++) { 1848 if (allowed_tx_mask & (1 << i)) num_allowed++; 1849 } 1850 assert(num_allowed > 0); 1851 1852 if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) { 1853 int pf = prune_factors[txfm_params->prune_2d_txfm_mode]; 1854 int mf = mul_factors[txfm_params->prune_2d_txfm_mode]; 1855 if (num_allowed <= 7) { 1856 const uint16_t prune = 1857 prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col, 1858 plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx, 1859 cm->features.reduced_tx_set_used); 1860 allowed_tx_mask &= (~prune); 1861 } else { 1862 const int num_sel = (num_allowed * mf + 50) / 100; 1863 const uint16_t prune = prune_txk_type_separ( 1864 cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize, 1865 txk_map, allowed_tx_mask, pf, txb_ctx, 1866 cm->features.reduced_tx_set_used, ref_best_rd, num_sel); 1867 1868 allowed_tx_mask &= (~prune); 1869 } 1870 } else { 1871 assert(num_allowed > 0); 1872 int allowed_tx_count = 1873 (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5; 1874 // !fast_tx_search && txk_end != txk_start && plane == 0 1875 if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter && 1876 num_allowed > allowed_tx_count) { 1877 prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type, 1878 txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask); 1879 } 1880 } 1881 } 1882 1883 // Need to have at least one transform type allowed. 1884 if (allowed_tx_mask == 0) { 1885 txk_allowed = (plane ? uv_tx_type : DCT_DCT); 1886 allowed_tx_mask = (1 << txk_allowed); 1887 } 1888 1889 assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed)); 1890 *allowed_txk_types = txk_allowed; 1891 return allowed_tx_mask; 1892 } 1893 1894 #if CONFIG_RD_DEBUG 1895 static inline void update_txb_coeff_cost(RD_STATS *rd_stats, int plane, 1896 int txb_coeff_cost) { 1897 rd_stats->txb_coeff_cost[plane] += txb_coeff_cost; 1898 } 1899 #endif 1900 1901 static inline int cost_coeffs(MACROBLOCK *x, int plane, int block, 1902 TX_SIZE tx_size, const TX_TYPE tx_type, 1903 const TXB_CTX *const txb_ctx, 1904 int reduced_tx_set_used) { 1905 #if TXCOEFF_COST_TIMER 1906 struct aom_usec_timer timer; 1907 aom_usec_timer_start(&timer); 1908 #endif 1909 const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type, 1910 txb_ctx, reduced_tx_set_used); 1911 #if TXCOEFF_COST_TIMER 1912 AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common; 1913 aom_usec_timer_mark(&timer); 1914 const int64_t elapsed_time = aom_usec_timer_elapsed(&timer); 1915 tmp_cm->txcoeff_cost_timer += elapsed_time; 1916 ++tmp_cm->txcoeff_cost_count; 1917 #endif 1918 return cost; 1919 } 1920 1921 static int skip_trellis_opt_based_on_satd(MACROBLOCK *x, 1922 QUANT_PARAM *quant_param, int plane, 1923 int block, TX_SIZE tx_size, 1924 int quant_b_adapt, int qstep, 1925 unsigned int coeff_opt_satd_threshold, 1926 int skip_trellis, int dc_only_blk) { 1927 if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX)) 1928 return skip_trellis; 1929 1930 const struct macroblock_plane *const p = &x->plane[plane]; 1931 const int block_offset = BLOCK_OFFSET(block); 1932 tran_low_t *const coeff_ptr = p->coeff + block_offset; 1933 const int n_coeffs = av1_get_max_eob(tx_size); 1934 const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)); 1935 int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs); 1936 satd = RIGHT_SIGNED_SHIFT(satd, shift); 1937 satd >>= (x->e_mbd.bd - 8); 1938 1939 const int skip_block_trellis = 1940 ((uint64_t)satd > 1941 (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]); 1942 1943 av1_setup_quant( 1944 tx_size, !skip_block_trellis, 1945 skip_block_trellis 1946 ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP) 1947 : AV1_XFORM_QUANT_FP, 1948 quant_b_adapt, quant_param); 1949 1950 return skip_block_trellis; 1951 } 1952 1953 // Predict DC only blocks if the residual variance is below a qstep based 1954 // threshold.For such blocks, transform type search is bypassed. 1955 static inline void predict_dc_only_block( 1956 MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size, 1957 int block, int blk_row, int blk_col, RD_STATS *best_rd_stats, 1958 int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean, 1959 int *dc_only_blk) { 1960 MACROBLOCKD *xd = &x->e_mbd; 1961 MB_MODE_INFO *mbmi = xd->mi[0]; 1962 const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3; 1963 const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift; 1964 uint64_t block_var = UINT64_MAX; 1965 const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3; 1966 *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize, 1967 txsize_to_bsize[tx_size], block_mse_q8, 1968 per_px_mean, &block_var); 1969 assert((*block_mse_q8) != UINT_MAX); 1970 uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep); 1971 if (is_cur_buf_hbd(xd)) 1972 block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2); 1973 1974 if (block_var >= var_threshold) return; 1975 const unsigned int predict_dc_level = x->txfm_search_params.predict_dc_level; 1976 assert(predict_dc_level != 0); 1977 1978 // Prediction of skip block if residual mean and variance are less 1979 // than qstep based threshold 1980 if ((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) { 1981 // If the normalized mean of residual block is less than the dc qstep and 1982 // the normalized block variance is less than ac qstep, then the block is 1983 // assumed to be a skip block and its rdcost is updated accordingly. 1984 best_rd_stats->skip_txfm = 1; 1985 1986 x->plane[plane].eobs[block] = 0; 1987 1988 if (is_cur_buf_hbd(xd)) 1989 *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2); 1990 1991 best_rd_stats->dist = (*block_sse) << 4; 1992 best_rd_stats->sse = best_rd_stats->dist; 1993 1994 ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE]; 1995 ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE]; 1996 av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl); 1997 ENTROPY_CONTEXT *ta = ctxa; 1998 ENTROPY_CONTEXT *tl = ctxl; 1999 const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size); 2000 TXB_CTX txb_ctx_tmp; 2001 const PLANE_TYPE plane_type = get_plane_type(plane); 2002 get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp); 2003 const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type] 2004 .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1]; 2005 best_rd_stats->rate = zero_blk_rate; 2006 2007 best_rd_stats->rdcost = 2008 RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse); 2009 2010 x->plane[plane].txb_entropy_ctx[block] = 0; 2011 } else if (predict_dc_level > 1) { 2012 // Predict DC only blocks based on residual variance. 2013 // For chroma plane, this prediction is disabled for intra blocks. 2014 if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1; 2015 } 2016 } 2017 2018 // Search for the best transform type for a given transform block. 2019 // This function can be used for both inter and intra, both luma and chroma. 2020 static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, 2021 int block, int blk_row, int blk_col, 2022 BLOCK_SIZE plane_bsize, TX_SIZE tx_size, 2023 const TXB_CTX *const txb_ctx, 2024 FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis, 2025 int64_t ref_best_rd, RD_STATS *best_rd_stats) { 2026 const AV1_COMMON *cm = &cpi->common; 2027 MACROBLOCKD *xd = &x->e_mbd; 2028 MB_MODE_INFO *mbmi = xd->mi[0]; 2029 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 2030 int64_t best_rd = INT64_MAX; 2031 uint16_t best_eob = 0; 2032 TX_TYPE best_tx_type = DCT_DCT; 2033 int rate_cost = 0; 2034 struct macroblock_plane *const p = &x->plane[plane]; 2035 tran_low_t *orig_dqcoeff = p->dqcoeff; 2036 tran_low_t *best_dqcoeff = x->dqcoeff_buf; 2037 const int tx_type_map_idx = 2038 plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col; 2039 av1_invalid_rd_stats(best_rd_stats); 2040 2041 skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id], 2042 DRY_RUN_NORMAL); 2043 2044 uint8_t best_txb_ctx = 0; 2045 // txk_allowed = TX_TYPES: >1 tx types are allowed 2046 // txk_allowed < TX_TYPES: only that specific tx type is allowed. 2047 TX_TYPE txk_allowed = TX_TYPES; 2048 int txk_map[TX_TYPES] = { 2049 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 2050 }; 2051 const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3; 2052 const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift; 2053 2054 const uint8_t txw = tx_size_wide[tx_size]; 2055 const uint8_t txh = tx_size_high[tx_size]; 2056 int64_t block_sse; 2057 unsigned int block_mse_q8; 2058 int dc_only_blk = 0; 2059 const bool predict_dc_block = 2060 txfm_params->predict_dc_level >= 1 && txw != 64 && txh != 64; 2061 int64_t per_px_mean = INT64_MAX; 2062 if (predict_dc_block) { 2063 predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row, 2064 blk_col, best_rd_stats, &block_sse, &block_mse_q8, 2065 &per_px_mean, &dc_only_blk); 2066 if (best_rd_stats->skip_txfm == 1) { 2067 const TX_TYPE tx_type = DCT_DCT; 2068 if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type; 2069 return; 2070 } 2071 } else { 2072 block_sse = av1_pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize, 2073 txsize_to_bsize[tx_size], &block_mse_q8); 2074 assert(block_mse_q8 != UINT_MAX); 2075 } 2076 2077 // Bit mask to indicate which transform types are allowed in the RD search. 2078 uint16_t tx_mask; 2079 2080 // Use DCT_DCT transform for DC only block. 2081 if (dc_only_blk || cpi->sf.rt_sf.dct_only_palette_nonrd == 1) 2082 tx_mask = 1 << DCT_DCT; 2083 else 2084 tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize, 2085 tx_size, txb_ctx, ftxs_mode, ref_best_rd, 2086 &txk_allowed, txk_map); 2087 const uint16_t allowed_tx_mask = tx_mask; 2088 2089 if (is_cur_buf_hbd(xd)) { 2090 block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2); 2091 block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2); 2092 } 2093 block_sse *= 16; 2094 // Use mse / qstep^2 based threshold logic to take decision of R-D 2095 // optimization of coeffs. For smaller residuals, coeff optimization 2096 // would be helpful. For larger residuals, R-D optimization may not be 2097 // effective. 2098 // TODO(any): Experiment with variance and mean based thresholds 2099 const int perform_block_coeff_opt = 2100 ((uint64_t)block_mse_q8 <= 2101 (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep); 2102 skip_trellis |= !perform_block_coeff_opt; 2103 2104 // Flag to indicate if distortion should be calculated in transform domain or 2105 // not during iterating through transform type candidates. 2106 // Transform domain distortion is accurate for higher residuals. 2107 // TODO(any): Experiment with variance and mean based thresholds 2108 int use_transform_domain_distortion = 2109 (txfm_params->use_transform_domain_distortion > 0) && 2110 (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) && 2111 // Any 64-pt transforms only preserves half the coefficients. 2112 // Therefore transform domain distortion is not valid for these 2113 // transform sizes. 2114 (txsize_sqr_up_map[tx_size] != TX_64X64) && 2115 // Use pixel domain distortion for DC only blocks 2116 !dc_only_blk; 2117 // Flag to indicate if an extra calculation of distortion in the pixel domain 2118 // should be performed at the end, after the best transform type has been 2119 // decided. 2120 int calc_pixel_domain_distortion_final = 2121 txfm_params->use_transform_domain_distortion == 1 && 2122 use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD; 2123 if (calc_pixel_domain_distortion_final && 2124 (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001)) 2125 calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0; 2126 2127 const uint16_t *eobs_ptr = x->plane[plane].eobs; 2128 2129 TxfmParam txfm_param; 2130 QUANT_PARAM quant_param; 2131 int skip_trellis_based_on_satd[TX_TYPES] = { 0 }; 2132 av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param); 2133 av1_setup_quant(tx_size, !skip_trellis, 2134 skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B 2135 : AV1_XFORM_QUANT_FP) 2136 : AV1_XFORM_QUANT_FP, 2137 cpi->oxcf.q_cfg.quant_b_adapt, &quant_param); 2138 2139 // Iterate through all transform type candidates. 2140 for (int idx = 0; idx < TX_TYPES; ++idx) { 2141 const TX_TYPE tx_type = (TX_TYPE)txk_map[idx]; 2142 if (tx_type == TX_TYPE_INVALID || !check_bit_mask(allowed_tx_mask, tx_type)) 2143 continue; 2144 txfm_param.tx_type = tx_type; 2145 if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) { 2146 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type, 2147 &quant_param); 2148 } 2149 if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type; 2150 RD_STATS this_rd_stats; 2151 av1_invalid_rd_stats(&this_rd_stats); 2152 2153 if (!dc_only_blk) 2154 av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param); 2155 else 2156 av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean); 2157 2158 skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd( 2159 x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt, 2160 qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk); 2161 2162 av1_quant(x, plane, block, &txfm_param, &quant_param); 2163 2164 // Calculate rate cost of quantized coefficients. 2165 if (quant_param.use_optimize_b) { 2166 // TODO(aomedia:3209): update Trellis quantization to take into account 2167 // quantization matrices. 2168 av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx, 2169 &rate_cost); 2170 } else { 2171 rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx, 2172 cm->features.reduced_tx_set_used); 2173 } 2174 2175 // If rd cost based on coeff rate alone is already more than best_rd, 2176 // terminate early. 2177 if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue; 2178 2179 // Calculate distortion. 2180 if (eobs_ptr[block] == 0) { 2181 // When eob is 0, pixel domain distortion is more efficient and accurate. 2182 this_rd_stats.dist = this_rd_stats.sse = block_sse; 2183 } else if (dc_only_blk) { 2184 this_rd_stats.sse = block_sse; 2185 this_rd_stats.dist = dist_block_px_domain( 2186 cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size); 2187 } else if (use_transform_domain_distortion) { 2188 const SCAN_ORDER *const scan_order = 2189 get_scan(txfm_param.tx_size, txfm_param.tx_type); 2190 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix, 2191 scan_order->scan, &this_rd_stats.dist, 2192 &this_rd_stats.sse); 2193 } else { 2194 int64_t sse_diff = INT64_MAX; 2195 // high_energy threshold assumes that every pixel within a txfm block 2196 // has a residue energy of at least 25% of the maximum, i.e. 128 * 128 2197 // for 8 bit. 2198 const int64_t high_energy_thresh = 2199 ((int64_t)128 * 128 * tx_size_2d[tx_size]); 2200 const int is_high_energy = (block_sse >= high_energy_thresh); 2201 if (tx_size == TX_64X64 || is_high_energy) { 2202 // Because 3 out 4 quadrants of transform coefficients are forced to 2203 // zero, the inverse transform has a tendency to overflow. sse_diff 2204 // is effectively the energy of those 3 quadrants, here we use it 2205 // to decide if we should do pixel domain distortion. If the energy 2206 // is mostly in first quadrant, then it is unlikely that we have 2207 // overflow issue in inverse transform. 2208 const SCAN_ORDER *const scan_order = 2209 get_scan(txfm_param.tx_size, txfm_param.tx_type); 2210 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix, 2211 scan_order->scan, &this_rd_stats.dist, 2212 &this_rd_stats.sse); 2213 sse_diff = block_sse - this_rd_stats.sse; 2214 } 2215 if (tx_size != TX_64X64 || !is_high_energy || 2216 (sse_diff * 2) < this_rd_stats.sse) { 2217 const int64_t tx_domain_dist = this_rd_stats.dist; 2218 this_rd_stats.dist = dist_block_px_domain( 2219 cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size); 2220 // For high energy blocks, occasionally, the pixel domain distortion 2221 // can be artificially low due to clamping at reconstruction stage 2222 // even when inverse transform output is hugely different from the 2223 // actual residue. 2224 if (is_high_energy && this_rd_stats.dist < tx_domain_dist) 2225 this_rd_stats.dist = tx_domain_dist; 2226 } else { 2227 assert(sse_diff < INT64_MAX); 2228 this_rd_stats.dist += sse_diff; 2229 } 2230 this_rd_stats.sse = block_sse; 2231 } 2232 2233 this_rd_stats.rate = rate_cost; 2234 2235 const int64_t rd = 2236 RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist); 2237 2238 if (rd < best_rd) { 2239 best_rd = rd; 2240 *best_rd_stats = this_rd_stats; 2241 best_tx_type = tx_type; 2242 best_txb_ctx = x->plane[plane].txb_entropy_ctx[block]; 2243 best_eob = x->plane[plane].eobs[block]; 2244 // Swap dqcoeff buffers 2245 tran_low_t *const tmp_dqcoeff = best_dqcoeff; 2246 best_dqcoeff = p->dqcoeff; 2247 p->dqcoeff = tmp_dqcoeff; 2248 } 2249 2250 #if CONFIG_COLLECT_RD_STATS == 1 2251 if (plane == 0) { 2252 PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col, 2253 plane_bsize, tx_size, tx_type, rd); 2254 } 2255 #endif // CONFIG_COLLECT_RD_STATS == 1 2256 2257 #if COLLECT_TX_SIZE_DATA 2258 // Generate small sample to restrict output size. 2259 static unsigned int seed = 21743; 2260 if (lcg_rand16(&seed) % 200 == 0) { 2261 FILE *fp = NULL; 2262 2263 if (within_border) { 2264 fp = fopen(av1_tx_size_data_output_file, "a"); 2265 } 2266 2267 if (fp) { 2268 // Transform info and RD 2269 const int txb_w = tx_size_wide[tx_size]; 2270 const int txb_h = tx_size_high[tx_size]; 2271 2272 // Residue signal. 2273 const int diff_stride = block_size_wide[plane_bsize]; 2274 struct macroblock_plane *const p = &x->plane[plane]; 2275 const int16_t *src_diff = 2276 &p->src_diff[(blk_row * diff_stride + blk_col) * 4]; 2277 2278 for (int r = 0; r < txb_h; ++r) { 2279 for (int c = 0; c < txb_w; ++c) { 2280 fprintf(fp, "%d,", src_diff[c]); 2281 } 2282 src_diff += diff_stride; 2283 } 2284 2285 fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd); 2286 fprintf(fp, "\n"); 2287 fclose(fp); 2288 } 2289 } 2290 #endif // COLLECT_TX_SIZE_DATA 2291 2292 // If the current best RD cost is much worse than the reference RD cost, 2293 // terminate early. 2294 if (cpi->sf.tx_sf.adaptive_txb_search_level) { 2295 if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) > 2296 ref_best_rd) { 2297 break; 2298 } 2299 } 2300 2301 // Terminate transform type search if the block has been quantized to 2302 // all zero. 2303 if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break; 2304 } 2305 2306 assert(best_rd != INT64_MAX); 2307 2308 best_rd_stats->skip_txfm = best_eob == 0; 2309 if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type); 2310 x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx; 2311 x->plane[plane].eobs[block] = best_eob; 2312 skip_trellis = skip_trellis_based_on_satd[best_tx_type]; 2313 2314 // Point dqcoeff to the quantized coefficients corresponding to the best 2315 // transform type, then we can skip transform and quantization, e.g. in the 2316 // final pixel domain distortion calculation and recon_intra(). 2317 p->dqcoeff = best_dqcoeff; 2318 2319 if (calc_pixel_domain_distortion_final && best_eob) { 2320 best_rd_stats->dist = dist_block_px_domain( 2321 cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size); 2322 best_rd_stats->sse = block_sse; 2323 } 2324 2325 // Intra mode needs decoded pixels such that the next transform block 2326 // can use them for prediction. 2327 recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size, 2328 txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob); 2329 p->dqcoeff = orig_dqcoeff; 2330 } 2331 2332 // Pick transform type for a luma transform block of tx_size. Note this function 2333 // is used only for inter-predicted blocks. 2334 static inline void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x, 2335 TX_SIZE tx_size, int blk_row, int blk_col, 2336 int block, int plane_bsize, TXB_CTX *txb_ctx, 2337 RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode, 2338 int64_t ref_rdcost) { 2339 assert(is_inter_block(x->e_mbd.mi[0])); 2340 RD_STATS this_rd_stats; 2341 const int skip_trellis = 0; 2342 search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size, 2343 txb_ctx, ftxs_mode, skip_trellis, ref_rdcost, &this_rd_stats); 2344 2345 av1_merge_rd_stats(rd_stats, &this_rd_stats); 2346 } 2347 2348 static inline void try_tx_block_no_split( 2349 const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block, 2350 TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, 2351 const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl, 2352 int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd, 2353 FAST_TX_SEARCH_MODE ftxs_mode, TxCandidateInfo *no_split) { 2354 MACROBLOCKD *const xd = &x->e_mbd; 2355 MB_MODE_INFO *const mbmi = xd->mi[0]; 2356 struct macroblock_plane *const p = &x->plane[0]; 2357 const int bw = mi_size_wide[plane_bsize]; 2358 const ENTROPY_CONTEXT *const pta = ta + blk_col; 2359 const ENTROPY_CONTEXT *const ptl = tl + blk_row; 2360 const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size); 2361 TXB_CTX txb_ctx; 2362 get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx); 2363 const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y] 2364 .txb_skip_cost[txb_ctx.txb_skip_ctx][1]; 2365 rd_stats->zero_rate = zero_blk_rate; 2366 const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col); 2367 mbmi->inter_tx_size[index] = tx_size; 2368 tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx, 2369 rd_stats, ftxs_mode, ref_best_rd); 2370 assert(rd_stats->rate < INT_MAX); 2371 2372 const int pick_skip_txfm = 2373 !xd->lossless[mbmi->segment_id] && 2374 (rd_stats->skip_txfm == 1 || 2375 RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >= 2376 RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse)); 2377 if (pick_skip_txfm) { 2378 #if CONFIG_RD_DEBUG 2379 update_txb_coeff_cost(rd_stats, 0, zero_blk_rate - rd_stats->rate); 2380 #endif // CONFIG_RD_DEBUG 2381 rd_stats->rate = zero_blk_rate; 2382 rd_stats->dist = rd_stats->sse; 2383 p->eobs[block] = 0; 2384 update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT); 2385 } 2386 rd_stats->skip_txfm = pick_skip_txfm; 2387 set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col, 2388 pick_skip_txfm); 2389 2390 if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH) 2391 rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0]; 2392 2393 no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); 2394 no_split->txb_entropy_ctx = p->txb_entropy_ctx[block]; 2395 no_split->tx_type = 2396 xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col]; 2397 } 2398 2399 static inline void try_tx_block_split( 2400 const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block, 2401 TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta, 2402 ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, 2403 int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd, 2404 FAST_TX_SEARCH_MODE ftxs_mode, RD_STATS *split_rd_stats) { 2405 assert(tx_size < TX_SIZES_ALL); 2406 MACROBLOCKD *const xd = &x->e_mbd; 2407 const int max_blocks_high = max_block_high(xd, plane_bsize, 0); 2408 const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0); 2409 const int txb_width = tx_size_wide_unit[tx_size]; 2410 const int txb_height = tx_size_high_unit[tx_size]; 2411 // Transform size after splitting current block. 2412 const TX_SIZE sub_txs = sub_tx_size_map[tx_size]; 2413 const int sub_txb_width = tx_size_wide_unit[sub_txs]; 2414 const int sub_txb_height = tx_size_high_unit[sub_txs]; 2415 const int sub_step = sub_txb_width * sub_txb_height; 2416 const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width); 2417 assert(nblks > 0); 2418 av1_init_rd_stats(split_rd_stats); 2419 split_rd_stats->rate = 2420 x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1]; 2421 2422 for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) { 2423 const int offsetr = blk_row + r; 2424 if (offsetr >= max_blocks_high) break; 2425 for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) { 2426 assert(blk_idx < 4); 2427 const int offsetc = blk_col + c; 2428 if (offsetc >= max_blocks_wide) continue; 2429 2430 RD_STATS this_rd_stats; 2431 int this_cost_valid = 1; 2432 select_tx_block(cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, 2433 plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats, 2434 no_split_rd / nblks, ref_best_rd - split_rd_stats->rdcost, 2435 &this_cost_valid, ftxs_mode); 2436 if (!this_cost_valid) { 2437 split_rd_stats->rdcost = INT64_MAX; 2438 return; 2439 } 2440 av1_merge_rd_stats(split_rd_stats, &this_rd_stats); 2441 split_rd_stats->rdcost = 2442 RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist); 2443 if (split_rd_stats->rdcost > ref_best_rd) { 2444 split_rd_stats->rdcost = INT64_MAX; 2445 return; 2446 } 2447 block += sub_step; 2448 } 2449 } 2450 } 2451 2452 static float get_var(float mean, double x2_sum, int num) { 2453 const float e_x2 = (float)(x2_sum / num); 2454 const float diff = e_x2 - mean * mean; 2455 return diff; 2456 } 2457 2458 static inline void get_blk_var_dev(const int16_t *data, int stride, int bw, 2459 int bh, float *dev_of_mean, 2460 float *var_of_vars) { 2461 const int16_t *const data_ptr = &data[0]; 2462 const int subh = (bh >= bw) ? (bh >> 1) : bh; 2463 const int subw = (bw >= bh) ? (bw >> 1) : bw; 2464 const int num = bw * bh; 2465 const int sub_num = subw * subh; 2466 int total_x_sum = 0; 2467 int64_t total_x2_sum = 0; 2468 int blk_idx = 0; 2469 float var_sum = 0.0f; 2470 float mean_sum = 0.0f; 2471 double var2_sum = 0.0f; 2472 double mean2_sum = 0.0f; 2473 2474 for (int row = 0; row < bh; row += subh) { 2475 for (int col = 0; col < bw; col += subw) { 2476 int x_sum; 2477 int64_t x2_sum; 2478 aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh, 2479 &x_sum, &x2_sum); 2480 total_x_sum += x_sum; 2481 total_x2_sum += x2_sum; 2482 2483 const float mean = (float)x_sum / sub_num; 2484 const float var = get_var(mean, (double)x2_sum, sub_num); 2485 mean_sum += mean; 2486 mean2_sum += (double)(mean * mean); 2487 var_sum += var; 2488 var2_sum += var * var; 2489 blk_idx++; 2490 } 2491 } 2492 2493 const float lvl0_mean = (float)total_x_sum / num; 2494 const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num); 2495 mean_sum += lvl0_mean; 2496 mean2_sum += (double)(lvl0_mean * lvl0_mean); 2497 var_sum += block_var; 2498 var2_sum += block_var * block_var; 2499 const float av_mean = mean_sum / 5; 2500 2501 if (blk_idx > 1) { 2502 // Deviation of means. 2503 *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1)); 2504 // Variance of variances. 2505 const float mean_var = var_sum / (blk_idx + 1); 2506 *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1)); 2507 } 2508 } 2509 2510 static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize, 2511 int blk_row, int blk_col, TX_SIZE tx_size, 2512 int *try_no_split, int *try_split, 2513 int pruning_level) { 2514 const int diff_stride = block_size_wide[bsize]; 2515 const int16_t *diff = 2516 x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col; 2517 const int bw = tx_size_wide[tx_size]; 2518 const int bh = tx_size_high[tx_size]; 2519 float dev_of_means = 0.0f; 2520 float var_of_vars = 0.0f; 2521 2522 // This function calculates the deviation of means, and the variance of pixel 2523 // variances of the block as well as it's sub-blocks. 2524 get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars); 2525 const int dc_q = x->plane[0].dequant_QTX[0] >> 3; 2526 const int ac_q = x->plane[0].dequant_QTX[1] >> 3; 2527 const int no_split_thresh_scales[4] = { 0, 24, 8, 8 }; 2528 const int no_split_thresh_scale = no_split_thresh_scales[pruning_level]; 2529 const int split_thresh_scales[4] = { 0, 24, 10, 8 }; 2530 const int split_thresh_scale = split_thresh_scales[pruning_level]; 2531 2532 if ((dev_of_means <= dc_q) && 2533 (split_thresh_scale * var_of_vars <= ac_q * ac_q)) { 2534 *try_split = 0; 2535 } 2536 if ((dev_of_means > no_split_thresh_scale * dc_q) && 2537 (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) { 2538 *try_no_split = 0; 2539 } 2540 } 2541 2542 // Search for the best transform partition(recursive)/type for a given 2543 // inter-predicted luma block. The obtained transform selection will be saved 2544 // in xd->mi[0], the corresponding RD stats will be saved in rd_stats. 2545 static inline void select_tx_block( 2546 const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block, 2547 TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta, 2548 ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, 2549 RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd, 2550 int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode) { 2551 assert(tx_size < TX_SIZES_ALL); 2552 av1_init_rd_stats(rd_stats); 2553 if (ref_best_rd < 0) { 2554 *is_cost_valid = 0; 2555 return; 2556 } 2557 2558 MACROBLOCKD *const xd = &x->e_mbd; 2559 assert(blk_row < max_block_high(xd, plane_bsize, 0) && 2560 blk_col < max_block_wide(xd, plane_bsize, 0)); 2561 MB_MODE_INFO *const mbmi = xd->mi[0]; 2562 const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row, 2563 mbmi->bsize, tx_size); 2564 struct macroblock_plane *const p = &x->plane[0]; 2565 2566 int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 || 2567 txsize_sqr_up_map[tx_size] != TX_64X64) && 2568 (cpi->oxcf.txfm_cfg.enable_rect_tx || 2569 tx_size_wide[tx_size] == tx_size_high[tx_size]); 2570 int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH; 2571 TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES }; 2572 2573 // Prune tx_split and no-split based on sub-block properties. 2574 if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 && 2575 cpi->sf.tx_sf.prune_tx_size_level > 0) { 2576 prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size, 2577 &try_no_split, &try_split, 2578 cpi->sf.tx_sf.prune_tx_size_level); 2579 } 2580 2581 if (cpi->sf.rt_sf.skip_tx_no_split_var_based_partition) { 2582 if (x->try_merge_partition && try_split && p->eobs[block]) try_no_split = 0; 2583 } 2584 2585 // Try using current block as a single transform block without split. 2586 if (try_no_split) { 2587 try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth, 2588 plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd, 2589 ftxs_mode, &no_split); 2590 2591 // Speed features for early termination. 2592 const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level; 2593 if (search_level) { 2594 if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) { 2595 *is_cost_valid = 0; 2596 return; 2597 } 2598 if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) { 2599 try_split = 0; 2600 } 2601 } 2602 if (cpi->sf.tx_sf.txb_split_cap) { 2603 if (p->eobs[block] == 0) try_split = 0; 2604 } 2605 } 2606 2607 // ML based speed feature to skip searching for split transform blocks. 2608 if (x->e_mbd.bd == 8 && try_split && 2609 !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) { 2610 const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh; 2611 if (threshold >= 0) { 2612 const int split_score = 2613 ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size); 2614 if (split_score < -threshold) try_split = 0; 2615 } 2616 } 2617 2618 RD_STATS split_rd_stats; 2619 split_rd_stats.rdcost = INT64_MAX; 2620 // Try splitting current block into smaller transform blocks. 2621 if (try_split) { 2622 try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth, 2623 plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd, 2624 AOMMIN(no_split.rd, ref_best_rd), ftxs_mode, 2625 &split_rd_stats); 2626 } 2627 2628 if (no_split.rd < split_rd_stats.rdcost) { 2629 ENTROPY_CONTEXT *pta = ta + blk_col; 2630 ENTROPY_CONTEXT *ptl = tl + blk_row; 2631 p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx; 2632 av1_set_txb_context(x, 0, block, tx_size, pta, ptl); 2633 txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size, 2634 tx_size); 2635 for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) { 2636 for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) { 2637 const int index = 2638 av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx); 2639 mbmi->inter_tx_size[index] = tx_size; 2640 } 2641 } 2642 mbmi->tx_size = tx_size; 2643 update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type); 2644 const int bw = mi_size_wide[plane_bsize]; 2645 set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col, 2646 rd_stats->skip_txfm); 2647 } else { 2648 *rd_stats = split_rd_stats; 2649 if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0; 2650 } 2651 } 2652 2653 static inline void choose_largest_tx_size(const AV1_COMP *const cpi, 2654 MACROBLOCK *x, RD_STATS *rd_stats, 2655 int64_t ref_best_rd, BLOCK_SIZE bs) { 2656 MACROBLOCKD *const xd = &x->e_mbd; 2657 MB_MODE_INFO *const mbmi = xd->mi[0]; 2658 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 2659 mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type); 2660 2661 // If tx64 is not enabled, we need to go down to the next available size 2662 if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) { 2663 static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = { 2664 TX_4X4, // 4x4 transform 2665 TX_8X8, // 8x8 transform 2666 TX_16X16, // 16x16 transform 2667 TX_32X32, // 32x32 transform 2668 TX_32X32, // 64x64 transform 2669 TX_4X8, // 4x8 transform 2670 TX_8X4, // 8x4 transform 2671 TX_8X16, // 8x16 transform 2672 TX_16X8, // 16x8 transform 2673 TX_16X32, // 16x32 transform 2674 TX_32X16, // 32x16 transform 2675 TX_32X32, // 32x64 transform 2676 TX_32X32, // 64x32 transform 2677 TX_4X16, // 4x16 transform 2678 TX_16X4, // 16x4 transform 2679 TX_8X32, // 8x32 transform 2680 TX_32X8, // 32x8 transform 2681 TX_16X32, // 16x64 transform 2682 TX_32X16, // 64x16 transform 2683 }; 2684 mbmi->tx_size = tx_size_max_32[mbmi->tx_size]; 2685 } else if (cpi->oxcf.txfm_cfg.enable_tx64 && 2686 !cpi->oxcf.txfm_cfg.enable_rect_tx) { 2687 static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = { 2688 TX_4X4, // 4x4 transform 2689 TX_8X8, // 8x8 transform 2690 TX_16X16, // 16x16 transform 2691 TX_32X32, // 32x32 transform 2692 TX_64X64, // 64x64 transform 2693 TX_4X4, // 4x8 transform 2694 TX_4X4, // 8x4 transform 2695 TX_8X8, // 8x16 transform 2696 TX_8X8, // 16x8 transform 2697 TX_16X16, // 16x32 transform 2698 TX_16X16, // 32x16 transform 2699 TX_32X32, // 32x64 transform 2700 TX_32X32, // 64x32 transform 2701 TX_4X4, // 4x16 transform 2702 TX_4X4, // 16x4 transform 2703 TX_8X8, // 8x32 transform 2704 TX_8X8, // 32x8 transform 2705 TX_16X16, // 16x64 transform 2706 TX_16X16, // 64x16 transform 2707 }; 2708 mbmi->tx_size = tx_size_max_square[mbmi->tx_size]; 2709 } else if (!cpi->oxcf.txfm_cfg.enable_tx64 && 2710 !cpi->oxcf.txfm_cfg.enable_rect_tx) { 2711 static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = { 2712 TX_4X4, // 4x4 transform 2713 TX_8X8, // 8x8 transform 2714 TX_16X16, // 16x16 transform 2715 TX_32X32, // 32x32 transform 2716 TX_32X32, // 64x64 transform 2717 TX_4X4, // 4x8 transform 2718 TX_4X4, // 8x4 transform 2719 TX_8X8, // 8x16 transform 2720 TX_8X8, // 16x8 transform 2721 TX_16X16, // 16x32 transform 2722 TX_16X16, // 32x16 transform 2723 TX_32X32, // 32x64 transform 2724 TX_32X32, // 64x32 transform 2725 TX_4X4, // 4x16 transform 2726 TX_4X4, // 16x4 transform 2727 TX_8X8, // 8x32 transform 2728 TX_8X8, // 32x8 transform 2729 TX_16X16, // 16x64 transform 2730 TX_16X16, // 64x16 transform 2731 }; 2732 2733 mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size]; 2734 } 2735 2736 const int skip_ctx = av1_get_skip_txfm_context(xd); 2737 const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0]; 2738 const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1]; 2739 // Skip RDcost is used only for Inter blocks 2740 const int64_t skip_txfm_rd = 2741 is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX; 2742 const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0); 2743 const int skip_trellis = 0; 2744 av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 2745 AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs, 2746 mbmi->tx_size, FTXS_NONE, skip_trellis); 2747 } 2748 2749 static inline void choose_smallest_tx_size(const AV1_COMP *const cpi, 2750 MACROBLOCK *x, RD_STATS *rd_stats, 2751 int64_t ref_best_rd, BLOCK_SIZE bs) { 2752 MACROBLOCKD *const xd = &x->e_mbd; 2753 MB_MODE_INFO *const mbmi = xd->mi[0]; 2754 2755 mbmi->tx_size = TX_4X4; 2756 // TODO(any) : Pass this_rd based on skip/non-skip cost 2757 const int skip_trellis = 0; 2758 av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size, 2759 FTXS_NONE, skip_trellis); 2760 } 2761 2762 #if !CONFIG_REALTIME_ONLY 2763 static void ml_predict_intra_tx_depth_prune(MACROBLOCK *x, int blk_row, 2764 int blk_col, BLOCK_SIZE bsize, 2765 TX_SIZE tx_size) { 2766 const MACROBLOCKD *const xd = &x->e_mbd; 2767 const MB_MODE_INFO *const mbmi = xd->mi[0]; 2768 2769 // Disable the pruning logic using NN model for the following cases: 2770 // 1) Lossless coding as only 4x4 transform is evaluated in this case 2771 // 2) When transform and current block sizes do not match as the features are 2772 // obtained over the current block 2773 // 3) When operating bit-depth is not 8-bit as the input features are not 2774 // scaled according to bit-depth. 2775 if (xd->lossless[mbmi->segment_id] || txsize_to_bsize[tx_size] != bsize || 2776 xd->bd != 8) 2777 return; 2778 2779 // Currently NN model based pruning is supported only when largest transform 2780 // size is 8x8 2781 if (tx_size != TX_8X8) return; 2782 2783 // Neural network model is a sequential neural net and was trained using SGD 2784 // optimizer. The model can be further improved in terms of speed/quality by 2785 // considering the following experiments: 2786 // 1) Generate ML model by training with balanced data for different learning 2787 // rates and optimizers. 2788 // 2) Experiment with ML model by adding features related to the statistics of 2789 // top and left pixels to capture the accuracy of reconstructed neighbouring 2790 // pixels for 4x4 blocks numbered 1, 2, 3 in 8x8 block, source variance of 4x4 2791 // sub-blocks, etc. 2792 // 3) Generate ML models for transform blocks other than 8x8. 2793 const NN_CONFIG *const nn_config = &av1_intra_tx_split_nnconfig_8x8; 2794 const float *const intra_tx_prune_thresh = av1_intra_tx_prune_nn_thresh_8x8; 2795 2796 float features[NUM_INTRA_TX_SPLIT_FEATURES] = { 0.0f }; 2797 const int diff_stride = block_size_wide[bsize]; 2798 2799 const int16_t *diff = x->plane[0].src_diff + MI_SIZE * blk_row * diff_stride + 2800 MI_SIZE * blk_col; 2801 const int bw = tx_size_wide[tx_size]; 2802 const int bh = tx_size_high[tx_size]; 2803 2804 int feature_idx = get_mean_dev_features(diff, diff_stride, bw, bh, features); 2805 2806 features[feature_idx++] = log1pf((float)x->source_variance); 2807 2808 const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8); 2809 const float log_dc_q_square = log1pf((float)(dc_q * dc_q) / 256.0f); 2810 features[feature_idx++] = log_dc_q_square; 2811 assert(feature_idx == NUM_INTRA_TX_SPLIT_FEATURES); 2812 for (int i = 0; i < NUM_INTRA_TX_SPLIT_FEATURES; i++) { 2813 features[i] = (features[i] - av1_intra_tx_split_8x8_mean[i]) / 2814 av1_intra_tx_split_8x8_std[i]; 2815 } 2816 2817 float score; 2818 av1_nn_predict(features, nn_config, 1, &score); 2819 2820 TxfmSearchParams *const txfm_params = &x->txfm_search_params; 2821 if (score <= intra_tx_prune_thresh[0]) 2822 txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_SPLIT; 2823 else if (score > intra_tx_prune_thresh[1]) 2824 txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_LARGEST; 2825 } 2826 #endif // !CONFIG_REALTIME_ONLY 2827 2828 /*!\brief Transform type search for luma macroblock with fixed transform size. 2829 * 2830 * \ingroup transform_search 2831 * Search for the best transform type and return the transform coefficients RD 2832 * cost of current luma macroblock with the given uniform transform size. 2833 * 2834 * \param[in] x Pointer to structure holding the data for the 2835 current encoding macroblock 2836 * \param[in] cpi Top-level encoder structure 2837 * \param[in] rd_stats Pointer to struct to keep track of the RD stats 2838 * \param[in] ref_best_rd Best RD cost seen for this block so far 2839 * \param[in] bs Size of the current macroblock 2840 * \param[in] tx_size The given transform size 2841 * \param[in] ftxs_mode Transform search mode specifying desired speed 2842 and quality tradeoff 2843 * \param[in] skip_trellis Binary flag indicating if trellis optimization 2844 should be skipped 2845 * \return An int64_t value that is the best RD cost found. 2846 */ 2847 static int64_t uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x, 2848 RD_STATS *rd_stats, int64_t ref_best_rd, 2849 BLOCK_SIZE bs, TX_SIZE tx_size, 2850 FAST_TX_SEARCH_MODE ftxs_mode, 2851 int skip_trellis) { 2852 assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs))); 2853 MACROBLOCKD *const xd = &x->e_mbd; 2854 MB_MODE_INFO *const mbmi = xd->mi[0]; 2855 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 2856 const ModeCosts *mode_costs = &x->mode_costs; 2857 const int is_inter = is_inter_block(mbmi); 2858 const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT && 2859 block_signals_txsize(mbmi->bsize); 2860 int tx_size_rate = 0; 2861 if (tx_select) { 2862 const int ctx = txfm_partition_context( 2863 xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size); 2864 tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0] 2865 : tx_size_cost(x, bs, tx_size); 2866 } 2867 const int skip_ctx = av1_get_skip_txfm_context(xd); 2868 const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0]; 2869 const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1]; 2870 const int64_t skip_txfm_rd = 2871 is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX; 2872 const int64_t no_this_rd = 2873 RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0); 2874 2875 mbmi->tx_size = tx_size; 2876 av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 2877 AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs, 2878 tx_size, ftxs_mode, skip_trellis); 2879 if (rd_stats->rate == INT_MAX) return INT64_MAX; 2880 2881 int64_t rd; 2882 // rdstats->rate should include all the rate except skip/non-skip cost as the 2883 // same is accounted in the caller functions after rd evaluation of all 2884 // planes. However the decisions should be done after considering the 2885 // skip/non-skip header cost 2886 if (rd_stats->skip_txfm && is_inter) { 2887 rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse); 2888 } else { 2889 // Intra blocks are always signalled as non-skip 2890 rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate, 2891 rd_stats->dist); 2892 rd_stats->rate += tx_size_rate; 2893 } 2894 // Check if forcing the block to skip transform leads to smaller RD cost. 2895 if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) { 2896 int64_t temp_skip_txfm_rd = 2897 RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse); 2898 if (temp_skip_txfm_rd <= rd) { 2899 rd = temp_skip_txfm_rd; 2900 rd_stats->rate = 0; 2901 rd_stats->dist = rd_stats->sse; 2902 rd_stats->skip_txfm = 1; 2903 } 2904 } 2905 2906 return rd; 2907 } 2908 2909 // Search for the best uniform transform size and type for current coding block. 2910 static inline void choose_tx_size_type_from_rd(const AV1_COMP *const cpi, 2911 MACROBLOCK *x, 2912 RD_STATS *rd_stats, 2913 int64_t ref_best_rd, 2914 BLOCK_SIZE bs) { 2915 av1_invalid_rd_stats(rd_stats); 2916 2917 MACROBLOCKD *const xd = &x->e_mbd; 2918 MB_MODE_INFO *const mbmi = xd->mi[0]; 2919 TxfmSearchParams *const txfm_params = &x->txfm_search_params; 2920 const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs]; 2921 const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT; 2922 int start_tx; 2923 // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls 2924 // how many times of splitting is allowed during the RD search. 2925 int init_depth; 2926 2927 if (tx_select) { 2928 start_tx = max_rect_tx_size; 2929 init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs], 2930 is_inter_block(mbmi), &cpi->sf, 2931 txfm_params->tx_size_search_method); 2932 if (init_depth == MAX_TX_DEPTH && !cpi->oxcf.txfm_cfg.enable_tx64 && 2933 txsize_sqr_up_map[start_tx] == TX_64X64) { 2934 start_tx = sub_tx_size_map[start_tx]; 2935 } 2936 } else { 2937 const TX_SIZE chosen_tx_size = 2938 tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type); 2939 start_tx = chosen_tx_size; 2940 init_depth = MAX_TX_DEPTH; 2941 } 2942 2943 const int skip_trellis = 0; 2944 uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE]; 2945 uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE]; 2946 TX_SIZE best_tx_size = max_rect_tx_size; 2947 int64_t best_rd = INT64_MAX; 2948 const int num_blks = bsize_to_num_blk(bs); 2949 x->rd_model = FULL_TXFM_RD; 2950 int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX }; 2951 TxfmSearchInfo *txfm_info = &x->txfm_search_info; 2952 for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH; 2953 depth++, tx_size = sub_tx_size_map[tx_size]) { 2954 if ((!cpi->oxcf.txfm_cfg.enable_tx64 && 2955 txsize_sqr_up_map[tx_size] == TX_64X64) || 2956 (!cpi->oxcf.txfm_cfg.enable_rect_tx && 2957 tx_size_wide[tx_size] != tx_size_high[tx_size])) { 2958 continue; 2959 } 2960 2961 #if !CONFIG_REALTIME_ONLY 2962 if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_SPLIT) break; 2963 2964 // Set the flag to enable the evaluation of NN classifier to prune transform 2965 // depths. As the features are based on intra residual information of 2966 // largest transform, the evaluation of NN model is enabled only for this 2967 // case. 2968 txfm_params->enable_nn_prune_intra_tx_depths = 2969 (cpi->sf.tx_sf.prune_intra_tx_depths_using_nn && tx_size == start_tx); 2970 #endif 2971 2972 RD_STATS this_rd_stats; 2973 // When the speed feature use_rd_based_breakout_for_intra_tx_search is 2974 // enabled, use the known minimum best_rd for early termination. 2975 const int64_t rd_thresh = 2976 cpi->sf.tx_sf.use_rd_based_breakout_for_intra_tx_search 2977 ? AOMMIN(ref_best_rd, best_rd) 2978 : ref_best_rd; 2979 rd[depth] = uniform_txfm_yrd(cpi, x, &this_rd_stats, rd_thresh, bs, tx_size, 2980 FTXS_NONE, skip_trellis); 2981 if (rd[depth] < best_rd) { 2982 av1_copy_array(best_blk_skip, txfm_info->blk_skip, num_blks); 2983 av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks); 2984 best_tx_size = tx_size; 2985 best_rd = rd[depth]; 2986 *rd_stats = this_rd_stats; 2987 } 2988 if (tx_size == TX_4X4) break; 2989 // If we are searching three depths, prune the smallest size depending 2990 // on rd results for the first two depths for low contrast blocks. 2991 if (depth > init_depth && depth != MAX_TX_DEPTH && 2992 x->source_variance < 256) { 2993 if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break; 2994 } 2995 } 2996 2997 if (rd_stats->rate != INT_MAX) { 2998 mbmi->tx_size = best_tx_size; 2999 av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks); 3000 av1_copy_array(txfm_info->blk_skip, best_blk_skip, num_blks); 3001 } 3002 3003 #if !CONFIG_REALTIME_ONLY 3004 // Reset the flags to avoid any unintentional evaluation of NN model and 3005 // consumption of prune depths. 3006 txfm_params->enable_nn_prune_intra_tx_depths = false; 3007 txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_NONE; 3008 #endif 3009 } 3010 3011 // Search for the best transform type for the given transform block in the 3012 // given plane/channel, and calculate the corresponding RD cost. 3013 static inline void block_rd_txfm(int plane, int block, int blk_row, int blk_col, 3014 BLOCK_SIZE plane_bsize, TX_SIZE tx_size, 3015 void *arg) { 3016 struct rdcost_block_args *args = arg; 3017 if (args->exit_early) { 3018 args->incomplete_exit = 1; 3019 return; 3020 } 3021 3022 MACROBLOCK *const x = args->x; 3023 MACROBLOCKD *const xd = &x->e_mbd; 3024 const int is_inter = is_inter_block(xd->mi[0]); 3025 const AV1_COMP *cpi = args->cpi; 3026 ENTROPY_CONTEXT *a = args->t_above + blk_col; 3027 ENTROPY_CONTEXT *l = args->t_left + blk_row; 3028 const AV1_COMMON *cm = &cpi->common; 3029 RD_STATS this_rd_stats; 3030 av1_init_rd_stats(&this_rd_stats); 3031 3032 if (!is_inter) { 3033 av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size); 3034 av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size); 3035 #if !CONFIG_REALTIME_ONLY 3036 const TxfmSearchParams *const txfm_params = &x->txfm_search_params; 3037 if (txfm_params->enable_nn_prune_intra_tx_depths) { 3038 ml_predict_intra_tx_depth_prune(x, blk_row, blk_col, plane_bsize, 3039 tx_size); 3040 if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_LARGEST) { 3041 av1_invalid_rd_stats(&args->rd_stats); 3042 args->exit_early = 1; 3043 return; 3044 } 3045 } 3046 #endif 3047 } 3048 3049 TXB_CTX txb_ctx; 3050 get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx); 3051 search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size, 3052 &txb_ctx, args->ftxs_mode, args->skip_trellis, 3053 args->best_rd - args->current_rd, &this_rd_stats); 3054 3055 #if !CONFIG_REALTIME_ONLY 3056 if (plane == AOM_PLANE_Y && xd->cfl.store_y) { 3057 assert(!is_inter || plane_bsize < BLOCK_8X8); 3058 cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize); 3059 } 3060 #endif 3061 3062 #if CONFIG_RD_DEBUG 3063 update_txb_coeff_cost(&this_rd_stats, plane, this_rd_stats.rate); 3064 #endif // CONFIG_RD_DEBUG 3065 av1_set_txb_context(x, plane, block, tx_size, a, l); 3066 3067 const int blk_idx = 3068 blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col; 3069 3070 TxfmSearchInfo *txfm_info = &x->txfm_search_info; 3071 if (plane == 0) 3072 set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 3073 x->plane[plane].eobs[block] == 0); 3074 else 3075 set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0); 3076 3077 int64_t rd; 3078 if (is_inter) { 3079 const int64_t no_skip_txfm_rd = 3080 RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist); 3081 const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse); 3082 rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd); 3083 this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block]; 3084 } else { 3085 // Signal non-skip_txfm for Intra blocks 3086 rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist); 3087 this_rd_stats.skip_txfm = 0; 3088 } 3089 3090 av1_merge_rd_stats(&args->rd_stats, &this_rd_stats); 3091 3092 args->current_rd += rd; 3093 if (args->current_rd > args->best_rd) args->exit_early = 1; 3094 } 3095 3096 int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x, 3097 RD_STATS *rd_stats, int64_t ref_best_rd, 3098 BLOCK_SIZE bs, TX_SIZE tx_size) { 3099 MACROBLOCKD *const xd = &x->e_mbd; 3100 MB_MODE_INFO *const mbmi = xd->mi[0]; 3101 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 3102 const ModeCosts *mode_costs = &x->mode_costs; 3103 const int is_inter = is_inter_block(mbmi); 3104 const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT && 3105 block_signals_txsize(mbmi->bsize); 3106 int tx_size_rate = 0; 3107 if (tx_select) { 3108 const int ctx = txfm_partition_context( 3109 xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size); 3110 tx_size_rate = mode_costs->txfm_partition_cost[ctx][0]; 3111 } 3112 const int skip_ctx = av1_get_skip_txfm_context(xd); 3113 const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0]; 3114 const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1]; 3115 const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0); 3116 const int64_t no_this_rd = 3117 RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0); 3118 mbmi->tx_size = tx_size; 3119 3120 const uint8_t txw_unit = tx_size_wide_unit[tx_size]; 3121 const uint8_t txh_unit = tx_size_high_unit[tx_size]; 3122 const int step = txw_unit * txh_unit; 3123 const int max_blocks_wide = max_block_wide(xd, bs, 0); 3124 const int max_blocks_high = max_block_high(xd, bs, 0); 3125 3126 struct rdcost_block_args args; 3127 av1_zero(args); 3128 args.x = x; 3129 args.cpi = cpi; 3130 args.best_rd = ref_best_rd; 3131 args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd); 3132 av1_init_rd_stats(&args.rd_stats); 3133 av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left); 3134 int i = 0; 3135 for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit; 3136 blk_row += txh_unit) { 3137 for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) { 3138 RD_STATS this_rd_stats; 3139 av1_init_rd_stats(&this_rd_stats); 3140 3141 if (args.exit_early) { 3142 args.incomplete_exit = 1; 3143 break; 3144 } 3145 3146 ENTROPY_CONTEXT *a = args.t_above + blk_col; 3147 ENTROPY_CONTEXT *l = args.t_left + blk_row; 3148 TXB_CTX txb_ctx; 3149 get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx); 3150 3151 TxfmParam txfm_param; 3152 QUANT_PARAM quant_param; 3153 av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param); 3154 av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param); 3155 3156 av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param); 3157 av1_quant(x, 0, i, &txfm_param, &quant_param); 3158 3159 this_rd_stats.rate = 3160 cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0); 3161 3162 const SCAN_ORDER *const scan_order = 3163 get_scan(txfm_param.tx_size, txfm_param.tx_type); 3164 dist_block_tx_domain(x, 0, i, tx_size, quant_param.qmatrix, 3165 scan_order->scan, &this_rd_stats.dist, 3166 &this_rd_stats.sse); 3167 3168 const int64_t no_skip_txfm_rd = 3169 RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist); 3170 const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse); 3171 3172 this_rd_stats.skip_txfm &= !x->plane[0].eobs[i]; 3173 3174 av1_merge_rd_stats(&args.rd_stats, &this_rd_stats); 3175 args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd); 3176 3177 if (args.current_rd > ref_best_rd) { 3178 args.exit_early = 1; 3179 break; 3180 } 3181 3182 av1_set_txb_context(x, 0, i, tx_size, a, l); 3183 i += step; 3184 } 3185 } 3186 3187 if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats); 3188 3189 *rd_stats = args.rd_stats; 3190 if (rd_stats->rate == INT_MAX) return INT64_MAX; 3191 3192 int64_t rd; 3193 // rdstats->rate should include all the rate except skip/non-skip cost as the 3194 // same is accounted in the caller functions after rd evaluation of all 3195 // planes. However the decisions should be done after considering the 3196 // skip/non-skip header cost 3197 if (rd_stats->skip_txfm && is_inter) { 3198 rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse); 3199 } else { 3200 // Intra blocks are always signalled as non-skip 3201 rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate, 3202 rd_stats->dist); 3203 rd_stats->rate += tx_size_rate; 3204 } 3205 // Check if forcing the block to skip transform leads to smaller RD cost. 3206 if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) { 3207 int64_t temp_skip_txfm_rd = 3208 RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse); 3209 if (temp_skip_txfm_rd <= rd) { 3210 rd = temp_skip_txfm_rd; 3211 rd_stats->rate = 0; 3212 rd_stats->dist = rd_stats->sse; 3213 rd_stats->skip_txfm = 1; 3214 } 3215 } 3216 3217 return rd; 3218 } 3219 3220 // Search for the best transform type for a luma inter-predicted block, given 3221 // the transform block partitions. 3222 // This function is used only when some speed features are enabled. 3223 static inline void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, 3224 int blk_col, int block, TX_SIZE tx_size, 3225 BLOCK_SIZE plane_bsize, int depth, 3226 ENTROPY_CONTEXT *above_ctx, 3227 ENTROPY_CONTEXT *left_ctx, 3228 TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, 3229 int64_t ref_best_rd, RD_STATS *rd_stats, 3230 FAST_TX_SEARCH_MODE ftxs_mode) { 3231 assert(tx_size < TX_SIZES_ALL); 3232 MACROBLOCKD *const xd = &x->e_mbd; 3233 MB_MODE_INFO *const mbmi = xd->mi[0]; 3234 assert(is_inter_block(mbmi)); 3235 const int max_blocks_high = max_block_high(xd, plane_bsize, 0); 3236 const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0); 3237 3238 if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return; 3239 3240 const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index( 3241 plane_bsize, blk_row, blk_col)]; 3242 const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row, 3243 mbmi->bsize, tx_size); 3244 3245 av1_init_rd_stats(rd_stats); 3246 if (tx_size == plane_tx_size) { 3247 ENTROPY_CONTEXT *ta = above_ctx + blk_col; 3248 ENTROPY_CONTEXT *tl = left_ctx + blk_row; 3249 const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size); 3250 TXB_CTX txb_ctx; 3251 get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx); 3252 3253 const int zero_blk_rate = 3254 x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)] 3255 .txb_skip_cost[txb_ctx.txb_skip_ctx][1]; 3256 rd_stats->zero_rate = zero_blk_rate; 3257 tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx, 3258 rd_stats, ftxs_mode, ref_best_rd); 3259 const int mi_width = mi_size_wide[plane_bsize]; 3260 TxfmSearchInfo *txfm_info = &x->txfm_search_info; 3261 if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >= 3262 RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) || 3263 rd_stats->skip_txfm == 1) { 3264 rd_stats->rate = zero_blk_rate; 3265 rd_stats->dist = rd_stats->sse; 3266 rd_stats->skip_txfm = 1; 3267 set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 1); 3268 x->plane[0].eobs[block] = 0; 3269 x->plane[0].txb_entropy_ctx[block] = 0; 3270 update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT); 3271 } else { 3272 rd_stats->skip_txfm = 0; 3273 set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 0); 3274 } 3275 if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH) 3276 rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0]; 3277 av1_set_txb_context(x, 0, block, tx_size, ta, tl); 3278 txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size, 3279 tx_size); 3280 } else { 3281 const TX_SIZE sub_txs = sub_tx_size_map[tx_size]; 3282 const int txb_width = tx_size_wide_unit[sub_txs]; 3283 const int txb_height = tx_size_high_unit[sub_txs]; 3284 const int step = txb_height * txb_width; 3285 const int row_end = 3286 AOMMIN(tx_size_high_unit[tx_size], max_blocks_high - blk_row); 3287 const int col_end = 3288 AOMMIN(tx_size_wide_unit[tx_size], max_blocks_wide - blk_col); 3289 RD_STATS pn_rd_stats; 3290 int64_t this_rd = 0; 3291 assert(txb_width > 0 && txb_height > 0); 3292 3293 for (int row = 0; row < row_end; row += txb_height) { 3294 const int offsetr = blk_row + row; 3295 for (int col = 0; col < col_end; col += txb_width) { 3296 const int offsetc = blk_col + col; 3297 3298 av1_init_rd_stats(&pn_rd_stats); 3299 tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize, 3300 depth + 1, above_ctx, left_ctx, tx_above, tx_left, 3301 ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode); 3302 if (pn_rd_stats.rate == INT_MAX) { 3303 av1_invalid_rd_stats(rd_stats); 3304 return; 3305 } 3306 av1_merge_rd_stats(rd_stats, &pn_rd_stats); 3307 this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist); 3308 block += step; 3309 } 3310 } 3311 3312 if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH) 3313 rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1]; 3314 } 3315 } 3316 3317 // search for tx type with tx sizes already decided for a inter-predicted luma 3318 // partition block. It's used only when some speed features are enabled. 3319 // Return value 0: early termination triggered, no valid rd cost available; 3320 // 1: rd cost values are valid. 3321 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, 3322 RD_STATS *rd_stats, BLOCK_SIZE bsize, 3323 int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) { 3324 if (ref_best_rd < 0) { 3325 av1_invalid_rd_stats(rd_stats); 3326 return 0; 3327 } 3328 3329 av1_init_rd_stats(rd_stats); 3330 3331 MACROBLOCKD *const xd = &x->e_mbd; 3332 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 3333 const struct macroblockd_plane *const pd = &xd->plane[0]; 3334 const int mi_width = mi_size_wide[bsize]; 3335 const int mi_height = mi_size_high[bsize]; 3336 const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0); 3337 const int bh = tx_size_high_unit[max_tx_size]; 3338 const int bw = tx_size_wide_unit[max_tx_size]; 3339 const int step = bw * bh; 3340 const int init_depth = get_search_init_depth( 3341 mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method); 3342 ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE]; 3343 ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE]; 3344 TXFM_CONTEXT tx_above[MAX_MIB_SIZE]; 3345 TXFM_CONTEXT tx_left[MAX_MIB_SIZE]; 3346 av1_get_entropy_contexts(bsize, pd, ctxa, ctxl); 3347 memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width); 3348 memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height); 3349 3350 int64_t this_rd = 0; 3351 for (int idy = 0, block = 0; idy < mi_height; idy += bh) { 3352 for (int idx = 0; idx < mi_width; idx += bw) { 3353 RD_STATS pn_rd_stats; 3354 av1_init_rd_stats(&pn_rd_stats); 3355 tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth, 3356 ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd, 3357 &pn_rd_stats, ftxs_mode); 3358 if (pn_rd_stats.rate == INT_MAX) { 3359 av1_invalid_rd_stats(rd_stats); 3360 return 0; 3361 } 3362 av1_merge_rd_stats(rd_stats, &pn_rd_stats); 3363 this_rd += 3364 AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist), 3365 RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse)); 3366 block += step; 3367 } 3368 } 3369 3370 const int skip_ctx = av1_get_skip_txfm_context(xd); 3371 const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0]; 3372 const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1]; 3373 const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse); 3374 this_rd = 3375 RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist); 3376 if (skip_txfm_rd < this_rd) { 3377 this_rd = skip_txfm_rd; 3378 rd_stats->rate = 0; 3379 rd_stats->dist = rd_stats->sse; 3380 rd_stats->skip_txfm = 1; 3381 } 3382 3383 const int is_cost_valid = this_rd > ref_best_rd; 3384 if (!is_cost_valid) { 3385 // reset cost value 3386 av1_invalid_rd_stats(rd_stats); 3387 } 3388 return is_cost_valid; 3389 } 3390 3391 // Search for the best transform size and type for current inter-predicted 3392 // luma block with recursive transform block partitioning. The obtained 3393 // transform selection will be saved in xd->mi[0], the corresponding RD stats 3394 // will be saved in rd_stats. The returned value is the corresponding RD cost. 3395 static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x, 3396 RD_STATS *rd_stats, BLOCK_SIZE bsize, 3397 int64_t ref_best_rd) { 3398 MACROBLOCKD *const xd = &x->e_mbd; 3399 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 3400 assert(is_inter_block(xd->mi[0])); 3401 assert(bsize < BLOCK_SIZES_ALL); 3402 const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD; 3403 int64_t rd_thresh = ref_best_rd; 3404 if (rd_thresh == 0) { 3405 av1_invalid_rd_stats(rd_stats); 3406 return INT64_MAX; 3407 } 3408 if (fast_tx_search && rd_thresh < INT64_MAX) { 3409 if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3); 3410 } 3411 assert(rd_thresh > 0); 3412 const FAST_TX_SEARCH_MODE ftxs_mode = 3413 fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE; 3414 const struct macroblockd_plane *const pd = &xd->plane[0]; 3415 assert(bsize < BLOCK_SIZES_ALL); 3416 const int mi_width = mi_size_wide[bsize]; 3417 const int mi_height = mi_size_high[bsize]; 3418 ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE]; 3419 ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE]; 3420 TXFM_CONTEXT tx_above[MAX_MIB_SIZE]; 3421 TXFM_CONTEXT tx_left[MAX_MIB_SIZE]; 3422 av1_get_entropy_contexts(bsize, pd, ctxa, ctxl); 3423 memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width); 3424 memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height); 3425 const int init_depth = get_search_init_depth( 3426 mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method); 3427 const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize]; 3428 const int bh = tx_size_high_unit[max_tx_size]; 3429 const int bw = tx_size_wide_unit[max_tx_size]; 3430 const int step = bw * bh; 3431 const int skip_ctx = av1_get_skip_txfm_context(xd); 3432 const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0]; 3433 const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1]; 3434 int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0); 3435 int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0); 3436 int block = 0; 3437 3438 av1_init_rd_stats(rd_stats); 3439 for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) { 3440 for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) { 3441 const int64_t best_rd_sofar = 3442 (rd_thresh == INT64_MAX) 3443 ? INT64_MAX 3444 : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd))); 3445 int is_cost_valid = 1; 3446 RD_STATS pn_rd_stats; 3447 // Search for the best transform block size and type for the sub-block. 3448 select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize, 3449 ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX, 3450 best_rd_sofar, &is_cost_valid, ftxs_mode); 3451 if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) { 3452 av1_invalid_rd_stats(rd_stats); 3453 return INT64_MAX; 3454 } 3455 av1_merge_rd_stats(rd_stats, &pn_rd_stats); 3456 skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse); 3457 no_skip_txfm_rd = 3458 RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist); 3459 block += step; 3460 } 3461 } 3462 3463 if (rd_stats->rate == INT_MAX) return INT64_MAX; 3464 3465 rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd); 3466 3467 // If fast_tx_search is true, only DCT and 1D DCT were tested in 3468 // select_inter_block_yrd() above. Do a better search for tx type with 3469 // tx sizes already decided. 3470 if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) { 3471 if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE)) 3472 return INT64_MAX; 3473 } 3474 3475 int64_t final_rd; 3476 if (rd_stats->skip_txfm) { 3477 final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse); 3478 } else { 3479 final_rd = 3480 RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist); 3481 if (!xd->lossless[xd->mi[0]->segment_id]) { 3482 final_rd = 3483 AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse)); 3484 } 3485 } 3486 3487 return final_rd; 3488 } 3489 3490 // Return 1 to terminate transform search early. The decision is made based on 3491 // the comparison with the reference RD cost and the model-estimated RD cost. 3492 static inline int model_based_tx_search_prune(const AV1_COMP *cpi, 3493 MACROBLOCK *x, BLOCK_SIZE bsize, 3494 int64_t ref_best_rd) { 3495 const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level; 3496 assert(level >= 0 && level <= 2); 3497 int model_rate; 3498 int64_t model_dist; 3499 uint8_t model_skip; 3500 MACROBLOCKD *const xd = &x->e_mbd; 3501 model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE]( 3502 cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL, 3503 NULL, NULL, NULL); 3504 if (model_skip) return 0; 3505 const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist); 3506 // TODO(debargha, urvang): Improve the model and make the check below 3507 // tighter. 3508 static const int prune_factor_by8[] = { 3, 5 }; 3509 const int factor = prune_factor_by8[level - 1]; 3510 return ((model_rd * factor) >> 3) > ref_best_rd; 3511 } 3512 3513 void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x, 3514 RD_STATS *rd_stats, BLOCK_SIZE bsize, 3515 int64_t ref_best_rd) { 3516 MACROBLOCKD *const xd = &x->e_mbd; 3517 const TxfmSearchParams *txfm_params = &x->txfm_search_params; 3518 assert(is_inter_block(xd->mi[0])); 3519 3520 av1_invalid_rd_stats(rd_stats); 3521 3522 // If modeled RD cost is a lot worse than the best so far, terminate early. 3523 if (cpi->sf.tx_sf.model_based_prune_tx_search_level && 3524 ref_best_rd != INT64_MAX) { 3525 if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return; 3526 } 3527 3528 // Hashing based speed feature. If the hash of the prediction residue block is 3529 // found in the hash table, use previous search results and terminate early. 3530 uint32_t hash = 0; 3531 MB_RD_RECORD *mb_rd_record = NULL; 3532 const int mi_row = x->e_mbd.mi_row; 3533 const int mi_col = x->e_mbd.mi_col; 3534 const int within_border = 3535 mi_row >= xd->tile.mi_row_start && 3536 (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) && 3537 mi_col >= xd->tile.mi_col_start && 3538 (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end); 3539 const int is_mb_rd_hash_enabled = 3540 (within_border && cpi->sf.rd_sf.use_mb_rd_hash); 3541 const int n4 = bsize_to_num_blk(bsize); 3542 if (is_mb_rd_hash_enabled) { 3543 hash = get_block_residue_hash(x, bsize); 3544 mb_rd_record = x->txfm_search_info.mb_rd_record; 3545 const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash); 3546 if (match_index != -1) { 3547 MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index]; 3548 fetch_mb_rd_info(n4, mb_rd_info, rd_stats, x); 3549 return; 3550 } 3551 } 3552 3553 // If we predict that skip is the optimal RD decision - set the respective 3554 // context and terminate early. 3555 int64_t dist; 3556 if (txfm_params->skip_txfm_level && 3557 predict_skip_txfm(x, bsize, &dist, 3558 cpi->common.features.reduced_tx_set_used)) { 3559 set_skip_txfm(x, rd_stats, bsize, dist); 3560 // Save the RD search results into mb_rd_record. 3561 if (is_mb_rd_hash_enabled) 3562 save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record); 3563 return; 3564 } 3565 #if CONFIG_SPEED_STATS 3566 ++x->txfm_search_info.tx_search_count; 3567 #endif // CONFIG_SPEED_STATS 3568 3569 const int64_t rd = 3570 select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd); 3571 3572 if (rd == INT64_MAX) { 3573 // We should always find at least one candidate unless ref_best_rd is less 3574 // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type 3575 // might have failed to find something better) 3576 assert(ref_best_rd != INT64_MAX); 3577 av1_invalid_rd_stats(rd_stats); 3578 return; 3579 } 3580 3581 // Save the RD search results into mb_rd_record. 3582 if (is_mb_rd_hash_enabled) { 3583 assert(mb_rd_record != NULL); 3584 save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record); 3585 } 3586 } 3587 3588 void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x, 3589 RD_STATS *rd_stats, BLOCK_SIZE bs, 3590 int64_t ref_best_rd) { 3591 MACROBLOCKD *const xd = &x->e_mbd; 3592 MB_MODE_INFO *const mbmi = xd->mi[0]; 3593 const TxfmSearchParams *tx_params = &x->txfm_search_params; 3594 assert(bs == mbmi->bsize); 3595 const int is_inter = is_inter_block(mbmi); 3596 const int mi_row = xd->mi_row; 3597 const int mi_col = xd->mi_col; 3598 3599 av1_init_rd_stats(rd_stats); 3600 3601 // Hashing based speed feature for inter blocks. If the hash of the residue 3602 // block is found in the table, use previously saved search results and 3603 // terminate early. 3604 uint32_t hash = 0; 3605 MB_RD_RECORD *mb_rd_record = NULL; 3606 const int num_blks = bsize_to_num_blk(bs); 3607 if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) { 3608 const int within_border = 3609 mi_row >= xd->tile.mi_row_start && 3610 (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) && 3611 mi_col >= xd->tile.mi_col_start && 3612 (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end); 3613 if (within_border) { 3614 hash = get_block_residue_hash(x, bs); 3615 mb_rd_record = x->txfm_search_info.mb_rd_record; 3616 const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash); 3617 if (match_index != -1) { 3618 MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index]; 3619 fetch_mb_rd_info(num_blks, mb_rd_info, rd_stats, x); 3620 return; 3621 } 3622 } 3623 } 3624 3625 // If we predict that skip is the optimal RD decision - set the respective 3626 // context and terminate early. 3627 int64_t dist; 3628 if (tx_params->skip_txfm_level && is_inter && 3629 !xd->lossless[mbmi->segment_id] && 3630 predict_skip_txfm(x, bs, &dist, 3631 cpi->common.features.reduced_tx_set_used)) { 3632 // Populate rdstats as per skip decision 3633 set_skip_txfm(x, rd_stats, bs, dist); 3634 // Save the RD search results into mb_rd_record. 3635 if (mb_rd_record) { 3636 save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record); 3637 } 3638 return; 3639 } 3640 3641 if (xd->lossless[mbmi->segment_id]) { 3642 // Lossless mode can only pick the smallest (4x4) transform size. 3643 choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs); 3644 } else if (tx_params->tx_size_search_method == USE_LARGESTALL) { 3645 choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs); 3646 } else { 3647 choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs); 3648 } 3649 3650 // Save the RD search results into mb_rd_record for possible reuse in future. 3651 if (mb_rd_record) { 3652 save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record); 3653 } 3654 } 3655 3656 int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats, 3657 BLOCK_SIZE bsize, int64_t ref_best_rd) { 3658 av1_init_rd_stats(rd_stats); 3659 if (ref_best_rd < 0) return 0; 3660 if (!x->e_mbd.is_chroma_ref) return 1; 3661 3662 MACROBLOCKD *const xd = &x->e_mbd; 3663 MB_MODE_INFO *const mbmi = xd->mi[0]; 3664 struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U]; 3665 const int is_inter = is_inter_block(mbmi); 3666 int64_t this_rd = 0, skip_txfm_rd = 0; 3667 const BLOCK_SIZE plane_bsize = 3668 get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y); 3669 3670 if (is_inter) { 3671 for (int plane = 1; plane < MAX_MB_PLANE; ++plane) 3672 av1_subtract_plane(x, plane_bsize, plane); 3673 } 3674 3675 const int skip_trellis = 0; 3676 const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd); 3677 int is_cost_valid = 1; 3678 for (int plane = 1; plane < MAX_MB_PLANE; ++plane) { 3679 RD_STATS this_rd_stats; 3680 int64_t chroma_ref_best_rd = ref_best_rd; 3681 // For inter blocks, refined ref_best_rd is used for early exit 3682 // For intra blocks, even though current rd crosses ref_best_rd, early 3683 // exit is not recommended as current rd is used for gating subsequent 3684 // modes as well (say, for angular modes) 3685 // TODO(any): Extend the early exit mechanism for intra modes as well 3686 if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter && 3687 chroma_ref_best_rd != INT64_MAX) 3688 chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd); 3689 av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane, 3690 plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis); 3691 if (this_rd_stats.rate == INT_MAX) { 3692 is_cost_valid = 0; 3693 break; 3694 } 3695 av1_merge_rd_stats(rd_stats, &this_rd_stats); 3696 this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); 3697 skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse); 3698 if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) { 3699 is_cost_valid = 0; 3700 break; 3701 } 3702 } 3703 3704 if (!is_cost_valid) { 3705 // reset cost value 3706 av1_invalid_rd_stats(rd_stats); 3707 } 3708 3709 return is_cost_valid; 3710 } 3711 3712 void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi, 3713 RD_STATS *rd_stats, int64_t ref_best_rd, 3714 int64_t current_rd, int plane, BLOCK_SIZE plane_bsize, 3715 TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode, 3716 int skip_trellis) { 3717 assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size)); 3718 3719 if (!cpi->oxcf.txfm_cfg.enable_tx64 && 3720 txsize_sqr_up_map[tx_size] == TX_64X64) { 3721 av1_invalid_rd_stats(rd_stats); 3722 return; 3723 } 3724 3725 if (current_rd > ref_best_rd) { 3726 av1_invalid_rd_stats(rd_stats); 3727 return; 3728 } 3729 3730 MACROBLOCKD *const xd = &x->e_mbd; 3731 const struct macroblockd_plane *const pd = &xd->plane[plane]; 3732 struct rdcost_block_args args; 3733 av1_zero(args); 3734 args.x = x; 3735 args.cpi = cpi; 3736 args.best_rd = ref_best_rd; 3737 args.current_rd = current_rd; 3738 args.ftxs_mode = ftxs_mode; 3739 args.skip_trellis = skip_trellis; 3740 av1_init_rd_stats(&args.rd_stats); 3741 3742 av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left); 3743 av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm, 3744 &args); 3745 3746 MB_MODE_INFO *const mbmi = xd->mi[0]; 3747 const int is_inter = is_inter_block(mbmi); 3748 const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early; 3749 3750 if (invalid_rd) { 3751 av1_invalid_rd_stats(rd_stats); 3752 } else { 3753 *rd_stats = args.rd_stats; 3754 } 3755 } 3756 3757 int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, 3758 RD_STATS *rd_stats, RD_STATS *rd_stats_y, 3759 RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) { 3760 MACROBLOCKD *const xd = &x->e_mbd; 3761 TxfmSearchParams *txfm_params = &x->txfm_search_params; 3762 const int skip_ctx = av1_get_skip_txfm_context(xd); 3763 const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0], 3764 x->mode_costs.skip_txfm_cost[skip_ctx][1] }; 3765 const int64_t min_header_rate = 3766 mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]); 3767 // Account for minimum skip and non_skip rd. 3768 // Eventually either one of them will be added to mode_rate 3769 const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0); 3770 if (min_header_rd_possible > ref_best_rd) { 3771 av1_invalid_rd_stats(rd_stats_y); 3772 return 0; 3773 } 3774 3775 const AV1_COMMON *cm = &cpi->common; 3776 MB_MODE_INFO *const mbmi = xd->mi[0]; 3777 const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0); 3778 const int64_t rd_thresh = 3779 ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd; 3780 av1_init_rd_stats(rd_stats); 3781 av1_init_rd_stats(rd_stats_y); 3782 rd_stats->rate = mode_rate; 3783 3784 // cost and distortion 3785 av1_subtract_plane(x, bsize, 0); 3786 if (txfm_params->tx_mode_search_type == TX_MODE_SELECT && 3787 !xd->lossless[mbmi->segment_id]) { 3788 av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh); 3789 #if CONFIG_COLLECT_RD_STATS == 2 3790 PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize); 3791 #endif // CONFIG_COLLECT_RD_STATS == 2 3792 } else { 3793 av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh); 3794 memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size)); 3795 for (int i = 0; i < xd->height * xd->width; ++i) 3796 set_blk_skip(x->txfm_search_info.blk_skip, 0, i, rd_stats_y->skip_txfm); 3797 } 3798 3799 if (rd_stats_y->rate == INT_MAX) return 0; 3800 3801 av1_merge_rd_stats(rd_stats, rd_stats_y); 3802 3803 const int64_t non_skip_txfm_rdcosty = 3804 RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist); 3805 const int64_t skip_txfm_rdcosty = 3806 RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse); 3807 const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty); 3808 if (min_rdcosty > ref_best_rd) return 0; 3809 3810 av1_init_rd_stats(rd_stats_uv); 3811 const int num_planes = av1_num_planes(cm); 3812 if (num_planes > 1) { 3813 int64_t ref_best_chroma_rd = ref_best_rd; 3814 // Calculate best rd cost possible for chroma 3815 if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && 3816 (ref_best_chroma_rd != INT64_MAX)) { 3817 ref_best_chroma_rd = (ref_best_chroma_rd - 3818 AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty)); 3819 } 3820 const int is_cost_valid_uv = 3821 av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd); 3822 if (!is_cost_valid_uv) return 0; 3823 av1_merge_rd_stats(rd_stats, rd_stats_uv); 3824 } 3825 3826 int choose_skip_txfm = rd_stats->skip_txfm; 3827 if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) { 3828 const int64_t rdcost_no_skip_txfm = RDCOST( 3829 x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0], 3830 rd_stats->dist); 3831 const int64_t rdcost_skip_txfm = 3832 RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse); 3833 if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1; 3834 } 3835 if (choose_skip_txfm) { 3836 rd_stats_y->rate = 0; 3837 rd_stats_uv->rate = 0; 3838 rd_stats->rate = mode_rate + skip_txfm_cost[1]; 3839 rd_stats->dist = rd_stats->sse; 3840 rd_stats_y->dist = rd_stats_y->sse; 3841 rd_stats_uv->dist = rd_stats_uv->sse; 3842 mbmi->skip_txfm = 1; 3843 if (rd_stats->skip_txfm) { 3844 const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); 3845 if (tmprd > ref_best_rd) return 0; 3846 } 3847 } else { 3848 rd_stats->rate += skip_txfm_cost[0]; 3849 mbmi->skip_txfm = 0; 3850 } 3851 3852 return 1; 3853 }