rdopt_avx2.c (9735B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <assert.h> 13 #include <immintrin.h> 14 #include "aom_dsp/x86/mem_sse2.h" 15 #include "aom_dsp/x86/synonyms_avx2.h" 16 17 #include "config/av1_rtcd.h" 18 #include "av1/encoder/rdopt.h" 19 20 // Process horizontal and vertical correlations in a 4x4 block of pixels. 21 // We actually use the 4x4 pixels to calculate correlations corresponding to 22 // the top-left 3x3 pixels, so this function must be called with 1x1 overlap, 23 // moving the window along/down by 3 pixels at a time. 24 static inline void horver_correlation_4x4(const int16_t *diff, int stride, 25 __m256i *xy_sum_32, 26 __m256i *xz_sum_32, __m256i *x_sum_32, 27 __m256i *x2_sum_32) { 28 // Pixels in this 4x4 [ a b c d ] 29 // are referred to as: [ e f g h ] 30 // [ i j k l ] 31 // [ m n o p ] 32 33 const __m256i pixels = _mm256_set_epi64x( 34 loadu_int64(&diff[0 * stride]), loadu_int64(&diff[1 * stride]), 35 loadu_int64(&diff[2 * stride]), loadu_int64(&diff[3 * stride])); 36 // pixels = [d c b a h g f e] [l k j i p o n m] as i16 37 38 const __m256i slli = _mm256_slli_epi64(pixels, 16); 39 // slli = [c b a 0 g f e 0] [k j i 0 o n m 0] as i16 40 41 const __m256i madd_xy = _mm256_madd_epi16(pixels, slli); 42 // madd_xy = [bc+cd ab fg+gh ef] [jk+kl ij no+op mn] as i32 43 *xy_sum_32 = _mm256_add_epi32(*xy_sum_32, madd_xy); 44 45 // Permute control [3 2] [1 0] => [2 1] [0 0], 0b10010000 = 0x90 46 const __m256i perm = _mm256_permute4x64_epi64(slli, 0x90); 47 // perm = [g f e 0 k j i 0] [o n m 0 o n m 0] as i16 48 49 const __m256i madd_xz = _mm256_madd_epi16(slli, perm); 50 // madd_xz = [cg+bf ae gk+fj ei] [ko+jn im oo+nn mm] as i32 51 *xz_sum_32 = _mm256_add_epi32(*xz_sum_32, madd_xz); 52 53 // Sum every element in slli (and then also their squares) 54 const __m256i madd1_slli = _mm256_madd_epi16(slli, _mm256_set1_epi16(1)); 55 // madd1_slli = [c+b a g+f e] [k+j i o+n m] as i32 56 *x_sum_32 = _mm256_add_epi32(*x_sum_32, madd1_slli); 57 58 const __m256i madd_slli = _mm256_madd_epi16(slli, slli); 59 // madd_slli = [cc+bb aa gg+ff ee] [kk+jj ii oo+nn mm] as i32 60 *x2_sum_32 = _mm256_add_epi32(*x2_sum_32, madd_slli); 61 } 62 63 void av1_get_horver_correlation_full_avx2(const int16_t *diff, int stride, 64 int width, int height, float *hcorr, 65 float *vcorr) { 66 // The following notation is used: 67 // x - current pixel 68 // y - right neighbour pixel 69 // z - below neighbour pixel 70 // w - down-right neighbour pixel 71 int64_t xy_sum = 0, xz_sum = 0; 72 int64_t x_sum = 0, x2_sum = 0; 73 74 // Process horizontal and vertical correlations through the body in 4x4 75 // blocks. This excludes the final row and column and possibly one extra 76 // column depending how 3 divides into width and height 77 int32_t xy_xz_tmp[8] = { 0 }, x_x2_tmp[8] = { 0 }; 78 __m256i xy_sum_32 = _mm256_setzero_si256(); 79 __m256i xz_sum_32 = _mm256_setzero_si256(); 80 __m256i x_sum_32 = _mm256_setzero_si256(); 81 __m256i x2_sum_32 = _mm256_setzero_si256(); 82 for (int i = 0; i <= height - 4; i += 3) { 83 for (int j = 0; j <= width - 4; j += 3) { 84 horver_correlation_4x4(&diff[i * stride + j], stride, &xy_sum_32, 85 &xz_sum_32, &x_sum_32, &x2_sum_32); 86 } 87 const __m256i hadd_xy_xz = _mm256_hadd_epi32(xy_sum_32, xz_sum_32); 88 // hadd_xy_xz = [ae+bf+cg ei+fj+gk ab+bc+cd ef+fg+gh] 89 // [im+jn+ko mm+nn+oo ij+jk+kl mn+no+op] as i32 90 yy_storeu_256(xy_xz_tmp, hadd_xy_xz); 91 xy_sum += (int64_t)xy_xz_tmp[5] + xy_xz_tmp[4] + xy_xz_tmp[1]; 92 xz_sum += (int64_t)xy_xz_tmp[7] + xy_xz_tmp[6] + xy_xz_tmp[3]; 93 94 const __m256i hadd_x_x2 = _mm256_hadd_epi32(x_sum_32, x2_sum_32); 95 // hadd_x_x2 = [aa+bb+cc ee+ff+gg a+b+c e+f+g] 96 // [ii+jj+kk mm+nn+oo i+j+k m+n+o] as i32 97 yy_storeu_256(x_x2_tmp, hadd_x_x2); 98 x_sum += (int64_t)x_x2_tmp[5] + x_x2_tmp[4] + x_x2_tmp[1]; 99 x2_sum += (int64_t)x_x2_tmp[7] + x_x2_tmp[6] + x_x2_tmp[3]; 100 101 xy_sum_32 = _mm256_setzero_si256(); 102 xz_sum_32 = _mm256_setzero_si256(); 103 x_sum_32 = _mm256_setzero_si256(); 104 x2_sum_32 = _mm256_setzero_si256(); 105 } 106 107 // x_sum now covers every pixel except the final 1-2 rows and 1-2 cols 108 int64_t x_finalrow = 0, x_finalcol = 0, x2_finalrow = 0, x2_finalcol = 0; 109 110 // Do we have 2 rows remaining or just the one? Note that width and height 111 // are powers of 2, so each modulo 3 must be 1 or 2. 112 if (height % 3 == 1) { // Just horiz corrs on the final row 113 const int16_t x0 = diff[(height - 1) * stride]; 114 x_sum += x0; 115 x_finalrow += x0; 116 x2_sum += x0 * x0; 117 x2_finalrow += x0 * x0; 118 for (int j = 0; j < width - 1; ++j) { 119 const int16_t x = diff[(height - 1) * stride + j]; 120 const int16_t y = diff[(height - 1) * stride + j + 1]; 121 xy_sum += x * y; 122 x_sum += y; 123 x2_sum += y * y; 124 x_finalrow += y; 125 x2_finalrow += y * y; 126 } 127 } else { // Two rows remaining to do 128 const int16_t x0 = diff[(height - 2) * stride]; 129 const int16_t z0 = diff[(height - 1) * stride]; 130 x_sum += x0 + z0; 131 x2_sum += x0 * x0 + z0 * z0; 132 x_finalrow += z0; 133 x2_finalrow += z0 * z0; 134 for (int j = 0; j < width - 1; ++j) { 135 const int16_t x = diff[(height - 2) * stride + j]; 136 const int16_t y = diff[(height - 2) * stride + j + 1]; 137 const int16_t z = diff[(height - 1) * stride + j]; 138 const int16_t w = diff[(height - 1) * stride + j + 1]; 139 140 // Horizontal and vertical correlations for the penultimate row: 141 xy_sum += x * y; 142 xz_sum += x * z; 143 144 // Now just horizontal correlations for the final row: 145 xy_sum += z * w; 146 147 x_sum += y + w; 148 x2_sum += y * y + w * w; 149 x_finalrow += w; 150 x2_finalrow += w * w; 151 } 152 } 153 154 // Do we have 2 columns remaining or just the one? 155 if (width % 3 == 1) { // Just vert corrs on the final col 156 const int16_t x0 = diff[width - 1]; 157 x_sum += x0; 158 x_finalcol += x0; 159 x2_sum += x0 * x0; 160 x2_finalcol += x0 * x0; 161 for (int i = 0; i < height - 1; ++i) { 162 const int16_t x = diff[i * stride + width - 1]; 163 const int16_t z = diff[(i + 1) * stride + width - 1]; 164 xz_sum += x * z; 165 x_finalcol += z; 166 x2_finalcol += z * z; 167 // So the bottom-right elements don't get counted twice: 168 if (i < height - (height % 3 == 1 ? 2 : 3)) { 169 x_sum += z; 170 x2_sum += z * z; 171 } 172 } 173 } else { // Two cols remaining 174 const int16_t x0 = diff[width - 2]; 175 const int16_t y0 = diff[width - 1]; 176 x_sum += x0 + y0; 177 x2_sum += x0 * x0 + y0 * y0; 178 x_finalcol += y0; 179 x2_finalcol += y0 * y0; 180 for (int i = 0; i < height - 1; ++i) { 181 const int16_t x = diff[i * stride + width - 2]; 182 const int16_t y = diff[i * stride + width - 1]; 183 const int16_t z = diff[(i + 1) * stride + width - 2]; 184 const int16_t w = diff[(i + 1) * stride + width - 1]; 185 186 // Horizontal and vertical correlations for the penultimate col: 187 // Skip these on the last iteration of this loop if we also had two 188 // rows remaining, otherwise the final horizontal and vertical correlation 189 // get erroneously processed twice 190 if (i < height - 2 || height % 3 == 1) { 191 xy_sum += x * y; 192 xz_sum += x * z; 193 } 194 195 x_finalcol += w; 196 x2_finalcol += w * w; 197 // So the bottom-right elements don't get counted twice: 198 if (i < height - (height % 3 == 1 ? 2 : 3)) { 199 x_sum += z + w; 200 x2_sum += z * z + w * w; 201 } 202 203 // Now just vertical correlations for the final column: 204 xz_sum += y * w; 205 } 206 } 207 208 // Calculate the simple sums and squared-sums 209 int64_t x_firstrow = 0, x_firstcol = 0; 210 int64_t x2_firstrow = 0, x2_firstcol = 0; 211 212 for (int j = 0; j < width; ++j) { 213 x_firstrow += diff[j]; 214 x2_firstrow += diff[j] * diff[j]; 215 } 216 for (int i = 0; i < height; ++i) { 217 x_firstcol += diff[i * stride]; 218 x2_firstcol += diff[i * stride] * diff[i * stride]; 219 } 220 221 int64_t xhor_sum = x_sum - x_finalcol; 222 int64_t xver_sum = x_sum - x_finalrow; 223 int64_t y_sum = x_sum - x_firstcol; 224 int64_t z_sum = x_sum - x_firstrow; 225 int64_t x2hor_sum = x2_sum - x2_finalcol; 226 int64_t x2ver_sum = x2_sum - x2_finalrow; 227 int64_t y2_sum = x2_sum - x2_firstcol; 228 int64_t z2_sum = x2_sum - x2_firstrow; 229 230 const float num_hor = (float)(height * (width - 1)); 231 const float num_ver = (float)((height - 1) * width); 232 233 const float xhor_var_n = x2hor_sum - (xhor_sum * xhor_sum) / num_hor; 234 const float xver_var_n = x2ver_sum - (xver_sum * xver_sum) / num_ver; 235 236 const float y_var_n = y2_sum - (y_sum * y_sum) / num_hor; 237 const float z_var_n = z2_sum - (z_sum * z_sum) / num_ver; 238 239 const float xy_var_n = xy_sum - (xhor_sum * y_sum) / num_hor; 240 const float xz_var_n = xz_sum - (xver_sum * z_sum) / num_ver; 241 242 if (xhor_var_n > 0 && y_var_n > 0) { 243 *hcorr = xy_var_n / sqrtf(xhor_var_n * y_var_n); 244 *hcorr = *hcorr < 0 ? 0 : *hcorr; 245 } else { 246 *hcorr = 1.0; 247 } 248 if (xver_var_n > 0 && z_var_n > 0) { 249 *vcorr = xz_var_n / sqrtf(xver_var_n * z_var_n); 250 *vcorr = *vcorr < 0 ? 0 : *vcorr; 251 } else { 252 *vcorr = 1.0; 253 } 254 }