txfm_common_avx2.h (12974B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #ifndef AOM_AOM_DSP_X86_TXFM_COMMON_AVX2_H_ 13 #define AOM_AOM_DSP_X86_TXFM_COMMON_AVX2_H_ 14 15 #include <emmintrin.h> 16 #include "aom/aom_integer.h" 17 #include "aom_dsp/x86/synonyms.h" 18 19 #ifdef __cplusplus 20 extern "C" { 21 #endif 22 23 static inline __m256i pair_set_w16_epi16(int16_t a, int16_t b) { 24 return _mm256_set1_epi32( 25 (int32_t)(((uint16_t)(a)) | (((uint32_t)(uint16_t)(b)) << 16))); 26 } 27 28 static inline void btf_16_w16_avx2(const __m256i w0, const __m256i w1, 29 __m256i *in0, __m256i *in1, const __m256i _r, 30 const int32_t cos_bit) { 31 __m256i t0 = _mm256_unpacklo_epi16(*in0, *in1); 32 __m256i t1 = _mm256_unpackhi_epi16(*in0, *in1); 33 __m256i u0 = _mm256_madd_epi16(t0, w0); 34 __m256i u1 = _mm256_madd_epi16(t1, w0); 35 __m256i v0 = _mm256_madd_epi16(t0, w1); 36 __m256i v1 = _mm256_madd_epi16(t1, w1); 37 38 __m256i a0 = _mm256_add_epi32(u0, _r); 39 __m256i a1 = _mm256_add_epi32(u1, _r); 40 __m256i b0 = _mm256_add_epi32(v0, _r); 41 __m256i b1 = _mm256_add_epi32(v1, _r); 42 43 __m256i c0 = _mm256_srai_epi32(a0, cos_bit); 44 __m256i c1 = _mm256_srai_epi32(a1, cos_bit); 45 __m256i d0 = _mm256_srai_epi32(b0, cos_bit); 46 __m256i d1 = _mm256_srai_epi32(b1, cos_bit); 47 48 *in0 = _mm256_packs_epi32(c0, c1); 49 *in1 = _mm256_packs_epi32(d0, d1); 50 } 51 52 static inline void btf_16_adds_subs_avx2(__m256i *in0, __m256i *in1) { 53 const __m256i _in0 = *in0; 54 const __m256i _in1 = *in1; 55 *in0 = _mm256_adds_epi16(_in0, _in1); 56 *in1 = _mm256_subs_epi16(_in0, _in1); 57 } 58 59 static inline void btf_32_add_sub_avx2(__m256i *in0, __m256i *in1) { 60 const __m256i _in0 = *in0; 61 const __m256i _in1 = *in1; 62 *in0 = _mm256_add_epi32(_in0, _in1); 63 *in1 = _mm256_sub_epi32(_in0, _in1); 64 } 65 66 static inline void btf_16_adds_subs_out_avx2(__m256i *out0, __m256i *out1, 67 __m256i in0, __m256i in1) { 68 const __m256i _in0 = in0; 69 const __m256i _in1 = in1; 70 *out0 = _mm256_adds_epi16(_in0, _in1); 71 *out1 = _mm256_subs_epi16(_in0, _in1); 72 } 73 74 static inline void btf_32_add_sub_out_avx2(__m256i *out0, __m256i *out1, 75 __m256i in0, __m256i in1) { 76 const __m256i _in0 = in0; 77 const __m256i _in1 = in1; 78 *out0 = _mm256_add_epi32(_in0, _in1); 79 *out1 = _mm256_sub_epi32(_in0, _in1); 80 } 81 82 static inline __m256i load_16bit_to_16bit_avx2(const int16_t *a) { 83 return _mm256_load_si256((const __m256i *)a); 84 } 85 86 static inline void load_buffer_16bit_to_16bit_avx2(const int16_t *in, 87 int stride, __m256i *out, 88 int out_size) { 89 for (int i = 0; i < out_size; ++i) { 90 out[i] = load_16bit_to_16bit_avx2(in + i * stride); 91 } 92 } 93 94 static inline void load_buffer_16bit_to_16bit_flip_avx2(const int16_t *in, 95 int stride, 96 __m256i *out, 97 int out_size) { 98 for (int i = 0; i < out_size; ++i) { 99 out[out_size - i - 1] = load_16bit_to_16bit_avx2(in + i * stride); 100 } 101 } 102 103 static inline __m256i load_32bit_to_16bit_w16_avx2(const int32_t *a) { 104 const __m256i a_low = _mm256_lddqu_si256((const __m256i *)a); 105 const __m256i b = _mm256_packs_epi32(a_low, *(const __m256i *)(a + 8)); 106 return _mm256_permute4x64_epi64(b, 0xD8); 107 } 108 109 static inline void load_buffer_32bit_to_16bit_w16_avx2(const int32_t *in, 110 int stride, __m256i *out, 111 int out_size) { 112 for (int i = 0; i < out_size; ++i) { 113 out[i] = load_32bit_to_16bit_w16_avx2(in + i * stride); 114 } 115 } 116 117 static inline void transpose2_8x8_avx2(const __m256i *const in, 118 __m256i *const out) { 119 __m256i t[16], u[16]; 120 // (1st, 2nd) ==> (lo, hi) 121 // (0, 1) ==> (0, 1) 122 // (2, 3) ==> (2, 3) 123 // (4, 5) ==> (4, 5) 124 // (6, 7) ==> (6, 7) 125 for (int i = 0; i < 4; i++) { 126 t[2 * i] = _mm256_unpacklo_epi16(in[2 * i], in[2 * i + 1]); 127 t[2 * i + 1] = _mm256_unpackhi_epi16(in[2 * i], in[2 * i + 1]); 128 } 129 130 // (1st, 2nd) ==> (lo, hi) 131 // (0, 2) ==> (0, 2) 132 // (1, 3) ==> (1, 3) 133 // (4, 6) ==> (4, 6) 134 // (5, 7) ==> (5, 7) 135 for (int i = 0; i < 2; i++) { 136 u[i] = _mm256_unpacklo_epi32(t[i], t[i + 2]); 137 u[i + 2] = _mm256_unpackhi_epi32(t[i], t[i + 2]); 138 139 u[i + 4] = _mm256_unpacklo_epi32(t[i + 4], t[i + 6]); 140 u[i + 6] = _mm256_unpackhi_epi32(t[i + 4], t[i + 6]); 141 } 142 143 // (1st, 2nd) ==> (lo, hi) 144 // (0, 4) ==> (0, 1) 145 // (1, 5) ==> (4, 5) 146 // (2, 6) ==> (2, 3) 147 // (3, 7) ==> (6, 7) 148 for (int i = 0; i < 2; i++) { 149 out[2 * i] = _mm256_unpacklo_epi64(u[2 * i], u[2 * i + 4]); 150 out[2 * i + 1] = _mm256_unpackhi_epi64(u[2 * i], u[2 * i + 4]); 151 152 out[2 * i + 4] = _mm256_unpacklo_epi64(u[2 * i + 1], u[2 * i + 5]); 153 out[2 * i + 5] = _mm256_unpackhi_epi64(u[2 * i + 1], u[2 * i + 5]); 154 } 155 } 156 157 static inline void transpose_16bit_16x16_avx2(const __m256i *const in, 158 __m256i *const out) { 159 __m256i t[16]; 160 161 #define LOADL(idx) \ 162 t[idx] = _mm256_castsi128_si256(_mm_load_si128((__m128i const *)&in[idx])); \ 163 t[idx] = _mm256_inserti128_si256( \ 164 t[idx], _mm_load_si128((__m128i const *)&in[idx + 8]), 1); 165 166 #define LOADR(idx) \ 167 t[8 + idx] = \ 168 _mm256_castsi128_si256(_mm_load_si128((__m128i const *)&in[idx] + 1)); \ 169 t[8 + idx] = _mm256_inserti128_si256( \ 170 t[8 + idx], _mm_load_si128((__m128i const *)&in[idx + 8] + 1), 1); 171 172 // load left 8x16 173 LOADL(0) 174 LOADL(1) 175 LOADL(2) 176 LOADL(3) 177 LOADL(4) 178 LOADL(5) 179 LOADL(6) 180 LOADL(7) 181 182 // load right 8x16 183 LOADR(0) 184 LOADR(1) 185 LOADR(2) 186 LOADR(3) 187 LOADR(4) 188 LOADR(5) 189 LOADR(6) 190 LOADR(7) 191 192 // get the top 16x8 result 193 transpose2_8x8_avx2(t, out); 194 // get the bottom 16x8 result 195 transpose2_8x8_avx2(&t[8], &out[8]); 196 } 197 198 static inline void transpose_16bit_16x8_avx2(const __m256i *const in, 199 __m256i *const out) { 200 const __m256i a0 = _mm256_unpacklo_epi16(in[0], in[1]); 201 const __m256i a1 = _mm256_unpacklo_epi16(in[2], in[3]); 202 const __m256i a2 = _mm256_unpacklo_epi16(in[4], in[5]); 203 const __m256i a3 = _mm256_unpacklo_epi16(in[6], in[7]); 204 const __m256i a4 = _mm256_unpackhi_epi16(in[0], in[1]); 205 const __m256i a5 = _mm256_unpackhi_epi16(in[2], in[3]); 206 const __m256i a6 = _mm256_unpackhi_epi16(in[4], in[5]); 207 const __m256i a7 = _mm256_unpackhi_epi16(in[6], in[7]); 208 209 const __m256i b0 = _mm256_unpacklo_epi32(a0, a1); 210 const __m256i b1 = _mm256_unpacklo_epi32(a2, a3); 211 const __m256i b2 = _mm256_unpacklo_epi32(a4, a5); 212 const __m256i b3 = _mm256_unpacklo_epi32(a6, a7); 213 const __m256i b4 = _mm256_unpackhi_epi32(a0, a1); 214 const __m256i b5 = _mm256_unpackhi_epi32(a2, a3); 215 const __m256i b6 = _mm256_unpackhi_epi32(a4, a5); 216 const __m256i b7 = _mm256_unpackhi_epi32(a6, a7); 217 218 out[0] = _mm256_unpacklo_epi64(b0, b1); 219 out[1] = _mm256_unpackhi_epi64(b0, b1); 220 out[2] = _mm256_unpacklo_epi64(b4, b5); 221 out[3] = _mm256_unpackhi_epi64(b4, b5); 222 out[4] = _mm256_unpacklo_epi64(b2, b3); 223 out[5] = _mm256_unpackhi_epi64(b2, b3); 224 out[6] = _mm256_unpacklo_epi64(b6, b7); 225 out[7] = _mm256_unpackhi_epi64(b6, b7); 226 } 227 228 static inline void flip_buf_avx2(__m256i *in, __m256i *out, int size) { 229 for (int i = 0; i < size; ++i) { 230 out[size - i - 1] = in[i]; 231 } 232 } 233 234 static inline void round_shift_16bit_w16_avx2(__m256i *in, int size, int bit) { 235 if (bit < 0) { 236 bit = -bit; 237 __m256i round = _mm256_set1_epi16(1 << (bit - 1)); 238 for (int i = 0; i < size; ++i) { 239 in[i] = _mm256_adds_epi16(in[i], round); 240 in[i] = _mm256_srai_epi16(in[i], bit); 241 } 242 } else if (bit > 0) { 243 for (int i = 0; i < size; ++i) { 244 in[i] = _mm256_slli_epi16(in[i], bit); 245 } 246 } 247 } 248 249 static inline __m256i round_shift_32_avx2(__m256i vec, int bit) { 250 __m256i tmp, round; 251 round = _mm256_set1_epi32(1 << (bit - 1)); 252 tmp = _mm256_add_epi32(vec, round); 253 return _mm256_srai_epi32(tmp, bit); 254 } 255 256 static inline void round_shift_array_32_avx2(__m256i *input, __m256i *output, 257 const int size, const int bit) { 258 if (bit > 0) { 259 int i; 260 for (i = 0; i < size; i++) { 261 output[i] = round_shift_32_avx2(input[i], bit); 262 } 263 } else { 264 int i; 265 for (i = 0; i < size; i++) { 266 output[i] = _mm256_slli_epi32(input[i], -bit); 267 } 268 } 269 } 270 271 static inline void round_shift_rect_array_32_avx2(__m256i *input, 272 __m256i *output, 273 const int size, const int bit, 274 const int val) { 275 const __m256i sqrt2 = _mm256_set1_epi32(val); 276 if (bit > 0) { 277 int i; 278 for (i = 0; i < size; i++) { 279 const __m256i r0 = round_shift_32_avx2(input[i], bit); 280 const __m256i r1 = _mm256_mullo_epi32(sqrt2, r0); 281 output[i] = round_shift_32_avx2(r1, NewSqrt2Bits); 282 } 283 } else { 284 int i; 285 for (i = 0; i < size; i++) { 286 const __m256i r0 = _mm256_slli_epi32(input[i], -bit); 287 const __m256i r1 = _mm256_mullo_epi32(sqrt2, r0); 288 output[i] = round_shift_32_avx2(r1, NewSqrt2Bits); 289 } 290 } 291 } 292 293 static inline __m256i scale_round_avx2(const __m256i a, const int scale) { 294 const __m256i scale_rounding = 295 pair_set_w16_epi16(scale, 1 << (NewSqrt2Bits - 1)); 296 const __m256i b = _mm256_madd_epi16(a, scale_rounding); 297 return _mm256_srai_epi32(b, NewSqrt2Bits); 298 } 299 300 static inline void store_rect_16bit_to_32bit_w8_avx2(const __m256i a, 301 int32_t *const b) { 302 const __m256i one = _mm256_set1_epi16(1); 303 const __m256i a_lo = _mm256_unpacklo_epi16(a, one); 304 const __m256i a_hi = _mm256_unpackhi_epi16(a, one); 305 const __m256i b_lo = scale_round_avx2(a_lo, NewSqrt2); 306 const __m256i b_hi = scale_round_avx2(a_hi, NewSqrt2); 307 const __m256i temp = _mm256_permute2f128_si256(b_lo, b_hi, 0x31); 308 _mm_store_si128((__m128i *)b, _mm256_castsi256_si128(b_lo)); 309 _mm_store_si128((__m128i *)(b + 4), _mm256_castsi256_si128(b_hi)); 310 _mm256_store_si256((__m256i *)(b + 64), temp); 311 } 312 313 static inline void store_rect_buffer_16bit_to_32bit_w8_avx2( 314 const __m256i *const in, int32_t *const out, const int stride, 315 const int out_size) { 316 for (int i = 0; i < out_size; ++i) { 317 store_rect_16bit_to_32bit_w8_avx2(in[i], out + i * stride); 318 } 319 } 320 321 static inline void pack_reg(const __m128i *in1, const __m128i *in2, 322 __m256i *out) { 323 out[0] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[0]), in2[0], 0x1); 324 out[1] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[1]), in2[1], 0x1); 325 out[2] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[2]), in2[2], 0x1); 326 out[3] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[3]), in2[3], 0x1); 327 out[4] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[4]), in2[4], 0x1); 328 out[5] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[5]), in2[5], 0x1); 329 out[6] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[6]), in2[6], 0x1); 330 out[7] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[7]), in2[7], 0x1); 331 } 332 333 static inline void extract_reg(const __m256i *in, __m128i *out1) { 334 out1[0] = _mm256_castsi256_si128(in[0]); 335 out1[1] = _mm256_castsi256_si128(in[1]); 336 out1[2] = _mm256_castsi256_si128(in[2]); 337 out1[3] = _mm256_castsi256_si128(in[3]); 338 out1[4] = _mm256_castsi256_si128(in[4]); 339 out1[5] = _mm256_castsi256_si128(in[5]); 340 out1[6] = _mm256_castsi256_si128(in[6]); 341 out1[7] = _mm256_castsi256_si128(in[7]); 342 343 out1[8] = _mm256_extracti128_si256(in[0], 0x01); 344 out1[9] = _mm256_extracti128_si256(in[1], 0x01); 345 out1[10] = _mm256_extracti128_si256(in[2], 0x01); 346 out1[11] = _mm256_extracti128_si256(in[3], 0x01); 347 out1[12] = _mm256_extracti128_si256(in[4], 0x01); 348 out1[13] = _mm256_extracti128_si256(in[5], 0x01); 349 out1[14] = _mm256_extracti128_si256(in[6], 0x01); 350 out1[15] = _mm256_extracti128_si256(in[7], 0x01); 351 } 352 353 #ifdef __cplusplus 354 } 355 #endif 356 357 #endif // AOM_AOM_DSP_X86_TXFM_COMMON_AVX2_H_