tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

sse_avx2.c (14122B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <smmintrin.h>
     13 #include <immintrin.h>
     14 
     15 #include "config/aom_config.h"
     16 #include "config/aom_dsp_rtcd.h"
     17 
     18 #include "aom_ports/mem.h"
     19 #include "aom_dsp/x86/synonyms.h"
     20 #include "aom_dsp/x86/synonyms_avx2.h"
     21 
     22 static inline void sse_w32_avx2(__m256i *sum, const uint8_t *a,
     23                                const uint8_t *b) {
     24  const __m256i v_a0 = yy_loadu_256(a);
     25  const __m256i v_b0 = yy_loadu_256(b);
     26  const __m256i zero = _mm256_setzero_si256();
     27  const __m256i v_a00_w = _mm256_unpacklo_epi8(v_a0, zero);
     28  const __m256i v_a01_w = _mm256_unpackhi_epi8(v_a0, zero);
     29  const __m256i v_b00_w = _mm256_unpacklo_epi8(v_b0, zero);
     30  const __m256i v_b01_w = _mm256_unpackhi_epi8(v_b0, zero);
     31  const __m256i v_d00_w = _mm256_sub_epi16(v_a00_w, v_b00_w);
     32  const __m256i v_d01_w = _mm256_sub_epi16(v_a01_w, v_b01_w);
     33  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d00_w, v_d00_w));
     34  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d01_w, v_d01_w));
     35 }
     36 
     37 static inline int64_t summary_all_avx2(const __m256i *sum_all) {
     38  int64_t sum;
     39  __m256i zero = _mm256_setzero_si256();
     40  const __m256i sum0_4x64 = _mm256_unpacklo_epi32(*sum_all, zero);
     41  const __m256i sum1_4x64 = _mm256_unpackhi_epi32(*sum_all, zero);
     42  const __m256i sum_4x64 = _mm256_add_epi64(sum0_4x64, sum1_4x64);
     43  const __m128i sum_2x64 = _mm_add_epi64(_mm256_castsi256_si128(sum_4x64),
     44                                         _mm256_extracti128_si256(sum_4x64, 1));
     45  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
     46  xx_storel_64(&sum, sum_1x64);
     47  return sum;
     48 }
     49 
     50 #if CONFIG_AV1_HIGHBITDEPTH
     51 static inline void summary_32_avx2(const __m256i *sum32, __m256i *sum) {
     52  const __m256i sum0_4x64 =
     53      _mm256_cvtepu32_epi64(_mm256_castsi256_si128(*sum32));
     54  const __m256i sum1_4x64 =
     55      _mm256_cvtepu32_epi64(_mm256_extracti128_si256(*sum32, 1));
     56  const __m256i sum_4x64 = _mm256_add_epi64(sum0_4x64, sum1_4x64);
     57  *sum = _mm256_add_epi64(*sum, sum_4x64);
     58 }
     59 
     60 static inline int64_t summary_4x64_avx2(const __m256i sum_4x64) {
     61  int64_t sum;
     62  const __m128i sum_2x64 = _mm_add_epi64(_mm256_castsi256_si128(sum_4x64),
     63                                         _mm256_extracti128_si256(sum_4x64, 1));
     64  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
     65 
     66  xx_storel_64(&sum, sum_1x64);
     67  return sum;
     68 }
     69 #endif
     70 
     71 static inline void sse_w4x4_avx2(const uint8_t *a, int a_stride,
     72                                 const uint8_t *b, int b_stride, __m256i *sum) {
     73  const __m128i v_a0 = xx_loadl_32(a);
     74  const __m128i v_a1 = xx_loadl_32(a + a_stride);
     75  const __m128i v_a2 = xx_loadl_32(a + a_stride * 2);
     76  const __m128i v_a3 = xx_loadl_32(a + a_stride * 3);
     77  const __m128i v_b0 = xx_loadl_32(b);
     78  const __m128i v_b1 = xx_loadl_32(b + b_stride);
     79  const __m128i v_b2 = xx_loadl_32(b + b_stride * 2);
     80  const __m128i v_b3 = xx_loadl_32(b + b_stride * 3);
     81  const __m128i v_a0123 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_a0, v_a1),
     82                                             _mm_unpacklo_epi32(v_a2, v_a3));
     83  const __m128i v_b0123 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_b0, v_b1),
     84                                             _mm_unpacklo_epi32(v_b2, v_b3));
     85  const __m256i v_a_w = _mm256_cvtepu8_epi16(v_a0123);
     86  const __m256i v_b_w = _mm256_cvtepu8_epi16(v_b0123);
     87  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
     88  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
     89 }
     90 
     91 static inline void sse_w8x2_avx2(const uint8_t *a, int a_stride,
     92                                 const uint8_t *b, int b_stride, __m256i *sum) {
     93  const __m128i v_a0 = xx_loadl_64(a);
     94  const __m128i v_a1 = xx_loadl_64(a + a_stride);
     95  const __m128i v_b0 = xx_loadl_64(b);
     96  const __m128i v_b1 = xx_loadl_64(b + b_stride);
     97  const __m256i v_a_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_a0, v_a1));
     98  const __m256i v_b_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_b0, v_b1));
     99  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
    100  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
    101 }
    102 
    103 int64_t aom_sse_avx2(const uint8_t *a, int a_stride, const uint8_t *b,
    104                     int b_stride, int width, int height) {
    105  int32_t y = 0;
    106  int64_t sse = 0;
    107  __m256i sum = _mm256_setzero_si256();
    108  __m256i zero = _mm256_setzero_si256();
    109  switch (width) {
    110    case 4:
    111      do {
    112        sse_w4x4_avx2(a, a_stride, b, b_stride, &sum);
    113        a += a_stride << 2;
    114        b += b_stride << 2;
    115        y += 4;
    116      } while (y < height);
    117      sse = summary_all_avx2(&sum);
    118      break;
    119    case 8:
    120      do {
    121        sse_w8x2_avx2(a, a_stride, b, b_stride, &sum);
    122        a += a_stride << 1;
    123        b += b_stride << 1;
    124        y += 2;
    125      } while (y < height);
    126      sse = summary_all_avx2(&sum);
    127      break;
    128    case 16:
    129      do {
    130        const __m128i v_a0 = xx_loadu_128(a);
    131        const __m128i v_a1 = xx_loadu_128(a + a_stride);
    132        const __m128i v_b0 = xx_loadu_128(b);
    133        const __m128i v_b1 = xx_loadu_128(b + b_stride);
    134        const __m256i v_a =
    135            _mm256_insertf128_si256(_mm256_castsi128_si256(v_a0), v_a1, 0x01);
    136        const __m256i v_b =
    137            _mm256_insertf128_si256(_mm256_castsi128_si256(v_b0), v_b1, 0x01);
    138        const __m256i v_al = _mm256_unpacklo_epi8(v_a, zero);
    139        const __m256i v_au = _mm256_unpackhi_epi8(v_a, zero);
    140        const __m256i v_bl = _mm256_unpacklo_epi8(v_b, zero);
    141        const __m256i v_bu = _mm256_unpackhi_epi8(v_b, zero);
    142        const __m256i v_asub = _mm256_sub_epi16(v_al, v_bl);
    143        const __m256i v_bsub = _mm256_sub_epi16(v_au, v_bu);
    144        const __m256i temp =
    145            _mm256_add_epi32(_mm256_madd_epi16(v_asub, v_asub),
    146                             _mm256_madd_epi16(v_bsub, v_bsub));
    147        sum = _mm256_add_epi32(sum, temp);
    148        a += a_stride << 1;
    149        b += b_stride << 1;
    150        y += 2;
    151      } while (y < height);
    152      sse = summary_all_avx2(&sum);
    153      break;
    154    case 32:
    155      do {
    156        sse_w32_avx2(&sum, a, b);
    157        a += a_stride;
    158        b += b_stride;
    159        y += 1;
    160      } while (y < height);
    161      sse = summary_all_avx2(&sum);
    162      break;
    163    case 64:
    164      do {
    165        sse_w32_avx2(&sum, a, b);
    166        sse_w32_avx2(&sum, a + 32, b + 32);
    167        a += a_stride;
    168        b += b_stride;
    169        y += 1;
    170      } while (y < height);
    171      sse = summary_all_avx2(&sum);
    172      break;
    173    case 128:
    174      do {
    175        sse_w32_avx2(&sum, a, b);
    176        sse_w32_avx2(&sum, a + 32, b + 32);
    177        sse_w32_avx2(&sum, a + 64, b + 64);
    178        sse_w32_avx2(&sum, a + 96, b + 96);
    179        a += a_stride;
    180        b += b_stride;
    181        y += 1;
    182      } while (y < height);
    183      sse = summary_all_avx2(&sum);
    184      break;
    185    default:
    186      if ((width & 0x07) == 0) {
    187        do {
    188          int i = 0;
    189          do {
    190            sse_w8x2_avx2(a + i, a_stride, b + i, b_stride, &sum);
    191            i += 8;
    192          } while (i < width);
    193          a += a_stride << 1;
    194          b += b_stride << 1;
    195          y += 2;
    196        } while (y < height);
    197      } else {
    198        do {
    199          int i = 0;
    200          do {
    201            sse_w8x2_avx2(a + i, a_stride, b + i, b_stride, &sum);
    202            const uint8_t *a2 = a + i + (a_stride << 1);
    203            const uint8_t *b2 = b + i + (b_stride << 1);
    204            sse_w8x2_avx2(a2, a_stride, b2, b_stride, &sum);
    205            i += 8;
    206          } while (i + 4 < width);
    207          sse_w4x4_avx2(a + i, a_stride, b + i, b_stride, &sum);
    208          a += a_stride << 2;
    209          b += b_stride << 2;
    210          y += 4;
    211        } while (y < height);
    212      }
    213      sse = summary_all_avx2(&sum);
    214      break;
    215  }
    216 
    217  return sse;
    218 }
    219 
    220 #if CONFIG_AV1_HIGHBITDEPTH
    221 static inline void highbd_sse_w16_avx2(__m256i *sum, const uint16_t *a,
    222                                       const uint16_t *b) {
    223  const __m256i v_a_w = yy_loadu_256(a);
    224  const __m256i v_b_w = yy_loadu_256(b);
    225  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
    226  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
    227 }
    228 
    229 static inline void highbd_sse_w4x4_avx2(__m256i *sum, const uint16_t *a,
    230                                        int a_stride, const uint16_t *b,
    231                                        int b_stride) {
    232  const __m128i v_a0 = xx_loadl_64(a);
    233  const __m128i v_a1 = xx_loadl_64(a + a_stride);
    234  const __m128i v_a2 = xx_loadl_64(a + a_stride * 2);
    235  const __m128i v_a3 = xx_loadl_64(a + a_stride * 3);
    236  const __m128i v_b0 = xx_loadl_64(b);
    237  const __m128i v_b1 = xx_loadl_64(b + b_stride);
    238  const __m128i v_b2 = xx_loadl_64(b + b_stride * 2);
    239  const __m128i v_b3 = xx_loadl_64(b + b_stride * 3);
    240  const __m256i v_a_w = yy_set_m128i(_mm_unpacklo_epi64(v_a0, v_a1),
    241                                     _mm_unpacklo_epi64(v_a2, v_a3));
    242  const __m256i v_b_w = yy_set_m128i(_mm_unpacklo_epi64(v_b0, v_b1),
    243                                     _mm_unpacklo_epi64(v_b2, v_b3));
    244  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
    245  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
    246 }
    247 
    248 static inline void highbd_sse_w8x2_avx2(__m256i *sum, const uint16_t *a,
    249                                        int a_stride, const uint16_t *b,
    250                                        int b_stride) {
    251  const __m256i v_a_w = yy_loadu2_128(a + a_stride, a);
    252  const __m256i v_b_w = yy_loadu2_128(b + b_stride, b);
    253  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
    254  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
    255 }
    256 
    257 int64_t aom_highbd_sse_avx2(const uint8_t *a8, int a_stride, const uint8_t *b8,
    258                            int b_stride, int width, int height) {
    259  int32_t y = 0;
    260  int64_t sse = 0;
    261  uint16_t *a = CONVERT_TO_SHORTPTR(a8);
    262  uint16_t *b = CONVERT_TO_SHORTPTR(b8);
    263  __m256i sum = _mm256_setzero_si256();
    264  switch (width) {
    265    case 4:
    266      do {
    267        highbd_sse_w4x4_avx2(&sum, a, a_stride, b, b_stride);
    268        a += a_stride << 2;
    269        b += b_stride << 2;
    270        y += 4;
    271      } while (y < height);
    272      sse = summary_all_avx2(&sum);
    273      break;
    274    case 8:
    275      do {
    276        highbd_sse_w8x2_avx2(&sum, a, a_stride, b, b_stride);
    277        a += a_stride << 1;
    278        b += b_stride << 1;
    279        y += 2;
    280      } while (y < height);
    281      sse = summary_all_avx2(&sum);
    282      break;
    283    case 16:
    284      do {
    285        highbd_sse_w16_avx2(&sum, a, b);
    286        a += a_stride;
    287        b += b_stride;
    288        y += 1;
    289      } while (y < height);
    290      sse = summary_all_avx2(&sum);
    291      break;
    292    case 32:
    293      do {
    294        int l = 0;
    295        __m256i sum32 = _mm256_setzero_si256();
    296        do {
    297          highbd_sse_w16_avx2(&sum32, a, b);
    298          highbd_sse_w16_avx2(&sum32, a + 16, b + 16);
    299          a += a_stride;
    300          b += b_stride;
    301          l += 1;
    302        } while (l < 64 && l < (height - y));
    303        summary_32_avx2(&sum32, &sum);
    304        y += 64;
    305      } while (y < height);
    306      sse = summary_4x64_avx2(sum);
    307      break;
    308    case 64:
    309      do {
    310        int l = 0;
    311        __m256i sum32 = _mm256_setzero_si256();
    312        do {
    313          highbd_sse_w16_avx2(&sum32, a, b);
    314          highbd_sse_w16_avx2(&sum32, a + 16 * 1, b + 16 * 1);
    315          highbd_sse_w16_avx2(&sum32, a + 16 * 2, b + 16 * 2);
    316          highbd_sse_w16_avx2(&sum32, a + 16 * 3, b + 16 * 3);
    317          a += a_stride;
    318          b += b_stride;
    319          l += 1;
    320        } while (l < 32 && l < (height - y));
    321        summary_32_avx2(&sum32, &sum);
    322        y += 32;
    323      } while (y < height);
    324      sse = summary_4x64_avx2(sum);
    325      break;
    326    case 128:
    327      do {
    328        int l = 0;
    329        __m256i sum32 = _mm256_setzero_si256();
    330        do {
    331          highbd_sse_w16_avx2(&sum32, a, b);
    332          highbd_sse_w16_avx2(&sum32, a + 16 * 1, b + 16 * 1);
    333          highbd_sse_w16_avx2(&sum32, a + 16 * 2, b + 16 * 2);
    334          highbd_sse_w16_avx2(&sum32, a + 16 * 3, b + 16 * 3);
    335          highbd_sse_w16_avx2(&sum32, a + 16 * 4, b + 16 * 4);
    336          highbd_sse_w16_avx2(&sum32, a + 16 * 5, b + 16 * 5);
    337          highbd_sse_w16_avx2(&sum32, a + 16 * 6, b + 16 * 6);
    338          highbd_sse_w16_avx2(&sum32, a + 16 * 7, b + 16 * 7);
    339          a += a_stride;
    340          b += b_stride;
    341          l += 1;
    342        } while (l < 16 && l < (height - y));
    343        summary_32_avx2(&sum32, &sum);
    344        y += 16;
    345      } while (y < height);
    346      sse = summary_4x64_avx2(sum);
    347      break;
    348    default:
    349      if (width & 0x7) {
    350        do {
    351          int i = 0;
    352          __m256i sum32 = _mm256_setzero_si256();
    353          do {
    354            highbd_sse_w8x2_avx2(&sum32, a + i, a_stride, b + i, b_stride);
    355            const uint16_t *a2 = a + i + (a_stride << 1);
    356            const uint16_t *b2 = b + i + (b_stride << 1);
    357            highbd_sse_w8x2_avx2(&sum32, a2, a_stride, b2, b_stride);
    358            i += 8;
    359          } while (i + 4 < width);
    360          highbd_sse_w4x4_avx2(&sum32, a + i, a_stride, b + i, b_stride);
    361          summary_32_avx2(&sum32, &sum);
    362          a += a_stride << 2;
    363          b += b_stride << 2;
    364          y += 4;
    365        } while (y < height);
    366      } else {
    367        do {
    368          int l = 0;
    369          __m256i sum32 = _mm256_setzero_si256();
    370          do {
    371            int i = 0;
    372            do {
    373              highbd_sse_w8x2_avx2(&sum32, a + i, a_stride, b + i, b_stride);
    374              i += 8;
    375            } while (i < width);
    376            a += a_stride << 1;
    377            b += b_stride << 1;
    378            l += 2;
    379          } while (l < 8 && l < (height - y));
    380          summary_32_avx2(&sum32, &sum);
    381          y += 8;
    382        } while (y < height);
    383      }
    384      sse = summary_4x64_avx2(sum);
    385      break;
    386  }
    387  return sse;
    388 }
    389 #endif  // CONFIG_AV1_HIGHBITDEPTH