av1_k_means_avx2.c (5167B)
1 /* 2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 #include <immintrin.h> // AVX2 12 13 #include "config/av1_rtcd.h" 14 #include "aom_dsp/x86/synonyms.h" 15 16 static int64_t k_means_horizontal_sum_avx2(__m256i a) { 17 const __m128i low = _mm256_castsi256_si128(a); 18 const __m128i high = _mm256_extracti128_si256(a, 1); 19 const __m128i sum = _mm_add_epi64(low, high); 20 const __m128i sum_high = _mm_unpackhi_epi64(sum, sum); 21 int64_t res; 22 _mm_storel_epi64((__m128i *)&res, _mm_add_epi64(sum, sum_high)); 23 return res; 24 } 25 26 void av1_calc_indices_dim1_avx2(const int16_t *data, const int16_t *centroids, 27 uint8_t *indices, int64_t *total_dist, int n, 28 int k) { 29 const __m256i v_zero = _mm256_setzero_si256(); 30 __m256i sum = _mm256_setzero_si256(); 31 __m256i cents[PALETTE_MAX_SIZE]; 32 for (int j = 0; j < k; ++j) { 33 cents[j] = _mm256_set1_epi16(centroids[j]); 34 } 35 36 for (int i = 0; i < n; i += 16) { 37 const __m256i in = _mm256_loadu_si256((__m256i *)data); 38 __m256i ind = _mm256_setzero_si256(); 39 // Compute the distance to the first centroid. 40 __m256i d1 = _mm256_sub_epi16(in, cents[0]); 41 __m256i dist_min = _mm256_abs_epi16(d1); 42 43 for (int j = 1; j < k; ++j) { 44 // Compute the distance to the centroid. 45 d1 = _mm256_sub_epi16(in, cents[j]); 46 const __m256i dist = _mm256_abs_epi16(d1); 47 // Compare to the minimal one. 48 const __m256i cmp = _mm256_cmpgt_epi16(dist_min, dist); 49 dist_min = _mm256_min_epi16(dist_min, dist); 50 const __m256i ind1 = _mm256_set1_epi16(j); 51 ind = _mm256_or_si256(_mm256_andnot_si256(cmp, ind), 52 _mm256_and_si256(cmp, ind1)); 53 } 54 55 const __m256i p1 = _mm256_packus_epi16(ind, v_zero); 56 const __m256i px = _mm256_permute4x64_epi64(p1, 0x58); 57 const __m128i d2 = _mm256_extracti128_si256(px, 0); 58 59 _mm_storeu_si128((__m128i *)indices, d2); 60 61 if (total_dist) { 62 // Square, convert to 32 bit and add together. 63 dist_min = _mm256_madd_epi16(dist_min, dist_min); 64 // Convert to 64 bit and add to sum. 65 const __m256i dist1 = _mm256_unpacklo_epi32(dist_min, v_zero); 66 const __m256i dist2 = _mm256_unpackhi_epi32(dist_min, v_zero); 67 sum = _mm256_add_epi64(sum, dist1); 68 sum = _mm256_add_epi64(sum, dist2); 69 } 70 71 indices += 16; 72 data += 16; 73 } 74 if (total_dist) { 75 *total_dist = k_means_horizontal_sum_avx2(sum); 76 } 77 } 78 79 void av1_calc_indices_dim2_avx2(const int16_t *data, const int16_t *centroids, 80 uint8_t *indices, int64_t *total_dist, int n, 81 int k) { 82 const __m256i v_zero = _mm256_setzero_si256(); 83 const __m256i permute = _mm256_set_epi32(0, 0, 0, 0, 5, 1, 4, 0); 84 __m256i sum = _mm256_setzero_si256(); 85 __m256i ind[2]; 86 __m256i cents[PALETTE_MAX_SIZE]; 87 for (int j = 0; j < k; ++j) { 88 const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1]; 89 cents[j] = _mm256_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx, cy, cx, cy, cx, 90 cy, cx, cy, cx); 91 } 92 93 for (int i = 0; i < n; i += 16) { 94 for (int l = 0; l < 2; ++l) { 95 const __m256i in = _mm256_loadu_si256((__m256i *)data); 96 ind[l] = _mm256_setzero_si256(); 97 // Compute the distance to the first centroid. 98 __m256i d1 = _mm256_sub_epi16(in, cents[0]); 99 __m256i dist_min = _mm256_madd_epi16(d1, d1); 100 101 for (int j = 1; j < k; ++j) { 102 // Compute the distance to the centroid. 103 d1 = _mm256_sub_epi16(in, cents[j]); 104 const __m256i dist = _mm256_madd_epi16(d1, d1); 105 // Compare to the minimal one. 106 const __m256i cmp = _mm256_cmpgt_epi32(dist_min, dist); 107 dist_min = _mm256_min_epi32(dist_min, dist); 108 const __m256i ind1 = _mm256_set1_epi32(j); 109 ind[l] = _mm256_or_si256(_mm256_andnot_si256(cmp, ind[l]), 110 _mm256_and_si256(cmp, ind1)); 111 } 112 if (total_dist) { 113 // Convert to 64 bit and add to sum. 114 const __m256i dist1 = _mm256_unpacklo_epi32(dist_min, v_zero); 115 const __m256i dist2 = _mm256_unpackhi_epi32(dist_min, v_zero); 116 sum = _mm256_add_epi64(sum, dist1); 117 sum = _mm256_add_epi64(sum, dist2); 118 } 119 data += 16; 120 } 121 // Cast to 8 bit and store. 122 const __m256i d2 = _mm256_packus_epi32(ind[0], ind[1]); 123 const __m256i d3 = _mm256_packus_epi16(d2, v_zero); 124 const __m256i d4 = _mm256_permutevar8x32_epi32(d3, permute); 125 const __m128i d5 = _mm256_extracti128_si256(d4, 0); 126 _mm_storeu_si128((__m128i *)indices, d5); 127 indices += 16; 128 } 129 if (total_dist) { 130 *total_dist = k_means_horizontal_sum_avx2(sum); 131 } 132 }