tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

mv_prec.c (16531B)


      1 /*
      2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include "config/aom_config.h"
     13 
     14 #include "av1/encoder/encodemv.h"
     15 #if !CONFIG_REALTIME_ONLY
     16 #include "av1/encoder/misc_model_weights.h"
     17 #endif  // !CONFIG_REALTIME_ONLY
     18 #include "av1/encoder/mv_prec.h"
     19 
     20 #if !CONFIG_REALTIME_ONLY
     21 static inline int_mv get_ref_mv_for_mv_stats(
     22    const MB_MODE_INFO *mbmi, const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame,
     23    int ref_idx) {
     24  int ref_mv_idx = mbmi->ref_mv_idx;
     25  if (mbmi->mode == NEAR_NEWMV || mbmi->mode == NEW_NEARMV) {
     26    assert(has_second_ref(mbmi));
     27    ref_mv_idx += 1;
     28  }
     29 
     30  const MV_REFERENCE_FRAME *ref_frames = mbmi->ref_frame;
     31  const int8_t ref_frame_type = av1_ref_frame_type(ref_frames);
     32  const CANDIDATE_MV *curr_ref_mv_stack = mbmi_ext_frame->ref_mv_stack;
     33 
     34  if (ref_frames[1] > INTRA_FRAME) {
     35    assert(ref_idx == 0 || ref_idx == 1);
     36    return ref_idx ? curr_ref_mv_stack[ref_mv_idx].comp_mv
     37                   : curr_ref_mv_stack[ref_mv_idx].this_mv;
     38  }
     39 
     40  assert(ref_idx == 0);
     41  return ref_mv_idx < mbmi_ext_frame->ref_mv_count
     42             ? curr_ref_mv_stack[ref_mv_idx].this_mv
     43             : mbmi_ext_frame->global_mvs[ref_frame_type];
     44 }
     45 
     46 static inline int get_symbol_cost(const aom_cdf_prob *cdf, int symbol) {
     47  const aom_cdf_prob cur_cdf = AOM_ICDF(cdf[symbol]);
     48  const aom_cdf_prob prev_cdf = symbol ? AOM_ICDF(cdf[symbol - 1]) : 0;
     49  const aom_cdf_prob p15 = AOMMAX(cur_cdf - prev_cdf, EC_MIN_PROB);
     50 
     51  return av1_cost_symbol(p15);
     52 }
     53 
     54 static inline int keep_one_comp_stat(MV_STATS *mv_stats, int comp_val,
     55                                     int comp_idx, const AV1_COMP *cpi,
     56                                     int *rates) {
     57  assert(comp_val != 0 && "mv component should not have zero value!");
     58  const int sign = comp_val < 0;
     59  const int mag = sign ? -comp_val : comp_val;
     60  const int mag_minus_1 = mag - 1;
     61  int offset;
     62  const int mv_class = av1_get_mv_class(mag_minus_1, &offset);
     63  const int int_part = offset >> 3;         // int mv data
     64  const int frac_part = (offset >> 1) & 3;  // fractional mv data
     65  const int high_part = offset & 1;         // high precision mv data
     66  const int use_hp = cpi->common.features.allow_high_precision_mv;
     67  int r_idx = 0;
     68 
     69  const MACROBLOCK *const x = &cpi->td.mb;
     70  const MACROBLOCKD *const xd = &x->e_mbd;
     71  FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
     72  nmv_context *nmvc = &ec_ctx->nmvc;
     73  nmv_component *mvcomp_ctx = nmvc->comps;
     74  nmv_component *cur_mvcomp_ctx = &mvcomp_ctx[comp_idx];
     75  aom_cdf_prob *sign_cdf = cur_mvcomp_ctx->sign_cdf;
     76  aom_cdf_prob *class_cdf = cur_mvcomp_ctx->classes_cdf;
     77  aom_cdf_prob *class0_cdf = cur_mvcomp_ctx->class0_cdf;
     78  aom_cdf_prob(*bits_cdf)[3] = cur_mvcomp_ctx->bits_cdf;
     79  aom_cdf_prob *frac_part_cdf = mv_class
     80                                    ? (cur_mvcomp_ctx->fp_cdf)
     81                                    : (cur_mvcomp_ctx->class0_fp_cdf[int_part]);
     82  aom_cdf_prob *high_part_cdf =
     83      mv_class ? (cur_mvcomp_ctx->hp_cdf) : (cur_mvcomp_ctx->class0_hp_cdf);
     84 
     85  const int sign_rate = get_symbol_cost(sign_cdf, sign);
     86  rates[r_idx++] = sign_rate;
     87  update_cdf(sign_cdf, sign, 2);
     88 
     89  const int class_rate = get_symbol_cost(class_cdf, mv_class);
     90  rates[r_idx++] = class_rate;
     91  update_cdf(class_cdf, mv_class, MV_CLASSES);
     92 
     93  int int_bit_rate = 0;
     94  if (mv_class == MV_CLASS_0) {
     95    int_bit_rate = get_symbol_cost(class0_cdf, int_part);
     96    update_cdf(class0_cdf, int_part, CLASS0_SIZE);
     97  } else {
     98    const int n = mv_class + CLASS0_BITS - 1;  // number of bits
     99    for (int i = 0; i < n; ++i) {
    100      int_bit_rate += get_symbol_cost(bits_cdf[i], (int_part >> i) & 1);
    101      update_cdf(bits_cdf[i], (int_part >> i) & 1, 2);
    102    }
    103  }
    104  rates[r_idx++] = int_bit_rate;
    105  const int frac_part_rate = get_symbol_cost(frac_part_cdf, frac_part);
    106  rates[r_idx++] = frac_part_rate;
    107  update_cdf(frac_part_cdf, frac_part, MV_FP_SIZE);
    108  const int high_part_rate =
    109      use_hp ? get_symbol_cost(high_part_cdf, high_part) : 0;
    110  if (use_hp) {
    111    update_cdf(high_part_cdf, high_part, 2);
    112  }
    113  rates[r_idx++] = high_part_rate;
    114 
    115  mv_stats->last_bit_zero += !high_part;
    116  mv_stats->last_bit_nonzero += high_part;
    117  const int total_rate =
    118      (sign_rate + class_rate + int_bit_rate + frac_part_rate + high_part_rate);
    119  return total_rate;
    120 }
    121 
    122 static inline void keep_one_mv_stat(MV_STATS *mv_stats, const MV *ref_mv,
    123                                    const MV *cur_mv, const AV1_COMP *cpi) {
    124  const MACROBLOCK *const x = &cpi->td.mb;
    125  const MACROBLOCKD *const xd = &x->e_mbd;
    126  FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
    127  nmv_context *nmvc = &ec_ctx->nmvc;
    128  aom_cdf_prob *joint_cdf = nmvc->joints_cdf;
    129  const int use_hp = cpi->common.features.allow_high_precision_mv;
    130 
    131  const MV diff = { cur_mv->row - ref_mv->row, cur_mv->col - ref_mv->col };
    132  const int mv_joint = av1_get_mv_joint(&diff);
    133  // TODO(chiyotsai@google.com): Estimate hp_diff when we are using lp
    134  const MV hp_diff = diff;
    135  const int hp_mv_joint = av1_get_mv_joint(&hp_diff);
    136  const MV truncated_diff = { (diff.row / 2) * 2, (diff.col / 2) * 2 };
    137  const MV lp_diff = use_hp ? truncated_diff : diff;
    138  const int lp_mv_joint = av1_get_mv_joint(&lp_diff);
    139 
    140  const int mv_joint_rate = get_symbol_cost(joint_cdf, mv_joint);
    141  const int hp_mv_joint_rate = get_symbol_cost(joint_cdf, hp_mv_joint);
    142  const int lp_mv_joint_rate = get_symbol_cost(joint_cdf, lp_mv_joint);
    143 
    144  update_cdf(joint_cdf, mv_joint, MV_JOINTS);
    145 
    146  mv_stats->total_mv_rate += mv_joint_rate;
    147  mv_stats->hp_total_mv_rate += hp_mv_joint_rate;
    148  mv_stats->lp_total_mv_rate += lp_mv_joint_rate;
    149  mv_stats->mv_joint_count[mv_joint]++;
    150 
    151  for (int comp_idx = 0; comp_idx < 2; comp_idx++) {
    152    const int comp_val = comp_idx ? diff.col : diff.row;
    153    const int hp_comp_val = comp_idx ? hp_diff.col : hp_diff.row;
    154    const int lp_comp_val = comp_idx ? lp_diff.col : lp_diff.row;
    155    int rates[5];
    156    av1_zero_array(rates, 5);
    157 
    158    const int comp_rate =
    159        comp_val ? keep_one_comp_stat(mv_stats, comp_val, comp_idx, cpi, rates)
    160                 : 0;
    161    // TODO(chiyotsai@google.com): Properly get hp rate when use_hp is false
    162    const int hp_rate =
    163        hp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] + rates[4] : 0;
    164    const int lp_rate =
    165        lp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] : 0;
    166 
    167    mv_stats->total_mv_rate += comp_rate;
    168    mv_stats->hp_total_mv_rate += hp_rate;
    169    mv_stats->lp_total_mv_rate += lp_rate;
    170  }
    171 }
    172 
    173 static inline void collect_mv_stats_b(MV_STATS *mv_stats, const AV1_COMP *cpi,
    174                                      int mi_row, int mi_col) {
    175  const AV1_COMMON *cm = &cpi->common;
    176  const CommonModeInfoParams *const mi_params = &cm->mi_params;
    177 
    178  if (mi_row >= mi_params->mi_rows || mi_col >= mi_params->mi_cols) {
    179    return;
    180  }
    181 
    182  const MB_MODE_INFO *mbmi =
    183      mi_params->mi_grid_base[mi_row * mi_params->mi_stride + mi_col];
    184  const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame =
    185      cpi->mbmi_ext_info.frame_base +
    186      get_mi_ext_idx(mi_row, mi_col, cm->mi_params.mi_alloc_bsize,
    187                     cpi->mbmi_ext_info.stride);
    188 
    189  if (!is_inter_block(mbmi)) {
    190    mv_stats->intra_count++;
    191    return;
    192  }
    193  mv_stats->inter_count++;
    194 
    195  const PREDICTION_MODE mode = mbmi->mode;
    196  const int is_compound = has_second_ref(mbmi);
    197 
    198  if (mode == NEWMV || mode == NEW_NEWMV) {
    199    // All mvs are new
    200    for (int ref_idx = 0; ref_idx < 1 + is_compound; ++ref_idx) {
    201      const MV ref_mv =
    202          get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv;
    203      const MV cur_mv = mbmi->mv[ref_idx].as_mv;
    204      keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi);
    205    }
    206  } else if (mode == NEAREST_NEWMV || mode == NEAR_NEWMV ||
    207             mode == NEW_NEARESTMV || mode == NEW_NEARMV) {
    208    // has exactly one new_mv
    209    mv_stats->default_mvs += 1;
    210 
    211    const int ref_idx = (mode == NEAREST_NEWMV || mode == NEAR_NEWMV);
    212    const MV ref_mv =
    213        get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv;
    214    const MV cur_mv = mbmi->mv[ref_idx].as_mv;
    215 
    216    keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi);
    217  } else {
    218    // No new_mv
    219    mv_stats->default_mvs += 1 + is_compound;
    220  }
    221 
    222  // Add texture information
    223  const BLOCK_SIZE bsize = mbmi->bsize;
    224  const int num_rows = block_size_high[bsize];
    225  const int num_cols = block_size_wide[bsize];
    226  const int y_stride = cpi->source->y_stride;
    227  const int px_row = 4 * mi_row, px_col = 4 * mi_col;
    228  const int buf_is_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
    229  const int bd = cm->seq_params->bit_depth;
    230  if (buf_is_hbd) {
    231    uint16_t *source_buf =
    232        CONVERT_TO_SHORTPTR(cpi->source->y_buffer) + px_row * y_stride + px_col;
    233    for (int row = 0; row < num_rows - 1; row++) {
    234      for (int col = 0; col < num_cols - 1; col++) {
    235        const int offset = row * y_stride + col;
    236        const int horz_diff =
    237            abs(source_buf[offset + 1] - source_buf[offset]) >> (bd - 8);
    238        const int vert_diff =
    239            abs(source_buf[offset + y_stride] - source_buf[offset]) >> (bd - 8);
    240        mv_stats->horz_text += horz_diff;
    241        mv_stats->vert_text += vert_diff;
    242        mv_stats->diag_text += horz_diff * vert_diff;
    243      }
    244    }
    245  } else {
    246    uint8_t *source_buf = cpi->source->y_buffer + px_row * y_stride + px_col;
    247    for (int row = 0; row < num_rows - 1; row++) {
    248      for (int col = 0; col < num_cols - 1; col++) {
    249        const int offset = row * y_stride + col;
    250        const int horz_diff = abs(source_buf[offset + 1] - source_buf[offset]);
    251        const int vert_diff =
    252            abs(source_buf[offset + y_stride] - source_buf[offset]);
    253        mv_stats->horz_text += horz_diff;
    254        mv_stats->vert_text += vert_diff;
    255        mv_stats->diag_text += horz_diff * vert_diff;
    256      }
    257    }
    258  }
    259 }
    260 
    261 // Split block
    262 static inline void collect_mv_stats_sb(MV_STATS *mv_stats, const AV1_COMP *cpi,
    263                                       int mi_row, int mi_col,
    264                                       BLOCK_SIZE bsize) {
    265  assert(bsize < BLOCK_SIZES_ALL);
    266  const AV1_COMMON *cm = &cpi->common;
    267 
    268  if (mi_row >= cm->mi_params.mi_rows || mi_col >= cm->mi_params.mi_cols)
    269    return;
    270 
    271  const PARTITION_TYPE partition = get_partition(cm, mi_row, mi_col, bsize);
    272  const BLOCK_SIZE subsize = get_partition_subsize(bsize, partition);
    273 
    274  const int hbs = mi_size_wide[bsize] / 2;
    275  const int qbs = mi_size_wide[bsize] / 4;
    276  switch (partition) {
    277    case PARTITION_NONE:
    278      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
    279      break;
    280    case PARTITION_HORZ:
    281      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
    282      collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
    283      break;
    284    case PARTITION_VERT:
    285      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
    286      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
    287      break;
    288    case PARTITION_SPLIT:
    289      collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, subsize);
    290      collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col + hbs, subsize);
    291      collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col, subsize);
    292      collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col + hbs, subsize);
    293      break;
    294    case PARTITION_HORZ_A:
    295      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
    296      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
    297      collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
    298      break;
    299    case PARTITION_HORZ_B:
    300      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
    301      collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
    302      collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs);
    303      break;
    304    case PARTITION_VERT_A:
    305      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
    306      collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
    307      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
    308      break;
    309    case PARTITION_VERT_B:
    310      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
    311      collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
    312      collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs);
    313      break;
    314    case PARTITION_HORZ_4:
    315      for (int i = 0; i < 4; ++i) {
    316        const int this_mi_row = mi_row + i * qbs;
    317        collect_mv_stats_b(mv_stats, cpi, this_mi_row, mi_col);
    318      }
    319      break;
    320    case PARTITION_VERT_4:
    321      for (int i = 0; i < 4; ++i) {
    322        const int this_mi_col = mi_col + i * qbs;
    323        collect_mv_stats_b(mv_stats, cpi, mi_row, this_mi_col);
    324      }
    325      break;
    326    default: assert(0);
    327  }
    328 }
    329 
    330 static inline void collect_mv_stats_tile(MV_STATS *mv_stats,
    331                                         const AV1_COMP *cpi,
    332                                         const TileInfo *tile_info) {
    333  const AV1_COMMON *cm = &cpi->common;
    334  const int mi_row_start = tile_info->mi_row_start;
    335  const int mi_row_end = tile_info->mi_row_end;
    336  const int mi_col_start = tile_info->mi_col_start;
    337  const int mi_col_end = tile_info->mi_col_end;
    338  const int sb_size_mi = cm->seq_params->mib_size;
    339  BLOCK_SIZE sb_size = cm->seq_params->sb_size;
    340  for (int mi_row = mi_row_start; mi_row < mi_row_end; mi_row += sb_size_mi) {
    341    for (int mi_col = mi_col_start; mi_col < mi_col_end; mi_col += sb_size_mi) {
    342      collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, sb_size);
    343    }
    344  }
    345 }
    346 
    347 void av1_collect_mv_stats(AV1_COMP *cpi, int current_q) {
    348  MV_STATS *mv_stats = &cpi->mv_stats;
    349  const AV1_COMMON *cm = &cpi->common;
    350  const int tile_cols = cm->tiles.cols;
    351  const int tile_rows = cm->tiles.rows;
    352 
    353  for (int tile_row = 0; tile_row < tile_rows; tile_row++) {
    354    TileInfo tile_info;
    355    av1_tile_set_row(&tile_info, cm, tile_row);
    356    for (int tile_col = 0; tile_col < tile_cols; tile_col++) {
    357      const int tile_idx = tile_row * tile_cols + tile_col;
    358      av1_tile_set_col(&tile_info, cm, tile_col);
    359      cpi->tile_data[tile_idx].tctx = *cm->fc;
    360      cpi->td.mb.e_mbd.tile_ctx = &cpi->tile_data[tile_idx].tctx;
    361      collect_mv_stats_tile(mv_stats, cpi, &tile_info);
    362    }
    363  }
    364 
    365  mv_stats->q = current_q;
    366  mv_stats->order = cpi->common.current_frame.order_hint;
    367  mv_stats->valid = 1;
    368 }
    369 
    370 static inline int get_smart_mv_prec(AV1_COMP *cpi, const MV_STATS *mv_stats,
    371                                    int current_q) {
    372  const AV1_COMMON *cm = &cpi->common;
    373  const int order_hint = cpi->common.current_frame.order_hint;
    374  const int order_diff = order_hint - mv_stats->order;
    375  const float area = (float)(cm->width * cm->height);
    376  float features[MV_PREC_FEATURE_SIZE] = {
    377    (float)current_q,
    378    (float)mv_stats->q,
    379    (float)order_diff,
    380    mv_stats->inter_count / area,
    381    mv_stats->intra_count / area,
    382    mv_stats->default_mvs / area,
    383    mv_stats->mv_joint_count[0] / area,
    384    mv_stats->mv_joint_count[1] / area,
    385    mv_stats->mv_joint_count[2] / area,
    386    mv_stats->mv_joint_count[3] / area,
    387    mv_stats->last_bit_zero / area,
    388    mv_stats->last_bit_nonzero / area,
    389    mv_stats->total_mv_rate / area,
    390    mv_stats->hp_total_mv_rate / area,
    391    mv_stats->lp_total_mv_rate / area,
    392    mv_stats->horz_text / area,
    393    mv_stats->vert_text / area,
    394    mv_stats->diag_text / area,
    395  };
    396 
    397  for (int f_idx = 0; f_idx < MV_PREC_FEATURE_SIZE; f_idx++) {
    398    features[f_idx] =
    399        (features[f_idx] - av1_mv_prec_mean[f_idx]) / av1_mv_prec_std[f_idx];
    400  }
    401  float score = 0.0f;
    402 
    403  av1_nn_predict(features, &av1_mv_prec_dnn_config, 1, &score);
    404 
    405  const int use_high_hp = score >= 0.0f;
    406  return use_high_hp;
    407 }
    408 #endif  // !CONFIG_REALTIME_ONLY
    409 
    410 void av1_pick_and_set_high_precision_mv(AV1_COMP *cpi, int qindex) {
    411  int use_hp = qindex < HIGH_PRECISION_MV_QTHRESH;
    412 #if !CONFIG_REALTIME_ONLY
    413  MV_STATS *mv_stats = &cpi->mv_stats;
    414 #endif  // !CONFIG_REALTIME_ONLY
    415 
    416  if (cpi->sf.hl_sf.high_precision_mv_usage == QTR_ONLY) {
    417    use_hp = 0;
    418  }
    419 #if !CONFIG_REALTIME_ONLY
    420  else if (cpi->sf.hl_sf.high_precision_mv_usage == LAST_MV_DATA &&
    421           av1_frame_allows_smart_mv(cpi) && mv_stats->valid) {
    422    use_hp = get_smart_mv_prec(cpi, mv_stats, qindex);
    423  }
    424 #endif  // !CONFIG_REALTIME_ONLY
    425 
    426  av1_set_high_precision_mv(cpi, use_hp,
    427                            cpi->common.features.cur_frame_force_integer_mv);
    428 }