SIMD_avx2.cpp (10081B)
1 /* vim: set ts=8 sts=2 et sw=2 tw=80: */ 2 /* This Source Code Form is subject to the terms of the Mozilla Public 3 * License, v. 2.0. If a copy of the MPL was not distributed with this 4 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 5 6 #include "mozilla/SIMD.h" 7 8 #include "mozilla/SSE.h" 9 #include "mozilla/Assertions.h" 10 11 // Restricting to x86_64 simplifies things, and we're not particularly 12 // worried about slightly degraded performance on 32 bit processors which 13 // support AVX2, as this should be quite a minority. 14 #if defined(MOZILLA_MAY_SUPPORT_AVX2) && defined(__x86_64__) 15 16 # include <cstring> 17 # include <immintrin.h> 18 # include <stdint.h> 19 # include <type_traits> 20 21 namespace mozilla { 22 23 const __m256i* Cast256(uintptr_t ptr) { 24 return reinterpret_cast<const __m256i*>(ptr); 25 } 26 27 template <typename T> 28 T GetAs(uintptr_t ptr) { 29 return *reinterpret_cast<const T*>(ptr); 30 } 31 32 uintptr_t AlignDown32(uintptr_t ptr) { return ptr & ~0x1f; } 33 34 uintptr_t AlignUp32(uintptr_t ptr) { return AlignDown32(ptr + 0x1f); } 35 36 template <typename TValue> 37 __m128i CmpEq128(__m128i a, __m128i b) { 38 static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2); 39 if (sizeof(TValue) == 1) { 40 return _mm_cmpeq_epi8(a, b); 41 } 42 return _mm_cmpeq_epi16(a, b); 43 } 44 45 template <typename TValue> 46 __m256i CmpEq256(__m256i a, __m256i b) { 47 static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2 || 48 sizeof(TValue) == 4 || sizeof(TValue) == 8); 49 if (sizeof(TValue) == 1) { 50 return _mm256_cmpeq_epi8(a, b); 51 } 52 if (sizeof(TValue) == 2) { 53 return _mm256_cmpeq_epi16(a, b); 54 } 55 if (sizeof(TValue) == 4) { 56 return _mm256_cmpeq_epi32(a, b); 57 } 58 59 return _mm256_cmpeq_epi64(a, b); 60 } 61 62 # if defined(__GNUC__) && !defined(__clang__) 63 64 // See the comment in SIMD.cpp over Load32BitsIntoXMM. This is just adapted 65 // from that workaround. Testing this, it also yields the correct instructions 66 // across all tested compilers. 67 __m128i Load64BitsIntoXMM(uintptr_t ptr) { 68 int64_t tmp; 69 memcpy(&tmp, reinterpret_cast<const void*>(ptr), sizeof(tmp)); 70 return _mm_cvtsi64_si128(tmp); 71 } 72 73 # else 74 75 __m128i Load64BitsIntoXMM(uintptr_t ptr) { 76 return _mm_loadu_si64(reinterpret_cast<const __m128i*>(ptr)); 77 } 78 79 # endif 80 81 template <typename TValue> 82 const TValue* Check4x8Bytes(__m128i needle, uintptr_t a, uintptr_t b, 83 uintptr_t c, uintptr_t d) { 84 __m128i haystackA = Load64BitsIntoXMM(a); 85 __m128i cmpA = CmpEq128<TValue>(needle, haystackA); 86 __m128i haystackB = Load64BitsIntoXMM(b); 87 __m128i cmpB = CmpEq128<TValue>(needle, haystackB); 88 __m128i haystackC = Load64BitsIntoXMM(c); 89 __m128i cmpC = CmpEq128<TValue>(needle, haystackC); 90 __m128i haystackD = Load64BitsIntoXMM(d); 91 __m128i cmpD = CmpEq128<TValue>(needle, haystackD); 92 __m128i or_ab = _mm_or_si128(cmpA, cmpB); 93 __m128i or_cd = _mm_or_si128(cmpC, cmpD); 94 __m128i or_abcd = _mm_or_si128(or_ab, or_cd); 95 int orMask = _mm_movemask_epi8(or_abcd); 96 if (orMask & 0xff) { 97 int cmpMask; 98 cmpMask = _mm_movemask_epi8(cmpA); 99 if (cmpMask & 0xff) { 100 return reinterpret_cast<const TValue*>(a + __builtin_ctz(cmpMask)); 101 } 102 cmpMask = _mm_movemask_epi8(cmpB); 103 if (cmpMask & 0xff) { 104 return reinterpret_cast<const TValue*>(b + __builtin_ctz(cmpMask)); 105 } 106 cmpMask = _mm_movemask_epi8(cmpC); 107 if (cmpMask & 0xff) { 108 return reinterpret_cast<const TValue*>(c + __builtin_ctz(cmpMask)); 109 } 110 cmpMask = _mm_movemask_epi8(cmpD); 111 if (cmpMask & 0xff) { 112 return reinterpret_cast<const TValue*>(d + __builtin_ctz(cmpMask)); 113 } 114 } 115 116 return nullptr; 117 } 118 119 template <typename TValue> 120 const TValue* Check4x32Bytes(__m256i needle, uintptr_t a, uintptr_t b, 121 uintptr_t c, uintptr_t d) { 122 __m256i haystackA = _mm256_loadu_si256(Cast256(a)); 123 __m256i cmpA = CmpEq256<TValue>(needle, haystackA); 124 __m256i haystackB = _mm256_loadu_si256(Cast256(b)); 125 __m256i cmpB = CmpEq256<TValue>(needle, haystackB); 126 __m256i haystackC = _mm256_loadu_si256(Cast256(c)); 127 __m256i cmpC = CmpEq256<TValue>(needle, haystackC); 128 __m256i haystackD = _mm256_loadu_si256(Cast256(d)); 129 __m256i cmpD = CmpEq256<TValue>(needle, haystackD); 130 __m256i or_ab = _mm256_or_si256(cmpA, cmpB); 131 __m256i or_cd = _mm256_or_si256(cmpC, cmpD); 132 __m256i or_abcd = _mm256_or_si256(or_ab, or_cd); 133 int orMask = _mm256_movemask_epi8(or_abcd); 134 if (orMask) { 135 int cmpMask; 136 cmpMask = _mm256_movemask_epi8(cmpA); 137 if (cmpMask) { 138 return reinterpret_cast<const TValue*>(a + __builtin_ctz(cmpMask)); 139 } 140 cmpMask = _mm256_movemask_epi8(cmpB); 141 if (cmpMask) { 142 return reinterpret_cast<const TValue*>(b + __builtin_ctz(cmpMask)); 143 } 144 cmpMask = _mm256_movemask_epi8(cmpC); 145 if (cmpMask) { 146 return reinterpret_cast<const TValue*>(c + __builtin_ctz(cmpMask)); 147 } 148 cmpMask = _mm256_movemask_epi8(cmpD); 149 if (cmpMask) { 150 return reinterpret_cast<const TValue*>(d + __builtin_ctz(cmpMask)); 151 } 152 } 153 154 return nullptr; 155 } 156 157 template <typename TValue> 158 const TValue* FindInBufferAVX2(const TValue* ptr, TValue value, size_t length) { 159 static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2 || 160 sizeof(TValue) == 4 || sizeof(TValue) == 8); 161 static_assert(std::is_unsigned<TValue>::value); 162 163 // Load our needle into a 32-byte register 164 __m256i needle; 165 if (sizeof(TValue) == 1) { 166 needle = _mm256_set1_epi8(value); 167 } else if (sizeof(TValue) == 2) { 168 needle = _mm256_set1_epi16(value); 169 } else if (sizeof(TValue) == 4) { 170 needle = _mm256_set1_epi32(value); 171 } else { 172 needle = _mm256_set1_epi64x(value); 173 } 174 175 size_t numBytes = length * sizeof(TValue); 176 uintptr_t cur = reinterpret_cast<uintptr_t>(ptr); 177 uintptr_t end = cur + numBytes; 178 179 if (numBytes < 8 || (sizeof(TValue) >= 4 && numBytes < 32)) { 180 while (cur < end) { 181 if (GetAs<TValue>(cur) == value) { 182 return reinterpret_cast<const TValue*>(cur); 183 } 184 cur += sizeof(TValue); 185 } 186 return nullptr; 187 } 188 189 if constexpr (sizeof(TValue) < 4) { 190 if (numBytes < 32) { 191 __m128i needle_narrow; 192 if (sizeof(TValue) == 1) { 193 needle_narrow = _mm_set1_epi8(value); 194 } else { 195 needle_narrow = _mm_set1_epi16(value); 196 } 197 uintptr_t a = cur; 198 uintptr_t b = cur + ((numBytes & 16) >> 1); 199 uintptr_t c = end - 8 - ((numBytes & 16) >> 1); 200 uintptr_t d = end - 8; 201 return Check4x8Bytes<TValue>(needle_narrow, a, b, c, d); 202 } 203 } 204 205 if (numBytes < 128) { 206 // NOTE: here and below, we have some bit fiddling which could look a 207 // little weird. The important thing to note though is it's just a trick 208 // for getting the number 32 if numBytes is greater than or equal to 64, 209 // and 0 otherwise. This lets us fully cover the range without any 210 // branching for the case where numBytes is in [32,64), and [64,128). We get 211 // four ranges from this - if numbytes > 64, we get: 212 // [0,32), [32,64], [end - 64), [end - 32) 213 // and if numbytes < 64, we get 214 // [0,32), [0,32), [end - 32), [end - 32) 215 uintptr_t a = cur; 216 uintptr_t b = cur + ((numBytes & 64) >> 1); 217 uintptr_t c = end - 32 - ((numBytes & 64) >> 1); 218 uintptr_t d = end - 32; 219 return Check4x32Bytes<TValue>(needle, a, b, c, d); 220 } 221 222 // Get the initial unaligned load out of the way. This will overlap with the 223 // aligned stuff below, but the overlapped part should effectively be free 224 // (relative to a mispredict from doing a byte-by-byte loop). 225 __m256i haystack = _mm256_loadu_si256(Cast256(cur)); 226 __m256i cmp = CmpEq256<TValue>(needle, haystack); 227 int cmpMask = _mm256_movemask_epi8(cmp); 228 if (cmpMask) { 229 return reinterpret_cast<const TValue*>(cur + __builtin_ctz(cmpMask)); 230 } 231 232 // Now we're working with aligned memory. Hooray! \o/ 233 cur = AlignUp32(cur); 234 235 uintptr_t tailStartPtr = AlignDown32(end - 96); 236 uintptr_t tailEndPtr = end - 32; 237 238 while (cur < tailStartPtr) { 239 uintptr_t a = cur; 240 uintptr_t b = cur + 32; 241 uintptr_t c = cur + 64; 242 uintptr_t d = cur + 96; 243 const TValue* result = Check4x32Bytes<TValue>(needle, a, b, c, d); 244 if (result) { 245 return result; 246 } 247 cur += 128; 248 } 249 250 uintptr_t a = tailStartPtr; 251 uintptr_t b = tailStartPtr + 32; 252 uintptr_t c = tailStartPtr + 64; 253 uintptr_t d = tailEndPtr; 254 return Check4x32Bytes<TValue>(needle, a, b, c, d); 255 } 256 257 const char* SIMD::memchr8AVX2(const char* ptr, char value, size_t length) { 258 const unsigned char* uptr = reinterpret_cast<const unsigned char*>(ptr); 259 unsigned char uvalue = static_cast<unsigned char>(value); 260 const unsigned char* uresult = 261 FindInBufferAVX2<unsigned char>(uptr, uvalue, length); 262 return reinterpret_cast<const char*>(uresult); 263 } 264 265 const char16_t* SIMD::memchr16AVX2(const char16_t* ptr, char16_t value, 266 size_t length) { 267 return FindInBufferAVX2<char16_t>(ptr, value, length); 268 } 269 270 const uint32_t* SIMD::memchr32AVX2(const uint32_t* ptr, uint32_t value, 271 size_t length) { 272 return FindInBufferAVX2<uint32_t>(ptr, value, length); 273 } 274 275 const uint64_t* SIMD::memchr64AVX2(const uint64_t* ptr, uint64_t value, 276 size_t length) { 277 return FindInBufferAVX2<uint64_t>(ptr, value, length); 278 } 279 280 } // namespace mozilla 281 282 #else 283 284 namespace mozilla { 285 286 const char* SIMD::memchr8AVX2(const char* ptr, char value, size_t length) { 287 MOZ_RELEASE_ASSERT(false, "AVX2 not supported in this binary."); 288 } 289 290 const char16_t* SIMD::memchr16AVX2(const char16_t* ptr, char16_t value, 291 size_t length) { 292 MOZ_RELEASE_ASSERT(false, "AVX2 not supported in this binary."); 293 } 294 295 const uint32_t* SIMD::memchr32AVX2(const uint32_t* ptr, uint32_t value, 296 size_t length) { 297 MOZ_RELEASE_ASSERT(false, "AVX2 not supported in this binary."); 298 } 299 300 const uint64_t* SIMD::memchr64AVX2(const uint64_t* ptr, uint64_t value, 301 size_t length) { 302 MOZ_RELEASE_ASSERT(false, "AVX2 not supported in this binary."); 303 } 304 305 } // namespace mozilla 306 307 #endif