tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

SIMD_avx2.cpp (10081B)


      1 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
      2 /* This Source Code Form is subject to the terms of the Mozilla Public
      3 * License, v. 2.0. If a copy of the MPL was not distributed with this
      4 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      5 
      6 #include "mozilla/SIMD.h"
      7 
      8 #include "mozilla/SSE.h"
      9 #include "mozilla/Assertions.h"
     10 
     11 // Restricting to x86_64 simplifies things, and we're not particularly
     12 // worried about slightly degraded performance on 32 bit processors which
     13 // support AVX2, as this should be quite a minority.
     14 #if defined(MOZILLA_MAY_SUPPORT_AVX2) && defined(__x86_64__)
     15 
     16 #  include <cstring>
     17 #  include <immintrin.h>
     18 #  include <stdint.h>
     19 #  include <type_traits>
     20 
     21 namespace mozilla {
     22 
     23 const __m256i* Cast256(uintptr_t ptr) {
     24  return reinterpret_cast<const __m256i*>(ptr);
     25 }
     26 
     27 template <typename T>
     28 T GetAs(uintptr_t ptr) {
     29  return *reinterpret_cast<const T*>(ptr);
     30 }
     31 
     32 uintptr_t AlignDown32(uintptr_t ptr) { return ptr & ~0x1f; }
     33 
     34 uintptr_t AlignUp32(uintptr_t ptr) { return AlignDown32(ptr + 0x1f); }
     35 
     36 template <typename TValue>
     37 __m128i CmpEq128(__m128i a, __m128i b) {
     38  static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2);
     39  if (sizeof(TValue) == 1) {
     40    return _mm_cmpeq_epi8(a, b);
     41  }
     42  return _mm_cmpeq_epi16(a, b);
     43 }
     44 
     45 template <typename TValue>
     46 __m256i CmpEq256(__m256i a, __m256i b) {
     47  static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2 ||
     48                sizeof(TValue) == 4 || sizeof(TValue) == 8);
     49  if (sizeof(TValue) == 1) {
     50    return _mm256_cmpeq_epi8(a, b);
     51  }
     52  if (sizeof(TValue) == 2) {
     53    return _mm256_cmpeq_epi16(a, b);
     54  }
     55  if (sizeof(TValue) == 4) {
     56    return _mm256_cmpeq_epi32(a, b);
     57  }
     58 
     59  return _mm256_cmpeq_epi64(a, b);
     60 }
     61 
     62 #  if defined(__GNUC__) && !defined(__clang__)
     63 
     64 // See the comment in SIMD.cpp over Load32BitsIntoXMM. This is just adapted
     65 // from that workaround. Testing this, it also yields the correct instructions
     66 // across all tested compilers.
     67 __m128i Load64BitsIntoXMM(uintptr_t ptr) {
     68  int64_t tmp;
     69  memcpy(&tmp, reinterpret_cast<const void*>(ptr), sizeof(tmp));
     70  return _mm_cvtsi64_si128(tmp);
     71 }
     72 
     73 #  else
     74 
     75 __m128i Load64BitsIntoXMM(uintptr_t ptr) {
     76  return _mm_loadu_si64(reinterpret_cast<const __m128i*>(ptr));
     77 }
     78 
     79 #  endif
     80 
     81 template <typename TValue>
     82 const TValue* Check4x8Bytes(__m128i needle, uintptr_t a, uintptr_t b,
     83                            uintptr_t c, uintptr_t d) {
     84  __m128i haystackA = Load64BitsIntoXMM(a);
     85  __m128i cmpA = CmpEq128<TValue>(needle, haystackA);
     86  __m128i haystackB = Load64BitsIntoXMM(b);
     87  __m128i cmpB = CmpEq128<TValue>(needle, haystackB);
     88  __m128i haystackC = Load64BitsIntoXMM(c);
     89  __m128i cmpC = CmpEq128<TValue>(needle, haystackC);
     90  __m128i haystackD = Load64BitsIntoXMM(d);
     91  __m128i cmpD = CmpEq128<TValue>(needle, haystackD);
     92  __m128i or_ab = _mm_or_si128(cmpA, cmpB);
     93  __m128i or_cd = _mm_or_si128(cmpC, cmpD);
     94  __m128i or_abcd = _mm_or_si128(or_ab, or_cd);
     95  int orMask = _mm_movemask_epi8(or_abcd);
     96  if (orMask & 0xff) {
     97    int cmpMask;
     98    cmpMask = _mm_movemask_epi8(cmpA);
     99    if (cmpMask & 0xff) {
    100      return reinterpret_cast<const TValue*>(a + __builtin_ctz(cmpMask));
    101    }
    102    cmpMask = _mm_movemask_epi8(cmpB);
    103    if (cmpMask & 0xff) {
    104      return reinterpret_cast<const TValue*>(b + __builtin_ctz(cmpMask));
    105    }
    106    cmpMask = _mm_movemask_epi8(cmpC);
    107    if (cmpMask & 0xff) {
    108      return reinterpret_cast<const TValue*>(c + __builtin_ctz(cmpMask));
    109    }
    110    cmpMask = _mm_movemask_epi8(cmpD);
    111    if (cmpMask & 0xff) {
    112      return reinterpret_cast<const TValue*>(d + __builtin_ctz(cmpMask));
    113    }
    114  }
    115 
    116  return nullptr;
    117 }
    118 
    119 template <typename TValue>
    120 const TValue* Check4x32Bytes(__m256i needle, uintptr_t a, uintptr_t b,
    121                             uintptr_t c, uintptr_t d) {
    122  __m256i haystackA = _mm256_loadu_si256(Cast256(a));
    123  __m256i cmpA = CmpEq256<TValue>(needle, haystackA);
    124  __m256i haystackB = _mm256_loadu_si256(Cast256(b));
    125  __m256i cmpB = CmpEq256<TValue>(needle, haystackB);
    126  __m256i haystackC = _mm256_loadu_si256(Cast256(c));
    127  __m256i cmpC = CmpEq256<TValue>(needle, haystackC);
    128  __m256i haystackD = _mm256_loadu_si256(Cast256(d));
    129  __m256i cmpD = CmpEq256<TValue>(needle, haystackD);
    130  __m256i or_ab = _mm256_or_si256(cmpA, cmpB);
    131  __m256i or_cd = _mm256_or_si256(cmpC, cmpD);
    132  __m256i or_abcd = _mm256_or_si256(or_ab, or_cd);
    133  int orMask = _mm256_movemask_epi8(or_abcd);
    134  if (orMask) {
    135    int cmpMask;
    136    cmpMask = _mm256_movemask_epi8(cmpA);
    137    if (cmpMask) {
    138      return reinterpret_cast<const TValue*>(a + __builtin_ctz(cmpMask));
    139    }
    140    cmpMask = _mm256_movemask_epi8(cmpB);
    141    if (cmpMask) {
    142      return reinterpret_cast<const TValue*>(b + __builtin_ctz(cmpMask));
    143    }
    144    cmpMask = _mm256_movemask_epi8(cmpC);
    145    if (cmpMask) {
    146      return reinterpret_cast<const TValue*>(c + __builtin_ctz(cmpMask));
    147    }
    148    cmpMask = _mm256_movemask_epi8(cmpD);
    149    if (cmpMask) {
    150      return reinterpret_cast<const TValue*>(d + __builtin_ctz(cmpMask));
    151    }
    152  }
    153 
    154  return nullptr;
    155 }
    156 
    157 template <typename TValue>
    158 const TValue* FindInBufferAVX2(const TValue* ptr, TValue value, size_t length) {
    159  static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2 ||
    160                sizeof(TValue) == 4 || sizeof(TValue) == 8);
    161  static_assert(std::is_unsigned<TValue>::value);
    162 
    163  // Load our needle into a 32-byte register
    164  __m256i needle;
    165  if (sizeof(TValue) == 1) {
    166    needle = _mm256_set1_epi8(value);
    167  } else if (sizeof(TValue) == 2) {
    168    needle = _mm256_set1_epi16(value);
    169  } else if (sizeof(TValue) == 4) {
    170    needle = _mm256_set1_epi32(value);
    171  } else {
    172    needle = _mm256_set1_epi64x(value);
    173  }
    174 
    175  size_t numBytes = length * sizeof(TValue);
    176  uintptr_t cur = reinterpret_cast<uintptr_t>(ptr);
    177  uintptr_t end = cur + numBytes;
    178 
    179  if (numBytes < 8 || (sizeof(TValue) >= 4 && numBytes < 32)) {
    180    while (cur < end) {
    181      if (GetAs<TValue>(cur) == value) {
    182        return reinterpret_cast<const TValue*>(cur);
    183      }
    184      cur += sizeof(TValue);
    185    }
    186    return nullptr;
    187  }
    188 
    189  if constexpr (sizeof(TValue) < 4) {
    190    if (numBytes < 32) {
    191      __m128i needle_narrow;
    192      if (sizeof(TValue) == 1) {
    193        needle_narrow = _mm_set1_epi8(value);
    194      } else {
    195        needle_narrow = _mm_set1_epi16(value);
    196      }
    197      uintptr_t a = cur;
    198      uintptr_t b = cur + ((numBytes & 16) >> 1);
    199      uintptr_t c = end - 8 - ((numBytes & 16) >> 1);
    200      uintptr_t d = end - 8;
    201      return Check4x8Bytes<TValue>(needle_narrow, a, b, c, d);
    202    }
    203  }
    204 
    205  if (numBytes < 128) {
    206    // NOTE: here and below, we have some bit fiddling which could look a
    207    // little weird. The important thing to note though is it's just a trick
    208    // for getting the number 32 if numBytes is greater than or equal to 64,
    209    // and 0 otherwise. This lets us fully cover the range without any
    210    // branching for the case where numBytes is in [32,64), and [64,128). We get
    211    // four ranges from this - if numbytes > 64, we get:
    212    //   [0,32), [32,64], [end - 64), [end - 32)
    213    // and if numbytes < 64, we get
    214    //   [0,32), [0,32), [end - 32), [end - 32)
    215    uintptr_t a = cur;
    216    uintptr_t b = cur + ((numBytes & 64) >> 1);
    217    uintptr_t c = end - 32 - ((numBytes & 64) >> 1);
    218    uintptr_t d = end - 32;
    219    return Check4x32Bytes<TValue>(needle, a, b, c, d);
    220  }
    221 
    222  // Get the initial unaligned load out of the way. This will overlap with the
    223  // aligned stuff below, but the overlapped part should effectively be free
    224  // (relative to a mispredict from doing a byte-by-byte loop).
    225  __m256i haystack = _mm256_loadu_si256(Cast256(cur));
    226  __m256i cmp = CmpEq256<TValue>(needle, haystack);
    227  int cmpMask = _mm256_movemask_epi8(cmp);
    228  if (cmpMask) {
    229    return reinterpret_cast<const TValue*>(cur + __builtin_ctz(cmpMask));
    230  }
    231 
    232  // Now we're working with aligned memory. Hooray! \o/
    233  cur = AlignUp32(cur);
    234 
    235  uintptr_t tailStartPtr = AlignDown32(end - 96);
    236  uintptr_t tailEndPtr = end - 32;
    237 
    238  while (cur < tailStartPtr) {
    239    uintptr_t a = cur;
    240    uintptr_t b = cur + 32;
    241    uintptr_t c = cur + 64;
    242    uintptr_t d = cur + 96;
    243    const TValue* result = Check4x32Bytes<TValue>(needle, a, b, c, d);
    244    if (result) {
    245      return result;
    246    }
    247    cur += 128;
    248  }
    249 
    250  uintptr_t a = tailStartPtr;
    251  uintptr_t b = tailStartPtr + 32;
    252  uintptr_t c = tailStartPtr + 64;
    253  uintptr_t d = tailEndPtr;
    254  return Check4x32Bytes<TValue>(needle, a, b, c, d);
    255 }
    256 
    257 const char* SIMD::memchr8AVX2(const char* ptr, char value, size_t length) {
    258  const unsigned char* uptr = reinterpret_cast<const unsigned char*>(ptr);
    259  unsigned char uvalue = static_cast<unsigned char>(value);
    260  const unsigned char* uresult =
    261      FindInBufferAVX2<unsigned char>(uptr, uvalue, length);
    262  return reinterpret_cast<const char*>(uresult);
    263 }
    264 
    265 const char16_t* SIMD::memchr16AVX2(const char16_t* ptr, char16_t value,
    266                                   size_t length) {
    267  return FindInBufferAVX2<char16_t>(ptr, value, length);
    268 }
    269 
    270 const uint32_t* SIMD::memchr32AVX2(const uint32_t* ptr, uint32_t value,
    271                                   size_t length) {
    272  return FindInBufferAVX2<uint32_t>(ptr, value, length);
    273 }
    274 
    275 const uint64_t* SIMD::memchr64AVX2(const uint64_t* ptr, uint64_t value,
    276                                   size_t length) {
    277  return FindInBufferAVX2<uint64_t>(ptr, value, length);
    278 }
    279 
    280 }  // namespace mozilla
    281 
    282 #else
    283 
    284 namespace mozilla {
    285 
    286 const char* SIMD::memchr8AVX2(const char* ptr, char value, size_t length) {
    287  MOZ_RELEASE_ASSERT(false, "AVX2 not supported in this binary.");
    288 }
    289 
    290 const char16_t* SIMD::memchr16AVX2(const char16_t* ptr, char16_t value,
    291                                   size_t length) {
    292  MOZ_RELEASE_ASSERT(false, "AVX2 not supported in this binary.");
    293 }
    294 
    295 const uint32_t* SIMD::memchr32AVX2(const uint32_t* ptr, uint32_t value,
    296                                   size_t length) {
    297  MOZ_RELEASE_ASSERT(false, "AVX2 not supported in this binary.");
    298 }
    299 
    300 const uint64_t* SIMD::memchr64AVX2(const uint64_t* ptr, uint64_t value,
    301                                   size_t length) {
    302  MOZ_RELEASE_ASSERT(false, "AVX2 not supported in this binary.");
    303 }
    304 
    305 }  // namespace mozilla
    306 
    307 #endif