tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

wedge_utils_avx2.c (8118B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <immintrin.h>
     14 #include <smmintrin.h>
     15 
     16 #include "aom_dsp/x86/synonyms.h"
     17 #include "aom_dsp/x86/synonyms_avx2.h"
     18 #include "aom/aom_integer.h"
     19 
     20 #include "av1/common/reconinter.h"
     21 
     22 #define MAX_MASK_VALUE (1 << WEDGE_WEIGHT_BITS)
     23 
     24 /**
     25 * See av1_wedge_sse_from_residuals_c
     26 */
     27 uint64_t av1_wedge_sse_from_residuals_avx2(const int16_t *r1, const int16_t *d,
     28                                           const uint8_t *m, int N) {
     29  int n = -N;
     30 
     31  uint64_t csse;
     32 
     33  const __m256i v_mask_max_w = _mm256_set1_epi16(MAX_MASK_VALUE);
     34  const __m256i v_zext_q = _mm256_set1_epi64x(~0u);
     35 
     36  __m256i v_acc0_q = _mm256_setzero_si256();
     37 
     38  assert(N % 64 == 0);
     39 
     40  r1 += N;
     41  d += N;
     42  m += N;
     43 
     44  do {
     45    const __m256i v_r0_w = _mm256_lddqu_si256((__m256i *)(r1 + n));
     46    const __m256i v_d0_w = _mm256_lddqu_si256((__m256i *)(d + n));
     47    const __m128i v_m01_b = _mm_lddqu_si128((__m128i *)(m + n));
     48 
     49    const __m256i v_rd0l_w = _mm256_unpacklo_epi16(v_d0_w, v_r0_w);
     50    const __m256i v_rd0h_w = _mm256_unpackhi_epi16(v_d0_w, v_r0_w);
     51    const __m256i v_m0_w = _mm256_cvtepu8_epi16(v_m01_b);
     52 
     53    const __m256i v_m0l_w = _mm256_unpacklo_epi16(v_m0_w, v_mask_max_w);
     54    const __m256i v_m0h_w = _mm256_unpackhi_epi16(v_m0_w, v_mask_max_w);
     55 
     56    const __m256i v_t0l_d = _mm256_madd_epi16(v_rd0l_w, v_m0l_w);
     57    const __m256i v_t0h_d = _mm256_madd_epi16(v_rd0h_w, v_m0h_w);
     58 
     59    const __m256i v_t0_w = _mm256_packs_epi32(v_t0l_d, v_t0h_d);
     60 
     61    const __m256i v_sq0_d = _mm256_madd_epi16(v_t0_w, v_t0_w);
     62 
     63    const __m256i v_sum0_q = _mm256_add_epi64(
     64        _mm256_and_si256(v_sq0_d, v_zext_q), _mm256_srli_epi64(v_sq0_d, 32));
     65 
     66    v_acc0_q = _mm256_add_epi64(v_acc0_q, v_sum0_q);
     67 
     68    n += 16;
     69  } while (n);
     70 
     71  v_acc0_q = _mm256_add_epi64(v_acc0_q, _mm256_srli_si256(v_acc0_q, 8));
     72  __m128i v_acc_q_0 = _mm256_castsi256_si128(v_acc0_q);
     73  __m128i v_acc_q_1 = _mm256_extracti128_si256(v_acc0_q, 1);
     74  v_acc_q_0 = _mm_add_epi64(v_acc_q_0, v_acc_q_1);
     75 #if AOM_ARCH_X86_64
     76  csse = (uint64_t)_mm_extract_epi64(v_acc_q_0, 0);
     77 #else
     78  xx_storel_64(&csse, v_acc_q_0);
     79 #endif
     80 
     81  return ROUND_POWER_OF_TWO(csse, 2 * WEDGE_WEIGHT_BITS);
     82 }
     83 
     84 /**
     85 * See av1_wedge_sign_from_residuals_c
     86 */
     87 int8_t av1_wedge_sign_from_residuals_avx2(const int16_t *ds, const uint8_t *m,
     88                                          int N, int64_t limit) {
     89  int64_t acc;
     90  __m256i v_acc0_d = _mm256_setzero_si256();
     91 
     92  // Input size limited to 8192 by the use of 32 bit accumulators and m
     93  // being between [0, 64]. Overflow might happen at larger sizes,
     94  // though it is practically impossible on real video input.
     95  assert(N < 8192);
     96  assert(N % 64 == 0);
     97 
     98  do {
     99    const __m256i v_m01_b = _mm256_lddqu_si256((__m256i *)(m));
    100    const __m256i v_m23_b = _mm256_lddqu_si256((__m256i *)(m + 32));
    101 
    102    const __m256i v_d0_w = _mm256_lddqu_si256((__m256i *)(ds));
    103    const __m256i v_d1_w = _mm256_lddqu_si256((__m256i *)(ds + 16));
    104    const __m256i v_d2_w = _mm256_lddqu_si256((__m256i *)(ds + 32));
    105    const __m256i v_d3_w = _mm256_lddqu_si256((__m256i *)(ds + 48));
    106 
    107    const __m256i v_m0_w =
    108        _mm256_cvtepu8_epi16(_mm256_castsi256_si128(v_m01_b));
    109    const __m256i v_m1_w =
    110        _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v_m01_b, 1));
    111    const __m256i v_m2_w =
    112        _mm256_cvtepu8_epi16(_mm256_castsi256_si128(v_m23_b));
    113    const __m256i v_m3_w =
    114        _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v_m23_b, 1));
    115 
    116    const __m256i v_p0_d = _mm256_madd_epi16(v_d0_w, v_m0_w);
    117    const __m256i v_p1_d = _mm256_madd_epi16(v_d1_w, v_m1_w);
    118    const __m256i v_p2_d = _mm256_madd_epi16(v_d2_w, v_m2_w);
    119    const __m256i v_p3_d = _mm256_madd_epi16(v_d3_w, v_m3_w);
    120 
    121    const __m256i v_p01_d = _mm256_add_epi32(v_p0_d, v_p1_d);
    122    const __m256i v_p23_d = _mm256_add_epi32(v_p2_d, v_p3_d);
    123 
    124    const __m256i v_p0123_d = _mm256_add_epi32(v_p01_d, v_p23_d);
    125 
    126    v_acc0_d = _mm256_add_epi32(v_acc0_d, v_p0123_d);
    127 
    128    ds += 64;
    129    m += 64;
    130 
    131    N -= 64;
    132  } while (N);
    133 
    134  __m256i v_sign_d = _mm256_srai_epi32(v_acc0_d, 31);
    135  v_acc0_d = _mm256_add_epi64(_mm256_unpacklo_epi32(v_acc0_d, v_sign_d),
    136                              _mm256_unpackhi_epi32(v_acc0_d, v_sign_d));
    137 
    138  __m256i v_acc_q = _mm256_add_epi64(v_acc0_d, _mm256_srli_si256(v_acc0_d, 8));
    139 
    140  __m128i v_acc_q_0 = _mm256_castsi256_si128(v_acc_q);
    141  __m128i v_acc_q_1 = _mm256_extracti128_si256(v_acc_q, 1);
    142  v_acc_q_0 = _mm_add_epi64(v_acc_q_0, v_acc_q_1);
    143 
    144 #if AOM_ARCH_X86_64
    145  acc = _mm_extract_epi64(v_acc_q_0, 0);
    146 #else
    147  xx_storel_64(&acc, v_acc_q_0);
    148 #endif
    149 
    150  return acc > limit;
    151 }
    152 
    153 /**
    154 * av1_wedge_compute_delta_squares_c
    155 */
    156 void av1_wedge_compute_delta_squares_avx2(int16_t *d, const int16_t *a,
    157                                          const int16_t *b, int N) {
    158  const __m256i v_neg_w = _mm256_set1_epi32((int)0xffff0001);
    159 
    160  assert(N % 64 == 0);
    161 
    162  do {
    163    const __m256i v_a0_w = _mm256_lddqu_si256((__m256i *)(a));
    164    const __m256i v_b0_w = _mm256_lddqu_si256((__m256i *)(b));
    165    const __m256i v_a1_w = _mm256_lddqu_si256((__m256i *)(a + 16));
    166    const __m256i v_b1_w = _mm256_lddqu_si256((__m256i *)(b + 16));
    167    const __m256i v_a2_w = _mm256_lddqu_si256((__m256i *)(a + 32));
    168    const __m256i v_b2_w = _mm256_lddqu_si256((__m256i *)(b + 32));
    169    const __m256i v_a3_w = _mm256_lddqu_si256((__m256i *)(a + 48));
    170    const __m256i v_b3_w = _mm256_lddqu_si256((__m256i *)(b + 48));
    171 
    172    const __m256i v_ab0l_w = _mm256_unpacklo_epi16(v_a0_w, v_b0_w);
    173    const __m256i v_ab0h_w = _mm256_unpackhi_epi16(v_a0_w, v_b0_w);
    174    const __m256i v_ab1l_w = _mm256_unpacklo_epi16(v_a1_w, v_b1_w);
    175    const __m256i v_ab1h_w = _mm256_unpackhi_epi16(v_a1_w, v_b1_w);
    176    const __m256i v_ab2l_w = _mm256_unpacklo_epi16(v_a2_w, v_b2_w);
    177    const __m256i v_ab2h_w = _mm256_unpackhi_epi16(v_a2_w, v_b2_w);
    178    const __m256i v_ab3l_w = _mm256_unpacklo_epi16(v_a3_w, v_b3_w);
    179    const __m256i v_ab3h_w = _mm256_unpackhi_epi16(v_a3_w, v_b3_w);
    180 
    181    // Negate top word of pairs
    182    const __m256i v_abl0n_w = _mm256_sign_epi16(v_ab0l_w, v_neg_w);
    183    const __m256i v_abh0n_w = _mm256_sign_epi16(v_ab0h_w, v_neg_w);
    184    const __m256i v_abl1n_w = _mm256_sign_epi16(v_ab1l_w, v_neg_w);
    185    const __m256i v_abh1n_w = _mm256_sign_epi16(v_ab1h_w, v_neg_w);
    186    const __m256i v_abl2n_w = _mm256_sign_epi16(v_ab2l_w, v_neg_w);
    187    const __m256i v_abh2n_w = _mm256_sign_epi16(v_ab2h_w, v_neg_w);
    188    const __m256i v_abl3n_w = _mm256_sign_epi16(v_ab3l_w, v_neg_w);
    189    const __m256i v_abh3n_w = _mm256_sign_epi16(v_ab3h_w, v_neg_w);
    190 
    191    const __m256i v_r0l_w = _mm256_madd_epi16(v_ab0l_w, v_abl0n_w);
    192    const __m256i v_r0h_w = _mm256_madd_epi16(v_ab0h_w, v_abh0n_w);
    193    const __m256i v_r1l_w = _mm256_madd_epi16(v_ab1l_w, v_abl1n_w);
    194    const __m256i v_r1h_w = _mm256_madd_epi16(v_ab1h_w, v_abh1n_w);
    195    const __m256i v_r2l_w = _mm256_madd_epi16(v_ab2l_w, v_abl2n_w);
    196    const __m256i v_r2h_w = _mm256_madd_epi16(v_ab2h_w, v_abh2n_w);
    197    const __m256i v_r3l_w = _mm256_madd_epi16(v_ab3l_w, v_abl3n_w);
    198    const __m256i v_r3h_w = _mm256_madd_epi16(v_ab3h_w, v_abh3n_w);
    199 
    200    const __m256i v_r0_w = _mm256_packs_epi32(v_r0l_w, v_r0h_w);
    201    const __m256i v_r1_w = _mm256_packs_epi32(v_r1l_w, v_r1h_w);
    202    const __m256i v_r2_w = _mm256_packs_epi32(v_r2l_w, v_r2h_w);
    203    const __m256i v_r3_w = _mm256_packs_epi32(v_r3l_w, v_r3h_w);
    204 
    205    _mm256_store_si256((__m256i *)(d), v_r0_w);
    206    _mm256_store_si256((__m256i *)(d + 16), v_r1_w);
    207    _mm256_store_si256((__m256i *)(d + 32), v_r2_w);
    208    _mm256_store_si256((__m256i *)(d + 48), v_r3_w);
    209 
    210    a += 64;
    211    b += 64;
    212    d += 64;
    213    N -= 64;
    214  } while (N);
    215 }