wedge_utils_avx2.c (8118B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <assert.h> 13 #include <immintrin.h> 14 #include <smmintrin.h> 15 16 #include "aom_dsp/x86/synonyms.h" 17 #include "aom_dsp/x86/synonyms_avx2.h" 18 #include "aom/aom_integer.h" 19 20 #include "av1/common/reconinter.h" 21 22 #define MAX_MASK_VALUE (1 << WEDGE_WEIGHT_BITS) 23 24 /** 25 * See av1_wedge_sse_from_residuals_c 26 */ 27 uint64_t av1_wedge_sse_from_residuals_avx2(const int16_t *r1, const int16_t *d, 28 const uint8_t *m, int N) { 29 int n = -N; 30 31 uint64_t csse; 32 33 const __m256i v_mask_max_w = _mm256_set1_epi16(MAX_MASK_VALUE); 34 const __m256i v_zext_q = _mm256_set1_epi64x(~0u); 35 36 __m256i v_acc0_q = _mm256_setzero_si256(); 37 38 assert(N % 64 == 0); 39 40 r1 += N; 41 d += N; 42 m += N; 43 44 do { 45 const __m256i v_r0_w = _mm256_lddqu_si256((__m256i *)(r1 + n)); 46 const __m256i v_d0_w = _mm256_lddqu_si256((__m256i *)(d + n)); 47 const __m128i v_m01_b = _mm_lddqu_si128((__m128i *)(m + n)); 48 49 const __m256i v_rd0l_w = _mm256_unpacklo_epi16(v_d0_w, v_r0_w); 50 const __m256i v_rd0h_w = _mm256_unpackhi_epi16(v_d0_w, v_r0_w); 51 const __m256i v_m0_w = _mm256_cvtepu8_epi16(v_m01_b); 52 53 const __m256i v_m0l_w = _mm256_unpacklo_epi16(v_m0_w, v_mask_max_w); 54 const __m256i v_m0h_w = _mm256_unpackhi_epi16(v_m0_w, v_mask_max_w); 55 56 const __m256i v_t0l_d = _mm256_madd_epi16(v_rd0l_w, v_m0l_w); 57 const __m256i v_t0h_d = _mm256_madd_epi16(v_rd0h_w, v_m0h_w); 58 59 const __m256i v_t0_w = _mm256_packs_epi32(v_t0l_d, v_t0h_d); 60 61 const __m256i v_sq0_d = _mm256_madd_epi16(v_t0_w, v_t0_w); 62 63 const __m256i v_sum0_q = _mm256_add_epi64( 64 _mm256_and_si256(v_sq0_d, v_zext_q), _mm256_srli_epi64(v_sq0_d, 32)); 65 66 v_acc0_q = _mm256_add_epi64(v_acc0_q, v_sum0_q); 67 68 n += 16; 69 } while (n); 70 71 v_acc0_q = _mm256_add_epi64(v_acc0_q, _mm256_srli_si256(v_acc0_q, 8)); 72 __m128i v_acc_q_0 = _mm256_castsi256_si128(v_acc0_q); 73 __m128i v_acc_q_1 = _mm256_extracti128_si256(v_acc0_q, 1); 74 v_acc_q_0 = _mm_add_epi64(v_acc_q_0, v_acc_q_1); 75 #if AOM_ARCH_X86_64 76 csse = (uint64_t)_mm_extract_epi64(v_acc_q_0, 0); 77 #else 78 xx_storel_64(&csse, v_acc_q_0); 79 #endif 80 81 return ROUND_POWER_OF_TWO(csse, 2 * WEDGE_WEIGHT_BITS); 82 } 83 84 /** 85 * See av1_wedge_sign_from_residuals_c 86 */ 87 int8_t av1_wedge_sign_from_residuals_avx2(const int16_t *ds, const uint8_t *m, 88 int N, int64_t limit) { 89 int64_t acc; 90 __m256i v_acc0_d = _mm256_setzero_si256(); 91 92 // Input size limited to 8192 by the use of 32 bit accumulators and m 93 // being between [0, 64]. Overflow might happen at larger sizes, 94 // though it is practically impossible on real video input. 95 assert(N < 8192); 96 assert(N % 64 == 0); 97 98 do { 99 const __m256i v_m01_b = _mm256_lddqu_si256((__m256i *)(m)); 100 const __m256i v_m23_b = _mm256_lddqu_si256((__m256i *)(m + 32)); 101 102 const __m256i v_d0_w = _mm256_lddqu_si256((__m256i *)(ds)); 103 const __m256i v_d1_w = _mm256_lddqu_si256((__m256i *)(ds + 16)); 104 const __m256i v_d2_w = _mm256_lddqu_si256((__m256i *)(ds + 32)); 105 const __m256i v_d3_w = _mm256_lddqu_si256((__m256i *)(ds + 48)); 106 107 const __m256i v_m0_w = 108 _mm256_cvtepu8_epi16(_mm256_castsi256_si128(v_m01_b)); 109 const __m256i v_m1_w = 110 _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v_m01_b, 1)); 111 const __m256i v_m2_w = 112 _mm256_cvtepu8_epi16(_mm256_castsi256_si128(v_m23_b)); 113 const __m256i v_m3_w = 114 _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v_m23_b, 1)); 115 116 const __m256i v_p0_d = _mm256_madd_epi16(v_d0_w, v_m0_w); 117 const __m256i v_p1_d = _mm256_madd_epi16(v_d1_w, v_m1_w); 118 const __m256i v_p2_d = _mm256_madd_epi16(v_d2_w, v_m2_w); 119 const __m256i v_p3_d = _mm256_madd_epi16(v_d3_w, v_m3_w); 120 121 const __m256i v_p01_d = _mm256_add_epi32(v_p0_d, v_p1_d); 122 const __m256i v_p23_d = _mm256_add_epi32(v_p2_d, v_p3_d); 123 124 const __m256i v_p0123_d = _mm256_add_epi32(v_p01_d, v_p23_d); 125 126 v_acc0_d = _mm256_add_epi32(v_acc0_d, v_p0123_d); 127 128 ds += 64; 129 m += 64; 130 131 N -= 64; 132 } while (N); 133 134 __m256i v_sign_d = _mm256_srai_epi32(v_acc0_d, 31); 135 v_acc0_d = _mm256_add_epi64(_mm256_unpacklo_epi32(v_acc0_d, v_sign_d), 136 _mm256_unpackhi_epi32(v_acc0_d, v_sign_d)); 137 138 __m256i v_acc_q = _mm256_add_epi64(v_acc0_d, _mm256_srli_si256(v_acc0_d, 8)); 139 140 __m128i v_acc_q_0 = _mm256_castsi256_si128(v_acc_q); 141 __m128i v_acc_q_1 = _mm256_extracti128_si256(v_acc_q, 1); 142 v_acc_q_0 = _mm_add_epi64(v_acc_q_0, v_acc_q_1); 143 144 #if AOM_ARCH_X86_64 145 acc = _mm_extract_epi64(v_acc_q_0, 0); 146 #else 147 xx_storel_64(&acc, v_acc_q_0); 148 #endif 149 150 return acc > limit; 151 } 152 153 /** 154 * av1_wedge_compute_delta_squares_c 155 */ 156 void av1_wedge_compute_delta_squares_avx2(int16_t *d, const int16_t *a, 157 const int16_t *b, int N) { 158 const __m256i v_neg_w = _mm256_set1_epi32((int)0xffff0001); 159 160 assert(N % 64 == 0); 161 162 do { 163 const __m256i v_a0_w = _mm256_lddqu_si256((__m256i *)(a)); 164 const __m256i v_b0_w = _mm256_lddqu_si256((__m256i *)(b)); 165 const __m256i v_a1_w = _mm256_lddqu_si256((__m256i *)(a + 16)); 166 const __m256i v_b1_w = _mm256_lddqu_si256((__m256i *)(b + 16)); 167 const __m256i v_a2_w = _mm256_lddqu_si256((__m256i *)(a + 32)); 168 const __m256i v_b2_w = _mm256_lddqu_si256((__m256i *)(b + 32)); 169 const __m256i v_a3_w = _mm256_lddqu_si256((__m256i *)(a + 48)); 170 const __m256i v_b3_w = _mm256_lddqu_si256((__m256i *)(b + 48)); 171 172 const __m256i v_ab0l_w = _mm256_unpacklo_epi16(v_a0_w, v_b0_w); 173 const __m256i v_ab0h_w = _mm256_unpackhi_epi16(v_a0_w, v_b0_w); 174 const __m256i v_ab1l_w = _mm256_unpacklo_epi16(v_a1_w, v_b1_w); 175 const __m256i v_ab1h_w = _mm256_unpackhi_epi16(v_a1_w, v_b1_w); 176 const __m256i v_ab2l_w = _mm256_unpacklo_epi16(v_a2_w, v_b2_w); 177 const __m256i v_ab2h_w = _mm256_unpackhi_epi16(v_a2_w, v_b2_w); 178 const __m256i v_ab3l_w = _mm256_unpacklo_epi16(v_a3_w, v_b3_w); 179 const __m256i v_ab3h_w = _mm256_unpackhi_epi16(v_a3_w, v_b3_w); 180 181 // Negate top word of pairs 182 const __m256i v_abl0n_w = _mm256_sign_epi16(v_ab0l_w, v_neg_w); 183 const __m256i v_abh0n_w = _mm256_sign_epi16(v_ab0h_w, v_neg_w); 184 const __m256i v_abl1n_w = _mm256_sign_epi16(v_ab1l_w, v_neg_w); 185 const __m256i v_abh1n_w = _mm256_sign_epi16(v_ab1h_w, v_neg_w); 186 const __m256i v_abl2n_w = _mm256_sign_epi16(v_ab2l_w, v_neg_w); 187 const __m256i v_abh2n_w = _mm256_sign_epi16(v_ab2h_w, v_neg_w); 188 const __m256i v_abl3n_w = _mm256_sign_epi16(v_ab3l_w, v_neg_w); 189 const __m256i v_abh3n_w = _mm256_sign_epi16(v_ab3h_w, v_neg_w); 190 191 const __m256i v_r0l_w = _mm256_madd_epi16(v_ab0l_w, v_abl0n_w); 192 const __m256i v_r0h_w = _mm256_madd_epi16(v_ab0h_w, v_abh0n_w); 193 const __m256i v_r1l_w = _mm256_madd_epi16(v_ab1l_w, v_abl1n_w); 194 const __m256i v_r1h_w = _mm256_madd_epi16(v_ab1h_w, v_abh1n_w); 195 const __m256i v_r2l_w = _mm256_madd_epi16(v_ab2l_w, v_abl2n_w); 196 const __m256i v_r2h_w = _mm256_madd_epi16(v_ab2h_w, v_abh2n_w); 197 const __m256i v_r3l_w = _mm256_madd_epi16(v_ab3l_w, v_abl3n_w); 198 const __m256i v_r3h_w = _mm256_madd_epi16(v_ab3h_w, v_abh3n_w); 199 200 const __m256i v_r0_w = _mm256_packs_epi32(v_r0l_w, v_r0h_w); 201 const __m256i v_r1_w = _mm256_packs_epi32(v_r1l_w, v_r1h_w); 202 const __m256i v_r2_w = _mm256_packs_epi32(v_r2l_w, v_r2h_w); 203 const __m256i v_r3_w = _mm256_packs_epi32(v_r3l_w, v_r3h_w); 204 205 _mm256_store_si256((__m256i *)(d), v_r0_w); 206 _mm256_store_si256((__m256i *)(d + 16), v_r1_w); 207 _mm256_store_si256((__m256i *)(d + 32), v_r2_w); 208 _mm256_store_si256((__m256i *)(d + 48), v_r3_w); 209 210 a += 64; 211 b += 64; 212 d += 64; 213 N -= 64; 214 } while (N); 215 }