tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

rdopt_avx2.c (9735B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <immintrin.h>
     14 #include "aom_dsp/x86/mem_sse2.h"
     15 #include "aom_dsp/x86/synonyms_avx2.h"
     16 
     17 #include "config/av1_rtcd.h"
     18 #include "av1/encoder/rdopt.h"
     19 
     20 // Process horizontal and vertical correlations in a 4x4 block of pixels.
     21 // We actually use the 4x4 pixels to calculate correlations corresponding to
     22 // the top-left 3x3 pixels, so this function must be called with 1x1 overlap,
     23 // moving the window along/down by 3 pixels at a time.
     24 static inline void horver_correlation_4x4(const int16_t *diff, int stride,
     25                                          __m256i *xy_sum_32,
     26                                          __m256i *xz_sum_32, __m256i *x_sum_32,
     27                                          __m256i *x2_sum_32) {
     28  // Pixels in this 4x4   [ a b c d ]
     29  // are referred to as:  [ e f g h ]
     30  //                      [ i j k l ]
     31  //                      [ m n o p ]
     32 
     33  const __m256i pixels = _mm256_set_epi64x(
     34      loadu_int64(&diff[0 * stride]), loadu_int64(&diff[1 * stride]),
     35      loadu_int64(&diff[2 * stride]), loadu_int64(&diff[3 * stride]));
     36  // pixels = [d c b a h g f e] [l k j i p o n m] as i16
     37 
     38  const __m256i slli = _mm256_slli_epi64(pixels, 16);
     39  // slli = [c b a 0 g f e 0] [k j i 0 o n m 0] as i16
     40 
     41  const __m256i madd_xy = _mm256_madd_epi16(pixels, slli);
     42  // madd_xy = [bc+cd ab fg+gh ef] [jk+kl ij no+op mn] as i32
     43  *xy_sum_32 = _mm256_add_epi32(*xy_sum_32, madd_xy);
     44 
     45  // Permute control [3 2] [1 0] => [2 1] [0 0], 0b10010000 = 0x90
     46  const __m256i perm = _mm256_permute4x64_epi64(slli, 0x90);
     47  // perm = [g f e 0 k j i 0] [o n m 0 o n m 0] as i16
     48 
     49  const __m256i madd_xz = _mm256_madd_epi16(slli, perm);
     50  // madd_xz = [cg+bf ae gk+fj ei] [ko+jn im oo+nn mm] as i32
     51  *xz_sum_32 = _mm256_add_epi32(*xz_sum_32, madd_xz);
     52 
     53  // Sum every element in slli (and then also their squares)
     54  const __m256i madd1_slli = _mm256_madd_epi16(slli, _mm256_set1_epi16(1));
     55  // madd1_slli = [c+b a g+f e] [k+j i o+n m] as i32
     56  *x_sum_32 = _mm256_add_epi32(*x_sum_32, madd1_slli);
     57 
     58  const __m256i madd_slli = _mm256_madd_epi16(slli, slli);
     59  // madd_slli = [cc+bb aa gg+ff ee] [kk+jj ii oo+nn mm] as i32
     60  *x2_sum_32 = _mm256_add_epi32(*x2_sum_32, madd_slli);
     61 }
     62 
     63 void av1_get_horver_correlation_full_avx2(const int16_t *diff, int stride,
     64                                          int width, int height, float *hcorr,
     65                                          float *vcorr) {
     66  // The following notation is used:
     67  // x - current pixel
     68  // y - right neighbour pixel
     69  // z - below neighbour pixel
     70  // w - down-right neighbour pixel
     71  int64_t xy_sum = 0, xz_sum = 0;
     72  int64_t x_sum = 0, x2_sum = 0;
     73 
     74  // Process horizontal and vertical correlations through the body in 4x4
     75  // blocks.  This excludes the final row and column and possibly one extra
     76  // column depending how 3 divides into width and height
     77  int32_t xy_xz_tmp[8] = { 0 }, x_x2_tmp[8] = { 0 };
     78  __m256i xy_sum_32 = _mm256_setzero_si256();
     79  __m256i xz_sum_32 = _mm256_setzero_si256();
     80  __m256i x_sum_32 = _mm256_setzero_si256();
     81  __m256i x2_sum_32 = _mm256_setzero_si256();
     82  for (int i = 0; i <= height - 4; i += 3) {
     83    for (int j = 0; j <= width - 4; j += 3) {
     84      horver_correlation_4x4(&diff[i * stride + j], stride, &xy_sum_32,
     85                             &xz_sum_32, &x_sum_32, &x2_sum_32);
     86    }
     87    const __m256i hadd_xy_xz = _mm256_hadd_epi32(xy_sum_32, xz_sum_32);
     88    // hadd_xy_xz = [ae+bf+cg ei+fj+gk ab+bc+cd ef+fg+gh]
     89    //              [im+jn+ko mm+nn+oo ij+jk+kl mn+no+op] as i32
     90    yy_storeu_256(xy_xz_tmp, hadd_xy_xz);
     91    xy_sum += (int64_t)xy_xz_tmp[5] + xy_xz_tmp[4] + xy_xz_tmp[1];
     92    xz_sum += (int64_t)xy_xz_tmp[7] + xy_xz_tmp[6] + xy_xz_tmp[3];
     93 
     94    const __m256i hadd_x_x2 = _mm256_hadd_epi32(x_sum_32, x2_sum_32);
     95    // hadd_x_x2 = [aa+bb+cc ee+ff+gg a+b+c e+f+g]
     96    //             [ii+jj+kk mm+nn+oo i+j+k m+n+o] as i32
     97    yy_storeu_256(x_x2_tmp, hadd_x_x2);
     98    x_sum += (int64_t)x_x2_tmp[5] + x_x2_tmp[4] + x_x2_tmp[1];
     99    x2_sum += (int64_t)x_x2_tmp[7] + x_x2_tmp[6] + x_x2_tmp[3];
    100 
    101    xy_sum_32 = _mm256_setzero_si256();
    102    xz_sum_32 = _mm256_setzero_si256();
    103    x_sum_32 = _mm256_setzero_si256();
    104    x2_sum_32 = _mm256_setzero_si256();
    105  }
    106 
    107  // x_sum now covers every pixel except the final 1-2 rows and 1-2 cols
    108  int64_t x_finalrow = 0, x_finalcol = 0, x2_finalrow = 0, x2_finalcol = 0;
    109 
    110  // Do we have 2 rows remaining or just the one?  Note that width and height
    111  // are powers of 2, so each modulo 3 must be 1 or 2.
    112  if (height % 3 == 1) {  // Just horiz corrs on the final row
    113    const int16_t x0 = diff[(height - 1) * stride];
    114    x_sum += x0;
    115    x_finalrow += x0;
    116    x2_sum += x0 * x0;
    117    x2_finalrow += x0 * x0;
    118    for (int j = 0; j < width - 1; ++j) {
    119      const int16_t x = diff[(height - 1) * stride + j];
    120      const int16_t y = diff[(height - 1) * stride + j + 1];
    121      xy_sum += x * y;
    122      x_sum += y;
    123      x2_sum += y * y;
    124      x_finalrow += y;
    125      x2_finalrow += y * y;
    126    }
    127  } else {  // Two rows remaining to do
    128    const int16_t x0 = diff[(height - 2) * stride];
    129    const int16_t z0 = diff[(height - 1) * stride];
    130    x_sum += x0 + z0;
    131    x2_sum += x0 * x0 + z0 * z0;
    132    x_finalrow += z0;
    133    x2_finalrow += z0 * z0;
    134    for (int j = 0; j < width - 1; ++j) {
    135      const int16_t x = diff[(height - 2) * stride + j];
    136      const int16_t y = diff[(height - 2) * stride + j + 1];
    137      const int16_t z = diff[(height - 1) * stride + j];
    138      const int16_t w = diff[(height - 1) * stride + j + 1];
    139 
    140      // Horizontal and vertical correlations for the penultimate row:
    141      xy_sum += x * y;
    142      xz_sum += x * z;
    143 
    144      // Now just horizontal correlations for the final row:
    145      xy_sum += z * w;
    146 
    147      x_sum += y + w;
    148      x2_sum += y * y + w * w;
    149      x_finalrow += w;
    150      x2_finalrow += w * w;
    151    }
    152  }
    153 
    154  // Do we have 2 columns remaining or just the one?
    155  if (width % 3 == 1) {  // Just vert corrs on the final col
    156    const int16_t x0 = diff[width - 1];
    157    x_sum += x0;
    158    x_finalcol += x0;
    159    x2_sum += x0 * x0;
    160    x2_finalcol += x0 * x0;
    161    for (int i = 0; i < height - 1; ++i) {
    162      const int16_t x = diff[i * stride + width - 1];
    163      const int16_t z = diff[(i + 1) * stride + width - 1];
    164      xz_sum += x * z;
    165      x_finalcol += z;
    166      x2_finalcol += z * z;
    167      // So the bottom-right elements don't get counted twice:
    168      if (i < height - (height % 3 == 1 ? 2 : 3)) {
    169        x_sum += z;
    170        x2_sum += z * z;
    171      }
    172    }
    173  } else {  // Two cols remaining
    174    const int16_t x0 = diff[width - 2];
    175    const int16_t y0 = diff[width - 1];
    176    x_sum += x0 + y0;
    177    x2_sum += x0 * x0 + y0 * y0;
    178    x_finalcol += y0;
    179    x2_finalcol += y0 * y0;
    180    for (int i = 0; i < height - 1; ++i) {
    181      const int16_t x = diff[i * stride + width - 2];
    182      const int16_t y = diff[i * stride + width - 1];
    183      const int16_t z = diff[(i + 1) * stride + width - 2];
    184      const int16_t w = diff[(i + 1) * stride + width - 1];
    185 
    186      // Horizontal and vertical correlations for the penultimate col:
    187      // Skip these on the last iteration of this loop if we also had two
    188      // rows remaining, otherwise the final horizontal and vertical correlation
    189      // get erroneously processed twice
    190      if (i < height - 2 || height % 3 == 1) {
    191        xy_sum += x * y;
    192        xz_sum += x * z;
    193      }
    194 
    195      x_finalcol += w;
    196      x2_finalcol += w * w;
    197      // So the bottom-right elements don't get counted twice:
    198      if (i < height - (height % 3 == 1 ? 2 : 3)) {
    199        x_sum += z + w;
    200        x2_sum += z * z + w * w;
    201      }
    202 
    203      // Now just vertical correlations for the final column:
    204      xz_sum += y * w;
    205    }
    206  }
    207 
    208  // Calculate the simple sums and squared-sums
    209  int64_t x_firstrow = 0, x_firstcol = 0;
    210  int64_t x2_firstrow = 0, x2_firstcol = 0;
    211 
    212  for (int j = 0; j < width; ++j) {
    213    x_firstrow += diff[j];
    214    x2_firstrow += diff[j] * diff[j];
    215  }
    216  for (int i = 0; i < height; ++i) {
    217    x_firstcol += diff[i * stride];
    218    x2_firstcol += diff[i * stride] * diff[i * stride];
    219  }
    220 
    221  int64_t xhor_sum = x_sum - x_finalcol;
    222  int64_t xver_sum = x_sum - x_finalrow;
    223  int64_t y_sum = x_sum - x_firstcol;
    224  int64_t z_sum = x_sum - x_firstrow;
    225  int64_t x2hor_sum = x2_sum - x2_finalcol;
    226  int64_t x2ver_sum = x2_sum - x2_finalrow;
    227  int64_t y2_sum = x2_sum - x2_firstcol;
    228  int64_t z2_sum = x2_sum - x2_firstrow;
    229 
    230  const float num_hor = (float)(height * (width - 1));
    231  const float num_ver = (float)((height - 1) * width);
    232 
    233  const float xhor_var_n = x2hor_sum - (xhor_sum * xhor_sum) / num_hor;
    234  const float xver_var_n = x2ver_sum - (xver_sum * xver_sum) / num_ver;
    235 
    236  const float y_var_n = y2_sum - (y_sum * y_sum) / num_hor;
    237  const float z_var_n = z2_sum - (z_sum * z_sum) / num_ver;
    238 
    239  const float xy_var_n = xy_sum - (xhor_sum * y_sum) / num_hor;
    240  const float xz_var_n = xz_sum - (xver_sum * z_sum) / num_ver;
    241 
    242  if (xhor_var_n > 0 && y_var_n > 0) {
    243    *hcorr = xy_var_n / sqrtf(xhor_var_n * y_var_n);
    244    *hcorr = *hcorr < 0 ? 0 : *hcorr;
    245  } else {
    246    *hcorr = 1.0;
    247  }
    248  if (xver_var_n > 0 && z_var_n > 0) {
    249    *vcorr = xz_var_n / sqrtf(xver_var_n * z_var_n);
    250    *vcorr = *vcorr < 0 ? 0 : *vcorr;
    251  } else {
    252    *vcorr = 1.0;
    253  }
    254 }