tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

blk_sse_sum_avx2.c (7480B)


      1 /*
      2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <immintrin.h>
     13 
     14 #include "config/aom_dsp_rtcd.h"
     15 
     16 static inline void accumulate_sse_sum(__m256i regx_sum, __m256i regx2_sum,
     17                                      int *x_sum, int64_t *x2_sum) {
     18  __m256i sum_buffer, sse_buffer;
     19  __m128i out_buffer;
     20 
     21  // Accumulate the various elements of register into first element.
     22  sum_buffer = _mm256_permute2f128_si256(regx_sum, regx_sum, 1);
     23  regx_sum = _mm256_add_epi32(sum_buffer, regx_sum);
     24  regx_sum = _mm256_add_epi32(regx_sum, _mm256_srli_si256(regx_sum, 8));
     25  regx_sum = _mm256_add_epi32(regx_sum, _mm256_srli_si256(regx_sum, 4));
     26 
     27  sse_buffer = _mm256_permute2f128_si256(regx2_sum, regx2_sum, 1);
     28  regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum);
     29  regx2_sum = _mm256_add_epi64(regx2_sum, _mm256_srli_si256(regx2_sum, 8));
     30 
     31  out_buffer = _mm256_castsi256_si128(regx_sum);
     32  *x_sum += _mm_cvtsi128_si32(out_buffer);
     33  out_buffer = _mm256_castsi256_si128(regx2_sum);
     34 #if AOM_ARCH_X86_64
     35  *x2_sum += _mm_cvtsi128_si64(out_buffer);
     36 #else
     37  {
     38    int64_t tmp;
     39    _mm_storel_epi64((__m128i *)&tmp, out_buffer);
     40    *x2_sum += tmp;
     41  }
     42 #endif
     43 }
     44 
     45 static inline void sse_sum_wd4_avx2(const int16_t *data, int stride, int bh,
     46                                    int *x_sum, int64_t *x2_sum) {
     47  __m128i row1, row2, row3;
     48  __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer,
     49      temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer;
     50  const int16_t *data_tmp = data;
     51  __m256i one = _mm256_set1_epi16(1);
     52  regx_sum = _mm256_setzero_si256();
     53  regx2_sum = regx_sum;
     54  sum_buffer = _mm256_setzero_si256();
     55  sse_buffer = sum_buffer;
     56 
     57  for (int j = 0; j < (bh >> 2); ++j) {
     58    // Load 4 rows at a time.
     59    row1 = _mm_loadl_epi64((__m128i const *)(data_tmp));
     60    row2 = _mm_loadl_epi64((__m128i const *)(data_tmp + stride));
     61    row1 = _mm_unpacklo_epi64(row1, row2);
     62    row2 = _mm_loadl_epi64((__m128i const *)(data_tmp + 2 * stride));
     63    row3 = _mm_loadl_epi64((__m128i const *)(data_tmp + 3 * stride));
     64    row2 = _mm_unpacklo_epi64(row2, row3);
     65    load_pixels =
     66        _mm256_insertf128_si256(_mm256_castsi128_si256(row1), row2, 1);
     67 
     68    row_sum_buffer = _mm256_madd_epi16(load_pixels, one);
     69    row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels);
     70    sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer);
     71    sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer);
     72    data_tmp += 4 * stride;
     73  }
     74 
     75  // To prevent 32-bit variable overflow, unpack the elements to 64-bit.
     76  temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256());
     77  temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256());
     78  sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2);
     79  regx_sum = _mm256_add_epi32(sum_buffer, regx_sum);
     80  regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum);
     81 
     82  accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum);
     83 }
     84 
     85 static inline void sse_sum_wd8_avx2(const int16_t *data, int stride, int bh,
     86                                    int *x_sum, int64_t *x2_sum) {
     87  __m128i load_128bit, load_next_128bit;
     88  __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer,
     89      temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer;
     90  const int16_t *data_tmp = data;
     91  __m256i one = _mm256_set1_epi16(1);
     92  regx_sum = _mm256_setzero_si256();
     93  regx2_sum = regx_sum;
     94  sum_buffer = _mm256_setzero_si256();
     95  sse_buffer = sum_buffer;
     96 
     97  for (int j = 0; j < (bh >> 1); ++j) {
     98    // Load 2 rows at a time.
     99    load_128bit = _mm_loadu_si128((__m128i const *)(data_tmp));
    100    load_next_128bit = _mm_loadu_si128((__m128i const *)(data_tmp + stride));
    101    load_pixels = _mm256_insertf128_si256(_mm256_castsi128_si256(load_128bit),
    102                                          load_next_128bit, 1);
    103 
    104    row_sum_buffer = _mm256_madd_epi16(load_pixels, one);
    105    row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels);
    106    sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer);
    107    sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer);
    108    data_tmp += 2 * stride;
    109  }
    110 
    111  temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256());
    112  temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256());
    113  sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2);
    114  regx_sum = _mm256_add_epi32(sum_buffer, regx_sum);
    115  regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum);
    116 
    117  accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum);
    118 }
    119 
    120 static inline void sse_sum_wd16_avx2(const int16_t *data, int stride, int bh,
    121                                     int *x_sum, int64_t *x2_sum,
    122                                     int loop_count) {
    123  __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer,
    124      temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer;
    125  const int16_t *data_tmp = data;
    126  __m256i one = _mm256_set1_epi16(1);
    127  regx_sum = _mm256_setzero_si256();
    128  regx2_sum = regx_sum;
    129  sum_buffer = _mm256_setzero_si256();
    130  sse_buffer = sum_buffer;
    131 
    132  for (int i = 0; i < loop_count; ++i) {
    133    data_tmp = data + 16 * i;
    134    for (int j = 0; j < bh; ++j) {
    135      load_pixels = _mm256_lddqu_si256((__m256i const *)(data_tmp));
    136 
    137      row_sum_buffer = _mm256_madd_epi16(load_pixels, one);
    138      row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels);
    139      sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer);
    140      sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer);
    141      data_tmp += stride;
    142    }
    143  }
    144 
    145  temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256());
    146  temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256());
    147  sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2);
    148  regx_sum = _mm256_add_epi32(sum_buffer, regx_sum);
    149  regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum);
    150 
    151  accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum);
    152 }
    153 
    154 void aom_get_blk_sse_sum_avx2(const int16_t *data, int stride, int bw, int bh,
    155                              int *x_sum, int64_t *x2_sum) {
    156  *x_sum = 0;
    157  *x2_sum = 0;
    158 
    159  if ((bh & 3) == 0) {
    160    switch (bw) {
    161        // For smaller block widths, compute multiple rows simultaneously.
    162      case 4: sse_sum_wd4_avx2(data, stride, bh, x_sum, x2_sum); break;
    163      case 8: sse_sum_wd8_avx2(data, stride, bh, x_sum, x2_sum); break;
    164      case 16:
    165      case 32:
    166        sse_sum_wd16_avx2(data, stride, bh, x_sum, x2_sum, bw >> 4);
    167        break;
    168      case 64:
    169        // 32-bit variables will overflow for 64 rows at a single time, so
    170        // compute 32 rows at a time.
    171        if (bh <= 32) {
    172          sse_sum_wd16_avx2(data, stride, bh, x_sum, x2_sum, bw >> 4);
    173        } else {
    174          sse_sum_wd16_avx2(data, stride, 32, x_sum, x2_sum, bw >> 4);
    175          sse_sum_wd16_avx2(data + 32 * stride, stride, 32, x_sum, x2_sum,
    176                            bw >> 4);
    177        }
    178        break;
    179 
    180      default: aom_get_blk_sse_sum_c(data, stride, bw, bh, x_sum, x2_sum);
    181    }
    182  } else {
    183    aom_get_blk_sse_sum_c(data, stride, bw, bh, x_sum, x2_sum);
    184  }
    185 }