x86_256-inl.h (346425B)
1 // Copyright 2019 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when 17 // compiling for that target. 18 // External include guard in highway.h - see comment there. 19 20 // WARNING: most operations do not cross 128-bit block boundaries. In 21 // particular, "Broadcast", pack and zip behavior may be surprising. 22 23 // Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL 24 #include "hwy/base.h" 25 26 // Avoid uninitialized warnings in GCC's avx512fintrin.h - see 27 // https://github.com/google/highway/issues/710) 28 HWY_DIAGNOSTICS(push) 29 #if HWY_COMPILER_GCC_ACTUAL 30 HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") 31 HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, 32 ignored "-Wmaybe-uninitialized") 33 #endif 34 35 // Must come before HWY_COMPILER_CLANGCL 36 #include <immintrin.h> // AVX2+ 37 38 #if HWY_COMPILER_CLANGCL 39 // Including <immintrin.h> should be enough, but Clang's headers helpfully skip 40 // including these headers when _MSC_VER is defined, like when using clang-cl. 41 // Include these directly here. 42 #include <avxintrin.h> 43 // avxintrin defines __m256i and must come before avx2intrin. 44 #include <avx2intrin.h> 45 #include <bmi2intrin.h> // _pext_u64 46 #include <f16cintrin.h> 47 #include <fmaintrin.h> 48 #include <smmintrin.h> 49 50 #if HWY_TARGET <= HWY_AVX10_2 51 #include <avx512bitalgintrin.h> 52 #include <avx512bwintrin.h> 53 #include <avx512cdintrin.h> 54 #include <avx512dqintrin.h> 55 #include <avx512fintrin.h> 56 #include <avx512vbmi2intrin.h> 57 #include <avx512vbmiintrin.h> 58 #include <avx512vbmivlintrin.h> 59 #include <avx512vlbitalgintrin.h> 60 #include <avx512vlbwintrin.h> 61 #include <avx512vlcdintrin.h> 62 #include <avx512vldqintrin.h> 63 #include <avx512vlintrin.h> 64 #include <avx512vlvbmi2intrin.h> 65 #include <avx512vlvnniintrin.h> 66 #include <avx512vnniintrin.h> 67 #include <avx512vpopcntdqintrin.h> 68 #include <avx512vpopcntdqvlintrin.h> 69 // Must come after avx512fintrin, else will not define 512-bit intrinsics. 70 #include <avx512fp16intrin.h> 71 #include <avx512vlfp16intrin.h> 72 #include <gfniintrin.h> 73 #include <vaesintrin.h> 74 #include <vpclmulqdqintrin.h> 75 76 #endif // HWY_TARGET <= HWY_AVX10_2 77 78 // clang-format on 79 #endif // HWY_COMPILER_CLANGCL 80 81 // For half-width vectors. Already includes base.h. 82 #include "hwy/ops/shared-inl.h" 83 // Already included by shared-inl, but do it again to avoid IDE warnings. 84 #include "hwy/ops/x86_128-inl.h" 85 86 HWY_BEFORE_NAMESPACE(); 87 namespace hwy { 88 namespace HWY_NAMESPACE { 89 namespace detail { 90 91 template <typename T> 92 struct Raw256 { 93 using type = __m256i; 94 }; 95 #if HWY_HAVE_FLOAT16 96 template <> 97 struct Raw256<float16_t> { 98 using type = __m256h; 99 }; 100 #endif // HWY_HAVE_FLOAT16 101 template <> 102 struct Raw256<float> { 103 using type = __m256; 104 }; 105 template <> 106 struct Raw256<double> { 107 using type = __m256d; 108 }; 109 110 } // namespace detail 111 112 template <typename T> 113 class Vec256 { 114 using Raw = typename detail::Raw256<T>::type; 115 116 public: 117 using PrivateT = T; // only for DFromV 118 static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV 119 120 // Compound assignment. Only usable if there is a corresponding non-member 121 // binary operator overload. For example, only f32 and f64 support division. 122 HWY_INLINE Vec256& operator*=(const Vec256 other) { 123 return *this = (*this * other); 124 } 125 HWY_INLINE Vec256& operator/=(const Vec256 other) { 126 return *this = (*this / other); 127 } 128 HWY_INLINE Vec256& operator+=(const Vec256 other) { 129 return *this = (*this + other); 130 } 131 HWY_INLINE Vec256& operator-=(const Vec256 other) { 132 return *this = (*this - other); 133 } 134 HWY_INLINE Vec256& operator%=(const Vec256 other) { 135 return *this = (*this % other); 136 } 137 HWY_INLINE Vec256& operator&=(const Vec256 other) { 138 return *this = (*this & other); 139 } 140 HWY_INLINE Vec256& operator|=(const Vec256 other) { 141 return *this = (*this | other); 142 } 143 HWY_INLINE Vec256& operator^=(const Vec256 other) { 144 return *this = (*this ^ other); 145 } 146 147 Raw raw; 148 }; 149 150 namespace detail { 151 152 #if HWY_TARGET <= HWY_AVX3 153 154 // Template arg: sizeof(lane type) 155 template <size_t size> 156 struct RawMask256T {}; 157 template <> 158 struct RawMask256T<1> { 159 using type = __mmask32; 160 }; 161 template <> 162 struct RawMask256T<2> { 163 using type = __mmask16; 164 }; 165 template <> 166 struct RawMask256T<4> { 167 using type = __mmask8; 168 }; 169 template <> 170 struct RawMask256T<8> { 171 using type = __mmask8; 172 }; 173 174 template <typename T> 175 using RawMask256 = typename RawMask256T<sizeof(T)>::type; 176 177 #else // AVX2 or earlier 178 179 template <typename T> 180 using RawMask256 = typename Raw256<T>::type; 181 182 #endif // HWY_TARGET <= HWY_AVX3 183 184 } // namespace detail 185 186 template <typename T> 187 struct Mask256 { 188 using Raw = typename detail::RawMask256<T>; 189 190 using PrivateT = T; // only for DFromM 191 static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromM 192 193 #if HWY_TARGET <= HWY_AVX3 194 static Mask256<T> FromBits(uint64_t mask_bits) { 195 return Mask256<T>{static_cast<Raw>(mask_bits)}; 196 } 197 #else 198 // Lanes are either FF..FF or 0. 199 #endif // HWY_TARGET <= HWY_AVX3 200 201 Raw raw; 202 }; 203 204 template <typename T> 205 using Full256 = Simd<T, 32 / sizeof(T), 0>; 206 207 // ------------------------------ Zero 208 209 // Cannot use VFromD here because it is defined in terms of Zero. 210 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 211 HWY_API Vec256<TFromD<D>> Zero(D /* tag */) { 212 return Vec256<TFromD<D>>{_mm256_setzero_si256()}; 213 } 214 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)> 215 HWY_API Vec256<bfloat16_t> Zero(D /* tag */) { 216 return Vec256<bfloat16_t>{_mm256_setzero_si256()}; 217 } 218 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 219 HWY_API Vec256<float16_t> Zero(D /* tag */) { 220 #if HWY_HAVE_FLOAT16 221 return Vec256<float16_t>{_mm256_setzero_ph()}; 222 #else 223 return Vec256<float16_t>{_mm256_setzero_si256()}; 224 #endif // HWY_HAVE_FLOAT16 225 } 226 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 227 HWY_API Vec256<float> Zero(D /* tag */) { 228 return Vec256<float>{_mm256_setzero_ps()}; 229 } 230 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 231 HWY_API Vec256<double> Zero(D /* tag */) { 232 return Vec256<double>{_mm256_setzero_pd()}; 233 } 234 235 // ------------------------------ BitCast 236 237 namespace detail { 238 239 HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } 240 #if HWY_HAVE_FLOAT16 241 HWY_INLINE __m256i BitCastToInteger(__m256h v) { 242 return _mm256_castph_si256(v); 243 } 244 #endif // HWY_HAVE_FLOAT16 245 HWY_INLINE __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } 246 HWY_INLINE __m256i BitCastToInteger(__m256d v) { 247 return _mm256_castpd_si256(v); 248 } 249 250 #if HWY_AVX3_HAVE_F32_TO_BF16C 251 HWY_INLINE __m256i BitCastToInteger(__m256bh v) { 252 // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to 253 // bit cast a __m256bh to a __m256i as there is currently no intrinsic 254 // available (as of GCC 13 and Clang 17) that can bit cast a __m256bh vector 255 // to a __m256i vector 256 257 #if HWY_COMPILER_GCC || HWY_COMPILER_CLANG 258 // On GCC or Clang, use reinterpret_cast to bit cast a __m256bh to a __m256i 259 return reinterpret_cast<__m256i>(v); 260 #else 261 // On MSVC, use BitCastScalar to bit cast a __m256bh to a __m256i as MSVC does 262 // not allow reinterpret_cast, static_cast, or a C-style cast to be used to 263 // bit cast from one AVX vector type to a different AVX vector type 264 return BitCastScalar<__m256i>(v); 265 #endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG 266 } 267 #endif // HWY_AVX3_HAVE_F32_TO_BF16C 268 269 template <typename T> 270 HWY_INLINE Vec256<uint8_t> BitCastToByte(Vec256<T> v) { 271 return Vec256<uint8_t>{BitCastToInteger(v.raw)}; 272 } 273 274 // Cannot rely on function overloading because return types differ. 275 template <typename T> 276 struct BitCastFromInteger256 { 277 HWY_INLINE __m256i operator()(__m256i v) { return v; } 278 }; 279 #if HWY_HAVE_FLOAT16 280 template <> 281 struct BitCastFromInteger256<float16_t> { 282 HWY_INLINE __m256h operator()(__m256i v) { return _mm256_castsi256_ph(v); } 283 }; 284 #endif // HWY_HAVE_FLOAT16 285 template <> 286 struct BitCastFromInteger256<float> { 287 HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } 288 }; 289 template <> 290 struct BitCastFromInteger256<double> { 291 HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } 292 }; 293 294 template <class D, HWY_IF_V_SIZE_D(D, 32)> 295 HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */, Vec256<uint8_t> v) { 296 return VFromD<D>{BitCastFromInteger256<TFromD<D>>()(v.raw)}; 297 } 298 299 } // namespace detail 300 301 template <class D, HWY_IF_V_SIZE_D(D, 32), typename FromT> 302 HWY_API VFromD<D> BitCast(D d, Vec256<FromT> v) { 303 return detail::BitCastFromByte(d, detail::BitCastToByte(v)); 304 } 305 306 // ------------------------------ Set 307 308 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 309 HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { 310 return VFromD<D>{_mm256_set1_epi8(static_cast<char>(t))}; // NOLINT 311 } 312 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI16_D(D)> 313 HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { 314 return VFromD<D>{_mm256_set1_epi16(static_cast<short>(t))}; // NOLINT 315 } 316 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 317 HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { 318 return VFromD<D>{_mm256_set1_epi32(static_cast<int>(t))}; 319 } 320 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 321 HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { 322 return VFromD<D>{_mm256_set1_epi64x(static_cast<long long>(t))}; // NOLINT 323 } 324 // bfloat16_t is handled by x86_128-inl.h. 325 #if HWY_HAVE_FLOAT16 326 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 327 HWY_API Vec256<float16_t> Set(D /* tag */, float16_t t) { 328 return Vec256<float16_t>{_mm256_set1_ph(t)}; 329 } 330 #endif // HWY_HAVE_FLOAT16 331 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 332 HWY_API Vec256<float> Set(D /* tag */, float t) { 333 return Vec256<float>{_mm256_set1_ps(t)}; 334 } 335 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 336 HWY_API Vec256<double> Set(D /* tag */, double t) { 337 return Vec256<double>{_mm256_set1_pd(t)}; 338 } 339 340 HWY_DIAGNOSTICS(push) 341 HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") 342 343 // Returns a vector with uninitialized elements. 344 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 345 HWY_API VFromD<D> Undefined(D /* tag */) { 346 // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC 347 // generate an XOR instruction. 348 return VFromD<D>{_mm256_undefined_si256()}; 349 } 350 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)> 351 HWY_API Vec256<bfloat16_t> Undefined(D /* tag */) { 352 return Vec256<bfloat16_t>{_mm256_undefined_si256()}; 353 } 354 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 355 HWY_API Vec256<float16_t> Undefined(D /* tag */) { 356 #if HWY_HAVE_FLOAT16 357 return Vec256<float16_t>{_mm256_undefined_ph()}; 358 #else 359 return Vec256<float16_t>{_mm256_undefined_si256()}; 360 #endif 361 } 362 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 363 HWY_API Vec256<float> Undefined(D /* tag */) { 364 return Vec256<float>{_mm256_undefined_ps()}; 365 } 366 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 367 HWY_API Vec256<double> Undefined(D /* tag */) { 368 return Vec256<double>{_mm256_undefined_pd()}; 369 } 370 371 HWY_DIAGNOSTICS(pop) 372 373 // ------------------------------ ResizeBitCast 374 375 // 32-byte vector to 32-byte vector (or 64-byte vector to 64-byte vector on 376 // AVX3) 377 template <class D, class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16), 378 HWY_IF_V_SIZE_D(D, HWY_MAX_LANES_V(FromV) * sizeof(TFromV<FromV>))> 379 HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { 380 return BitCast(d, v); 381 } 382 383 // 32-byte vector to 16-byte vector (or 64-byte vector to 32-byte vector on 384 // AVX3) 385 template <class D, class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16), 386 HWY_IF_V_SIZE_D(D, 387 (HWY_MAX_LANES_V(FromV) * sizeof(TFromV<FromV>)) / 2)> 388 HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { 389 const DFromV<decltype(v)> d_from; 390 const Half<decltype(d_from)> dh_from; 391 return BitCast(d, LowerHalf(dh_from, v)); 392 } 393 394 // 32-byte vector (or 64-byte vector on AVX3) to <= 8-byte vector 395 template <class D, class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16), 396 HWY_IF_V_SIZE_LE_D(D, 8)> 397 HWY_API VFromD<D> ResizeBitCast(D /*d*/, FromV v) { 398 return VFromD<D>{ResizeBitCast(Full128<TFromD<D>>(), v).raw}; 399 } 400 401 // <= 16-byte vector to 32-byte vector 402 template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 16), 403 HWY_IF_V_SIZE_D(D, 32)> 404 HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { 405 return BitCast(d, Vec256<uint8_t>{_mm256_castsi128_si256( 406 ResizeBitCast(Full128<uint8_t>(), v).raw)}); 407 } 408 409 // ------------------------------ Dup128VecFromValues 410 411 template <class D, HWY_IF_UI8_D(D), HWY_IF_V_SIZE_D(D, 32)> 412 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 413 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 414 TFromD<D> t5, TFromD<D> t6, TFromD<D> t7, 415 TFromD<D> t8, TFromD<D> t9, TFromD<D> t10, 416 TFromD<D> t11, TFromD<D> t12, 417 TFromD<D> t13, TFromD<D> t14, 418 TFromD<D> t15) { 419 return VFromD<D>{_mm256_setr_epi8( 420 static_cast<char>(t0), static_cast<char>(t1), static_cast<char>(t2), 421 static_cast<char>(t3), static_cast<char>(t4), static_cast<char>(t5), 422 static_cast<char>(t6), static_cast<char>(t7), static_cast<char>(t8), 423 static_cast<char>(t9), static_cast<char>(t10), static_cast<char>(t11), 424 static_cast<char>(t12), static_cast<char>(t13), static_cast<char>(t14), 425 static_cast<char>(t15), static_cast<char>(t0), static_cast<char>(t1), 426 static_cast<char>(t2), static_cast<char>(t3), static_cast<char>(t4), 427 static_cast<char>(t5), static_cast<char>(t6), static_cast<char>(t7), 428 static_cast<char>(t8), static_cast<char>(t9), static_cast<char>(t10), 429 static_cast<char>(t11), static_cast<char>(t12), static_cast<char>(t13), 430 static_cast<char>(t14), static_cast<char>(t15))}; 431 } 432 433 template <class D, HWY_IF_UI16_D(D), HWY_IF_V_SIZE_D(D, 32)> 434 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 435 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 436 TFromD<D> t5, TFromD<D> t6, 437 TFromD<D> t7) { 438 return VFromD<D>{ 439 _mm256_setr_epi16(static_cast<int16_t>(t0), static_cast<int16_t>(t1), 440 static_cast<int16_t>(t2), static_cast<int16_t>(t3), 441 static_cast<int16_t>(t4), static_cast<int16_t>(t5), 442 static_cast<int16_t>(t6), static_cast<int16_t>(t7), 443 static_cast<int16_t>(t0), static_cast<int16_t>(t1), 444 static_cast<int16_t>(t2), static_cast<int16_t>(t3), 445 static_cast<int16_t>(t4), static_cast<int16_t>(t5), 446 static_cast<int16_t>(t6), static_cast<int16_t>(t7))}; 447 } 448 449 #if HWY_HAVE_FLOAT16 450 template <class D, HWY_IF_F16_D(D), HWY_IF_V_SIZE_D(D, 32)> 451 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 452 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 453 TFromD<D> t5, TFromD<D> t6, 454 TFromD<D> t7) { 455 return VFromD<D>{_mm256_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, 456 t3, t4, t5, t6, t7)}; 457 } 458 #endif 459 460 template <class D, HWY_IF_UI32_D(D), HWY_IF_V_SIZE_D(D, 32)> 461 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 462 TFromD<D> t2, TFromD<D> t3) { 463 return VFromD<D>{ 464 _mm256_setr_epi32(static_cast<int32_t>(t0), static_cast<int32_t>(t1), 465 static_cast<int32_t>(t2), static_cast<int32_t>(t3), 466 static_cast<int32_t>(t0), static_cast<int32_t>(t1), 467 static_cast<int32_t>(t2), static_cast<int32_t>(t3))}; 468 } 469 470 template <class D, HWY_IF_F32_D(D), HWY_IF_V_SIZE_D(D, 32)> 471 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 472 TFromD<D> t2, TFromD<D> t3) { 473 return VFromD<D>{_mm256_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3)}; 474 } 475 476 template <class D, HWY_IF_UI64_D(D), HWY_IF_V_SIZE_D(D, 32)> 477 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { 478 return VFromD<D>{ 479 _mm256_setr_epi64x(static_cast<int64_t>(t0), static_cast<int64_t>(t1), 480 static_cast<int64_t>(t0), static_cast<int64_t>(t1))}; 481 } 482 483 template <class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_D(D, 32)> 484 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { 485 return VFromD<D>{_mm256_setr_pd(t0, t1, t0, t1)}; 486 } 487 488 // ================================================== LOGICAL 489 490 // ------------------------------ And 491 492 template <typename T> 493 HWY_API Vec256<T> And(Vec256<T> a, Vec256<T> b) { 494 const DFromV<decltype(a)> d; // for float16_t 495 const RebindToUnsigned<decltype(d)> du; 496 return BitCast(d, VFromD<decltype(du)>{_mm256_and_si256(BitCast(du, a).raw, 497 BitCast(du, b).raw)}); 498 } 499 500 HWY_API Vec256<float> And(Vec256<float> a, Vec256<float> b) { 501 return Vec256<float>{_mm256_and_ps(a.raw, b.raw)}; 502 } 503 HWY_API Vec256<double> And(Vec256<double> a, Vec256<double> b) { 504 return Vec256<double>{_mm256_and_pd(a.raw, b.raw)}; 505 } 506 507 // ------------------------------ AndNot 508 509 // Returns ~not_mask & mask. 510 template <typename T> 511 HWY_API Vec256<T> AndNot(Vec256<T> not_mask, Vec256<T> mask) { 512 const DFromV<decltype(mask)> d; // for float16_t 513 const RebindToUnsigned<decltype(d)> du; 514 return BitCast(d, VFromD<decltype(du)>{_mm256_andnot_si256( 515 BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); 516 } 517 HWY_API Vec256<float> AndNot(Vec256<float> not_mask, Vec256<float> mask) { 518 return Vec256<float>{_mm256_andnot_ps(not_mask.raw, mask.raw)}; 519 } 520 HWY_API Vec256<double> AndNot(Vec256<double> not_mask, Vec256<double> mask) { 521 return Vec256<double>{_mm256_andnot_pd(not_mask.raw, mask.raw)}; 522 } 523 524 // ------------------------------ Or 525 526 template <typename T> 527 HWY_API Vec256<T> Or(Vec256<T> a, Vec256<T> b) { 528 const DFromV<decltype(a)> d; // for float16_t 529 const RebindToUnsigned<decltype(d)> du; 530 return BitCast(d, VFromD<decltype(du)>{_mm256_or_si256(BitCast(du, a).raw, 531 BitCast(du, b).raw)}); 532 } 533 534 HWY_API Vec256<float> Or(Vec256<float> a, Vec256<float> b) { 535 return Vec256<float>{_mm256_or_ps(a.raw, b.raw)}; 536 } 537 HWY_API Vec256<double> Or(Vec256<double> a, Vec256<double> b) { 538 return Vec256<double>{_mm256_or_pd(a.raw, b.raw)}; 539 } 540 541 // ------------------------------ Xor 542 543 template <typename T> 544 HWY_API Vec256<T> Xor(Vec256<T> a, Vec256<T> b) { 545 const DFromV<decltype(a)> d; // for float16_t 546 const RebindToUnsigned<decltype(d)> du; 547 return BitCast(d, VFromD<decltype(du)>{_mm256_xor_si256(BitCast(du, a).raw, 548 BitCast(du, b).raw)}); 549 } 550 551 HWY_API Vec256<float> Xor(Vec256<float> a, Vec256<float> b) { 552 return Vec256<float>{_mm256_xor_ps(a.raw, b.raw)}; 553 } 554 HWY_API Vec256<double> Xor(Vec256<double> a, Vec256<double> b) { 555 return Vec256<double>{_mm256_xor_pd(a.raw, b.raw)}; 556 } 557 558 // ------------------------------ Not 559 template <typename T> 560 HWY_API Vec256<T> Not(const Vec256<T> v) { 561 const DFromV<decltype(v)> d; 562 using TU = MakeUnsigned<T>; 563 #if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 564 const __m256i vu = BitCast(RebindToUnsigned<decltype(d)>(), v).raw; 565 return BitCast(d, Vec256<TU>{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); 566 #else 567 return Xor(v, BitCast(d, Vec256<TU>{_mm256_set1_epi32(-1)})); 568 #endif 569 } 570 571 // ------------------------------ Xor3 572 template <typename T> 573 HWY_API Vec256<T> Xor3(Vec256<T> x1, Vec256<T> x2, Vec256<T> x3) { 574 #if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 575 const DFromV<decltype(x1)> d; 576 const RebindToUnsigned<decltype(d)> du; 577 using VU = VFromD<decltype(du)>; 578 const __m256i ret = _mm256_ternarylogic_epi64( 579 BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); 580 return BitCast(d, VU{ret}); 581 #else 582 return Xor(x1, Xor(x2, x3)); 583 #endif 584 } 585 586 // ------------------------------ Or3 587 template <typename T> 588 HWY_API Vec256<T> Or3(Vec256<T> o1, Vec256<T> o2, Vec256<T> o3) { 589 #if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 590 const DFromV<decltype(o1)> d; 591 const RebindToUnsigned<decltype(d)> du; 592 using VU = VFromD<decltype(du)>; 593 const __m256i ret = _mm256_ternarylogic_epi64( 594 BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); 595 return BitCast(d, VU{ret}); 596 #else 597 return Or(o1, Or(o2, o3)); 598 #endif 599 } 600 601 // ------------------------------ OrAnd 602 template <typename T> 603 HWY_API Vec256<T> OrAnd(Vec256<T> o, Vec256<T> a1, Vec256<T> a2) { 604 #if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 605 const DFromV<decltype(o)> d; 606 const RebindToUnsigned<decltype(d)> du; 607 using VU = VFromD<decltype(du)>; 608 const __m256i ret = _mm256_ternarylogic_epi64( 609 BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); 610 return BitCast(d, VU{ret}); 611 #else 612 return Or(o, And(a1, a2)); 613 #endif 614 } 615 616 // ------------------------------ IfVecThenElse 617 template <typename T> 618 HWY_API Vec256<T> IfVecThenElse(Vec256<T> mask, Vec256<T> yes, Vec256<T> no) { 619 #if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 620 const DFromV<decltype(yes)> d; 621 const RebindToUnsigned<decltype(d)> du; 622 using VU = VFromD<decltype(du)>; 623 return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw, 624 BitCast(du, yes).raw, 625 BitCast(du, no).raw, 0xCA)}); 626 #else 627 return IfThenElse(MaskFromVec(mask), yes, no); 628 #endif 629 } 630 631 // ------------------------------ Operator overloads (internal-only if float) 632 633 template <typename T> 634 HWY_API Vec256<T> operator&(const Vec256<T> a, const Vec256<T> b) { 635 return And(a, b); 636 } 637 638 template <typename T> 639 HWY_API Vec256<T> operator|(const Vec256<T> a, const Vec256<T> b) { 640 return Or(a, b); 641 } 642 643 template <typename T> 644 HWY_API Vec256<T> operator^(const Vec256<T> a, const Vec256<T> b) { 645 return Xor(a, b); 646 } 647 648 // ------------------------------ PopulationCount 649 650 // 8/16 require BITALG, 32/64 require VPOPCNTDQ. 651 #if HWY_TARGET <= HWY_AVX3_DL 652 653 #ifdef HWY_NATIVE_POPCNT 654 #undef HWY_NATIVE_POPCNT 655 #else 656 #define HWY_NATIVE_POPCNT 657 #endif 658 659 namespace detail { 660 661 template <typename T> 662 HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<1> /* tag */, Vec256<T> v) { 663 return Vec256<T>{_mm256_popcnt_epi8(v.raw)}; 664 } 665 template <typename T> 666 HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<2> /* tag */, Vec256<T> v) { 667 return Vec256<T>{_mm256_popcnt_epi16(v.raw)}; 668 } 669 template <typename T> 670 HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<4> /* tag */, Vec256<T> v) { 671 return Vec256<T>{_mm256_popcnt_epi32(v.raw)}; 672 } 673 template <typename T> 674 HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<8> /* tag */, Vec256<T> v) { 675 return Vec256<T>{_mm256_popcnt_epi64(v.raw)}; 676 } 677 678 } // namespace detail 679 680 template <typename T> 681 HWY_API Vec256<T> PopulationCount(Vec256<T> v) { 682 return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v); 683 } 684 685 #endif // HWY_TARGET <= HWY_AVX3_DL 686 687 // ================================================== MASK 688 689 #if HWY_TARGET <= HWY_AVX3 690 691 // ------------------------------ IfThenElse 692 693 // Returns mask ? b : a. 694 695 namespace detail { 696 697 // Templates for signed/unsigned integer of a particular size. 698 template <typename T> 699 HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<1> /* tag */, Mask256<T> mask, 700 Vec256<T> yes, Vec256<T> no) { 701 return Vec256<T>{_mm256_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; 702 } 703 template <typename T> 704 HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<2> /* tag */, Mask256<T> mask, 705 Vec256<T> yes, Vec256<T> no) { 706 return Vec256<T>{_mm256_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; 707 } 708 template <typename T> 709 HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<4> /* tag */, Mask256<T> mask, 710 Vec256<T> yes, Vec256<T> no) { 711 return Vec256<T>{_mm256_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; 712 } 713 template <typename T> 714 HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<8> /* tag */, Mask256<T> mask, 715 Vec256<T> yes, Vec256<T> no) { 716 return Vec256<T>{_mm256_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; 717 } 718 719 } // namespace detail 720 721 template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> 722 HWY_API Vec256<T> IfThenElse(Mask256<T> mask, Vec256<T> yes, Vec256<T> no) { 723 return detail::IfThenElse(hwy::SizeTag<sizeof(T)>(), mask, yes, no); 724 } 725 #if HWY_HAVE_FLOAT16 726 HWY_API Vec256<float16_t> IfThenElse(Mask256<float16_t> mask, 727 Vec256<float16_t> yes, 728 Vec256<float16_t> no) { 729 return Vec256<float16_t>{_mm256_mask_blend_ph(mask.raw, no.raw, yes.raw)}; 730 } 731 #endif // HWY_HAVE_FLOAT16 732 HWY_API Vec256<float> IfThenElse(Mask256<float> mask, Vec256<float> yes, 733 Vec256<float> no) { 734 return Vec256<float>{_mm256_mask_blend_ps(mask.raw, no.raw, yes.raw)}; 735 } 736 HWY_API Vec256<double> IfThenElse(Mask256<double> mask, Vec256<double> yes, 737 Vec256<double> no) { 738 return Vec256<double>{_mm256_mask_blend_pd(mask.raw, no.raw, yes.raw)}; 739 } 740 741 namespace detail { 742 743 template <typename T> 744 HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<1> /* tag */, Mask256<T> mask, 745 Vec256<T> yes) { 746 return Vec256<T>{_mm256_maskz_mov_epi8(mask.raw, yes.raw)}; 747 } 748 template <typename T> 749 HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<2> /* tag */, Mask256<T> mask, 750 Vec256<T> yes) { 751 return Vec256<T>{_mm256_maskz_mov_epi16(mask.raw, yes.raw)}; 752 } 753 template <typename T> 754 HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<4> /* tag */, Mask256<T> mask, 755 Vec256<T> yes) { 756 return Vec256<T>{_mm256_maskz_mov_epi32(mask.raw, yes.raw)}; 757 } 758 template <typename T> 759 HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256<T> mask, 760 Vec256<T> yes) { 761 return Vec256<T>{_mm256_maskz_mov_epi64(mask.raw, yes.raw)}; 762 } 763 764 } // namespace detail 765 766 template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> 767 HWY_API Vec256<T> IfThenElseZero(Mask256<T> mask, Vec256<T> yes) { 768 return detail::IfThenElseZero(hwy::SizeTag<sizeof(T)>(), mask, yes); 769 } 770 HWY_API Vec256<float> IfThenElseZero(Mask256<float> mask, Vec256<float> yes) { 771 return Vec256<float>{_mm256_maskz_mov_ps(mask.raw, yes.raw)}; 772 } 773 HWY_API Vec256<double> IfThenElseZero(Mask256<double> mask, 774 Vec256<double> yes) { 775 return Vec256<double>{_mm256_maskz_mov_pd(mask.raw, yes.raw)}; 776 } 777 778 namespace detail { 779 780 template <typename T> 781 HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<1> /* tag */, Mask256<T> mask, 782 Vec256<T> no) { 783 // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. 784 return Vec256<T>{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; 785 } 786 template <typename T> 787 HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<2> /* tag */, Mask256<T> mask, 788 Vec256<T> no) { 789 return Vec256<T>{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; 790 } 791 template <typename T> 792 HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<4> /* tag */, Mask256<T> mask, 793 Vec256<T> no) { 794 return Vec256<T>{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; 795 } 796 template <typename T> 797 HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256<T> mask, 798 Vec256<T> no) { 799 return Vec256<T>{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; 800 } 801 802 } // namespace detail 803 804 template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> 805 HWY_API Vec256<T> IfThenZeroElse(Mask256<T> mask, Vec256<T> no) { 806 return detail::IfThenZeroElse(hwy::SizeTag<sizeof(T)>(), mask, no); 807 } 808 HWY_API Vec256<float> IfThenZeroElse(Mask256<float> mask, Vec256<float> no) { 809 return Vec256<float>{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; 810 } 811 HWY_API Vec256<double> IfThenZeroElse(Mask256<double> mask, Vec256<double> no) { 812 return Vec256<double>{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; 813 } 814 815 // ------------------------------ Mask logical 816 817 namespace detail { 818 819 template <typename T> 820 HWY_INLINE Mask256<T> And(hwy::SizeTag<1> /*tag*/, const Mask256<T> a, 821 const Mask256<T> b) { 822 #if HWY_COMPILER_HAS_MASK_INTRINSICS 823 return Mask256<T>{_kand_mask32(a.raw, b.raw)}; 824 #else 825 return Mask256<T>{static_cast<__mmask32>(a.raw & b.raw)}; 826 #endif 827 } 828 template <typename T> 829 HWY_INLINE Mask256<T> And(hwy::SizeTag<2> /*tag*/, const Mask256<T> a, 830 const Mask256<T> b) { 831 #if HWY_COMPILER_HAS_MASK_INTRINSICS 832 return Mask256<T>{_kand_mask16(a.raw, b.raw)}; 833 #else 834 return Mask256<T>{static_cast<__mmask16>(a.raw & b.raw)}; 835 #endif 836 } 837 template <typename T> 838 HWY_INLINE Mask256<T> And(hwy::SizeTag<4> /*tag*/, const Mask256<T> a, 839 const Mask256<T> b) { 840 #if HWY_COMPILER_HAS_MASK_INTRINSICS 841 return Mask256<T>{_kand_mask8(a.raw, b.raw)}; 842 #else 843 return Mask256<T>{static_cast<__mmask8>(a.raw & b.raw)}; 844 #endif 845 } 846 template <typename T> 847 HWY_INLINE Mask256<T> And(hwy::SizeTag<8> /*tag*/, const Mask256<T> a, 848 const Mask256<T> b) { 849 #if HWY_COMPILER_HAS_MASK_INTRINSICS 850 return Mask256<T>{_kand_mask8(a.raw, b.raw)}; 851 #else 852 return Mask256<T>{static_cast<__mmask8>(a.raw & b.raw)}; 853 #endif 854 } 855 856 template <typename T> 857 HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<1> /*tag*/, const Mask256<T> a, 858 const Mask256<T> b) { 859 #if HWY_COMPILER_HAS_MASK_INTRINSICS 860 return Mask256<T>{_kandn_mask32(a.raw, b.raw)}; 861 #else 862 return Mask256<T>{static_cast<__mmask32>(~a.raw & b.raw)}; 863 #endif 864 } 865 template <typename T> 866 HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<2> /*tag*/, const Mask256<T> a, 867 const Mask256<T> b) { 868 #if HWY_COMPILER_HAS_MASK_INTRINSICS 869 return Mask256<T>{_kandn_mask16(a.raw, b.raw)}; 870 #else 871 return Mask256<T>{static_cast<__mmask16>(~a.raw & b.raw)}; 872 #endif 873 } 874 template <typename T> 875 HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<4> /*tag*/, const Mask256<T> a, 876 const Mask256<T> b) { 877 #if HWY_COMPILER_HAS_MASK_INTRINSICS 878 return Mask256<T>{_kandn_mask8(a.raw, b.raw)}; 879 #else 880 return Mask256<T>{static_cast<__mmask8>(~a.raw & b.raw)}; 881 #endif 882 } 883 template <typename T> 884 HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<8> /*tag*/, const Mask256<T> a, 885 const Mask256<T> b) { 886 #if HWY_COMPILER_HAS_MASK_INTRINSICS 887 return Mask256<T>{_kandn_mask8(a.raw, b.raw)}; 888 #else 889 return Mask256<T>{static_cast<__mmask8>(~a.raw & b.raw)}; 890 #endif 891 } 892 893 template <typename T> 894 HWY_INLINE Mask256<T> Or(hwy::SizeTag<1> /*tag*/, const Mask256<T> a, 895 const Mask256<T> b) { 896 #if HWY_COMPILER_HAS_MASK_INTRINSICS 897 return Mask256<T>{_kor_mask32(a.raw, b.raw)}; 898 #else 899 return Mask256<T>{static_cast<__mmask32>(a.raw | b.raw)}; 900 #endif 901 } 902 template <typename T> 903 HWY_INLINE Mask256<T> Or(hwy::SizeTag<2> /*tag*/, const Mask256<T> a, 904 const Mask256<T> b) { 905 #if HWY_COMPILER_HAS_MASK_INTRINSICS 906 return Mask256<T>{_kor_mask16(a.raw, b.raw)}; 907 #else 908 return Mask256<T>{static_cast<__mmask16>(a.raw | b.raw)}; 909 #endif 910 } 911 template <typename T> 912 HWY_INLINE Mask256<T> Or(hwy::SizeTag<4> /*tag*/, const Mask256<T> a, 913 const Mask256<T> b) { 914 #if HWY_COMPILER_HAS_MASK_INTRINSICS 915 return Mask256<T>{_kor_mask8(a.raw, b.raw)}; 916 #else 917 return Mask256<T>{static_cast<__mmask8>(a.raw | b.raw)}; 918 #endif 919 } 920 template <typename T> 921 HWY_INLINE Mask256<T> Or(hwy::SizeTag<8> /*tag*/, const Mask256<T> a, 922 const Mask256<T> b) { 923 #if HWY_COMPILER_HAS_MASK_INTRINSICS 924 return Mask256<T>{_kor_mask8(a.raw, b.raw)}; 925 #else 926 return Mask256<T>{static_cast<__mmask8>(a.raw | b.raw)}; 927 #endif 928 } 929 930 template <typename T> 931 HWY_INLINE Mask256<T> Xor(hwy::SizeTag<1> /*tag*/, const Mask256<T> a, 932 const Mask256<T> b) { 933 #if HWY_COMPILER_HAS_MASK_INTRINSICS 934 return Mask256<T>{_kxor_mask32(a.raw, b.raw)}; 935 #else 936 return Mask256<T>{static_cast<__mmask32>(a.raw ^ b.raw)}; 937 #endif 938 } 939 template <typename T> 940 HWY_INLINE Mask256<T> Xor(hwy::SizeTag<2> /*tag*/, const Mask256<T> a, 941 const Mask256<T> b) { 942 #if HWY_COMPILER_HAS_MASK_INTRINSICS 943 return Mask256<T>{_kxor_mask16(a.raw, b.raw)}; 944 #else 945 return Mask256<T>{static_cast<__mmask16>(a.raw ^ b.raw)}; 946 #endif 947 } 948 template <typename T> 949 HWY_INLINE Mask256<T> Xor(hwy::SizeTag<4> /*tag*/, const Mask256<T> a, 950 const Mask256<T> b) { 951 #if HWY_COMPILER_HAS_MASK_INTRINSICS 952 return Mask256<T>{_kxor_mask8(a.raw, b.raw)}; 953 #else 954 return Mask256<T>{static_cast<__mmask8>(a.raw ^ b.raw)}; 955 #endif 956 } 957 template <typename T> 958 HWY_INLINE Mask256<T> Xor(hwy::SizeTag<8> /*tag*/, const Mask256<T> a, 959 const Mask256<T> b) { 960 #if HWY_COMPILER_HAS_MASK_INTRINSICS 961 return Mask256<T>{_kxor_mask8(a.raw, b.raw)}; 962 #else 963 return Mask256<T>{static_cast<__mmask8>(a.raw ^ b.raw)}; 964 #endif 965 } 966 967 template <typename T> 968 HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<1> /*tag*/, 969 const Mask256<T> a, const Mask256<T> b) { 970 #if HWY_COMPILER_HAS_MASK_INTRINSICS 971 return Mask256<T>{_kxnor_mask32(a.raw, b.raw)}; 972 #else 973 return Mask256<T>{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; 974 #endif 975 } 976 template <typename T> 977 HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<2> /*tag*/, 978 const Mask256<T> a, const Mask256<T> b) { 979 #if HWY_COMPILER_HAS_MASK_INTRINSICS 980 return Mask256<T>{_kxnor_mask16(a.raw, b.raw)}; 981 #else 982 return Mask256<T>{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; 983 #endif 984 } 985 template <typename T> 986 HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<4> /*tag*/, 987 const Mask256<T> a, const Mask256<T> b) { 988 #if HWY_COMPILER_HAS_MASK_INTRINSICS 989 return Mask256<T>{_kxnor_mask8(a.raw, b.raw)}; 990 #else 991 return Mask256<T>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; 992 #endif 993 } 994 template <typename T> 995 HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<8> /*tag*/, 996 const Mask256<T> a, const Mask256<T> b) { 997 #if HWY_COMPILER_HAS_MASK_INTRINSICS 998 return Mask256<T>{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; 999 #else 1000 return Mask256<T>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; 1001 #endif 1002 } 1003 1004 // UnmaskedNot returns ~m.raw without zeroing out any invalid bits 1005 template <typename T, HWY_IF_T_SIZE(T, 1)> 1006 HWY_INLINE Mask256<T> UnmaskedNot(const Mask256<T> m) { 1007 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1008 return Mask256<T>{static_cast<__mmask32>(_knot_mask32(m.raw))}; 1009 #else 1010 return Mask256<T>{static_cast<__mmask32>(~m.raw)}; 1011 #endif 1012 } 1013 1014 template <typename T, HWY_IF_T_SIZE(T, 2)> 1015 HWY_INLINE Mask256<T> UnmaskedNot(const Mask256<T> m) { 1016 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1017 return Mask256<T>{static_cast<__mmask16>(_knot_mask16(m.raw))}; 1018 #else 1019 return Mask256<T>{static_cast<__mmask16>(~m.raw)}; 1020 #endif 1021 } 1022 1023 template <typename T, HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))> 1024 HWY_INLINE Mask256<T> UnmaskedNot(const Mask256<T> m) { 1025 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1026 return Mask256<T>{static_cast<__mmask8>(_knot_mask8(m.raw))}; 1027 #else 1028 return Mask256<T>{static_cast<__mmask8>(~m.raw)}; 1029 #endif 1030 } 1031 1032 template <typename T> 1033 HWY_INLINE Mask256<T> Not(hwy::SizeTag<1> /*tag*/, const Mask256<T> m) { 1034 // sizeof(T) == 1: simply return ~m as all 32 bits of m are valid 1035 return UnmaskedNot(m); 1036 } 1037 template <typename T> 1038 HWY_INLINE Mask256<T> Not(hwy::SizeTag<2> /*tag*/, const Mask256<T> m) { 1039 // sizeof(T) == 2: simply return ~m as all 16 bits of m are valid 1040 return UnmaskedNot(m); 1041 } 1042 template <typename T> 1043 HWY_INLINE Mask256<T> Not(hwy::SizeTag<4> /*tag*/, const Mask256<T> m) { 1044 // sizeof(T) == 4: simply return ~m as all 8 bits of m are valid 1045 return UnmaskedNot(m); 1046 } 1047 template <typename T> 1048 HWY_INLINE Mask256<T> Not(hwy::SizeTag<8> /*tag*/, const Mask256<T> m) { 1049 // sizeof(T) == 8: need to zero out the upper 4 bits of ~m as only the lower 1050 // 4 bits of m are valid 1051 1052 // Return (~m) & 0x0F 1053 return AndNot(hwy::SizeTag<8>(), m, Mask256<T>::FromBits(uint64_t{0x0F})); 1054 } 1055 1056 } // namespace detail 1057 1058 template <typename T> 1059 HWY_API Mask256<T> And(const Mask256<T> a, Mask256<T> b) { 1060 return detail::And(hwy::SizeTag<sizeof(T)>(), a, b); 1061 } 1062 1063 template <typename T> 1064 HWY_API Mask256<T> AndNot(const Mask256<T> a, Mask256<T> b) { 1065 return detail::AndNot(hwy::SizeTag<sizeof(T)>(), a, b); 1066 } 1067 1068 template <typename T> 1069 HWY_API Mask256<T> Or(const Mask256<T> a, Mask256<T> b) { 1070 return detail::Or(hwy::SizeTag<sizeof(T)>(), a, b); 1071 } 1072 1073 template <typename T> 1074 HWY_API Mask256<T> Xor(const Mask256<T> a, Mask256<T> b) { 1075 return detail::Xor(hwy::SizeTag<sizeof(T)>(), a, b); 1076 } 1077 1078 template <typename T> 1079 HWY_API Mask256<T> Not(const Mask256<T> m) { 1080 // Flip only the valid bits. 1081 return detail::Not(hwy::SizeTag<sizeof(T)>(), m); 1082 } 1083 1084 template <typename T> 1085 HWY_API Mask256<T> ExclusiveNeither(const Mask256<T> a, Mask256<T> b) { 1086 return detail::ExclusiveNeither(hwy::SizeTag<sizeof(T)>(), a, b); 1087 } 1088 1089 template <class D, HWY_IF_LANES_D(D, 32)> 1090 HWY_API MFromD<D> CombineMasks(D /*d*/, MFromD<Half<D>> hi, 1091 MFromD<Half<D>> lo) { 1092 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1093 const __mmask32 combined_mask = _mm512_kunpackw( 1094 static_cast<__mmask32>(hi.raw), static_cast<__mmask32>(lo.raw)); 1095 #else 1096 const auto combined_mask = 1097 ((static_cast<uint32_t>(hi.raw) << 16) | (lo.raw & 0xFFFFu)); 1098 #endif 1099 1100 return MFromD<D>{static_cast<decltype(MFromD<D>().raw)>(combined_mask)}; 1101 } 1102 1103 template <class D, HWY_IF_LANES_D(D, 16)> 1104 HWY_API MFromD<D> UpperHalfOfMask(D /*d*/, MFromD<Twice<D>> m) { 1105 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1106 const auto shifted_mask = _kshiftri_mask32(static_cast<__mmask32>(m.raw), 16); 1107 #else 1108 const auto shifted_mask = static_cast<uint32_t>(m.raw) >> 16; 1109 #endif 1110 1111 return MFromD<D>{static_cast<decltype(MFromD<D>().raw)>(shifted_mask)}; 1112 } 1113 1114 template <class D, HWY_IF_LANES_D(D, 32)> 1115 HWY_API MFromD<D> SlideMask1Up(D /*d*/, MFromD<D> m) { 1116 using RawM = decltype(MFromD<D>().raw); 1117 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1118 return MFromD<D>{ 1119 static_cast<RawM>(_kshiftli_mask32(static_cast<__mmask32>(m.raw), 1))}; 1120 #else 1121 return MFromD<D>{static_cast<RawM>(static_cast<uint32_t>(m.raw) << 1)}; 1122 #endif 1123 } 1124 1125 template <class D, HWY_IF_LANES_D(D, 32)> 1126 HWY_API MFromD<D> SlideMask1Down(D /*d*/, MFromD<D> m) { 1127 using RawM = decltype(MFromD<D>().raw); 1128 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1129 return MFromD<D>{ 1130 static_cast<RawM>(_kshiftri_mask32(static_cast<__mmask32>(m.raw), 1))}; 1131 #else 1132 return MFromD<D>{static_cast<RawM>(static_cast<uint32_t>(m.raw) >> 1)}; 1133 #endif 1134 } 1135 1136 #else // AVX2 1137 1138 // ------------------------------ Mask 1139 1140 // Mask and Vec are the same (true = FF..FF). 1141 template <typename T> 1142 HWY_API Mask256<T> MaskFromVec(const Vec256<T> v) { 1143 return Mask256<T>{v.raw}; 1144 } 1145 1146 template <typename T> 1147 HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { 1148 return Vec256<T>{v.raw}; 1149 } 1150 1151 // ------------------------------ IfThenElse 1152 1153 // mask ? yes : no 1154 template <typename T, HWY_IF_NOT_FLOAT3264(T)> 1155 HWY_API Vec256<T> IfThenElse(Mask256<T> mask, Vec256<T> yes, Vec256<T> no) { 1156 return Vec256<T>{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; 1157 } 1158 HWY_API Vec256<float> IfThenElse(Mask256<float> mask, Vec256<float> yes, 1159 Vec256<float> no) { 1160 return Vec256<float>{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; 1161 } 1162 HWY_API Vec256<double> IfThenElse(Mask256<double> mask, Vec256<double> yes, 1163 Vec256<double> no) { 1164 return Vec256<double>{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; 1165 } 1166 1167 // mask ? yes : 0 1168 template <typename T> 1169 HWY_API Vec256<T> IfThenElseZero(Mask256<T> mask, Vec256<T> yes) { 1170 const DFromV<decltype(yes)> d; 1171 return yes & VecFromMask(d, mask); 1172 } 1173 1174 // mask ? 0 : no 1175 template <typename T> 1176 HWY_API Vec256<T> IfThenZeroElse(Mask256<T> mask, Vec256<T> no) { 1177 const DFromV<decltype(no)> d; 1178 return AndNot(VecFromMask(d, mask), no); 1179 } 1180 1181 template <typename T> 1182 HWY_API Vec256<T> ZeroIfNegative(Vec256<T> v) { 1183 static_assert(IsSigned<T>(), "Only for float"); 1184 const DFromV<decltype(v)> d; 1185 const auto zero = Zero(d); 1186 // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes 1187 return IfThenElse(MaskFromVec(v), zero, v); 1188 } 1189 1190 // ------------------------------ Mask logical 1191 1192 template <typename T> 1193 HWY_API Mask256<T> Not(const Mask256<T> m) { 1194 const Full256<T> d; 1195 return MaskFromVec(Not(VecFromMask(d, m))); 1196 } 1197 1198 template <typename T> 1199 HWY_API Mask256<T> And(const Mask256<T> a, Mask256<T> b) { 1200 const Full256<T> d; 1201 return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); 1202 } 1203 1204 template <typename T> 1205 HWY_API Mask256<T> AndNot(const Mask256<T> a, Mask256<T> b) { 1206 const Full256<T> d; 1207 return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); 1208 } 1209 1210 template <typename T> 1211 HWY_API Mask256<T> Or(const Mask256<T> a, Mask256<T> b) { 1212 const Full256<T> d; 1213 return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); 1214 } 1215 1216 template <typename T> 1217 HWY_API Mask256<T> Xor(const Mask256<T> a, Mask256<T> b) { 1218 const Full256<T> d; 1219 return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); 1220 } 1221 1222 template <typename T> 1223 HWY_API Mask256<T> ExclusiveNeither(const Mask256<T> a, Mask256<T> b) { 1224 const Full256<T> d; 1225 return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); 1226 } 1227 1228 #endif // HWY_TARGET <= HWY_AVX3 1229 1230 // ================================================== COMPARE 1231 1232 #if HWY_TARGET <= HWY_AVX3 1233 1234 // Comparisons set a mask bit to 1 if the condition is true, else 0. 1235 1236 template <class DTo, HWY_IF_V_SIZE_D(DTo, 32), typename TFrom> 1237 HWY_API MFromD<DTo> RebindMask(DTo /*tag*/, Mask256<TFrom> m) { 1238 static_assert(sizeof(TFrom) == sizeof(TFromD<DTo>), "Must have same size"); 1239 return MFromD<DTo>{m.raw}; 1240 } 1241 1242 namespace detail { 1243 1244 template <typename T> 1245 HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<1> /*tag*/, const Vec256<T> v, 1246 const Vec256<T> bit) { 1247 return Mask256<T>{_mm256_test_epi8_mask(v.raw, bit.raw)}; 1248 } 1249 template <typename T> 1250 HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<2> /*tag*/, const Vec256<T> v, 1251 const Vec256<T> bit) { 1252 return Mask256<T>{_mm256_test_epi16_mask(v.raw, bit.raw)}; 1253 } 1254 template <typename T> 1255 HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<4> /*tag*/, const Vec256<T> v, 1256 const Vec256<T> bit) { 1257 return Mask256<T>{_mm256_test_epi32_mask(v.raw, bit.raw)}; 1258 } 1259 template <typename T> 1260 HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<8> /*tag*/, const Vec256<T> v, 1261 const Vec256<T> bit) { 1262 return Mask256<T>{_mm256_test_epi64_mask(v.raw, bit.raw)}; 1263 } 1264 1265 } // namespace detail 1266 1267 template <typename T> 1268 HWY_API Mask256<T> TestBit(const Vec256<T> v, const Vec256<T> bit) { 1269 static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); 1270 return detail::TestBit(hwy::SizeTag<sizeof(T)>(), v, bit); 1271 } 1272 1273 // ------------------------------ Equality 1274 1275 template <typename T, HWY_IF_T_SIZE(T, 1)> 1276 HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { 1277 return Mask256<T>{_mm256_cmpeq_epi8_mask(a.raw, b.raw)}; 1278 } 1279 template <typename T, HWY_IF_UI16(T)> 1280 HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { 1281 return Mask256<T>{_mm256_cmpeq_epi16_mask(a.raw, b.raw)}; 1282 } 1283 template <typename T, HWY_IF_UI32(T)> 1284 HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { 1285 return Mask256<T>{_mm256_cmpeq_epi32_mask(a.raw, b.raw)}; 1286 } 1287 template <typename T, HWY_IF_UI64(T)> 1288 HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { 1289 return Mask256<T>{_mm256_cmpeq_epi64_mask(a.raw, b.raw)}; 1290 } 1291 1292 #if HWY_HAVE_FLOAT16 1293 HWY_API Mask256<float16_t> operator==(Vec256<float16_t> a, 1294 Vec256<float16_t> b) { 1295 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 1296 HWY_DIAGNOSTICS(push) 1297 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1298 return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; 1299 HWY_DIAGNOSTICS(pop) 1300 } 1301 #endif // HWY_HAVE_FLOAT16 1302 HWY_API Mask256<float> operator==(Vec256<float> a, Vec256<float> b) { 1303 return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; 1304 } 1305 1306 HWY_API Mask256<double> operator==(Vec256<double> a, Vec256<double> b) { 1307 return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; 1308 } 1309 1310 // ------------------------------ Inequality 1311 1312 template <typename T, HWY_IF_T_SIZE(T, 1)> 1313 HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { 1314 return Mask256<T>{_mm256_cmpneq_epi8_mask(a.raw, b.raw)}; 1315 } 1316 template <typename T, HWY_IF_UI16(T)> 1317 HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { 1318 return Mask256<T>{_mm256_cmpneq_epi16_mask(a.raw, b.raw)}; 1319 } 1320 template <typename T, HWY_IF_UI32(T)> 1321 HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { 1322 return Mask256<T>{_mm256_cmpneq_epi32_mask(a.raw, b.raw)}; 1323 } 1324 template <typename T, HWY_IF_UI64(T)> 1325 HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { 1326 return Mask256<T>{_mm256_cmpneq_epi64_mask(a.raw, b.raw)}; 1327 } 1328 1329 #if HWY_HAVE_FLOAT16 1330 HWY_API Mask256<float16_t> operator!=(Vec256<float16_t> a, 1331 Vec256<float16_t> b) { 1332 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 1333 HWY_DIAGNOSTICS(push) 1334 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1335 return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; 1336 HWY_DIAGNOSTICS(pop) 1337 } 1338 #endif // HWY_HAVE_FLOAT16 1339 HWY_API Mask256<float> operator!=(Vec256<float> a, Vec256<float> b) { 1340 return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; 1341 } 1342 1343 HWY_API Mask256<double> operator!=(Vec256<double> a, Vec256<double> b) { 1344 return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; 1345 } 1346 1347 // ------------------------------ Strict inequality 1348 1349 HWY_API Mask256<int8_t> operator>(Vec256<int8_t> a, Vec256<int8_t> b) { 1350 return Mask256<int8_t>{_mm256_cmpgt_epi8_mask(a.raw, b.raw)}; 1351 } 1352 HWY_API Mask256<int16_t> operator>(Vec256<int16_t> a, Vec256<int16_t> b) { 1353 return Mask256<int16_t>{_mm256_cmpgt_epi16_mask(a.raw, b.raw)}; 1354 } 1355 HWY_API Mask256<int32_t> operator>(Vec256<int32_t> a, Vec256<int32_t> b) { 1356 return Mask256<int32_t>{_mm256_cmpgt_epi32_mask(a.raw, b.raw)}; 1357 } 1358 HWY_API Mask256<int64_t> operator>(Vec256<int64_t> a, Vec256<int64_t> b) { 1359 return Mask256<int64_t>{_mm256_cmpgt_epi64_mask(a.raw, b.raw)}; 1360 } 1361 1362 HWY_API Mask256<uint8_t> operator>(Vec256<uint8_t> a, Vec256<uint8_t> b) { 1363 return Mask256<uint8_t>{_mm256_cmpgt_epu8_mask(a.raw, b.raw)}; 1364 } 1365 HWY_API Mask256<uint16_t> operator>(Vec256<uint16_t> a, Vec256<uint16_t> b) { 1366 return Mask256<uint16_t>{_mm256_cmpgt_epu16_mask(a.raw, b.raw)}; 1367 } 1368 HWY_API Mask256<uint32_t> operator>(Vec256<uint32_t> a, Vec256<uint32_t> b) { 1369 return Mask256<uint32_t>{_mm256_cmpgt_epu32_mask(a.raw, b.raw)}; 1370 } 1371 HWY_API Mask256<uint64_t> operator>(Vec256<uint64_t> a, Vec256<uint64_t> b) { 1372 return Mask256<uint64_t>{_mm256_cmpgt_epu64_mask(a.raw, b.raw)}; 1373 } 1374 1375 #if HWY_HAVE_FLOAT16 1376 HWY_API Mask256<float16_t> operator>(Vec256<float16_t> a, Vec256<float16_t> b) { 1377 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 1378 HWY_DIAGNOSTICS(push) 1379 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1380 return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; 1381 HWY_DIAGNOSTICS(pop) 1382 } 1383 #endif // HWY_HAVE_FLOAT16 1384 HWY_API Mask256<float> operator>(Vec256<float> a, Vec256<float> b) { 1385 return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; 1386 } 1387 HWY_API Mask256<double> operator>(Vec256<double> a, Vec256<double> b) { 1388 return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; 1389 } 1390 1391 // ------------------------------ Weak inequality 1392 1393 #if HWY_HAVE_FLOAT16 1394 HWY_API Mask256<float16_t> operator>=(Vec256<float16_t> a, 1395 Vec256<float16_t> b) { 1396 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 1397 HWY_DIAGNOSTICS(push) 1398 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1399 return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; 1400 HWY_DIAGNOSTICS(pop) 1401 } 1402 #endif // HWY_HAVE_FLOAT16 1403 1404 HWY_API Mask256<float> operator>=(Vec256<float> a, Vec256<float> b) { 1405 return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; 1406 } 1407 HWY_API Mask256<double> operator>=(Vec256<double> a, Vec256<double> b) { 1408 return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; 1409 } 1410 1411 HWY_API Mask256<int8_t> operator>=(Vec256<int8_t> a, Vec256<int8_t> b) { 1412 return Mask256<int8_t>{_mm256_cmpge_epi8_mask(a.raw, b.raw)}; 1413 } 1414 HWY_API Mask256<int16_t> operator>=(Vec256<int16_t> a, Vec256<int16_t> b) { 1415 return Mask256<int16_t>{_mm256_cmpge_epi16_mask(a.raw, b.raw)}; 1416 } 1417 HWY_API Mask256<int32_t> operator>=(Vec256<int32_t> a, Vec256<int32_t> b) { 1418 return Mask256<int32_t>{_mm256_cmpge_epi32_mask(a.raw, b.raw)}; 1419 } 1420 HWY_API Mask256<int64_t> operator>=(Vec256<int64_t> a, Vec256<int64_t> b) { 1421 return Mask256<int64_t>{_mm256_cmpge_epi64_mask(a.raw, b.raw)}; 1422 } 1423 1424 HWY_API Mask256<uint8_t> operator>=(Vec256<uint8_t> a, Vec256<uint8_t> b) { 1425 return Mask256<uint8_t>{_mm256_cmpge_epu8_mask(a.raw, b.raw)}; 1426 } 1427 HWY_API Mask256<uint16_t> operator>=(const Vec256<uint16_t> a, 1428 const Vec256<uint16_t> b) { 1429 return Mask256<uint16_t>{_mm256_cmpge_epu16_mask(a.raw, b.raw)}; 1430 } 1431 HWY_API Mask256<uint32_t> operator>=(const Vec256<uint32_t> a, 1432 const Vec256<uint32_t> b) { 1433 return Mask256<uint32_t>{_mm256_cmpge_epu32_mask(a.raw, b.raw)}; 1434 } 1435 HWY_API Mask256<uint64_t> operator>=(const Vec256<uint64_t> a, 1436 const Vec256<uint64_t> b) { 1437 return Mask256<uint64_t>{_mm256_cmpge_epu64_mask(a.raw, b.raw)}; 1438 } 1439 1440 // ------------------------------ Mask 1441 1442 namespace detail { 1443 1444 template <typename T> 1445 HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec256<T> v) { 1446 return Mask256<T>{_mm256_movepi8_mask(v.raw)}; 1447 } 1448 template <typename T> 1449 HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec256<T> v) { 1450 return Mask256<T>{_mm256_movepi16_mask(v.raw)}; 1451 } 1452 template <typename T> 1453 HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec256<T> v) { 1454 return Mask256<T>{_mm256_movepi32_mask(v.raw)}; 1455 } 1456 template <typename T> 1457 HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec256<T> v) { 1458 return Mask256<T>{_mm256_movepi64_mask(v.raw)}; 1459 } 1460 1461 } // namespace detail 1462 1463 template <typename T, HWY_IF_NOT_FLOAT(T)> 1464 HWY_API Mask256<T> MaskFromVec(const Vec256<T> v) { 1465 return detail::MaskFromVec(hwy::SizeTag<sizeof(T)>(), v); 1466 } 1467 // There do not seem to be native floating-point versions of these instructions. 1468 template <typename T, HWY_IF_FLOAT(T)> 1469 HWY_API Mask256<T> MaskFromVec(const Vec256<T> v) { 1470 const RebindToSigned<DFromV<decltype(v)>> di; 1471 return Mask256<T>{MaskFromVec(BitCast(di, v)).raw}; 1472 } 1473 1474 template <typename T, HWY_IF_T_SIZE(T, 1)> 1475 HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { 1476 return Vec256<T>{_mm256_movm_epi8(v.raw)}; 1477 } 1478 1479 template <typename T, HWY_IF_UI16(T)> 1480 HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { 1481 return Vec256<T>{_mm256_movm_epi16(v.raw)}; 1482 } 1483 1484 template <typename T, HWY_IF_UI32(T)> 1485 HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { 1486 return Vec256<T>{_mm256_movm_epi32(v.raw)}; 1487 } 1488 1489 template <typename T, HWY_IF_UI64(T)> 1490 HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { 1491 return Vec256<T>{_mm256_movm_epi64(v.raw)}; 1492 } 1493 1494 #if HWY_HAVE_FLOAT16 1495 HWY_API Vec256<float16_t> VecFromMask(const Mask256<float16_t> v) { 1496 return Vec256<float16_t>{_mm256_castsi256_ph(_mm256_movm_epi16(v.raw))}; 1497 } 1498 #endif // HWY_HAVE_FLOAT16 1499 1500 HWY_API Vec256<float> VecFromMask(const Mask256<float> v) { 1501 return Vec256<float>{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))}; 1502 } 1503 1504 HWY_API Vec256<double> VecFromMask(const Mask256<double> v) { 1505 return Vec256<double>{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))}; 1506 } 1507 1508 #else // AVX2 1509 1510 // Comparisons fill a lane with 1-bits if the condition is true, else 0. 1511 1512 template <class DTo, HWY_IF_V_SIZE_D(DTo, 32), typename TFrom> 1513 HWY_API MFromD<DTo> RebindMask(DTo d_to, Mask256<TFrom> m) { 1514 static_assert(sizeof(TFrom) == sizeof(TFromD<DTo>), "Must have same size"); 1515 const Full256<TFrom> dfrom; 1516 return MaskFromVec(BitCast(d_to, VecFromMask(dfrom, m))); 1517 } 1518 1519 template <typename T> 1520 HWY_API Mask256<T> TestBit(const Vec256<T> v, const Vec256<T> bit) { 1521 static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); 1522 return (v & bit) == bit; 1523 } 1524 1525 // ------------------------------ Equality 1526 1527 template <typename T, HWY_IF_T_SIZE(T, 1)> 1528 HWY_API Mask256<T> operator==(Vec256<T> a, Vec256<T> b) { 1529 return Mask256<T>{_mm256_cmpeq_epi8(a.raw, b.raw)}; 1530 } 1531 1532 template <typename T, HWY_IF_UI16(T)> 1533 HWY_API Mask256<T> operator==(Vec256<T> a, Vec256<T> b) { 1534 return Mask256<T>{_mm256_cmpeq_epi16(a.raw, b.raw)}; 1535 } 1536 1537 template <typename T, HWY_IF_UI32(T)> 1538 HWY_API Mask256<T> operator==(Vec256<T> a, Vec256<T> b) { 1539 return Mask256<T>{_mm256_cmpeq_epi32(a.raw, b.raw)}; 1540 } 1541 1542 template <typename T, HWY_IF_UI64(T)> 1543 HWY_API Mask256<T> operator==(Vec256<T> a, Vec256<T> b) { 1544 return Mask256<T>{_mm256_cmpeq_epi64(a.raw, b.raw)}; 1545 } 1546 1547 HWY_API Mask256<float> operator==(Vec256<float> a, Vec256<float> b) { 1548 return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; 1549 } 1550 1551 HWY_API Mask256<double> operator==(Vec256<double> a, Vec256<double> b) { 1552 return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; 1553 } 1554 1555 // ------------------------------ Inequality 1556 1557 template <typename T, HWY_IF_NOT_FLOAT3264(T)> 1558 HWY_API Mask256<T> operator!=(Vec256<T> a, Vec256<T> b) { 1559 return Not(a == b); 1560 } 1561 HWY_API Mask256<float> operator!=(Vec256<float> a, Vec256<float> b) { 1562 return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)}; 1563 } 1564 HWY_API Mask256<double> operator!=(Vec256<double> a, Vec256<double> b) { 1565 return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)}; 1566 } 1567 1568 // ------------------------------ Strict inequality 1569 1570 // Tag dispatch instead of SFINAE for MSVC 2017 compatibility 1571 namespace detail { 1572 1573 // Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 1574 // to perform an unsigned comparison instead of the intended signed. Workaround 1575 // is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy 1576 #if HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 903 1577 #define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 1578 #else 1579 #define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 1580 #endif 1581 1582 HWY_API Mask256<int8_t> Gt(hwy::SignedTag /*tag*/, Vec256<int8_t> a, 1583 Vec256<int8_t> b) { 1584 #if HWY_AVX2_GCC_CMPGT8_WORKAROUND 1585 using i8x32 = signed char __attribute__((__vector_size__(32))); 1586 return Mask256<int8_t>{static_cast<__m256i>(reinterpret_cast<i8x32>(a.raw) > 1587 reinterpret_cast<i8x32>(b.raw))}; 1588 #else 1589 return Mask256<int8_t>{_mm256_cmpgt_epi8(a.raw, b.raw)}; 1590 #endif 1591 } 1592 HWY_API Mask256<int16_t> Gt(hwy::SignedTag /*tag*/, Vec256<int16_t> a, 1593 Vec256<int16_t> b) { 1594 return Mask256<int16_t>{_mm256_cmpgt_epi16(a.raw, b.raw)}; 1595 } 1596 HWY_API Mask256<int32_t> Gt(hwy::SignedTag /*tag*/, Vec256<int32_t> a, 1597 Vec256<int32_t> b) { 1598 return Mask256<int32_t>{_mm256_cmpgt_epi32(a.raw, b.raw)}; 1599 } 1600 HWY_API Mask256<int64_t> Gt(hwy::SignedTag /*tag*/, Vec256<int64_t> a, 1601 Vec256<int64_t> b) { 1602 return Mask256<int64_t>{_mm256_cmpgt_epi64(a.raw, b.raw)}; 1603 } 1604 1605 template <typename T> 1606 HWY_INLINE Mask256<T> Gt(hwy::UnsignedTag /*tag*/, Vec256<T> a, Vec256<T> b) { 1607 const Full256<T> du; 1608 const RebindToSigned<decltype(du)> di; 1609 const Vec256<T> msb = Set(du, (LimitsMax<T>() >> 1) + 1); 1610 return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); 1611 } 1612 1613 HWY_API Mask256<float> Gt(hwy::FloatTag /*tag*/, Vec256<float> a, 1614 Vec256<float> b) { 1615 return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; 1616 } 1617 HWY_API Mask256<double> Gt(hwy::FloatTag /*tag*/, Vec256<double> a, 1618 Vec256<double> b) { 1619 return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; 1620 } 1621 1622 } // namespace detail 1623 1624 template <typename T> 1625 HWY_API Mask256<T> operator>(Vec256<T> a, Vec256<T> b) { 1626 return detail::Gt(hwy::TypeTag<T>(), a, b); 1627 } 1628 1629 // ------------------------------ Weak inequality 1630 1631 namespace detail { 1632 1633 template <typename T> 1634 HWY_INLINE Mask256<T> Ge(hwy::SignedTag tag, Vec256<T> a, Vec256<T> b) { 1635 return Not(Gt(tag, b, a)); 1636 } 1637 1638 template <typename T> 1639 HWY_INLINE Mask256<T> Ge(hwy::UnsignedTag tag, Vec256<T> a, Vec256<T> b) { 1640 return Not(Gt(tag, b, a)); 1641 } 1642 1643 HWY_INLINE Mask256<float> Ge(hwy::FloatTag /*tag*/, Vec256<float> a, 1644 Vec256<float> b) { 1645 return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; 1646 } 1647 HWY_INLINE Mask256<double> Ge(hwy::FloatTag /*tag*/, Vec256<double> a, 1648 Vec256<double> b) { 1649 return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; 1650 } 1651 1652 } // namespace detail 1653 1654 template <typename T> 1655 HWY_API Mask256<T> operator>=(Vec256<T> a, Vec256<T> b) { 1656 return detail::Ge(hwy::TypeTag<T>(), a, b); 1657 } 1658 1659 #endif // HWY_TARGET <= HWY_AVX3 1660 1661 // ------------------------------ Reversed comparisons 1662 1663 template <typename T> 1664 HWY_API Mask256<T> operator<(const Vec256<T> a, const Vec256<T> b) { 1665 return b > a; 1666 } 1667 1668 template <typename T> 1669 HWY_API Mask256<T> operator<=(const Vec256<T> a, const Vec256<T> b) { 1670 return b >= a; 1671 } 1672 1673 // ------------------------------ Min (Gt, IfThenElse) 1674 1675 // Unsigned 1676 HWY_API Vec256<uint8_t> Min(const Vec256<uint8_t> a, const Vec256<uint8_t> b) { 1677 return Vec256<uint8_t>{_mm256_min_epu8(a.raw, b.raw)}; 1678 } 1679 HWY_API Vec256<uint16_t> Min(const Vec256<uint16_t> a, 1680 const Vec256<uint16_t> b) { 1681 return Vec256<uint16_t>{_mm256_min_epu16(a.raw, b.raw)}; 1682 } 1683 HWY_API Vec256<uint32_t> Min(const Vec256<uint32_t> a, 1684 const Vec256<uint32_t> b) { 1685 return Vec256<uint32_t>{_mm256_min_epu32(a.raw, b.raw)}; 1686 } 1687 HWY_API Vec256<uint64_t> Min(const Vec256<uint64_t> a, 1688 const Vec256<uint64_t> b) { 1689 #if HWY_TARGET <= HWY_AVX3 1690 return Vec256<uint64_t>{_mm256_min_epu64(a.raw, b.raw)}; 1691 #else 1692 const Full256<uint64_t> du; 1693 const Full256<int64_t> di; 1694 const auto msb = Set(du, 1ull << 63); 1695 const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); 1696 return IfThenElse(gt, b, a); 1697 #endif 1698 } 1699 1700 // Signed 1701 HWY_API Vec256<int8_t> Min(const Vec256<int8_t> a, const Vec256<int8_t> b) { 1702 return Vec256<int8_t>{_mm256_min_epi8(a.raw, b.raw)}; 1703 } 1704 HWY_API Vec256<int16_t> Min(const Vec256<int16_t> a, const Vec256<int16_t> b) { 1705 return Vec256<int16_t>{_mm256_min_epi16(a.raw, b.raw)}; 1706 } 1707 HWY_API Vec256<int32_t> Min(const Vec256<int32_t> a, const Vec256<int32_t> b) { 1708 return Vec256<int32_t>{_mm256_min_epi32(a.raw, b.raw)}; 1709 } 1710 HWY_API Vec256<int64_t> Min(const Vec256<int64_t> a, const Vec256<int64_t> b) { 1711 #if HWY_TARGET <= HWY_AVX3 1712 return Vec256<int64_t>{_mm256_min_epi64(a.raw, b.raw)}; 1713 #else 1714 return IfThenElse(a < b, a, b); 1715 #endif 1716 } 1717 1718 // Float 1719 #if HWY_HAVE_FLOAT16 1720 HWY_API Vec256<float16_t> Min(Vec256<float16_t> a, Vec256<float16_t> b) { 1721 return Vec256<float16_t>{_mm256_min_ph(a.raw, b.raw)}; 1722 } 1723 #endif // HWY_HAVE_FLOAT16 1724 HWY_API Vec256<float> Min(const Vec256<float> a, const Vec256<float> b) { 1725 return Vec256<float>{_mm256_min_ps(a.raw, b.raw)}; 1726 } 1727 HWY_API Vec256<double> Min(const Vec256<double> a, const Vec256<double> b) { 1728 return Vec256<double>{_mm256_min_pd(a.raw, b.raw)}; 1729 } 1730 1731 // ------------------------------ Max (Gt, IfThenElse) 1732 1733 // Unsigned 1734 HWY_API Vec256<uint8_t> Max(const Vec256<uint8_t> a, const Vec256<uint8_t> b) { 1735 return Vec256<uint8_t>{_mm256_max_epu8(a.raw, b.raw)}; 1736 } 1737 HWY_API Vec256<uint16_t> Max(const Vec256<uint16_t> a, 1738 const Vec256<uint16_t> b) { 1739 return Vec256<uint16_t>{_mm256_max_epu16(a.raw, b.raw)}; 1740 } 1741 HWY_API Vec256<uint32_t> Max(const Vec256<uint32_t> a, 1742 const Vec256<uint32_t> b) { 1743 return Vec256<uint32_t>{_mm256_max_epu32(a.raw, b.raw)}; 1744 } 1745 HWY_API Vec256<uint64_t> Max(const Vec256<uint64_t> a, 1746 const Vec256<uint64_t> b) { 1747 #if HWY_TARGET <= HWY_AVX3 1748 return Vec256<uint64_t>{_mm256_max_epu64(a.raw, b.raw)}; 1749 #else 1750 const Full256<uint64_t> du; 1751 const Full256<int64_t> di; 1752 const auto msb = Set(du, 1ull << 63); 1753 const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); 1754 return IfThenElse(gt, a, b); 1755 #endif 1756 } 1757 1758 // Signed 1759 HWY_API Vec256<int8_t> Max(const Vec256<int8_t> a, const Vec256<int8_t> b) { 1760 return Vec256<int8_t>{_mm256_max_epi8(a.raw, b.raw)}; 1761 } 1762 HWY_API Vec256<int16_t> Max(const Vec256<int16_t> a, const Vec256<int16_t> b) { 1763 return Vec256<int16_t>{_mm256_max_epi16(a.raw, b.raw)}; 1764 } 1765 HWY_API Vec256<int32_t> Max(const Vec256<int32_t> a, const Vec256<int32_t> b) { 1766 return Vec256<int32_t>{_mm256_max_epi32(a.raw, b.raw)}; 1767 } 1768 HWY_API Vec256<int64_t> Max(const Vec256<int64_t> a, const Vec256<int64_t> b) { 1769 #if HWY_TARGET <= HWY_AVX3 1770 return Vec256<int64_t>{_mm256_max_epi64(a.raw, b.raw)}; 1771 #else 1772 return IfThenElse(a < b, b, a); 1773 #endif 1774 } 1775 1776 // Float 1777 #if HWY_HAVE_FLOAT16 1778 HWY_API Vec256<float16_t> Max(Vec256<float16_t> a, Vec256<float16_t> b) { 1779 return Vec256<float16_t>{_mm256_max_ph(a.raw, b.raw)}; 1780 } 1781 #endif // HWY_HAVE_FLOAT16 1782 HWY_API Vec256<float> Max(const Vec256<float> a, const Vec256<float> b) { 1783 return Vec256<float>{_mm256_max_ps(a.raw, b.raw)}; 1784 } 1785 HWY_API Vec256<double> Max(const Vec256<double> a, const Vec256<double> b) { 1786 return Vec256<double>{_mm256_max_pd(a.raw, b.raw)}; 1787 } 1788 1789 // ------------------------------ MinNumber and MaxNumber 1790 1791 #if HWY_X86_HAVE_AVX10_2_OPS 1792 1793 #if HWY_HAVE_FLOAT16 1794 HWY_API Vec256<float16_t> MinNumber(Vec256<float16_t> a, Vec256<float16_t> b) { 1795 return Vec256<float16_t>{_mm256_minmax_ph(a.raw, b.raw, 0x14)}; 1796 } 1797 #endif 1798 HWY_API Vec256<float> MinNumber(Vec256<float> a, Vec256<float> b) { 1799 return Vec256<float>{_mm256_minmax_ps(a.raw, b.raw, 0x14)}; 1800 } 1801 HWY_API Vec256<double> MinNumber(Vec256<double> a, Vec256<double> b) { 1802 return Vec256<double>{_mm256_minmax_pd(a.raw, b.raw, 0x14)}; 1803 } 1804 1805 #if HWY_HAVE_FLOAT16 1806 HWY_API Vec256<float16_t> MaxNumber(Vec256<float16_t> a, Vec256<float16_t> b) { 1807 return Vec256<float16_t>{_mm256_minmax_ph(a.raw, b.raw, 0x15)}; 1808 } 1809 #endif 1810 HWY_API Vec256<float> MaxNumber(Vec256<float> a, Vec256<float> b) { 1811 return Vec256<float>{_mm256_minmax_ps(a.raw, b.raw, 0x15)}; 1812 } 1813 HWY_API Vec256<double> MaxNumber(Vec256<double> a, Vec256<double> b) { 1814 return Vec256<double>{_mm256_minmax_pd(a.raw, b.raw, 0x15)}; 1815 } 1816 1817 #endif 1818 1819 // ------------------------------ MinMagnitude and MaxMagnitude 1820 1821 #if HWY_X86_HAVE_AVX10_2_OPS 1822 1823 #if HWY_HAVE_FLOAT16 1824 HWY_API Vec256<float16_t> MinMagnitude(Vec256<float16_t> a, 1825 Vec256<float16_t> b) { 1826 return Vec256<float16_t>{_mm256_minmax_ph(a.raw, b.raw, 0x16)}; 1827 } 1828 #endif 1829 HWY_API Vec256<float> MinMagnitude(Vec256<float> a, Vec256<float> b) { 1830 return Vec256<float>{_mm256_minmax_ps(a.raw, b.raw, 0x16)}; 1831 } 1832 HWY_API Vec256<double> MinMagnitude(Vec256<double> a, Vec256<double> b) { 1833 return Vec256<double>{_mm256_minmax_pd(a.raw, b.raw, 0x16)}; 1834 } 1835 1836 #if HWY_HAVE_FLOAT16 1837 HWY_API Vec256<float16_t> MaxMagnitude(Vec256<float16_t> a, 1838 Vec256<float16_t> b) { 1839 return Vec256<float16_t>{_mm256_minmax_ph(a.raw, b.raw, 0x17)}; 1840 } 1841 #endif 1842 HWY_API Vec256<float> MaxMagnitude(Vec256<float> a, Vec256<float> b) { 1843 return Vec256<float>{_mm256_minmax_ps(a.raw, b.raw, 0x17)}; 1844 } 1845 HWY_API Vec256<double> MaxMagnitude(Vec256<double> a, Vec256<double> b) { 1846 return Vec256<double>{_mm256_minmax_pd(a.raw, b.raw, 0x17)}; 1847 } 1848 1849 #endif 1850 1851 // ------------------------------ Iota 1852 1853 namespace detail { 1854 1855 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 1856 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 1857 return VFromD<D>{_mm256_set_epi8( 1858 static_cast<char>(31), static_cast<char>(30), static_cast<char>(29), 1859 static_cast<char>(28), static_cast<char>(27), static_cast<char>(26), 1860 static_cast<char>(25), static_cast<char>(24), static_cast<char>(23), 1861 static_cast<char>(22), static_cast<char>(21), static_cast<char>(20), 1862 static_cast<char>(19), static_cast<char>(18), static_cast<char>(17), 1863 static_cast<char>(16), static_cast<char>(15), static_cast<char>(14), 1864 static_cast<char>(13), static_cast<char>(12), static_cast<char>(11), 1865 static_cast<char>(10), static_cast<char>(9), static_cast<char>(8), 1866 static_cast<char>(7), static_cast<char>(6), static_cast<char>(5), 1867 static_cast<char>(4), static_cast<char>(3), static_cast<char>(2), 1868 static_cast<char>(1), static_cast<char>(0))}; 1869 } 1870 1871 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI16_D(D)> 1872 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 1873 return VFromD<D>{_mm256_set_epi16( 1874 int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12}, int16_t{11}, 1875 int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6}, int16_t{5}, 1876 int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})}; 1877 } 1878 1879 #if HWY_HAVE_FLOAT16 1880 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 1881 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 1882 return VFromD<D>{ 1883 _mm256_set_ph(float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12}, 1884 float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, 1885 float16_t{7}, float16_t{6}, float16_t{5}, float16_t{4}, 1886 float16_t{3}, float16_t{2}, float16_t{1}, float16_t{0})}; 1887 } 1888 #endif // HWY_HAVE_FLOAT16 1889 1890 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 1891 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 1892 return VFromD<D>{_mm256_set_epi32(int32_t{7}, int32_t{6}, int32_t{5}, 1893 int32_t{4}, int32_t{3}, int32_t{2}, 1894 int32_t{1}, int32_t{0})}; 1895 } 1896 1897 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 1898 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 1899 return VFromD<D>{ 1900 _mm256_set_epi64x(int64_t{3}, int64_t{2}, int64_t{1}, int64_t{0})}; 1901 } 1902 1903 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 1904 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 1905 return VFromD<D>{ 1906 _mm256_set_ps(7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f)}; 1907 } 1908 1909 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 1910 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 1911 return VFromD<D>{_mm256_set_pd(3.0, 2.0, 1.0, 0.0)}; 1912 } 1913 1914 } // namespace detail 1915 1916 template <class D, HWY_IF_V_SIZE_D(D, 32), typename T2> 1917 HWY_API VFromD<D> Iota(D d, const T2 first) { 1918 return detail::Iota0(d) + Set(d, ConvertScalarTo<TFromD<D>>(first)); 1919 } 1920 1921 // ------------------------------ FirstN (Iota, Lt) 1922 1923 template <class D, HWY_IF_V_SIZE_D(D, 32), class M = MFromD<D>> 1924 HWY_API M FirstN(const D d, size_t n) { 1925 constexpr size_t kN = MaxLanes(d); 1926 // For AVX3, this ensures `num` <= 255 as required by bzhi, which only looks 1927 // at the lower 8 bits; for AVX2 and below, this ensures `num` fits in TI. 1928 n = HWY_MIN(n, kN); 1929 1930 #if HWY_TARGET <= HWY_AVX3 1931 #if HWY_ARCH_X86_64 1932 const uint64_t all = (1ull << kN) - 1; 1933 return M::FromBits(_bzhi_u64(all, n)); 1934 #else 1935 const uint32_t all = static_cast<uint32_t>((1ull << kN) - 1); 1936 return M::FromBits(_bzhi_u32(all, static_cast<uint32_t>(n))); 1937 #endif // HWY_ARCH_X86_64 1938 #else 1939 const RebindToSigned<decltype(d)> di; // Signed comparisons are cheaper. 1940 using TI = TFromD<decltype(di)>; 1941 return RebindMask(d, detail::Iota0(di) < Set(di, static_cast<TI>(n))); 1942 #endif 1943 } 1944 1945 // ================================================== ARITHMETIC 1946 1947 // ------------------------------ Addition 1948 1949 // Unsigned 1950 HWY_API Vec256<uint8_t> operator+(Vec256<uint8_t> a, Vec256<uint8_t> b) { 1951 return Vec256<uint8_t>{_mm256_add_epi8(a.raw, b.raw)}; 1952 } 1953 HWY_API Vec256<uint16_t> operator+(Vec256<uint16_t> a, Vec256<uint16_t> b) { 1954 return Vec256<uint16_t>{_mm256_add_epi16(a.raw, b.raw)}; 1955 } 1956 HWY_API Vec256<uint32_t> operator+(Vec256<uint32_t> a, Vec256<uint32_t> b) { 1957 return Vec256<uint32_t>{_mm256_add_epi32(a.raw, b.raw)}; 1958 } 1959 HWY_API Vec256<uint64_t> operator+(Vec256<uint64_t> a, Vec256<uint64_t> b) { 1960 return Vec256<uint64_t>{_mm256_add_epi64(a.raw, b.raw)}; 1961 } 1962 1963 // Signed 1964 HWY_API Vec256<int8_t> operator+(Vec256<int8_t> a, Vec256<int8_t> b) { 1965 return Vec256<int8_t>{_mm256_add_epi8(a.raw, b.raw)}; 1966 } 1967 HWY_API Vec256<int16_t> operator+(Vec256<int16_t> a, Vec256<int16_t> b) { 1968 return Vec256<int16_t>{_mm256_add_epi16(a.raw, b.raw)}; 1969 } 1970 HWY_API Vec256<int32_t> operator+(Vec256<int32_t> a, Vec256<int32_t> b) { 1971 return Vec256<int32_t>{_mm256_add_epi32(a.raw, b.raw)}; 1972 } 1973 HWY_API Vec256<int64_t> operator+(Vec256<int64_t> a, Vec256<int64_t> b) { 1974 return Vec256<int64_t>{_mm256_add_epi64(a.raw, b.raw)}; 1975 } 1976 1977 // Float 1978 #if HWY_HAVE_FLOAT16 1979 HWY_API Vec256<float16_t> operator+(Vec256<float16_t> a, Vec256<float16_t> b) { 1980 return Vec256<float16_t>{_mm256_add_ph(a.raw, b.raw)}; 1981 } 1982 #endif // HWY_HAVE_FLOAT16 1983 HWY_API Vec256<float> operator+(Vec256<float> a, Vec256<float> b) { 1984 return Vec256<float>{_mm256_add_ps(a.raw, b.raw)}; 1985 } 1986 HWY_API Vec256<double> operator+(Vec256<double> a, Vec256<double> b) { 1987 return Vec256<double>{_mm256_add_pd(a.raw, b.raw)}; 1988 } 1989 1990 // ------------------------------ Subtraction 1991 1992 // Unsigned 1993 HWY_API Vec256<uint8_t> operator-(Vec256<uint8_t> a, Vec256<uint8_t> b) { 1994 return Vec256<uint8_t>{_mm256_sub_epi8(a.raw, b.raw)}; 1995 } 1996 HWY_API Vec256<uint16_t> operator-(Vec256<uint16_t> a, Vec256<uint16_t> b) { 1997 return Vec256<uint16_t>{_mm256_sub_epi16(a.raw, b.raw)}; 1998 } 1999 HWY_API Vec256<uint32_t> operator-(Vec256<uint32_t> a, Vec256<uint32_t> b) { 2000 return Vec256<uint32_t>{_mm256_sub_epi32(a.raw, b.raw)}; 2001 } 2002 HWY_API Vec256<uint64_t> operator-(Vec256<uint64_t> a, Vec256<uint64_t> b) { 2003 return Vec256<uint64_t>{_mm256_sub_epi64(a.raw, b.raw)}; 2004 } 2005 2006 // Signed 2007 HWY_API Vec256<int8_t> operator-(Vec256<int8_t> a, Vec256<int8_t> b) { 2008 return Vec256<int8_t>{_mm256_sub_epi8(a.raw, b.raw)}; 2009 } 2010 HWY_API Vec256<int16_t> operator-(Vec256<int16_t> a, Vec256<int16_t> b) { 2011 return Vec256<int16_t>{_mm256_sub_epi16(a.raw, b.raw)}; 2012 } 2013 HWY_API Vec256<int32_t> operator-(Vec256<int32_t> a, Vec256<int32_t> b) { 2014 return Vec256<int32_t>{_mm256_sub_epi32(a.raw, b.raw)}; 2015 } 2016 HWY_API Vec256<int64_t> operator-(Vec256<int64_t> a, Vec256<int64_t> b) { 2017 return Vec256<int64_t>{_mm256_sub_epi64(a.raw, b.raw)}; 2018 } 2019 2020 // Float 2021 #if HWY_HAVE_FLOAT16 2022 HWY_API Vec256<float16_t> operator-(Vec256<float16_t> a, Vec256<float16_t> b) { 2023 return Vec256<float16_t>{_mm256_sub_ph(a.raw, b.raw)}; 2024 } 2025 #endif // HWY_HAVE_FLOAT16 2026 HWY_API Vec256<float> operator-(Vec256<float> a, Vec256<float> b) { 2027 return Vec256<float>{_mm256_sub_ps(a.raw, b.raw)}; 2028 } 2029 HWY_API Vec256<double> operator-(Vec256<double> a, Vec256<double> b) { 2030 return Vec256<double>{_mm256_sub_pd(a.raw, b.raw)}; 2031 } 2032 2033 // ------------------------------ AddSub 2034 2035 HWY_API Vec256<float> AddSub(Vec256<float> a, Vec256<float> b) { 2036 return Vec256<float>{_mm256_addsub_ps(a.raw, b.raw)}; 2037 } 2038 HWY_API Vec256<double> AddSub(Vec256<double> a, Vec256<double> b) { 2039 return Vec256<double>{_mm256_addsub_pd(a.raw, b.raw)}; 2040 } 2041 2042 // ------------------------------ PairwiseAdd128/PairwiseSub128 2043 2044 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI16_D(D)> 2045 HWY_API VFromD<D> PairwiseAdd128(D /*d*/, VFromD<D> a, VFromD<D> b) { 2046 return VFromD<D>{_mm256_hadd_epi16(a.raw, b.raw)}; 2047 } 2048 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI16_D(D)> 2049 HWY_API VFromD<D> PairwiseSub128(D /*d*/, VFromD<D> a, VFromD<D> b) { 2050 const DFromV<decltype(a)> d; 2051 const RebindToSigned<decltype(d)> di; 2052 return BitCast(d, 2053 Neg(BitCast(di, VFromD<D>{_mm256_hsub_epi16(a.raw, b.raw)}))); 2054 } 2055 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 2056 HWY_API VFromD<D> PairwiseAdd128(D /*d*/, VFromD<D> a, VFromD<D> b) { 2057 return VFromD<D>{_mm256_hadd_epi32(a.raw, b.raw)}; 2058 } 2059 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 2060 HWY_API VFromD<D> PairwiseSub128(D /*d*/, VFromD<D> a, VFromD<D> b) { 2061 const DFromV<decltype(a)> d; 2062 const RebindToSigned<decltype(d)> di; 2063 return BitCast(d, 2064 Neg(BitCast(di, VFromD<D>{_mm256_hsub_epi32(a.raw, b.raw)}))); 2065 } 2066 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 2067 HWY_API VFromD<D> PairwiseAdd128(D /*d*/, VFromD<D> a, VFromD<D> b) { 2068 return VFromD<D>{_mm256_hadd_ps(a.raw, b.raw)}; 2069 } 2070 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 2071 HWY_API VFromD<D> PairwiseSub128(D /*d*/, VFromD<D> a, VFromD<D> b) { 2072 return Neg(VFromD<D>{_mm256_hsub_ps(a.raw, b.raw)}); 2073 } 2074 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 2075 HWY_API VFromD<D> PairwiseAdd128(D /*d*/, VFromD<D> a, VFromD<D> b) { 2076 return VFromD<D>{_mm256_hadd_pd(a.raw, b.raw)}; 2077 } 2078 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 2079 HWY_API VFromD<D> PairwiseSub128(D /*d*/, VFromD<D> a, VFromD<D> b) { 2080 return Neg(VFromD<D>{_mm256_hsub_pd(a.raw, b.raw)}); 2081 } 2082 2083 // ------------------------------ SumsOf8 2084 HWY_API Vec256<uint64_t> SumsOf8(Vec256<uint8_t> v) { 2085 return Vec256<uint64_t>{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; 2086 } 2087 2088 HWY_API Vec256<uint64_t> SumsOf8AbsDiff(Vec256<uint8_t> a, Vec256<uint8_t> b) { 2089 return Vec256<uint64_t>{_mm256_sad_epu8(a.raw, b.raw)}; 2090 } 2091 2092 // ------------------------------ SumsOf4 2093 #if HWY_TARGET <= HWY_AVX3 2094 namespace detail { 2095 2096 HWY_INLINE Vec256<uint32_t> SumsOf4(hwy::UnsignedTag /*type_tag*/, 2097 hwy::SizeTag<1> /*lane_size_tag*/, 2098 Vec256<uint8_t> v) { 2099 const DFromV<decltype(v)> d; 2100 2101 // _mm256_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be 2102 // zeroed out and the sums of the 4 consecutive lanes are already in the 2103 // even uint16_t lanes of the _mm256_maskz_dbsad_epu8 result. 2104 return Vec256<uint32_t>{_mm256_maskz_dbsad_epu8( 2105 static_cast<__mmask16>(0x5555), v.raw, Zero(d).raw, 0)}; 2106 } 2107 2108 // detail::SumsOf4 for Vec256<int8_t> on AVX3 is implemented in x86_512-inl.h 2109 2110 } // namespace detail 2111 #endif // HWY_TARGET <= HWY_AVX3 2112 2113 // ------------------------------ SumsOfAdjQuadAbsDiff 2114 2115 template <int kAOffset, int kBOffset> 2116 HWY_API Vec256<uint16_t> SumsOfAdjQuadAbsDiff(Vec256<uint8_t> a, 2117 Vec256<uint8_t> b) { 2118 static_assert(0 <= kAOffset && kAOffset <= 1, 2119 "kAOffset must be between 0 and 1"); 2120 static_assert(0 <= kBOffset && kBOffset <= 3, 2121 "kBOffset must be between 0 and 3"); 2122 return Vec256<uint16_t>{_mm256_mpsadbw_epu8( 2123 a.raw, b.raw, 2124 (kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)}; 2125 } 2126 2127 // ------------------------------ SumsOfShuffledQuadAbsDiff 2128 2129 #if HWY_TARGET <= HWY_AVX3 2130 template <int kIdx3, int kIdx2, int kIdx1, int kIdx0> 2131 static Vec256<uint16_t> SumsOfShuffledQuadAbsDiff(Vec256<uint8_t> a, 2132 Vec256<uint8_t> b) { 2133 static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); 2134 static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); 2135 static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); 2136 static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); 2137 return Vec256<uint16_t>{ 2138 _mm256_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; 2139 } 2140 #endif 2141 2142 // ------------------------------ SaturatedAdd 2143 2144 // Returns a + b clamped to the destination range. 2145 2146 // Unsigned 2147 HWY_API Vec256<uint8_t> SaturatedAdd(Vec256<uint8_t> a, Vec256<uint8_t> b) { 2148 return Vec256<uint8_t>{_mm256_adds_epu8(a.raw, b.raw)}; 2149 } 2150 HWY_API Vec256<uint16_t> SaturatedAdd(Vec256<uint16_t> a, Vec256<uint16_t> b) { 2151 return Vec256<uint16_t>{_mm256_adds_epu16(a.raw, b.raw)}; 2152 } 2153 2154 // Signed 2155 HWY_API Vec256<int8_t> SaturatedAdd(Vec256<int8_t> a, Vec256<int8_t> b) { 2156 return Vec256<int8_t>{_mm256_adds_epi8(a.raw, b.raw)}; 2157 } 2158 HWY_API Vec256<int16_t> SaturatedAdd(Vec256<int16_t> a, Vec256<int16_t> b) { 2159 return Vec256<int16_t>{_mm256_adds_epi16(a.raw, b.raw)}; 2160 } 2161 2162 #if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 2163 HWY_API Vec256<int32_t> SaturatedAdd(Vec256<int32_t> a, Vec256<int32_t> b) { 2164 const DFromV<decltype(a)> d; 2165 const auto sum = a + b; 2166 const auto overflow_mask = MaskFromVec( 2167 Vec256<int32_t>{_mm256_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)}); 2168 const auto i32_max = Set(d, LimitsMax<int32_t>()); 2169 const Vec256<int32_t> overflow_result{_mm256_mask_ternarylogic_epi32( 2170 i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; 2171 return IfThenElse(overflow_mask, overflow_result, sum); 2172 } 2173 2174 HWY_API Vec256<int64_t> SaturatedAdd(Vec256<int64_t> a, Vec256<int64_t> b) { 2175 const DFromV<decltype(a)> d; 2176 const auto sum = a + b; 2177 const auto overflow_mask = MaskFromVec( 2178 Vec256<int64_t>{_mm256_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)}); 2179 const auto i64_max = Set(d, LimitsMax<int64_t>()); 2180 const Vec256<int64_t> overflow_result{_mm256_mask_ternarylogic_epi64( 2181 i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; 2182 return IfThenElse(overflow_mask, overflow_result, sum); 2183 } 2184 #endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 2185 2186 // ------------------------------ SaturatedSub 2187 2188 // Returns a - b clamped to the destination range. 2189 2190 // Unsigned 2191 HWY_API Vec256<uint8_t> SaturatedSub(Vec256<uint8_t> a, Vec256<uint8_t> b) { 2192 return Vec256<uint8_t>{_mm256_subs_epu8(a.raw, b.raw)}; 2193 } 2194 HWY_API Vec256<uint16_t> SaturatedSub(Vec256<uint16_t> a, Vec256<uint16_t> b) { 2195 return Vec256<uint16_t>{_mm256_subs_epu16(a.raw, b.raw)}; 2196 } 2197 2198 // Signed 2199 HWY_API Vec256<int8_t> SaturatedSub(Vec256<int8_t> a, Vec256<int8_t> b) { 2200 return Vec256<int8_t>{_mm256_subs_epi8(a.raw, b.raw)}; 2201 } 2202 HWY_API Vec256<int16_t> SaturatedSub(Vec256<int16_t> a, Vec256<int16_t> b) { 2203 return Vec256<int16_t>{_mm256_subs_epi16(a.raw, b.raw)}; 2204 } 2205 2206 #if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 2207 HWY_API Vec256<int32_t> SaturatedSub(Vec256<int32_t> a, Vec256<int32_t> b) { 2208 const DFromV<decltype(a)> d; 2209 const auto diff = a - b; 2210 const auto overflow_mask = MaskFromVec( 2211 Vec256<int32_t>{_mm256_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)}); 2212 const auto i32_max = Set(d, LimitsMax<int32_t>()); 2213 const Vec256<int32_t> overflow_result{_mm256_mask_ternarylogic_epi32( 2214 i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; 2215 return IfThenElse(overflow_mask, overflow_result, diff); 2216 } 2217 2218 HWY_API Vec256<int64_t> SaturatedSub(Vec256<int64_t> a, Vec256<int64_t> b) { 2219 const DFromV<decltype(a)> d; 2220 const auto diff = a - b; 2221 const auto overflow_mask = MaskFromVec( 2222 Vec256<int64_t>{_mm256_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)}); 2223 const auto i64_max = Set(d, LimitsMax<int64_t>()); 2224 const Vec256<int64_t> overflow_result{_mm256_mask_ternarylogic_epi64( 2225 i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; 2226 return IfThenElse(overflow_mask, overflow_result, diff); 2227 } 2228 #endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN 2229 2230 // ------------------------------ Average 2231 2232 // Returns (a + b + 1) / 2 2233 2234 // Unsigned 2235 HWY_API Vec256<uint8_t> AverageRound(Vec256<uint8_t> a, Vec256<uint8_t> b) { 2236 return Vec256<uint8_t>{_mm256_avg_epu8(a.raw, b.raw)}; 2237 } 2238 HWY_API Vec256<uint16_t> AverageRound(Vec256<uint16_t> a, Vec256<uint16_t> b) { 2239 return Vec256<uint16_t>{_mm256_avg_epu16(a.raw, b.raw)}; 2240 } 2241 2242 // ------------------------------ Abs (Sub) 2243 2244 // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. 2245 HWY_API Vec256<int8_t> Abs(Vec256<int8_t> v) { 2246 #if HWY_COMPILER_MSVC 2247 // Workaround for incorrect codegen? (wrong result) 2248 const DFromV<decltype(v)> d; 2249 const auto zero = Zero(d); 2250 return Vec256<int8_t>{_mm256_max_epi8(v.raw, (zero - v).raw)}; 2251 #else 2252 return Vec256<int8_t>{_mm256_abs_epi8(v.raw)}; 2253 #endif 2254 } 2255 HWY_API Vec256<int16_t> Abs(const Vec256<int16_t> v) { 2256 return Vec256<int16_t>{_mm256_abs_epi16(v.raw)}; 2257 } 2258 HWY_API Vec256<int32_t> Abs(const Vec256<int32_t> v) { 2259 return Vec256<int32_t>{_mm256_abs_epi32(v.raw)}; 2260 } 2261 2262 #if HWY_TARGET <= HWY_AVX3 2263 HWY_API Vec256<int64_t> Abs(const Vec256<int64_t> v) { 2264 return Vec256<int64_t>{_mm256_abs_epi64(v.raw)}; 2265 } 2266 #endif 2267 2268 // ------------------------------ Integer multiplication 2269 2270 // Unsigned 2271 HWY_API Vec256<uint16_t> operator*(Vec256<uint16_t> a, Vec256<uint16_t> b) { 2272 return Vec256<uint16_t>{_mm256_mullo_epi16(a.raw, b.raw)}; 2273 } 2274 HWY_API Vec256<uint32_t> operator*(Vec256<uint32_t> a, Vec256<uint32_t> b) { 2275 return Vec256<uint32_t>{_mm256_mullo_epi32(a.raw, b.raw)}; 2276 } 2277 #if HWY_TARGET <= HWY_AVX3 2278 HWY_API Vec256<uint64_t> operator*(Vec256<uint64_t> a, Vec256<uint64_t> b) { 2279 return Vec256<uint64_t>{_mm256_mullo_epi64(a.raw, b.raw)}; 2280 } 2281 #endif 2282 2283 // Signed 2284 HWY_API Vec256<int16_t> operator*(Vec256<int16_t> a, Vec256<int16_t> b) { 2285 return Vec256<int16_t>{_mm256_mullo_epi16(a.raw, b.raw)}; 2286 } 2287 HWY_API Vec256<int32_t> operator*(Vec256<int32_t> a, Vec256<int32_t> b) { 2288 return Vec256<int32_t>{_mm256_mullo_epi32(a.raw, b.raw)}; 2289 } 2290 #if HWY_TARGET <= HWY_AVX3 2291 HWY_API Vec256<int64_t> operator*(Vec256<int64_t> a, Vec256<int64_t> b) { 2292 return Vec256<int64_t>{_mm256_mullo_epi64(a.raw, b.raw)}; 2293 } 2294 #endif 2295 2296 // Returns the upper 16 bits of a * b in each lane. 2297 HWY_API Vec256<uint16_t> MulHigh(Vec256<uint16_t> a, Vec256<uint16_t> b) { 2298 return Vec256<uint16_t>{_mm256_mulhi_epu16(a.raw, b.raw)}; 2299 } 2300 HWY_API Vec256<int16_t> MulHigh(Vec256<int16_t> a, Vec256<int16_t> b) { 2301 return Vec256<int16_t>{_mm256_mulhi_epi16(a.raw, b.raw)}; 2302 } 2303 2304 HWY_API Vec256<int16_t> MulFixedPoint15(Vec256<int16_t> a, Vec256<int16_t> b) { 2305 return Vec256<int16_t>{_mm256_mulhrs_epi16(a.raw, b.raw)}; 2306 } 2307 2308 // Multiplies even lanes (0, 2 ..) and places the double-wide result into 2309 // even and the upper half into its odd neighbor lane. 2310 HWY_API Vec256<int64_t> MulEven(Vec256<int32_t> a, Vec256<int32_t> b) { 2311 return Vec256<int64_t>{_mm256_mul_epi32(a.raw, b.raw)}; 2312 } 2313 HWY_API Vec256<uint64_t> MulEven(Vec256<uint32_t> a, Vec256<uint32_t> b) { 2314 return Vec256<uint64_t>{_mm256_mul_epu32(a.raw, b.raw)}; 2315 } 2316 2317 // ------------------------------ ShiftLeft 2318 2319 #if HWY_TARGET <= HWY_AVX3_DL 2320 namespace detail { 2321 template <typename T> 2322 HWY_API Vec256<T> GaloisAffine(Vec256<T> v, Vec256<uint64_t> matrix) { 2323 return Vec256<T>{_mm256_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; 2324 } 2325 } // namespace detail 2326 #endif // HWY_TARGET <= HWY_AVX3_DL 2327 2328 template <int kBits> 2329 HWY_API Vec256<uint16_t> ShiftLeft(Vec256<uint16_t> v) { 2330 return Vec256<uint16_t>{_mm256_slli_epi16(v.raw, kBits)}; 2331 } 2332 2333 template <int kBits> 2334 HWY_API Vec256<uint32_t> ShiftLeft(Vec256<uint32_t> v) { 2335 return Vec256<uint32_t>{_mm256_slli_epi32(v.raw, kBits)}; 2336 } 2337 2338 template <int kBits> 2339 HWY_API Vec256<uint64_t> ShiftLeft(Vec256<uint64_t> v) { 2340 return Vec256<uint64_t>{_mm256_slli_epi64(v.raw, kBits)}; 2341 } 2342 2343 template <int kBits> 2344 HWY_API Vec256<int16_t> ShiftLeft(Vec256<int16_t> v) { 2345 return Vec256<int16_t>{_mm256_slli_epi16(v.raw, kBits)}; 2346 } 2347 2348 template <int kBits> 2349 HWY_API Vec256<int32_t> ShiftLeft(Vec256<int32_t> v) { 2350 return Vec256<int32_t>{_mm256_slli_epi32(v.raw, kBits)}; 2351 } 2352 2353 template <int kBits> 2354 HWY_API Vec256<int64_t> ShiftLeft(Vec256<int64_t> v) { 2355 return Vec256<int64_t>{_mm256_slli_epi64(v.raw, kBits)}; 2356 } 2357 2358 #if HWY_TARGET > HWY_AVX3_DL 2359 2360 template <int kBits, typename T, HWY_IF_T_SIZE(T, 1)> 2361 HWY_API Vec256<T> ShiftLeft(const Vec256<T> v) { 2362 const Full256<T> d8; 2363 const RepartitionToWide<decltype(d8)> d16; 2364 const auto shifted = BitCast(d8, ShiftLeft<kBits>(BitCast(d16, v))); 2365 return kBits == 1 2366 ? (v + v) 2367 : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF))); 2368 } 2369 2370 #endif // HWY_TARGET > HWY_AVX3_DL 2371 2372 // ------------------------------ ShiftRight 2373 2374 template <int kBits> 2375 HWY_API Vec256<uint16_t> ShiftRight(Vec256<uint16_t> v) { 2376 return Vec256<uint16_t>{_mm256_srli_epi16(v.raw, kBits)}; 2377 } 2378 2379 template <int kBits> 2380 HWY_API Vec256<uint32_t> ShiftRight(Vec256<uint32_t> v) { 2381 return Vec256<uint32_t>{_mm256_srli_epi32(v.raw, kBits)}; 2382 } 2383 2384 template <int kBits> 2385 HWY_API Vec256<uint64_t> ShiftRight(Vec256<uint64_t> v) { 2386 return Vec256<uint64_t>{_mm256_srli_epi64(v.raw, kBits)}; 2387 } 2388 2389 template <int kBits> 2390 HWY_API Vec256<int16_t> ShiftRight(Vec256<int16_t> v) { 2391 return Vec256<int16_t>{_mm256_srai_epi16(v.raw, kBits)}; 2392 } 2393 2394 template <int kBits> 2395 HWY_API Vec256<int32_t> ShiftRight(Vec256<int32_t> v) { 2396 return Vec256<int32_t>{_mm256_srai_epi32(v.raw, kBits)}; 2397 } 2398 2399 #if HWY_TARGET > HWY_AVX3_DL 2400 2401 template <int kBits> 2402 HWY_API Vec256<uint8_t> ShiftRight(Vec256<uint8_t> v) { 2403 const Full256<uint8_t> d8; 2404 // Use raw instead of BitCast to support N=1. 2405 const Vec256<uint8_t> shifted{ShiftRight<kBits>(Vec256<uint16_t>{v.raw}).raw}; 2406 return shifted & Set(d8, 0xFF >> kBits); 2407 } 2408 2409 template <int kBits> 2410 HWY_API Vec256<int8_t> ShiftRight(Vec256<int8_t> v) { 2411 const Full256<int8_t> di; 2412 const Full256<uint8_t> du; 2413 const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); 2414 const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); 2415 return (shifted ^ shifted_sign) - shifted_sign; 2416 } 2417 2418 #endif // HWY_TARGET > HWY_AVX3_DL 2419 2420 // i64 is implemented after BroadcastSignBit. 2421 2422 // ------------------------------ RotateRight 2423 2424 // U8 RotateRight implementation on AVX3_DL is now in x86_512-inl.h as U8 2425 // RotateRight uses detail::GaloisAffine on AVX3_DL 2426 2427 #if HWY_TARGET > HWY_AVX3_DL 2428 template <int kBits> 2429 HWY_API Vec256<uint8_t> RotateRight(const Vec256<uint8_t> v) { 2430 static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); 2431 if (kBits == 0) return v; 2432 // AVX3 does not support 8-bit. 2433 return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(7, 8 - kBits)>(v)); 2434 } 2435 #endif 2436 2437 template <int kBits> 2438 HWY_API Vec256<uint16_t> RotateRight(const Vec256<uint16_t> v) { 2439 static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); 2440 if (kBits == 0) return v; 2441 #if HWY_TARGET <= HWY_AVX3_DL 2442 return Vec256<uint16_t>{_mm256_shrdi_epi16(v.raw, v.raw, kBits)}; 2443 #else 2444 // AVX3 does not support 16-bit. 2445 return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(15, 16 - kBits)>(v)); 2446 #endif 2447 } 2448 2449 template <int kBits> 2450 HWY_API Vec256<uint32_t> RotateRight(const Vec256<uint32_t> v) { 2451 static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); 2452 #if HWY_TARGET <= HWY_AVX3 2453 return Vec256<uint32_t>{_mm256_ror_epi32(v.raw, kBits)}; 2454 #else 2455 if (kBits == 0) return v; 2456 return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(31, 32 - kBits)>(v)); 2457 #endif 2458 } 2459 2460 template <int kBits> 2461 HWY_API Vec256<uint64_t> RotateRight(const Vec256<uint64_t> v) { 2462 static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); 2463 #if HWY_TARGET <= HWY_AVX3 2464 return Vec256<uint64_t>{_mm256_ror_epi64(v.raw, kBits)}; 2465 #else 2466 if (kBits == 0) return v; 2467 return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(63, 64 - kBits)>(v)); 2468 #endif 2469 } 2470 2471 // ------------------------------ Rol/Ror 2472 #if HWY_TARGET <= HWY_AVX3_DL 2473 template <class T, HWY_IF_UI16(T)> 2474 HWY_API Vec256<T> Ror(Vec256<T> a, Vec256<T> b) { 2475 return Vec256<T>{_mm256_shrdv_epi16(a.raw, a.raw, b.raw)}; 2476 } 2477 #endif // HWY_TARGET <= HWY_AVX3_DL 2478 2479 #if HWY_TARGET <= HWY_AVX3 2480 2481 template <class T, HWY_IF_UI32(T)> 2482 HWY_API Vec256<T> Rol(Vec256<T> a, Vec256<T> b) { 2483 return Vec256<T>{_mm256_rolv_epi32(a.raw, b.raw)}; 2484 } 2485 2486 template <class T, HWY_IF_UI32(T)> 2487 HWY_API Vec256<T> Ror(Vec256<T> a, Vec256<T> b) { 2488 return Vec256<T>{_mm256_rorv_epi32(a.raw, b.raw)}; 2489 } 2490 2491 template <class T, HWY_IF_UI64(T)> 2492 HWY_API Vec256<T> Rol(Vec256<T> a, Vec256<T> b) { 2493 return Vec256<T>{_mm256_rolv_epi64(a.raw, b.raw)}; 2494 } 2495 2496 template <class T, HWY_IF_UI64(T)> 2497 HWY_API Vec256<T> Ror(Vec256<T> a, Vec256<T> b) { 2498 return Vec256<T>{_mm256_rorv_epi64(a.raw, b.raw)}; 2499 } 2500 2501 #endif 2502 2503 // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) 2504 2505 HWY_API Vec256<int8_t> BroadcastSignBit(const Vec256<int8_t> v) { 2506 const DFromV<decltype(v)> d; 2507 return VecFromMask(v < Zero(d)); 2508 } 2509 2510 HWY_API Vec256<int16_t> BroadcastSignBit(const Vec256<int16_t> v) { 2511 return ShiftRight<15>(v); 2512 } 2513 2514 HWY_API Vec256<int32_t> BroadcastSignBit(const Vec256<int32_t> v) { 2515 return ShiftRight<31>(v); 2516 } 2517 2518 #if HWY_TARGET <= HWY_AVX3 2519 2520 template <int kBits> 2521 HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) { 2522 return Vec256<int64_t>{ 2523 _mm256_srai_epi64(v.raw, static_cast<Shift64Count>(kBits))}; 2524 } 2525 2526 HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) { 2527 return ShiftRight<63>(v); 2528 } 2529 2530 #else // AVX2 2531 2532 // Unlike above, this will be used to implement int64_t ShiftRight. 2533 HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) { 2534 const DFromV<decltype(v)> d; 2535 return VecFromMask(v < Zero(d)); 2536 } 2537 2538 template <int kBits> 2539 HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) { 2540 const Full256<int64_t> di; 2541 const Full256<uint64_t> du; 2542 const auto right = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); 2543 const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); 2544 return right | sign; 2545 } 2546 2547 #endif // #if HWY_TARGET <= HWY_AVX3 2548 2549 // ------------------------------ IfNegativeThenElse (BroadcastSignBit) 2550 HWY_API Vec256<int8_t> IfNegativeThenElse(Vec256<int8_t> v, Vec256<int8_t> yes, 2551 Vec256<int8_t> no) { 2552 // int8: AVX2 IfThenElse only looks at the MSB. 2553 return IfThenElse(MaskFromVec(v), yes, no); 2554 } 2555 2556 template <typename T, HWY_IF_T_SIZE(T, 2)> 2557 HWY_API Vec256<T> IfNegativeThenElse(Vec256<T> v, Vec256<T> yes, Vec256<T> no) { 2558 static_assert(IsSigned<T>(), "Only works for signed/float"); 2559 2560 #if HWY_TARGET <= HWY_AVX3 2561 const auto mask = MaskFromVec(v); 2562 #else 2563 // 16-bit: no native blendv on AVX2, so copy sign to lower byte's MSB. 2564 const DFromV<decltype(v)> d; 2565 const RebindToSigned<decltype(d)> di; 2566 const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); 2567 #endif 2568 2569 return IfThenElse(mask, yes, no); 2570 } 2571 2572 template <typename T, HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))> 2573 HWY_API Vec256<T> IfNegativeThenElse(Vec256<T> v, Vec256<T> yes, Vec256<T> no) { 2574 static_assert(IsSigned<T>(), "Only works for signed/float"); 2575 2576 #if HWY_TARGET <= HWY_AVX3 2577 // No need to cast to float on AVX3 as IfThenElse only looks at the MSB on 2578 // AVX3 2579 return IfThenElse(MaskFromVec(v), yes, no); 2580 #else 2581 const DFromV<decltype(v)> d; 2582 const RebindToFloat<decltype(d)> df; 2583 // 32/64-bit: use float IfThenElse, which only looks at the MSB. 2584 const MFromD<decltype(df)> msb = MaskFromVec(BitCast(df, v)); 2585 return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); 2586 #endif 2587 } 2588 2589 // ------------------------------ IfNegativeThenNegOrUndefIfZero 2590 2591 HWY_API Vec256<int8_t> IfNegativeThenNegOrUndefIfZero(Vec256<int8_t> mask, 2592 Vec256<int8_t> v) { 2593 return Vec256<int8_t>{_mm256_sign_epi8(v.raw, mask.raw)}; 2594 } 2595 2596 HWY_API Vec256<int16_t> IfNegativeThenNegOrUndefIfZero(Vec256<int16_t> mask, 2597 Vec256<int16_t> v) { 2598 return Vec256<int16_t>{_mm256_sign_epi16(v.raw, mask.raw)}; 2599 } 2600 2601 HWY_API Vec256<int32_t> IfNegativeThenNegOrUndefIfZero(Vec256<int32_t> mask, 2602 Vec256<int32_t> v) { 2603 return Vec256<int32_t>{_mm256_sign_epi32(v.raw, mask.raw)}; 2604 } 2605 2606 // ------------------------------ ShiftLeftSame 2607 2608 // Disable sign conversion warnings for GCC debug intrinsics. 2609 HWY_DIAGNOSTICS(push) 2610 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 2611 2612 HWY_API Vec256<uint16_t> ShiftLeftSame(const Vec256<uint16_t> v, 2613 const int bits) { 2614 #if HWY_COMPILER_GCC 2615 if (__builtin_constant_p(bits)) { 2616 return Vec256<uint16_t>{_mm256_slli_epi16(v.raw, bits)}; 2617 } 2618 #endif 2619 return Vec256<uint16_t>{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 2620 } 2621 HWY_API Vec256<uint32_t> ShiftLeftSame(const Vec256<uint32_t> v, 2622 const int bits) { 2623 #if HWY_COMPILER_GCC 2624 if (__builtin_constant_p(bits)) { 2625 return Vec256<uint32_t>{_mm256_slli_epi32(v.raw, bits)}; 2626 } 2627 #endif 2628 return Vec256<uint32_t>{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 2629 } 2630 HWY_API Vec256<uint64_t> ShiftLeftSame(const Vec256<uint64_t> v, 2631 const int bits) { 2632 #if HWY_COMPILER_GCC 2633 if (__builtin_constant_p(bits)) { 2634 return Vec256<uint64_t>{_mm256_slli_epi64(v.raw, bits)}; 2635 } 2636 #endif 2637 return Vec256<uint64_t>{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 2638 } 2639 2640 HWY_API Vec256<int16_t> ShiftLeftSame(const Vec256<int16_t> v, const int bits) { 2641 #if HWY_COMPILER_GCC 2642 if (__builtin_constant_p(bits)) { 2643 return Vec256<int16_t>{_mm256_slli_epi16(v.raw, bits)}; 2644 } 2645 #endif 2646 return Vec256<int16_t>{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 2647 } 2648 2649 HWY_API Vec256<int32_t> ShiftLeftSame(const Vec256<int32_t> v, const int bits) { 2650 #if HWY_COMPILER_GCC 2651 if (__builtin_constant_p(bits)) { 2652 return Vec256<int32_t>{_mm256_slli_epi32(v.raw, bits)}; 2653 } 2654 #endif 2655 return Vec256<int32_t>{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 2656 } 2657 2658 HWY_API Vec256<int64_t> ShiftLeftSame(const Vec256<int64_t> v, const int bits) { 2659 #if HWY_COMPILER_GCC 2660 if (__builtin_constant_p(bits)) { 2661 return Vec256<int64_t>{_mm256_slli_epi64(v.raw, bits)}; 2662 } 2663 #endif 2664 return Vec256<int64_t>{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 2665 } 2666 2667 template <typename T, HWY_IF_T_SIZE(T, 1)> 2668 HWY_API Vec256<T> ShiftLeftSame(const Vec256<T> v, const int bits) { 2669 const Full256<T> d8; 2670 const RepartitionToWide<decltype(d8)> d16; 2671 const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); 2672 return shifted & Set(d8, static_cast<T>((0xFF << bits) & 0xFF)); 2673 } 2674 2675 // ------------------------------ ShiftRightSame (BroadcastSignBit) 2676 2677 HWY_API Vec256<uint16_t> ShiftRightSame(const Vec256<uint16_t> v, 2678 const int bits) { 2679 #if HWY_COMPILER_GCC 2680 if (__builtin_constant_p(bits)) { 2681 return Vec256<uint16_t>{_mm256_srli_epi16(v.raw, bits)}; 2682 } 2683 #endif 2684 return Vec256<uint16_t>{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 2685 } 2686 HWY_API Vec256<uint32_t> ShiftRightSame(const Vec256<uint32_t> v, 2687 const int bits) { 2688 #if HWY_COMPILER_GCC 2689 if (__builtin_constant_p(bits)) { 2690 return Vec256<uint32_t>{_mm256_srli_epi32(v.raw, bits)}; 2691 } 2692 #endif 2693 return Vec256<uint32_t>{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 2694 } 2695 HWY_API Vec256<uint64_t> ShiftRightSame(const Vec256<uint64_t> v, 2696 const int bits) { 2697 #if HWY_COMPILER_GCC 2698 if (__builtin_constant_p(bits)) { 2699 return Vec256<uint64_t>{_mm256_srli_epi64(v.raw, bits)}; 2700 } 2701 #endif 2702 return Vec256<uint64_t>{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 2703 } 2704 2705 HWY_API Vec256<uint8_t> ShiftRightSame(Vec256<uint8_t> v, const int bits) { 2706 const Full256<uint8_t> d8; 2707 const RepartitionToWide<decltype(d8)> d16; 2708 const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); 2709 return shifted & Set(d8, static_cast<uint8_t>(0xFF >> bits)); 2710 } 2711 2712 HWY_API Vec256<int16_t> ShiftRightSame(const Vec256<int16_t> v, 2713 const int bits) { 2714 #if HWY_COMPILER_GCC 2715 if (__builtin_constant_p(bits)) { 2716 return Vec256<int16_t>{_mm256_srai_epi16(v.raw, bits)}; 2717 } 2718 #endif 2719 return Vec256<int16_t>{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 2720 } 2721 2722 HWY_API Vec256<int32_t> ShiftRightSame(const Vec256<int32_t> v, 2723 const int bits) { 2724 #if HWY_COMPILER_GCC 2725 if (__builtin_constant_p(bits)) { 2726 return Vec256<int32_t>{_mm256_srai_epi32(v.raw, bits)}; 2727 } 2728 #endif 2729 return Vec256<int32_t>{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 2730 } 2731 HWY_API Vec256<int64_t> ShiftRightSame(const Vec256<int64_t> v, 2732 const int bits) { 2733 #if HWY_TARGET <= HWY_AVX3 2734 #if HWY_COMPILER_GCC 2735 if (__builtin_constant_p(bits)) { 2736 return Vec256<int64_t>{ 2737 _mm256_srai_epi64(v.raw, static_cast<Shift64Count>(bits))}; 2738 } 2739 #endif 2740 return Vec256<int64_t>{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 2741 #else 2742 const Full256<int64_t> di; 2743 const Full256<uint64_t> du; 2744 const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); 2745 const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); 2746 return right | sign; 2747 #endif 2748 } 2749 2750 HWY_API Vec256<int8_t> ShiftRightSame(Vec256<int8_t> v, const int bits) { 2751 const Full256<int8_t> di; 2752 const Full256<uint8_t> du; 2753 const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); 2754 const auto shifted_sign = 2755 BitCast(di, Set(du, static_cast<uint8_t>(0x80 >> bits))); 2756 return (shifted ^ shifted_sign) - shifted_sign; 2757 } 2758 2759 HWY_DIAGNOSTICS(pop) 2760 2761 // ------------------------------ Neg (Xor, Sub) 2762 2763 // Tag dispatch instead of SFINAE for MSVC 2017 compatibility 2764 namespace detail { 2765 2766 template <typename T> 2767 HWY_INLINE Vec256<T> Neg(hwy::FloatTag /*tag*/, const Vec256<T> v) { 2768 const DFromV<decltype(v)> d; 2769 return Xor(v, SignBit(d)); 2770 } 2771 2772 template <typename T> 2773 HWY_INLINE Vec256<T> Neg(hwy::SpecialTag /*tag*/, const Vec256<T> v) { 2774 const DFromV<decltype(v)> d; 2775 return Xor(v, SignBit(d)); 2776 } 2777 2778 // Not floating-point 2779 template <typename T> 2780 HWY_INLINE Vec256<T> Neg(hwy::SignedTag /*tag*/, const Vec256<T> v) { 2781 const DFromV<decltype(v)> d; 2782 return Zero(d) - v; 2783 } 2784 2785 } // namespace detail 2786 2787 template <typename T> 2788 HWY_API Vec256<T> Neg(const Vec256<T> v) { 2789 return detail::Neg(hwy::TypeTag<T>(), v); 2790 } 2791 2792 // ------------------------------ Floating-point mul / div 2793 2794 #if HWY_HAVE_FLOAT16 2795 HWY_API Vec256<float16_t> operator*(Vec256<float16_t> a, Vec256<float16_t> b) { 2796 return Vec256<float16_t>{_mm256_mul_ph(a.raw, b.raw)}; 2797 } 2798 #endif // HWY_HAVE_FLOAT16 2799 HWY_API Vec256<float> operator*(Vec256<float> a, Vec256<float> b) { 2800 return Vec256<float>{_mm256_mul_ps(a.raw, b.raw)}; 2801 } 2802 HWY_API Vec256<double> operator*(Vec256<double> a, Vec256<double> b) { 2803 return Vec256<double>{_mm256_mul_pd(a.raw, b.raw)}; 2804 } 2805 2806 #if HWY_TARGET <= HWY_AVX3 2807 2808 #if HWY_HAVE_FLOAT16 2809 HWY_API Vec256<float16_t> MulByFloorPow2(Vec256<float16_t> a, 2810 Vec256<float16_t> b) { 2811 return Vec256<float16_t>{_mm256_scalef_ph(a.raw, b.raw)}; 2812 } 2813 #endif 2814 2815 HWY_API Vec256<float> MulByFloorPow2(Vec256<float> a, Vec256<float> b) { 2816 return Vec256<float>{_mm256_scalef_ps(a.raw, b.raw)}; 2817 } 2818 2819 HWY_API Vec256<double> MulByFloorPow2(Vec256<double> a, Vec256<double> b) { 2820 return Vec256<double>{_mm256_scalef_pd(a.raw, b.raw)}; 2821 } 2822 2823 #endif // HWY_TARGET <= HWY_AVX3 2824 2825 #if HWY_HAVE_FLOAT16 2826 HWY_API Vec256<float16_t> operator/(Vec256<float16_t> a, Vec256<float16_t> b) { 2827 return Vec256<float16_t>{_mm256_div_ph(a.raw, b.raw)}; 2828 } 2829 #endif // HWY_HAVE_FLOAT16 2830 HWY_API Vec256<float> operator/(Vec256<float> a, Vec256<float> b) { 2831 return Vec256<float>{_mm256_div_ps(a.raw, b.raw)}; 2832 } 2833 HWY_API Vec256<double> operator/(Vec256<double> a, Vec256<double> b) { 2834 return Vec256<double>{_mm256_div_pd(a.raw, b.raw)}; 2835 } 2836 2837 // Approximate reciprocal 2838 #if HWY_HAVE_FLOAT16 2839 HWY_API Vec256<float16_t> ApproximateReciprocal(Vec256<float16_t> v) { 2840 return Vec256<float16_t>{_mm256_rcp_ph(v.raw)}; 2841 } 2842 #endif // HWY_HAVE_FLOAT16 2843 2844 HWY_API Vec256<float> ApproximateReciprocal(Vec256<float> v) { 2845 return Vec256<float>{_mm256_rcp_ps(v.raw)}; 2846 } 2847 2848 #if HWY_TARGET <= HWY_AVX3 2849 HWY_API Vec256<double> ApproximateReciprocal(Vec256<double> v) { 2850 return Vec256<double>{_mm256_rcp14_pd(v.raw)}; 2851 } 2852 #endif 2853 2854 // ------------------------------ GetExponent 2855 2856 #if HWY_TARGET <= HWY_AVX3 2857 2858 #if HWY_HAVE_FLOAT16 2859 template <class V, HWY_IF_F16(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)> 2860 HWY_API V GetExponent(V v) { 2861 return V{_mm256_getexp_ph(v.raw)}; 2862 } 2863 #endif 2864 template <class V, HWY_IF_F32(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)> 2865 HWY_API V GetExponent(V v) { 2866 return V{_mm256_getexp_ps(v.raw)}; 2867 } 2868 template <class V, HWY_IF_F64(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)> 2869 HWY_API V GetExponent(V v) { 2870 return V{_mm256_getexp_pd(v.raw)}; 2871 } 2872 2873 #endif 2874 2875 // ------------------------------ MaskedMinOr 2876 2877 #if HWY_TARGET <= HWY_AVX3 2878 2879 template <typename T, HWY_IF_U8(T)> 2880 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2881 Vec256<T> b) { 2882 return Vec256<T>{_mm256_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; 2883 } 2884 template <typename T, HWY_IF_I8(T)> 2885 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2886 Vec256<T> b) { 2887 return Vec256<T>{_mm256_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; 2888 } 2889 2890 template <typename T, HWY_IF_U16(T)> 2891 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2892 Vec256<T> b) { 2893 return Vec256<T>{_mm256_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; 2894 } 2895 template <typename T, HWY_IF_I16(T)> 2896 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2897 Vec256<T> b) { 2898 return Vec256<T>{_mm256_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; 2899 } 2900 2901 template <typename T, HWY_IF_U32(T)> 2902 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2903 Vec256<T> b) { 2904 return Vec256<T>{_mm256_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; 2905 } 2906 template <typename T, HWY_IF_I32(T)> 2907 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2908 Vec256<T> b) { 2909 return Vec256<T>{_mm256_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; 2910 } 2911 2912 template <typename T, HWY_IF_U64(T)> 2913 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2914 Vec256<T> b) { 2915 return Vec256<T>{_mm256_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; 2916 } 2917 template <typename T, HWY_IF_I64(T)> 2918 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2919 Vec256<T> b) { 2920 return Vec256<T>{_mm256_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; 2921 } 2922 2923 template <typename T, HWY_IF_F32(T)> 2924 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2925 Vec256<T> b) { 2926 return Vec256<T>{_mm256_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; 2927 } 2928 2929 template <typename T, HWY_IF_F64(T)> 2930 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2931 Vec256<T> b) { 2932 return Vec256<T>{_mm256_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; 2933 } 2934 2935 #if HWY_HAVE_FLOAT16 2936 template <typename T, HWY_IF_F16(T)> 2937 HWY_API Vec256<T> MaskedMinOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2938 Vec256<T> b) { 2939 return Vec256<T>{_mm256_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; 2940 } 2941 #endif // HWY_HAVE_FLOAT16 2942 2943 // ------------------------------ MaskedMaxOr 2944 2945 template <typename T, HWY_IF_U8(T)> 2946 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2947 Vec256<T> b) { 2948 return Vec256<T>{_mm256_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; 2949 } 2950 template <typename T, HWY_IF_I8(T)> 2951 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2952 Vec256<T> b) { 2953 return Vec256<T>{_mm256_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; 2954 } 2955 2956 template <typename T, HWY_IF_U16(T)> 2957 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2958 Vec256<T> b) { 2959 return Vec256<T>{_mm256_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; 2960 } 2961 template <typename T, HWY_IF_I16(T)> 2962 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2963 Vec256<T> b) { 2964 return Vec256<T>{_mm256_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; 2965 } 2966 2967 template <typename T, HWY_IF_U32(T)> 2968 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2969 Vec256<T> b) { 2970 return Vec256<T>{_mm256_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; 2971 } 2972 template <typename T, HWY_IF_I32(T)> 2973 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2974 Vec256<T> b) { 2975 return Vec256<T>{_mm256_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; 2976 } 2977 2978 template <typename T, HWY_IF_U64(T)> 2979 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2980 Vec256<T> b) { 2981 return Vec256<T>{_mm256_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; 2982 } 2983 template <typename T, HWY_IF_I64(T)> 2984 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2985 Vec256<T> b) { 2986 return Vec256<T>{_mm256_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; 2987 } 2988 2989 template <typename T, HWY_IF_F32(T)> 2990 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2991 Vec256<T> b) { 2992 return Vec256<T>{_mm256_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; 2993 } 2994 2995 template <typename T, HWY_IF_F64(T)> 2996 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 2997 Vec256<T> b) { 2998 return Vec256<T>{_mm256_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; 2999 } 3000 3001 #if HWY_HAVE_FLOAT16 3002 template <typename T, HWY_IF_F16(T)> 3003 HWY_API Vec256<T> MaskedMaxOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3004 Vec256<T> b) { 3005 return Vec256<T>{_mm256_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; 3006 } 3007 #endif // HWY_HAVE_FLOAT16 3008 3009 // ------------------------------ MaskedAddOr 3010 3011 template <typename T, HWY_IF_UI8(T)> 3012 HWY_API Vec256<T> MaskedAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3013 Vec256<T> b) { 3014 return Vec256<T>{_mm256_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; 3015 } 3016 3017 template <typename T, HWY_IF_UI16(T)> 3018 HWY_API Vec256<T> MaskedAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3019 Vec256<T> b) { 3020 return Vec256<T>{_mm256_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; 3021 } 3022 3023 template <typename T, HWY_IF_UI32(T)> 3024 HWY_API Vec256<T> MaskedAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3025 Vec256<T> b) { 3026 return Vec256<T>{_mm256_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; 3027 } 3028 3029 template <typename T, HWY_IF_UI64(T)> 3030 HWY_API Vec256<T> MaskedAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3031 Vec256<T> b) { 3032 return Vec256<T>{_mm256_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; 3033 } 3034 3035 template <typename T, HWY_IF_F32(T)> 3036 HWY_API Vec256<T> MaskedAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3037 Vec256<T> b) { 3038 return Vec256<T>{_mm256_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; 3039 } 3040 3041 template <typename T, HWY_IF_F64(T)> 3042 HWY_API Vec256<T> MaskedAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3043 Vec256<T> b) { 3044 return Vec256<T>{_mm256_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; 3045 } 3046 3047 #if HWY_HAVE_FLOAT16 3048 template <typename T, HWY_IF_F16(T)> 3049 HWY_API Vec256<T> MaskedAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3050 Vec256<T> b) { 3051 return Vec256<T>{_mm256_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; 3052 } 3053 #endif // HWY_HAVE_FLOAT16 3054 3055 // ------------------------------ MaskedSubOr 3056 3057 template <typename T, HWY_IF_UI8(T)> 3058 HWY_API Vec256<T> MaskedSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3059 Vec256<T> b) { 3060 return Vec256<T>{_mm256_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; 3061 } 3062 3063 template <typename T, HWY_IF_UI16(T)> 3064 HWY_API Vec256<T> MaskedSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3065 Vec256<T> b) { 3066 return Vec256<T>{_mm256_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; 3067 } 3068 3069 template <typename T, HWY_IF_UI32(T)> 3070 HWY_API Vec256<T> MaskedSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3071 Vec256<T> b) { 3072 return Vec256<T>{_mm256_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; 3073 } 3074 3075 template <typename T, HWY_IF_UI64(T)> 3076 HWY_API Vec256<T> MaskedSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3077 Vec256<T> b) { 3078 return Vec256<T>{_mm256_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; 3079 } 3080 3081 template <typename T, HWY_IF_F32(T)> 3082 HWY_API Vec256<T> MaskedSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3083 Vec256<T> b) { 3084 return Vec256<T>{_mm256_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; 3085 } 3086 3087 template <typename T, HWY_IF_F64(T)> 3088 HWY_API Vec256<T> MaskedSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3089 Vec256<T> b) { 3090 return Vec256<T>{_mm256_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; 3091 } 3092 3093 #if HWY_HAVE_FLOAT16 3094 template <typename T, HWY_IF_F16(T)> 3095 HWY_API Vec256<T> MaskedSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3096 Vec256<T> b) { 3097 return Vec256<T>{_mm256_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; 3098 } 3099 #endif // HWY_HAVE_FLOAT16 3100 3101 // ------------------------------ MaskedMulOr 3102 3103 HWY_API Vec256<float> MaskedMulOr(Vec256<float> no, Mask256<float> m, 3104 Vec256<float> a, Vec256<float> b) { 3105 return Vec256<float>{_mm256_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; 3106 } 3107 3108 HWY_API Vec256<double> MaskedMulOr(Vec256<double> no, Mask256<double> m, 3109 Vec256<double> a, Vec256<double> b) { 3110 return Vec256<double>{_mm256_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; 3111 } 3112 3113 #if HWY_HAVE_FLOAT16 3114 HWY_API Vec256<float16_t> MaskedMulOr(Vec256<float16_t> no, 3115 Mask256<float16_t> m, Vec256<float16_t> a, 3116 Vec256<float16_t> b) { 3117 return Vec256<float16_t>{_mm256_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; 3118 } 3119 #endif // HWY_HAVE_FLOAT16 3120 3121 // ------------------------------ MaskedDivOr 3122 3123 HWY_API Vec256<float> MaskedDivOr(Vec256<float> no, Mask256<float> m, 3124 Vec256<float> a, Vec256<float> b) { 3125 return Vec256<float>{_mm256_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; 3126 } 3127 3128 HWY_API Vec256<double> MaskedDivOr(Vec256<double> no, Mask256<double> m, 3129 Vec256<double> a, Vec256<double> b) { 3130 return Vec256<double>{_mm256_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; 3131 } 3132 3133 #if HWY_HAVE_FLOAT16 3134 HWY_API Vec256<float16_t> MaskedDivOr(Vec256<float16_t> no, 3135 Mask256<float16_t> m, Vec256<float16_t> a, 3136 Vec256<float16_t> b) { 3137 return Vec256<float16_t>{_mm256_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; 3138 } 3139 #endif // HWY_HAVE_FLOAT16 3140 3141 // ------------------------------ MaskedSatAddOr 3142 3143 template <typename T, HWY_IF_I8(T)> 3144 HWY_API Vec256<T> MaskedSatAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3145 Vec256<T> b) { 3146 return Vec256<T>{_mm256_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; 3147 } 3148 3149 template <typename T, HWY_IF_U8(T)> 3150 HWY_API Vec256<T> MaskedSatAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3151 Vec256<T> b) { 3152 return Vec256<T>{_mm256_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; 3153 } 3154 3155 template <typename T, HWY_IF_I16(T)> 3156 HWY_API Vec256<T> MaskedSatAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3157 Vec256<T> b) { 3158 return Vec256<T>{_mm256_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; 3159 } 3160 3161 template <typename T, HWY_IF_U16(T)> 3162 HWY_API Vec256<T> MaskedSatAddOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3163 Vec256<T> b) { 3164 return Vec256<T>{_mm256_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; 3165 } 3166 3167 // ------------------------------ MaskedSatSubOr 3168 3169 template <typename T, HWY_IF_I8(T)> 3170 HWY_API Vec256<T> MaskedSatSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3171 Vec256<T> b) { 3172 return Vec256<T>{_mm256_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; 3173 } 3174 3175 template <typename T, HWY_IF_U8(T)> 3176 HWY_API Vec256<T> MaskedSatSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3177 Vec256<T> b) { 3178 return Vec256<T>{_mm256_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; 3179 } 3180 3181 template <typename T, HWY_IF_I16(T)> 3182 HWY_API Vec256<T> MaskedSatSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3183 Vec256<T> b) { 3184 return Vec256<T>{_mm256_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; 3185 } 3186 3187 template <typename T, HWY_IF_U16(T)> 3188 HWY_API Vec256<T> MaskedSatSubOr(Vec256<T> no, Mask256<T> m, Vec256<T> a, 3189 Vec256<T> b) { 3190 return Vec256<T>{_mm256_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; 3191 } 3192 3193 #endif // HWY_TARGET <= HWY_AVX3 3194 3195 // ------------------------------ Floating-point multiply-add variants 3196 3197 #if HWY_HAVE_FLOAT16 3198 3199 HWY_API Vec256<float16_t> MulAdd(Vec256<float16_t> mul, Vec256<float16_t> x, 3200 Vec256<float16_t> add) { 3201 return Vec256<float16_t>{_mm256_fmadd_ph(mul.raw, x.raw, add.raw)}; 3202 } 3203 3204 HWY_API Vec256<float16_t> NegMulAdd(Vec256<float16_t> mul, Vec256<float16_t> x, 3205 Vec256<float16_t> add) { 3206 return Vec256<float16_t>{_mm256_fnmadd_ph(mul.raw, x.raw, add.raw)}; 3207 } 3208 3209 HWY_API Vec256<float16_t> MulSub(Vec256<float16_t> mul, Vec256<float16_t> x, 3210 Vec256<float16_t> sub) { 3211 return Vec256<float16_t>{_mm256_fmsub_ph(mul.raw, x.raw, sub.raw)}; 3212 } 3213 3214 HWY_API Vec256<float16_t> NegMulSub(Vec256<float16_t> mul, Vec256<float16_t> x, 3215 Vec256<float16_t> sub) { 3216 return Vec256<float16_t>{_mm256_fnmsub_ph(mul.raw, x.raw, sub.raw)}; 3217 } 3218 3219 #endif // HWY_HAVE_FLOAT16 3220 3221 HWY_API Vec256<float> MulAdd(Vec256<float> mul, Vec256<float> x, 3222 Vec256<float> add) { 3223 #ifdef HWY_DISABLE_BMI2_FMA 3224 return mul * x + add; 3225 #else 3226 return Vec256<float>{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; 3227 #endif 3228 } 3229 HWY_API Vec256<double> MulAdd(Vec256<double> mul, Vec256<double> x, 3230 Vec256<double> add) { 3231 #ifdef HWY_DISABLE_BMI2_FMA 3232 return mul * x + add; 3233 #else 3234 return Vec256<double>{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; 3235 #endif 3236 } 3237 3238 HWY_API Vec256<float> NegMulAdd(Vec256<float> mul, Vec256<float> x, 3239 Vec256<float> add) { 3240 #ifdef HWY_DISABLE_BMI2_FMA 3241 return add - mul * x; 3242 #else 3243 return Vec256<float>{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; 3244 #endif 3245 } 3246 HWY_API Vec256<double> NegMulAdd(Vec256<double> mul, Vec256<double> x, 3247 Vec256<double> add) { 3248 #ifdef HWY_DISABLE_BMI2_FMA 3249 return add - mul * x; 3250 #else 3251 return Vec256<double>{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; 3252 #endif 3253 } 3254 3255 HWY_API Vec256<float> MulSub(Vec256<float> mul, Vec256<float> x, 3256 Vec256<float> sub) { 3257 #ifdef HWY_DISABLE_BMI2_FMA 3258 return mul * x - sub; 3259 #else 3260 return Vec256<float>{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; 3261 #endif 3262 } 3263 HWY_API Vec256<double> MulSub(Vec256<double> mul, Vec256<double> x, 3264 Vec256<double> sub) { 3265 #ifdef HWY_DISABLE_BMI2_FMA 3266 return mul * x - sub; 3267 #else 3268 return Vec256<double>{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; 3269 #endif 3270 } 3271 3272 HWY_API Vec256<float> NegMulSub(Vec256<float> mul, Vec256<float> x, 3273 Vec256<float> sub) { 3274 #ifdef HWY_DISABLE_BMI2_FMA 3275 return Neg(mul * x) - sub; 3276 #else 3277 return Vec256<float>{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; 3278 #endif 3279 } 3280 HWY_API Vec256<double> NegMulSub(Vec256<double> mul, Vec256<double> x, 3281 Vec256<double> sub) { 3282 #ifdef HWY_DISABLE_BMI2_FMA 3283 return Neg(mul * x) - sub; 3284 #else 3285 return Vec256<double>{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; 3286 #endif 3287 } 3288 3289 #if HWY_HAVE_FLOAT16 3290 HWY_API Vec256<float16_t> MulAddSub(Vec256<float16_t> mul, Vec256<float16_t> x, 3291 Vec256<float16_t> sub_or_add) { 3292 return Vec256<float16_t>{_mm256_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; 3293 } 3294 #endif // HWY_HAVE_FLOAT16 3295 3296 HWY_API Vec256<float> MulAddSub(Vec256<float> mul, Vec256<float> x, 3297 Vec256<float> sub_or_add) { 3298 #ifdef HWY_DISABLE_BMI2_FMA 3299 return AddSub(mul * x, sub_or_add); 3300 #else 3301 return Vec256<float>{_mm256_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; 3302 #endif 3303 } 3304 3305 HWY_API Vec256<double> MulAddSub(Vec256<double> mul, Vec256<double> x, 3306 Vec256<double> sub_or_add) { 3307 #ifdef HWY_DISABLE_BMI2_FMA 3308 return AddSub(mul * x, sub_or_add); 3309 #else 3310 return Vec256<double>{_mm256_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; 3311 #endif 3312 } 3313 3314 // ------------------------------ Floating-point square root 3315 3316 // Full precision square root 3317 #if HWY_HAVE_FLOAT16 3318 HWY_API Vec256<float16_t> Sqrt(Vec256<float16_t> v) { 3319 return Vec256<float16_t>{_mm256_sqrt_ph(v.raw)}; 3320 } 3321 #endif // HWY_HAVE_FLOAT16 3322 HWY_API Vec256<float> Sqrt(Vec256<float> v) { 3323 return Vec256<float>{_mm256_sqrt_ps(v.raw)}; 3324 } 3325 HWY_API Vec256<double> Sqrt(Vec256<double> v) { 3326 return Vec256<double>{_mm256_sqrt_pd(v.raw)}; 3327 } 3328 3329 // Approximate reciprocal square root 3330 #if HWY_HAVE_FLOAT16 3331 HWY_API Vec256<float16_t> ApproximateReciprocalSqrt(Vec256<float16_t> v) { 3332 return Vec256<float16_t>{_mm256_rsqrt_ph(v.raw)}; 3333 } 3334 #endif 3335 HWY_API Vec256<float> ApproximateReciprocalSqrt(Vec256<float> v) { 3336 return Vec256<float>{_mm256_rsqrt_ps(v.raw)}; 3337 } 3338 3339 #if HWY_TARGET <= HWY_AVX3 3340 HWY_API Vec256<double> ApproximateReciprocalSqrt(Vec256<double> v) { 3341 #if HWY_COMPILER_MSVC 3342 const DFromV<decltype(v)> d; 3343 return Vec256<double>{_mm256_mask_rsqrt14_pd( 3344 Undefined(d).raw, static_cast<__mmask8>(0xFF), v.raw)}; 3345 #else 3346 return Vec256<double>{_mm256_rsqrt14_pd(v.raw)}; 3347 #endif 3348 } 3349 #endif 3350 3351 // ------------------------------ Floating-point rounding 3352 3353 // Toward nearest integer, tie to even 3354 #if HWY_HAVE_FLOAT16 3355 HWY_API Vec256<float16_t> Round(Vec256<float16_t> v) { 3356 return Vec256<float16_t>{_mm256_roundscale_ph( 3357 v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; 3358 } 3359 #endif // HWY_HAVE_FLOAT16 3360 HWY_API Vec256<float> Round(Vec256<float> v) { 3361 return Vec256<float>{ 3362 _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; 3363 } 3364 HWY_API Vec256<double> Round(Vec256<double> v) { 3365 return Vec256<double>{ 3366 _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; 3367 } 3368 3369 // Toward zero, aka truncate 3370 #if HWY_HAVE_FLOAT16 3371 HWY_API Vec256<float16_t> Trunc(Vec256<float16_t> v) { 3372 return Vec256<float16_t>{ 3373 _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; 3374 } 3375 #endif // HWY_HAVE_FLOAT16 3376 HWY_API Vec256<float> Trunc(Vec256<float> v) { 3377 return Vec256<float>{ 3378 _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; 3379 } 3380 HWY_API Vec256<double> Trunc(Vec256<double> v) { 3381 return Vec256<double>{ 3382 _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; 3383 } 3384 3385 // Toward +infinity, aka ceiling 3386 #if HWY_HAVE_FLOAT16 3387 HWY_API Vec256<float16_t> Ceil(Vec256<float16_t> v) { 3388 return Vec256<float16_t>{ 3389 _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; 3390 } 3391 #endif // HWY_HAVE_FLOAT16 3392 HWY_API Vec256<float> Ceil(Vec256<float> v) { 3393 return Vec256<float>{ 3394 _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; 3395 } 3396 HWY_API Vec256<double> Ceil(Vec256<double> v) { 3397 return Vec256<double>{ 3398 _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; 3399 } 3400 3401 // Toward -infinity, aka floor 3402 #if HWY_HAVE_FLOAT16 3403 HWY_API Vec256<float16_t> Floor(Vec256<float16_t> v) { 3404 return Vec256<float16_t>{ 3405 _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; 3406 } 3407 #endif // HWY_HAVE_FLOAT16 3408 HWY_API Vec256<float> Floor(Vec256<float> v) { 3409 return Vec256<float>{ 3410 _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; 3411 } 3412 HWY_API Vec256<double> Floor(Vec256<double> v) { 3413 return Vec256<double>{ 3414 _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; 3415 } 3416 3417 // ------------------------------ Floating-point classification 3418 3419 #if HWY_HAVE_FLOAT16 || HWY_IDE 3420 3421 HWY_API Mask256<float16_t> IsNaN(Vec256<float16_t> v) { 3422 return Mask256<float16_t>{_mm256_fpclass_ph_mask( 3423 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; 3424 } 3425 3426 HWY_API Mask256<float16_t> IsEitherNaN(Vec256<float16_t> a, 3427 Vec256<float16_t> b) { 3428 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 3429 HWY_DIAGNOSTICS(push) 3430 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 3431 return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; 3432 HWY_DIAGNOSTICS(pop) 3433 } 3434 3435 HWY_API Mask256<float16_t> IsInf(Vec256<float16_t> v) { 3436 return Mask256<float16_t>{_mm256_fpclass_ph_mask( 3437 v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; 3438 } 3439 3440 HWY_API Mask256<float16_t> IsFinite(Vec256<float16_t> v) { 3441 // fpclass doesn't have a flag for positive, so we have to check for inf/NaN 3442 // and negate the mask. 3443 return Not(Mask256<float16_t>{_mm256_fpclass_ph_mask( 3444 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | 3445 HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); 3446 } 3447 3448 #endif // HWY_HAVE_FLOAT16 3449 3450 HWY_API Mask256<float> IsNaN(Vec256<float> v) { 3451 #if HWY_TARGET <= HWY_AVX3 3452 return Mask256<float>{_mm256_fpclass_ps_mask( 3453 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; 3454 #else 3455 return Mask256<float>{_mm256_cmp_ps(v.raw, v.raw, _CMP_UNORD_Q)}; 3456 #endif 3457 } 3458 HWY_API Mask256<double> IsNaN(Vec256<double> v) { 3459 #if HWY_TARGET <= HWY_AVX3 3460 return Mask256<double>{_mm256_fpclass_pd_mask( 3461 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; 3462 #else 3463 return Mask256<double>{_mm256_cmp_pd(v.raw, v.raw, _CMP_UNORD_Q)}; 3464 #endif 3465 } 3466 3467 HWY_API Mask256<float> IsEitherNaN(Vec256<float> a, Vec256<float> b) { 3468 #if HWY_TARGET <= HWY_AVX3 3469 return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; 3470 #else 3471 return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_UNORD_Q)}; 3472 #endif 3473 } 3474 3475 HWY_API Mask256<double> IsEitherNaN(Vec256<double> a, Vec256<double> b) { 3476 #if HWY_TARGET <= HWY_AVX3 3477 return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; 3478 #else 3479 return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_UNORD_Q)}; 3480 #endif 3481 } 3482 3483 #if HWY_TARGET <= HWY_AVX3 3484 3485 HWY_API Mask256<float> IsInf(Vec256<float> v) { 3486 return Mask256<float>{_mm256_fpclass_ps_mask( 3487 v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; 3488 } 3489 HWY_API Mask256<double> IsInf(Vec256<double> v) { 3490 return Mask256<double>{_mm256_fpclass_pd_mask( 3491 v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; 3492 } 3493 3494 HWY_API Mask256<float> IsFinite(Vec256<float> v) { 3495 // fpclass doesn't have a flag for positive, so we have to check for inf/NaN 3496 // and negate the mask. 3497 return Not(Mask256<float>{_mm256_fpclass_ps_mask( 3498 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | 3499 HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); 3500 } 3501 HWY_API Mask256<double> IsFinite(Vec256<double> v) { 3502 return Not(Mask256<double>{_mm256_fpclass_pd_mask( 3503 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | 3504 HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); 3505 } 3506 3507 #endif // HWY_TARGET <= HWY_AVX3 3508 3509 // ================================================== MEMORY 3510 3511 // ------------------------------ Load 3512 3513 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3514 HWY_API VFromD<D> Load(D /* tag */, const TFromD<D>* HWY_RESTRICT aligned) { 3515 return VFromD<D>{ 3516 _mm256_load_si256(reinterpret_cast<const __m256i*>(aligned))}; 3517 } 3518 // bfloat16_t is handled by x86_128-inl.h. 3519 #if HWY_HAVE_FLOAT16 3520 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 3521 HWY_API Vec256<float16_t> Load(D /* tag */, 3522 const float16_t* HWY_RESTRICT aligned) { 3523 return Vec256<float16_t>{_mm256_load_ph(aligned)}; 3524 } 3525 #endif 3526 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3527 HWY_API Vec256<float> Load(D /* tag */, const float* HWY_RESTRICT aligned) { 3528 return Vec256<float>{_mm256_load_ps(aligned)}; 3529 } 3530 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3531 HWY_API Vec256<double> Load(D /* tag */, const double* HWY_RESTRICT aligned) { 3532 return Vec256<double>{_mm256_load_pd(aligned)}; 3533 } 3534 3535 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3536 HWY_API VFromD<D> LoadU(D /* tag */, const TFromD<D>* HWY_RESTRICT p) { 3537 return VFromD<D>{_mm256_loadu_si256(reinterpret_cast<const __m256i*>(p))}; 3538 } 3539 // bfloat16_t is handled by x86_128-inl.h. 3540 #if HWY_HAVE_FLOAT16 3541 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 3542 HWY_API Vec256<float16_t> LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { 3543 return Vec256<float16_t>{_mm256_loadu_ph(p)}; 3544 } 3545 #endif 3546 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3547 HWY_API Vec256<float> LoadU(D /* tag */, const float* HWY_RESTRICT p) { 3548 return Vec256<float>{_mm256_loadu_ps(p)}; 3549 } 3550 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3551 HWY_API Vec256<double> LoadU(D /* tag */, const double* HWY_RESTRICT p) { 3552 return Vec256<double>{_mm256_loadu_pd(p)}; 3553 } 3554 3555 // ------------------------------ MaskedLoad 3556 3557 #if HWY_TARGET <= HWY_AVX3 3558 3559 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 3560 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, 3561 const TFromD<D>* HWY_RESTRICT p) { 3562 return VFromD<D>{_mm256_maskz_loadu_epi8(m.raw, p)}; 3563 } 3564 3565 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 3566 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d, 3567 const TFromD<D>* HWY_RESTRICT p) { 3568 const RebindToUnsigned<decltype(d)> du; // for float16_t 3569 return BitCast(d, VFromD<decltype(du)>{_mm256_maskz_loadu_epi16(m.raw, p)}); 3570 } 3571 3572 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 3573 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, 3574 const TFromD<D>* HWY_RESTRICT p) { 3575 return VFromD<D>{_mm256_maskz_loadu_epi32(m.raw, p)}; 3576 } 3577 3578 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 3579 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, 3580 const TFromD<D>* HWY_RESTRICT p) { 3581 return VFromD<D>{_mm256_maskz_loadu_epi64(m.raw, p)}; 3582 } 3583 3584 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3585 HWY_API Vec256<float> MaskedLoad(Mask256<float> m, D /* tag */, 3586 const float* HWY_RESTRICT p) { 3587 return Vec256<float>{_mm256_maskz_loadu_ps(m.raw, p)}; 3588 } 3589 3590 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3591 HWY_API Vec256<double> MaskedLoad(Mask256<double> m, D /* tag */, 3592 const double* HWY_RESTRICT p) { 3593 return Vec256<double>{_mm256_maskz_loadu_pd(m.raw, p)}; 3594 } 3595 3596 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 3597 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, 3598 const TFromD<D>* HWY_RESTRICT p) { 3599 return VFromD<D>{_mm256_mask_loadu_epi8(v.raw, m.raw, p)}; 3600 } 3601 3602 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 3603 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D d, 3604 const TFromD<D>* HWY_RESTRICT p) { 3605 const RebindToUnsigned<decltype(d)> du; // for float16_t 3606 return BitCast(d, VFromD<decltype(du)>{ 3607 _mm256_mask_loadu_epi16(BitCast(du, v).raw, m.raw, p)}); 3608 } 3609 3610 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 3611 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, 3612 const TFromD<D>* HWY_RESTRICT p) { 3613 return VFromD<D>{_mm256_mask_loadu_epi32(v.raw, m.raw, p)}; 3614 } 3615 3616 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 3617 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, 3618 const TFromD<D>* HWY_RESTRICT p) { 3619 return VFromD<D>{_mm256_mask_loadu_epi64(v.raw, m.raw, p)}; 3620 } 3621 3622 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3623 HWY_API Vec256<float> MaskedLoadOr(VFromD<D> v, Mask256<float> m, D /* tag */, 3624 const float* HWY_RESTRICT p) { 3625 return Vec256<float>{_mm256_mask_loadu_ps(v.raw, m.raw, p)}; 3626 } 3627 3628 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3629 HWY_API Vec256<double> MaskedLoadOr(VFromD<D> v, Mask256<double> m, D /* tag */, 3630 const double* HWY_RESTRICT p) { 3631 return Vec256<double>{_mm256_mask_loadu_pd(v.raw, m.raw, p)}; 3632 } 3633 3634 #else // AVX2 3635 3636 // There is no maskload_epi8/16, so blend instead. 3637 template <class D, HWY_IF_V_SIZE_D(D, 32), 3638 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 1) | (1 << 2))> 3639 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d, 3640 const TFromD<D>* HWY_RESTRICT p) { 3641 return IfThenElseZero(m, LoadU(d, p)); 3642 } 3643 3644 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 3645 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, 3646 const TFromD<D>* HWY_RESTRICT p) { 3647 auto pi = reinterpret_cast<const int*>(p); // NOLINT 3648 return VFromD<D>{_mm256_maskload_epi32(pi, m.raw)}; 3649 } 3650 3651 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 3652 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, 3653 const TFromD<D>* HWY_RESTRICT p) { 3654 auto pi = reinterpret_cast<const long long*>(p); // NOLINT 3655 return VFromD<D>{_mm256_maskload_epi64(pi, m.raw)}; 3656 } 3657 3658 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3659 HWY_API Vec256<float> MaskedLoad(Mask256<float> m, D d, 3660 const float* HWY_RESTRICT p) { 3661 const Vec256<int32_t> mi = 3662 BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); 3663 return Vec256<float>{_mm256_maskload_ps(p, mi.raw)}; 3664 } 3665 3666 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3667 HWY_API Vec256<double> MaskedLoad(Mask256<double> m, D d, 3668 const double* HWY_RESTRICT p) { 3669 const Vec256<int64_t> mi = 3670 BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); 3671 return Vec256<double>{_mm256_maskload_pd(p, mi.raw)}; 3672 } 3673 3674 #endif 3675 3676 // ------------------------------ LoadDup128 3677 3678 // Loads 128 bit and duplicates into both 128-bit halves. This avoids the 3679 // 3-cycle cost of moving data between 128-bit halves and avoids port 5. 3680 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> 3681 HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) { 3682 const RebindToUnsigned<decltype(d)> du; 3683 const Full128<TFromD<D>> d128; 3684 const RebindToUnsigned<decltype(d128)> du128; 3685 const __m128i v128 = BitCast(du128, LoadU(d128, p)).raw; 3686 #if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 3687 // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note 3688 // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the 3689 // upper half undefined) is fine because we're overwriting that anyway. 3690 // This workaround seems in turn to generate incorrect code in MSVC 2022 3691 // (19.31), so use broadcastsi128 there. 3692 return BitCast(d, VFromD<decltype(du)>{_mm256_inserti128_si256( 3693 _mm256_castsi128_si256(v128), v128, 1)}); 3694 #else 3695 // The preferred path. This is perhaps surprising, because vbroadcasti128 3696 // with xmm input has 7 cycle latency on Intel, but Clang >= 7 is able to 3697 // pattern-match this to vbroadcastf128 with a memory operand as desired. 3698 return BitCast(d, VFromD<decltype(du)>{_mm256_broadcastsi128_si256(v128)}); 3699 #endif 3700 } 3701 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3702 HWY_API Vec256<float> LoadDup128(D /* tag */, const float* HWY_RESTRICT p) { 3703 #if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 3704 const Full128<float> d128; 3705 const __m128 v128 = LoadU(d128, p).raw; 3706 return Vec256<float>{ 3707 _mm256_insertf128_ps(_mm256_castps128_ps256(v128), v128, 1)}; 3708 #else 3709 return Vec256<float>{_mm256_broadcast_ps(reinterpret_cast<const __m128*>(p))}; 3710 #endif 3711 } 3712 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3713 HWY_API Vec256<double> LoadDup128(D /* tag */, const double* HWY_RESTRICT p) { 3714 #if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 3715 const Full128<double> d128; 3716 const __m128d v128 = LoadU(d128, p).raw; 3717 return Vec256<double>{ 3718 _mm256_insertf128_pd(_mm256_castpd128_pd256(v128), v128, 1)}; 3719 #else 3720 return Vec256<double>{ 3721 _mm256_broadcast_pd(reinterpret_cast<const __m128d*>(p))}; 3722 #endif 3723 } 3724 3725 // ------------------------------ Store 3726 3727 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3728 HWY_API void Store(VFromD<D> v, D /* tag */, TFromD<D>* HWY_RESTRICT aligned) { 3729 _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); 3730 } 3731 #if HWY_HAVE_FLOAT16 3732 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 3733 HWY_API void Store(Vec256<float16_t> v, D /* tag */, 3734 float16_t* HWY_RESTRICT aligned) { 3735 _mm256_store_ph(aligned, v.raw); 3736 } 3737 #endif // HWY_HAVE_FLOAT16 3738 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3739 HWY_API void Store(Vec256<float> v, D /* tag */, float* HWY_RESTRICT aligned) { 3740 _mm256_store_ps(aligned, v.raw); 3741 } 3742 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3743 HWY_API void Store(Vec256<double> v, D /* tag */, 3744 double* HWY_RESTRICT aligned) { 3745 _mm256_store_pd(aligned, v.raw); 3746 } 3747 3748 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3749 HWY_API void StoreU(VFromD<D> v, D /* tag */, TFromD<D>* HWY_RESTRICT p) { 3750 _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); 3751 } 3752 #if HWY_HAVE_FLOAT16 3753 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 3754 HWY_API void StoreU(Vec256<float16_t> v, D /* tag */, 3755 float16_t* HWY_RESTRICT p) { 3756 _mm256_storeu_ph(p, v.raw); 3757 } 3758 #endif 3759 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3760 HWY_API void StoreU(Vec256<float> v, D /* tag */, float* HWY_RESTRICT p) { 3761 _mm256_storeu_ps(p, v.raw); 3762 } 3763 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3764 HWY_API void StoreU(Vec256<double> v, D /* tag */, double* HWY_RESTRICT p) { 3765 _mm256_storeu_pd(p, v.raw); 3766 } 3767 3768 // ------------------------------ BlendedStore 3769 3770 #if HWY_TARGET <= HWY_AVX3 3771 3772 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 3773 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, 3774 TFromD<D>* HWY_RESTRICT p) { 3775 _mm256_mask_storeu_epi8(p, m.raw, v.raw); 3776 } 3777 3778 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 3779 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d, 3780 TFromD<D>* HWY_RESTRICT p) { 3781 const RebindToUnsigned<decltype(d)> du; // for float16_t 3782 _mm256_mask_storeu_epi16(reinterpret_cast<uint16_t*>(p), 3783 RebindMask(du, m).raw, BitCast(du, v).raw); 3784 } 3785 3786 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 3787 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, 3788 TFromD<D>* HWY_RESTRICT p) { 3789 _mm256_mask_storeu_epi32(p, m.raw, v.raw); 3790 } 3791 3792 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 3793 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, 3794 TFromD<D>* HWY_RESTRICT p) { 3795 _mm256_mask_storeu_epi64(p, m.raw, v.raw); 3796 } 3797 3798 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3799 HWY_API void BlendedStore(Vec256<float> v, Mask256<float> m, D /* tag */, 3800 float* HWY_RESTRICT p) { 3801 _mm256_mask_storeu_ps(p, m.raw, v.raw); 3802 } 3803 3804 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3805 HWY_API void BlendedStore(Vec256<double> v, Mask256<double> m, D /* tag */, 3806 double* HWY_RESTRICT p) { 3807 _mm256_mask_storeu_pd(p, m.raw, v.raw); 3808 } 3809 3810 #else // AVX2 3811 3812 // Intel SDM says "No AC# reported for any mask bit combinations". However, AMD 3813 // allows AC# if "Alignment checking enabled and: 256-bit memory operand not 3814 // 32-byte aligned". Fortunately AC# is not enabled by default and requires both 3815 // OS support (CR0) and the application to set rflags.AC. We assume these remain 3816 // disabled because x86/x64 code and compiler output often contain misaligned 3817 // scalar accesses, which would also fault. 3818 // 3819 // Caveat: these are slow on AMD Jaguar/Bulldozer. 3820 3821 template <class D, HWY_IF_V_SIZE_D(D, 32), 3822 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 1) | (1 << 2))> 3823 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d, 3824 TFromD<D>* HWY_RESTRICT p) { 3825 // There is no maskload_epi8/16. Blending is also unsafe because loading a 3826 // full vector that crosses the array end causes asan faults. Resort to scalar 3827 // code; the caller should instead use memcpy, assuming m is FirstN(d, n). 3828 const RebindToUnsigned<decltype(d)> du; 3829 using TU = TFromD<decltype(du)>; 3830 alignas(32) TU buf[MaxLanes(d)]; 3831 alignas(32) TU mask[MaxLanes(d)]; 3832 Store(BitCast(du, v), du, buf); 3833 Store(BitCast(du, VecFromMask(d, m)), du, mask); 3834 for (size_t i = 0; i < MaxLanes(d); ++i) { 3835 if (mask[i]) { 3836 CopySameSize(buf + i, p + i); 3837 } 3838 } 3839 } 3840 3841 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 3842 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, 3843 TFromD<D>* HWY_RESTRICT p) { 3844 auto pi = reinterpret_cast<int*>(p); // NOLINT 3845 _mm256_maskstore_epi32(pi, m.raw, v.raw); 3846 } 3847 3848 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 3849 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, 3850 TFromD<D>* HWY_RESTRICT p) { 3851 auto pi = reinterpret_cast<long long*>(p); // NOLINT 3852 _mm256_maskstore_epi64(pi, m.raw, v.raw); 3853 } 3854 3855 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3856 HWY_API void BlendedStore(Vec256<float> v, Mask256<float> m, D d, 3857 float* HWY_RESTRICT p) { 3858 const Vec256<int32_t> mi = 3859 BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); 3860 _mm256_maskstore_ps(p, mi.raw, v.raw); 3861 } 3862 3863 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3864 HWY_API void BlendedStore(Vec256<double> v, Mask256<double> m, D d, 3865 double* HWY_RESTRICT p) { 3866 const Vec256<int64_t> mi = 3867 BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); 3868 _mm256_maskstore_pd(p, mi.raw, v.raw); 3869 } 3870 3871 #endif 3872 3873 // ------------------------------ Non-temporal stores 3874 3875 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> 3876 HWY_API void Stream(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT aligned) { 3877 const RebindToUnsigned<decltype(d)> du; // for float16_t 3878 _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), BitCast(du, v).raw); 3879 } 3880 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3881 HWY_API void Stream(Vec256<float> v, D /* tag */, float* HWY_RESTRICT aligned) { 3882 _mm256_stream_ps(aligned, v.raw); 3883 } 3884 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3885 HWY_API void Stream(Vec256<double> v, D /* tag */, 3886 double* HWY_RESTRICT aligned) { 3887 _mm256_stream_pd(aligned, v.raw); 3888 } 3889 3890 // ------------------------------ ScatterOffset 3891 3892 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 3893 HWY_DIAGNOSTICS(push) 3894 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 3895 3896 #if HWY_TARGET <= HWY_AVX3 3897 3898 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 3899 HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, 3900 TFromD<D>* HWY_RESTRICT base, 3901 Vec256<int32_t> offset) { 3902 _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); 3903 } 3904 3905 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 3906 HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, 3907 TFromD<D>* HWY_RESTRICT base, 3908 Vec256<int64_t> offset) { 3909 _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); 3910 } 3911 3912 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3913 HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, float* HWY_RESTRICT base, 3914 const Vec256<int32_t> offset) { 3915 _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); 3916 } 3917 3918 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3919 HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, double* HWY_RESTRICT base, 3920 const Vec256<int64_t> offset) { 3921 _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); 3922 } 3923 3924 // ------------------------------ ScatterIndex 3925 3926 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 3927 HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, 3928 TFromD<D>* HWY_RESTRICT base, 3929 VFromD<RebindToSigned<D>> index) { 3930 _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); 3931 } 3932 3933 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 3934 HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, 3935 TFromD<D>* HWY_RESTRICT base, 3936 VFromD<RebindToSigned<D>> index) { 3937 _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); 3938 } 3939 3940 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3941 HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, float* HWY_RESTRICT base, 3942 VFromD<RebindToSigned<D>> index) { 3943 _mm256_i32scatter_ps(base, index.raw, v.raw, 4); 3944 } 3945 3946 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3947 HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, double* HWY_RESTRICT base, 3948 VFromD<RebindToSigned<D>> index) { 3949 _mm256_i64scatter_pd(base, index.raw, v.raw, 8); 3950 } 3951 3952 // ------------------------------ MaskedScatterIndex 3953 3954 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 3955 HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, 3956 TFromD<D>* HWY_RESTRICT base, 3957 VFromD<RebindToSigned<D>> index) { 3958 _mm256_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, 4); 3959 } 3960 3961 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 3962 HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, 3963 TFromD<D>* HWY_RESTRICT base, 3964 VFromD<RebindToSigned<D>> index) { 3965 _mm256_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, 8); 3966 } 3967 3968 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3969 HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, 3970 float* HWY_RESTRICT base, 3971 VFromD<RebindToSigned<D>> index) { 3972 _mm256_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, 4); 3973 } 3974 3975 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3976 HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, 3977 double* HWY_RESTRICT base, 3978 VFromD<RebindToSigned<D>> index) { 3979 _mm256_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, 8); 3980 } 3981 3982 #endif // HWY_TARGET <= HWY_AVX3 3983 3984 // ------------------------------ Gather 3985 3986 namespace detail { 3987 3988 template <int kScale, typename T, HWY_IF_UI32(T)> 3989 HWY_INLINE Vec256<T> NativeGather256(const T* HWY_RESTRICT base, 3990 Vec256<int32_t> indices) { 3991 return Vec256<T>{_mm256_i32gather_epi32( 3992 reinterpret_cast<const int32_t*>(base), indices.raw, kScale)}; 3993 } 3994 3995 template <int kScale, typename T, HWY_IF_UI64(T)> 3996 HWY_INLINE Vec256<T> NativeGather256(const T* HWY_RESTRICT base, 3997 Vec256<int64_t> indices) { 3998 return Vec256<T>{_mm256_i64gather_epi64( 3999 reinterpret_cast<const GatherIndex64*>(base), indices.raw, kScale)}; 4000 } 4001 4002 template <int kScale> 4003 HWY_API Vec256<float> NativeGather256(const float* HWY_RESTRICT base, 4004 Vec256<int32_t> indices) { 4005 return Vec256<float>{_mm256_i32gather_ps(base, indices.raw, kScale)}; 4006 } 4007 4008 template <int kScale> 4009 HWY_API Vec256<double> NativeGather256(const double* HWY_RESTRICT base, 4010 Vec256<int64_t> indices) { 4011 return Vec256<double>{_mm256_i64gather_pd(base, indices.raw, kScale)}; 4012 } 4013 4014 } // namespace detail 4015 4016 template <class D, HWY_IF_V_SIZE_D(D, 32)> 4017 HWY_API VFromD<D> GatherOffset(D /*d*/, const TFromD<D>* HWY_RESTRICT base, 4018 VFromD<RebindToSigned<D>> offsets) { 4019 return detail::NativeGather256<1>(base, offsets); 4020 } 4021 4022 template <class D, HWY_IF_V_SIZE_D(D, 32)> 4023 HWY_API VFromD<D> GatherIndex(D /*d*/, const TFromD<D>* HWY_RESTRICT base, 4024 VFromD<RebindToSigned<D>> indices) { 4025 return detail::NativeGather256<sizeof(TFromD<D>)>(base, indices); 4026 } 4027 4028 // ------------------------------ MaskedGatherIndexOr 4029 4030 namespace detail { 4031 4032 template <int kScale, typename T, HWY_IF_UI32(T)> 4033 HWY_INLINE Vec256<T> NativeMaskedGatherOr256(Vec256<T> no, Mask256<T> m, 4034 const T* HWY_RESTRICT base, 4035 Vec256<int32_t> indices) { 4036 #if HWY_TARGET <= HWY_AVX3 4037 return Vec256<T>{_mm256_mmask_i32gather_epi32( 4038 no.raw, m.raw, indices.raw, reinterpret_cast<const int32_t*>(base), 4039 kScale)}; 4040 #else 4041 return Vec256<T>{_mm256_mask_i32gather_epi32( 4042 no.raw, reinterpret_cast<const int32_t*>(base), indices.raw, m.raw, 4043 kScale)}; 4044 #endif 4045 } 4046 4047 template <int kScale, typename T, HWY_IF_UI64(T)> 4048 HWY_INLINE Vec256<T> NativeMaskedGatherOr256(Vec256<T> no, Mask256<T> m, 4049 const T* HWY_RESTRICT base, 4050 Vec256<int64_t> indices) { 4051 #if HWY_TARGET <= HWY_AVX3 4052 return Vec256<T>{_mm256_mmask_i64gather_epi64( 4053 no.raw, m.raw, indices.raw, reinterpret_cast<const GatherIndex64*>(base), 4054 kScale)}; 4055 #else 4056 // For reasons unknown, _mm256_mask_i64gather_epi64 returns all-zeros. 4057 const Full256<T> d; 4058 const Full256<double> dd; 4059 return BitCast(d, 4060 Vec256<double>{_mm256_mask_i64gather_pd( 4061 BitCast(dd, no).raw, reinterpret_cast<const double*>(base), 4062 indices.raw, RebindMask(dd, m).raw, kScale)}); 4063 #endif 4064 } 4065 4066 template <int kScale> 4067 HWY_API Vec256<float> NativeMaskedGatherOr256(Vec256<float> no, 4068 Mask256<float> m, 4069 const float* HWY_RESTRICT base, 4070 Vec256<int32_t> indices) { 4071 #if HWY_TARGET <= HWY_AVX3 4072 return Vec256<float>{ 4073 _mm256_mmask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; 4074 #else 4075 return Vec256<float>{ 4076 _mm256_mask_i32gather_ps(no.raw, base, indices.raw, m.raw, kScale)}; 4077 #endif 4078 } 4079 4080 template <int kScale> 4081 HWY_API Vec256<double> NativeMaskedGatherOr256(Vec256<double> no, 4082 Mask256<double> m, 4083 const double* HWY_RESTRICT base, 4084 Vec256<int64_t> indices) { 4085 #if HWY_TARGET <= HWY_AVX3 4086 return Vec256<double>{ 4087 _mm256_mmask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; 4088 #else 4089 return Vec256<double>{ 4090 _mm256_mask_i64gather_pd(no.raw, base, indices.raw, m.raw, kScale)}; 4091 #endif 4092 } 4093 4094 } // namespace detail 4095 4096 template <class D, HWY_IF_V_SIZE_D(D, 32)> 4097 HWY_API VFromD<D> MaskedGatherIndexOr(VFromD<D> no, MFromD<D> m, D /*d*/, 4098 const TFromD<D>* HWY_RESTRICT base, 4099 VFromD<RebindToSigned<D>> indices) { 4100 return detail::NativeMaskedGatherOr256<sizeof(TFromD<D>)>(no, m, base, 4101 indices); 4102 } 4103 4104 HWY_DIAGNOSTICS(pop) 4105 4106 // ================================================== SWIZZLE 4107 4108 // ------------------------------ LowerHalf 4109 4110 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 4111 HWY_API VFromD<D> LowerHalf(D /* tag */, VFromD<Twice<D>> v) { 4112 return VFromD<D>{_mm256_castsi256_si128(v.raw)}; 4113 } 4114 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_BF16_D(D)> 4115 HWY_API Vec128<bfloat16_t> LowerHalf(D /* tag */, Vec256<bfloat16_t> v) { 4116 return Vec128<bfloat16_t>{_mm256_castsi256_si128(v.raw)}; 4117 } 4118 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F16_D(D)> 4119 HWY_API Vec128<float16_t> LowerHalf(D /* tag */, Vec256<float16_t> v) { 4120 #if HWY_HAVE_FLOAT16 4121 return Vec128<float16_t>{_mm256_castph256_ph128(v.raw)}; 4122 #else 4123 return Vec128<float16_t>{_mm256_castsi256_si128(v.raw)}; 4124 #endif // HWY_HAVE_FLOAT16 4125 } 4126 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F32_D(D)> 4127 HWY_API Vec128<float> LowerHalf(D /* tag */, Vec256<float> v) { 4128 return Vec128<float>{_mm256_castps256_ps128(v.raw)}; 4129 } 4130 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F64_D(D)> 4131 HWY_API Vec128<double> LowerHalf(D /* tag */, Vec256<double> v) { 4132 return Vec128<double>{_mm256_castpd256_pd128(v.raw)}; 4133 } 4134 4135 template <typename T> 4136 HWY_API Vec128<T> LowerHalf(Vec256<T> v) { 4137 const Full128<T> dh; 4138 return LowerHalf(dh, v); 4139 } 4140 4141 // ------------------------------ UpperHalf 4142 4143 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_NOT_FLOAT3264_D(D)> 4144 HWY_API VFromD<D> UpperHalf(D d, VFromD<Twice<D>> v) { 4145 const RebindToUnsigned<decltype(d)> du; // for float16_t 4146 const Twice<decltype(du)> dut; 4147 return BitCast(d, VFromD<decltype(du)>{ 4148 _mm256_extracti128_si256(BitCast(dut, v).raw, 1)}); 4149 } 4150 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F32_D(D)> 4151 HWY_API VFromD<D> UpperHalf(D /* tag */, Vec256<float> v) { 4152 return VFromD<D>{_mm256_extractf128_ps(v.raw, 1)}; 4153 } 4154 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F64_D(D)> 4155 HWY_API VFromD<D> UpperHalf(D /* tag */, Vec256<double> v) { 4156 return VFromD<D>{_mm256_extractf128_pd(v.raw, 1)}; 4157 } 4158 4159 // ------------------------------ ExtractLane (Store) 4160 template <typename T> 4161 HWY_API T ExtractLane(const Vec256<T> v, size_t i) { 4162 const DFromV<decltype(v)> d; 4163 HWY_DASSERT(i < Lanes(d)); 4164 4165 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 4166 constexpr size_t kLanesPerBlock = 16 / sizeof(T); 4167 if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) { 4168 return ExtractLane(LowerHalf(Half<decltype(d)>(), v), i); 4169 } 4170 #endif 4171 4172 alignas(32) T lanes[32 / sizeof(T)]; 4173 Store(v, d, lanes); 4174 return lanes[i]; 4175 } 4176 4177 // ------------------------------ InsertLane (Store) 4178 template <typename T> 4179 HWY_API Vec256<T> InsertLane(const Vec256<T> v, size_t i, T t) { 4180 return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); 4181 } 4182 4183 // ------------------------------ GetLane (LowerHalf) 4184 template <typename T> 4185 HWY_API T GetLane(const Vec256<T> v) { 4186 return GetLane(LowerHalf(v)); 4187 } 4188 4189 // ------------------------------ ExtractBlock (LowerHalf, UpperHalf) 4190 4191 template <int kBlockIdx, class T> 4192 HWY_API Vec128<T> ExtractBlock(Vec256<T> v) { 4193 static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); 4194 const Half<DFromV<decltype(v)>> dh; 4195 return (kBlockIdx == 0) ? LowerHalf(dh, v) : UpperHalf(dh, v); 4196 } 4197 4198 // ------------------------------ ZeroExtendVector 4199 4200 // Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper 4201 // bits undefined. Although it makes sense for them to be zero (VEX encoded 4202 // 128-bit instructions zero the upper lanes to avoid large penalties), a 4203 // compiler could decide to optimize out code that relies on this. 4204 // 4205 // The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the 4206 // zeroing, but it is not available on MSVC until 1920 nor GCC until 10.1. 4207 // Unfortunately as of 2023-08 it still seems to cause internal compiler errors 4208 // on MSVC, so we consider it unavailable there. 4209 // 4210 // Without zext we can still possibly obtain the desired code thanks to pattern 4211 // recognition; note that the expensive insert instruction might not actually be 4212 // generated, see https://gcc.godbolt.org/z/1MKGaP. 4213 4214 #if !defined(HWY_HAVE_ZEXT) 4215 #if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 500) || \ 4216 (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1000) 4217 #define HWY_HAVE_ZEXT 1 4218 #else 4219 #define HWY_HAVE_ZEXT 0 4220 #endif 4221 #endif // defined(HWY_HAVE_ZEXT) 4222 4223 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 4224 HWY_API VFromD<D> ZeroExtendVector(D /* tag */, VFromD<Half<D>> lo) { 4225 #if HWY_HAVE_ZEXT 4226 return VFromD<D>{_mm256_zextsi128_si256(lo.raw)}; 4227 #elif HWY_COMPILER_MSVC 4228 // Workaround: _mm256_inserti128_si256 does not actually zero the hi part. 4229 return VFromD<D>{_mm256_set_m128i(_mm_setzero_si128(), lo.raw)}; 4230 #else 4231 return VFromD<D>{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; 4232 #endif 4233 } 4234 #if HWY_HAVE_FLOAT16 4235 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 4236 HWY_API Vec256<float16_t> ZeroExtendVector(D d, Vec128<float16_t> lo) { 4237 #if HWY_HAVE_ZEXT 4238 (void)d; 4239 return Vec256<float16_t>{_mm256_zextph128_ph256(lo.raw)}; 4240 #else 4241 const RebindToUnsigned<D> du; 4242 return BitCast(d, ZeroExtendVector(du, BitCast(du, lo))); 4243 #endif // HWY_HAVE_ZEXT 4244 } 4245 #endif // HWY_HAVE_FLOAT16 4246 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 4247 HWY_API Vec256<float> ZeroExtendVector(D /* tag */, Vec128<float> lo) { 4248 #if HWY_HAVE_ZEXT 4249 return Vec256<float>{_mm256_zextps128_ps256(lo.raw)}; 4250 #else 4251 return Vec256<float>{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; 4252 #endif 4253 } 4254 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 4255 HWY_API Vec256<double> ZeroExtendVector(D /* tag */, Vec128<double> lo) { 4256 #if HWY_HAVE_ZEXT 4257 return Vec256<double>{_mm256_zextpd128_pd256(lo.raw)}; 4258 #else 4259 return Vec256<double>{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; 4260 #endif 4261 } 4262 4263 // ------------------------------ ZeroExtendResizeBitCast 4264 4265 namespace detail { 4266 4267 template <class DTo, class DFrom> 4268 HWY_INLINE VFromD<DTo> ZeroExtendResizeBitCast( 4269 hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<32> /* to_size_tag */, 4270 DTo d_to, DFrom d_from, VFromD<DFrom> v) { 4271 const Twice<decltype(d_from)> dt_from; 4272 const Twice<decltype(dt_from)> dq_from; 4273 return BitCast(d_to, ZeroExtendVector(dq_from, ZeroExtendVector(dt_from, v))); 4274 } 4275 4276 } // namespace detail 4277 4278 // ------------------------------ Combine 4279 4280 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> 4281 HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) { 4282 const RebindToUnsigned<decltype(d)> du; // for float16_t 4283 const Half<decltype(du)> dh_u; 4284 const auto lo256 = ZeroExtendVector(du, BitCast(dh_u, lo)); 4285 return BitCast(d, VFromD<decltype(du)>{_mm256_inserti128_si256( 4286 lo256.raw, BitCast(dh_u, hi).raw, 1)}); 4287 } 4288 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 4289 HWY_API Vec256<float> Combine(D d, Vec128<float> hi, Vec128<float> lo) { 4290 const auto lo256 = ZeroExtendVector(d, lo); 4291 return Vec256<float>{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; 4292 } 4293 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 4294 HWY_API Vec256<double> Combine(D d, Vec128<double> hi, Vec128<double> lo) { 4295 const auto lo256 = ZeroExtendVector(d, lo); 4296 return Vec256<double>{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; 4297 } 4298 4299 // ------------------------------ ShiftLeftBytes 4300 template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 32)> 4301 HWY_API VFromD<D> ShiftLeftBytes(D /* tag */, VFromD<D> v) { 4302 static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); 4303 // This is the same operation as _mm256_bslli_epi128. 4304 return VFromD<D>{_mm256_slli_si256(v.raw, kBytes)}; 4305 } 4306 4307 // ------------------------------ ShiftRightBytes 4308 template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 32)> 4309 HWY_API VFromD<D> ShiftRightBytes(D /* tag */, VFromD<D> v) { 4310 static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); 4311 // This is the same operation as _mm256_bsrli_epi128. 4312 return VFromD<D>{_mm256_srli_si256(v.raw, kBytes)}; 4313 } 4314 4315 // ------------------------------ CombineShiftRightBytes 4316 template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 32)> 4317 HWY_API VFromD<D> CombineShiftRightBytes(D d, VFromD<D> hi, VFromD<D> lo) { 4318 const Repartition<uint8_t, decltype(d)> d8; 4319 return BitCast(d, Vec256<uint8_t>{_mm256_alignr_epi8( 4320 BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); 4321 } 4322 4323 // ------------------------------ Broadcast 4324 4325 template <int kLane, typename T, HWY_IF_T_SIZE(T, 2)> 4326 HWY_API Vec256<T> Broadcast(const Vec256<T> v) { 4327 const DFromV<decltype(v)> d; 4328 const RebindToUnsigned<decltype(d)> du; 4329 using VU = VFromD<decltype(du)>; 4330 const VU vu = BitCast(du, v); // for float16_t 4331 static_assert(0 <= kLane && kLane < 8, "Invalid lane"); 4332 if (kLane < 4) { 4333 const __m256i lo = _mm256_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); 4334 return BitCast(d, VU{_mm256_unpacklo_epi64(lo, lo)}); 4335 } else { 4336 const __m256i hi = 4337 _mm256_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); 4338 return BitCast(d, VU{_mm256_unpackhi_epi64(hi, hi)}); 4339 } 4340 } 4341 template <int kLane, typename T, HWY_IF_UI32(T)> 4342 HWY_API Vec256<T> Broadcast(const Vec256<T> v) { 4343 static_assert(0 <= kLane && kLane < 4, "Invalid lane"); 4344 return Vec256<T>{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; 4345 } 4346 4347 template <int kLane, typename T, HWY_IF_UI64(T)> 4348 HWY_API Vec256<T> Broadcast(const Vec256<T> v) { 4349 static_assert(0 <= kLane && kLane < 2, "Invalid lane"); 4350 return Vec256<T>{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; 4351 } 4352 4353 template <int kLane> 4354 HWY_API Vec256<float> Broadcast(Vec256<float> v) { 4355 static_assert(0 <= kLane && kLane < 4, "Invalid lane"); 4356 return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; 4357 } 4358 4359 template <int kLane> 4360 HWY_API Vec256<double> Broadcast(const Vec256<double> v) { 4361 static_assert(0 <= kLane && kLane < 2, "Invalid lane"); 4362 return Vec256<double>{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; 4363 } 4364 4365 // ------------------------------ Concat blocks (LowerHalf, ZeroExtendVector) 4366 4367 // _mm256_broadcastsi128_si256 has 7 cycle latency on ICL. 4368 // _mm256_permute2x128_si256 is slow on Zen1 (8 uops), so we avoid it (at no 4369 // extra cost) for LowerLower and UpperLower. 4370 4371 // hiH,hiL loH,loL |-> hiL,loL (= lower halves) 4372 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> 4373 HWY_API VFromD<D> ConcatLowerLower(D d, VFromD<D> hi, VFromD<D> lo) { 4374 const RebindToUnsigned<decltype(d)> du; // for float16_t 4375 const Half<decltype(d)> d2; 4376 const RebindToUnsigned<decltype(d2)> du2; // for float16_t 4377 return BitCast( 4378 d, VFromD<decltype(du)>{_mm256_inserti128_si256( 4379 BitCast(du, lo).raw, BitCast(du2, LowerHalf(d2, hi)).raw, 1)}); 4380 } 4381 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 4382 HWY_API Vec256<float> ConcatLowerLower(D d, Vec256<float> hi, 4383 Vec256<float> lo) { 4384 const Half<decltype(d)> d2; 4385 return Vec256<float>{_mm256_insertf128_ps(lo.raw, LowerHalf(d2, hi).raw, 1)}; 4386 } 4387 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 4388 HWY_API Vec256<double> ConcatLowerLower(D d, Vec256<double> hi, 4389 Vec256<double> lo) { 4390 const Half<decltype(d)> d2; 4391 return Vec256<double>{_mm256_insertf128_pd(lo.raw, LowerHalf(d2, hi).raw, 1)}; 4392 } 4393 4394 // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) 4395 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> 4396 HWY_API VFromD<D> ConcatLowerUpper(D d, VFromD<D> hi, VFromD<D> lo) { 4397 const RebindToUnsigned<decltype(d)> du; 4398 return BitCast(d, VFromD<decltype(du)>{_mm256_permute2x128_si256( 4399 BitCast(du, lo).raw, BitCast(du, hi).raw, 0x21)}); 4400 } 4401 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 4402 HWY_API Vec256<float> ConcatLowerUpper(D /* tag */, Vec256<float> hi, 4403 Vec256<float> lo) { 4404 return Vec256<float>{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; 4405 } 4406 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 4407 HWY_API Vec256<double> ConcatLowerUpper(D /* tag */, Vec256<double> hi, 4408 Vec256<double> lo) { 4409 return Vec256<double>{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; 4410 } 4411 4412 // hiH,hiL loH,loL |-> hiH,loL (= outer halves) 4413 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> 4414 HWY_API VFromD<D> ConcatUpperLower(D d, VFromD<D> hi, VFromD<D> lo) { 4415 const RebindToUnsigned<decltype(d)> du; // for float16_t 4416 return BitCast(d, VFromD<decltype(du)>{_mm256_blend_epi32( 4417 BitCast(du, hi).raw, BitCast(du, lo).raw, 0x0F)}); 4418 } 4419 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 4420 HWY_API Vec256<float> ConcatUpperLower(D /* tag */, Vec256<float> hi, 4421 Vec256<float> lo) { 4422 return Vec256<float>{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; 4423 } 4424 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 4425 HWY_API Vec256<double> ConcatUpperLower(D /* tag */, Vec256<double> hi, 4426 Vec256<double> lo) { 4427 return Vec256<double>{_mm256_blend_pd(hi.raw, lo.raw, 3)}; 4428 } 4429 4430 // hiH,hiL loH,loL |-> hiH,loH (= upper halves) 4431 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> 4432 HWY_API VFromD<D> ConcatUpperUpper(D d, VFromD<D> hi, VFromD<D> lo) { 4433 const RebindToUnsigned<decltype(d)> du; // for float16_t 4434 return BitCast(d, VFromD<decltype(du)>{_mm256_permute2x128_si256( 4435 BitCast(du, lo).raw, BitCast(du, hi).raw, 0x31)}); 4436 } 4437 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 4438 HWY_API Vec256<float> ConcatUpperUpper(D /* tag */, Vec256<float> hi, 4439 Vec256<float> lo) { 4440 return Vec256<float>{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x31)}; 4441 } 4442 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 4443 HWY_API Vec256<double> ConcatUpperUpper(D /* tag */, Vec256<double> hi, 4444 Vec256<double> lo) { 4445 return Vec256<double>{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x31)}; 4446 } 4447 4448 // ------------------------------ BroadcastBlock 4449 template <int kBlockIdx, class T> 4450 HWY_API Vec256<T> BroadcastBlock(Vec256<T> v) { 4451 static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); 4452 const DFromV<decltype(v)> d; 4453 return (kBlockIdx == 0) ? ConcatLowerLower(d, v, v) 4454 : ConcatUpperUpper(d, v, v); 4455 } 4456 4457 // ------------------------------ BroadcastLane 4458 4459 namespace detail { 4460 4461 template <class T, HWY_IF_T_SIZE(T, 1)> 4462 HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 4463 Vec256<T> v) { 4464 const Half<DFromV<decltype(v)>> dh; 4465 return Vec256<T>{_mm256_broadcastb_epi8(LowerHalf(dh, v).raw)}; 4466 } 4467 4468 template <class T, HWY_IF_T_SIZE(T, 2)> 4469 HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 4470 Vec256<T> v) { 4471 const DFromV<decltype(v)> d; 4472 const RebindToUnsigned<decltype(d)> du; // for float16_t 4473 const Half<decltype(d)> dh; 4474 const RebindToUnsigned<decltype(dh)> dh_u; 4475 return BitCast(d, VFromD<decltype(du)>{_mm256_broadcastw_epi16( 4476 BitCast(dh_u, LowerHalf(dh, v)).raw)}); 4477 } 4478 4479 template <class T, HWY_IF_UI32(T)> 4480 HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 4481 Vec256<T> v) { 4482 const Half<DFromV<decltype(v)>> dh; 4483 return Vec256<T>{_mm256_broadcastd_epi32(LowerHalf(dh, v).raw)}; 4484 } 4485 4486 template <class T, HWY_IF_UI64(T)> 4487 HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 4488 Vec256<T> v) { 4489 const Half<DFromV<decltype(v)>> dh; 4490 return Vec256<T>{_mm256_broadcastq_epi64(LowerHalf(dh, v).raw)}; 4491 } 4492 4493 HWY_INLINE Vec256<float> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 4494 Vec256<float> v) { 4495 const Half<DFromV<decltype(v)>> dh; 4496 return Vec256<float>{_mm256_broadcastss_ps(LowerHalf(dh, v).raw)}; 4497 } 4498 4499 HWY_INLINE Vec256<double> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 4500 Vec256<double> v) { 4501 const Half<DFromV<decltype(v)>> dh; 4502 return Vec256<double>{_mm256_broadcastsd_pd(LowerHalf(dh, v).raw)}; 4503 } 4504 4505 template <size_t kLaneIdx, class T, hwy::EnableIf<kLaneIdx != 0>* = nullptr, 4506 HWY_IF_NOT_T_SIZE(T, 8)> 4507 HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<kLaneIdx> /* lane_idx_tag */, 4508 Vec256<T> v) { 4509 constexpr size_t kLanesPerBlock = 16 / sizeof(T); 4510 constexpr int kBlockIdx = static_cast<int>(kLaneIdx / kLanesPerBlock); 4511 constexpr int kLaneInBlkIdx = 4512 static_cast<int>(kLaneIdx) & (kLanesPerBlock - 1); 4513 return Broadcast<kLaneInBlkIdx>(BroadcastBlock<kBlockIdx>(v)); 4514 } 4515 4516 template <size_t kLaneIdx, class T, hwy::EnableIf<kLaneIdx != 0>* = nullptr, 4517 HWY_IF_UI64(T)> 4518 HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<kLaneIdx> /* lane_idx_tag */, 4519 Vec256<T> v) { 4520 static_assert(kLaneIdx <= 3, "Invalid lane"); 4521 return Vec256<T>{ 4522 _mm256_permute4x64_epi64(v.raw, static_cast<int>(0x55 * kLaneIdx))}; 4523 } 4524 4525 template <size_t kLaneIdx, hwy::EnableIf<kLaneIdx != 0>* = nullptr> 4526 HWY_INLINE Vec256<double> BroadcastLane( 4527 hwy::SizeTag<kLaneIdx> /* lane_idx_tag */, Vec256<double> v) { 4528 static_assert(kLaneIdx <= 3, "Invalid lane"); 4529 return Vec256<double>{ 4530 _mm256_permute4x64_pd(v.raw, static_cast<int>(0x55 * kLaneIdx))}; 4531 } 4532 4533 } // namespace detail 4534 4535 template <int kLaneIdx, class T> 4536 HWY_API Vec256<T> BroadcastLane(Vec256<T> v) { 4537 static_assert(kLaneIdx >= 0, "Invalid lane"); 4538 return detail::BroadcastLane(hwy::SizeTag<static_cast<size_t>(kLaneIdx)>(), 4539 v); 4540 } 4541 4542 // ------------------------------ Hard-coded shuffles 4543 4544 // Notation: let Vec256<int32_t> have lanes 7,6,5,4,3,2,1,0 (0 is 4545 // least-significant). Shuffle0321 rotates four-lane blocks one lane to the 4546 // right (the previous least-significant lane is now most-significant => 4547 // 47650321). These could also be implemented via CombineShiftRightBytes but 4548 // the shuffle_abcd notation is more convenient. 4549 4550 // Swap 32-bit halves in 64-bit halves. 4551 template <typename T, HWY_IF_UI32(T)> 4552 HWY_API Vec256<T> Shuffle2301(const Vec256<T> v) { 4553 return Vec256<T>{_mm256_shuffle_epi32(v.raw, 0xB1)}; 4554 } 4555 HWY_API Vec256<float> Shuffle2301(const Vec256<float> v) { 4556 return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; 4557 } 4558 4559 // Used by generic_ops-inl.h 4560 namespace detail { 4561 4562 template <typename T, HWY_IF_T_SIZE(T, 4)> 4563 HWY_API Vec256<T> ShuffleTwo2301(const Vec256<T> a, const Vec256<T> b) { 4564 const DFromV<decltype(a)> d; 4565 const RebindToFloat<decltype(d)> df; 4566 constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); 4567 return BitCast(d, Vec256<float>{_mm256_shuffle_ps(BitCast(df, a).raw, 4568 BitCast(df, b).raw, m)}); 4569 } 4570 template <typename T, HWY_IF_T_SIZE(T, 4)> 4571 HWY_API Vec256<T> ShuffleTwo1230(const Vec256<T> a, const Vec256<T> b) { 4572 const DFromV<decltype(a)> d; 4573 const RebindToFloat<decltype(d)> df; 4574 constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); 4575 return BitCast(d, Vec256<float>{_mm256_shuffle_ps(BitCast(df, a).raw, 4576 BitCast(df, b).raw, m)}); 4577 } 4578 template <typename T, HWY_IF_T_SIZE(T, 4)> 4579 HWY_API Vec256<T> ShuffleTwo3012(const Vec256<T> a, const Vec256<T> b) { 4580 const DFromV<decltype(a)> d; 4581 const RebindToFloat<decltype(d)> df; 4582 constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); 4583 return BitCast(d, Vec256<float>{_mm256_shuffle_ps(BitCast(df, a).raw, 4584 BitCast(df, b).raw, m)}); 4585 } 4586 4587 } // namespace detail 4588 4589 // Swap 64-bit halves 4590 HWY_API Vec256<uint32_t> Shuffle1032(const Vec256<uint32_t> v) { 4591 return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x4E)}; 4592 } 4593 HWY_API Vec256<int32_t> Shuffle1032(const Vec256<int32_t> v) { 4594 return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x4E)}; 4595 } 4596 HWY_API Vec256<float> Shuffle1032(const Vec256<float> v) { 4597 // Shorter encoding than _mm256_permute_ps. 4598 return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; 4599 } 4600 HWY_API Vec256<uint64_t> Shuffle01(const Vec256<uint64_t> v) { 4601 return Vec256<uint64_t>{_mm256_shuffle_epi32(v.raw, 0x4E)}; 4602 } 4603 HWY_API Vec256<int64_t> Shuffle01(const Vec256<int64_t> v) { 4604 return Vec256<int64_t>{_mm256_shuffle_epi32(v.raw, 0x4E)}; 4605 } 4606 HWY_API Vec256<double> Shuffle01(const Vec256<double> v) { 4607 // Shorter encoding than _mm256_permute_pd. 4608 return Vec256<double>{_mm256_shuffle_pd(v.raw, v.raw, 5)}; 4609 } 4610 4611 // Rotate right 32 bits 4612 HWY_API Vec256<uint32_t> Shuffle0321(const Vec256<uint32_t> v) { 4613 return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x39)}; 4614 } 4615 HWY_API Vec256<int32_t> Shuffle0321(const Vec256<int32_t> v) { 4616 return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x39)}; 4617 } 4618 HWY_API Vec256<float> Shuffle0321(const Vec256<float> v) { 4619 return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; 4620 } 4621 // Rotate left 32 bits 4622 HWY_API Vec256<uint32_t> Shuffle2103(const Vec256<uint32_t> v) { 4623 return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x93)}; 4624 } 4625 HWY_API Vec256<int32_t> Shuffle2103(const Vec256<int32_t> v) { 4626 return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x93)}; 4627 } 4628 HWY_API Vec256<float> Shuffle2103(const Vec256<float> v) { 4629 return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; 4630 } 4631 4632 // Reverse 4633 HWY_API Vec256<uint32_t> Shuffle0123(const Vec256<uint32_t> v) { 4634 return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x1B)}; 4635 } 4636 HWY_API Vec256<int32_t> Shuffle0123(const Vec256<int32_t> v) { 4637 return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x1B)}; 4638 } 4639 HWY_API Vec256<float> Shuffle0123(const Vec256<float> v) { 4640 return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; 4641 } 4642 4643 // ------------------------------ TableLookupLanes 4644 4645 // Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. 4646 template <typename T> 4647 struct Indices256 { 4648 __m256i raw; 4649 }; 4650 4651 // 8-bit lanes: indices remain unchanged 4652 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1), typename TI> 4653 HWY_API Indices256<TFromD<D>> IndicesFromVec(D /* tag */, Vec256<TI> vec) { 4654 static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane"); 4655 #if HWY_IS_DEBUG_BUILD 4656 const Full256<TI> di; 4657 HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && 4658 AllTrue(di, Lt(vec, Set(di, static_cast<TI>(2 * Lanes(di)))))); 4659 #endif 4660 return Indices256<TFromD<D>>{vec.raw}; 4661 } 4662 4663 // 16-bit lanes: convert indices to 32x8 unless AVX3 is available 4664 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2), typename TI> 4665 HWY_API Indices256<TFromD<D>> IndicesFromVec(D /* tag */, Vec256<TI> vec) { 4666 static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane"); 4667 const Full256<TI> di; 4668 #if HWY_IS_DEBUG_BUILD 4669 HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && 4670 AllTrue(di, Lt(vec, Set(di, static_cast<TI>(2 * Lanes(di)))))); 4671 #endif 4672 4673 #if HWY_TARGET <= HWY_AVX3 4674 (void)di; 4675 return Indices256<TFromD<D>>{vec.raw}; 4676 #else 4677 const Repartition<uint8_t, decltype(di)> d8; 4678 using V8 = VFromD<decltype(d8)>; 4679 alignas(32) static constexpr uint8_t kByteOffsets[32] = { 4680 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 4681 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; 4682 4683 // Broadcast each lane index to all 2 bytes of T 4684 alignas(32) static constexpr uint8_t kBroadcastLaneBytes[32] = { 4685 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14, 4686 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; 4687 const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); 4688 4689 // Shift to bytes 4690 const Repartition<uint16_t, decltype(di)> d16; 4691 const V8 byte_indices = BitCast(d8, ShiftLeft<1>(BitCast(d16, lane_indices))); 4692 4693 return Indices256<TFromD<D>>{Add(byte_indices, Load(d8, kByteOffsets)).raw}; 4694 #endif // HWY_TARGET <= HWY_AVX3 4695 } 4696 4697 // Native 8x32 instruction: indices remain unchanged 4698 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4), typename TI> 4699 HWY_API Indices256<TFromD<D>> IndicesFromVec(D /* tag */, Vec256<TI> vec) { 4700 static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane"); 4701 #if HWY_IS_DEBUG_BUILD 4702 const Full256<TI> di; 4703 HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && 4704 AllTrue(di, Lt(vec, Set(di, static_cast<TI>(2 * Lanes(di)))))); 4705 #endif 4706 return Indices256<TFromD<D>>{vec.raw}; 4707 } 4708 4709 // 64-bit lanes: convert indices to 8x32 unless AVX3 is available 4710 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8), typename TI> 4711 HWY_API Indices256<TFromD<D>> IndicesFromVec(D d, Vec256<TI> idx64) { 4712 static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane"); 4713 const Rebind<TI, decltype(d)> di; 4714 (void)di; // potentially unused 4715 #if HWY_IS_DEBUG_BUILD 4716 HWY_DASSERT(AllFalse(di, Lt(idx64, Zero(di))) && 4717 AllTrue(di, Lt(idx64, Set(di, static_cast<TI>(2 * Lanes(di)))))); 4718 #endif 4719 4720 #if HWY_TARGET <= HWY_AVX3 4721 (void)d; 4722 return Indices256<TFromD<D>>{idx64.raw}; 4723 #else 4724 const Repartition<float, decltype(d)> df; // 32-bit! 4725 // Replicate 64-bit index into upper 32 bits 4726 const Vec256<TI> dup = 4727 BitCast(di, Vec256<float>{_mm256_moveldup_ps(BitCast(df, idx64).raw)}); 4728 // For each idx64 i, idx32 are 2*i and 2*i+1. 4729 const Vec256<TI> idx32 = dup + dup + Set(di, TI(1) << 32); 4730 return Indices256<TFromD<D>>{idx32.raw}; 4731 #endif 4732 } 4733 4734 template <class D, HWY_IF_V_SIZE_D(D, 32), typename TI> 4735 HWY_API Indices256<TFromD<D>> SetTableIndices(D d, const TI* idx) { 4736 const Rebind<TI, decltype(d)> di; 4737 return IndicesFromVec(d, LoadU(di, idx)); 4738 } 4739 4740 template <typename T, HWY_IF_T_SIZE(T, 1)> 4741 HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) { 4742 #if HWY_TARGET <= HWY_AVX3_DL 4743 return Vec256<T>{_mm256_permutexvar_epi8(idx.raw, v.raw)}; 4744 #else 4745 const Vec256<T> idx_vec{idx.raw}; 4746 const DFromV<decltype(v)> d; 4747 const Repartition<uint16_t, decltype(d)> du16; 4748 const auto sel_hi_mask = 4749 MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); 4750 4751 const auto a = ConcatLowerLower(d, v, v); 4752 const auto b = ConcatUpperUpper(d, v, v); 4753 const auto lo_lookup_result = TableLookupBytes(a, idx_vec); 4754 4755 #if HWY_TARGET <= HWY_AVX3 4756 return Vec256<T>{_mm256_mask_shuffle_epi8( 4757 lo_lookup_result.raw, sel_hi_mask.raw, b.raw, idx_vec.raw)}; 4758 #else 4759 const auto hi_lookup_result = TableLookupBytes(b, idx_vec); 4760 return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); 4761 #endif // HWY_TARGET <= HWY_AVX3 4762 #endif // HWY_TARGET <= HWY_AVX3_DL 4763 } 4764 4765 template <typename T, HWY_IF_T_SIZE(T, 2), HWY_IF_NOT_SPECIAL_FLOAT(T)> 4766 HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) { 4767 #if HWY_TARGET <= HWY_AVX3 4768 return Vec256<T>{_mm256_permutexvar_epi16(idx.raw, v.raw)}; 4769 #else 4770 const DFromV<decltype(v)> d; 4771 const Repartition<uint8_t, decltype(d)> du8; 4772 return BitCast( 4773 d, TableLookupLanes(BitCast(du8, v), Indices256<uint8_t>{idx.raw})); 4774 #endif 4775 } 4776 4777 #if HWY_HAVE_FLOAT16 4778 HWY_API Vec256<float16_t> TableLookupLanes(Vec256<float16_t> v, 4779 Indices256<float16_t> idx) { 4780 return Vec256<float16_t>{_mm256_permutexvar_ph(idx.raw, v.raw)}; 4781 } 4782 #endif // HWY_HAVE_FLOAT16 4783 4784 template <typename T, HWY_IF_T_SIZE(T, 4)> 4785 HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) { 4786 return Vec256<T>{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; 4787 } 4788 4789 template <typename T, HWY_IF_T_SIZE(T, 8)> 4790 HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) { 4791 #if HWY_TARGET <= HWY_AVX3 4792 return Vec256<T>{_mm256_permutexvar_epi64(idx.raw, v.raw)}; 4793 #else 4794 return Vec256<T>{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; 4795 #endif 4796 } 4797 4798 HWY_API Vec256<float> TableLookupLanes(const Vec256<float> v, 4799 const Indices256<float> idx) { 4800 return Vec256<float>{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; 4801 } 4802 4803 HWY_API Vec256<double> TableLookupLanes(const Vec256<double> v, 4804 const Indices256<double> idx) { 4805 #if HWY_TARGET <= HWY_AVX3 4806 return Vec256<double>{_mm256_permutexvar_pd(idx.raw, v.raw)}; 4807 #else 4808 const Full256<double> df; 4809 const Full256<uint64_t> du; 4810 return BitCast(df, Vec256<uint64_t>{_mm256_permutevar8x32_epi32( 4811 BitCast(du, v).raw, idx.raw)}); 4812 #endif 4813 } 4814 4815 template <typename T, HWY_IF_T_SIZE(T, 1)> 4816 HWY_API Vec256<T> TwoTablesLookupLanes(Vec256<T> a, Vec256<T> b, 4817 Indices256<T> idx) { 4818 #if HWY_TARGET <= HWY_AVX3_DL 4819 return Vec256<T>{_mm256_permutex2var_epi8(a.raw, idx.raw, b.raw)}; 4820 #else 4821 const DFromV<decltype(a)> d; 4822 const auto sel_hi_mask = 4823 MaskFromVec(BitCast(d, ShiftLeft<2>(Vec256<uint16_t>{idx.raw}))); 4824 const auto lo_lookup_result = TableLookupLanes(a, idx); 4825 const auto hi_lookup_result = TableLookupLanes(b, idx); 4826 return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); 4827 #endif 4828 } 4829 4830 template <typename T, HWY_IF_T_SIZE(T, 2)> 4831 HWY_API Vec256<T> TwoTablesLookupLanes(Vec256<T> a, Vec256<T> b, 4832 Indices256<T> idx) { 4833 #if HWY_TARGET <= HWY_AVX3 4834 return Vec256<T>{_mm256_permutex2var_epi16(a.raw, idx.raw, b.raw)}; 4835 #else 4836 const DFromV<decltype(a)> d; 4837 const Repartition<uint8_t, decltype(d)> du8; 4838 return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), 4839 Indices256<uint8_t>{idx.raw})); 4840 #endif 4841 } 4842 4843 template <typename T, HWY_IF_UI32(T)> 4844 HWY_API Vec256<T> TwoTablesLookupLanes(Vec256<T> a, Vec256<T> b, 4845 Indices256<T> idx) { 4846 #if HWY_TARGET <= HWY_AVX3 4847 return Vec256<T>{_mm256_permutex2var_epi32(a.raw, idx.raw, b.raw)}; 4848 #else 4849 const DFromV<decltype(a)> d; 4850 const RebindToFloat<decltype(d)> df; 4851 const Vec256<T> idx_vec{idx.raw}; 4852 4853 const auto sel_hi_mask = MaskFromVec(BitCast(df, ShiftLeft<28>(idx_vec))); 4854 const auto lo_lookup_result = BitCast(df, TableLookupLanes(a, idx)); 4855 const auto hi_lookup_result = BitCast(df, TableLookupLanes(b, idx)); 4856 return BitCast(d, 4857 IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result)); 4858 #endif 4859 } 4860 4861 #if HWY_HAVE_FLOAT16 4862 HWY_API Vec256<float16_t> TwoTablesLookupLanes(Vec256<float16_t> a, 4863 Vec256<float16_t> b, 4864 Indices256<float16_t> idx) { 4865 return Vec256<float16_t>{_mm256_permutex2var_ph(a.raw, idx.raw, b.raw)}; 4866 } 4867 #endif // HWY_HAVE_FLOAT16 4868 HWY_API Vec256<float> TwoTablesLookupLanes(Vec256<float> a, Vec256<float> b, 4869 Indices256<float> idx) { 4870 #if HWY_TARGET <= HWY_AVX3 4871 return Vec256<float>{_mm256_permutex2var_ps(a.raw, idx.raw, b.raw)}; 4872 #else 4873 const DFromV<decltype(a)> d; 4874 const auto sel_hi_mask = 4875 MaskFromVec(BitCast(d, ShiftLeft<28>(Vec256<uint32_t>{idx.raw}))); 4876 const auto lo_lookup_result = TableLookupLanes(a, idx); 4877 const auto hi_lookup_result = TableLookupLanes(b, idx); 4878 return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); 4879 #endif 4880 } 4881 4882 template <typename T, HWY_IF_UI64(T)> 4883 HWY_API Vec256<T> TwoTablesLookupLanes(Vec256<T> a, Vec256<T> b, 4884 Indices256<T> idx) { 4885 #if HWY_TARGET <= HWY_AVX3 4886 return Vec256<T>{_mm256_permutex2var_epi64(a.raw, idx.raw, b.raw)}; 4887 #else 4888 const DFromV<decltype(a)> d; 4889 const Repartition<uint32_t, decltype(d)> du32; 4890 return BitCast(d, TwoTablesLookupLanes(BitCast(du32, a), BitCast(du32, b), 4891 Indices256<uint32_t>{idx.raw})); 4892 #endif 4893 } 4894 4895 HWY_API Vec256<double> TwoTablesLookupLanes(Vec256<double> a, Vec256<double> b, 4896 Indices256<double> idx) { 4897 #if HWY_TARGET <= HWY_AVX3 4898 return Vec256<double>{_mm256_permutex2var_pd(a.raw, idx.raw, b.raw)}; 4899 #else 4900 const DFromV<decltype(a)> d; 4901 const Repartition<uint32_t, decltype(d)> du32; 4902 return BitCast(d, TwoTablesLookupLanes(BitCast(du32, a), BitCast(du32, b), 4903 Indices256<uint32_t>{idx.raw})); 4904 #endif 4905 } 4906 4907 // ------------------------------ SwapAdjacentBlocks 4908 4909 template <typename T> 4910 HWY_API Vec256<T> SwapAdjacentBlocks(Vec256<T> v) { 4911 const DFromV<decltype(v)> d; 4912 const RebindToUnsigned<decltype(d)> du; // for float16_t 4913 return BitCast(d, VFromD<decltype(du)>{_mm256_permute4x64_epi64( 4914 BitCast(du, v).raw, _MM_SHUFFLE(1, 0, 3, 2))}); 4915 } 4916 4917 HWY_API Vec256<double> SwapAdjacentBlocks(Vec256<double> v) { 4918 return Vec256<double>{_mm256_permute4x64_pd(v.raw, _MM_SHUFFLE(1, 0, 3, 2))}; 4919 } 4920 4921 HWY_API Vec256<float> SwapAdjacentBlocks(Vec256<float> v) { 4922 // Assume no domain-crossing penalty between float/double (true on SKX). 4923 const DFromV<decltype(v)> d; 4924 const RepartitionToWide<decltype(d)> dw; 4925 return BitCast(d, SwapAdjacentBlocks(BitCast(dw, v))); 4926 } 4927 4928 // ------------------------------ InterleaveEvenBlocks (ConcatLowerLower) 4929 template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)> 4930 HWY_API V InterleaveEvenBlocks(D d, V a, V b) { 4931 return ConcatLowerLower(d, b, a); 4932 } 4933 4934 // ------------------------------ InterleaveOddBlocks (ConcatUpperUpper) 4935 template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)> 4936 HWY_API V InterleaveOddBlocks(D d, V a, V b) { 4937 return ConcatUpperUpper(d, b, a); 4938 } 4939 4940 // ------------------------------ Reverse (RotateRight) 4941 4942 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)> 4943 HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { 4944 alignas(32) static constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; 4945 return TableLookupLanes(v, SetTableIndices(d, kReverse)); 4946 } 4947 4948 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)> 4949 HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { 4950 alignas(32) static constexpr int64_t kReverse[4] = {3, 2, 1, 0}; 4951 return TableLookupLanes(v, SetTableIndices(d, kReverse)); 4952 } 4953 4954 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 4955 HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { 4956 #if HWY_TARGET <= HWY_AVX3 4957 const RebindToSigned<decltype(d)> di; 4958 alignas(32) static constexpr int16_t kReverse[16] = { 4959 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; 4960 const Vec256<int16_t> idx = Load(di, kReverse); 4961 return BitCast(d, Vec256<int16_t>{ 4962 _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); 4963 #else 4964 const RebindToSigned<decltype(d)> di; 4965 const VFromD<decltype(di)> shuffle = Dup128VecFromValues( 4966 di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); 4967 const auto rev128 = TableLookupBytes(v, shuffle); 4968 return VFromD<D>{ 4969 _mm256_permute4x64_epi64(rev128.raw, _MM_SHUFFLE(1, 0, 3, 2))}; 4970 #endif 4971 } 4972 4973 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 4974 HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { 4975 #if HWY_TARGET <= HWY_AVX3_DL 4976 alignas(32) static constexpr TFromD<D> kReverse[32] = { 4977 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 4978 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; 4979 return TableLookupLanes(v, SetTableIndices(d, kReverse)); 4980 #else 4981 // First reverse bytes within blocks via PSHUFB, then swap blocks. 4982 alignas(32) static constexpr TFromD<D> kReverse[32] = { 4983 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 4984 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; 4985 return SwapAdjacentBlocks(TableLookupBytes(v, Load(d, kReverse))); 4986 #endif 4987 } 4988 4989 // ------------------------------ Reverse2 (in x86_128) 4990 4991 // ------------------------------ Reverse4 (SwapAdjacentBlocks) 4992 4993 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 4994 HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { 4995 const RebindToSigned<decltype(d)> di; 4996 const VFromD<decltype(di)> shuffle = Dup128VecFromValues( 4997 di, 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908); 4998 return BitCast(d, TableLookupBytes(v, shuffle)); 4999 } 5000 5001 // 32 bit Reverse4 defined in x86_128. 5002 5003 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)> 5004 HWY_API VFromD<D> Reverse4(D /* tag */, const VFromD<D> v) { 5005 // Could also use _mm256_permute4x64_epi64. 5006 return SwapAdjacentBlocks(Shuffle01(v)); 5007 } 5008 5009 // ------------------------------ Reverse8 5010 5011 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 5012 HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { 5013 const RebindToSigned<decltype(d)> di; 5014 const VFromD<decltype(di)> shuffle = Dup128VecFromValues( 5015 di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); 5016 return BitCast(d, TableLookupBytes(v, shuffle)); 5017 } 5018 5019 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)> 5020 HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { 5021 return Reverse(d, v); 5022 } 5023 5024 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)> 5025 HWY_API VFromD<D> Reverse8(D /* tag */, const VFromD<D> /* v */) { 5026 HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes 5027 } 5028 5029 // ------------------------------ ReverseBits in x86_512 5030 5031 // ------------------------------ InterleaveLower 5032 5033 // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides 5034 // the least-significant lane) and "b". To concatenate two half-width integers 5035 // into one, use ZipLower/Upper instead (also works with scalar). 5036 5037 template <typename T, HWY_IF_T_SIZE(T, 1)> 5038 HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) { 5039 return Vec256<T>{_mm256_unpacklo_epi8(a.raw, b.raw)}; 5040 } 5041 template <typename T, HWY_IF_T_SIZE(T, 2)> 5042 HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) { 5043 const DFromV<decltype(a)> d; 5044 const RebindToUnsigned<decltype(d)> du; 5045 using VU = VFromD<decltype(du)>; // for float16_t 5046 return BitCast( 5047 d, VU{_mm256_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); 5048 } 5049 template <typename T, HWY_IF_UI32(T)> 5050 HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) { 5051 return Vec256<T>{_mm256_unpacklo_epi32(a.raw, b.raw)}; 5052 } 5053 template <typename T, HWY_IF_UI64(T)> 5054 HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) { 5055 return Vec256<T>{_mm256_unpacklo_epi64(a.raw, b.raw)}; 5056 } 5057 5058 HWY_API Vec256<float> InterleaveLower(Vec256<float> a, Vec256<float> b) { 5059 return Vec256<float>{_mm256_unpacklo_ps(a.raw, b.raw)}; 5060 } 5061 HWY_API Vec256<double> InterleaveLower(Vec256<double> a, Vec256<double> b) { 5062 return Vec256<double>{_mm256_unpacklo_pd(a.raw, b.raw)}; 5063 } 5064 5065 // ------------------------------ InterleaveUpper 5066 5067 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 5068 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 5069 return VFromD<D>{_mm256_unpackhi_epi8(a.raw, b.raw)}; 5070 } 5071 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 5072 HWY_API VFromD<D> InterleaveUpper(D d, VFromD<D> a, VFromD<D> b) { 5073 const RebindToUnsigned<decltype(d)> du; 5074 using VU = VFromD<decltype(du)>; // for float16_t 5075 return BitCast( 5076 d, VU{_mm256_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); 5077 } 5078 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 5079 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 5080 return VFromD<D>{_mm256_unpackhi_epi32(a.raw, b.raw)}; 5081 } 5082 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 5083 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 5084 return VFromD<D>{_mm256_unpackhi_epi64(a.raw, b.raw)}; 5085 } 5086 5087 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 5088 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 5089 return VFromD<D>{_mm256_unpackhi_ps(a.raw, b.raw)}; 5090 } 5091 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 5092 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 5093 return VFromD<D>{_mm256_unpackhi_pd(a.raw, b.raw)}; 5094 } 5095 5096 // ---------------------------- InsertBlock (ConcatLowerLower, ConcatUpperLower) 5097 template <int kBlockIdx, class T> 5098 HWY_API Vec256<T> InsertBlock(Vec256<T> v, Vec128<T> blk_to_insert) { 5099 static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); 5100 5101 const DFromV<decltype(v)> d; 5102 const auto vec_to_insert = ResizeBitCast(d, blk_to_insert); 5103 return (kBlockIdx == 0) ? ConcatUpperLower(d, v, vec_to_insert) 5104 : ConcatLowerLower(d, vec_to_insert, v); 5105 } 5106 5107 // ------------------------------ ConcatOdd 5108 5109 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 5110 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 5111 const RebindToUnsigned<decltype(d)> du; 5112 #if HWY_TARGET <= HWY_AVX3_DL 5113 alignas(32) static constexpr uint8_t kIdx[32] = { 5114 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 5115 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; 5116 return BitCast( 5117 d, Vec256<uint16_t>{_mm256_permutex2var_epi8( 5118 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 5119 #else 5120 const RepartitionToWide<decltype(du)> dw; 5121 // Unsigned 8-bit shift so we can pack. 5122 const Vec256<uint16_t> uH = ShiftRight<8>(BitCast(dw, hi)); 5123 const Vec256<uint16_t> uL = ShiftRight<8>(BitCast(dw, lo)); 5124 const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); 5125 return VFromD<D>{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; 5126 #endif 5127 } 5128 5129 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 5130 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 5131 const RebindToUnsigned<decltype(d)> du; 5132 #if HWY_TARGET <= HWY_AVX3 5133 alignas(32) static constexpr uint16_t kIdx[16] = { 5134 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; 5135 return BitCast( 5136 d, Vec256<uint16_t>{_mm256_permutex2var_epi16( 5137 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 5138 #else 5139 const RepartitionToWide<decltype(du)> dw; 5140 // Unsigned 16-bit shift so we can pack. 5141 const Vec256<uint32_t> uH = ShiftRight<16>(BitCast(dw, hi)); 5142 const Vec256<uint32_t> uL = ShiftRight<16>(BitCast(dw, lo)); 5143 const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); 5144 return BitCast(d, VFromD<decltype(du)>{_mm256_permute4x64_epi64( 5145 u16, _MM_SHUFFLE(3, 1, 2, 0))}); 5146 #endif 5147 } 5148 5149 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 5150 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 5151 const RebindToUnsigned<decltype(d)> du; 5152 #if HWY_TARGET <= HWY_AVX3 5153 alignas(32) static constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; 5154 return BitCast( 5155 d, Vec256<uint32_t>{_mm256_permutex2var_epi32( 5156 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 5157 #else 5158 const RebindToFloat<decltype(d)> df; 5159 const Vec256<float> v3131{_mm256_shuffle_ps( 5160 BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(3, 1, 3, 1))}; 5161 return VFromD<D>{_mm256_permute4x64_epi64(BitCast(du, v3131).raw, 5162 _MM_SHUFFLE(3, 1, 2, 0))}; 5163 #endif 5164 } 5165 5166 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 5167 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 5168 const RebindToUnsigned<decltype(d)> du; 5169 #if HWY_TARGET <= HWY_AVX3 5170 alignas(32) static constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; 5171 return VFromD<D>{_mm256_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; 5172 #else 5173 const VFromD<D> v3131{ 5174 _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; 5175 return BitCast(d, Vec256<uint32_t>{_mm256_permute4x64_epi64( 5176 BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}); 5177 #endif 5178 } 5179 5180 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 5181 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 5182 const RebindToUnsigned<decltype(d)> du; 5183 #if HWY_TARGET <= HWY_AVX3 5184 alignas(64) static constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; 5185 return BitCast( 5186 d, Vec256<uint64_t>{_mm256_permutex2var_epi64( 5187 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 5188 #else 5189 const RebindToFloat<decltype(d)> df; 5190 const Vec256<double> v31{ 5191 _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 15)}; 5192 return VFromD<D>{ 5193 _mm256_permute4x64_epi64(BitCast(du, v31).raw, _MM_SHUFFLE(3, 1, 2, 0))}; 5194 #endif 5195 } 5196 5197 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 5198 HWY_API Vec256<double> ConcatOdd(D d, Vec256<double> hi, Vec256<double> lo) { 5199 #if HWY_TARGET <= HWY_AVX3 5200 const RebindToUnsigned<decltype(d)> du; 5201 alignas(64) static constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; 5202 return Vec256<double>{ 5203 _mm256_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; 5204 #else 5205 (void)d; 5206 const Vec256<double> v31{_mm256_shuffle_pd(lo.raw, hi.raw, 15)}; 5207 return Vec256<double>{ 5208 _mm256_permute4x64_pd(v31.raw, _MM_SHUFFLE(3, 1, 2, 0))}; 5209 #endif 5210 } 5211 5212 // ------------------------------ ConcatEven 5213 5214 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 5215 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 5216 const RebindToUnsigned<decltype(d)> du; 5217 #if HWY_TARGET <= HWY_AVX3_DL 5218 alignas(64) static constexpr uint8_t kIdx[32] = { 5219 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 5220 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; 5221 return BitCast( 5222 d, Vec256<uint32_t>{_mm256_permutex2var_epi8( 5223 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 5224 #else 5225 const RepartitionToWide<decltype(du)> dw; 5226 // Isolate lower 8 bits per u16 so we can pack. 5227 const Vec256<uint16_t> mask = Set(dw, 0x00FF); 5228 const Vec256<uint16_t> uH = And(BitCast(dw, hi), mask); 5229 const Vec256<uint16_t> uL = And(BitCast(dw, lo), mask); 5230 const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); 5231 return VFromD<D>{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; 5232 #endif 5233 } 5234 5235 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 5236 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 5237 const RebindToUnsigned<decltype(d)> du; 5238 #if HWY_TARGET <= HWY_AVX3 5239 alignas(64) static constexpr uint16_t kIdx[16] = { 5240 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; 5241 return BitCast( 5242 d, Vec256<uint32_t>{_mm256_permutex2var_epi16( 5243 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 5244 #else 5245 const RepartitionToWide<decltype(du)> dw; 5246 // Isolate lower 16 bits per u32 so we can pack. 5247 const Vec256<uint32_t> mask = Set(dw, 0x0000FFFF); 5248 const Vec256<uint32_t> uH = And(BitCast(dw, hi), mask); 5249 const Vec256<uint32_t> uL = And(BitCast(dw, lo), mask); 5250 const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); 5251 return BitCast(d, VFromD<decltype(du)>{_mm256_permute4x64_epi64( 5252 u16, _MM_SHUFFLE(3, 1, 2, 0))}); 5253 #endif 5254 } 5255 5256 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 5257 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 5258 const RebindToUnsigned<decltype(d)> du; 5259 #if HWY_TARGET <= HWY_AVX3 5260 alignas(64) static constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; 5261 return BitCast( 5262 d, Vec256<uint32_t>{_mm256_permutex2var_epi32( 5263 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 5264 #else 5265 const RebindToFloat<decltype(d)> df; 5266 const Vec256<float> v2020{_mm256_shuffle_ps( 5267 BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(2, 0, 2, 0))}; 5268 return VFromD<D>{_mm256_permute4x64_epi64(BitCast(du, v2020).raw, 5269 _MM_SHUFFLE(3, 1, 2, 0))}; 5270 5271 #endif 5272 } 5273 5274 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 5275 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 5276 const RebindToUnsigned<decltype(d)> du; 5277 #if HWY_TARGET <= HWY_AVX3 5278 alignas(64) static constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; 5279 return VFromD<D>{_mm256_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; 5280 #else 5281 const VFromD<D> v2020{ 5282 _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; 5283 return BitCast(d, Vec256<uint32_t>{_mm256_permute4x64_epi64( 5284 BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); 5285 5286 #endif 5287 } 5288 5289 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 5290 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 5291 const RebindToUnsigned<decltype(d)> du; 5292 #if HWY_TARGET <= HWY_AVX3 5293 alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; 5294 return BitCast( 5295 d, Vec256<uint64_t>{_mm256_permutex2var_epi64( 5296 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 5297 #else 5298 const RebindToFloat<decltype(d)> df; 5299 const Vec256<double> v20{ 5300 _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; 5301 return VFromD<D>{ 5302 _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; 5303 5304 #endif 5305 } 5306 5307 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 5308 HWY_API Vec256<double> ConcatEven(D d, Vec256<double> hi, Vec256<double> lo) { 5309 #if HWY_TARGET <= HWY_AVX3 5310 const RebindToUnsigned<decltype(d)> du; 5311 alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; 5312 return Vec256<double>{ 5313 _mm256_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; 5314 #else 5315 (void)d; 5316 const Vec256<double> v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; 5317 return Vec256<double>{ 5318 _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; 5319 #endif 5320 } 5321 5322 // ------------------------------ InterleaveWholeLower 5323 5324 #if HWY_TARGET <= HWY_AVX3 5325 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 5326 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 5327 #if HWY_TARGET <= HWY_AVX3_DL 5328 const RebindToUnsigned<decltype(d)> du; 5329 alignas(32) static constexpr uint8_t kIdx[32] = { 5330 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, 5331 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; 5332 return VFromD<D>{_mm256_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; 5333 #else 5334 return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); 5335 #endif 5336 } 5337 5338 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 5339 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 5340 const RebindToUnsigned<decltype(d)> du; 5341 alignas(32) static constexpr uint16_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, 5342 4, 20, 5, 21, 6, 22, 7, 23}; 5343 return BitCast( 5344 d, VFromD<decltype(du)>{_mm256_permutex2var_epi16( 5345 BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); 5346 } 5347 5348 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 5349 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 5350 const RebindToUnsigned<decltype(d)> du; 5351 alignas(32) static constexpr uint32_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; 5352 return VFromD<D>{_mm256_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; 5353 } 5354 5355 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 5356 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 5357 const RebindToUnsigned<decltype(d)> du; 5358 alignas(32) static constexpr uint32_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; 5359 return VFromD<D>{_mm256_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; 5360 } 5361 5362 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 5363 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 5364 const RebindToUnsigned<decltype(d)> du; 5365 alignas(32) static constexpr uint64_t kIdx[4] = {0, 4, 1, 5}; 5366 return VFromD<D>{_mm256_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; 5367 } 5368 5369 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 5370 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 5371 const RebindToUnsigned<decltype(d)> du; 5372 alignas(32) static constexpr uint64_t kIdx[4] = {0, 4, 1, 5}; 5373 return VFromD<D>{_mm256_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; 5374 } 5375 #else // AVX2 5376 template <class D, HWY_IF_V_SIZE_D(D, 32)> 5377 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 5378 return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); 5379 } 5380 #endif 5381 5382 // ------------------------------ InterleaveWholeUpper 5383 5384 #if HWY_TARGET <= HWY_AVX3 5385 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 5386 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 5387 #if HWY_TARGET <= HWY_AVX3_DL 5388 const RebindToUnsigned<decltype(d)> du; 5389 alignas(32) static constexpr uint8_t kIdx[32] = { 5390 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, 5391 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; 5392 return VFromD<D>{_mm256_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; 5393 #else 5394 return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); 5395 #endif 5396 } 5397 5398 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 5399 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 5400 const RebindToUnsigned<decltype(d)> du; 5401 alignas(32) static constexpr uint16_t kIdx[16] = { 5402 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; 5403 return BitCast( 5404 d, VFromD<decltype(du)>{_mm256_permutex2var_epi16( 5405 BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); 5406 } 5407 5408 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 5409 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 5410 const RebindToUnsigned<decltype(d)> du; 5411 alignas(32) static constexpr uint32_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; 5412 return VFromD<D>{_mm256_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; 5413 } 5414 5415 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 5416 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 5417 const RebindToUnsigned<decltype(d)> du; 5418 alignas(32) static constexpr uint32_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; 5419 return VFromD<D>{_mm256_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; 5420 } 5421 5422 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)> 5423 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 5424 const RebindToUnsigned<decltype(d)> du; 5425 alignas(32) static constexpr uint64_t kIdx[4] = {2, 6, 3, 7}; 5426 return VFromD<D>{_mm256_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; 5427 } 5428 5429 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 5430 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 5431 const RebindToUnsigned<decltype(d)> du; 5432 alignas(32) static constexpr uint64_t kIdx[4] = {2, 6, 3, 7}; 5433 return VFromD<D>{_mm256_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; 5434 } 5435 #else // AVX2 5436 template <class D, HWY_IF_V_SIZE_D(D, 32)> 5437 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 5438 return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); 5439 } 5440 #endif 5441 5442 // ------------------------------ DupEven (InterleaveLower) 5443 5444 template <typename T, HWY_IF_UI32(T)> 5445 HWY_API Vec256<T> DupEven(Vec256<T> v) { 5446 return Vec256<T>{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; 5447 } 5448 HWY_API Vec256<float> DupEven(Vec256<float> v) { 5449 return Vec256<float>{ 5450 _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; 5451 } 5452 5453 template <typename T, HWY_IF_T_SIZE(T, 8)> 5454 HWY_API Vec256<T> DupEven(const Vec256<T> v) { 5455 const DFromV<decltype(v)> d; 5456 return InterleaveLower(d, v, v); 5457 } 5458 5459 // ------------------------------ DupOdd (InterleaveUpper) 5460 5461 template <typename T, HWY_IF_UI32(T)> 5462 HWY_API Vec256<T> DupOdd(Vec256<T> v) { 5463 return Vec256<T>{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; 5464 } 5465 HWY_API Vec256<float> DupOdd(Vec256<float> v) { 5466 return Vec256<float>{ 5467 _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; 5468 } 5469 5470 template <typename T, HWY_IF_T_SIZE(T, 8)> 5471 HWY_API Vec256<T> DupOdd(const Vec256<T> v) { 5472 const DFromV<decltype(v)> d; 5473 return InterleaveUpper(d, v, v); 5474 } 5475 5476 // ------------------------------ OddEven 5477 5478 template <typename T, HWY_IF_T_SIZE(T, 1)> 5479 HWY_INLINE Vec256<T> OddEven(Vec256<T> a, Vec256<T> b) { 5480 const DFromV<decltype(a)> d; 5481 const Full256<uint8_t> d8; 5482 const VFromD<decltype(d8)> mask = 5483 Dup128VecFromValues(d8, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 5484 0, 0xFF, 0, 0xFF, 0); 5485 return IfThenElse(MaskFromVec(BitCast(d, mask)), b, a); 5486 } 5487 5488 template <typename T, HWY_IF_UI16(T)> 5489 HWY_INLINE Vec256<T> OddEven(Vec256<T> a, Vec256<T> b) { 5490 const DFromV<decltype(a)> d; 5491 const RebindToUnsigned<decltype(d)> du; // for float16_t 5492 return BitCast(d, VFromD<decltype(du)>{_mm256_blend_epi16( 5493 BitCast(du, a).raw, BitCast(du, b).raw, 0x55)}); 5494 } 5495 5496 #if HWY_HAVE_FLOAT16 5497 HWY_INLINE Vec256<float16_t> OddEven(Vec256<float16_t> a, Vec256<float16_t> b) { 5498 return Vec256<float16_t>{ 5499 _mm256_mask_blend_ph(static_cast<__mmask16>(0x5555), a.raw, b.raw)}; 5500 } 5501 #endif // HWY_HAVE_FLOAT16 5502 5503 template <typename T, HWY_IF_UI32(T)> 5504 HWY_INLINE Vec256<T> OddEven(Vec256<T> a, Vec256<T> b) { 5505 return Vec256<T>{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; 5506 } 5507 5508 template <typename T, HWY_IF_UI64(T)> 5509 HWY_INLINE Vec256<T> OddEven(Vec256<T> a, Vec256<T> b) { 5510 return Vec256<T>{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; 5511 } 5512 5513 HWY_API Vec256<float> OddEven(Vec256<float> a, Vec256<float> b) { 5514 return Vec256<float>{_mm256_blend_ps(a.raw, b.raw, 0x55)}; 5515 } 5516 5517 HWY_API Vec256<double> OddEven(Vec256<double> a, Vec256<double> b) { 5518 return Vec256<double>{_mm256_blend_pd(a.raw, b.raw, 5)}; 5519 } 5520 5521 // -------------------------- InterleaveEven 5522 5523 #if HWY_TARGET <= HWY_AVX3 5524 template <class D, HWY_IF_LANES_D(D, 8), HWY_IF_UI32_D(D)> 5525 HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { 5526 return VFromD<D>{_mm256_mask_shuffle_epi32( 5527 a.raw, static_cast<__mmask8>(0xAA), b.raw, 5528 static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; 5529 } 5530 template <class D, HWY_IF_LANES_D(D, 8), HWY_IF_F32_D(D)> 5531 HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { 5532 return VFromD<D>{_mm256_mask_shuffle_ps(a.raw, static_cast<__mmask8>(0xAA), 5533 b.raw, b.raw, 5534 _MM_SHUFFLE(2, 2, 0, 0))}; 5535 } 5536 #else 5537 template <class D, HWY_IF_LANES_D(D, 8), HWY_IF_T_SIZE_D(D, 4)> 5538 HWY_API VFromD<D> InterleaveEven(D d, VFromD<D> a, VFromD<D> b) { 5539 const RebindToFloat<decltype(d)> df; 5540 const VFromD<decltype(df)> b2_b0_a2_a0{_mm256_shuffle_ps( 5541 BitCast(df, a).raw, BitCast(df, b).raw, _MM_SHUFFLE(2, 0, 2, 0))}; 5542 return BitCast( 5543 d, VFromD<decltype(df)>{_mm256_shuffle_ps( 5544 b2_b0_a2_a0.raw, b2_b0_a2_a0.raw, _MM_SHUFFLE(3, 1, 2, 0))}); 5545 } 5546 #endif 5547 5548 // I64/U64/F64 InterleaveEven is generic for vector lengths >= 32 bytes 5549 template <class D, HWY_IF_LANES_GT_D(D, 2), HWY_IF_T_SIZE_D(D, 8)> 5550 HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { 5551 return InterleaveLower(a, b); 5552 } 5553 5554 // -------------------------- InterleaveOdd 5555 5556 #if HWY_TARGET <= HWY_AVX3 5557 template <class D, HWY_IF_LANES_D(D, 8), HWY_IF_UI32_D(D)> 5558 HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) { 5559 return VFromD<D>{_mm256_mask_shuffle_epi32( 5560 b.raw, static_cast<__mmask8>(0x55), a.raw, 5561 static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; 5562 } 5563 template <class D, HWY_IF_LANES_D(D, 8), HWY_IF_F32_D(D)> 5564 HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) { 5565 return VFromD<D>{_mm256_mask_shuffle_ps(b.raw, static_cast<__mmask8>(0x55), 5566 a.raw, a.raw, 5567 _MM_SHUFFLE(3, 3, 1, 1))}; 5568 } 5569 #else 5570 template <class D, HWY_IF_LANES_D(D, 8), HWY_IF_T_SIZE_D(D, 4)> 5571 HWY_API VFromD<D> InterleaveOdd(D d, VFromD<D> a, VFromD<D> b) { 5572 const RebindToFloat<decltype(d)> df; 5573 const VFromD<decltype(df)> b3_b1_a3_a3{_mm256_shuffle_ps( 5574 BitCast(df, a).raw, BitCast(df, b).raw, _MM_SHUFFLE(3, 1, 3, 1))}; 5575 return BitCast( 5576 d, VFromD<decltype(df)>{_mm256_shuffle_ps( 5577 b3_b1_a3_a3.raw, b3_b1_a3_a3.raw, _MM_SHUFFLE(3, 1, 2, 0))}); 5578 } 5579 #endif 5580 5581 // I64/U64/F64 InterleaveOdd is generic for vector lengths >= 32 bytes 5582 template <class D, HWY_IF_LANES_GT_D(D, 2), HWY_IF_T_SIZE_D(D, 8)> 5583 HWY_API VFromD<D> InterleaveOdd(D d, VFromD<D> a, VFromD<D> b) { 5584 return InterleaveUpper(d, a, b); 5585 } 5586 5587 // ------------------------------ OddEvenBlocks 5588 5589 template <typename T, HWY_IF_NOT_FLOAT3264(T)> 5590 Vec256<T> OddEvenBlocks(Vec256<T> odd, Vec256<T> even) { 5591 const DFromV<decltype(odd)> d; 5592 const RebindToUnsigned<decltype(d)> du; 5593 return BitCast(d, VFromD<decltype(du)>{_mm256_blend_epi32( 5594 BitCast(du, odd).raw, BitCast(du, even).raw, 0xFu)}); 5595 } 5596 5597 HWY_API Vec256<float> OddEvenBlocks(Vec256<float> odd, Vec256<float> even) { 5598 return Vec256<float>{_mm256_blend_ps(odd.raw, even.raw, 0xFu)}; 5599 } 5600 5601 HWY_API Vec256<double> OddEvenBlocks(Vec256<double> odd, Vec256<double> even) { 5602 return Vec256<double>{_mm256_blend_pd(odd.raw, even.raw, 0x3u)}; 5603 } 5604 5605 // ------------------------------ ReverseBlocks (SwapAdjacentBlocks) 5606 5607 template <class D, HWY_IF_V_SIZE_D(D, 32)> 5608 HWY_API VFromD<D> ReverseBlocks(D /*d*/, VFromD<D> v) { 5609 return SwapAdjacentBlocks(v); 5610 } 5611 5612 // ------------------------------ TableLookupBytes (ZeroExtendVector) 5613 5614 // Both full 5615 template <typename T, typename TI> 5616 HWY_API Vec256<TI> TableLookupBytes(Vec256<T> bytes, Vec256<TI> from) { 5617 const DFromV<decltype(from)> d; 5618 return BitCast(d, Vec256<uint8_t>{_mm256_shuffle_epi8( 5619 BitCast(Full256<uint8_t>(), bytes).raw, 5620 BitCast(Full256<uint8_t>(), from).raw)}); 5621 } 5622 5623 // Partial index vector 5624 template <typename T, typename TI, size_t NI> 5625 HWY_API Vec128<TI, NI> TableLookupBytes(Vec256<T> bytes, Vec128<TI, NI> from) { 5626 const Full256<TI> di; 5627 const Half<decltype(di)> dih; 5628 // First expand to full 128, then 256. 5629 const auto from_256 = ZeroExtendVector(di, Vec128<TI>{from.raw}); 5630 const auto tbl_full = TableLookupBytes(bytes, from_256); 5631 // Shrink to 128, then partial. 5632 return Vec128<TI, NI>{LowerHalf(dih, tbl_full).raw}; 5633 } 5634 5635 // Partial table vector 5636 template <typename T, size_t N, typename TI> 5637 HWY_API Vec256<TI> TableLookupBytes(Vec128<T, N> bytes, Vec256<TI> from) { 5638 const Full256<T> d; 5639 // First expand to full 128, then 256. 5640 const auto bytes_256 = ZeroExtendVector(d, Vec128<T>{bytes.raw}); 5641 return TableLookupBytes(bytes_256, from); 5642 } 5643 5644 // Partial both are handled by x86_128. 5645 5646 // ------------------------------ I8/U8 Broadcast (TableLookupBytes) 5647 5648 template <int kLane, class T, HWY_IF_T_SIZE(T, 1)> 5649 HWY_API Vec256<T> Broadcast(const Vec256<T> v) { 5650 static_assert(0 <= kLane && kLane < 16, "Invalid lane"); 5651 return TableLookupBytes(v, Set(Full256<T>(), static_cast<T>(kLane))); 5652 } 5653 5654 // ------------------------------ Per4LaneBlockShuffle 5655 5656 namespace detail { 5657 5658 template <class D, HWY_IF_V_SIZE_D(D, 32)> 5659 HWY_INLINE VFromD<D> Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, 5660 const uint32_t x2, 5661 const uint32_t x1, 5662 const uint32_t x0) { 5663 return BitCast(d, Vec256<uint32_t>{_mm256_set_epi32( 5664 static_cast<int32_t>(x3), static_cast<int32_t>(x2), 5665 static_cast<int32_t>(x1), static_cast<int32_t>(x0), 5666 static_cast<int32_t>(x3), static_cast<int32_t>(x2), 5667 static_cast<int32_t>(x1), static_cast<int32_t>(x0))}); 5668 } 5669 5670 template <size_t kIdx3210, class V, HWY_IF_NOT_FLOAT(TFromV<V>)> 5671 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, 5672 hwy::SizeTag<4> /*lane_size_tag*/, 5673 hwy::SizeTag<32> /*vect_size_tag*/, V v) { 5674 return V{_mm256_shuffle_epi32(v.raw, static_cast<int>(kIdx3210 & 0xFF))}; 5675 } 5676 5677 template <size_t kIdx3210, class V, HWY_IF_FLOAT(TFromV<V>)> 5678 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, 5679 hwy::SizeTag<4> /*lane_size_tag*/, 5680 hwy::SizeTag<32> /*vect_size_tag*/, V v) { 5681 return V{_mm256_shuffle_ps(v.raw, v.raw, static_cast<int>(kIdx3210 & 0xFF))}; 5682 } 5683 5684 template <class V> 5685 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x44> /*idx_3210_tag*/, 5686 hwy::SizeTag<8> /*lane_size_tag*/, 5687 hwy::SizeTag<32> /*vect_size_tag*/, V v) { 5688 const DFromV<decltype(v)> d; 5689 return ConcatLowerLower(d, v, v); 5690 } 5691 5692 template <class V> 5693 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xEE> /*idx_3210_tag*/, 5694 hwy::SizeTag<8> /*lane_size_tag*/, 5695 hwy::SizeTag<32> /*vect_size_tag*/, V v) { 5696 const DFromV<decltype(v)> d; 5697 return ConcatUpperUpper(d, v, v); 5698 } 5699 5700 template <size_t kIdx3210, class V, HWY_IF_NOT_FLOAT(TFromV<V>)> 5701 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, 5702 hwy::SizeTag<8> /*lane_size_tag*/, 5703 hwy::SizeTag<32> /*vect_size_tag*/, V v) { 5704 return V{_mm256_permute4x64_epi64(v.raw, static_cast<int>(kIdx3210 & 0xFF))}; 5705 } 5706 5707 template <size_t kIdx3210, class V, HWY_IF_FLOAT(TFromV<V>)> 5708 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, 5709 hwy::SizeTag<8> /*lane_size_tag*/, 5710 hwy::SizeTag<32> /*vect_size_tag*/, V v) { 5711 return V{_mm256_permute4x64_pd(v.raw, static_cast<int>(kIdx3210 & 0xFF))}; 5712 } 5713 5714 } // namespace detail 5715 5716 // ------------------------------ SlideUpLanes 5717 5718 namespace detail { 5719 5720 #if HWY_TARGET <= HWY_AVX3 5721 template <int kI32Lanes, class V, HWY_IF_V_SIZE_V(V, 32)> 5722 HWY_INLINE V CombineShiftRightI32Lanes(V hi, V lo) { 5723 const DFromV<decltype(hi)> d; 5724 const Repartition<uint32_t, decltype(d)> du32; 5725 return BitCast(d, 5726 Vec256<uint32_t>{_mm256_alignr_epi32( 5727 BitCast(du32, hi).raw, BitCast(du32, lo).raw, kI32Lanes)}); 5728 } 5729 5730 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 32)> 5731 HWY_INLINE V CombineShiftRightI64Lanes(V hi, V lo) { 5732 const DFromV<decltype(hi)> d; 5733 const Repartition<uint64_t, decltype(d)> du64; 5734 return BitCast(d, 5735 Vec256<uint64_t>{_mm256_alignr_epi64( 5736 BitCast(du64, hi).raw, BitCast(du64, lo).raw, kI64Lanes)}); 5737 } 5738 5739 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 32)> 5740 HWY_INLINE V SlideUpI64Lanes(V v) { 5741 static_assert(0 <= kI64Lanes && kI64Lanes <= 3, 5742 "kI64Lanes must be between 0 and 3"); 5743 const DFromV<decltype(v)> d; 5744 return CombineShiftRightI64Lanes<4 - kI64Lanes>(v, Zero(d)); 5745 } 5746 #else // AVX2 5747 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 32), 5748 HWY_IF_NOT_FLOAT_D(DFromV<V>)> 5749 HWY_INLINE V SlideUpI64Lanes(V v) { 5750 static_assert(0 <= kI64Lanes && kI64Lanes <= 3, 5751 "kI64Lanes must be between 0 and 3"); 5752 constexpr int kIdx0 = (-kI64Lanes) & 3; 5753 constexpr int kIdx1 = (-kI64Lanes + 1) & 3; 5754 constexpr int kIdx2 = (-kI64Lanes + 2) & 3; 5755 constexpr int kIdx3 = (-kI64Lanes + 3) & 3; 5756 constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0); 5757 constexpr int kBlendMask = (1 << (kI64Lanes * 2)) - 1; 5758 5759 const DFromV<decltype(v)> d; 5760 return V{_mm256_blend_epi32(_mm256_permute4x64_epi64(v.raw, kIdx3210), 5761 Zero(d).raw, kBlendMask)}; 5762 } 5763 5764 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 32), 5765 HWY_IF_FLOAT_D(DFromV<V>)> 5766 HWY_INLINE V SlideUpI64Lanes(V v) { 5767 static_assert(0 <= kI64Lanes && kI64Lanes <= 3, 5768 "kI64Lanes must be between 0 and 3"); 5769 constexpr int kIdx0 = (-kI64Lanes) & 3; 5770 constexpr int kIdx1 = (-kI64Lanes + 1) & 3; 5771 constexpr int kIdx2 = (-kI64Lanes + 2) & 3; 5772 constexpr int kIdx3 = (-kI64Lanes + 3) & 3; 5773 constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0); 5774 constexpr int kBlendMask = (1 << kI64Lanes) - 1; 5775 5776 const DFromV<decltype(v)> d; 5777 const Repartition<double, decltype(d)> dd; 5778 return BitCast(d, Vec256<double>{_mm256_blend_pd( 5779 _mm256_permute4x64_pd(BitCast(dd, v).raw, kIdx3210), 5780 Zero(dd).raw, kBlendMask)}); 5781 } 5782 #endif // HWY_TARGET <= HWY_AVX3 5783 5784 template <class D, HWY_IF_V_SIZE_D(D, 32), 5785 HWY_IF_T_SIZE_ONE_OF_D( 5786 D, (1 << 1) | ((HWY_TARGET > HWY_AVX3) ? (1 << 2) : 0))> 5787 HWY_INLINE VFromD<D> TableLookupSlideUpLanes(D d, VFromD<D> v, size_t amt) { 5788 const Repartition<uint8_t, decltype(d)> du8; 5789 5790 const auto idx_vec = 5791 Iota(du8, static_cast<uint8_t>(size_t{0} - amt * sizeof(TFromD<D>))); 5792 const Indices256<TFromD<D>> idx{idx_vec.raw}; 5793 5794 #if HWY_TARGET <= HWY_AVX3_DL 5795 return TwoTablesLookupLanes(v, Zero(d), idx); 5796 #else 5797 return TableLookupLanes(v, idx); 5798 #endif 5799 } 5800 5801 template <class D, HWY_IF_V_SIZE_GT_D(D, 16), 5802 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | ((HWY_TARGET <= HWY_AVX3) 5803 ? ((1 << 2) | (1 << 8)) 5804 : 0))> 5805 HWY_INLINE VFromD<D> TableLookupSlideUpLanes(D d, VFromD<D> v, size_t amt) { 5806 const RebindToUnsigned<decltype(d)> du; 5807 using TU = TFromD<decltype(du)>; 5808 5809 const auto idx = Iota(du, static_cast<TU>(size_t{0} - amt)); 5810 #if HWY_TARGET <= HWY_AVX3 5811 const auto masked_idx = 5812 And(idx, Set(du, static_cast<TU>(MaxLanes(d) * 2 - 1))); 5813 return TwoTablesLookupLanes(v, Zero(d), IndicesFromVec(d, masked_idx)); 5814 #else 5815 const auto masked_idx = And(idx, Set(du, static_cast<TU>(MaxLanes(d) - 1))); 5816 return IfThenElseZero(RebindMask(d, idx == masked_idx), 5817 TableLookupLanes(v, IndicesFromVec(d, masked_idx))); 5818 #endif 5819 } 5820 5821 #if HWY_TARGET > HWY_AVX3 5822 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)> 5823 HWY_INLINE VFromD<D> TableLookupSlideUpLanes(D d, VFromD<D> v, size_t amt) { 5824 const RepartitionToNarrow<D> dn; 5825 return BitCast(d, TableLookupSlideUpLanes(dn, BitCast(dn, v), amt * 2)); 5826 } 5827 #endif // HWY_TARGET > HWY_AVX3 5828 5829 } // namespace detail 5830 5831 template <int kBlocks, class D, HWY_IF_V_SIZE_D(D, 32)> 5832 HWY_API VFromD<D> SlideUpBlocks(D d, VFromD<D> v) { 5833 static_assert(0 <= kBlocks && kBlocks <= 1, 5834 "kBlocks must be between 0 and 1"); 5835 return (kBlocks == 1) ? ConcatLowerLower(d, v, Zero(d)) : v; 5836 } 5837 5838 template <class D, HWY_IF_V_SIZE_D(D, 32)> 5839 HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { 5840 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5841 constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD<D>); 5842 if (__builtin_constant_p(amt)) { 5843 const auto v_lo = ConcatLowerLower(d, v, Zero(d)); 5844 switch (amt * sizeof(TFromD<D>)) { 5845 case 0: 5846 return v; 5847 case 1: 5848 return CombineShiftRightBytes<15>(d, v, v_lo); 5849 case 2: 5850 return CombineShiftRightBytes<14>(d, v, v_lo); 5851 case 3: 5852 return CombineShiftRightBytes<13>(d, v, v_lo); 5853 case 4: 5854 #if HWY_TARGET <= HWY_AVX3 5855 return detail::CombineShiftRightI32Lanes<7>(v, Zero(d)); 5856 #else 5857 return CombineShiftRightBytes<12>(d, v, v_lo); 5858 #endif 5859 case 5: 5860 return CombineShiftRightBytes<11>(d, v, v_lo); 5861 case 6: 5862 return CombineShiftRightBytes<10>(d, v, v_lo); 5863 case 7: 5864 return CombineShiftRightBytes<9>(d, v, v_lo); 5865 case 8: 5866 return detail::SlideUpI64Lanes<1>(v); 5867 case 9: 5868 return CombineShiftRightBytes<7>(d, v, v_lo); 5869 case 10: 5870 return CombineShiftRightBytes<6>(d, v, v_lo); 5871 case 11: 5872 return CombineShiftRightBytes<5>(d, v, v_lo); 5873 case 12: 5874 #if HWY_TARGET <= HWY_AVX3 5875 return detail::CombineShiftRightI32Lanes<5>(v, Zero(d)); 5876 #else 5877 return CombineShiftRightBytes<4>(d, v, v_lo); 5878 #endif 5879 case 13: 5880 return CombineShiftRightBytes<3>(d, v, v_lo); 5881 case 14: 5882 return CombineShiftRightBytes<2>(d, v, v_lo); 5883 case 15: 5884 return CombineShiftRightBytes<1>(d, v, v_lo); 5885 case 16: 5886 return ConcatLowerLower(d, v, Zero(d)); 5887 #if HWY_TARGET <= HWY_AVX3 5888 case 20: 5889 return detail::CombineShiftRightI32Lanes<3>(v, Zero(d)); 5890 #endif 5891 case 24: 5892 return detail::SlideUpI64Lanes<3>(v); 5893 #if HWY_TARGET <= HWY_AVX3 5894 case 28: 5895 return detail::CombineShiftRightI32Lanes<1>(v, Zero(d)); 5896 #endif 5897 } 5898 } 5899 5900 if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { 5901 const Half<decltype(d)> dh; 5902 return Combine(d, SlideUpLanes(dh, LowerHalf(dh, v), amt - kLanesPerBlock), 5903 Zero(dh)); 5904 } 5905 #endif 5906 5907 return detail::TableLookupSlideUpLanes(d, v, amt); 5908 } 5909 5910 // ------------------------------ Slide1Up 5911 5912 template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 5913 HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { 5914 const auto v_lo = ConcatLowerLower(d, v, Zero(d)); 5915 return CombineShiftRightBytes<15>(d, v, v_lo); 5916 } 5917 5918 template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 5919 HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { 5920 const auto v_lo = ConcatLowerLower(d, v, Zero(d)); 5921 return CombineShiftRightBytes<14>(d, v, v_lo); 5922 } 5923 5924 template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)> 5925 HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { 5926 #if HWY_TARGET <= HWY_AVX3 5927 return detail::CombineShiftRightI32Lanes<7>(v, Zero(d)); 5928 #else 5929 const auto v_lo = ConcatLowerLower(d, v, Zero(d)); 5930 return CombineShiftRightBytes<12>(d, v, v_lo); 5931 #endif 5932 } 5933 5934 template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)> 5935 HWY_API VFromD<D> Slide1Up(D /*d*/, VFromD<D> v) { 5936 return detail::SlideUpI64Lanes<1>(v); 5937 } 5938 5939 // ------------------------------ SlideDownLanes 5940 5941 namespace detail { 5942 5943 #if HWY_TARGET <= HWY_AVX3 5944 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 32)> 5945 HWY_INLINE V SlideDownI64Lanes(V v) { 5946 static_assert(0 <= kI64Lanes && kI64Lanes <= 3, 5947 "kI64Lanes must be between 0 and 3"); 5948 const DFromV<decltype(v)> d; 5949 return CombineShiftRightI64Lanes<kI64Lanes>(Zero(d), v); 5950 } 5951 #else // AVX2 5952 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 32), 5953 HWY_IF_NOT_FLOAT_D(DFromV<V>)> 5954 HWY_INLINE V SlideDownI64Lanes(V v) { 5955 static_assert(0 <= kI64Lanes && kI64Lanes <= 3, 5956 "kI64Lanes must be between 0 and 3"); 5957 constexpr int kIdx1 = (kI64Lanes + 1) & 3; 5958 constexpr int kIdx2 = (kI64Lanes + 2) & 3; 5959 constexpr int kIdx3 = (kI64Lanes + 3) & 3; 5960 constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kI64Lanes); 5961 constexpr int kBlendMask = 5962 static_cast<int>((0xFFu << ((4 - kI64Lanes) * 2)) & 0xFFu); 5963 5964 const DFromV<decltype(v)> d; 5965 return V{_mm256_blend_epi32(_mm256_permute4x64_epi64(v.raw, kIdx3210), 5966 Zero(d).raw, kBlendMask)}; 5967 } 5968 5969 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 32), 5970 HWY_IF_FLOAT_D(DFromV<V>)> 5971 HWY_INLINE V SlideDownI64Lanes(V v) { 5972 static_assert(0 <= kI64Lanes && kI64Lanes <= 3, 5973 "kI64Lanes must be between 0 and 3"); 5974 constexpr int kIdx1 = (kI64Lanes + 1) & 3; 5975 constexpr int kIdx2 = (kI64Lanes + 2) & 3; 5976 constexpr int kIdx3 = (kI64Lanes + 3) & 3; 5977 constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kI64Lanes); 5978 constexpr int kBlendMask = (0x0F << (4 - kI64Lanes)) & 0x0F; 5979 5980 const DFromV<decltype(v)> d; 5981 const Repartition<double, decltype(d)> dd; 5982 return BitCast(d, Vec256<double>{_mm256_blend_pd( 5983 _mm256_permute4x64_pd(BitCast(dd, v).raw, kIdx3210), 5984 Zero(dd).raw, kBlendMask)}); 5985 } 5986 #endif // HWY_TARGET <= HWY_AVX3 5987 5988 template <class D, HWY_IF_V_SIZE_D(D, 32), 5989 HWY_IF_T_SIZE_ONE_OF_D( 5990 D, (1 << 1) | ((HWY_TARGET > HWY_AVX3) ? (1 << 2) : 0))> 5991 HWY_INLINE VFromD<D> TableLookupSlideDownLanes(D d, VFromD<D> v, size_t amt) { 5992 const Repartition<uint8_t, decltype(d)> du8; 5993 5994 auto idx_vec = Iota(du8, static_cast<uint8_t>(amt * sizeof(TFromD<D>))); 5995 5996 #if HWY_TARGET <= HWY_AVX3_DL 5997 const auto result_mask = idx_vec < Set(du8, uint8_t{32}); 5998 return VFromD<D>{ 5999 _mm256_maskz_permutexvar_epi8(result_mask.raw, idx_vec.raw, v.raw)}; 6000 #else 6001 const RebindToSigned<decltype(du8)> di8; 6002 idx_vec = 6003 Or(idx_vec, BitCast(du8, VecFromMask(di8, BitCast(di8, idx_vec) > 6004 Set(di8, int8_t{31})))); 6005 return TableLookupLanes(v, Indices256<TFromD<D>>{idx_vec.raw}); 6006 #endif 6007 } 6008 6009 template <class D, HWY_IF_V_SIZE_GT_D(D, 16), 6010 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | ((HWY_TARGET <= HWY_AVX3) 6011 ? ((1 << 2) | (1 << 8)) 6012 : 0))> 6013 HWY_INLINE VFromD<D> TableLookupSlideDownLanes(D d, VFromD<D> v, size_t amt) { 6014 const RebindToUnsigned<decltype(d)> du; 6015 using TU = TFromD<decltype(du)>; 6016 6017 const auto idx = Iota(du, static_cast<TU>(amt)); 6018 const auto masked_idx = And(idx, Set(du, static_cast<TU>(MaxLanes(d) - 1))); 6019 6020 return IfThenElseZero(RebindMask(d, idx == masked_idx), 6021 TableLookupLanes(v, IndicesFromVec(d, masked_idx))); 6022 } 6023 6024 #if HWY_TARGET > HWY_AVX3 6025 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)> 6026 HWY_INLINE VFromD<D> TableLookupSlideDownLanes(D d, VFromD<D> v, size_t amt) { 6027 const RepartitionToNarrow<D> dn; 6028 return BitCast(d, TableLookupSlideDownLanes(dn, BitCast(dn, v), amt * 2)); 6029 } 6030 #endif // HWY_TARGET > HWY_AVX3 6031 6032 } // namespace detail 6033 6034 template <int kBlocks, class D, HWY_IF_V_SIZE_D(D, 32)> 6035 HWY_API VFromD<D> SlideDownBlocks(D d, VFromD<D> v) { 6036 static_assert(0 <= kBlocks && kBlocks <= 1, 6037 "kBlocks must be between 0 and 1"); 6038 const Half<decltype(d)> dh; 6039 return (kBlocks == 1) ? ZeroExtendVector(d, UpperHalf(dh, v)) : v; 6040 } 6041 6042 template <class D, HWY_IF_V_SIZE_D(D, 32)> 6043 HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { 6044 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 6045 constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD<D>); 6046 const Half<decltype(d)> dh; 6047 if (__builtin_constant_p(amt)) { 6048 const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); 6049 switch (amt * sizeof(TFromD<D>)) { 6050 case 0: 6051 return v; 6052 case 1: 6053 return CombineShiftRightBytes<1>(d, v_hi, v); 6054 case 2: 6055 return CombineShiftRightBytes<2>(d, v_hi, v); 6056 case 3: 6057 return CombineShiftRightBytes<3>(d, v_hi, v); 6058 case 4: 6059 #if HWY_TARGET <= HWY_AVX3 6060 return detail::CombineShiftRightI32Lanes<1>(Zero(d), v); 6061 #else 6062 return CombineShiftRightBytes<4>(d, v_hi, v); 6063 #endif 6064 case 5: 6065 return CombineShiftRightBytes<5>(d, v_hi, v); 6066 case 6: 6067 return CombineShiftRightBytes<6>(d, v_hi, v); 6068 case 7: 6069 return CombineShiftRightBytes<7>(d, v_hi, v); 6070 case 8: 6071 return detail::SlideDownI64Lanes<1>(v); 6072 case 9: 6073 return CombineShiftRightBytes<9>(d, v_hi, v); 6074 case 10: 6075 return CombineShiftRightBytes<10>(d, v_hi, v); 6076 case 11: 6077 return CombineShiftRightBytes<11>(d, v_hi, v); 6078 case 12: 6079 #if HWY_TARGET <= HWY_AVX3 6080 return detail::CombineShiftRightI32Lanes<3>(Zero(d), v); 6081 #else 6082 return CombineShiftRightBytes<12>(d, v_hi, v); 6083 #endif 6084 case 13: 6085 return CombineShiftRightBytes<13>(d, v_hi, v); 6086 case 14: 6087 return CombineShiftRightBytes<14>(d, v_hi, v); 6088 case 15: 6089 return CombineShiftRightBytes<15>(d, v_hi, v); 6090 case 16: 6091 return v_hi; 6092 #if HWY_TARGET <= HWY_AVX3 6093 case 20: 6094 return detail::CombineShiftRightI32Lanes<5>(Zero(d), v); 6095 #endif 6096 case 24: 6097 return detail::SlideDownI64Lanes<3>(v); 6098 #if HWY_TARGET <= HWY_AVX3 6099 case 28: 6100 return detail::CombineShiftRightI32Lanes<7>(Zero(d), v); 6101 #endif 6102 } 6103 } 6104 6105 if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { 6106 return ZeroExtendVector( 6107 d, SlideDownLanes(dh, UpperHalf(dh, v), amt - kLanesPerBlock)); 6108 } 6109 #endif 6110 6111 return detail::TableLookupSlideDownLanes(d, v, amt); 6112 } 6113 6114 // ------------------------------ Slide1Down 6115 6116 template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)> 6117 HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { 6118 const Half<decltype(d)> dh; 6119 const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); 6120 return CombineShiftRightBytes<1>(d, v_hi, v); 6121 } 6122 6123 template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 6124 HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { 6125 const Half<decltype(d)> dh; 6126 const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); 6127 return CombineShiftRightBytes<2>(d, v_hi, v); 6128 } 6129 6130 template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)> 6131 HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { 6132 #if HWY_TARGET <= HWY_AVX3 6133 return detail::CombineShiftRightI32Lanes<1>(Zero(d), v); 6134 #else 6135 const Half<decltype(d)> dh; 6136 const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); 6137 return CombineShiftRightBytes<4>(d, v_hi, v); 6138 #endif 6139 } 6140 6141 template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)> 6142 HWY_API VFromD<D> Slide1Down(D /*d*/, VFromD<D> v) { 6143 return detail::SlideDownI64Lanes<1>(v); 6144 } 6145 6146 // ------------------------------ Shl (Mul, ZipLower) 6147 6148 namespace detail { 6149 6150 #if HWY_TARGET > HWY_AVX3 && !HWY_IDE // AVX2 or older 6151 template <class V> 6152 HWY_INLINE V AVX2ShlU16Vec256(V v, V bits) { 6153 const DFromV<decltype(v)> d; 6154 const Half<decltype(d)> dh; 6155 const Rebind<uint32_t, decltype(dh)> du32; 6156 6157 const auto lo_shl_result = PromoteTo(du32, LowerHalf(dh, v)) 6158 << PromoteTo(du32, LowerHalf(dh, bits)); 6159 const auto hi_shl_result = PromoteTo(du32, UpperHalf(dh, v)) 6160 << PromoteTo(du32, UpperHalf(dh, bits)); 6161 return ConcatEven(d, BitCast(d, hi_shl_result), BitCast(d, lo_shl_result)); 6162 } 6163 #endif 6164 6165 HWY_INLINE Vec256<uint16_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint16_t> v, 6166 Vec256<uint16_t> bits) { 6167 #if HWY_TARGET <= HWY_AVX3 || HWY_IDE 6168 return Vec256<uint16_t>{_mm256_sllv_epi16(v.raw, bits.raw)}; 6169 #else 6170 return AVX2ShlU16Vec256(v, bits); 6171 #endif 6172 } 6173 6174 // 8-bit: may use the Shl overload for uint16_t. 6175 HWY_API Vec256<uint8_t> Shl(hwy::UnsignedTag tag, Vec256<uint8_t> v, 6176 Vec256<uint8_t> bits) { 6177 const DFromV<decltype(v)> d; 6178 #if HWY_TARGET <= HWY_AVX3_DL 6179 (void)tag; 6180 // masks[i] = 0xFF >> i 6181 const VFromD<decltype(d)> masks = 6182 Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, 6183 0, 0, 0, 0, 0, 0, 0); 6184 // kShl[i] = 1 << i 6185 const VFromD<decltype(d)> shl = Dup128VecFromValues( 6186 d, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 0, 0, 0, 0, 0, 0, 0, 0); 6187 v = And(v, TableLookupBytes(masks, bits)); 6188 const VFromD<decltype(d)> mul = TableLookupBytes(shl, bits); 6189 return VFromD<decltype(d)>{_mm256_gf2p8mul_epi8(v.raw, mul.raw)}; 6190 #else 6191 const Repartition<uint16_t, decltype(d)> dw; 6192 using VW = VFromD<decltype(dw)>; 6193 const VW even_mask = Set(dw, 0x00FF); 6194 const VW odd_mask = Set(dw, 0xFF00); 6195 const VW vw = BitCast(dw, v); 6196 const VW bits16 = BitCast(dw, bits); 6197 // Shift even lanes in-place 6198 const VW evens = Shl(tag, vw, And(bits16, even_mask)); 6199 const VW odds = Shl(tag, And(vw, odd_mask), ShiftRight<8>(bits16)); 6200 return OddEven(BitCast(d, odds), BitCast(d, evens)); 6201 #endif 6202 } 6203 6204 HWY_INLINE Vec256<uint32_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint32_t> v, 6205 Vec256<uint32_t> bits) { 6206 return Vec256<uint32_t>{_mm256_sllv_epi32(v.raw, bits.raw)}; 6207 } 6208 6209 HWY_INLINE Vec256<uint64_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint64_t> v, 6210 Vec256<uint64_t> bits) { 6211 return Vec256<uint64_t>{_mm256_sllv_epi64(v.raw, bits.raw)}; 6212 } 6213 6214 template <typename T> 6215 HWY_INLINE Vec256<T> Shl(hwy::SignedTag /*tag*/, Vec256<T> v, Vec256<T> bits) { 6216 // Signed left shifts are the same as unsigned. 6217 const Full256<T> di; 6218 const Full256<MakeUnsigned<T>> du; 6219 return BitCast(di, 6220 Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); 6221 } 6222 6223 } // namespace detail 6224 6225 template <typename T> 6226 HWY_API Vec256<T> operator<<(Vec256<T> v, Vec256<T> bits) { 6227 return detail::Shl(hwy::TypeTag<T>(), v, bits); 6228 } 6229 6230 // ------------------------------ Shr (MulHigh, IfThenElse, Not) 6231 6232 #if HWY_TARGET > HWY_AVX3 // AVX2 6233 namespace detail { 6234 6235 template <class V> 6236 HWY_INLINE V AVX2ShrU16Vec256(V v, V bits) { 6237 const DFromV<decltype(v)> d; 6238 const Half<decltype(d)> dh; 6239 const Rebind<int32_t, decltype(dh)> di32; 6240 const Rebind<uint32_t, decltype(dh)> du32; 6241 6242 const auto lo_shr_result = 6243 PromoteTo(du32, LowerHalf(dh, v)) >> PromoteTo(du32, LowerHalf(dh, bits)); 6244 const auto hi_shr_result = 6245 PromoteTo(du32, UpperHalf(dh, v)) >> PromoteTo(du32, UpperHalf(dh, bits)); 6246 return OrderedDemote2To(d, BitCast(di32, lo_shr_result), 6247 BitCast(di32, hi_shr_result)); 6248 } 6249 6250 } // namespace detail 6251 #endif 6252 6253 HWY_API Vec256<uint16_t> operator>>(Vec256<uint16_t> v, Vec256<uint16_t> bits) { 6254 #if HWY_TARGET <= HWY_AVX3 6255 return Vec256<uint16_t>{_mm256_srlv_epi16(v.raw, bits.raw)}; 6256 #else 6257 return detail::AVX2ShrU16Vec256(v, bits); 6258 #endif 6259 } 6260 6261 // 8-bit uses 16-bit shifts. 6262 HWY_API Vec256<uint8_t> operator>>(Vec256<uint8_t> v, Vec256<uint8_t> bits) { 6263 const DFromV<decltype(v)> d; 6264 const RepartitionToWide<decltype(d)> dw; 6265 using VW = VFromD<decltype(dw)>; 6266 const VW mask = Set(dw, 0x00FF); 6267 const VW vw = BitCast(dw, v); 6268 const VW bits16 = BitCast(dw, bits); 6269 const VW evens = And(vw, mask) >> And(bits16, mask); 6270 // Shift odd lanes in-place 6271 const VW odds = vw >> ShiftRight<8>(bits16); 6272 return OddEven(BitCast(d, odds), BitCast(d, evens)); 6273 } 6274 6275 HWY_API Vec256<uint32_t> operator>>(Vec256<uint32_t> v, Vec256<uint32_t> bits) { 6276 return Vec256<uint32_t>{_mm256_srlv_epi32(v.raw, bits.raw)}; 6277 } 6278 6279 HWY_API Vec256<uint64_t> operator>>(Vec256<uint64_t> v, Vec256<uint64_t> bits) { 6280 return Vec256<uint64_t>{_mm256_srlv_epi64(v.raw, bits.raw)}; 6281 } 6282 6283 #if HWY_TARGET > HWY_AVX3 // AVX2 6284 namespace detail { 6285 6286 template <class V> 6287 HWY_INLINE V AVX2ShrI16Vec256(V v, V bits) { 6288 const DFromV<decltype(v)> d; 6289 const Half<decltype(d)> dh; 6290 const Rebind<int32_t, decltype(dh)> di32; 6291 6292 const auto lo_shr_result = 6293 PromoteTo(di32, LowerHalf(dh, v)) >> PromoteTo(di32, LowerHalf(dh, bits)); 6294 const auto hi_shr_result = 6295 PromoteTo(di32, UpperHalf(dh, v)) >> PromoteTo(di32, UpperHalf(dh, bits)); 6296 return OrderedDemote2To(d, lo_shr_result, hi_shr_result); 6297 } 6298 6299 } // namespace detail 6300 #endif 6301 6302 HWY_API Vec256<int16_t> operator>>(Vec256<int16_t> v, Vec256<int16_t> bits) { 6303 #if HWY_TARGET <= HWY_AVX3 6304 return Vec256<int16_t>{_mm256_srav_epi16(v.raw, bits.raw)}; 6305 #else 6306 return detail::AVX2ShrI16Vec256(v, bits); 6307 #endif 6308 } 6309 6310 // 8-bit uses 16-bit shifts. 6311 HWY_API Vec256<int8_t> operator>>(Vec256<int8_t> v, Vec256<int8_t> bits) { 6312 const DFromV<decltype(v)> d; 6313 const RepartitionToWide<decltype(d)> dw; 6314 const RebindToUnsigned<decltype(dw)> dw_u; 6315 using VW = VFromD<decltype(dw)>; 6316 const VW mask = Set(dw, 0x00FF); 6317 const VW vw = BitCast(dw, v); 6318 const VW bits16 = BitCast(dw, bits); 6319 const VW evens = ShiftRight<8>(ShiftLeft<8>(vw)) >> And(bits16, mask); 6320 // Shift odd lanes in-place 6321 const VW odds = vw >> BitCast(dw, ShiftRight<8>(BitCast(dw_u, bits16))); 6322 return OddEven(BitCast(d, odds), BitCast(d, evens)); 6323 } 6324 6325 HWY_API Vec256<int32_t> operator>>(Vec256<int32_t> v, Vec256<int32_t> bits) { 6326 return Vec256<int32_t>{_mm256_srav_epi32(v.raw, bits.raw)}; 6327 } 6328 6329 HWY_API Vec256<int64_t> operator>>(Vec256<int64_t> v, Vec256<int64_t> bits) { 6330 #if HWY_TARGET <= HWY_AVX3 6331 return Vec256<int64_t>{_mm256_srav_epi64(v.raw, bits.raw)}; 6332 #else 6333 const DFromV<decltype(v)> d; 6334 return detail::SignedShr(d, v, bits); 6335 #endif 6336 } 6337 6338 // ------------------------------ WidenMulPairwiseAdd 6339 6340 #if HWY_NATIVE_DOT_BF16 6341 6342 template <class DF, HWY_IF_F32_D(DF), HWY_IF_V_SIZE_D(DF, 32), 6343 class VBF = VFromD<Repartition<bfloat16_t, DF>>> 6344 HWY_API VFromD<DF> WidenMulPairwiseAdd(DF df, VBF a, VBF b) { 6345 return VFromD<DF>{_mm256_dpbf16_ps(Zero(df).raw, 6346 reinterpret_cast<__m256bh>(a.raw), 6347 reinterpret_cast<__m256bh>(b.raw))}; 6348 } 6349 6350 #endif // HWY_NATIVE_DOT_BF16 6351 6352 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> 6353 HWY_API VFromD<D> WidenMulPairwiseAdd(D /*d32*/, Vec256<int16_t> a, 6354 Vec256<int16_t> b) { 6355 return VFromD<D>{_mm256_madd_epi16(a.raw, b.raw)}; 6356 } 6357 6358 // ------------------------------ SatWidenMulPairwiseAdd 6359 6360 template <class DI16, HWY_IF_V_SIZE_D(DI16, 32), HWY_IF_I16_D(DI16)> 6361 HWY_API VFromD<DI16> SatWidenMulPairwiseAdd( 6362 DI16 /* tag */, VFromD<Repartition<uint8_t, DI16>> a, 6363 VFromD<Repartition<int8_t, DI16>> b) { 6364 return VFromD<DI16>{_mm256_maddubs_epi16(a.raw, b.raw)}; 6365 } 6366 6367 // ------------------------------ SatWidenMulPairwiseAccumulate 6368 6369 #if HWY_TARGET <= HWY_AVX3_DL 6370 template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 32)> 6371 HWY_API VFromD<DI32> SatWidenMulPairwiseAccumulate( 6372 DI32 /* tag */, VFromD<Repartition<int16_t, DI32>> a, 6373 VFromD<Repartition<int16_t, DI32>> b, VFromD<DI32> sum) { 6374 return VFromD<DI32>{_mm256_dpwssds_epi32(sum.raw, a.raw, b.raw)}; 6375 } 6376 #endif // HWY_TARGET <= HWY_AVX3_DL 6377 6378 // ------------------------------ ReorderWidenMulAccumulate 6379 6380 #if HWY_NATIVE_DOT_BF16 6381 template <class DF, HWY_IF_F32_D(DF), HWY_IF_V_SIZE_D(DF, 32), 6382 class VBF = VFromD<Repartition<bfloat16_t, DF>>> 6383 HWY_API VFromD<DF> ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, 6384 const VFromD<DF> sum0, 6385 VFromD<DF>& /*sum1*/) { 6386 return VFromD<DF>{_mm256_dpbf16_ps(sum0.raw, 6387 reinterpret_cast<__m256bh>(a.raw), 6388 reinterpret_cast<__m256bh>(b.raw))}; 6389 } 6390 #endif // HWY_NATIVE_DOT_BF16 6391 6392 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> 6393 HWY_API VFromD<D> ReorderWidenMulAccumulate(D d, Vec256<int16_t> a, 6394 Vec256<int16_t> b, 6395 const VFromD<D> sum0, 6396 VFromD<D>& /*sum1*/) { 6397 (void)d; 6398 #if HWY_TARGET <= HWY_AVX3_DL 6399 return VFromD<D>{_mm256_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; 6400 #else 6401 return sum0 + WidenMulPairwiseAdd(d, a, b); 6402 #endif 6403 } 6404 6405 // ------------------------------ RearrangeToOddPlusEven 6406 HWY_API Vec256<int32_t> RearrangeToOddPlusEven(const Vec256<int32_t> sum0, 6407 Vec256<int32_t> /*sum1*/) { 6408 return sum0; // invariant already holds 6409 } 6410 6411 HWY_API Vec256<uint32_t> RearrangeToOddPlusEven(const Vec256<uint32_t> sum0, 6412 Vec256<uint32_t> /*sum1*/) { 6413 return sum0; // invariant already holds 6414 } 6415 6416 // ------------------------------ SumOfMulQuadAccumulate 6417 6418 #if HWY_TARGET <= HWY_AVX3_DL 6419 6420 template <class DI32, HWY_IF_V_SIZE_D(DI32, 32)> 6421 HWY_API VFromD<DI32> SumOfMulQuadAccumulate( 6422 DI32 /*di32*/, VFromD<Repartition<uint8_t, DI32>> a_u, 6423 VFromD<Repartition<int8_t, DI32>> b_i, VFromD<DI32> sum) { 6424 return VFromD<DI32>{_mm256_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; 6425 } 6426 6427 #if HWY_X86_HAVE_AVX10_2_OPS 6428 template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 32)> 6429 HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/, 6430 VFromD<Repartition<int8_t, DI32>> a, 6431 VFromD<Repartition<int8_t, DI32>> b, 6432 VFromD<DI32> sum) { 6433 return VFromD<DI32>{_mm256_dpbssd_epi32(sum.raw, a.raw, b.raw)}; 6434 } 6435 6436 template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_D(DU32, 32)> 6437 HWY_API VFromD<DU32> SumOfMulQuadAccumulate( 6438 DU32 /*du32*/, VFromD<Repartition<uint8_t, DU32>> a, 6439 VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) { 6440 return VFromD<DU32>{_mm256_dpbuud_epi32(sum.raw, a.raw, b.raw)}; 6441 } 6442 #endif // HWY_X86_HAVE_AVX10_2_OPS 6443 6444 #endif // HWY_TARGET <= HWY_AVX3_DL 6445 6446 // ================================================== CONVERT 6447 6448 // ------------------------------ Promotions (part w/ narrow lanes -> full) 6449 6450 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 6451 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<float> v) { 6452 return VFromD<D>{_mm256_cvtps_pd(v.raw)}; 6453 } 6454 6455 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 6456 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int32_t> v) { 6457 return VFromD<D>{_mm256_cvtepi32_pd(v.raw)}; 6458 } 6459 6460 #if HWY_TARGET <= HWY_AVX3 6461 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 6462 HWY_API Vec256<double> PromoteTo(D /* tag */, Vec128<uint32_t> v) { 6463 return Vec256<double>{_mm256_cvtepu32_pd(v.raw)}; 6464 } 6465 #endif 6466 6467 // Unsigned: zero-extend. 6468 // Note: these have 3 cycle latency; if inputs are already split across the 6469 // 128 bit blocks (in their upper/lower halves), then Zip* would be faster. 6470 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> 6471 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint8_t> v) { 6472 return VFromD<D>{_mm256_cvtepu8_epi16(v.raw)}; 6473 } 6474 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> 6475 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint8_t, 8> v) { 6476 return VFromD<D>{_mm256_cvtepu8_epi32(v.raw)}; 6477 } 6478 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> 6479 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint16_t> v) { 6480 return VFromD<D>{_mm256_cvtepu16_epi32(v.raw)}; 6481 } 6482 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)> 6483 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint32_t> v) { 6484 return VFromD<D>{_mm256_cvtepu32_epi64(v.raw)}; 6485 } 6486 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)> 6487 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec64<uint16_t> v) { 6488 return VFromD<D>{_mm256_cvtepu16_epi64(v.raw)}; 6489 } 6490 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)> 6491 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec32<uint8_t> v) { 6492 return VFromD<D>{_mm256_cvtepu8_epi64(v.raw)}; 6493 } 6494 6495 // Signed: replicate sign bit. 6496 // Note: these have 3 cycle latency; if inputs are already split across the 6497 // 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by 6498 // signed shift would be faster. 6499 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I16_D(D)> 6500 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int8_t> v) { 6501 return VFromD<D>{_mm256_cvtepi8_epi16(v.raw)}; 6502 } 6503 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> 6504 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int8_t, 8> v) { 6505 return VFromD<D>{_mm256_cvtepi8_epi32(v.raw)}; 6506 } 6507 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> 6508 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int16_t> v) { 6509 return VFromD<D>{_mm256_cvtepi16_epi32(v.raw)}; 6510 } 6511 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)> 6512 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int32_t> v) { 6513 return VFromD<D>{_mm256_cvtepi32_epi64(v.raw)}; 6514 } 6515 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)> 6516 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec64<int16_t> v) { 6517 return VFromD<D>{_mm256_cvtepi16_epi64(v.raw)}; 6518 } 6519 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)> 6520 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec32<int8_t> v) { 6521 return VFromD<D>{_mm256_cvtepi8_epi64(v.raw)}; 6522 } 6523 6524 #if HWY_TARGET <= HWY_AVX3 6525 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)> 6526 HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) { 6527 #if HWY_X86_HAVE_AVX10_2_OPS 6528 return VFromD<D>{_mm256_cvtts_ps_epi64(v.raw)}; 6529 #elif HWY_COMPILER_GCC_ACTUAL 6530 // Workaround for undefined behavior with GCC if any values of v[i] are not 6531 // within the range of an int64_t 6532 6533 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6534 if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { 6535 typedef float GccF32RawVectType __attribute__((__vector_size__(16))); 6536 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 6537 return VFromD<D>{_mm256_setr_epi64x( 6538 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[0]), 6539 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[1]), 6540 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[2]), 6541 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[3]))}; 6542 } 6543 #endif 6544 6545 __m256i raw_result; 6546 __asm__("vcvttps2qq {%1, %0|%0, %1}" 6547 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6548 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6549 :); 6550 return VFromD<D>{raw_result}; 6551 #else // !HWY_COMPILER_GCC_ACTUAL 6552 return VFromD<D>{_mm256_cvttps_epi64(v.raw)}; 6553 #endif // HWY_COMPILER_GCC_ACTUAL 6554 } 6555 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)> 6556 HWY_API VFromD<D> PromoteInRangeTo(D /* tag */, VFromD<Rebind<float, D>> v) { 6557 #if HWY_X86_HAVE_AVX10_2_OPS 6558 return VFromD<D>{_mm256_cvtts_ps_epu64(v.raw)}; 6559 #elif HWY_COMPILER_GCC_ACTUAL 6560 // Workaround for undefined behavior with GCC if any values of v[i] are not 6561 // within the range of an uint64_t 6562 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6563 if (detail::IsConstantX86VecForF2IConv<uint64_t>(v)) { 6564 typedef float GccF32RawVectType __attribute__((__vector_size__(16))); 6565 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 6566 return VFromD<D>{_mm256_setr_epi64x( 6567 static_cast<int64_t>( 6568 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[0])), 6569 static_cast<int64_t>( 6570 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[1])), 6571 static_cast<int64_t>( 6572 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[2])), 6573 static_cast<int64_t>( 6574 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[3])))}; 6575 } 6576 #endif 6577 6578 __m256i raw_result; 6579 __asm__("vcvttps2uqq {%1, %0|%0, %1}" 6580 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6581 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6582 :); 6583 return VFromD<D>{raw_result}; 6584 #else // !HWY_COMPILER_GCC_ACTUAL 6585 return VFromD<D>{_mm256_cvttps_epu64(v.raw)}; 6586 #endif // HWY_COMPILER_GCC_ACTUAL 6587 } 6588 #endif // HWY_TARGET <= HWY_AVX3 6589 6590 // ------------------------------ PromoteEvenTo/PromoteOddTo 6591 #if HWY_TARGET > HWY_AVX3 6592 namespace detail { 6593 6594 // I32->I64 PromoteEvenTo/PromoteOddTo 6595 6596 template <class D, HWY_IF_LANES_D(D, 4)> 6597 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, 6598 hwy::SizeTag<8> /*to_lane_size_tag*/, 6599 hwy::SignedTag /*from_type_tag*/, D d_to, 6600 Vec256<int32_t> v) { 6601 return BitCast(d_to, OddEven(DupEven(BroadcastSignBit(v)), v)); 6602 } 6603 6604 template <class D, HWY_IF_LANES_D(D, 4)> 6605 HWY_INLINE VFromD<D> PromoteOddTo(hwy::SignedTag /*to_type_tag*/, 6606 hwy::SizeTag<8> /*to_lane_size_tag*/, 6607 hwy::SignedTag /*from_type_tag*/, D d_to, 6608 Vec256<int32_t> v) { 6609 return BitCast(d_to, OddEven(BroadcastSignBit(v), DupOdd(v))); 6610 } 6611 6612 } // namespace detail 6613 #endif 6614 6615 // ------------------------------ Demotions (full -> part w/ narrow lanes) 6616 6617 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> 6618 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int32_t> v) { 6619 const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); 6620 // Concatenating lower halves of both 128-bit blocks afterward is more 6621 // efficient than an extra input with low block = high block of v. 6622 return VFromD<D>{_mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; 6623 } 6624 6625 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> 6626 HWY_API VFromD<D> DemoteTo(D dn, Vec256<uint32_t> v) { 6627 const DFromV<decltype(v)> d; 6628 const RebindToSigned<decltype(d)> di; 6629 return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); 6630 } 6631 6632 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I16_D(D)> 6633 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int32_t> v) { 6634 const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); 6635 return VFromD<D>{_mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; 6636 } 6637 6638 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> 6639 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int32_t> v) { 6640 const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); 6641 // Concatenate lower 64 bits of each 128-bit block 6642 const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); 6643 const __m128i i16 = _mm256_castsi256_si128(i16_concat); 6644 return VFromD<D>{_mm_packus_epi16(i16, i16)}; 6645 } 6646 6647 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> 6648 HWY_API VFromD<D> DemoteTo(D dn, Vec256<uint32_t> v) { 6649 #if HWY_TARGET <= HWY_AVX3 6650 (void)dn; 6651 return VFromD<D>{_mm256_cvtusepi32_epi8(v.raw)}; 6652 #else 6653 const DFromV<decltype(v)> d; 6654 const RebindToSigned<decltype(d)> di; 6655 return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); 6656 #endif 6657 } 6658 6659 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> 6660 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int16_t> v) { 6661 const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); 6662 return VFromD<D>{_mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; 6663 } 6664 6665 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> 6666 HWY_API VFromD<D> DemoteTo(D dn, Vec256<uint16_t> v) { 6667 const DFromV<decltype(v)> d; 6668 const RebindToSigned<decltype(d)> di; 6669 return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFu)))); 6670 } 6671 6672 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_I8_D(D)> 6673 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int32_t> v) { 6674 const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); 6675 // Concatenate lower 64 bits of each 128-bit block 6676 const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); 6677 const __m128i i16 = _mm256_castsi256_si128(i16_concat); 6678 return VFromD<D>{_mm_packs_epi16(i16, i16)}; 6679 } 6680 6681 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I8_D(D)> 6682 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int16_t> v) { 6683 const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); 6684 return VFromD<D>{_mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; 6685 } 6686 6687 #if HWY_TARGET <= HWY_AVX3 6688 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I32_D(D)> 6689 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int64_t> v) { 6690 return VFromD<D>{_mm256_cvtsepi64_epi32(v.raw)}; 6691 } 6692 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_I16_D(D)> 6693 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int64_t> v) { 6694 return VFromD<D>{_mm256_cvtsepi64_epi16(v.raw)}; 6695 } 6696 template <class D, HWY_IF_V_SIZE_D(D, 4), HWY_IF_I8_D(D)> 6697 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int64_t> v) { 6698 return VFromD<D>{_mm256_cvtsepi64_epi8(v.raw)}; 6699 } 6700 6701 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U32_D(D)> 6702 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int64_t> v) { 6703 const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; 6704 return VFromD<D>{_mm256_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; 6705 } 6706 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U16_D(D)> 6707 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int64_t> v) { 6708 const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; 6709 return VFromD<D>{_mm256_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; 6710 } 6711 template <class D, HWY_IF_V_SIZE_D(D, 4), HWY_IF_U8_D(D)> 6712 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<int64_t> v) { 6713 const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; 6714 return VFromD<D>{_mm256_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; 6715 } 6716 6717 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U32_D(D)> 6718 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<uint64_t> v) { 6719 return VFromD<D>{_mm256_cvtusepi64_epi32(v.raw)}; 6720 } 6721 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U16_D(D)> 6722 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<uint64_t> v) { 6723 return VFromD<D>{_mm256_cvtusepi64_epi16(v.raw)}; 6724 } 6725 template <class D, HWY_IF_V_SIZE_D(D, 4), HWY_IF_U8_D(D)> 6726 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<uint64_t> v) { 6727 return VFromD<D>{_mm256_cvtusepi64_epi8(v.raw)}; 6728 } 6729 #endif // HWY_TARGET <= HWY_AVX3 6730 6731 #ifndef HWY_DISABLE_F16C 6732 6733 // Avoid "value of intrinsic immediate argument '8' is out of range '0 - 7'". 6734 // 8 is the correct value of _MM_FROUND_NO_EXC, which is allowed here. 6735 HWY_DIAGNOSTICS(push) 6736 HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") 6737 6738 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F16_D(D)> 6739 HWY_API VFromD<D> DemoteTo(D df16, Vec256<float> v) { 6740 const RebindToUnsigned<decltype(df16)> du16; 6741 return BitCast( 6742 df16, VFromD<decltype(du16)>{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); 6743 } 6744 6745 HWY_DIAGNOSTICS(pop) 6746 6747 #endif // HWY_DISABLE_F16C 6748 6749 #if HWY_HAVE_FLOAT16 6750 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_F16_D(D)> 6751 HWY_API VFromD<D> DemoteTo(D /*df16*/, Vec256<double> v) { 6752 return VFromD<D>{_mm256_cvtpd_ph(v.raw)}; 6753 } 6754 #endif // HWY_HAVE_FLOAT16 6755 6756 #if HWY_AVX3_HAVE_F32_TO_BF16C 6757 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_BF16_D(D)> 6758 HWY_API VFromD<D> DemoteTo(D /*dbf16*/, Vec256<float> v) { 6759 #if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 6760 // Inline assembly workaround for LLVM codegen bug 6761 __m128i raw_result; 6762 __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); 6763 return VFromD<D>{raw_result}; 6764 #else 6765 // The _mm256_cvtneps_pbh intrinsic returns a __m128bh vector that needs to be 6766 // bit casted to a __m128i vector 6767 return VFromD<D>{detail::BitCastToInteger(_mm256_cvtneps_pbh(v.raw))}; 6768 #endif 6769 } 6770 6771 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)> 6772 HWY_API VFromD<D> ReorderDemote2To(D /*dbf16*/, Vec256<float> a, 6773 Vec256<float> b) { 6774 #if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 6775 // Inline assembly workaround for LLVM codegen bug 6776 __m256i raw_result; 6777 __asm__("vcvtne2ps2bf16 %2, %1, %0" 6778 : "=v"(raw_result) 6779 : "v"(b.raw), "v"(a.raw)); 6780 return VFromD<D>{raw_result}; 6781 #else 6782 // The _mm256_cvtne2ps_pbh intrinsic returns a __m256bh vector that needs to 6783 // be bit casted to a __m256i vector 6784 return VFromD<D>{detail::BitCastToInteger(_mm256_cvtne2ps_pbh(b.raw, a.raw))}; 6785 #endif 6786 } 6787 #endif // HWY_AVX3_HAVE_F32_TO_BF16C 6788 6789 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I16_D(D)> 6790 HWY_API VFromD<D> ReorderDemote2To(D /*d16*/, Vec256<int32_t> a, 6791 Vec256<int32_t> b) { 6792 return VFromD<D>{_mm256_packs_epi32(a.raw, b.raw)}; 6793 } 6794 6795 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> 6796 HWY_API VFromD<D> ReorderDemote2To(D /*d16*/, Vec256<int32_t> a, 6797 Vec256<int32_t> b) { 6798 return VFromD<D>{_mm256_packus_epi32(a.raw, b.raw)}; 6799 } 6800 6801 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> 6802 HWY_API VFromD<D> ReorderDemote2To(D dn, Vec256<uint32_t> a, 6803 Vec256<uint32_t> b) { 6804 const DFromV<decltype(a)> d; 6805 const RebindToSigned<decltype(d)> di; 6806 const auto max_i32 = Set(d, 0x7FFFFFFFu); 6807 return ReorderDemote2To(dn, BitCast(di, Min(a, max_i32)), 6808 BitCast(di, Min(b, max_i32))); 6809 } 6810 6811 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I8_D(D)> 6812 HWY_API VFromD<D> ReorderDemote2To(D /*d16*/, Vec256<int16_t> a, 6813 Vec256<int16_t> b) { 6814 return VFromD<D>{_mm256_packs_epi16(a.raw, b.raw)}; 6815 } 6816 6817 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> 6818 HWY_API VFromD<D> ReorderDemote2To(D /*d16*/, Vec256<int16_t> a, 6819 Vec256<int16_t> b) { 6820 return VFromD<D>{_mm256_packus_epi16(a.raw, b.raw)}; 6821 } 6822 6823 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> 6824 HWY_API VFromD<D> ReorderDemote2To(D dn, Vec256<uint16_t> a, 6825 Vec256<uint16_t> b) { 6826 const DFromV<decltype(a)> d; 6827 const RebindToSigned<decltype(d)> di; 6828 const auto max_i16 = Set(d, 0x7FFFu); 6829 return ReorderDemote2To(dn, BitCast(di, Min(a, max_i16)), 6830 BitCast(di, Min(b, max_i16))); 6831 } 6832 6833 #if HWY_TARGET > HWY_AVX3 6834 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> 6835 HWY_API Vec256<int32_t> ReorderDemote2To(D dn, Vec256<int64_t> a, 6836 Vec256<int64_t> b) { 6837 const DFromV<decltype(a)> di64; 6838 const RebindToUnsigned<decltype(di64)> du64; 6839 const Half<decltype(dn)> dnh; 6840 const Repartition<float, decltype(dn)> dn_f; 6841 6842 // Negative values are saturated by first saturating their bitwise inverse 6843 // and then inverting the saturation result 6844 const auto invert_mask_a = BitCast(du64, BroadcastSignBit(a)); 6845 const auto invert_mask_b = BitCast(du64, BroadcastSignBit(b)); 6846 const auto saturated_a = Xor( 6847 invert_mask_a, 6848 detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_a, BitCast(du64, a)))); 6849 const auto saturated_b = Xor( 6850 invert_mask_b, 6851 detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_b, BitCast(du64, b)))); 6852 6853 return BitCast(dn, 6854 Vec256<float>{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, 6855 BitCast(dn_f, saturated_b).raw, 6856 _MM_SHUFFLE(2, 0, 2, 0))}); 6857 } 6858 6859 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> 6860 HWY_API Vec256<uint32_t> ReorderDemote2To(D dn, Vec256<int64_t> a, 6861 Vec256<int64_t> b) { 6862 const DFromV<decltype(a)> di64; 6863 const RebindToUnsigned<decltype(di64)> du64; 6864 const Half<decltype(dn)> dnh; 6865 const Repartition<float, decltype(dn)> dn_f; 6866 6867 const auto saturated_a = detail::DemoteFromU64Saturate( 6868 dnh, BitCast(du64, AndNot(BroadcastSignBit(a), a))); 6869 const auto saturated_b = detail::DemoteFromU64Saturate( 6870 dnh, BitCast(du64, AndNot(BroadcastSignBit(b), b))); 6871 6872 return BitCast(dn, 6873 Vec256<float>{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, 6874 BitCast(dn_f, saturated_b).raw, 6875 _MM_SHUFFLE(2, 0, 2, 0))}); 6876 } 6877 6878 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)> 6879 HWY_API VFromD<D> ReorderDemote2To(D dn, Vec256<uint64_t> a, 6880 Vec256<uint64_t> b) { 6881 const Half<decltype(dn)> dnh; 6882 const Repartition<float, decltype(dn)> dn_f; 6883 6884 const auto saturated_a = detail::DemoteFromU64Saturate(dnh, a); 6885 const auto saturated_b = detail::DemoteFromU64Saturate(dnh, b); 6886 6887 return BitCast(dn, 6888 Vec256<float>{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, 6889 BitCast(dn_f, saturated_b).raw, 6890 _MM_SHUFFLE(2, 0, 2, 0))}); 6891 } 6892 #endif // HWY_TARGET > HWY_AVX3 6893 6894 template <class D, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>), 6895 HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), 6896 HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2), 6897 HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV<V>) * 2), 6898 HWY_IF_T_SIZE_ONE_OF_V(V, 6899 (1 << 1) | (1 << 2) | (1 << 4) | 6900 ((HWY_TARGET > HWY_AVX3) ? (1 << 8) : 0))> 6901 HWY_API VFromD<D> OrderedDemote2To(D d, V a, V b) { 6902 return VFromD<D>{_mm256_permute4x64_epi64(ReorderDemote2To(d, a, b).raw, 6903 _MM_SHUFFLE(3, 1, 2, 0))}; 6904 } 6905 6906 #if HWY_TARGET <= HWY_AVX3 6907 template <class D, HWY_IF_V_SIZE_D(D, HWY_MAX_BYTES), HWY_IF_UI32_D(D)> 6908 HWY_API VFromD<D> ReorderDemote2To(D dn, VFromD<Repartition<int64_t, D>> a, 6909 VFromD<Repartition<int64_t, D>> b) { 6910 const Half<decltype(dn)> dnh; 6911 return Combine(dn, DemoteTo(dnh, b), DemoteTo(dnh, a)); 6912 } 6913 6914 template <class D, HWY_IF_V_SIZE_D(D, HWY_MAX_BYTES), HWY_IF_U32_D(D)> 6915 HWY_API VFromD<D> ReorderDemote2To(D dn, VFromD<Repartition<uint64_t, D>> a, 6916 VFromD<Repartition<uint64_t, D>> b) { 6917 const Half<decltype(dn)> dnh; 6918 return Combine(dn, DemoteTo(dnh, b), DemoteTo(dnh, a)); 6919 } 6920 6921 template <class D, HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>), 6922 HWY_IF_V_SIZE_GT_D(D, 16), class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), 6923 HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2), 6924 HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV<V>) * 2), 6925 HWY_IF_T_SIZE_V(V, 8)> 6926 HWY_API VFromD<D> OrderedDemote2To(D d, V a, V b) { 6927 return ReorderDemote2To(d, a, b); 6928 } 6929 #endif 6930 6931 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F32_D(D)> 6932 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<double> v) { 6933 return VFromD<D>{_mm256_cvtpd_ps(v.raw)}; 6934 } 6935 6936 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I32_D(D)> 6937 HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec256<double> v) { 6938 #if HWY_X86_HAVE_AVX10_2_OPS 6939 return VFromD<D>{_mm256_cvtts_pd_epi32(v.raw)}; 6940 #elif HWY_COMPILER_GCC_ACTUAL 6941 // Workaround for undefined behavior in _mm256_cvttpd_epi32 with GCC if any 6942 // values of v[i] are not within the range of an int32_t 6943 6944 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6945 if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { 6946 typedef double GccF64RawVectType __attribute__((__vector_size__(32))); 6947 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 6948 return Dup128VecFromValues( 6949 D(), detail::X86ConvertScalarFromFloat<int32_t>(raw_v[0]), 6950 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[1]), 6951 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[2]), 6952 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[3])); 6953 } 6954 #endif 6955 6956 __m128i raw_result; 6957 __asm__("vcvttpd2dq {%1, %0|%0, %1}" 6958 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6959 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6960 :); 6961 return VFromD<D>{raw_result}; 6962 #else 6963 return VFromD<D>{_mm256_cvttpd_epi32(v.raw)}; 6964 #endif 6965 } 6966 6967 #if HWY_TARGET <= HWY_AVX3 6968 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U32_D(D)> 6969 HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec256<double> v) { 6970 #if HWY_X86_HAVE_AVX10_2_OPS 6971 return VFromD<D>{_mm256_cvtts_pd_epu32(v.raw)}; 6972 #elif HWY_COMPILER_GCC_ACTUAL 6973 // Workaround for undefined behavior in _mm256_cvttpd_epu32 with GCC if any 6974 // values of v[i] are not within the range of an uint32_t 6975 6976 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6977 if (detail::IsConstantX86VecForF2IConv<uint32_t>(v)) { 6978 typedef double GccF64RawVectType __attribute__((__vector_size__(32))); 6979 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 6980 return Dup128VecFromValues( 6981 D(), detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[0]), 6982 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[1]), 6983 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[2]), 6984 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[3])); 6985 } 6986 #endif 6987 6988 __m128i raw_result; 6989 __asm__("vcvttpd2udq {%1, %0|%0, %1}" 6990 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6991 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6992 :); 6993 return VFromD<D>{raw_result}; 6994 #else 6995 return VFromD<D>{_mm256_cvttpd_epu32(v.raw)}; 6996 #endif 6997 } 6998 6999 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F32_D(D)> 7000 HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int64_t, D>> v) { 7001 return VFromD<D>{_mm256_cvtepi64_ps(v.raw)}; 7002 } 7003 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F32_D(D)> 7004 HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint64_t, D>> v) { 7005 return VFromD<D>{_mm256_cvtepu64_ps(v.raw)}; 7006 } 7007 #endif 7008 7009 // For already range-limited input [0, 255]. 7010 HWY_API Vec128<uint8_t, 8> U8FromU32(const Vec256<uint32_t> v) { 7011 const Full256<uint32_t> d32; 7012 const Full64<uint8_t> d8; 7013 alignas(32) static constexpr uint32_t k8From32[8] = { 7014 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; 7015 // Place first four bytes in lo[0], remaining 4 in hi[1]. 7016 const auto quad = TableLookupBytes(v, Load(d32, k8From32)); 7017 // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. 7018 const auto lo = LowerHalf(quad); 7019 const auto hi = UpperHalf(Half<decltype(d32)>(), quad); 7020 return BitCast(d8, LowerHalf(lo | hi)); 7021 } 7022 7023 // ------------------------------ Truncations 7024 7025 namespace detail { 7026 7027 // LO and HI each hold four indices of bytes within a 128-bit block. 7028 template <uint32_t LO, uint32_t HI, typename T> 7029 HWY_INLINE Vec128<uint32_t> LookupAndConcatHalves(Vec256<T> v) { 7030 const Full256<uint32_t> d32; 7031 7032 #if HWY_TARGET <= HWY_AVX3_DL 7033 alignas(32) static constexpr uint32_t kMap[8] = { 7034 LO, HI, 0x10101010 + LO, 0x10101010 + HI, 0, 0, 0, 0}; 7035 const auto result = _mm256_permutexvar_epi8(Load(d32, kMap).raw, v.raw); 7036 #else 7037 alignas(32) static constexpr uint32_t kMap[8] = {LO, HI, ~0u, ~0u, 7038 ~0u, ~0u, LO, HI}; 7039 const auto quad = TableLookupBytes(v, Load(d32, kMap)); 7040 const auto result = _mm256_permute4x64_epi64(quad.raw, 0xCC); 7041 // Possible alternative: 7042 // const auto lo = LowerHalf(quad); 7043 // const auto hi = UpperHalf(Half<decltype(d32)>(), quad); 7044 // const auto result = lo | hi; 7045 #endif 7046 7047 return Vec128<uint32_t>{_mm256_castsi256_si128(result)}; 7048 } 7049 7050 // LO and HI each hold two indices of bytes within a 128-bit block. 7051 template <uint16_t LO, uint16_t HI, typename T> 7052 HWY_INLINE Vec128<uint32_t, 2> LookupAndConcatQuarters(Vec256<T> v) { 7053 const Full256<uint16_t> d16; 7054 7055 #if HWY_TARGET <= HWY_AVX3_DL 7056 alignas(32) static constexpr uint16_t kMap[16] = { 7057 LO, HI, 0x1010 + LO, 0x1010 + HI, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; 7058 const auto result = _mm256_permutexvar_epi8(Load(d16, kMap).raw, v.raw); 7059 return LowerHalf(Vec128<uint32_t>{_mm256_castsi256_si128(result)}); 7060 #else 7061 constexpr uint16_t ff = static_cast<uint16_t>(~0u); 7062 alignas(32) static constexpr uint16_t kMap[16] = { 7063 LO, ff, HI, ff, ff, ff, ff, ff, ff, ff, ff, ff, LO, ff, HI, ff}; 7064 const auto quad = TableLookupBytes(v, Load(d16, kMap)); 7065 const auto mixed = _mm256_permute4x64_epi64(quad.raw, 0xCC); 7066 const auto half = _mm256_castsi256_si128(mixed); 7067 return LowerHalf(Vec128<uint32_t>{_mm_packus_epi32(half, half)}); 7068 #endif 7069 } 7070 7071 } // namespace detail 7072 7073 template <class D, HWY_IF_V_SIZE_D(D, 4), HWY_IF_U8_D(D)> 7074 HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint64_t> v) { 7075 const Full256<uint32_t> d32; 7076 #if HWY_TARGET <= HWY_AVX3_DL 7077 alignas(32) static constexpr uint32_t kMap[8] = {0x18100800u, 0, 0, 0, 7078 0, 0, 0, 0}; 7079 const auto result = _mm256_permutexvar_epi8(Load(d32, kMap).raw, v.raw); 7080 return LowerHalf(LowerHalf(LowerHalf(Vec256<uint8_t>{result}))); 7081 #else 7082 alignas(32) static constexpr uint32_t kMap[8] = {0xFFFF0800u, ~0u, ~0u, ~0u, 7083 0x0800FFFFu, ~0u, ~0u, ~0u}; 7084 const auto quad = TableLookupBytes(v, Load(d32, kMap)); 7085 const auto lo = LowerHalf(quad); 7086 const auto hi = UpperHalf(Half<decltype(d32)>(), quad); 7087 const auto result = lo | hi; 7088 return LowerHalf(LowerHalf(Vec128<uint8_t>{result.raw})); 7089 #endif 7090 } 7091 7092 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U16_D(D)> 7093 HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint64_t> v) { 7094 const auto result = detail::LookupAndConcatQuarters<0x100, 0x908>(v); 7095 return VFromD<D>{result.raw}; 7096 } 7097 7098 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U32_D(D)> 7099 HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint64_t> v) { 7100 const Full256<uint32_t> d32; 7101 alignas(32) static constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 7102 const auto v32 = 7103 TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven)); 7104 return LowerHalf(Vec256<uint32_t>{v32.raw}); 7105 } 7106 7107 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> 7108 HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint32_t> v) { 7109 const auto full = detail::LookupAndConcatQuarters<0x400, 0xC08>(v); 7110 return VFromD<D>{full.raw}; 7111 } 7112 7113 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> 7114 HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint32_t> v) { 7115 const auto full = detail::LookupAndConcatHalves<0x05040100, 0x0D0C0908>(v); 7116 return VFromD<D>{full.raw}; 7117 } 7118 7119 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> 7120 HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint16_t> v) { 7121 const auto full = detail::LookupAndConcatHalves<0x06040200, 0x0E0C0A08>(v); 7122 return VFromD<D>{full.raw}; 7123 } 7124 7125 // ------------------------------ Integer <=> fp (ShiftRight, OddEven) 7126 7127 #if HWY_HAVE_FLOAT16 7128 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 7129 HWY_API VFromD<D> ConvertTo(D /* tag */, Vec256<uint16_t> v) { 7130 return VFromD<D>{_mm256_cvtepu16_ph(v.raw)}; 7131 } 7132 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 7133 HWY_API VFromD<D> ConvertTo(D /* tag */, Vec256<int16_t> v) { 7134 return VFromD<D>{_mm256_cvtepi16_ph(v.raw)}; 7135 } 7136 #endif // HWY_HAVE_FLOAT16 7137 7138 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 7139 HWY_API VFromD<D> ConvertTo(D /* tag */, Vec256<int32_t> v) { 7140 return VFromD<D>{_mm256_cvtepi32_ps(v.raw)}; 7141 } 7142 7143 #if HWY_TARGET <= HWY_AVX3 7144 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 7145 HWY_API VFromD<D> ConvertTo(D /*df*/, Vec256<uint32_t> v) { 7146 return VFromD<D>{_mm256_cvtepu32_ps(v.raw)}; 7147 } 7148 7149 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 7150 HWY_API VFromD<D> ConvertTo(D /*dd*/, Vec256<int64_t> v) { 7151 return VFromD<D>{_mm256_cvtepi64_pd(v.raw)}; 7152 } 7153 7154 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 7155 HWY_API VFromD<D> ConvertTo(D /*dd*/, Vec256<uint64_t> v) { 7156 return VFromD<D>{_mm256_cvtepu64_pd(v.raw)}; 7157 } 7158 #endif // HWY_TARGET <= HWY_AVX3 7159 7160 // Truncates (rounds toward zero). 7161 7162 #if HWY_HAVE_FLOAT16 7163 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I16_D(D)> 7164 HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec256<float16_t> v) { 7165 #if HWY_COMPILER_GCC_ACTUAL 7166 // Workaround for undefined behavior in _mm256_cvttph_epi16 with GCC if any 7167 // values of v[i] are not within the range of an int16_t 7168 7169 #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ 7170 HWY_HAVE_SCALAR_F16_TYPE 7171 if (detail::IsConstantX86VecForF2IConv<int16_t>(v)) { 7172 typedef hwy::float16_t::Native GccF16RawVectType 7173 __attribute__((__vector_size__(32))); 7174 const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); 7175 return VFromD<D>{_mm256_setr_epi16( 7176 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[0]), 7177 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[1]), 7178 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[2]), 7179 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[3]), 7180 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[4]), 7181 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[5]), 7182 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[6]), 7183 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[7]), 7184 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[8]), 7185 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[9]), 7186 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[10]), 7187 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[11]), 7188 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[12]), 7189 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[13]), 7190 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[14]), 7191 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[15]))}; 7192 } 7193 #endif 7194 7195 __m256i raw_result; 7196 __asm__("vcvttph2w {%1, %0|%0, %1}" 7197 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7198 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7199 :); 7200 return VFromD<D>{raw_result}; 7201 #else // HWY_COMPILER_GCC_ACTUAL < 1200 7202 return VFromD<D>{_mm256_cvttph_epi16(v.raw)}; 7203 #endif 7204 } 7205 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> 7206 HWY_API VFromD<D> ConvertInRangeTo(D /* tag */, VFromD<RebindToFloat<D>> v) { 7207 #if HWY_COMPILER_GCC_ACTUAL 7208 // Workaround for undefined behavior in _mm256_cvttph_epu16 with GCC if any 7209 // values of v[i] are not within the range of an uint16_t 7210 7211 #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ 7212 HWY_HAVE_SCALAR_F16_TYPE 7213 if (detail::IsConstantX86VecForF2IConv<uint16_t>(v)) { 7214 typedef hwy::float16_t::Native GccF16RawVectType 7215 __attribute__((__vector_size__(32))); 7216 const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); 7217 return VFromD<D>{_mm256_setr_epi16( 7218 static_cast<int16_t>( 7219 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[0])), 7220 static_cast<int16_t>( 7221 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[1])), 7222 static_cast<int16_t>( 7223 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[2])), 7224 static_cast<int16_t>( 7225 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[3])), 7226 static_cast<int16_t>( 7227 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[4])), 7228 static_cast<int16_t>( 7229 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[5])), 7230 static_cast<int16_t>( 7231 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[6])), 7232 static_cast<int16_t>( 7233 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[7])), 7234 static_cast<int16_t>( 7235 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[8])), 7236 static_cast<int16_t>( 7237 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[9])), 7238 static_cast<int16_t>( 7239 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[10])), 7240 static_cast<int16_t>( 7241 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[11])), 7242 static_cast<int16_t>( 7243 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[12])), 7244 static_cast<int16_t>( 7245 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[13])), 7246 static_cast<int16_t>( 7247 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[14])), 7248 static_cast<int16_t>( 7249 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[15])))}; 7250 } 7251 #endif 7252 7253 __m256i raw_result; 7254 __asm__("vcvttph2uw {%1, %0|%0, %1}" 7255 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7256 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7257 :); 7258 return VFromD<D>{raw_result}; 7259 #else // HWY_COMPILER_GCC_ACTUAL < 1200 7260 return VFromD<D>{_mm256_cvttph_epu16(v.raw)}; 7261 #endif 7262 } 7263 #endif // HWY_HAVE_FLOAT16 7264 7265 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> 7266 HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec256<float> v) { 7267 #if HWY_X86_HAVE_AVX10_2_OPS 7268 return VFromD<D>{_mm256_cvtts_ps_epi32(v.raw)}; 7269 #elif HWY_COMPILER_GCC_ACTUAL 7270 // Workaround for undefined behavior in _mm256_cvttps_epi32 with GCC if any 7271 // values of v[i] are not within the range of an int32_t 7272 7273 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 7274 if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { 7275 typedef float GccF32RawVectType __attribute__((__vector_size__(32))); 7276 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 7277 return VFromD<D>{_mm256_setr_epi32( 7278 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[0]), 7279 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[1]), 7280 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[2]), 7281 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[3]), 7282 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[4]), 7283 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[5]), 7284 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[6]), 7285 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[7]))}; 7286 } 7287 #endif 7288 7289 __m256i raw_result; 7290 __asm__("vcvttps2dq {%1, %0|%0, %1}" 7291 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7292 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7293 :); 7294 return VFromD<D>{raw_result}; 7295 #else 7296 return VFromD<D>{_mm256_cvttps_epi32(v.raw)}; 7297 #endif 7298 } 7299 7300 #if HWY_TARGET <= HWY_AVX3 7301 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)> 7302 HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, Vec256<double> v) { 7303 #if HWY_X86_HAVE_AVX10_2_OPS 7304 return VFromD<D>{_mm256_cvtts_pd_epi64(v.raw)}; 7305 #elif HWY_COMPILER_GCC_ACTUAL 7306 // Workaround for undefined behavior in _mm256_cvttpd_epi64 with GCC if any 7307 // values of v[i] are not within the range of an int64_t 7308 7309 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 7310 if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { 7311 typedef double GccF64RawVectType __attribute__((__vector_size__(32))); 7312 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 7313 return VFromD<D>{_mm256_setr_epi64x( 7314 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[0]), 7315 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[1]), 7316 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[2]), 7317 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[3]))}; 7318 } 7319 #endif 7320 7321 __m256i raw_result; 7322 __asm__("vcvttpd2qq {%1, %0|%0, %1}" 7323 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7324 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7325 :); 7326 return VFromD<D>{raw_result}; 7327 #else // !HWY_COMPILER_GCC_ACTUAL 7328 return VFromD<D>{_mm256_cvttpd_epi64(v.raw)}; 7329 #endif // HWY_COMPILER_GCC_ACTUAL 7330 } 7331 template <class DU, HWY_IF_V_SIZE_D(DU, 32), HWY_IF_U32_D(DU)> 7332 HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) { 7333 #if HWY_X86_HAVE_AVX10_2_OPS 7334 return VFromD<DU>{_mm256_cvtts_ps_epu32(v.raw)}; 7335 #elif HWY_COMPILER_GCC_ACTUAL 7336 // Workaround for undefined behavior in _mm256_cvttps_epu32 with GCC if any 7337 // values of v[i] are not within the range of an uint32_t 7338 7339 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 7340 if (detail::IsConstantX86VecForF2IConv<uint32_t>(v)) { 7341 typedef float GccF32RawVectType __attribute__((__vector_size__(32))); 7342 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 7343 return VFromD<DU>{_mm256_setr_epi32( 7344 static_cast<int32_t>( 7345 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[0])), 7346 static_cast<int32_t>( 7347 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[1])), 7348 static_cast<int32_t>( 7349 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[2])), 7350 static_cast<int32_t>( 7351 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[3])), 7352 static_cast<int32_t>( 7353 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[4])), 7354 static_cast<int32_t>( 7355 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[5])), 7356 static_cast<int32_t>( 7357 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[6])), 7358 static_cast<int32_t>( 7359 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[7])))}; 7360 } 7361 #endif 7362 7363 __m256i raw_result; 7364 __asm__("vcvttps2udq {%1, %0|%0, %1}" 7365 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7366 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7367 :); 7368 return VFromD<DU>{raw_result}; 7369 #else // !HWY_COMPILER_GCC_ACTUAL 7370 return VFromD<DU>{_mm256_cvttps_epu32(v.raw)}; 7371 #endif // HWY_COMPILER_GCC_ACTUAL 7372 } 7373 template <class DU, HWY_IF_V_SIZE_D(DU, 32), HWY_IF_U64_D(DU)> 7374 HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) { 7375 #if HWY_X86_HAVE_AVX10_2_OPS 7376 return VFromD<DU>{_mm256_cvtts_pd_epu64(v.raw)}; 7377 #elif HWY_COMPILER_GCC_ACTUAL 7378 // Workaround for undefined behavior in _mm256_cvttpd_epu64 with GCC if any 7379 // values of v[i] are not within the range of an uint64_t 7380 7381 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 7382 if (detail::IsConstantX86VecForF2IConv<uint64_t>(v)) { 7383 typedef double GccF64RawVectType __attribute__((__vector_size__(32))); 7384 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 7385 return VFromD<DU>{_mm256_setr_epi64x( 7386 static_cast<int64_t>( 7387 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[0])), 7388 static_cast<int64_t>( 7389 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[1])), 7390 static_cast<int64_t>( 7391 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[2])), 7392 static_cast<int64_t>( 7393 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[3])))}; 7394 } 7395 #endif 7396 7397 __m256i raw_result; 7398 __asm__("vcvttpd2uqq {%1, %0|%0, %1}" 7399 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7400 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7401 :); 7402 return VFromD<DU>{raw_result}; 7403 #else // !HWY_COMPILER_GCC_ACTUAL 7404 return VFromD<DU>{_mm256_cvttpd_epu64(v.raw)}; 7405 #endif // HWY_COMPILER_GCC_ACTUAL 7406 } 7407 #endif // HWY_TARGET <= HWY_AVX3 7408 7409 template <class DI, HWY_IF_V_SIZE_D(DI, 32), HWY_IF_I32_D(DI)> 7410 static HWY_INLINE VFromD<DI> NearestIntInRange(DI, 7411 VFromD<RebindToFloat<DI>> v) { 7412 #if HWY_COMPILER_GCC_ACTUAL 7413 // Workaround for undefined behavior in _mm256_cvtps_epi32 if any values of 7414 // v[i] are not within the range of an int32_t 7415 7416 #if HWY_COMPILER_GCC >= 700 && !HWY_IS_DEBUG_BUILD 7417 if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { 7418 typedef float GccF32RawVectType __attribute__((__vector_size__(32))); 7419 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 7420 return VFromD<DI>{ 7421 _mm256_setr_epi32(detail::X86ScalarNearestInt<int32_t>(raw_v[0]), 7422 detail::X86ScalarNearestInt<int32_t>(raw_v[1]), 7423 detail::X86ScalarNearestInt<int32_t>(raw_v[2]), 7424 detail::X86ScalarNearestInt<int32_t>(raw_v[3]), 7425 detail::X86ScalarNearestInt<int32_t>(raw_v[4]), 7426 detail::X86ScalarNearestInt<int32_t>(raw_v[5]), 7427 detail::X86ScalarNearestInt<int32_t>(raw_v[6]), 7428 detail::X86ScalarNearestInt<int32_t>(raw_v[7]))}; 7429 } 7430 #endif 7431 7432 __m256i raw_result; 7433 __asm__("vcvtps2dq {%1, %0|%0, %1}" 7434 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7435 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7436 :); 7437 return VFromD<DI>{raw_result}; 7438 #else // !HWY_COMPILER_GCC_ACTUAL 7439 return VFromD<DI>{_mm256_cvtps_epi32(v.raw)}; 7440 #endif // HWY_COMPILER_GCC_ACTUAL 7441 } 7442 7443 #if HWY_HAVE_FLOAT16 7444 template <class DI, HWY_IF_V_SIZE_D(DI, 32), HWY_IF_I16_D(DI)> 7445 static HWY_INLINE VFromD<DI> NearestIntInRange(DI /*d*/, Vec256<float16_t> v) { 7446 #if HWY_COMPILER_GCC_ACTUAL 7447 // Workaround for undefined behavior in _mm256_cvtph_epi16 with GCC if any 7448 // values of v[i] are not within the range of an int16_t 7449 7450 #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ 7451 HWY_HAVE_SCALAR_F16_TYPE 7452 if (detail::IsConstantX86VecForF2IConv<int16_t>(v)) { 7453 typedef hwy::float16_t::Native GccF16RawVectType 7454 __attribute__((__vector_size__(32))); 7455 const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); 7456 return VFromD<DI>{ 7457 _mm256_setr_epi16(detail::X86ScalarNearestInt<int16_t>(raw_v[0]), 7458 detail::X86ScalarNearestInt<int16_t>(raw_v[1]), 7459 detail::X86ScalarNearestInt<int16_t>(raw_v[2]), 7460 detail::X86ScalarNearestInt<int16_t>(raw_v[3]), 7461 detail::X86ScalarNearestInt<int16_t>(raw_v[4]), 7462 detail::X86ScalarNearestInt<int16_t>(raw_v[5]), 7463 detail::X86ScalarNearestInt<int16_t>(raw_v[6]), 7464 detail::X86ScalarNearestInt<int16_t>(raw_v[7]), 7465 detail::X86ScalarNearestInt<int16_t>(raw_v[8]), 7466 detail::X86ScalarNearestInt<int16_t>(raw_v[9]), 7467 detail::X86ScalarNearestInt<int16_t>(raw_v[10]), 7468 detail::X86ScalarNearestInt<int16_t>(raw_v[11]), 7469 detail::X86ScalarNearestInt<int16_t>(raw_v[12]), 7470 detail::X86ScalarNearestInt<int16_t>(raw_v[13]), 7471 detail::X86ScalarNearestInt<int16_t>(raw_v[14]), 7472 detail::X86ScalarNearestInt<int16_t>(raw_v[15]))}; 7473 } 7474 #endif 7475 7476 __m256i raw_result; 7477 __asm__("vcvtph2w {%1, %0|%0, %1}" 7478 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7479 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7480 :); 7481 return VFromD<DI>{raw_result}; 7482 #else // HWY_COMPILER_GCC_ACTUAL 7483 return VFromD<DI>{_mm256_cvtph_epi16(v.raw)}; 7484 #endif 7485 } 7486 #endif 7487 7488 #if HWY_TARGET <= HWY_AVX3 7489 template <class DI, HWY_IF_V_SIZE_D(DI, 32), HWY_IF_I64_D(DI)> 7490 static HWY_INLINE VFromD<DI> NearestIntInRange(DI, 7491 VFromD<RebindToFloat<DI>> v) { 7492 #if HWY_COMPILER_GCC_ACTUAL 7493 // Workaround for undefined behavior in _mm256_cvtpd_epi64 with GCC if any 7494 // values of v[i] are not within the range of an int64_t 7495 7496 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 7497 if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { 7498 typedef double GccF64RawVectType __attribute__((__vector_size__(32))); 7499 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 7500 return VFromD<DI>{ 7501 _mm256_setr_epi64x(detail::X86ScalarNearestInt<int64_t>(raw_v[0]), 7502 detail::X86ScalarNearestInt<int64_t>(raw_v[1]), 7503 detail::X86ScalarNearestInt<int64_t>(raw_v[2]), 7504 detail::X86ScalarNearestInt<int64_t>(raw_v[3]))}; 7505 } 7506 #endif 7507 7508 __m256i raw_result; 7509 __asm__("vcvtpd2qq {%1, %0|%0, %1}" 7510 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7511 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7512 :); 7513 return VFromD<DI>{raw_result}; 7514 #else // !HWY_COMPILER_GCC_ACTUAL 7515 return VFromD<DI>{_mm256_cvtpd_epi64(v.raw)}; 7516 #endif // HWY_COMPILER_GCC_ACTUAL 7517 } 7518 #endif // HWY_TARGET <= HWY_AVX3 7519 7520 template <class DI, HWY_IF_V_SIZE_D(DI, 16), HWY_IF_I32_D(DI)> 7521 static HWY_INLINE VFromD<DI> DemoteToNearestIntInRange( 7522 DI, VFromD<Rebind<double, DI>> v) { 7523 #if HWY_COMPILER_GCC_ACTUAL 7524 // Workaround for undefined behavior in _mm256_cvtpd_epi32 with GCC if any 7525 // values of v[i] are not within the range of an int32_t 7526 7527 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 7528 if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { 7529 typedef double GccF32RawVectType __attribute__((__vector_size__(32))); 7530 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 7531 return Dup128VecFromValues(DI(), 7532 detail::X86ScalarNearestInt<int32_t>(raw_v[0]), 7533 detail::X86ScalarNearestInt<int32_t>(raw_v[1]), 7534 detail::X86ScalarNearestInt<int32_t>(raw_v[2]), 7535 detail::X86ScalarNearestInt<int32_t>(raw_v[3])); 7536 } 7537 #endif 7538 7539 __m128i raw_result; 7540 __asm__("vcvtpd2dq {%1, %0|%0, %1}" 7541 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 7542 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 7543 :); 7544 return VFromD<DI>{raw_result}; 7545 #else // !HWY_COMPILER_GCC_ACTUAL 7546 return VFromD<DI>{_mm256_cvtpd_epi32(v.raw)}; 7547 #endif 7548 } 7549 7550 #ifndef HWY_DISABLE_F16C 7551 7552 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 7553 HWY_API VFromD<D> PromoteTo(D df32, Vec128<float16_t> v) { 7554 (void)df32; 7555 #if HWY_HAVE_FLOAT16 7556 const RebindToUnsigned<DFromV<decltype(v)>> du16; 7557 return VFromD<D>{_mm256_cvtph_ps(BitCast(du16, v).raw)}; 7558 #else 7559 return VFromD<D>{_mm256_cvtph_ps(v.raw)}; 7560 #endif // HWY_HAVE_FLOAT16 7561 } 7562 7563 #endif // HWY_DISABLE_F16C 7564 7565 #if HWY_HAVE_FLOAT16 7566 7567 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 7568 HWY_INLINE VFromD<D> PromoteTo(D /*tag*/, Vec64<float16_t> v) { 7569 return VFromD<D>{_mm256_cvtph_pd(v.raw)}; 7570 } 7571 7572 #endif // HWY_HAVE_FLOAT16 7573 7574 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 7575 HWY_API VFromD<D> PromoteTo(D df32, Vec128<bfloat16_t> v) { 7576 const Rebind<uint16_t, decltype(df32)> du16; 7577 const RebindToSigned<decltype(df32)> di32; 7578 return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); 7579 } 7580 7581 // ================================================== CRYPTO 7582 7583 #if !defined(HWY_DISABLE_PCLMUL_AES) 7584 7585 HWY_API Vec256<uint8_t> AESRound(Vec256<uint8_t> state, 7586 Vec256<uint8_t> round_key) { 7587 #if HWY_TARGET <= HWY_AVX3_DL 7588 return Vec256<uint8_t>{_mm256_aesenc_epi128(state.raw, round_key.raw)}; 7589 #else 7590 const Full256<uint8_t> d; 7591 const Half<decltype(d)> d2; 7592 return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), 7593 AESRound(LowerHalf(state), LowerHalf(round_key))); 7594 #endif 7595 } 7596 7597 HWY_API Vec256<uint8_t> AESLastRound(Vec256<uint8_t> state, 7598 Vec256<uint8_t> round_key) { 7599 #if HWY_TARGET <= HWY_AVX3_DL 7600 return Vec256<uint8_t>{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; 7601 #else 7602 const Full256<uint8_t> d; 7603 const Half<decltype(d)> d2; 7604 return Combine(d, 7605 AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), 7606 AESLastRound(LowerHalf(state), LowerHalf(round_key))); 7607 #endif 7608 } 7609 7610 HWY_API Vec256<uint8_t> AESRoundInv(Vec256<uint8_t> state, 7611 Vec256<uint8_t> round_key) { 7612 #if HWY_TARGET <= HWY_AVX3_DL 7613 return Vec256<uint8_t>{_mm256_aesdec_epi128(state.raw, round_key.raw)}; 7614 #else 7615 const Full256<uint8_t> d; 7616 const Half<decltype(d)> d2; 7617 return Combine(d, AESRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), 7618 AESRoundInv(LowerHalf(state), LowerHalf(round_key))); 7619 #endif 7620 } 7621 7622 HWY_API Vec256<uint8_t> AESLastRoundInv(Vec256<uint8_t> state, 7623 Vec256<uint8_t> round_key) { 7624 #if HWY_TARGET <= HWY_AVX3_DL 7625 return Vec256<uint8_t>{_mm256_aesdeclast_epi128(state.raw, round_key.raw)}; 7626 #else 7627 const Full256<uint8_t> d; 7628 const Half<decltype(d)> d2; 7629 return Combine( 7630 d, AESLastRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), 7631 AESLastRoundInv(LowerHalf(state), LowerHalf(round_key))); 7632 #endif 7633 } 7634 7635 template <class V, HWY_IF_V_SIZE_GT_V(V, 16), HWY_IF_U8_D(DFromV<V>)> 7636 HWY_API V AESInvMixColumns(V state) { 7637 const DFromV<decltype(state)> d; 7638 #if HWY_TARGET <= HWY_AVX3_DL 7639 // On AVX3_DL, it is more efficient to do an InvMixColumns operation for a 7640 // 256-bit or 512-bit vector by doing a AESLastRound operation 7641 // (_mm256_aesenclast_epi128/_mm512_aesenclast_epi128) followed by a 7642 // AESRoundInv operation (_mm256_aesdec_epi128/_mm512_aesdec_epi128) than to 7643 // split the vector into 128-bit vectors, carrying out multiple 7644 // _mm_aesimc_si128 operations, and then combining the _mm_aesimc_si128 7645 // results back into a 256-bit or 512-bit vector. 7646 const auto zero = Zero(d); 7647 return AESRoundInv(AESLastRound(state, zero), zero); 7648 #else 7649 const Half<decltype(d)> dh; 7650 return Combine(d, AESInvMixColumns(UpperHalf(dh, state)), 7651 AESInvMixColumns(LowerHalf(dh, state))); 7652 #endif 7653 } 7654 7655 template <uint8_t kRcon> 7656 HWY_API Vec256<uint8_t> AESKeyGenAssist(Vec256<uint8_t> v) { 7657 const Full256<uint8_t> d; 7658 #if HWY_TARGET <= HWY_AVX3_DL 7659 const VFromD<decltype(d)> rconXorMask = Dup128VecFromValues( 7660 d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); 7661 const VFromD<decltype(d)> rotWordShuffle = Dup128VecFromValues( 7662 d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); 7663 const Repartition<uint32_t, decltype(d)> du32; 7664 const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); 7665 const auto sub_word_result = AESLastRound(w13, rconXorMask); 7666 return TableLookupBytes(sub_word_result, rotWordShuffle); 7667 #else 7668 const Half<decltype(d)> d2; 7669 return Combine(d, AESKeyGenAssist<kRcon>(UpperHalf(d2, v)), 7670 AESKeyGenAssist<kRcon>(LowerHalf(v))); 7671 #endif 7672 } 7673 7674 HWY_API Vec256<uint64_t> CLMulLower(Vec256<uint64_t> a, Vec256<uint64_t> b) { 7675 #if HWY_TARGET <= HWY_AVX3_DL 7676 return Vec256<uint64_t>{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; 7677 #else 7678 const Full256<uint64_t> d; 7679 const Half<decltype(d)> d2; 7680 return Combine(d, CLMulLower(UpperHalf(d2, a), UpperHalf(d2, b)), 7681 CLMulLower(LowerHalf(a), LowerHalf(b))); 7682 #endif 7683 } 7684 7685 HWY_API Vec256<uint64_t> CLMulUpper(Vec256<uint64_t> a, Vec256<uint64_t> b) { 7686 #if HWY_TARGET <= HWY_AVX3_DL 7687 return Vec256<uint64_t>{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x11)}; 7688 #else 7689 const Full256<uint64_t> d; 7690 const Half<decltype(d)> d2; 7691 return Combine(d, CLMulUpper(UpperHalf(d2, a), UpperHalf(d2, b)), 7692 CLMulUpper(LowerHalf(a), LowerHalf(b))); 7693 #endif 7694 } 7695 7696 #endif // HWY_DISABLE_PCLMUL_AES 7697 7698 // ================================================== MISC 7699 7700 #if HWY_TARGET <= HWY_AVX3 7701 7702 // ------------------------------ LoadMaskBits 7703 7704 // `p` points to at least 8 readable bytes, not all of which need be valid. 7705 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7706 HWY_API MFromD<D> LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { 7707 constexpr size_t kN = MaxLanes(d); 7708 constexpr size_t kNumBytes = (kN + 7) / 8; 7709 7710 uint64_t mask_bits = 0; 7711 CopyBytes<kNumBytes>(bits, &mask_bits); 7712 7713 if (kN < 8) { 7714 mask_bits &= (1ull << kN) - 1; 7715 } 7716 7717 return MFromD<D>::FromBits(mask_bits); 7718 } 7719 7720 // ------------------------------ StoreMaskBits 7721 7722 // `p` points to at least 8 writable bytes. 7723 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7724 HWY_API size_t StoreMaskBits(D d, MFromD<D> mask, uint8_t* bits) { 7725 constexpr size_t kN = MaxLanes(d); 7726 constexpr size_t kNumBytes = (kN + 7) / 8; 7727 7728 CopyBytes<kNumBytes>(&mask.raw, bits); 7729 7730 // Non-full byte, need to clear the undefined upper bits. 7731 if (kN < 8) { 7732 const int mask_bits = static_cast<int>((1ull << kN) - 1); 7733 bits[0] = static_cast<uint8_t>(bits[0] & mask_bits); 7734 } 7735 return kNumBytes; 7736 } 7737 7738 // ------------------------------ Mask testing 7739 7740 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7741 HWY_API size_t CountTrue(D /* tag */, MFromD<D> mask) { 7742 return PopCount(static_cast<uint64_t>(mask.raw)); 7743 } 7744 7745 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7746 HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD<D> mask) { 7747 return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); 7748 } 7749 7750 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7751 HWY_API intptr_t FindFirstTrue(D d, MFromD<D> mask) { 7752 return mask.raw ? static_cast<intptr_t>(FindKnownFirstTrue(d, mask)) 7753 : intptr_t{-1}; 7754 } 7755 7756 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7757 HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD<D> mask) { 7758 return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask.raw); 7759 } 7760 7761 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7762 HWY_API intptr_t FindLastTrue(D d, MFromD<D> mask) { 7763 return mask.raw ? static_cast<intptr_t>(FindKnownLastTrue(d, mask)) 7764 : intptr_t{-1}; 7765 } 7766 7767 // Beware: the suffix indicates the number of mask bits, not lane size! 7768 7769 namespace detail { 7770 7771 template <typename T> 7772 HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask256<T> mask) { 7773 #if HWY_COMPILER_HAS_MASK_INTRINSICS 7774 return _kortestz_mask32_u8(mask.raw, mask.raw); 7775 #else 7776 return mask.raw == 0; 7777 #endif 7778 } 7779 template <typename T> 7780 HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask256<T> mask) { 7781 #if HWY_COMPILER_HAS_MASK_INTRINSICS 7782 return _kortestz_mask16_u8(mask.raw, mask.raw); 7783 #else 7784 return mask.raw == 0; 7785 #endif 7786 } 7787 template <typename T> 7788 HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask256<T> mask) { 7789 #if HWY_COMPILER_HAS_MASK_INTRINSICS 7790 return _kortestz_mask8_u8(mask.raw, mask.raw); 7791 #else 7792 return mask.raw == 0; 7793 #endif 7794 } 7795 template <typename T> 7796 HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask256<T> mask) { 7797 return (uint64_t{mask.raw} & 0xF) == 0; 7798 } 7799 7800 } // namespace detail 7801 7802 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7803 HWY_API bool AllFalse(D /* tag */, MFromD<D> mask) { 7804 return detail::AllFalse(hwy::SizeTag<sizeof(TFromD<D>)>(), mask); 7805 } 7806 7807 namespace detail { 7808 7809 template <typename T> 7810 HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask256<T> mask) { 7811 #if HWY_COMPILER_HAS_MASK_INTRINSICS 7812 return _kortestc_mask32_u8(mask.raw, mask.raw); 7813 #else 7814 return mask.raw == 0xFFFFFFFFu; 7815 #endif 7816 } 7817 template <typename T> 7818 HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask256<T> mask) { 7819 #if HWY_COMPILER_HAS_MASK_INTRINSICS 7820 return _kortestc_mask16_u8(mask.raw, mask.raw); 7821 #else 7822 return mask.raw == 0xFFFFu; 7823 #endif 7824 } 7825 template <typename T> 7826 HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask256<T> mask) { 7827 #if HWY_COMPILER_HAS_MASK_INTRINSICS 7828 return _kortestc_mask8_u8(mask.raw, mask.raw); 7829 #else 7830 return mask.raw == 0xFFu; 7831 #endif 7832 } 7833 template <typename T> 7834 HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask256<T> mask) { 7835 // Cannot use _kortestc because we have less than 8 mask bits. 7836 return mask.raw == 0xFu; 7837 } 7838 7839 } // namespace detail 7840 7841 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7842 HWY_API bool AllTrue(D /* tag */, const MFromD<D> mask) { 7843 return detail::AllTrue(hwy::SizeTag<sizeof(TFromD<D>)>(), mask); 7844 } 7845 7846 // ------------------------------ Compress 7847 7848 // 16-bit is defined in x86_512 so we can use 512-bit vectors. 7849 7850 template <typename T, HWY_IF_T_SIZE(T, 4)> 7851 HWY_API Vec256<T> Compress(Vec256<T> v, Mask256<T> mask) { 7852 return Vec256<T>{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; 7853 } 7854 7855 HWY_API Vec256<float> Compress(Vec256<float> v, Mask256<float> mask) { 7856 return Vec256<float>{_mm256_maskz_compress_ps(mask.raw, v.raw)}; 7857 } 7858 7859 template <typename T, HWY_IF_T_SIZE(T, 8)> 7860 HWY_API Vec256<T> Compress(Vec256<T> v, Mask256<T> mask) { 7861 // See CompressIsPartition. 7862 alignas(16) static constexpr uint64_t packed_array[16] = { 7863 // PrintCompress64x4NibbleTables 7864 0x00003210, 0x00003210, 0x00003201, 0x00003210, 0x00003102, 0x00003120, 7865 0x00003021, 0x00003210, 0x00002103, 0x00002130, 0x00002031, 0x00002310, 7866 0x00001032, 0x00001320, 0x00000321, 0x00003210}; 7867 7868 // For lane i, shift the i-th 4-bit index down to bits [0, 2) - 7869 // _mm256_permutexvar_epi64 will ignore the upper bits. 7870 const DFromV<decltype(v)> d; 7871 const RebindToUnsigned<decltype(d)> du64; 7872 const auto packed = Set(du64, packed_array[mask.raw]); 7873 alignas(64) static constexpr uint64_t shifts[4] = {0, 4, 8, 12}; 7874 const auto indices = Indices256<T>{(packed >> Load(du64, shifts)).raw}; 7875 return TableLookupLanes(v, indices); 7876 } 7877 7878 // ------------------------------ CompressNot (Compress) 7879 7880 // Implemented in x86_512 for lane size != 8. 7881 7882 template <typename T, HWY_IF_T_SIZE(T, 8)> 7883 HWY_API Vec256<T> CompressNot(Vec256<T> v, Mask256<T> mask) { 7884 // See CompressIsPartition. 7885 alignas(16) static constexpr uint64_t packed_array[16] = { 7886 // PrintCompressNot64x4NibbleTables 7887 0x00003210, 0x00000321, 0x00001320, 0x00001032, 0x00002310, 0x00002031, 7888 0x00002130, 0x00002103, 0x00003210, 0x00003021, 0x00003120, 0x00003102, 7889 0x00003210, 0x00003201, 0x00003210, 0x00003210}; 7890 7891 // For lane i, shift the i-th 4-bit index down to bits [0, 2) - 7892 // _mm256_permutexvar_epi64 will ignore the upper bits. 7893 const DFromV<decltype(v)> d; 7894 const RebindToUnsigned<decltype(d)> du64; 7895 const auto packed = Set(du64, packed_array[mask.raw]); 7896 alignas(32) static constexpr uint64_t shifts[4] = {0, 4, 8, 12}; 7897 const auto indices = Indices256<T>{(packed >> Load(du64, shifts)).raw}; 7898 return TableLookupLanes(v, indices); 7899 } 7900 7901 // ------------------------------ CompressStore (defined in x86_512) 7902 // ------------------------------ CompressBlendedStore (defined in x86_512) 7903 // ------------------------------ CompressBitsStore (defined in x86_512) 7904 7905 #else // AVX2 7906 7907 // ------------------------------ LoadMaskBits (TestBit) 7908 7909 namespace detail { 7910 7911 // 256 suffix avoids ambiguity with x86_128 without needing HWY_IF_V_SIZE. 7912 template <typename T, HWY_IF_T_SIZE(T, 1)> 7913 HWY_INLINE Mask256<T> LoadMaskBits256(uint64_t mask_bits) { 7914 const Full256<T> d; 7915 const RebindToUnsigned<decltype(d)> du; 7916 const Repartition<uint32_t, decltype(d)> du32; 7917 const auto vbits = BitCast(du, Set(du32, static_cast<uint32_t>(mask_bits))); 7918 7919 // Replicate bytes 8x such that each byte contains the bit that governs it. 7920 const Repartition<uint64_t, decltype(d)> du64; 7921 alignas(32) static constexpr uint64_t kRep8[4] = { 7922 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, 7923 0x0303030303030303ull}; 7924 const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); 7925 7926 const VFromD<decltype(du)> bit = Dup128VecFromValues( 7927 du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); 7928 return RebindMask(d, TestBit(rep8, bit)); 7929 } 7930 7931 template <typename T, HWY_IF_T_SIZE(T, 2)> 7932 HWY_INLINE Mask256<T> LoadMaskBits256(uint64_t mask_bits) { 7933 const Full256<T> d; 7934 const RebindToUnsigned<decltype(d)> du; 7935 alignas(32) static constexpr uint16_t kBit[16] = { 7936 1, 2, 4, 8, 16, 32, 64, 128, 7937 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; 7938 const auto vmask_bits = Set(du, static_cast<uint16_t>(mask_bits)); 7939 return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); 7940 } 7941 7942 template <typename T, HWY_IF_T_SIZE(T, 4)> 7943 HWY_INLINE Mask256<T> LoadMaskBits256(uint64_t mask_bits) { 7944 const Full256<T> d; 7945 const RebindToUnsigned<decltype(d)> du; 7946 alignas(32) static constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; 7947 const auto vmask_bits = Set(du, static_cast<uint32_t>(mask_bits)); 7948 return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); 7949 } 7950 7951 template <typename T, HWY_IF_T_SIZE(T, 8)> 7952 HWY_INLINE Mask256<T> LoadMaskBits256(uint64_t mask_bits) { 7953 const Full256<T> d; 7954 const RebindToUnsigned<decltype(d)> du; 7955 alignas(32) static constexpr uint64_t kBit[8] = {1, 2, 4, 8}; 7956 return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); 7957 } 7958 7959 } // namespace detail 7960 7961 // `p` points to at least 8 readable bytes, not all of which need be valid. 7962 template <class D, HWY_IF_V_SIZE_D(D, 32)> 7963 HWY_API MFromD<D> LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { 7964 constexpr size_t kN = MaxLanes(d); 7965 constexpr size_t kNumBytes = (kN + 7) / 8; 7966 7967 uint64_t mask_bits = 0; 7968 CopyBytes<kNumBytes>(bits, &mask_bits); 7969 7970 if (kN < 8) { 7971 mask_bits &= (1ull << kN) - 1; 7972 } 7973 7974 return detail::LoadMaskBits256<TFromD<D>>(mask_bits); 7975 } 7976 7977 // ------------------------------ BitsFromMask 7978 7979 template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_D(D, 32)> 7980 HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) { 7981 const RebindToUnsigned<D> d8; 7982 const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; 7983 // Prevent sign-extension of 32-bit masks because the intrinsic returns int. 7984 return static_cast<uint32_t>(_mm256_movemask_epi8(sign_bits)); 7985 } 7986 7987 template <class D, HWY_IF_T_SIZE_D(D, 2), HWY_IF_V_SIZE_D(D, 32)> 7988 HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) { 7989 #if !defined(HWY_DISABLE_BMI2_FMA) && !defined(HWY_DISABLE_PEXT_ON_AVX2) 7990 const Repartition<uint8_t, D> d8; 7991 const Mask256<uint8_t> mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); 7992 const uint64_t sign_bits8 = BitsFromMask(d8, mask8); 7993 // Skip the bits from the lower byte of each u16 (better not to use the 7994 // same packs_epi16 as SSE4, because that requires an extra swizzle here). 7995 return _pext_u32(static_cast<uint32_t>(sign_bits8), 0xAAAAAAAAu); 7996 #else 7997 // Slow workaround for when BMI2 is disabled 7998 // Remove useless lower half of each u16 while preserving the sign bit. 7999 // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. 8000 const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); 8001 // Move odd qwords (value zero) to top so they don't affect the mask value. 8002 const auto compressed = _mm256_castsi256_si128( 8003 _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0))); 8004 return static_cast<unsigned>(_mm_movemask_epi8(compressed)); 8005 #endif // HWY_ARCH_X86_64 8006 } 8007 8008 template <class D, HWY_IF_T_SIZE_D(D, 4), HWY_IF_V_SIZE_D(D, 32)> 8009 HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) { 8010 const RebindToFloat<D> df; 8011 const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; 8012 return static_cast<unsigned>(_mm256_movemask_ps(sign_bits)); 8013 } 8014 8015 template <class D, HWY_IF_T_SIZE_D(D, 8), HWY_IF_V_SIZE_D(D, 32)> 8016 HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) { 8017 const RebindToFloat<D> df; 8018 const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; 8019 return static_cast<unsigned>(_mm256_movemask_pd(sign_bits)); 8020 } 8021 8022 // ------------------------------ StoreMaskBits 8023 // `p` points to at least 8 writable bytes. 8024 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8025 HWY_API size_t StoreMaskBits(D d, MFromD<D> mask, uint8_t* bits) { 8026 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 8027 HWY_LANES_CONSTEXPR size_t kNumBytes = (N + 7) / 8; 8028 8029 const uint64_t mask_bits = BitsFromMask(d, mask); 8030 CopyBytes(&mask_bits, bits, kNumBytes); 8031 return kNumBytes; 8032 } 8033 8034 // ------------------------------ Mask testing 8035 8036 // Specialize for 16-bit lanes to avoid unnecessary pext. This assumes each mask 8037 // lane is 0 or ~0. 8038 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 8039 HWY_API bool AllFalse(D d, MFromD<D> mask) { 8040 const Repartition<uint8_t, decltype(d)> d8; 8041 const Mask256<uint8_t> mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); 8042 return BitsFromMask(d8, mask8) == 0; 8043 } 8044 8045 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_T_SIZE_D(D, 2)> 8046 HWY_API bool AllFalse(D d, MFromD<D> mask) { 8047 // Cheaper than PTEST, which is 2 uop / 3L. 8048 return BitsFromMask(d, mask) == 0; 8049 } 8050 8051 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 8052 HWY_API bool AllTrue(D d, MFromD<D> mask) { 8053 const Repartition<uint8_t, decltype(d)> d8; 8054 const Mask256<uint8_t> mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); 8055 return BitsFromMask(d8, mask8) == (1ull << 32) - 1; 8056 } 8057 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_T_SIZE_D(D, 2)> 8058 HWY_API bool AllTrue(D d, MFromD<D> mask) { 8059 constexpr uint64_t kAllBits = (1ull << MaxLanes(d)) - 1; 8060 return BitsFromMask(d, mask) == kAllBits; 8061 } 8062 8063 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 8064 HWY_API size_t CountTrue(D d, MFromD<D> mask) { 8065 const Repartition<uint8_t, decltype(d)> d8; 8066 const Mask256<uint8_t> mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); 8067 return PopCount(BitsFromMask(d8, mask8)) >> 1; 8068 } 8069 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_T_SIZE_D(D, 2)> 8070 HWY_API size_t CountTrue(D d, MFromD<D> mask) { 8071 return PopCount(BitsFromMask(d, mask)); 8072 } 8073 8074 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8075 HWY_API size_t FindKnownFirstTrue(D d, MFromD<D> mask) { 8076 const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask)); 8077 return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); 8078 } 8079 8080 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8081 HWY_API intptr_t FindFirstTrue(D d, MFromD<D> mask) { 8082 const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask)); 8083 return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; 8084 } 8085 8086 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8087 HWY_API size_t FindKnownLastTrue(D d, MFromD<D> mask) { 8088 const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask)); 8089 return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits); 8090 } 8091 8092 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8093 HWY_API intptr_t FindLastTrue(D d, MFromD<D> mask) { 8094 const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask)); 8095 return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) 8096 : -1; 8097 } 8098 8099 // ------------------------------ Compress, CompressBits 8100 8101 namespace detail { 8102 8103 template <typename T, HWY_IF_T_SIZE(T, 4)> 8104 HWY_INLINE Vec256<uint32_t> IndicesFromBits256(uint64_t mask_bits) { 8105 const Full256<uint32_t> d32; 8106 // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT 8107 // of SetTableIndices would require 8 KiB, a large part of L1D. The other 8108 // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) 8109 // and unavailable in 32-bit builds. We instead compress each index into 4 8110 // bits, for a total of 1 KiB. 8111 alignas(16) static constexpr uint32_t packed_array[256] = { 8112 // PrintCompress32x8Tables 8113 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8, 8114 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98, 8115 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8, 8116 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98, 8117 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8, 8118 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98, 8119 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8, 8120 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98, 8121 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8, 8122 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98, 8123 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8, 8124 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98, 8125 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8, 8126 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98, 8127 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8, 8128 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98, 8129 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8, 8130 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98, 8131 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8, 8132 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98, 8133 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8, 8134 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98, 8135 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8, 8136 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98, 8137 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8, 8138 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98, 8139 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8, 8140 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98, 8141 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8, 8142 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98, 8143 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8, 8144 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98, 8145 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8, 8146 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98, 8147 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8, 8148 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98, 8149 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8, 8150 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98, 8151 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8, 8152 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98, 8153 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8, 8154 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98, 8155 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98}; 8156 8157 // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. 8158 // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. 8159 // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing 8160 // latency, it may be faster to use LoadDup128 and PSHUFB. 8161 const auto packed = Set(d32, packed_array[mask_bits]); 8162 alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12, 8163 16, 20, 24, 28}; 8164 return packed >> Load(d32, shifts); 8165 } 8166 8167 template <typename T, HWY_IF_T_SIZE(T, 8)> 8168 HWY_INLINE Vec256<uint32_t> IndicesFromBits256(uint64_t mask_bits) { 8169 const Full256<uint32_t> d32; 8170 8171 // For 64-bit, we still need 32-bit indices because there is no 64-bit 8172 // permutevar, but there are only 4 lanes, so we can afford to skip the 8173 // unpacking and load the entire index vector directly. 8174 alignas(32) static constexpr uint32_t u32_indices[128] = { 8175 // PrintCompress64x4PairTables 8176 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8177 10, 11, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 8178 12, 13, 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 2, 3, 6, 7, 8179 10, 11, 12, 13, 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 6, 7, 8180 14, 15, 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 2, 3, 4, 5, 8181 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 4, 5, 8182 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 2, 3, 8183 10, 11, 12, 13, 14, 15, 0, 1, 8, 9, 10, 11, 12, 13, 14, 15}; 8184 return Load(d32, u32_indices + 8 * mask_bits); 8185 } 8186 8187 template <typename T, HWY_IF_T_SIZE(T, 4)> 8188 HWY_INLINE Vec256<uint32_t> IndicesFromNotBits256(uint64_t mask_bits) { 8189 const Full256<uint32_t> d32; 8190 // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT 8191 // of SetTableIndices would require 8 KiB, a large part of L1D. The other 8192 // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) 8193 // and unavailable in 32-bit builds. We instead compress each index into 4 8194 // bits, for a total of 1 KiB. 8195 alignas(16) static constexpr uint32_t packed_array[256] = { 8196 // PrintCompressNot32x8Tables 8197 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9, 8198 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca, 8199 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9, 8200 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb, 8201 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9, 8202 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba, 8203 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9, 8204 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec, 8205 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9, 8206 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea, 8207 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9, 8208 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb, 8209 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9, 8210 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba, 8211 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9, 8212 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd, 8213 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9, 8214 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca, 8215 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9, 8216 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb, 8217 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9, 8218 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba, 8219 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9, 8220 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc, 8221 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9, 8222 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda, 8223 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9, 8224 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb, 8225 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9, 8226 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba, 8227 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9, 8228 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e, 8229 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9, 8230 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca, 8231 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9, 8232 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db, 8233 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9, 8234 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba, 8235 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9, 8236 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c, 8237 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9, 8238 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a, 8239 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98}; 8240 8241 // No need to mask because <_mm256_permutevar8x32_epi32> ignores bits 3..31. 8242 // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. 8243 // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing 8244 // latency, it may be faster to use LoadDup128 and PSHUFB. 8245 const Vec256<uint32_t> packed = Set(d32, packed_array[mask_bits]); 8246 alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12, 8247 16, 20, 24, 28}; 8248 return packed >> Load(d32, shifts); 8249 } 8250 8251 template <typename T, HWY_IF_T_SIZE(T, 8)> 8252 HWY_INLINE Vec256<uint32_t> IndicesFromNotBits256(uint64_t mask_bits) { 8253 const Full256<uint32_t> d32; 8254 8255 // For 64-bit, we still need 32-bit indices because there is no 64-bit 8256 // permutevar, but there are only 4 lanes, so we can afford to skip the 8257 // unpacking and load the entire index vector directly. 8258 alignas(32) static constexpr uint32_t u32_indices[128] = { 8259 // PrintCompressNot64x4PairTables 8260 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 8261 8, 9, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 8262 8, 9, 10, 11, 14, 15, 12, 13, 10, 11, 14, 15, 8, 9, 12, 13, 8263 8, 9, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 8264 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 8, 9, 14, 15, 8265 8, 9, 12, 13, 10, 11, 14, 15, 12, 13, 8, 9, 10, 11, 14, 15, 8266 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 8, 9, 12, 13, 14, 15, 8267 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}; 8268 return Load(d32, u32_indices + 8 * mask_bits); 8269 } 8270 8271 template <typename T, HWY_IF_NOT_T_SIZE(T, 2)> 8272 HWY_INLINE Vec256<T> Compress(Vec256<T> v, const uint64_t mask_bits) { 8273 const DFromV<decltype(v)> d; 8274 const Repartition<uint32_t, decltype(d)> du32; 8275 8276 HWY_DASSERT(mask_bits < (1ull << Lanes(d))); 8277 // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is 8278 // no instruction for 4x64). 8279 const Indices256<uint32_t> indices{IndicesFromBits256<T>(mask_bits).raw}; 8280 return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); 8281 } 8282 8283 // LUTs are infeasible for 2^16 possible masks, so splice together two 8284 // half-vector Compress. 8285 template <typename T, HWY_IF_T_SIZE(T, 2)> 8286 HWY_INLINE Vec256<T> Compress(Vec256<T> v, const uint64_t mask_bits) { 8287 const DFromV<decltype(v)> d; 8288 const RebindToUnsigned<decltype(d)> du; 8289 const auto vu16 = BitCast(du, v); // (required for float16_t inputs) 8290 const Half<decltype(du)> duh; 8291 const auto half0 = LowerHalf(duh, vu16); 8292 const auto half1 = UpperHalf(duh, vu16); 8293 8294 const uint64_t mask_bits0 = mask_bits & 0xFF; 8295 const uint64_t mask_bits1 = mask_bits >> 8; 8296 const auto compressed0 = detail::CompressBits(half0, mask_bits0); 8297 const auto compressed1 = detail::CompressBits(half1, mask_bits1); 8298 8299 alignas(32) uint16_t all_true[16] = {}; 8300 // Store mask=true lanes, left to right. 8301 const size_t num_true0 = PopCount(mask_bits0); 8302 Store(compressed0, duh, all_true); 8303 StoreU(compressed1, duh, all_true + num_true0); 8304 8305 if (hwy::HWY_NAMESPACE::CompressIsPartition<T>::value) { 8306 // Store mask=false lanes, right to left. The second vector fills the upper 8307 // half with right-aligned false lanes. The first vector is shifted 8308 // rightwards to overwrite the true lanes of the second. 8309 alignas(32) uint16_t all_false[16] = {}; 8310 const size_t num_true1 = PopCount(mask_bits1); 8311 Store(compressed1, duh, all_false + 8); 8312 StoreU(compressed0, duh, all_false + num_true1); 8313 8314 const auto mask = FirstN(du, num_true0 + num_true1); 8315 return BitCast(d, 8316 IfThenElse(mask, Load(du, all_true), Load(du, all_false))); 8317 } else { 8318 // Only care about the mask=true lanes. 8319 return BitCast(d, Load(du, all_true)); 8320 } 8321 } 8322 8323 template <typename T, HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))> 8324 HWY_INLINE Vec256<T> CompressNot(Vec256<T> v, const uint64_t mask_bits) { 8325 const DFromV<decltype(v)> d; 8326 const Repartition<uint32_t, decltype(d)> du32; 8327 8328 HWY_DASSERT(mask_bits < (1ull << Lanes(d))); 8329 // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is 8330 // no instruction for 4x64). 8331 const Indices256<uint32_t> indices{IndicesFromNotBits256<T>(mask_bits).raw}; 8332 return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); 8333 } 8334 8335 // LUTs are infeasible for 2^16 possible masks, so splice together two 8336 // half-vector Compress. 8337 template <typename T, HWY_IF_T_SIZE(T, 2)> 8338 HWY_INLINE Vec256<T> CompressNot(Vec256<T> v, const uint64_t mask_bits) { 8339 // Compress ensures only the lower 16 bits are set, so flip those. 8340 return Compress(v, mask_bits ^ 0xFFFF); 8341 } 8342 8343 } // namespace detail 8344 8345 template <typename T, HWY_IF_NOT_T_SIZE(T, 1)> 8346 HWY_API Vec256<T> Compress(Vec256<T> v, Mask256<T> m) { 8347 const DFromV<decltype(v)> d; 8348 return detail::Compress(v, BitsFromMask(d, m)); 8349 } 8350 8351 template <typename T, HWY_IF_NOT_T_SIZE(T, 1)> 8352 HWY_API Vec256<T> CompressNot(Vec256<T> v, Mask256<T> m) { 8353 const DFromV<decltype(v)> d; 8354 return detail::CompressNot(v, BitsFromMask(d, m)); 8355 } 8356 8357 HWY_API Vec256<uint64_t> CompressBlocksNot(Vec256<uint64_t> v, 8358 Mask256<uint64_t> mask) { 8359 return CompressNot(v, mask); 8360 } 8361 8362 template <typename T, HWY_IF_NOT_T_SIZE(T, 1)> 8363 HWY_API Vec256<T> CompressBits(Vec256<T> v, const uint8_t* HWY_RESTRICT bits) { 8364 constexpr size_t N = 32 / sizeof(T); 8365 constexpr size_t kNumBytes = (N + 7) / 8; 8366 8367 uint64_t mask_bits = 0; 8368 CopyBytes<kNumBytes>(bits, &mask_bits); 8369 8370 if (N < 8) { 8371 mask_bits &= (1ull << N) - 1; 8372 } 8373 8374 return detail::Compress(v, mask_bits); 8375 } 8376 8377 // ------------------------------ CompressStore, CompressBitsStore 8378 8379 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_T_SIZE_D(D, 1)> 8380 HWY_API size_t CompressStore(VFromD<D> v, MFromD<D> m, D d, 8381 TFromD<D>* HWY_RESTRICT unaligned) { 8382 const uint64_t mask_bits = BitsFromMask(d, m); 8383 const size_t count = PopCount(mask_bits); 8384 StoreU(detail::Compress(v, mask_bits), d, unaligned); 8385 detail::MaybeUnpoison(unaligned, count); 8386 return count; 8387 } 8388 8389 template <class D, HWY_IF_V_SIZE_D(D, 32), 8390 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))> 8391 HWY_API size_t CompressBlendedStore(VFromD<D> v, MFromD<D> m, D d, 8392 TFromD<D>* HWY_RESTRICT unaligned) { 8393 const uint64_t mask_bits = BitsFromMask(d, m); 8394 const size_t count = PopCount(mask_bits); 8395 8396 const RebindToUnsigned<decltype(d)> du; 8397 const Repartition<uint32_t, decltype(d)> du32; 8398 HWY_DASSERT(mask_bits < (1ull << Lanes(d))); 8399 // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is 8400 // no instruction for 4x64). Nibble MSB encodes FirstN. 8401 const Vec256<uint32_t> idx_mask = 8402 detail::IndicesFromBits256<TFromD<D>>(mask_bits); 8403 // Shift nibble MSB into MSB 8404 const Mask256<uint32_t> mask32 = MaskFromVec(ShiftLeft<28>(idx_mask)); 8405 // First cast to unsigned (RebindMask cannot change lane size) 8406 const MFromD<decltype(du)> mask_u{mask32.raw}; 8407 const MFromD<D> mask = RebindMask(d, mask_u); 8408 const VFromD<D> compressed = BitCast( 8409 d, 8410 TableLookupLanes(BitCast(du32, v), Indices256<uint32_t>{idx_mask.raw})); 8411 8412 BlendedStore(compressed, mask, d, unaligned); 8413 detail::MaybeUnpoison(unaligned, count); 8414 return count; 8415 } 8416 8417 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)> 8418 HWY_API size_t CompressBlendedStore(VFromD<D> v, MFromD<D> m, D d, 8419 TFromD<D>* HWY_RESTRICT unaligned) { 8420 const uint64_t mask_bits = BitsFromMask(d, m); 8421 const size_t count = PopCount(mask_bits); 8422 const VFromD<D> compressed = detail::Compress(v, mask_bits); 8423 8424 #if HWY_MEM_OPS_MIGHT_FAULT // true if HWY_IS_MSAN 8425 // BlendedStore tests mask for each lane, but we know that the mask is 8426 // FirstN, so we can just copy. 8427 alignas(32) TFromD<D> buf[16]; 8428 Store(compressed, d, buf); 8429 CopyBytes(buf, unaligned, count * sizeof(TFromD<D>)); 8430 #else 8431 BlendedStore(compressed, FirstN(d, count), d, unaligned); 8432 #endif 8433 return count; 8434 } 8435 8436 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_T_SIZE_D(D, 1)> 8437 HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits, 8438 D d, TFromD<D>* HWY_RESTRICT unaligned) { 8439 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 8440 HWY_LANES_CONSTEXPR size_t kNumBytes = (N + 7) / 8; 8441 8442 uint64_t mask_bits = 0; 8443 CopyBytes(bits, &mask_bits, kNumBytes); 8444 8445 if (N < 8) { 8446 mask_bits &= (1ull << N) - 1; 8447 } 8448 const size_t count = PopCount(mask_bits); 8449 8450 StoreU(detail::Compress(v, mask_bits), d, unaligned); 8451 detail::MaybeUnpoison(unaligned, count); 8452 return count; 8453 } 8454 8455 #endif // HWY_TARGET <= HWY_AVX3 8456 8457 // ------------------------------ Dup128MaskFromMaskBits 8458 8459 // Generic for all vector lengths >= 32 bytes 8460 template <class D, HWY_IF_V_SIZE_GT_D(D, 16)> 8461 HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { 8462 const Half<decltype(d)> dh; 8463 const auto mh = Dup128MaskFromMaskBits(dh, mask_bits); 8464 return CombineMasks(d, mh, mh); 8465 } 8466 8467 // ------------------------------ Expand 8468 8469 // Always define Expand/LoadExpand because generic_ops only does so for Vec128. 8470 8471 namespace detail { 8472 8473 #if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE // VBMI2 8474 8475 HWY_INLINE Vec256<uint8_t> NativeExpand(Vec256<uint8_t> v, 8476 Mask256<uint8_t> mask) { 8477 return Vec256<uint8_t>{_mm256_maskz_expand_epi8(mask.raw, v.raw)}; 8478 } 8479 8480 HWY_INLINE Vec256<uint16_t> NativeExpand(Vec256<uint16_t> v, 8481 Mask256<uint16_t> mask) { 8482 return Vec256<uint16_t>{_mm256_maskz_expand_epi16(mask.raw, v.raw)}; 8483 } 8484 8485 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> 8486 HWY_INLINE VFromD<D> NativeLoadExpand(MFromD<D> mask, D /* d */, 8487 const uint8_t* HWY_RESTRICT unaligned) { 8488 return VFromD<D>{_mm256_maskz_expandloadu_epi8(mask.raw, unaligned)}; 8489 } 8490 8491 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> 8492 HWY_INLINE VFromD<D> NativeLoadExpand(MFromD<D> mask, D /* d */, 8493 const uint16_t* HWY_RESTRICT unaligned) { 8494 return VFromD<D>{_mm256_maskz_expandloadu_epi16(mask.raw, unaligned)}; 8495 } 8496 8497 #endif // HWY_TARGET <= HWY_AVX3_DL 8498 #if HWY_TARGET <= HWY_AVX3 || HWY_IDE 8499 8500 HWY_INLINE Vec256<uint32_t> NativeExpand(Vec256<uint32_t> v, 8501 Mask256<uint32_t> mask) { 8502 return Vec256<uint32_t>{_mm256_maskz_expand_epi32(mask.raw, v.raw)}; 8503 } 8504 8505 HWY_INLINE Vec256<uint64_t> NativeExpand(Vec256<uint64_t> v, 8506 Mask256<uint64_t> mask) { 8507 return Vec256<uint64_t>{_mm256_maskz_expand_epi64(mask.raw, v.raw)}; 8508 } 8509 8510 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> 8511 HWY_INLINE VFromD<D> NativeLoadExpand(MFromD<D> mask, D /* d */, 8512 const uint32_t* HWY_RESTRICT unaligned) { 8513 return VFromD<D>{_mm256_maskz_expandloadu_epi32(mask.raw, unaligned)}; 8514 } 8515 8516 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)> 8517 HWY_INLINE VFromD<D> NativeLoadExpand(MFromD<D> mask, D /* d */, 8518 const uint64_t* HWY_RESTRICT unaligned) { 8519 return VFromD<D>{_mm256_maskz_expandloadu_epi64(mask.raw, unaligned)}; 8520 } 8521 8522 #endif // HWY_TARGET <= HWY_AVX3 8523 8524 } // namespace detail 8525 8526 template <typename T, HWY_IF_T_SIZE(T, 1)> 8527 HWY_API Vec256<T> Expand(Vec256<T> v, Mask256<T> mask) { 8528 const DFromV<decltype(v)> d; 8529 #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 8530 const RebindToUnsigned<decltype(d)> du; 8531 const MFromD<decltype(du)> mu = RebindMask(du, mask); 8532 return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); 8533 #else 8534 // LUTs are infeasible for so many mask combinations, so Combine two 8535 // half-vector Expand. 8536 const Half<decltype(d)> dh; 8537 const uint64_t mask_bits = BitsFromMask(d, mask); 8538 constexpr size_t N = 32 / sizeof(T); 8539 const size_t countL = PopCount(mask_bits & ((1 << (N / 2)) - 1)); 8540 const Mask128<T> maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask))); 8541 const Vec128<T> expandL = Expand(LowerHalf(v), maskL); 8542 // We have to shift the input by a variable number of bytes, but there isn't 8543 // a table-driven option for that until VBMI, and CPUs with that likely also 8544 // have VBMI2 and thus native Expand. 8545 alignas(32) T lanes[N]; 8546 Store(v, d, lanes); 8547 const Mask128<T> maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask))); 8548 const Vec128<T> expandH = Expand(LoadU(dh, lanes + countL), maskH); 8549 return Combine(d, expandH, expandL); 8550 #endif 8551 } 8552 8553 // If AVX3, this is already implemented by x86_512. 8554 #if HWY_TARGET != HWY_AVX3 8555 8556 template <typename T, HWY_IF_T_SIZE(T, 2)> 8557 HWY_API Vec256<T> Expand(Vec256<T> v, Mask256<T> mask) { 8558 const Full256<T> d; 8559 #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 8560 const RebindToUnsigned<decltype(d)> du; 8561 return BitCast(d, detail::NativeExpand(BitCast(du, v), RebindMask(du, mask))); 8562 #else // AVX2 8563 // LUTs are infeasible for 2^16 possible masks, so splice together two 8564 // half-vector Expand. 8565 const Half<decltype(d)> dh; 8566 const Mask128<T> maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask))); 8567 const Vec128<T> expandL = Expand(LowerHalf(v), maskL); 8568 // We have to shift the input by a variable number of u16. permutevar_epi16 8569 // requires AVX3 and if we had that, we'd use native u32 Expand. The only 8570 // alternative is re-loading, which incurs a store to load forwarding stall. 8571 alignas(32) T lanes[32 / sizeof(T)]; 8572 Store(v, d, lanes); 8573 const Vec128<T> vH = LoadU(dh, lanes + CountTrue(dh, maskL)); 8574 const Mask128<T> maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask))); 8575 const Vec128<T> expandH = Expand(vH, maskH); 8576 return Combine(d, expandH, expandL); 8577 #endif // AVX2 8578 } 8579 8580 #endif // HWY_TARGET != HWY_AVX3 8581 8582 template <typename T, HWY_IF_T_SIZE(T, 4)> 8583 HWY_API Vec256<T> Expand(Vec256<T> v, Mask256<T> mask) { 8584 const Full256<T> d; 8585 #if HWY_TARGET <= HWY_AVX3 8586 const RebindToUnsigned<decltype(d)> du; 8587 const MFromD<decltype(du)> mu = RebindMask(du, mask); 8588 return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); 8589 #else 8590 const RebindToUnsigned<decltype(d)> du; 8591 const uint64_t mask_bits = BitsFromMask(d, mask); 8592 8593 alignas(16) constexpr uint32_t packed_array[256] = { 8594 // PrintExpand32x8Nibble. 8595 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0, 8596 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10, 8597 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0, 8598 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210, 8599 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0, 8600 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10, 8601 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0, 8602 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210, 8603 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0, 8604 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10, 8605 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0, 8606 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210, 8607 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0, 8608 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10, 8609 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0, 8610 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210, 8611 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0, 8612 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10, 8613 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0, 8614 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210, 8615 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0, 8616 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10, 8617 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0, 8618 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210, 8619 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0, 8620 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10, 8621 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0, 8622 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210, 8623 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0, 8624 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10, 8625 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0, 8626 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210, 8627 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0, 8628 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10, 8629 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0, 8630 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210, 8631 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0, 8632 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10, 8633 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0, 8634 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210, 8635 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0, 8636 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10, 8637 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210, 8638 }; 8639 8640 // For lane i, shift the i-th 4-bit index down to bits [0, 3). 8641 const Vec256<uint32_t> packed = Set(du, packed_array[mask_bits]); 8642 alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; 8643 // TableLookupLanes ignores upper bits; avoid bounds-check in IndicesFromVec. 8644 const Indices256<uint32_t> indices{(packed >> Load(du, shifts)).raw}; 8645 const Vec256<uint32_t> expand = TableLookupLanes(BitCast(du, v), indices); 8646 // TableLookupLanes cannot also zero masked-off lanes, so do that now. 8647 return IfThenElseZero(mask, BitCast(d, expand)); 8648 #endif 8649 } 8650 8651 template <typename T, HWY_IF_T_SIZE(T, 8)> 8652 HWY_API Vec256<T> Expand(Vec256<T> v, Mask256<T> mask) { 8653 const Full256<T> d; 8654 #if HWY_TARGET <= HWY_AVX3 8655 const RebindToUnsigned<decltype(d)> du; 8656 const MFromD<decltype(du)> mu = RebindMask(du, mask); 8657 return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); 8658 #else 8659 const RebindToUnsigned<decltype(d)> du; 8660 const uint64_t mask_bits = BitsFromMask(d, mask); 8661 8662 alignas(16) constexpr uint64_t packed_array[16] = { 8663 // PrintExpand64x4Nibble. 8664 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, 8665 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, 8666 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; 8667 8668 // For lane i, shift the i-th 4-bit index down to bits [0, 2). 8669 const Vec256<uint64_t> packed = Set(du, packed_array[mask_bits]); 8670 alignas(32) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; 8671 #if HWY_TARGET <= HWY_AVX3 // native 64-bit TableLookupLanes 8672 // TableLookupLanes ignores upper bits; avoid bounds-check in IndicesFromVec. 8673 const Indices256<uint64_t> indices{(packed >> Load(du, shifts)).raw}; 8674 #else 8675 // 64-bit TableLookupLanes on AVX2 requires IndicesFromVec, which checks 8676 // bounds, so clear the upper bits. 8677 const Vec256<uint64_t> masked = And(packed >> Load(du, shifts), Set(du, 3)); 8678 const Indices256<uint64_t> indices = IndicesFromVec(du, masked); 8679 #endif 8680 const Vec256<uint64_t> expand = TableLookupLanes(BitCast(du, v), indices); 8681 // TableLookupLanes cannot also zero masked-off lanes, so do that now. 8682 return IfThenElseZero(mask, BitCast(d, expand)); 8683 #endif 8684 } 8685 8686 // ------------------------------ LoadExpand 8687 8688 template <class D, HWY_IF_V_SIZE_D(D, 32), 8689 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 1) | (1 << 2))> 8690 HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d, 8691 const TFromD<D>* HWY_RESTRICT unaligned) { 8692 #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 8693 const RebindToUnsigned<decltype(d)> du; 8694 using TU = TFromD<decltype(du)>; 8695 const TU* HWY_RESTRICT pu = reinterpret_cast<const TU*>(unaligned); 8696 const MFromD<decltype(du)> mu = RebindMask(du, mask); 8697 return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); 8698 #else 8699 return Expand(LoadU(d, unaligned), mask); 8700 #endif 8701 } 8702 8703 template <class D, HWY_IF_V_SIZE_D(D, 32), 8704 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))> 8705 HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d, 8706 const TFromD<D>* HWY_RESTRICT unaligned) { 8707 #if HWY_TARGET <= HWY_AVX3 8708 const RebindToUnsigned<decltype(d)> du; 8709 using TU = TFromD<decltype(du)>; 8710 const TU* HWY_RESTRICT pu = reinterpret_cast<const TU*>(unaligned); 8711 const MFromD<decltype(du)> mu = RebindMask(du, mask); 8712 return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); 8713 #else 8714 return Expand(LoadU(d, unaligned), mask); 8715 #endif 8716 } 8717 8718 // ------------------------------ LoadInterleaved3/4 8719 8720 // Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. 8721 8722 namespace detail { 8723 // Input: 8724 // 1 0 (<- first block of unaligned) 8725 // 3 2 8726 // 5 4 8727 // Output: 8728 // 3 0 8729 // 4 1 8730 // 5 2 8731 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8732 HWY_API void LoadTransposedBlocks3(D d, const TFromD<D>* HWY_RESTRICT unaligned, 8733 VFromD<D>& A, VFromD<D>& B, VFromD<D>& C) { 8734 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 8735 const VFromD<D> v10 = LoadU(d, unaligned + 0 * N); // 1 0 8736 const VFromD<D> v32 = LoadU(d, unaligned + 1 * N); 8737 const VFromD<D> v54 = LoadU(d, unaligned + 2 * N); 8738 8739 A = ConcatUpperLower(d, v32, v10); 8740 B = ConcatLowerUpper(d, v54, v10); 8741 C = ConcatUpperLower(d, v54, v32); 8742 } 8743 8744 // Input (128-bit blocks): 8745 // 1 0 (first block of unaligned) 8746 // 3 2 8747 // 5 4 8748 // 7 6 8749 // Output: 8750 // 4 0 (LSB of vA) 8751 // 5 1 8752 // 6 2 8753 // 7 3 8754 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8755 HWY_API void LoadTransposedBlocks4(D d, const TFromD<D>* HWY_RESTRICT unaligned, 8756 VFromD<D>& vA, VFromD<D>& vB, VFromD<D>& vC, 8757 VFromD<D>& vD) { 8758 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 8759 const VFromD<D> v10 = LoadU(d, unaligned + 0 * N); 8760 const VFromD<D> v32 = LoadU(d, unaligned + 1 * N); 8761 const VFromD<D> v54 = LoadU(d, unaligned + 2 * N); 8762 const VFromD<D> v76 = LoadU(d, unaligned + 3 * N); 8763 8764 vA = ConcatLowerLower(d, v54, v10); 8765 vB = ConcatUpperUpper(d, v54, v10); 8766 vC = ConcatLowerLower(d, v76, v32); 8767 vD = ConcatUpperUpper(d, v76, v32); 8768 } 8769 } // namespace detail 8770 8771 // ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) 8772 8773 // Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. 8774 8775 namespace detail { 8776 // Input (128-bit blocks): 8777 // 2 0 (LSB of i) 8778 // 3 1 8779 // Output: 8780 // 1 0 8781 // 3 2 8782 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8783 HWY_API void StoreTransposedBlocks2(VFromD<D> i, VFromD<D> j, D d, 8784 TFromD<D>* HWY_RESTRICT unaligned) { 8785 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 8786 const auto out0 = ConcatLowerLower(d, j, i); 8787 const auto out1 = ConcatUpperUpper(d, j, i); 8788 StoreU(out0, d, unaligned + 0 * N); 8789 StoreU(out1, d, unaligned + 1 * N); 8790 } 8791 8792 // Input (128-bit blocks): 8793 // 3 0 (LSB of i) 8794 // 4 1 8795 // 5 2 8796 // Output: 8797 // 1 0 8798 // 3 2 8799 // 5 4 8800 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8801 HWY_API void StoreTransposedBlocks3(VFromD<D> i, VFromD<D> j, VFromD<D> k, D d, 8802 TFromD<D>* HWY_RESTRICT unaligned) { 8803 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 8804 const auto out0 = ConcatLowerLower(d, j, i); 8805 const auto out1 = ConcatUpperLower(d, i, k); 8806 const auto out2 = ConcatUpperUpper(d, k, j); 8807 StoreU(out0, d, unaligned + 0 * N); 8808 StoreU(out1, d, unaligned + 1 * N); 8809 StoreU(out2, d, unaligned + 2 * N); 8810 } 8811 8812 // Input (128-bit blocks): 8813 // 4 0 (LSB of i) 8814 // 5 1 8815 // 6 2 8816 // 7 3 8817 // Output: 8818 // 1 0 8819 // 3 2 8820 // 5 4 8821 // 7 6 8822 template <class D, HWY_IF_V_SIZE_D(D, 32)> 8823 HWY_API void StoreTransposedBlocks4(VFromD<D> i, VFromD<D> j, VFromD<D> k, 8824 VFromD<D> l, D d, 8825 TFromD<D>* HWY_RESTRICT unaligned) { 8826 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 8827 // Write lower halves, then upper. 8828 const auto out0 = ConcatLowerLower(d, j, i); 8829 const auto out1 = ConcatLowerLower(d, l, k); 8830 StoreU(out0, d, unaligned + 0 * N); 8831 StoreU(out1, d, unaligned + 1 * N); 8832 const auto out2 = ConcatUpperUpper(d, j, i); 8833 const auto out3 = ConcatUpperUpper(d, l, k); 8834 StoreU(out2, d, unaligned + 2 * N); 8835 StoreU(out3, d, unaligned + 3 * N); 8836 } 8837 } // namespace detail 8838 8839 // ------------------------------ Additional mask logical operations 8840 8841 #if HWY_TARGET <= HWY_AVX3 8842 template <class T> 8843 HWY_API Mask256<T> SetAtOrAfterFirst(Mask256<T> mask) { 8844 constexpr size_t N = MaxLanes(Full256<T>()); 8845 constexpr uint32_t kActiveElemMask = 8846 static_cast<uint32_t>((uint64_t{1} << N) - 1); 8847 return Mask256<T>{static_cast<typename Mask256<T>::Raw>( 8848 (0u - detail::AVX3Blsi(mask.raw)) & kActiveElemMask)}; 8849 } 8850 template <class T> 8851 HWY_API Mask256<T> SetBeforeFirst(Mask256<T> mask) { 8852 constexpr size_t N = MaxLanes(Full256<T>()); 8853 constexpr uint32_t kActiveElemMask = 8854 static_cast<uint32_t>((uint64_t{1} << N) - 1); 8855 return Mask256<T>{static_cast<typename Mask256<T>::Raw>( 8856 (detail::AVX3Blsi(mask.raw) - 1u) & kActiveElemMask)}; 8857 } 8858 template <class T> 8859 HWY_API Mask256<T> SetAtOrBeforeFirst(Mask256<T> mask) { 8860 constexpr size_t N = MaxLanes(Full256<T>()); 8861 constexpr uint32_t kActiveElemMask = 8862 static_cast<uint32_t>((uint64_t{1} << N) - 1); 8863 return Mask256<T>{static_cast<typename Mask256<T>::Raw>( 8864 detail::AVX3Blsmsk(mask.raw) & kActiveElemMask)}; 8865 } 8866 template <class T> 8867 HWY_API Mask256<T> SetOnlyFirst(Mask256<T> mask) { 8868 return Mask256<T>{ 8869 static_cast<typename Mask256<T>::Raw>(detail::AVX3Blsi(mask.raw))}; 8870 } 8871 #else // AVX2 8872 template <class T> 8873 HWY_API Mask256<T> SetAtOrAfterFirst(Mask256<T> mask) { 8874 const Full256<T> d; 8875 const Repartition<int64_t, decltype(d)> di64; 8876 const Repartition<float, decltype(d)> df32; 8877 const Repartition<int32_t, decltype(d)> di32; 8878 const Half<decltype(di64)> dh_i64; 8879 const Half<decltype(di32)> dh_i32; 8880 using VF32 = VFromD<decltype(df32)>; 8881 8882 auto vmask = BitCast(di64, VecFromMask(d, mask)); 8883 vmask = Or(vmask, Neg(vmask)); 8884 8885 // Copy the sign bit of the even int64_t lanes to the odd int64_t lanes 8886 const auto vmask2 = BitCast( 8887 di32, VF32{_mm256_shuffle_ps(Zero(df32).raw, BitCast(df32, vmask).raw, 8888 _MM_SHUFFLE(1, 1, 0, 0))}); 8889 vmask = Or(vmask, BitCast(di64, BroadcastSignBit(vmask2))); 8890 8891 // Copy the sign bit of the lower 128-bit half to the upper 128-bit half 8892 const auto vmask3 = 8893 BroadcastSignBit(Broadcast<3>(BitCast(dh_i32, LowerHalf(dh_i64, vmask)))); 8894 vmask = Or(vmask, BitCast(di64, Combine(di32, vmask3, Zero(dh_i32)))); 8895 return MaskFromVec(BitCast(d, vmask)); 8896 } 8897 8898 template <class T> 8899 HWY_API Mask256<T> SetBeforeFirst(Mask256<T> mask) { 8900 return Not(SetAtOrAfterFirst(mask)); 8901 } 8902 8903 template <class T> 8904 HWY_API Mask256<T> SetOnlyFirst(Mask256<T> mask) { 8905 const Full256<T> d; 8906 const RebindToSigned<decltype(d)> di; 8907 const Repartition<int64_t, decltype(d)> di64; 8908 const Half<decltype(di64)> dh_i64; 8909 8910 const auto zero = Zero(di64); 8911 const auto vmask = BitCast(di64, VecFromMask(d, mask)); 8912 8913 const auto vmask_eq_0 = VecFromMask(di64, vmask == zero); 8914 auto vmask2_lo = LowerHalf(dh_i64, vmask_eq_0); 8915 auto vmask2_hi = UpperHalf(dh_i64, vmask_eq_0); 8916 8917 vmask2_lo = And(vmask2_lo, InterleaveLower(vmask2_lo, vmask2_lo)); 8918 vmask2_hi = And(ConcatLowerUpper(dh_i64, vmask2_hi, vmask2_lo), 8919 InterleaveUpper(dh_i64, vmask2_lo, vmask2_lo)); 8920 vmask2_lo = InterleaveLower(Set(dh_i64, int64_t{-1}), vmask2_lo); 8921 8922 const auto vmask2 = Combine(di64, vmask2_hi, vmask2_lo); 8923 const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); 8924 return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); 8925 } 8926 8927 template <class T> 8928 HWY_API Mask256<T> SetAtOrBeforeFirst(Mask256<T> mask) { 8929 const Full256<T> d; 8930 constexpr size_t kLanesPerBlock = MaxLanes(d) / 2; 8931 8932 const auto vmask = VecFromMask(d, mask); 8933 const auto vmask_lo = ConcatLowerLower(d, vmask, Zero(d)); 8934 return SetBeforeFirst( 8935 MaskFromVec(CombineShiftRightBytes<(kLanesPerBlock - 1) * sizeof(T)>( 8936 d, vmask, vmask_lo))); 8937 } 8938 #endif // HWY_TARGET <= HWY_AVX3 8939 8940 // ------------------------------ Reductions in generic_ops 8941 8942 // ------------------------------ BitShuffle 8943 #if HWY_TARGET <= HWY_AVX3_DL 8944 template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>), 8945 HWY_IF_V_SIZE_V(V, 32), HWY_IF_V_SIZE_V(VI, 32)> 8946 HWY_API V BitShuffle(V v, VI idx) { 8947 const DFromV<decltype(v)> d64; 8948 const RebindToUnsigned<decltype(d64)> du64; 8949 const Rebind<uint8_t, decltype(d64)> du8; 8950 8951 int32_t i32_bit_shuf_result = 8952 static_cast<int32_t>(_mm256_bitshuffle_epi64_mask(v.raw, idx.raw)); 8953 8954 return BitCast(d64, PromoteTo(du64, VFromD<decltype(du8)>{_mm_cvtsi32_si128( 8955 i32_bit_shuf_result)})); 8956 } 8957 #endif // HWY_TARGET <= HWY_AVX3_DL 8958 8959 // ------------------------------ MultiRotateRight 8960 8961 #if HWY_TARGET <= HWY_AVX3_DL 8962 8963 #ifdef HWY_NATIVE_MULTIROTATERIGHT 8964 #undef HWY_NATIVE_MULTIROTATERIGHT 8965 #else 8966 #define HWY_NATIVE_MULTIROTATERIGHT 8967 #endif 8968 8969 template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>), 8970 HWY_IF_V_SIZE_V(V, 32), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> 8971 HWY_API V MultiRotateRight(V v, VI idx) { 8972 return V{_mm256_multishift_epi64_epi8(idx.raw, v.raw)}; 8973 } 8974 8975 #endif 8976 8977 // ------------------------------ LeadingZeroCount 8978 8979 #if HWY_TARGET <= HWY_AVX3 8980 template <class V, HWY_IF_UI32(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)> 8981 HWY_API V LeadingZeroCount(V v) { 8982 return V{_mm256_lzcnt_epi32(v.raw)}; 8983 } 8984 8985 template <class V, HWY_IF_UI64(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)> 8986 HWY_API V LeadingZeroCount(V v) { 8987 return V{_mm256_lzcnt_epi64(v.raw)}; 8988 } 8989 8990 namespace detail { 8991 8992 template <class V, HWY_IF_UNSIGNED_V(V), 8993 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2)), 8994 HWY_IF_LANES_LE_D(DFromV<V>, HWY_MAX_BYTES / 4)> 8995 static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { 8996 const DFromV<decltype(v)> d; 8997 const Rebind<int32_t, decltype(d)> di32; 8998 const Rebind<uint32_t, decltype(d)> du32; 8999 9000 const auto v_lz_count = LeadingZeroCount(PromoteTo(du32, v)); 9001 return DemoteTo(d, BitCast(di32, v_lz_count)); 9002 } 9003 9004 template <class V, HWY_IF_UNSIGNED_V(V), HWY_IF_T_SIZE_V(V, 4)> 9005 static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { 9006 return LeadingZeroCount(v); 9007 } 9008 9009 template <class V, HWY_IF_UNSIGNED_V(V), 9010 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2)), 9011 HWY_IF_LANES_GT_D(DFromV<V>, HWY_MAX_BYTES / 4)> 9012 static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { 9013 const DFromV<decltype(v)> d; 9014 const RepartitionToWide<decltype(d)> dw; 9015 const RebindToSigned<decltype(dw)> dw_i; 9016 9017 const auto lo_v_lz_count = Lzcnt32ForU8OrU16OrU32(PromoteLowerTo(dw, v)); 9018 const auto hi_v_lz_count = Lzcnt32ForU8OrU16OrU32(PromoteUpperTo(dw, v)); 9019 return OrderedDemote2To(d, BitCast(dw_i, lo_v_lz_count), 9020 BitCast(dw_i, hi_v_lz_count)); 9021 } 9022 9023 } // namespace detail 9024 9025 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), 9026 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2))> 9027 HWY_API V LeadingZeroCount(V v) { 9028 const DFromV<decltype(v)> d; 9029 const RebindToUnsigned<decltype(d)> du; 9030 using TU = TFromD<decltype(du)>; 9031 9032 constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; 9033 const auto v_lzcnt32 = detail::Lzcnt32ForU8OrU16OrU32(BitCast(du, v)); 9034 return BitCast(d, Min(v_lzcnt32 - Set(du, TU{32 - kNumOfBitsInT}), 9035 Set(du, TU{kNumOfBitsInT}))); 9036 } 9037 9038 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), 9039 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2))> 9040 HWY_API V HighestSetBitIndex(V v) { 9041 const DFromV<decltype(v)> d; 9042 const RebindToUnsigned<decltype(d)> du; 9043 using TU = TFromD<decltype(du)>; 9044 return BitCast( 9045 d, Set(du, TU{31}) - detail::Lzcnt32ForU8OrU16OrU32(BitCast(du, v))); 9046 } 9047 9048 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), 9049 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 4) | (1 << 8))> 9050 HWY_API V HighestSetBitIndex(V v) { 9051 const DFromV<decltype(v)> d; 9052 using T = TFromD<decltype(d)>; 9053 return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); 9054 } 9055 9056 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 9057 HWY_API V TrailingZeroCount(V v) { 9058 const DFromV<decltype(v)> d; 9059 const RebindToSigned<decltype(d)> di; 9060 using T = TFromD<decltype(d)>; 9061 9062 const auto vi = BitCast(di, v); 9063 const auto lowest_bit = BitCast(d, And(vi, Neg(vi))); 9064 constexpr T kNumOfBitsInT{sizeof(T) * 8}; 9065 const auto bit_idx = HighestSetBitIndex(lowest_bit); 9066 return IfThenElse(MaskFromVec(bit_idx), Set(d, kNumOfBitsInT), bit_idx); 9067 } 9068 #endif // HWY_TARGET <= HWY_AVX3 9069 9070 // NOLINTNEXTLINE(google-readability-namespace-comments) 9071 } // namespace HWY_NAMESPACE 9072 } // namespace hwy 9073 HWY_AFTER_NAMESPACE(); 9074 9075 // Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - 9076 // the warning seems to be issued at the call site of intrinsics, i.e. our code. 9077 HWY_DIAGNOSTICS(pop)