mv_prec.c (16531B)
1 /* 2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include "config/aom_config.h" 13 14 #include "av1/encoder/encodemv.h" 15 #if !CONFIG_REALTIME_ONLY 16 #include "av1/encoder/misc_model_weights.h" 17 #endif // !CONFIG_REALTIME_ONLY 18 #include "av1/encoder/mv_prec.h" 19 20 #if !CONFIG_REALTIME_ONLY 21 static inline int_mv get_ref_mv_for_mv_stats( 22 const MB_MODE_INFO *mbmi, const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame, 23 int ref_idx) { 24 int ref_mv_idx = mbmi->ref_mv_idx; 25 if (mbmi->mode == NEAR_NEWMV || mbmi->mode == NEW_NEARMV) { 26 assert(has_second_ref(mbmi)); 27 ref_mv_idx += 1; 28 } 29 30 const MV_REFERENCE_FRAME *ref_frames = mbmi->ref_frame; 31 const int8_t ref_frame_type = av1_ref_frame_type(ref_frames); 32 const CANDIDATE_MV *curr_ref_mv_stack = mbmi_ext_frame->ref_mv_stack; 33 34 if (ref_frames[1] > INTRA_FRAME) { 35 assert(ref_idx == 0 || ref_idx == 1); 36 return ref_idx ? curr_ref_mv_stack[ref_mv_idx].comp_mv 37 : curr_ref_mv_stack[ref_mv_idx].this_mv; 38 } 39 40 assert(ref_idx == 0); 41 return ref_mv_idx < mbmi_ext_frame->ref_mv_count 42 ? curr_ref_mv_stack[ref_mv_idx].this_mv 43 : mbmi_ext_frame->global_mvs[ref_frame_type]; 44 } 45 46 static inline int get_symbol_cost(const aom_cdf_prob *cdf, int symbol) { 47 const aom_cdf_prob cur_cdf = AOM_ICDF(cdf[symbol]); 48 const aom_cdf_prob prev_cdf = symbol ? AOM_ICDF(cdf[symbol - 1]) : 0; 49 const aom_cdf_prob p15 = AOMMAX(cur_cdf - prev_cdf, EC_MIN_PROB); 50 51 return av1_cost_symbol(p15); 52 } 53 54 static inline int keep_one_comp_stat(MV_STATS *mv_stats, int comp_val, 55 int comp_idx, const AV1_COMP *cpi, 56 int *rates) { 57 assert(comp_val != 0 && "mv component should not have zero value!"); 58 const int sign = comp_val < 0; 59 const int mag = sign ? -comp_val : comp_val; 60 const int mag_minus_1 = mag - 1; 61 int offset; 62 const int mv_class = av1_get_mv_class(mag_minus_1, &offset); 63 const int int_part = offset >> 3; // int mv data 64 const int frac_part = (offset >> 1) & 3; // fractional mv data 65 const int high_part = offset & 1; // high precision mv data 66 const int use_hp = cpi->common.features.allow_high_precision_mv; 67 int r_idx = 0; 68 69 const MACROBLOCK *const x = &cpi->td.mb; 70 const MACROBLOCKD *const xd = &x->e_mbd; 71 FRAME_CONTEXT *ec_ctx = xd->tile_ctx; 72 nmv_context *nmvc = &ec_ctx->nmvc; 73 nmv_component *mvcomp_ctx = nmvc->comps; 74 nmv_component *cur_mvcomp_ctx = &mvcomp_ctx[comp_idx]; 75 aom_cdf_prob *sign_cdf = cur_mvcomp_ctx->sign_cdf; 76 aom_cdf_prob *class_cdf = cur_mvcomp_ctx->classes_cdf; 77 aom_cdf_prob *class0_cdf = cur_mvcomp_ctx->class0_cdf; 78 aom_cdf_prob(*bits_cdf)[3] = cur_mvcomp_ctx->bits_cdf; 79 aom_cdf_prob *frac_part_cdf = mv_class 80 ? (cur_mvcomp_ctx->fp_cdf) 81 : (cur_mvcomp_ctx->class0_fp_cdf[int_part]); 82 aom_cdf_prob *high_part_cdf = 83 mv_class ? (cur_mvcomp_ctx->hp_cdf) : (cur_mvcomp_ctx->class0_hp_cdf); 84 85 const int sign_rate = get_symbol_cost(sign_cdf, sign); 86 rates[r_idx++] = sign_rate; 87 update_cdf(sign_cdf, sign, 2); 88 89 const int class_rate = get_symbol_cost(class_cdf, mv_class); 90 rates[r_idx++] = class_rate; 91 update_cdf(class_cdf, mv_class, MV_CLASSES); 92 93 int int_bit_rate = 0; 94 if (mv_class == MV_CLASS_0) { 95 int_bit_rate = get_symbol_cost(class0_cdf, int_part); 96 update_cdf(class0_cdf, int_part, CLASS0_SIZE); 97 } else { 98 const int n = mv_class + CLASS0_BITS - 1; // number of bits 99 for (int i = 0; i < n; ++i) { 100 int_bit_rate += get_symbol_cost(bits_cdf[i], (int_part >> i) & 1); 101 update_cdf(bits_cdf[i], (int_part >> i) & 1, 2); 102 } 103 } 104 rates[r_idx++] = int_bit_rate; 105 const int frac_part_rate = get_symbol_cost(frac_part_cdf, frac_part); 106 rates[r_idx++] = frac_part_rate; 107 update_cdf(frac_part_cdf, frac_part, MV_FP_SIZE); 108 const int high_part_rate = 109 use_hp ? get_symbol_cost(high_part_cdf, high_part) : 0; 110 if (use_hp) { 111 update_cdf(high_part_cdf, high_part, 2); 112 } 113 rates[r_idx++] = high_part_rate; 114 115 mv_stats->last_bit_zero += !high_part; 116 mv_stats->last_bit_nonzero += high_part; 117 const int total_rate = 118 (sign_rate + class_rate + int_bit_rate + frac_part_rate + high_part_rate); 119 return total_rate; 120 } 121 122 static inline void keep_one_mv_stat(MV_STATS *mv_stats, const MV *ref_mv, 123 const MV *cur_mv, const AV1_COMP *cpi) { 124 const MACROBLOCK *const x = &cpi->td.mb; 125 const MACROBLOCKD *const xd = &x->e_mbd; 126 FRAME_CONTEXT *ec_ctx = xd->tile_ctx; 127 nmv_context *nmvc = &ec_ctx->nmvc; 128 aom_cdf_prob *joint_cdf = nmvc->joints_cdf; 129 const int use_hp = cpi->common.features.allow_high_precision_mv; 130 131 const MV diff = { cur_mv->row - ref_mv->row, cur_mv->col - ref_mv->col }; 132 const int mv_joint = av1_get_mv_joint(&diff); 133 // TODO(chiyotsai@google.com): Estimate hp_diff when we are using lp 134 const MV hp_diff = diff; 135 const int hp_mv_joint = av1_get_mv_joint(&hp_diff); 136 const MV truncated_diff = { (diff.row / 2) * 2, (diff.col / 2) * 2 }; 137 const MV lp_diff = use_hp ? truncated_diff : diff; 138 const int lp_mv_joint = av1_get_mv_joint(&lp_diff); 139 140 const int mv_joint_rate = get_symbol_cost(joint_cdf, mv_joint); 141 const int hp_mv_joint_rate = get_symbol_cost(joint_cdf, hp_mv_joint); 142 const int lp_mv_joint_rate = get_symbol_cost(joint_cdf, lp_mv_joint); 143 144 update_cdf(joint_cdf, mv_joint, MV_JOINTS); 145 146 mv_stats->total_mv_rate += mv_joint_rate; 147 mv_stats->hp_total_mv_rate += hp_mv_joint_rate; 148 mv_stats->lp_total_mv_rate += lp_mv_joint_rate; 149 mv_stats->mv_joint_count[mv_joint]++; 150 151 for (int comp_idx = 0; comp_idx < 2; comp_idx++) { 152 const int comp_val = comp_idx ? diff.col : diff.row; 153 const int hp_comp_val = comp_idx ? hp_diff.col : hp_diff.row; 154 const int lp_comp_val = comp_idx ? lp_diff.col : lp_diff.row; 155 int rates[5]; 156 av1_zero_array(rates, 5); 157 158 const int comp_rate = 159 comp_val ? keep_one_comp_stat(mv_stats, comp_val, comp_idx, cpi, rates) 160 : 0; 161 // TODO(chiyotsai@google.com): Properly get hp rate when use_hp is false 162 const int hp_rate = 163 hp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] + rates[4] : 0; 164 const int lp_rate = 165 lp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] : 0; 166 167 mv_stats->total_mv_rate += comp_rate; 168 mv_stats->hp_total_mv_rate += hp_rate; 169 mv_stats->lp_total_mv_rate += lp_rate; 170 } 171 } 172 173 static inline void collect_mv_stats_b(MV_STATS *mv_stats, const AV1_COMP *cpi, 174 int mi_row, int mi_col) { 175 const AV1_COMMON *cm = &cpi->common; 176 const CommonModeInfoParams *const mi_params = &cm->mi_params; 177 178 if (mi_row >= mi_params->mi_rows || mi_col >= mi_params->mi_cols) { 179 return; 180 } 181 182 const MB_MODE_INFO *mbmi = 183 mi_params->mi_grid_base[mi_row * mi_params->mi_stride + mi_col]; 184 const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame = 185 cpi->mbmi_ext_info.frame_base + 186 get_mi_ext_idx(mi_row, mi_col, cm->mi_params.mi_alloc_bsize, 187 cpi->mbmi_ext_info.stride); 188 189 if (!is_inter_block(mbmi)) { 190 mv_stats->intra_count++; 191 return; 192 } 193 mv_stats->inter_count++; 194 195 const PREDICTION_MODE mode = mbmi->mode; 196 const int is_compound = has_second_ref(mbmi); 197 198 if (mode == NEWMV || mode == NEW_NEWMV) { 199 // All mvs are new 200 for (int ref_idx = 0; ref_idx < 1 + is_compound; ++ref_idx) { 201 const MV ref_mv = 202 get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv; 203 const MV cur_mv = mbmi->mv[ref_idx].as_mv; 204 keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi); 205 } 206 } else if (mode == NEAREST_NEWMV || mode == NEAR_NEWMV || 207 mode == NEW_NEARESTMV || mode == NEW_NEARMV) { 208 // has exactly one new_mv 209 mv_stats->default_mvs += 1; 210 211 const int ref_idx = (mode == NEAREST_NEWMV || mode == NEAR_NEWMV); 212 const MV ref_mv = 213 get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv; 214 const MV cur_mv = mbmi->mv[ref_idx].as_mv; 215 216 keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi); 217 } else { 218 // No new_mv 219 mv_stats->default_mvs += 1 + is_compound; 220 } 221 222 // Add texture information 223 const BLOCK_SIZE bsize = mbmi->bsize; 224 const int num_rows = block_size_high[bsize]; 225 const int num_cols = block_size_wide[bsize]; 226 const int y_stride = cpi->source->y_stride; 227 const int px_row = 4 * mi_row, px_col = 4 * mi_col; 228 const int buf_is_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH; 229 const int bd = cm->seq_params->bit_depth; 230 if (buf_is_hbd) { 231 uint16_t *source_buf = 232 CONVERT_TO_SHORTPTR(cpi->source->y_buffer) + px_row * y_stride + px_col; 233 for (int row = 0; row < num_rows - 1; row++) { 234 for (int col = 0; col < num_cols - 1; col++) { 235 const int offset = row * y_stride + col; 236 const int horz_diff = 237 abs(source_buf[offset + 1] - source_buf[offset]) >> (bd - 8); 238 const int vert_diff = 239 abs(source_buf[offset + y_stride] - source_buf[offset]) >> (bd - 8); 240 mv_stats->horz_text += horz_diff; 241 mv_stats->vert_text += vert_diff; 242 mv_stats->diag_text += horz_diff * vert_diff; 243 } 244 } 245 } else { 246 uint8_t *source_buf = cpi->source->y_buffer + px_row * y_stride + px_col; 247 for (int row = 0; row < num_rows - 1; row++) { 248 for (int col = 0; col < num_cols - 1; col++) { 249 const int offset = row * y_stride + col; 250 const int horz_diff = abs(source_buf[offset + 1] - source_buf[offset]); 251 const int vert_diff = 252 abs(source_buf[offset + y_stride] - source_buf[offset]); 253 mv_stats->horz_text += horz_diff; 254 mv_stats->vert_text += vert_diff; 255 mv_stats->diag_text += horz_diff * vert_diff; 256 } 257 } 258 } 259 } 260 261 // Split block 262 static inline void collect_mv_stats_sb(MV_STATS *mv_stats, const AV1_COMP *cpi, 263 int mi_row, int mi_col, 264 BLOCK_SIZE bsize) { 265 assert(bsize < BLOCK_SIZES_ALL); 266 const AV1_COMMON *cm = &cpi->common; 267 268 if (mi_row >= cm->mi_params.mi_rows || mi_col >= cm->mi_params.mi_cols) 269 return; 270 271 const PARTITION_TYPE partition = get_partition(cm, mi_row, mi_col, bsize); 272 const BLOCK_SIZE subsize = get_partition_subsize(bsize, partition); 273 274 const int hbs = mi_size_wide[bsize] / 2; 275 const int qbs = mi_size_wide[bsize] / 4; 276 switch (partition) { 277 case PARTITION_NONE: 278 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); 279 break; 280 case PARTITION_HORZ: 281 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); 282 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col); 283 break; 284 case PARTITION_VERT: 285 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); 286 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs); 287 break; 288 case PARTITION_SPLIT: 289 collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, subsize); 290 collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col + hbs, subsize); 291 collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col, subsize); 292 collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col + hbs, subsize); 293 break; 294 case PARTITION_HORZ_A: 295 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); 296 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs); 297 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col); 298 break; 299 case PARTITION_HORZ_B: 300 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); 301 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col); 302 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs); 303 break; 304 case PARTITION_VERT_A: 305 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); 306 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col); 307 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs); 308 break; 309 case PARTITION_VERT_B: 310 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); 311 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs); 312 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs); 313 break; 314 case PARTITION_HORZ_4: 315 for (int i = 0; i < 4; ++i) { 316 const int this_mi_row = mi_row + i * qbs; 317 collect_mv_stats_b(mv_stats, cpi, this_mi_row, mi_col); 318 } 319 break; 320 case PARTITION_VERT_4: 321 for (int i = 0; i < 4; ++i) { 322 const int this_mi_col = mi_col + i * qbs; 323 collect_mv_stats_b(mv_stats, cpi, mi_row, this_mi_col); 324 } 325 break; 326 default: assert(0); 327 } 328 } 329 330 static inline void collect_mv_stats_tile(MV_STATS *mv_stats, 331 const AV1_COMP *cpi, 332 const TileInfo *tile_info) { 333 const AV1_COMMON *cm = &cpi->common; 334 const int mi_row_start = tile_info->mi_row_start; 335 const int mi_row_end = tile_info->mi_row_end; 336 const int mi_col_start = tile_info->mi_col_start; 337 const int mi_col_end = tile_info->mi_col_end; 338 const int sb_size_mi = cm->seq_params->mib_size; 339 BLOCK_SIZE sb_size = cm->seq_params->sb_size; 340 for (int mi_row = mi_row_start; mi_row < mi_row_end; mi_row += sb_size_mi) { 341 for (int mi_col = mi_col_start; mi_col < mi_col_end; mi_col += sb_size_mi) { 342 collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, sb_size); 343 } 344 } 345 } 346 347 void av1_collect_mv_stats(AV1_COMP *cpi, int current_q) { 348 MV_STATS *mv_stats = &cpi->mv_stats; 349 const AV1_COMMON *cm = &cpi->common; 350 const int tile_cols = cm->tiles.cols; 351 const int tile_rows = cm->tiles.rows; 352 353 for (int tile_row = 0; tile_row < tile_rows; tile_row++) { 354 TileInfo tile_info; 355 av1_tile_set_row(&tile_info, cm, tile_row); 356 for (int tile_col = 0; tile_col < tile_cols; tile_col++) { 357 const int tile_idx = tile_row * tile_cols + tile_col; 358 av1_tile_set_col(&tile_info, cm, tile_col); 359 cpi->tile_data[tile_idx].tctx = *cm->fc; 360 cpi->td.mb.e_mbd.tile_ctx = &cpi->tile_data[tile_idx].tctx; 361 collect_mv_stats_tile(mv_stats, cpi, &tile_info); 362 } 363 } 364 365 mv_stats->q = current_q; 366 mv_stats->order = cpi->common.current_frame.order_hint; 367 mv_stats->valid = 1; 368 } 369 370 static inline int get_smart_mv_prec(AV1_COMP *cpi, const MV_STATS *mv_stats, 371 int current_q) { 372 const AV1_COMMON *cm = &cpi->common; 373 const int order_hint = cpi->common.current_frame.order_hint; 374 const int order_diff = order_hint - mv_stats->order; 375 const float area = (float)(cm->width * cm->height); 376 float features[MV_PREC_FEATURE_SIZE] = { 377 (float)current_q, 378 (float)mv_stats->q, 379 (float)order_diff, 380 mv_stats->inter_count / area, 381 mv_stats->intra_count / area, 382 mv_stats->default_mvs / area, 383 mv_stats->mv_joint_count[0] / area, 384 mv_stats->mv_joint_count[1] / area, 385 mv_stats->mv_joint_count[2] / area, 386 mv_stats->mv_joint_count[3] / area, 387 mv_stats->last_bit_zero / area, 388 mv_stats->last_bit_nonzero / area, 389 mv_stats->total_mv_rate / area, 390 mv_stats->hp_total_mv_rate / area, 391 mv_stats->lp_total_mv_rate / area, 392 mv_stats->horz_text / area, 393 mv_stats->vert_text / area, 394 mv_stats->diag_text / area, 395 }; 396 397 for (int f_idx = 0; f_idx < MV_PREC_FEATURE_SIZE; f_idx++) { 398 features[f_idx] = 399 (features[f_idx] - av1_mv_prec_mean[f_idx]) / av1_mv_prec_std[f_idx]; 400 } 401 float score = 0.0f; 402 403 av1_nn_predict(features, &av1_mv_prec_dnn_config, 1, &score); 404 405 const int use_high_hp = score >= 0.0f; 406 return use_high_hp; 407 } 408 #endif // !CONFIG_REALTIME_ONLY 409 410 void av1_pick_and_set_high_precision_mv(AV1_COMP *cpi, int qindex) { 411 int use_hp = qindex < HIGH_PRECISION_MV_QTHRESH; 412 #if !CONFIG_REALTIME_ONLY 413 MV_STATS *mv_stats = &cpi->mv_stats; 414 #endif // !CONFIG_REALTIME_ONLY 415 416 if (cpi->sf.hl_sf.high_precision_mv_usage == QTR_ONLY) { 417 use_hp = 0; 418 } 419 #if !CONFIG_REALTIME_ONLY 420 else if (cpi->sf.hl_sf.high_precision_mv_usage == LAST_MV_DATA && 421 av1_frame_allows_smart_mv(cpi) && mv_stats->valid) { 422 use_hp = get_smart_mv_prec(cpi, mv_stats, qindex); 423 } 424 #endif // !CONFIG_REALTIME_ONLY 425 426 av1_set_high_precision_mv(cpi, use_hp, 427 cpi->common.features.cur_frame_force_integer_mv); 428 }