tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

sum_squares_avx2.c (12712B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <immintrin.h>
     13 #include <smmintrin.h>
     14 
     15 #include "aom_dsp/x86/synonyms.h"
     16 #include "aom_dsp/x86/synonyms_avx2.h"
     17 #include "aom_dsp/x86/sum_squares_sse2.h"
     18 #include "config/aom_config.h"
     19 #include "config/aom_dsp_rtcd.h"
     20 
     21 static uint64_t aom_sum_squares_2d_i16_nxn_avx2(const int16_t *src, int stride,
     22                                                int width, int height) {
     23  uint64_t result;
     24  __m256i v_acc_q = _mm256_setzero_si256();
     25  const __m256i v_zext_mask_q = _mm256_set1_epi64x(~0u);
     26  for (int col = 0; col < height; col += 4) {
     27    __m256i v_acc_d = _mm256_setzero_si256();
     28    for (int row = 0; row < width; row += 16) {
     29      const int16_t *tempsrc = src + row;
     30      const __m256i v_val_0_w =
     31          _mm256_loadu_si256((const __m256i *)(tempsrc + 0 * stride));
     32      const __m256i v_val_1_w =
     33          _mm256_loadu_si256((const __m256i *)(tempsrc + 1 * stride));
     34      const __m256i v_val_2_w =
     35          _mm256_loadu_si256((const __m256i *)(tempsrc + 2 * stride));
     36      const __m256i v_val_3_w =
     37          _mm256_loadu_si256((const __m256i *)(tempsrc + 3 * stride));
     38 
     39      const __m256i v_sq_0_d = _mm256_madd_epi16(v_val_0_w, v_val_0_w);
     40      const __m256i v_sq_1_d = _mm256_madd_epi16(v_val_1_w, v_val_1_w);
     41      const __m256i v_sq_2_d = _mm256_madd_epi16(v_val_2_w, v_val_2_w);
     42      const __m256i v_sq_3_d = _mm256_madd_epi16(v_val_3_w, v_val_3_w);
     43 
     44      const __m256i v_sum_01_d = _mm256_add_epi32(v_sq_0_d, v_sq_1_d);
     45      const __m256i v_sum_23_d = _mm256_add_epi32(v_sq_2_d, v_sq_3_d);
     46      const __m256i v_sum_0123_d = _mm256_add_epi32(v_sum_01_d, v_sum_23_d);
     47 
     48      v_acc_d = _mm256_add_epi32(v_acc_d, v_sum_0123_d);
     49    }
     50    v_acc_q =
     51        _mm256_add_epi64(v_acc_q, _mm256_and_si256(v_acc_d, v_zext_mask_q));
     52    v_acc_q = _mm256_add_epi64(v_acc_q, _mm256_srli_epi64(v_acc_d, 32));
     53    src += 4 * stride;
     54  }
     55  __m128i lower_64_2_Value = _mm256_castsi256_si128(v_acc_q);
     56  __m128i higher_64_2_Value = _mm256_extracti128_si256(v_acc_q, 1);
     57  __m128i result_64_2_int = _mm_add_epi64(lower_64_2_Value, higher_64_2_Value);
     58 
     59  result_64_2_int = _mm_add_epi64(
     60      result_64_2_int, _mm_unpackhi_epi64(result_64_2_int, result_64_2_int));
     61 
     62  xx_storel_64(&result, result_64_2_int);
     63 
     64  return result;
     65 }
     66 
     67 uint64_t aom_sum_squares_2d_i16_avx2(const int16_t *src, int stride, int width,
     68                                     int height) {
     69  if (LIKELY(width == 4 && height == 4)) {
     70    return aom_sum_squares_2d_i16_4x4_sse2(src, stride);
     71  } else if (LIKELY(width == 4 && (height & 3) == 0)) {
     72    return aom_sum_squares_2d_i16_4xn_sse2(src, stride, height);
     73  } else if (LIKELY(width == 8 && (height & 3) == 0)) {
     74    return aom_sum_squares_2d_i16_nxn_sse2(src, stride, width, height);
     75  } else if (LIKELY(((width & 15) == 0) && ((height & 3) == 0))) {
     76    return aom_sum_squares_2d_i16_nxn_avx2(src, stride, width, height);
     77  } else {
     78    return aom_sum_squares_2d_i16_c(src, stride, width, height);
     79  }
     80 }
     81 
     82 static uint64_t aom_sum_sse_2d_i16_nxn_avx2(const int16_t *src, int stride,
     83                                            int width, int height, int *sum) {
     84  uint64_t result;
     85  const __m256i zero_reg = _mm256_setzero_si256();
     86  const __m256i one_reg = _mm256_set1_epi16(1);
     87 
     88  __m256i v_sse_total = zero_reg;
     89  __m256i v_sum_total = zero_reg;
     90 
     91  for (int col = 0; col < height; col += 4) {
     92    __m256i v_sse_row = zero_reg;
     93    for (int row = 0; row < width; row += 16) {
     94      const int16_t *tempsrc = src + row;
     95      const __m256i v_val_0_w =
     96          _mm256_loadu_si256((const __m256i *)(tempsrc + 0 * stride));
     97      const __m256i v_val_1_w =
     98          _mm256_loadu_si256((const __m256i *)(tempsrc + 1 * stride));
     99      const __m256i v_val_2_w =
    100          _mm256_loadu_si256((const __m256i *)(tempsrc + 2 * stride));
    101      const __m256i v_val_3_w =
    102          _mm256_loadu_si256((const __m256i *)(tempsrc + 3 * stride));
    103 
    104      const __m256i v_sum_01 = _mm256_add_epi16(v_val_0_w, v_val_1_w);
    105      const __m256i v_sum_23 = _mm256_add_epi16(v_val_2_w, v_val_3_w);
    106      __m256i v_sum_0123 = _mm256_add_epi16(v_sum_01, v_sum_23);
    107      v_sum_0123 = _mm256_madd_epi16(v_sum_0123, one_reg);
    108      v_sum_total = _mm256_add_epi32(v_sum_total, v_sum_0123);
    109 
    110      const __m256i v_sq_0_d = _mm256_madd_epi16(v_val_0_w, v_val_0_w);
    111      const __m256i v_sq_1_d = _mm256_madd_epi16(v_val_1_w, v_val_1_w);
    112      const __m256i v_sq_2_d = _mm256_madd_epi16(v_val_2_w, v_val_2_w);
    113      const __m256i v_sq_3_d = _mm256_madd_epi16(v_val_3_w, v_val_3_w);
    114      const __m256i v_sq_01_d = _mm256_add_epi32(v_sq_0_d, v_sq_1_d);
    115      const __m256i v_sq_23_d = _mm256_add_epi32(v_sq_2_d, v_sq_3_d);
    116      const __m256i v_sq_0123_d = _mm256_add_epi32(v_sq_01_d, v_sq_23_d);
    117      v_sse_row = _mm256_add_epi32(v_sse_row, v_sq_0123_d);
    118    }
    119    const __m256i v_sse_row_low = _mm256_unpacklo_epi32(v_sse_row, zero_reg);
    120    const __m256i v_sse_row_hi = _mm256_unpackhi_epi32(v_sse_row, zero_reg);
    121    v_sse_row = _mm256_add_epi64(v_sse_row_low, v_sse_row_hi);
    122    v_sse_total = _mm256_add_epi64(v_sse_total, v_sse_row);
    123    src += 4 * stride;
    124  }
    125 
    126  const __m128i v_sum_total_low = _mm256_castsi256_si128(v_sum_total);
    127  const __m128i v_sum_total_hi = _mm256_extracti128_si256(v_sum_total, 1);
    128  __m128i sum_128bit = _mm_add_epi32(v_sum_total_hi, v_sum_total_low);
    129  sum_128bit = _mm_add_epi32(sum_128bit, _mm_srli_si128(sum_128bit, 8));
    130  sum_128bit = _mm_add_epi32(sum_128bit, _mm_srli_si128(sum_128bit, 4));
    131  *sum += _mm_cvtsi128_si32(sum_128bit);
    132 
    133  __m128i v_sse_total_lo = _mm256_castsi256_si128(v_sse_total);
    134  __m128i v_sse_total_hi = _mm256_extracti128_si256(v_sse_total, 1);
    135  __m128i sse_128bit = _mm_add_epi64(v_sse_total_lo, v_sse_total_hi);
    136 
    137  sse_128bit =
    138      _mm_add_epi64(sse_128bit, _mm_unpackhi_epi64(sse_128bit, sse_128bit));
    139 
    140  xx_storel_64(&result, sse_128bit);
    141 
    142  return result;
    143 }
    144 
    145 uint64_t aom_sum_sse_2d_i16_avx2(const int16_t *src, int src_stride, int width,
    146                                 int height, int *sum) {
    147  if (LIKELY(width == 4 && height == 4)) {
    148    return aom_sum_sse_2d_i16_4x4_sse2(src, src_stride, sum);
    149  } else if (LIKELY(width == 4 && (height & 3) == 0)) {
    150    return aom_sum_sse_2d_i16_4xn_sse2(src, src_stride, height, sum);
    151  } else if (LIKELY(width == 8 && (height & 3) == 0)) {
    152    return aom_sum_sse_2d_i16_nxn_sse2(src, src_stride, width, height, sum);
    153  } else if (LIKELY(((width & 15) == 0) && ((height & 3) == 0))) {
    154    return aom_sum_sse_2d_i16_nxn_avx2(src, src_stride, width, height, sum);
    155  } else {
    156    return aom_sum_sse_2d_i16_c(src, src_stride, width, height, sum);
    157  }
    158 }
    159 
    160 // Accumulate sum of 16-bit elements in the vector
    161 static inline int32_t mm256_accumulate_epi16(__m256i vec_a) {
    162  __m128i vtmp1 = _mm256_extracti128_si256(vec_a, 1);
    163  __m128i vtmp2 = _mm256_castsi256_si128(vec_a);
    164  vtmp1 = _mm_add_epi16(vtmp1, vtmp2);
    165  vtmp2 = _mm_srli_si128(vtmp1, 8);
    166  vtmp1 = _mm_add_epi16(vtmp1, vtmp2);
    167  vtmp2 = _mm_srli_si128(vtmp1, 4);
    168  vtmp1 = _mm_add_epi16(vtmp1, vtmp2);
    169  vtmp2 = _mm_srli_si128(vtmp1, 2);
    170  vtmp1 = _mm_add_epi16(vtmp1, vtmp2);
    171  return _mm_extract_epi16(vtmp1, 0);
    172 }
    173 
    174 // Accumulate sum of 32-bit elements in the vector
    175 static inline int32_t mm256_accumulate_epi32(__m256i vec_a) {
    176  __m128i vtmp1 = _mm256_extracti128_si256(vec_a, 1);
    177  __m128i vtmp2 = _mm256_castsi256_si128(vec_a);
    178  vtmp1 = _mm_add_epi32(vtmp1, vtmp2);
    179  vtmp2 = _mm_srli_si128(vtmp1, 8);
    180  vtmp1 = _mm_add_epi32(vtmp1, vtmp2);
    181  vtmp2 = _mm_srli_si128(vtmp1, 4);
    182  vtmp1 = _mm_add_epi32(vtmp1, vtmp2);
    183  return _mm_cvtsi128_si32(vtmp1);
    184 }
    185 
    186 uint64_t aom_var_2d_u8_avx2(uint8_t *src, int src_stride, int width,
    187                            int height) {
    188  uint8_t *srcp;
    189  uint64_t s = 0, ss = 0;
    190  __m256i vzero = _mm256_setzero_si256();
    191  __m256i v_acc_sum = vzero;
    192  __m256i v_acc_sqs = vzero;
    193  int i, j;
    194 
    195  // Process 32 elements in a row
    196  for (i = 0; i < width - 31; i += 32) {
    197    srcp = src + i;
    198    // Process 8 columns at a time
    199    for (j = 0; j < height - 7; j += 8) {
    200      __m256i vsrc[8];
    201      for (int k = 0; k < 8; k++) {
    202        vsrc[k] = _mm256_loadu_si256((__m256i *)srcp);
    203        srcp += src_stride;
    204      }
    205      for (int k = 0; k < 8; k++) {
    206        __m256i vsrc0 = _mm256_unpacklo_epi8(vsrc[k], vzero);
    207        __m256i vsrc1 = _mm256_unpackhi_epi8(vsrc[k], vzero);
    208        v_acc_sum = _mm256_add_epi16(v_acc_sum, vsrc0);
    209        v_acc_sum = _mm256_add_epi16(v_acc_sum, vsrc1);
    210 
    211        __m256i vsqs0 = _mm256_madd_epi16(vsrc0, vsrc0);
    212        __m256i vsqs1 = _mm256_madd_epi16(vsrc1, vsrc1);
    213        v_acc_sqs = _mm256_add_epi32(v_acc_sqs, vsqs0);
    214        v_acc_sqs = _mm256_add_epi32(v_acc_sqs, vsqs1);
    215      }
    216 
    217      // Update total sum and clear the vectors
    218      s += mm256_accumulate_epi16(v_acc_sum);
    219      ss += mm256_accumulate_epi32(v_acc_sqs);
    220      v_acc_sum = vzero;
    221      v_acc_sqs = vzero;
    222    }
    223 
    224    // Process remaining rows (height not a multiple of 8)
    225    for (; j < height; j++) {
    226      __m256i vsrc = _mm256_loadu_si256((__m256i *)srcp);
    227      __m256i vsrc0 = _mm256_unpacklo_epi8(vsrc, vzero);
    228      __m256i vsrc1 = _mm256_unpackhi_epi8(vsrc, vzero);
    229      v_acc_sum = _mm256_add_epi16(v_acc_sum, vsrc0);
    230      v_acc_sum = _mm256_add_epi16(v_acc_sum, vsrc1);
    231 
    232      __m256i vsqs0 = _mm256_madd_epi16(vsrc0, vsrc0);
    233      __m256i vsqs1 = _mm256_madd_epi16(vsrc1, vsrc1);
    234      v_acc_sqs = _mm256_add_epi32(v_acc_sqs, vsqs0);
    235      v_acc_sqs = _mm256_add_epi32(v_acc_sqs, vsqs1);
    236 
    237      srcp += src_stride;
    238    }
    239 
    240    // Update total sum and clear the vectors
    241    s += mm256_accumulate_epi16(v_acc_sum);
    242    ss += mm256_accumulate_epi32(v_acc_sqs);
    243    v_acc_sum = vzero;
    244    v_acc_sqs = vzero;
    245  }
    246 
    247  // Process the remaining area using C
    248  srcp = src;
    249  for (int k = 0; k < height; k++) {
    250    for (int m = i; m < width; m++) {
    251      uint8_t val = srcp[m];
    252      s += val;
    253      ss += val * val;
    254    }
    255    srcp += src_stride;
    256  }
    257  return (ss - s * s / (width * height));
    258 }
    259 
    260 #if CONFIG_AV1_HIGHBITDEPTH
    261 uint64_t aom_var_2d_u16_avx2(uint8_t *src, int src_stride, int width,
    262                             int height) {
    263  uint16_t *srcp1 = CONVERT_TO_SHORTPTR(src), *srcp;
    264  uint64_t s = 0, ss = 0;
    265  __m256i vzero = _mm256_setzero_si256();
    266  __m256i v_acc_sum = vzero;
    267  __m256i v_acc_sqs = vzero;
    268  int i, j;
    269 
    270  // Process 16 elements in a row
    271  for (i = 0; i < width - 15; i += 16) {
    272    srcp = srcp1 + i;
    273    // Process 8 columns at a time
    274    for (j = 0; j < height - 8; j += 8) {
    275      __m256i vsrc[8];
    276      for (int k = 0; k < 8; k++) {
    277        vsrc[k] = _mm256_loadu_si256((__m256i *)srcp);
    278        srcp += src_stride;
    279      }
    280      for (int k = 0; k < 8; k++) {
    281        __m256i vsrc0 = _mm256_unpacklo_epi16(vsrc[k], vzero);
    282        __m256i vsrc1 = _mm256_unpackhi_epi16(vsrc[k], vzero);
    283        v_acc_sum = _mm256_add_epi32(vsrc0, v_acc_sum);
    284        v_acc_sum = _mm256_add_epi32(vsrc1, v_acc_sum);
    285 
    286        __m256i vsqs0 = _mm256_madd_epi16(vsrc[k], vsrc[k]);
    287        v_acc_sqs = _mm256_add_epi32(v_acc_sqs, vsqs0);
    288      }
    289 
    290      // Update total sum and clear the vectors
    291      s += mm256_accumulate_epi32(v_acc_sum);
    292      ss += mm256_accumulate_epi32(v_acc_sqs);
    293      v_acc_sum = vzero;
    294      v_acc_sqs = vzero;
    295    }
    296 
    297    // Process remaining rows (height not a multiple of 8)
    298    for (; j < height; j++) {
    299      __m256i vsrc = _mm256_loadu_si256((__m256i *)srcp);
    300      __m256i vsrc0 = _mm256_unpacklo_epi16(vsrc, vzero);
    301      __m256i vsrc1 = _mm256_unpackhi_epi16(vsrc, vzero);
    302      v_acc_sum = _mm256_add_epi32(vsrc0, v_acc_sum);
    303      v_acc_sum = _mm256_add_epi32(vsrc1, v_acc_sum);
    304 
    305      __m256i vsqs0 = _mm256_madd_epi16(vsrc, vsrc);
    306      v_acc_sqs = _mm256_add_epi32(v_acc_sqs, vsqs0);
    307      srcp += src_stride;
    308    }
    309 
    310    // Update total sum and clear the vectors
    311    s += mm256_accumulate_epi32(v_acc_sum);
    312    ss += mm256_accumulate_epi32(v_acc_sqs);
    313    v_acc_sum = vzero;
    314    v_acc_sqs = vzero;
    315  }
    316 
    317  // Process the remaining area using C
    318  srcp = srcp1;
    319  for (int k = 0; k < height; k++) {
    320    for (int m = i; m < width; m++) {
    321      uint16_t val = srcp[m];
    322      s += val;
    323      ss += val * val;
    324    }
    325    srcp += src_stride;
    326  }
    327  return (ss - s * s / (width * height));
    328 }
    329 #endif  // CONFIG_AV1_HIGHBITDEPTH