tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

av1_convolve_scale_sse4.c (21556B)


      1 /*
      2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <smmintrin.h>
     14 
     15 #include "config/av1_rtcd.h"
     16 
     17 #include "aom_dsp/aom_dsp_common.h"
     18 #include "aom_dsp/aom_filter.h"
     19 #include "av1/common/convolve.h"
     20 
     21 // A specialised version of hfilter, the horizontal filter for
     22 // av1_convolve_2d_scale_sse4_1. This version only supports 8 tap filters.
     23 static void hfilter8(const uint8_t *src, int src_stride, int16_t *dst, int w,
     24                     int h, int subpel_x_qn, int x_step_qn,
     25                     const InterpFilterParams *filter_params, int round) {
     26  const int bd = 8;
     27  const int ntaps = 8;
     28 
     29  src -= ntaps / 2 - 1;
     30 
     31  int32_t round_add32 = (1 << round) / 2 + (1 << (bd + FILTER_BITS - 1));
     32  const __m128i round_add = _mm_set1_epi32(round_add32);
     33  const __m128i round_shift = _mm_cvtsi32_si128(round);
     34 
     35  int x_qn = subpel_x_qn;
     36  for (int x = 0; x < w; ++x, x_qn += x_step_qn) {
     37    const uint8_t *const src_col = src + (x_qn >> SCALE_SUBPEL_BITS);
     38    const int filter_idx = (x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
     39    assert(filter_idx < SUBPEL_SHIFTS);
     40    const int16_t *filter =
     41        av1_get_interp_filter_subpel_kernel(filter_params, filter_idx);
     42 
     43    // Load the filter coefficients
     44    const __m128i coefflo = _mm_loadu_si128((__m128i *)filter);
     45    const __m128i zero = _mm_castps_si128(_mm_setzero_ps());
     46 
     47    int y;
     48    for (y = 0; y <= h - 4; y += 4) {
     49      const uint8_t *const src0 = src_col + y * src_stride;
     50      const uint8_t *const src1 = src0 + 1 * src_stride;
     51      const uint8_t *const src2 = src0 + 2 * src_stride;
     52      const uint8_t *const src3 = src0 + 3 * src_stride;
     53 
     54      // Load up source data. This is 8-bit input data; each load is just
     55      // loading the lower half of the register and gets 8 pixels
     56      const __m128i data08 = _mm_loadl_epi64((__m128i *)src0);
     57      const __m128i data18 = _mm_loadl_epi64((__m128i *)src1);
     58      const __m128i data28 = _mm_loadl_epi64((__m128i *)src2);
     59      const __m128i data38 = _mm_loadl_epi64((__m128i *)src3);
     60 
     61      // Now zero-extend up to 16-bit precision by interleaving with
     62      // zeros. Drop the upper half of each register (which just had zeros)
     63      const __m128i data0lo = _mm_unpacklo_epi8(data08, zero);
     64      const __m128i data1lo = _mm_unpacklo_epi8(data18, zero);
     65      const __m128i data2lo = _mm_unpacklo_epi8(data28, zero);
     66      const __m128i data3lo = _mm_unpacklo_epi8(data38, zero);
     67 
     68      // Multiply by coefficients
     69      const __m128i conv0lo = _mm_madd_epi16(data0lo, coefflo);
     70      const __m128i conv1lo = _mm_madd_epi16(data1lo, coefflo);
     71      const __m128i conv2lo = _mm_madd_epi16(data2lo, coefflo);
     72      const __m128i conv3lo = _mm_madd_epi16(data3lo, coefflo);
     73 
     74      // Reduce horizontally and add
     75      const __m128i conv01lo = _mm_hadd_epi32(conv0lo, conv1lo);
     76      const __m128i conv23lo = _mm_hadd_epi32(conv2lo, conv3lo);
     77      const __m128i conv = _mm_hadd_epi32(conv01lo, conv23lo);
     78 
     79      // Divide down by (1 << round), rounding to nearest.
     80      __m128i shifted =
     81          _mm_sra_epi32(_mm_add_epi32(conv, round_add), round_shift);
     82 
     83      shifted = _mm_packus_epi32(shifted, shifted);
     84      // Write transposed to the output
     85      _mm_storel_epi64((__m128i *)(dst + y + x * h), shifted);
     86    }
     87    for (; y < h; ++y) {
     88      const uint8_t *const src_row = src_col + y * src_stride;
     89 
     90      int32_t sum = (1 << (bd + FILTER_BITS - 1));
     91      for (int k = 0; k < ntaps; ++k) {
     92        sum += filter[k] * src_row[k];
     93      }
     94 
     95      dst[y + x * h] = ROUND_POWER_OF_TWO(sum, round);
     96    }
     97  }
     98 }
     99 
    100 static __m128i convolve_16_8(const int16_t *src, __m128i coeff) {
    101  __m128i data = _mm_loadu_si128((__m128i *)src);
    102  return _mm_madd_epi16(data, coeff);
    103 }
    104 
    105 // A specialised version of vfilter, the vertical filter for
    106 // av1_convolve_2d_scale_sse4_1. This version only supports 8 tap filters.
    107 static void vfilter8(const int16_t *src, int src_stride, uint8_t *dst,
    108                     int dst_stride, int w, int h, int subpel_y_qn,
    109                     int y_step_qn, const InterpFilterParams *filter_params,
    110                     const ConvolveParams *conv_params, int bd) {
    111  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
    112  const int ntaps = 8;
    113 
    114  const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
    115 
    116  const int32_t sub32 = ((1 << (offset_bits - conv_params->round_1)) +
    117                         (1 << (offset_bits - conv_params->round_1 - 1)));
    118  const __m128i sub = _mm_set1_epi16(sub32);
    119 
    120  CONV_BUF_TYPE *dst16 = conv_params->dst;
    121  const int dst16_stride = conv_params->dst_stride;
    122  const int bits =
    123      FILTER_BITS * 2 - conv_params->round_0 - conv_params->round_1;
    124  const __m128i bits_shift = _mm_cvtsi32_si128(bits);
    125  const __m128i bits_const = _mm_set1_epi16(((1 << bits) >> 1));
    126  const __m128i round_shift_add =
    127      _mm_set1_epi32(((1 << conv_params->round_1) >> 1));
    128  const __m128i res_add_const = _mm_set1_epi32(1 << offset_bits);
    129 
    130  const int w0 = conv_params->fwd_offset;
    131  const int w1 = conv_params->bck_offset;
    132  const __m128i wt0 = _mm_set1_epi16((short)w0);
    133  const __m128i wt1 = _mm_set1_epi16((short)w1);
    134  const __m128i wt = _mm_unpacklo_epi16(wt0, wt1);
    135 
    136  int y_qn = subpel_y_qn;
    137  for (int y = 0; y < h; ++y, y_qn += y_step_qn) {
    138    const int16_t *src_y = src + (y_qn >> SCALE_SUBPEL_BITS);
    139    const int filter_idx = (y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
    140    assert(filter_idx < SUBPEL_SHIFTS);
    141    const int16_t *filter =
    142        av1_get_interp_filter_subpel_kernel(filter_params, filter_idx);
    143 
    144    const __m128i coeff0716 = _mm_loadu_si128((__m128i *)filter);
    145    int x;
    146    for (x = 0; x <= w - 4; x += 4) {
    147      const int16_t *const src0 = src_y + x * src_stride;
    148      const int16_t *const src1 = src0 + 1 * src_stride;
    149      const int16_t *const src2 = src0 + 2 * src_stride;
    150      const int16_t *const src3 = src0 + 3 * src_stride;
    151 
    152      // Load the source data for the three rows, adding the three registers of
    153      // convolved products to one as we go (conv0..conv3) to avoid the
    154      // register pressure getting too high.
    155      const __m128i conv0 = convolve_16_8(src0, coeff0716);
    156      const __m128i conv1 = convolve_16_8(src1, coeff0716);
    157      const __m128i conv2 = convolve_16_8(src2, coeff0716);
    158      const __m128i conv3 = convolve_16_8(src3, coeff0716);
    159 
    160      // Now reduce horizontally to get one lane for each result
    161      const __m128i conv01 = _mm_hadd_epi32(conv0, conv1);
    162      const __m128i conv23 = _mm_hadd_epi32(conv2, conv3);
    163      __m128i conv = _mm_hadd_epi32(conv01, conv23);
    164 
    165      conv = _mm_add_epi32(conv, res_add_const);
    166      // Divide down by (1 << round_1), rounding to nearest and subtract sub32.
    167      __m128i shifted =
    168          _mm_sra_epi32(_mm_add_epi32(conv, round_shift_add), round_shift);
    169 
    170      uint8_t *dst_x = dst + y * dst_stride + x;
    171      __m128i result;
    172      __m128i shifted_16 = _mm_packus_epi32(shifted, shifted);
    173 
    174      if (conv_params->is_compound) {
    175        CONV_BUF_TYPE *dst_16_x = dst16 + y * dst16_stride + x;
    176        if (conv_params->do_average) {
    177          const __m128i p_16 = _mm_loadl_epi64((__m128i *)dst_16_x);
    178          if (conv_params->use_dist_wtd_comp_avg) {
    179            const __m128i p_16_lo = _mm_unpacklo_epi16(p_16, shifted_16);
    180            const __m128i wt_res_lo = _mm_madd_epi16(p_16_lo, wt);
    181            const __m128i shifted_32 =
    182                _mm_srai_epi32(wt_res_lo, DIST_PRECISION_BITS);
    183            shifted_16 = _mm_packus_epi32(shifted_32, shifted_32);
    184          } else {
    185            shifted_16 = _mm_srai_epi16(_mm_add_epi16(p_16, shifted_16), 1);
    186          }
    187          const __m128i subbed = _mm_sub_epi16(shifted_16, sub);
    188          result = _mm_sra_epi16(_mm_add_epi16(subbed, bits_const), bits_shift);
    189          const __m128i result_8 = _mm_packus_epi16(result, result);
    190          *(int *)dst_x = _mm_cvtsi128_si32(result_8);
    191        } else {
    192          _mm_storel_epi64((__m128i *)dst_16_x, shifted_16);
    193        }
    194      } else {
    195        const __m128i subbed = _mm_sub_epi16(shifted_16, sub);
    196        result = _mm_sra_epi16(_mm_add_epi16(subbed, bits_const), bits_shift);
    197        const __m128i result_8 = _mm_packus_epi16(result, result);
    198        *(int *)dst_x = _mm_cvtsi128_si32(result_8);
    199      }
    200    }
    201    for (; x < w; ++x) {
    202      const int16_t *src_x = src_y + x * src_stride;
    203      int32_t sum = 1 << offset_bits;
    204      for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k];
    205      CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1);
    206 
    207      if (conv_params->is_compound) {
    208        if (conv_params->do_average) {
    209          int32_t tmp = dst16[y * dst16_stride + x];
    210          if (conv_params->use_dist_wtd_comp_avg) {
    211            tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
    212            tmp = tmp >> DIST_PRECISION_BITS;
    213          } else {
    214            tmp += res;
    215            tmp = tmp >> 1;
    216          }
    217          /* Subtract round offset and convolve round */
    218          tmp = tmp - sub32;
    219          dst[y * dst_stride + x] = clip_pixel(ROUND_POWER_OF_TWO(tmp, bits));
    220        } else {
    221          dst16[y * dst16_stride + x] = res;
    222        }
    223      } else {
    224        /* Subtract round offset and convolve round */
    225        int32_t tmp = res - ((1 << (offset_bits - conv_params->round_1)) +
    226                             (1 << (offset_bits - conv_params->round_1 - 1)));
    227        dst[y * dst_stride + x] = clip_pixel(ROUND_POWER_OF_TWO(tmp, bits));
    228      }
    229    }
    230  }
    231 }
    232 void av1_convolve_2d_scale_sse4_1(const uint8_t *src, int src_stride,
    233                                  uint8_t *dst8, int dst8_stride, int w, int h,
    234                                  const InterpFilterParams *filter_params_x,
    235                                  const InterpFilterParams *filter_params_y,
    236                                  const int subpel_x_qn, const int x_step_qn,
    237                                  const int subpel_y_qn, const int y_step_qn,
    238                                  ConvolveParams *conv_params) {
    239  int16_t tmp[(2 * MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE];
    240  int im_h = (((h - 1) * y_step_qn + subpel_y_qn) >> SCALE_SUBPEL_BITS) +
    241             filter_params_y->taps;
    242 
    243  const int xtaps = filter_params_x->taps;
    244  const int ytaps = filter_params_y->taps;
    245  const int fo_vert = ytaps / 2 - 1;
    246  assert((xtaps == 8) && (ytaps == 8));
    247  (void)xtaps;
    248 
    249  // horizontal filter
    250  hfilter8(src - fo_vert * src_stride, src_stride, tmp, w, im_h, subpel_x_qn,
    251           x_step_qn, filter_params_x, conv_params->round_0);
    252 
    253  // vertical filter (input is transposed)
    254  vfilter8(tmp, im_h, dst8, dst8_stride, w, h, subpel_y_qn, y_step_qn,
    255           filter_params_y, conv_params, 8);
    256 }
    257 
    258 #if CONFIG_AV1_HIGHBITDEPTH
    259 // A specialised version of hfilter, the horizontal filter for
    260 // av1_highbd_convolve_2d_scale_sse4_1. This version only supports 8 tap
    261 // filters.
    262 static void highbd_hfilter8(const uint16_t *src, int src_stride, int16_t *dst,
    263                            int w, int h, int subpel_x_qn, int x_step_qn,
    264                            const InterpFilterParams *filter_params, int round,
    265                            int bd) {
    266  const int ntaps = 8;
    267 
    268  src -= ntaps / 2 - 1;
    269 
    270  int32_t round_add32 = (1 << round) / 2 + (1 << (bd + FILTER_BITS - 1));
    271  const __m128i round_add = _mm_set1_epi32(round_add32);
    272  const __m128i round_shift = _mm_cvtsi32_si128(round);
    273 
    274  int x_qn = subpel_x_qn;
    275  for (int x = 0; x < w; ++x, x_qn += x_step_qn) {
    276    const uint16_t *const src_col = src + (x_qn >> SCALE_SUBPEL_BITS);
    277    const int filter_idx = (x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
    278    assert(filter_idx < SUBPEL_SHIFTS);
    279    const int16_t *filter =
    280        av1_get_interp_filter_subpel_kernel(filter_params, filter_idx);
    281 
    282    // Load the filter coefficients
    283    const __m128i coefflo = _mm_loadu_si128((__m128i *)filter);
    284 
    285    int y;
    286    for (y = 0; y <= h - 4; y += 4) {
    287      const uint16_t *const src0 = src_col + y * src_stride;
    288      const uint16_t *const src1 = src0 + 1 * src_stride;
    289      const uint16_t *const src2 = src0 + 2 * src_stride;
    290      const uint16_t *const src3 = src0 + 3 * src_stride;
    291 
    292      // Load up source data. This is 16-bit input data, so each load gets the 8
    293      // pixels we need.
    294      const __m128i data0lo = _mm_loadu_si128((__m128i *)src0);
    295      const __m128i data1lo = _mm_loadu_si128((__m128i *)src1);
    296      const __m128i data2lo = _mm_loadu_si128((__m128i *)src2);
    297      const __m128i data3lo = _mm_loadu_si128((__m128i *)src3);
    298 
    299      // Multiply by coefficients
    300      const __m128i conv0lo = _mm_madd_epi16(data0lo, coefflo);
    301      const __m128i conv1lo = _mm_madd_epi16(data1lo, coefflo);
    302      const __m128i conv2lo = _mm_madd_epi16(data2lo, coefflo);
    303      const __m128i conv3lo = _mm_madd_epi16(data3lo, coefflo);
    304 
    305      // Reduce horizontally and add
    306      const __m128i conv01lo = _mm_hadd_epi32(conv0lo, conv1lo);
    307      const __m128i conv23lo = _mm_hadd_epi32(conv2lo, conv3lo);
    308      const __m128i conv = _mm_hadd_epi32(conv01lo, conv23lo);
    309 
    310      // Divide down by (1 << round), rounding to nearest.
    311      __m128i shifted =
    312          _mm_sra_epi32(_mm_add_epi32(conv, round_add), round_shift);
    313 
    314      shifted = _mm_packus_epi32(shifted, shifted);
    315      // Write transposed to the output
    316      _mm_storel_epi64((__m128i *)(dst + y + x * h), shifted);
    317    }
    318    for (; y < h; ++y) {
    319      const uint16_t *const src_row = src_col + y * src_stride;
    320 
    321      int32_t sum = (1 << (bd + FILTER_BITS - 1));
    322      for (int k = 0; k < ntaps; ++k) {
    323        sum += filter[k] * src_row[k];
    324      }
    325 
    326      dst[y + x * h] = ROUND_POWER_OF_TWO(sum, round);
    327    }
    328  }
    329 }
    330 // A specialised version of vfilter, the vertical filter for
    331 // av1_highbd_convolve_2d_scale_sse4_1. This version only supports 8 tap
    332 // filters.
    333 static void highbd_vfilter8(const int16_t *src, int src_stride, uint16_t *dst,
    334                            int dst_stride, int w, int h, int subpel_y_qn,
    335                            int y_step_qn,
    336                            const InterpFilterParams *filter_params,
    337                            const ConvolveParams *conv_params, int bd) {
    338  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
    339  const int ntaps = 8;
    340 
    341  const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
    342 
    343  const int32_t sub32 = ((1 << (offset_bits - conv_params->round_1)) +
    344                         (1 << (offset_bits - conv_params->round_1 - 1)));
    345  const __m128i sub = _mm_set1_epi32(sub32);
    346 
    347  CONV_BUF_TYPE *dst16 = conv_params->dst;
    348  const int dst16_stride = conv_params->dst_stride;
    349  const __m128i clip_pixel_ =
    350      _mm_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
    351  const int bits =
    352      FILTER_BITS * 2 - conv_params->round_0 - conv_params->round_1;
    353  const __m128i bits_shift = _mm_cvtsi32_si128(bits);
    354  const __m128i bits_const = _mm_set1_epi32(((1 << bits) >> 1));
    355  const __m128i round_shift_add =
    356      _mm_set1_epi32(((1 << conv_params->round_1) >> 1));
    357  const __m128i res_add_const = _mm_set1_epi32(1 << offset_bits);
    358  const int round_bits =
    359      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
    360  __m128i round_bits_shift = _mm_cvtsi32_si128(round_bits);
    361  __m128i round_bits_const = _mm_set1_epi32(((1 << round_bits) >> 1));
    362 
    363  const int w0 = conv_params->fwd_offset;
    364  const int w1 = conv_params->bck_offset;
    365  const __m128i wt0 = _mm_set1_epi32(w0);
    366  const __m128i wt1 = _mm_set1_epi32(w1);
    367 
    368  int y_qn = subpel_y_qn;
    369  for (int y = 0; y < h; ++y, y_qn += y_step_qn) {
    370    const int16_t *src_y = src + (y_qn >> SCALE_SUBPEL_BITS);
    371    const int filter_idx = (y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
    372    assert(filter_idx < SUBPEL_SHIFTS);
    373    const int16_t *filter =
    374        av1_get_interp_filter_subpel_kernel(filter_params, filter_idx);
    375 
    376    const __m128i coeff0716 = _mm_loadu_si128((__m128i *)filter);
    377    int x;
    378    for (x = 0; x <= w - 4; x += 4) {
    379      const int16_t *const src0 = src_y + x * src_stride;
    380      const int16_t *const src1 = src0 + 1 * src_stride;
    381      const int16_t *const src2 = src0 + 2 * src_stride;
    382      const int16_t *const src3 = src0 + 3 * src_stride;
    383 
    384      // Load the source data for the three rows, adding the three registers of
    385      // convolved products to one as we go (conv0..conv3) to avoid the
    386      // register pressure getting too high.
    387      const __m128i conv0 = convolve_16_8(src0, coeff0716);
    388      const __m128i conv1 = convolve_16_8(src1, coeff0716);
    389      const __m128i conv2 = convolve_16_8(src2, coeff0716);
    390      const __m128i conv3 = convolve_16_8(src3, coeff0716);
    391 
    392      // Now reduce horizontally to get one lane for each result
    393      const __m128i conv01 = _mm_hadd_epi32(conv0, conv1);
    394      const __m128i conv23 = _mm_hadd_epi32(conv2, conv3);
    395      __m128i conv = _mm_hadd_epi32(conv01, conv23);
    396      conv = _mm_add_epi32(conv, res_add_const);
    397 
    398      // Divide down by (1 << round_1), rounding to nearest and subtract sub32.
    399      __m128i shifted =
    400          _mm_sra_epi32(_mm_add_epi32(conv, round_shift_add), round_shift);
    401 
    402      uint16_t *dst_x = dst + y * dst_stride + x;
    403 
    404      __m128i result;
    405      if (conv_params->is_compound) {
    406        CONV_BUF_TYPE *dst_16_x = dst16 + y * dst16_stride + x;
    407        if (conv_params->do_average) {
    408          __m128i p_32 =
    409              _mm_cvtepu16_epi32(_mm_loadl_epi64((__m128i *)dst_16_x));
    410 
    411          if (conv_params->use_dist_wtd_comp_avg) {
    412            shifted = _mm_add_epi32(_mm_mullo_epi32(p_32, wt0),
    413                                    _mm_mullo_epi32(shifted, wt1));
    414            shifted = _mm_srai_epi32(shifted, DIST_PRECISION_BITS);
    415          } else {
    416            shifted = _mm_srai_epi32(_mm_add_epi32(p_32, shifted), 1);
    417          }
    418          result = _mm_sub_epi32(shifted, sub);
    419          result = _mm_sra_epi32(_mm_add_epi32(result, round_bits_const),
    420                                 round_bits_shift);
    421 
    422          result = _mm_packus_epi32(result, result);
    423          result = _mm_min_epi16(result, clip_pixel_);
    424          _mm_storel_epi64((__m128i *)dst_x, result);
    425        } else {
    426          __m128i shifted_16 = _mm_packus_epi32(shifted, shifted);
    427          _mm_storel_epi64((__m128i *)dst_16_x, shifted_16);
    428        }
    429      } else {
    430        result = _mm_sub_epi32(shifted, sub);
    431        result = _mm_sra_epi16(_mm_add_epi32(result, bits_const), bits_shift);
    432        result = _mm_packus_epi32(result, result);
    433        result = _mm_min_epi16(result, clip_pixel_);
    434        _mm_storel_epi64((__m128i *)dst_x, result);
    435      }
    436    }
    437 
    438    for (; x < w; ++x) {
    439      const int16_t *src_x = src_y + x * src_stride;
    440      int32_t sum = 1 << offset_bits;
    441      for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k];
    442      CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1);
    443      if (conv_params->is_compound) {
    444        if (conv_params->do_average) {
    445          int32_t tmp = dst16[y * dst16_stride + x];
    446          if (conv_params->use_dist_wtd_comp_avg) {
    447            tmp = tmp * conv_params->fwd_offset + res * conv_params->bck_offset;
    448            tmp = tmp >> DIST_PRECISION_BITS;
    449          } else {
    450            tmp += res;
    451            tmp = tmp >> 1;
    452          }
    453          /* Subtract round offset and convolve round */
    454          tmp = tmp - ((1 << (offset_bits - conv_params->round_1)) +
    455                       (1 << (offset_bits - conv_params->round_1 - 1)));
    456          dst[y * dst_stride + x] =
    457              clip_pixel_highbd(ROUND_POWER_OF_TWO(tmp, bits), bd);
    458        } else {
    459          dst16[y * dst16_stride + x] = res;
    460        }
    461      } else {
    462        /* Subtract round offset and convolve round */
    463        int32_t tmp = res - ((1 << (offset_bits - conv_params->round_1)) +
    464                             (1 << (offset_bits - conv_params->round_1 - 1)));
    465        dst[y * dst_stride + x] =
    466            clip_pixel_highbd(ROUND_POWER_OF_TWO(tmp, bits), bd);
    467      }
    468    }
    469  }
    470 }
    471 
    472 void av1_highbd_convolve_2d_scale_sse4_1(
    473    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
    474    int h, const InterpFilterParams *filter_params_x,
    475    const InterpFilterParams *filter_params_y, const int subpel_x_qn,
    476    const int x_step_qn, const int subpel_y_qn, const int y_step_qn,
    477    ConvolveParams *conv_params, int bd) {
    478  // TODO(yaowu): Move this out of stack
    479  DECLARE_ALIGNED(16, int16_t,
    480                  tmp[(2 * MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
    481  int im_h = (((h - 1) * y_step_qn + subpel_y_qn) >> SCALE_SUBPEL_BITS) +
    482             filter_params_y->taps;
    483  const int xtaps = filter_params_x->taps;
    484  const int ytaps = filter_params_y->taps;
    485  const int fo_vert = ytaps / 2 - 1;
    486 
    487  memset(tmp, 0, sizeof(tmp));
    488  assert((xtaps == 8) && (ytaps == 8));
    489  (void)xtaps;
    490 
    491  // horizontal filter
    492  highbd_hfilter8(src - fo_vert * src_stride, src_stride, tmp, w, im_h,
    493                  subpel_x_qn, x_step_qn, filter_params_x, conv_params->round_0,
    494                  bd);
    495 
    496  // vertical filter (input is transposed)
    497  highbd_vfilter8(tmp, im_h, dst, dst_stride, w, h, subpel_y_qn, y_step_qn,
    498                  filter_params_y, conv_params, bd);
    499 }
    500 #endif  // CONFIG_AV1_HIGHBITDEPTH