tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

tune_vmaf.c (47498B)


      1 /*
      2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include "av1/encoder/tune_vmaf.h"
     13 
     14 #include "aom_dsp/psnr.h"
     15 #include "av1/encoder/extend.h"
     16 #include "av1/encoder/rdopt.h"
     17 #include "config/aom_scale_rtcd.h"
     18 
     19 static const double kBaselineVmaf = 97.42773;
     20 
     21 static double get_layer_value(const double *array, int layer) {
     22  while (array[layer] < 0.0 && layer > 0) layer--;
     23  return AOMMAX(array[layer], 0.0);
     24 }
     25 
     26 static void motion_search(AV1_COMP *cpi, const YV12_BUFFER_CONFIG *src,
     27                          const YV12_BUFFER_CONFIG *ref,
     28                          const BLOCK_SIZE block_size, const int mb_row,
     29                          const int mb_col, FULLPEL_MV *ref_mv) {
     30  // Block information (ONLY Y-plane is used for motion search).
     31  const int mb_height = block_size_high[block_size];
     32  const int mb_width = block_size_wide[block_size];
     33  const int y_stride = src->y_stride;
     34  assert(y_stride == ref->y_stride);
     35  const int y_offset = mb_row * mb_height * y_stride + mb_col * mb_width;
     36 
     37  // Save input state.
     38  MACROBLOCK *const mb = &cpi->td.mb;
     39  MACROBLOCKD *const mbd = &mb->e_mbd;
     40  const struct buf_2d ori_src_buf = mb->plane[0].src;
     41  const struct buf_2d ori_pre_buf = mbd->plane[0].pre[0];
     42 
     43  // Parameters used for motion search.
     44  FULLPEL_MOTION_SEARCH_PARAMS full_ms_params;
     45  FULLPEL_MV_STATS best_mv_stats;
     46  const SEARCH_METHODS search_method = NSTEP;
     47  const search_site_config *search_site_cfg =
     48      cpi->mv_search_params.search_site_cfg[SS_CFG_FPF];
     49  const int step_param =
     50      av1_init_search_range(AOMMAX(src->y_crop_width, src->y_crop_height));
     51 
     52  // Baseline position for motion search (used for rate distortion comparison).
     53  const MV baseline_mv = kZeroMv;
     54 
     55  // Setup.
     56  mb->plane[0].src.buf = src->y_buffer + y_offset;
     57  mb->plane[0].src.stride = y_stride;
     58  mbd->plane[0].pre[0].buf = ref->y_buffer + y_offset;
     59  mbd->plane[0].pre[0].stride = y_stride;
     60 
     61  // Unused intermediate results for motion search.
     62  int cost_list[5];
     63 
     64  // Do motion search.
     65  // Only do full search on the entire block.
     66  av1_make_default_fullpel_ms_params(&full_ms_params, cpi, mb, block_size,
     67                                     &baseline_mv, *ref_mv, search_site_cfg,
     68                                     search_method,
     69                                     /*fine_search_interval=*/0);
     70  av1_full_pixel_search(*ref_mv, &full_ms_params, step_param,
     71                        cond_cost_list(cpi, cost_list), ref_mv, &best_mv_stats,
     72                        NULL);
     73 
     74  // Restore input state.
     75  mb->plane[0].src = ori_src_buf;
     76  mbd->plane[0].pre[0] = ori_pre_buf;
     77 }
     78 
     79 static unsigned int residual_variance(const AV1_COMP *cpi,
     80                                      const YV12_BUFFER_CONFIG *src,
     81                                      const YV12_BUFFER_CONFIG *ref,
     82                                      const BLOCK_SIZE block_size,
     83                                      const int mb_row, const int mb_col,
     84                                      FULLPEL_MV ref_mv, unsigned int *sse) {
     85  const int mb_height = block_size_high[block_size];
     86  const int mb_width = block_size_wide[block_size];
     87  const int y_stride = src->y_stride;
     88  assert(y_stride == ref->y_stride);
     89  const int y_offset = mb_row * mb_height * y_stride + mb_col * mb_width;
     90  const int mv_offset = ref_mv.row * y_stride + ref_mv.col;
     91  const unsigned int var = cpi->ppi->fn_ptr[block_size].vf(
     92      ref->y_buffer + y_offset + mv_offset, y_stride, src->y_buffer + y_offset,
     93      y_stride, sse);
     94  return var;
     95 }
     96 
     97 static double frame_average_variance(const AV1_COMP *const cpi,
     98                                     const YV12_BUFFER_CONFIG *const frame) {
     99  const MACROBLOCKD *const xd = &cpi->td.mb.e_mbd;
    100  const uint8_t *const y_buffer = frame->y_buffer;
    101  const int y_stride = frame->y_stride;
    102  const BLOCK_SIZE block_size = BLOCK_64X64;
    103 
    104  const int block_w = mi_size_wide[block_size] * 4;
    105  const int block_h = mi_size_high[block_size] * 4;
    106  int row, col;
    107  double var = 0.0, var_count = 0.0;
    108  const int use_hbd = frame->flags & YV12_FLAG_HIGHBITDEPTH;
    109 
    110  // Loop through each block.
    111  for (row = 0; row < frame->y_height / block_h; ++row) {
    112    for (col = 0; col < frame->y_width / block_w; ++col) {
    113      struct buf_2d buf;
    114      const int row_offset_y = row * block_h;
    115      const int col_offset_y = col * block_w;
    116 
    117      buf.buf = (uint8_t *)y_buffer + row_offset_y * y_stride + col_offset_y;
    118      buf.stride = y_stride;
    119 
    120      var += av1_get_perpixel_variance(cpi, xd, &buf, block_size, AOM_PLANE_Y,
    121                                       use_hbd);
    122      var_count += 1.0;
    123    }
    124  }
    125  var /= var_count;
    126  return var;
    127 }
    128 
    129 static double residual_frame_average_variance(AV1_COMP *cpi,
    130                                              const YV12_BUFFER_CONFIG *src,
    131                                              const YV12_BUFFER_CONFIG *ref,
    132                                              FULLPEL_MV *mvs) {
    133  if (ref == NULL) return frame_average_variance(cpi, src);
    134  const BLOCK_SIZE block_size = BLOCK_16X16;
    135  const int frame_height = src->y_height;
    136  const int frame_width = src->y_width;
    137  const int mb_height = block_size_high[block_size];
    138  const int mb_width = block_size_wide[block_size];
    139  const int mb_rows = (frame_height + mb_height - 1) / mb_height;
    140  const int mb_cols = (frame_width + mb_width - 1) / mb_width;
    141  const int num_planes = av1_num_planes(&cpi->common);
    142  const int mi_h = mi_size_high_log2[block_size];
    143  const int mi_w = mi_size_wide_log2[block_size];
    144  assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
    145 
    146  // Save input state.
    147  MACROBLOCK *const mb = &cpi->td.mb;
    148  MACROBLOCKD *const mbd = &mb->e_mbd;
    149  uint8_t *input_buffer[MAX_MB_PLANE];
    150  for (int i = 0; i < num_planes; i++) {
    151    input_buffer[i] = mbd->plane[i].pre[0].buf;
    152  }
    153  MB_MODE_INFO **input_mb_mode_info = mbd->mi;
    154 
    155  bool do_motion_search = false;
    156  if (mvs == NULL) {
    157    do_motion_search = true;
    158    CHECK_MEM_ERROR(&cpi->common, mvs,
    159                    (FULLPEL_MV *)aom_calloc(mb_rows * mb_cols, sizeof(*mvs)));
    160  }
    161 
    162  unsigned int variance = 0;
    163  // Perform temporal filtering block by block.
    164  for (int mb_row = 0; mb_row < mb_rows; mb_row++) {
    165    av1_set_mv_row_limits(&cpi->common.mi_params, &mb->mv_limits,
    166                          (mb_row << mi_h), (mb_height >> MI_SIZE_LOG2),
    167                          cpi->oxcf.border_in_pixels);
    168    for (int mb_col = 0; mb_col < mb_cols; mb_col++) {
    169      av1_set_mv_col_limits(&cpi->common.mi_params, &mb->mv_limits,
    170                            (mb_col << mi_w), (mb_width >> MI_SIZE_LOG2),
    171                            cpi->oxcf.border_in_pixels);
    172      FULLPEL_MV *ref_mv = &mvs[mb_col + mb_row * mb_cols];
    173      if (do_motion_search) {
    174        motion_search(cpi, src, ref, block_size, mb_row, mb_col, ref_mv);
    175      }
    176      unsigned int mv_sse;
    177      const unsigned int blk_var = residual_variance(
    178          cpi, src, ref, block_size, mb_row, mb_col, *ref_mv, &mv_sse);
    179      variance += blk_var;
    180    }
    181  }
    182 
    183  // Restore input state
    184  for (int i = 0; i < num_planes; i++) {
    185    mbd->plane[i].pre[0].buf = input_buffer[i];
    186  }
    187  mbd->mi = input_mb_mode_info;
    188  return (double)variance / (double)(mb_rows * mb_cols);
    189 }
    190 
    191 // TODO(sdeng): Add the SIMD implementation.
    192 static inline void highbd_unsharp_rect(const uint16_t *source,
    193                                       int source_stride,
    194                                       const uint16_t *blurred,
    195                                       int blurred_stride, uint16_t *dst,
    196                                       int dst_stride, int w, int h,
    197                                       double amount, int bit_depth) {
    198  const int max_value = (1 << bit_depth) - 1;
    199  for (int i = 0; i < h; ++i) {
    200    for (int j = 0; j < w; ++j) {
    201      const double val =
    202          (double)source[j] + amount * ((double)source[j] - (double)blurred[j]);
    203      dst[j] = (uint16_t)clamp((int)(val + 0.5), 0, max_value);
    204    }
    205    source += source_stride;
    206    blurred += blurred_stride;
    207    dst += dst_stride;
    208  }
    209 }
    210 
    211 static inline void unsharp_rect(const uint8_t *source, int source_stride,
    212                                const uint8_t *blurred, int blurred_stride,
    213                                uint8_t *dst, int dst_stride, int w, int h,
    214                                double amount) {
    215  for (int i = 0; i < h; ++i) {
    216    for (int j = 0; j < w; ++j) {
    217      const double val =
    218          (double)source[j] + amount * ((double)source[j] - (double)blurred[j]);
    219      dst[j] = (uint8_t)clamp((int)(val + 0.5), 0, 255);
    220    }
    221    source += source_stride;
    222    blurred += blurred_stride;
    223    dst += dst_stride;
    224  }
    225 }
    226 
    227 static inline void unsharp(const AV1_COMP *const cpi,
    228                           const YV12_BUFFER_CONFIG *source,
    229                           const YV12_BUFFER_CONFIG *blurred,
    230                           const YV12_BUFFER_CONFIG *dst, double amount) {
    231  const int bit_depth = cpi->td.mb.e_mbd.bd;
    232  if (cpi->common.seq_params->use_highbitdepth) {
    233    assert(source->flags & YV12_FLAG_HIGHBITDEPTH);
    234    assert(blurred->flags & YV12_FLAG_HIGHBITDEPTH);
    235    assert(dst->flags & YV12_FLAG_HIGHBITDEPTH);
    236    highbd_unsharp_rect(CONVERT_TO_SHORTPTR(source->y_buffer), source->y_stride,
    237                        CONVERT_TO_SHORTPTR(blurred->y_buffer),
    238                        blurred->y_stride, CONVERT_TO_SHORTPTR(dst->y_buffer),
    239                        dst->y_stride, source->y_width, source->y_height,
    240                        amount, bit_depth);
    241  } else {
    242    unsharp_rect(source->y_buffer, source->y_stride, blurred->y_buffer,
    243                 blurred->y_stride, dst->y_buffer, dst->y_stride,
    244                 source->y_width, source->y_height, amount);
    245  }
    246 }
    247 
    248 // 8-tap Gaussian convolution filter with sigma = 1.0, sums to 128,
    249 // all co-efficients must be even.
    250 // The array is of size 9 to allow passing gauss_filter + 1 to
    251 // _mm_loadu_si128() in prepare_coeffs_6t().
    252 DECLARE_ALIGNED(16, static const int16_t, gauss_filter[9]) = { 0,  8, 30, 52,
    253                                                               30, 8, 0,  0 };
    254 static inline void gaussian_blur(const int bit_depth,
    255                                 const YV12_BUFFER_CONFIG *source,
    256                                 const YV12_BUFFER_CONFIG *dst) {
    257  const int block_size = BLOCK_128X128;
    258  const int block_w = mi_size_wide[block_size] * 4;
    259  const int block_h = mi_size_high[block_size] * 4;
    260  const int num_cols = (source->y_width + block_w - 1) / block_w;
    261  const int num_rows = (source->y_height + block_h - 1) / block_h;
    262  int row, col;
    263 
    264  ConvolveParams conv_params = get_conv_params(0, 0, bit_depth);
    265  InterpFilterParams filter = { .filter_ptr = gauss_filter,
    266                                .taps = 8,
    267                                .interp_filter = EIGHTTAP_REGULAR };
    268 
    269  for (row = 0; row < num_rows; ++row) {
    270    for (col = 0; col < num_cols; ++col) {
    271      const int row_offset_y = row * block_h;
    272      const int col_offset_y = col * block_w;
    273 
    274      uint8_t *src_buf =
    275          source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
    276      uint8_t *dst_buf =
    277          dst->y_buffer + row_offset_y * dst->y_stride + col_offset_y;
    278 
    279      if (source->flags & YV12_FLAG_HIGHBITDEPTH) {
    280        av1_highbd_convolve_2d_sr(
    281            CONVERT_TO_SHORTPTR(src_buf), source->y_stride,
    282            CONVERT_TO_SHORTPTR(dst_buf), dst->y_stride, block_w, block_h,
    283            &filter, &filter, 0, 0, &conv_params, bit_depth);
    284      } else {
    285        av1_convolve_2d_sr(src_buf, source->y_stride, dst_buf, dst->y_stride,
    286                           block_w, block_h, &filter, &filter, 0, 0,
    287                           &conv_params);
    288      }
    289    }
    290  }
    291 }
    292 
    293 static inline double cal_approx_vmaf(
    294    const AV1_COMP *const cpi, double source_variance,
    295    const YV12_BUFFER_CONFIG *const source,
    296    const YV12_BUFFER_CONFIG *const sharpened) {
    297  const int bit_depth = cpi->td.mb.e_mbd.bd;
    298  const bool cal_vmaf_neg =
    299      cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN;
    300  double new_vmaf;
    301 
    302  aom_calc_vmaf(cpi->vmaf_info.vmaf_model, source, sharpened, bit_depth,
    303                cal_vmaf_neg, &new_vmaf);
    304 
    305  const double sharpened_var = frame_average_variance(cpi, sharpened);
    306  return source_variance / sharpened_var * (new_vmaf - kBaselineVmaf);
    307 }
    308 
    309 static double find_best_frame_unsharp_amount_loop(
    310    const AV1_COMP *const cpi, const YV12_BUFFER_CONFIG *const source,
    311    const YV12_BUFFER_CONFIG *const blurred,
    312    const YV12_BUFFER_CONFIG *const sharpened, double best_vmaf,
    313    const double baseline_variance, const double unsharp_amount_start,
    314    const double step_size, const int max_loop_count, const double max_amount) {
    315  const double min_amount = 0.0;
    316  int loop_count = 0;
    317  double approx_vmaf = best_vmaf;
    318  double unsharp_amount = unsharp_amount_start;
    319  do {
    320    best_vmaf = approx_vmaf;
    321    unsharp_amount += step_size;
    322    if (unsharp_amount > max_amount || unsharp_amount < min_amount) break;
    323    unsharp(cpi, source, blurred, sharpened, unsharp_amount);
    324    approx_vmaf = cal_approx_vmaf(cpi, baseline_variance, source, sharpened);
    325 
    326    loop_count++;
    327  } while (approx_vmaf > best_vmaf && loop_count < max_loop_count);
    328  unsharp_amount =
    329      approx_vmaf > best_vmaf ? unsharp_amount : unsharp_amount - step_size;
    330  return fclamp(unsharp_amount, min_amount, max_amount);
    331 }
    332 
    333 static double find_best_frame_unsharp_amount(
    334    const AV1_COMP *const cpi, const YV12_BUFFER_CONFIG *const source,
    335    const YV12_BUFFER_CONFIG *const blurred, const double unsharp_amount_start,
    336    const double step_size, const int max_loop_count,
    337    const double max_filter_amount) {
    338  const AV1_COMMON *const cm = &cpi->common;
    339  const int width = source->y_width;
    340  const int height = source->y_height;
    341  YV12_BUFFER_CONFIG sharpened;
    342  memset(&sharpened, 0, sizeof(sharpened));
    343  aom_alloc_frame_buffer(
    344      &sharpened, width, height, source->subsampling_x, source->subsampling_y,
    345      cm->seq_params->use_highbitdepth, cpi->oxcf.border_in_pixels,
    346      cm->features.byte_alignment, false, 0);
    347 
    348  const double baseline_variance = frame_average_variance(cpi, source);
    349  double unsharp_amount;
    350  if (unsharp_amount_start <= step_size) {
    351    unsharp_amount = find_best_frame_unsharp_amount_loop(
    352        cpi, source, blurred, &sharpened, 0.0, baseline_variance, 0.0,
    353        step_size, max_loop_count, max_filter_amount);
    354  } else {
    355    double a0 = unsharp_amount_start - step_size, a1 = unsharp_amount_start;
    356    double v0, v1;
    357    unsharp(cpi, source, blurred, &sharpened, a0);
    358    v0 = cal_approx_vmaf(cpi, baseline_variance, source, &sharpened);
    359    unsharp(cpi, source, blurred, &sharpened, a1);
    360    v1 = cal_approx_vmaf(cpi, baseline_variance, source, &sharpened);
    361    if (fabs(v0 - v1) < 0.01) {
    362      unsharp_amount = a0;
    363    } else if (v0 > v1) {
    364      unsharp_amount = find_best_frame_unsharp_amount_loop(
    365          cpi, source, blurred, &sharpened, v0, baseline_variance, a0,
    366          -step_size, max_loop_count, max_filter_amount);
    367    } else {
    368      unsharp_amount = find_best_frame_unsharp_amount_loop(
    369          cpi, source, blurred, &sharpened, v1, baseline_variance, a1,
    370          step_size, max_loop_count, max_filter_amount);
    371    }
    372  }
    373 
    374  aom_free_frame_buffer(&sharpened);
    375  return unsharp_amount;
    376 }
    377 
    378 void av1_vmaf_neg_preprocessing(AV1_COMP *const cpi,
    379                                const YV12_BUFFER_CONFIG *const source) {
    380  const AV1_COMMON *const cm = &cpi->common;
    381  const int bit_depth = cpi->td.mb.e_mbd.bd;
    382  const int width = source->y_width;
    383  const int height = source->y_height;
    384 
    385  const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
    386  const int layer_depth =
    387      AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
    388  const double best_frame_unsharp_amount =
    389      get_layer_value(cpi->vmaf_info.last_frame_unsharp_amount, layer_depth);
    390 
    391  if (best_frame_unsharp_amount <= 0.0) return;
    392 
    393  YV12_BUFFER_CONFIG blurred;
    394  memset(&blurred, 0, sizeof(blurred));
    395  aom_alloc_frame_buffer(
    396      &blurred, width, height, source->subsampling_x, source->subsampling_y,
    397      cm->seq_params->use_highbitdepth, cpi->oxcf.border_in_pixels,
    398      cm->features.byte_alignment, false, 0);
    399 
    400  gaussian_blur(bit_depth, source, &blurred);
    401  unsharp(cpi, source, &blurred, source, best_frame_unsharp_amount);
    402  aom_free_frame_buffer(&blurred);
    403 }
    404 
    405 void av1_vmaf_frame_preprocessing(AV1_COMP *const cpi,
    406                                  const YV12_BUFFER_CONFIG *const source) {
    407  const AV1_COMMON *const cm = &cpi->common;
    408  const int bit_depth = cpi->td.mb.e_mbd.bd;
    409  const int width = source->y_width;
    410  const int height = source->y_height;
    411 
    412  YV12_BUFFER_CONFIG source_extended, blurred;
    413  memset(&source_extended, 0, sizeof(source_extended));
    414  memset(&blurred, 0, sizeof(blurred));
    415  aom_alloc_frame_buffer(
    416      &source_extended, width, height, source->subsampling_x,
    417      source->subsampling_y, cm->seq_params->use_highbitdepth,
    418      cpi->oxcf.border_in_pixels, cm->features.byte_alignment, false, 0);
    419  aom_alloc_frame_buffer(
    420      &blurred, width, height, source->subsampling_x, source->subsampling_y,
    421      cm->seq_params->use_highbitdepth, cpi->oxcf.border_in_pixels,
    422      cm->features.byte_alignment, false, 0);
    423 
    424  av1_copy_and_extend_frame(source, &source_extended);
    425  gaussian_blur(bit_depth, &source_extended, &blurred);
    426  aom_free_frame_buffer(&source_extended);
    427 
    428  const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
    429  const int layer_depth =
    430      AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
    431  const double last_frame_unsharp_amount =
    432      get_layer_value(cpi->vmaf_info.last_frame_unsharp_amount, layer_depth);
    433 
    434  const double best_frame_unsharp_amount = find_best_frame_unsharp_amount(
    435      cpi, source, &blurred, last_frame_unsharp_amount, 0.05, 20, 1.01);
    436 
    437  cpi->vmaf_info.last_frame_unsharp_amount[layer_depth] =
    438      best_frame_unsharp_amount;
    439 
    440  unsharp(cpi, source, &blurred, source, best_frame_unsharp_amount);
    441  aom_free_frame_buffer(&blurred);
    442 }
    443 
    444 void av1_vmaf_blk_preprocessing(AV1_COMP *const cpi,
    445                                const YV12_BUFFER_CONFIG *const source) {
    446  const AV1_COMMON *const cm = &cpi->common;
    447  const int width = source->y_width;
    448  const int height = source->y_height;
    449  const int bit_depth = cpi->td.mb.e_mbd.bd;
    450  const int ss_x = source->subsampling_x;
    451  const int ss_y = source->subsampling_y;
    452 
    453  YV12_BUFFER_CONFIG source_extended, blurred;
    454  memset(&blurred, 0, sizeof(blurred));
    455  memset(&source_extended, 0, sizeof(source_extended));
    456  aom_alloc_frame_buffer(
    457      &blurred, width, height, ss_x, ss_y, cm->seq_params->use_highbitdepth,
    458      cpi->oxcf.border_in_pixels, cm->features.byte_alignment, false, 0);
    459  aom_alloc_frame_buffer(&source_extended, width, height, ss_x, ss_y,
    460                         cm->seq_params->use_highbitdepth,
    461                         cpi->oxcf.border_in_pixels,
    462                         cm->features.byte_alignment, false, 0);
    463 
    464  av1_copy_and_extend_frame(source, &source_extended);
    465  gaussian_blur(bit_depth, &source_extended, &blurred);
    466  aom_free_frame_buffer(&source_extended);
    467 
    468  const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
    469  const int layer_depth =
    470      AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
    471  const double last_frame_unsharp_amount =
    472      get_layer_value(cpi->vmaf_info.last_frame_unsharp_amount, layer_depth);
    473 
    474  const double best_frame_unsharp_amount = find_best_frame_unsharp_amount(
    475      cpi, source, &blurred, last_frame_unsharp_amount, 0.05, 20, 1.01);
    476 
    477  cpi->vmaf_info.last_frame_unsharp_amount[layer_depth] =
    478      best_frame_unsharp_amount;
    479 
    480  const int block_size = BLOCK_64X64;
    481  const int block_w = mi_size_wide[block_size] * 4;
    482  const int block_h = mi_size_high[block_size] * 4;
    483  const int num_cols = (source->y_width + block_w - 1) / block_w;
    484  const int num_rows = (source->y_height + block_h - 1) / block_h;
    485  double *best_unsharp_amounts =
    486      aom_calloc(num_cols * num_rows, sizeof(*best_unsharp_amounts));
    487  if (!best_unsharp_amounts) {
    488    aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
    489                       "Error allocating vmaf data");
    490  }
    491 
    492  YV12_BUFFER_CONFIG source_block, blurred_block;
    493  memset(&source_block, 0, sizeof(source_block));
    494  memset(&blurred_block, 0, sizeof(blurred_block));
    495  aom_alloc_frame_buffer(&source_block, block_w, block_h, ss_x, ss_y,
    496                         cm->seq_params->use_highbitdepth,
    497                         cpi->oxcf.border_in_pixels,
    498                         cm->features.byte_alignment, false, 0);
    499  aom_alloc_frame_buffer(&blurred_block, block_w, block_h, ss_x, ss_y,
    500                         cm->seq_params->use_highbitdepth,
    501                         cpi->oxcf.border_in_pixels,
    502                         cm->features.byte_alignment, false, 0);
    503 
    504  for (int row = 0; row < num_rows; ++row) {
    505    for (int col = 0; col < num_cols; ++col) {
    506      const int row_offset_y = row * block_h;
    507      const int col_offset_y = col * block_w;
    508      const int block_width = AOMMIN(width - col_offset_y, block_w);
    509      const int block_height = AOMMIN(height - row_offset_y, block_h);
    510      const int index = col + row * num_cols;
    511 
    512      if (cm->seq_params->use_highbitdepth) {
    513        assert(source->flags & YV12_FLAG_HIGHBITDEPTH);
    514        assert(blurred.flags & YV12_FLAG_HIGHBITDEPTH);
    515        uint16_t *frame_src_buf = CONVERT_TO_SHORTPTR(source->y_buffer) +
    516                                  row_offset_y * source->y_stride +
    517                                  col_offset_y;
    518        uint16_t *frame_blurred_buf = CONVERT_TO_SHORTPTR(blurred.y_buffer) +
    519                                      row_offset_y * blurred.y_stride +
    520                                      col_offset_y;
    521        uint16_t *blurred_dst = CONVERT_TO_SHORTPTR(blurred_block.y_buffer);
    522        uint16_t *src_dst = CONVERT_TO_SHORTPTR(source_block.y_buffer);
    523 
    524        // Copy block from source frame.
    525        for (int i = 0; i < block_h; ++i) {
    526          for (int j = 0; j < block_w; ++j) {
    527            if (i >= block_height || j >= block_width) {
    528              src_dst[j] = 0;
    529              blurred_dst[j] = 0;
    530            } else {
    531              src_dst[j] = frame_src_buf[j];
    532              blurred_dst[j] = frame_blurred_buf[j];
    533            }
    534          }
    535          frame_src_buf += source->y_stride;
    536          frame_blurred_buf += blurred.y_stride;
    537          src_dst += source_block.y_stride;
    538          blurred_dst += blurred_block.y_stride;
    539        }
    540      } else {
    541        uint8_t *frame_src_buf =
    542            source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
    543        uint8_t *frame_blurred_buf =
    544            blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
    545        uint8_t *blurred_dst = blurred_block.y_buffer;
    546        uint8_t *src_dst = source_block.y_buffer;
    547 
    548        // Copy block from source frame.
    549        for (int i = 0; i < block_h; ++i) {
    550          for (int j = 0; j < block_w; ++j) {
    551            if (i >= block_height || j >= block_width) {
    552              src_dst[j] = 0;
    553              blurred_dst[j] = 0;
    554            } else {
    555              src_dst[j] = frame_src_buf[j];
    556              blurred_dst[j] = frame_blurred_buf[j];
    557            }
    558          }
    559          frame_src_buf += source->y_stride;
    560          frame_blurred_buf += blurred.y_stride;
    561          src_dst += source_block.y_stride;
    562          blurred_dst += blurred_block.y_stride;
    563        }
    564      }
    565 
    566      best_unsharp_amounts[index] = find_best_frame_unsharp_amount(
    567          cpi, &source_block, &blurred_block, best_frame_unsharp_amount, 0.1, 3,
    568          1.5);
    569    }
    570  }
    571 
    572  // Apply best blur amounts
    573  for (int row = 0; row < num_rows; ++row) {
    574    for (int col = 0; col < num_cols; ++col) {
    575      const int row_offset_y = row * block_h;
    576      const int col_offset_y = col * block_w;
    577      const int block_width = AOMMIN(source->y_width - col_offset_y, block_w);
    578      const int block_height = AOMMIN(source->y_height - row_offset_y, block_h);
    579      const int index = col + row * num_cols;
    580 
    581      if (cm->seq_params->use_highbitdepth) {
    582        assert(source->flags & YV12_FLAG_HIGHBITDEPTH);
    583        assert(blurred.flags & YV12_FLAG_HIGHBITDEPTH);
    584        uint16_t *src_buf = CONVERT_TO_SHORTPTR(source->y_buffer) +
    585                            row_offset_y * source->y_stride + col_offset_y;
    586        uint16_t *blurred_buf = CONVERT_TO_SHORTPTR(blurred.y_buffer) +
    587                                row_offset_y * blurred.y_stride + col_offset_y;
    588        highbd_unsharp_rect(src_buf, source->y_stride, blurred_buf,
    589                            blurred.y_stride, src_buf, source->y_stride,
    590                            block_width, block_height,
    591                            best_unsharp_amounts[index], bit_depth);
    592      } else {
    593        uint8_t *src_buf =
    594            source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
    595        uint8_t *blurred_buf =
    596            blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
    597        unsharp_rect(src_buf, source->y_stride, blurred_buf, blurred.y_stride,
    598                     src_buf, source->y_stride, block_width, block_height,
    599                     best_unsharp_amounts[index]);
    600      }
    601    }
    602  }
    603 
    604  aom_free_frame_buffer(&source_block);
    605  aom_free_frame_buffer(&blurred_block);
    606  aom_free_frame_buffer(&blurred);
    607  aom_free(best_unsharp_amounts);
    608 }
    609 
    610 void av1_set_mb_vmaf_rdmult_scaling(AV1_COMP *cpi) {
    611  AV1_COMMON *cm = &cpi->common;
    612  const int y_width = cpi->source->y_width;
    613  const int y_height = cpi->source->y_height;
    614  const int resized_block_size = BLOCK_32X32;
    615  const int resize_factor = 2;
    616  const int bit_depth = cpi->td.mb.e_mbd.bd;
    617  const int ss_x = cpi->source->subsampling_x;
    618  const int ss_y = cpi->source->subsampling_y;
    619 
    620  YV12_BUFFER_CONFIG resized_source;
    621  memset(&resized_source, 0, sizeof(resized_source));
    622  aom_alloc_frame_buffer(
    623      &resized_source, y_width / resize_factor, y_height / resize_factor, ss_x,
    624      ss_y, cm->seq_params->use_highbitdepth, cpi->oxcf.border_in_pixels,
    625      cm->features.byte_alignment, false, 0);
    626  if (!av1_resize_and_extend_frame_nonnormative(
    627          cpi->source, &resized_source, bit_depth, av1_num_planes(cm))) {
    628    aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
    629                       "Error allocating buffers during resize");
    630  }
    631 
    632  const int resized_y_width = resized_source.y_width;
    633  const int resized_y_height = resized_source.y_height;
    634  const int resized_block_w = mi_size_wide[resized_block_size] * 4;
    635  const int resized_block_h = mi_size_high[resized_block_size] * 4;
    636  const int num_cols =
    637      (resized_y_width + resized_block_w - 1) / resized_block_w;
    638  const int num_rows =
    639      (resized_y_height + resized_block_h - 1) / resized_block_h;
    640 
    641  YV12_BUFFER_CONFIG blurred;
    642  memset(&blurred, 0, sizeof(blurred));
    643  aom_alloc_frame_buffer(&blurred, resized_y_width, resized_y_height, ss_x,
    644                         ss_y, cm->seq_params->use_highbitdepth,
    645                         cpi->oxcf.border_in_pixels,
    646                         cm->features.byte_alignment, false, 0);
    647  gaussian_blur(bit_depth, &resized_source, &blurred);
    648 
    649  YV12_BUFFER_CONFIG recon;
    650  memset(&recon, 0, sizeof(recon));
    651  aom_alloc_frame_buffer(&recon, resized_y_width, resized_y_height, ss_x, ss_y,
    652                         cm->seq_params->use_highbitdepth,
    653                         cpi->oxcf.border_in_pixels,
    654                         cm->features.byte_alignment, false, 0);
    655  aom_yv12_copy_frame(&resized_source, &recon, 1);
    656 
    657  VmafContext *vmaf_context;
    658  const bool cal_vmaf_neg =
    659      cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN;
    660  aom_init_vmaf_context(&vmaf_context, cpi->vmaf_info.vmaf_model, cal_vmaf_neg);
    661  unsigned int *sses = aom_calloc(num_rows * num_cols, sizeof(*sses));
    662  if (!sses) {
    663    aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
    664                       "Error allocating vmaf data");
    665  }
    666 
    667  // Loop through each 'block_size' block.
    668  for (int row = 0; row < num_rows; ++row) {
    669    for (int col = 0; col < num_cols; ++col) {
    670      const int index = row * num_cols + col;
    671      const int row_offset_y = row * resized_block_h;
    672      const int col_offset_y = col * resized_block_w;
    673 
    674      uint8_t *const orig_buf = resized_source.y_buffer +
    675                                row_offset_y * resized_source.y_stride +
    676                                col_offset_y;
    677      uint8_t *const blurred_buf =
    678          blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
    679 
    680      cpi->ppi->fn_ptr[resized_block_size].vf(orig_buf, resized_source.y_stride,
    681                                              blurred_buf, blurred.y_stride,
    682                                              &sses[index]);
    683 
    684      uint8_t *const recon_buf =
    685          recon.y_buffer + row_offset_y * recon.y_stride + col_offset_y;
    686      // Set recon buf
    687      if (cpi->common.seq_params->use_highbitdepth) {
    688        highbd_unsharp_rect(CONVERT_TO_SHORTPTR(blurred_buf), blurred.y_stride,
    689                            CONVERT_TO_SHORTPTR(blurred_buf), blurred.y_stride,
    690                            CONVERT_TO_SHORTPTR(recon_buf), recon.y_stride,
    691                            resized_block_w, resized_block_h, 0.0, bit_depth);
    692      } else {
    693        unsharp_rect(blurred_buf, blurred.y_stride, blurred_buf,
    694                     blurred.y_stride, recon_buf, recon.y_stride,
    695                     resized_block_w, resized_block_h, 0.0);
    696      }
    697 
    698      aom_read_vmaf_image(vmaf_context, &resized_source, &recon, bit_depth,
    699                          index);
    700 
    701      // Restore recon buf
    702      if (cpi->common.seq_params->use_highbitdepth) {
    703        highbd_unsharp_rect(
    704            CONVERT_TO_SHORTPTR(orig_buf), resized_source.y_stride,
    705            CONVERT_TO_SHORTPTR(orig_buf), resized_source.y_stride,
    706            CONVERT_TO_SHORTPTR(recon_buf), recon.y_stride, resized_block_w,
    707            resized_block_h, 0.0, bit_depth);
    708      } else {
    709        unsharp_rect(orig_buf, resized_source.y_stride, orig_buf,
    710                     resized_source.y_stride, recon_buf, recon.y_stride,
    711                     resized_block_w, resized_block_h, 0.0);
    712      }
    713    }
    714  }
    715  aom_flush_vmaf_context(vmaf_context);
    716  for (int row = 0; row < num_rows; ++row) {
    717    for (int col = 0; col < num_cols; ++col) {
    718      const int index = row * num_cols + col;
    719      const double vmaf = aom_calc_vmaf_at_index(
    720          vmaf_context, cpi->vmaf_info.vmaf_model, index);
    721      const double dvmaf = kBaselineVmaf - vmaf;
    722 
    723      const double mse =
    724          (double)sses[index] / (double)(resized_y_width * resized_y_height);
    725      double weight;
    726      const double eps = 0.01 / (num_rows * num_cols);
    727      if (dvmaf < eps || mse < eps) {
    728        weight = 1.0;
    729      } else {
    730        weight = mse / dvmaf;
    731      }
    732 
    733      // Normalize it with a data fitted model.
    734      weight = 6.0 * (1.0 - exp(-0.05 * weight)) + 0.8;
    735      cpi->vmaf_info.rdmult_scaling_factors[index] = weight;
    736    }
    737  }
    738 
    739  aom_free_frame_buffer(&resized_source);
    740  aom_free_frame_buffer(&blurred);
    741  aom_close_vmaf_context(vmaf_context);
    742  aom_free(sses);
    743 }
    744 
    745 void av1_set_vmaf_rdmult(const AV1_COMP *const cpi, MACROBLOCK *const x,
    746                         const BLOCK_SIZE bsize, const int mi_row,
    747                         const int mi_col, int *const rdmult) {
    748  const AV1_COMMON *const cm = &cpi->common;
    749 
    750  const int bsize_base = BLOCK_64X64;
    751  const int num_mi_w = mi_size_wide[bsize_base];
    752  const int num_mi_h = mi_size_high[bsize_base];
    753  const int num_cols = (cm->mi_params.mi_cols + num_mi_w - 1) / num_mi_w;
    754  const int num_rows = (cm->mi_params.mi_rows + num_mi_h - 1) / num_mi_h;
    755  const int num_bcols = (mi_size_wide[bsize] + num_mi_w - 1) / num_mi_w;
    756  const int num_brows = (mi_size_high[bsize] + num_mi_h - 1) / num_mi_h;
    757  int row, col;
    758  double num_of_mi = 0.0;
    759  double geom_mean_of_scale = 0.0;
    760 
    761  for (row = mi_row / num_mi_w;
    762       row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
    763    for (col = mi_col / num_mi_h;
    764         col < num_cols && col < mi_col / num_mi_h + num_bcols; ++col) {
    765      const int index = row * num_cols + col;
    766      geom_mean_of_scale += log(cpi->vmaf_info.rdmult_scaling_factors[index]);
    767      num_of_mi += 1.0;
    768    }
    769  }
    770  geom_mean_of_scale = exp(geom_mean_of_scale / num_of_mi);
    771 
    772  *rdmult = (int)((double)(*rdmult) * geom_mean_of_scale + 0.5);
    773  *rdmult = AOMMAX(*rdmult, 0);
    774  av1_set_error_per_bit(&x->errorperbit, *rdmult);
    775 }
    776 
    777 // TODO(sdeng): replace them with the SIMD versions.
    778 static inline double highbd_image_sad_c(const uint16_t *src, int src_stride,
    779                                        const uint16_t *ref, int ref_stride,
    780                                        int w, int h) {
    781  double accum = 0.0;
    782  int i, j;
    783 
    784  for (i = 0; i < h; ++i) {
    785    for (j = 0; j < w; ++j) {
    786      double img1px = src[i * src_stride + j];
    787      double img2px = ref[i * ref_stride + j];
    788 
    789      accum += fabs(img1px - img2px);
    790    }
    791  }
    792 
    793  return accum / (double)(h * w);
    794 }
    795 
    796 static inline double image_sad_c(const uint8_t *src, int src_stride,
    797                                 const uint8_t *ref, int ref_stride, int w,
    798                                 int h) {
    799  double accum = 0.0;
    800  int i, j;
    801 
    802  for (i = 0; i < h; ++i) {
    803    for (j = 0; j < w; ++j) {
    804      double img1px = src[i * src_stride + j];
    805      double img2px = ref[i * ref_stride + j];
    806 
    807      accum += fabs(img1px - img2px);
    808    }
    809  }
    810 
    811  return accum / (double)(h * w);
    812 }
    813 
    814 static double calc_vmaf_motion_score(const AV1_COMP *const cpi,
    815                                     const AV1_COMMON *const cm,
    816                                     const YV12_BUFFER_CONFIG *const cur,
    817                                     const YV12_BUFFER_CONFIG *const last,
    818                                     const YV12_BUFFER_CONFIG *const next) {
    819  const int y_width = cur->y_width;
    820  const int y_height = cur->y_height;
    821  YV12_BUFFER_CONFIG blurred_cur, blurred_last, blurred_next;
    822  const int bit_depth = cpi->td.mb.e_mbd.bd;
    823  const int ss_x = cur->subsampling_x;
    824  const int ss_y = cur->subsampling_y;
    825 
    826  memset(&blurred_cur, 0, sizeof(blurred_cur));
    827  memset(&blurred_last, 0, sizeof(blurred_last));
    828  memset(&blurred_next, 0, sizeof(blurred_next));
    829 
    830  aom_alloc_frame_buffer(&blurred_cur, y_width, y_height, ss_x, ss_y,
    831                         cm->seq_params->use_highbitdepth,
    832                         cpi->oxcf.border_in_pixels,
    833                         cm->features.byte_alignment, false, 0);
    834  aom_alloc_frame_buffer(&blurred_last, y_width, y_height, ss_x, ss_y,
    835                         cm->seq_params->use_highbitdepth,
    836                         cpi->oxcf.border_in_pixels,
    837                         cm->features.byte_alignment, false, 0);
    838  aom_alloc_frame_buffer(&blurred_next, y_width, y_height, ss_x, ss_y,
    839                         cm->seq_params->use_highbitdepth,
    840                         cpi->oxcf.border_in_pixels,
    841                         cm->features.byte_alignment, false, 0);
    842 
    843  gaussian_blur(bit_depth, cur, &blurred_cur);
    844  gaussian_blur(bit_depth, last, &blurred_last);
    845  if (next) gaussian_blur(bit_depth, next, &blurred_next);
    846 
    847  double motion1, motion2 = 65536.0;
    848  if (cm->seq_params->use_highbitdepth) {
    849    assert(blurred_cur.flags & YV12_FLAG_HIGHBITDEPTH);
    850    assert(blurred_last.flags & YV12_FLAG_HIGHBITDEPTH);
    851    const float scale_factor = 1.0f / (float)(1 << (bit_depth - 8));
    852    motion1 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
    853                                 blurred_cur.y_stride,
    854                                 CONVERT_TO_SHORTPTR(blurred_last.y_buffer),
    855                                 blurred_last.y_stride, y_width, y_height) *
    856              scale_factor;
    857    if (next) {
    858      assert(blurred_next.flags & YV12_FLAG_HIGHBITDEPTH);
    859      motion2 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
    860                                   blurred_cur.y_stride,
    861                                   CONVERT_TO_SHORTPTR(blurred_next.y_buffer),
    862                                   blurred_next.y_stride, y_width, y_height) *
    863                scale_factor;
    864    }
    865  } else {
    866    motion1 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
    867                          blurred_last.y_buffer, blurred_last.y_stride, y_width,
    868                          y_height);
    869    if (next) {
    870      motion2 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
    871                            blurred_next.y_buffer, blurred_next.y_stride,
    872                            y_width, y_height);
    873    }
    874  }
    875 
    876  aom_free_frame_buffer(&blurred_cur);
    877  aom_free_frame_buffer(&blurred_last);
    878  aom_free_frame_buffer(&blurred_next);
    879 
    880  return AOMMIN(motion1, motion2);
    881 }
    882 
    883 static inline void get_neighbor_frames(const AV1_COMP *const cpi,
    884                                       const YV12_BUFFER_CONFIG **last,
    885                                       const YV12_BUFFER_CONFIG **next) {
    886  const AV1_COMMON *const cm = &cpi->common;
    887  const GF_GROUP *gf_group = &cpi->ppi->gf_group;
    888  const int src_index =
    889      cm->show_frame != 0 ? 0 : gf_group->arf_src_offset[cpi->gf_frame_index];
    890  struct lookahead_entry *last_entry = av1_lookahead_peek(
    891      cpi->ppi->lookahead, src_index - 1, cpi->compressor_stage);
    892  struct lookahead_entry *next_entry = av1_lookahead_peek(
    893      cpi->ppi->lookahead, src_index + 1, cpi->compressor_stage);
    894  *next = &next_entry->img;
    895  *last = cm->show_frame ? cpi->last_source : &last_entry->img;
    896 }
    897 
    898 // Calculates the new qindex from the VMAF motion score. This is based on the
    899 // observation: when the motion score becomes higher, the VMAF score of the
    900 // same source and distorted frames would become higher.
    901 int av1_get_vmaf_base_qindex(const AV1_COMP *const cpi, int current_qindex) {
    902  const AV1_COMMON *const cm = &cpi->common;
    903  if (cm->current_frame.frame_number == 0 || cpi->oxcf.pass == 1) {
    904    return current_qindex;
    905  }
    906  const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
    907  const int layer_depth =
    908      AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
    909  const double last_frame_ysse =
    910      get_layer_value(cpi->vmaf_info.last_frame_ysse, layer_depth);
    911  const double last_frame_vmaf =
    912      get_layer_value(cpi->vmaf_info.last_frame_vmaf, layer_depth);
    913  const int bit_depth = cpi->td.mb.e_mbd.bd;
    914  const double approx_sse = last_frame_ysse / (double)((1 << (bit_depth - 8)) *
    915                                                       (1 << (bit_depth - 8)));
    916  const double approx_dvmaf = kBaselineVmaf - last_frame_vmaf;
    917  const double sse_threshold =
    918      0.01 * cpi->source->y_width * cpi->source->y_height;
    919  const double vmaf_threshold = 0.01;
    920  if (approx_sse < sse_threshold || approx_dvmaf < vmaf_threshold) {
    921    return current_qindex;
    922  }
    923  const YV12_BUFFER_CONFIG *cur_buf = cpi->source;
    924  if (cm->show_frame == 0) {
    925    const int src_index = gf_group->arf_src_offset[cpi->gf_frame_index];
    926    struct lookahead_entry *cur_entry = av1_lookahead_peek(
    927        cpi->ppi->lookahead, src_index, cpi->compressor_stage);
    928    cur_buf = &cur_entry->img;
    929  }
    930  assert(cur_buf);
    931 
    932  const YV12_BUFFER_CONFIG *next_buf, *last_buf;
    933  get_neighbor_frames(cpi, &last_buf, &next_buf);
    934  assert(last_buf);
    935 
    936  const double motion =
    937      calc_vmaf_motion_score(cpi, cm, cur_buf, last_buf, next_buf);
    938 
    939  // Get dVMAF through a data fitted model.
    940  const double dvmaf = 26.11 * (1.0 - exp(-0.06 * motion));
    941  const double dsse = dvmaf * approx_sse / approx_dvmaf;
    942 
    943  // Clamping beta to address VQ issue (aomedia:3170).
    944  const double beta = AOMMAX(approx_sse / (dsse + approx_sse), 0.5);
    945  const int offset =
    946      av1_get_deltaq_offset(cm->seq_params->bit_depth, current_qindex, beta);
    947  const int qindex = clamp(current_qindex + offset, MINQ, MAXQ);
    948 
    949  return qindex;
    950 }
    951 
    952 static inline double cal_approx_score(
    953    AV1_COMP *const cpi, double src_variance, double new_variance,
    954    double src_score, const YV12_BUFFER_CONFIG *const src,
    955    const YV12_BUFFER_CONFIG *const recon_sharpened) {
    956  double score;
    957  const uint32_t bit_depth = cpi->td.mb.e_mbd.bd;
    958  const bool cal_vmaf_neg =
    959      cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN;
    960  aom_calc_vmaf(cpi->vmaf_info.vmaf_model, src, recon_sharpened, bit_depth,
    961                cal_vmaf_neg, &score);
    962  return src_variance / new_variance * (score - src_score);
    963 }
    964 
    965 static double find_best_frame_unsharp_amount_loop_neg(
    966    AV1_COMP *const cpi, double src_variance, double base_score,
    967    const YV12_BUFFER_CONFIG *const src, const YV12_BUFFER_CONFIG *const recon,
    968    const YV12_BUFFER_CONFIG *const ref,
    969    const YV12_BUFFER_CONFIG *const src_blurred,
    970    const YV12_BUFFER_CONFIG *const recon_blurred,
    971    const YV12_BUFFER_CONFIG *const src_sharpened,
    972    const YV12_BUFFER_CONFIG *const recon_sharpened, FULLPEL_MV *mvs,
    973    double best_score, const double unsharp_amount_start,
    974    const double step_size, const int max_loop_count, const double max_amount) {
    975  const double min_amount = 0.0;
    976  int loop_count = 0;
    977  double approx_score = best_score;
    978  double unsharp_amount = unsharp_amount_start;
    979 
    980  do {
    981    best_score = approx_score;
    982    unsharp_amount += step_size;
    983    if (unsharp_amount > max_amount || unsharp_amount < min_amount) break;
    984    unsharp(cpi, recon, recon_blurred, recon_sharpened, unsharp_amount);
    985    unsharp(cpi, src, src_blurred, src_sharpened, unsharp_amount);
    986    const double new_variance =
    987        residual_frame_average_variance(cpi, src_sharpened, ref, mvs);
    988    approx_score = cal_approx_score(cpi, src_variance, new_variance, base_score,
    989                                    src, recon_sharpened);
    990 
    991    loop_count++;
    992  } while (approx_score > best_score && loop_count < max_loop_count);
    993  unsharp_amount =
    994      approx_score > best_score ? unsharp_amount : unsharp_amount - step_size;
    995 
    996  return fclamp(unsharp_amount, min_amount, max_amount);
    997 }
    998 
    999 static double find_best_frame_unsharp_amount_neg(
   1000    AV1_COMP *const cpi, const YV12_BUFFER_CONFIG *const src,
   1001    const YV12_BUFFER_CONFIG *const recon, const YV12_BUFFER_CONFIG *const ref,
   1002    double base_score, const double unsharp_amount_start,
   1003    const double step_size, const int max_loop_count,
   1004    const double max_filter_amount) {
   1005  FULLPEL_MV *mvs = NULL;
   1006  const double src_variance =
   1007      residual_frame_average_variance(cpi, src, ref, mvs);
   1008 
   1009  const AV1_COMMON *const cm = &cpi->common;
   1010  const int width = recon->y_width;
   1011  const int height = recon->y_height;
   1012  const int bit_depth = cpi->td.mb.e_mbd.bd;
   1013  const int ss_x = recon->subsampling_x;
   1014  const int ss_y = recon->subsampling_y;
   1015 
   1016  YV12_BUFFER_CONFIG src_blurred, recon_blurred, src_sharpened, recon_sharpened;
   1017  memset(&recon_sharpened, 0, sizeof(recon_sharpened));
   1018  memset(&src_sharpened, 0, sizeof(src_sharpened));
   1019  memset(&recon_blurred, 0, sizeof(recon_blurred));
   1020  memset(&src_blurred, 0, sizeof(src_blurred));
   1021  aom_alloc_frame_buffer(&recon_sharpened, width, height, ss_x, ss_y,
   1022                         cm->seq_params->use_highbitdepth,
   1023                         cpi->oxcf.border_in_pixels,
   1024                         cm->features.byte_alignment, false, 0);
   1025  aom_alloc_frame_buffer(&src_sharpened, width, height, ss_x, ss_y,
   1026                         cm->seq_params->use_highbitdepth,
   1027                         cpi->oxcf.border_in_pixels,
   1028                         cm->features.byte_alignment, false, 0);
   1029  aom_alloc_frame_buffer(&recon_blurred, width, height, ss_x, ss_y,
   1030                         cm->seq_params->use_highbitdepth,
   1031                         cpi->oxcf.border_in_pixels,
   1032                         cm->features.byte_alignment, false, 0);
   1033  aom_alloc_frame_buffer(
   1034      &src_blurred, width, height, ss_x, ss_y, cm->seq_params->use_highbitdepth,
   1035      cpi->oxcf.border_in_pixels, cm->features.byte_alignment, false, 0);
   1036 
   1037  gaussian_blur(bit_depth, recon, &recon_blurred);
   1038  gaussian_blur(bit_depth, src, &src_blurred);
   1039 
   1040  unsharp(cpi, recon, &recon_blurred, &recon_sharpened, unsharp_amount_start);
   1041  unsharp(cpi, src, &src_blurred, &src_sharpened, unsharp_amount_start);
   1042  const double variance_start =
   1043      residual_frame_average_variance(cpi, &src_sharpened, ref, mvs);
   1044  const double score_start = cal_approx_score(
   1045      cpi, src_variance, variance_start, base_score, src, &recon_sharpened);
   1046 
   1047  const double unsharp_amount_next = unsharp_amount_start + step_size;
   1048  unsharp(cpi, recon, &recon_blurred, &recon_sharpened, unsharp_amount_next);
   1049  unsharp(cpi, src, &src_blurred, &src_sharpened, unsharp_amount_next);
   1050  const double variance_next =
   1051      residual_frame_average_variance(cpi, &src_sharpened, ref, mvs);
   1052  const double score_next = cal_approx_score(cpi, src_variance, variance_next,
   1053                                             base_score, src, &recon_sharpened);
   1054 
   1055  double unsharp_amount;
   1056  if (score_next > score_start) {
   1057    unsharp_amount = find_best_frame_unsharp_amount_loop_neg(
   1058        cpi, src_variance, base_score, src, recon, ref, &src_blurred,
   1059        &recon_blurred, &src_sharpened, &recon_sharpened, mvs, score_next,
   1060        unsharp_amount_next, step_size, max_loop_count, max_filter_amount);
   1061  } else {
   1062    unsharp_amount = find_best_frame_unsharp_amount_loop_neg(
   1063        cpi, src_variance, base_score, src, recon, ref, &src_blurred,
   1064        &recon_blurred, &src_sharpened, &recon_sharpened, mvs, score_start,
   1065        unsharp_amount_start, -step_size, max_loop_count, max_filter_amount);
   1066  }
   1067 
   1068  aom_free_frame_buffer(&recon_sharpened);
   1069  aom_free_frame_buffer(&src_sharpened);
   1070  aom_free_frame_buffer(&recon_blurred);
   1071  aom_free_frame_buffer(&src_blurred);
   1072  aom_free(mvs);
   1073  return unsharp_amount;
   1074 }
   1075 
   1076 void av1_update_vmaf_curve(AV1_COMP *cpi) {
   1077  const YV12_BUFFER_CONFIG *source = cpi->source;
   1078  const YV12_BUFFER_CONFIG *recon = &cpi->common.cur_frame->buf;
   1079  const int bit_depth = cpi->td.mb.e_mbd.bd;
   1080  const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
   1081  const int layer_depth =
   1082      AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
   1083  double base_score;
   1084  const bool cal_vmaf_neg =
   1085      cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN;
   1086  aom_calc_vmaf(cpi->vmaf_info.vmaf_model, source, recon, bit_depth,
   1087                cal_vmaf_neg, &base_score);
   1088  cpi->vmaf_info.last_frame_vmaf[layer_depth] = base_score;
   1089  if (cpi->common.seq_params->use_highbitdepth) {
   1090    assert(source->flags & YV12_FLAG_HIGHBITDEPTH);
   1091    assert(recon->flags & YV12_FLAG_HIGHBITDEPTH);
   1092    cpi->vmaf_info.last_frame_ysse[layer_depth] =
   1093        (double)aom_highbd_get_y_sse(source, recon);
   1094  } else {
   1095    cpi->vmaf_info.last_frame_ysse[layer_depth] =
   1096        (double)aom_get_y_sse(source, recon);
   1097  }
   1098 
   1099  if (cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN) {
   1100    const YV12_BUFFER_CONFIG *last, *next;
   1101    get_neighbor_frames(cpi, &last, &next);
   1102    double best_unsharp_amount_start =
   1103        get_layer_value(cpi->vmaf_info.last_frame_unsharp_amount, layer_depth);
   1104    const int max_loop_count = 5;
   1105    cpi->vmaf_info.last_frame_unsharp_amount[layer_depth] =
   1106        find_best_frame_unsharp_amount_neg(cpi, source, recon, last, base_score,
   1107                                           best_unsharp_amount_start, 0.025,
   1108                                           max_loop_count, 1.01);
   1109  }
   1110 }