x86_512-inl.h (302993B)
1 // Copyright 2019 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // 512-bit AVX512 vectors and operations. 17 // External include guard in highway.h - see comment there. 18 19 // WARNING: most operations do not cross 128-bit block boundaries. In 20 // particular, "Broadcast", pack and zip behavior may be surprising. 21 22 // Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL 23 #include "hwy/base.h" 24 25 // Avoid uninitialized warnings in GCC's avx512fintrin.h - see 26 // https://github.com/google/highway/issues/710) 27 HWY_DIAGNOSTICS(push) 28 #if HWY_COMPILER_GCC_ACTUAL 29 HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") 30 HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, 31 ignored "-Wmaybe-uninitialized") 32 #endif 33 34 #include <immintrin.h> // AVX2+ 35 36 #if HWY_COMPILER_CLANGCL 37 // Including <immintrin.h> should be enough, but Clang's headers helpfully skip 38 // including these headers when _MSC_VER is defined, like when using clang-cl. 39 // Include these directly here. 40 // clang-format off 41 #include <smmintrin.h> 42 43 #include <avxintrin.h> 44 // avxintrin defines __m256i and must come before avx2intrin. 45 #include <avx2intrin.h> 46 #include <f16cintrin.h> 47 #include <fmaintrin.h> 48 49 #include <avx512fintrin.h> 50 #include <avx512vlintrin.h> 51 #include <avx512bwintrin.h> 52 #include <avx512vlbwintrin.h> 53 #include <avx512dqintrin.h> 54 #include <avx512vldqintrin.h> 55 #include <avx512cdintrin.h> 56 #include <avx512vlcdintrin.h> 57 58 #if HWY_TARGET <= HWY_AVX3_DL 59 #include <avx512bitalgintrin.h> 60 #include <avx512vlbitalgintrin.h> 61 #include <avx512vbmiintrin.h> 62 #include <avx512vbmivlintrin.h> 63 #include <avx512vbmi2intrin.h> 64 #include <avx512vlvbmi2intrin.h> 65 #include <avx512vpopcntdqintrin.h> 66 #include <avx512vpopcntdqvlintrin.h> 67 #include <avx512vnniintrin.h> 68 #include <avx512vlvnniintrin.h> 69 // Must come after avx512fintrin, else will not define 512-bit intrinsics. 70 #include <vaesintrin.h> 71 #include <vpclmulqdqintrin.h> 72 #include <gfniintrin.h> 73 #endif // HWY_TARGET <= HWY_AVX3_DL 74 75 #if HWY_TARGET <= HWY_AVX3_SPR 76 #include <avx512fp16intrin.h> 77 #include <avx512vlfp16intrin.h> 78 #endif // HWY_TARGET <= HWY_AVX3_SPR 79 80 // clang-format on 81 #endif // HWY_COMPILER_CLANGCL 82 83 // For half-width vectors. Already includes base.h and shared-inl.h. 84 #include "hwy/ops/x86_256-inl.h" 85 86 HWY_BEFORE_NAMESPACE(); 87 namespace hwy { 88 namespace HWY_NAMESPACE { 89 90 namespace detail { 91 92 template <typename T> 93 struct Raw512 { 94 using type = __m512i; 95 }; 96 #if HWY_HAVE_FLOAT16 97 template <> 98 struct Raw512<float16_t> { 99 using type = __m512h; 100 }; 101 #endif // HWY_HAVE_FLOAT16 102 template <> 103 struct Raw512<float> { 104 using type = __m512; 105 }; 106 template <> 107 struct Raw512<double> { 108 using type = __m512d; 109 }; 110 111 // Template arg: sizeof(lane type) 112 template <size_t size> 113 struct RawMask512 {}; 114 template <> 115 struct RawMask512<1> { 116 using type = __mmask64; 117 }; 118 template <> 119 struct RawMask512<2> { 120 using type = __mmask32; 121 }; 122 template <> 123 struct RawMask512<4> { 124 using type = __mmask16; 125 }; 126 template <> 127 struct RawMask512<8> { 128 using type = __mmask8; 129 }; 130 131 } // namespace detail 132 133 template <typename T> 134 class Vec512 { 135 using Raw = typename detail::Raw512<T>::type; 136 137 public: 138 using PrivateT = T; // only for DFromV 139 static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromV 140 141 // Compound assignment. Only usable if there is a corresponding non-member 142 // binary operator overload. For example, only f32 and f64 support division. 143 HWY_INLINE Vec512& operator*=(const Vec512 other) { 144 return *this = (*this * other); 145 } 146 HWY_INLINE Vec512& operator/=(const Vec512 other) { 147 return *this = (*this / other); 148 } 149 HWY_INLINE Vec512& operator+=(const Vec512 other) { 150 return *this = (*this + other); 151 } 152 HWY_INLINE Vec512& operator-=(const Vec512 other) { 153 return *this = (*this - other); 154 } 155 HWY_INLINE Vec512& operator%=(const Vec512 other) { 156 return *this = (*this % other); 157 } 158 HWY_INLINE Vec512& operator&=(const Vec512 other) { 159 return *this = (*this & other); 160 } 161 HWY_INLINE Vec512& operator|=(const Vec512 other) { 162 return *this = (*this | other); 163 } 164 HWY_INLINE Vec512& operator^=(const Vec512 other) { 165 return *this = (*this ^ other); 166 } 167 168 Raw raw; 169 }; 170 171 // Mask register: one bit per lane. 172 template <typename T> 173 struct Mask512 { 174 using Raw = typename detail::RawMask512<sizeof(T)>::type; 175 176 using PrivateT = T; // only for DFromM 177 static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromM 178 179 Raw raw; 180 }; 181 182 template <typename T> 183 using Full512 = Simd<T, 64 / sizeof(T), 0>; 184 185 // ------------------------------ BitCast 186 187 namespace detail { 188 189 HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } 190 #if HWY_HAVE_FLOAT16 191 HWY_INLINE __m512i BitCastToInteger(__m512h v) { 192 return _mm512_castph_si512(v); 193 } 194 #endif // HWY_HAVE_FLOAT16 195 HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } 196 HWY_INLINE __m512i BitCastToInteger(__m512d v) { 197 return _mm512_castpd_si512(v); 198 } 199 200 #if HWY_AVX3_HAVE_F32_TO_BF16C 201 HWY_INLINE __m512i BitCastToInteger(__m512bh v) { 202 // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to 203 // bit cast a __m512bh to a __m512i as there is currently no intrinsic 204 // available (as of GCC 13 and Clang 17) that can bit cast a __m512bh vector 205 // to a __m512i vector 206 207 #if HWY_COMPILER_GCC || HWY_COMPILER_CLANG 208 // On GCC or Clang, use reinterpret_cast to bit cast a __m512bh to a __m512i 209 return reinterpret_cast<__m512i>(v); 210 #else 211 // On MSVC, use BitCastScalar to bit cast a __m512bh to a __m512i as MSVC does 212 // not allow reinterpret_cast, static_cast, or a C-style cast to be used to 213 // bit cast from one AVX vector type to a different AVX vector type 214 return BitCastScalar<__m512i>(v); 215 #endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG 216 } 217 #endif // HWY_AVX3_HAVE_F32_TO_BF16C 218 219 template <typename T> 220 HWY_INLINE Vec512<uint8_t> BitCastToByte(Vec512<T> v) { 221 return Vec512<uint8_t>{BitCastToInteger(v.raw)}; 222 } 223 224 // Cannot rely on function overloading because return types differ. 225 template <typename T> 226 struct BitCastFromInteger512 { 227 HWY_INLINE __m512i operator()(__m512i v) { return v; } 228 }; 229 #if HWY_HAVE_FLOAT16 230 template <> 231 struct BitCastFromInteger512<float16_t> { 232 HWY_INLINE __m512h operator()(__m512i v) { return _mm512_castsi512_ph(v); } 233 }; 234 #endif // HWY_HAVE_FLOAT16 235 template <> 236 struct BitCastFromInteger512<float> { 237 HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } 238 }; 239 template <> 240 struct BitCastFromInteger512<double> { 241 HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } 242 }; 243 244 template <class D, HWY_IF_V_SIZE_D(D, 64)> 245 HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */, Vec512<uint8_t> v) { 246 return VFromD<D>{BitCastFromInteger512<TFromD<D>>()(v.raw)}; 247 } 248 249 } // namespace detail 250 251 template <class D, HWY_IF_V_SIZE_D(D, 64), typename FromT> 252 HWY_API VFromD<D> BitCast(D d, Vec512<FromT> v) { 253 return detail::BitCastFromByte(d, detail::BitCastToByte(v)); 254 } 255 256 // ------------------------------ Set 257 258 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 259 HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { 260 return VFromD<D>{_mm512_set1_epi8(static_cast<char>(t))}; // NOLINT 261 } 262 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI16_D(D)> 263 HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { 264 return VFromD<D>{_mm512_set1_epi16(static_cast<short>(t))}; // NOLINT 265 } 266 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 267 HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { 268 return VFromD<D>{_mm512_set1_epi32(static_cast<int>(t))}; 269 } 270 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 271 HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) { 272 return VFromD<D>{_mm512_set1_epi64(static_cast<long long>(t))}; // NOLINT 273 } 274 // bfloat16_t is handled by x86_128-inl.h. 275 #if HWY_HAVE_FLOAT16 276 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 277 HWY_API Vec512<float16_t> Set(D /* tag */, float16_t t) { 278 return Vec512<float16_t>{_mm512_set1_ph(t)}; 279 } 280 #endif // HWY_HAVE_FLOAT16 281 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 282 HWY_API Vec512<float> Set(D /* tag */, float t) { 283 return Vec512<float>{_mm512_set1_ps(t)}; 284 } 285 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 286 HWY_API Vec512<double> Set(D /* tag */, double t) { 287 return Vec512<double>{_mm512_set1_pd(t)}; 288 } 289 290 // ------------------------------ Zero (Set) 291 292 // GCC pre-9.1 lacked setzero, so use Set instead. 293 #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 294 295 // Cannot use VFromD here because it is defined in terms of Zero. 296 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_SPECIAL_FLOAT_D(D)> 297 HWY_API Vec512<TFromD<D>> Zero(D d) { 298 return Set(d, TFromD<D>{0}); 299 } 300 // BitCast is defined below, but the Raw type is the same, so use that. 301 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)> 302 HWY_API Vec512<bfloat16_t> Zero(D /* tag */) { 303 const RebindToUnsigned<D> du; 304 return Vec512<bfloat16_t>{Set(du, 0).raw}; 305 } 306 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 307 HWY_API Vec512<float16_t> Zero(D /* tag */) { 308 const RebindToUnsigned<D> du; 309 return Vec512<float16_t>{Set(du, 0).raw}; 310 } 311 312 #else 313 314 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 315 HWY_API Vec512<TFromD<D>> Zero(D /* tag */) { 316 return Vec512<TFromD<D>>{_mm512_setzero_si512()}; 317 } 318 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)> 319 HWY_API Vec512<bfloat16_t> Zero(D /* tag */) { 320 return Vec512<bfloat16_t>{_mm512_setzero_si512()}; 321 } 322 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 323 HWY_API Vec512<float16_t> Zero(D /* tag */) { 324 #if HWY_HAVE_FLOAT16 325 return Vec512<float16_t>{_mm512_setzero_ph()}; 326 #else 327 return Vec512<float16_t>{_mm512_setzero_si512()}; 328 #endif 329 } 330 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 331 HWY_API Vec512<float> Zero(D /* tag */) { 332 return Vec512<float>{_mm512_setzero_ps()}; 333 } 334 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 335 HWY_API Vec512<double> Zero(D /* tag */) { 336 return Vec512<double>{_mm512_setzero_pd()}; 337 } 338 339 #endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 340 341 // ------------------------------ Undefined 342 343 HWY_DIAGNOSTICS(push) 344 HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") 345 346 // Returns a vector with uninitialized elements. 347 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 348 HWY_API Vec512<TFromD<D>> Undefined(D /* tag */) { 349 // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC 350 // generate an XOR instruction. 351 return Vec512<TFromD<D>>{_mm512_undefined_epi32()}; 352 } 353 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)> 354 HWY_API Vec512<bfloat16_t> Undefined(D /* tag */) { 355 return Vec512<bfloat16_t>{_mm512_undefined_epi32()}; 356 } 357 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 358 HWY_API Vec512<float16_t> Undefined(D /* tag */) { 359 #if HWY_HAVE_FLOAT16 360 return Vec512<float16_t>{_mm512_undefined_ph()}; 361 #else 362 return Vec512<float16_t>{_mm512_undefined_epi32()}; 363 #endif 364 } 365 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 366 HWY_API Vec512<float> Undefined(D /* tag */) { 367 return Vec512<float>{_mm512_undefined_ps()}; 368 } 369 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 370 HWY_API Vec512<double> Undefined(D /* tag */) { 371 return Vec512<double>{_mm512_undefined_pd()}; 372 } 373 374 HWY_DIAGNOSTICS(pop) 375 376 // ------------------------------ ResizeBitCast 377 378 // 64-byte vector to 16-byte vector 379 template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 64), 380 HWY_IF_V_SIZE_D(D, 16)> 381 HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { 382 return BitCast(d, Vec128<uint8_t>{_mm512_castsi512_si128( 383 BitCast(Full512<uint8_t>(), v).raw)}); 384 } 385 386 // <= 16-byte vector to 64-byte vector 387 template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 16), 388 HWY_IF_V_SIZE_D(D, 64)> 389 HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { 390 return BitCast(d, Vec512<uint8_t>{_mm512_castsi128_si512( 391 ResizeBitCast(Full128<uint8_t>(), v).raw)}); 392 } 393 394 // 32-byte vector to 64-byte vector 395 template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 32), 396 HWY_IF_V_SIZE_D(D, 64)> 397 HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { 398 return BitCast(d, Vec512<uint8_t>{_mm512_castsi256_si512( 399 BitCast(Full256<uint8_t>(), v).raw)}); 400 } 401 402 // ------------------------------ Dup128VecFromValues 403 404 template <class D, HWY_IF_UI8_D(D), HWY_IF_V_SIZE_D(D, 64)> 405 HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1, 406 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 407 TFromD<D> t5, TFromD<D> t6, TFromD<D> t7, 408 TFromD<D> t8, TFromD<D> t9, TFromD<D> t10, 409 TFromD<D> t11, TFromD<D> t12, 410 TFromD<D> t13, TFromD<D> t14, 411 TFromD<D> t15) { 412 #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 413 // Missing set_epi8/16. 414 return BroadcastBlock<0>(ResizeBitCast( 415 d, Dup128VecFromValues(Full128<TFromD<D>>(), t0, t1, t2, t3, t4, t5, t6, 416 t7, t8, t9, t10, t11, t12, t13, t14, t15))); 417 #else 418 (void)d; 419 // Need to use _mm512_set_epi8 as there is no _mm512_setr_epi8 intrinsic 420 // available 421 return VFromD<D>{_mm512_set_epi8( 422 static_cast<char>(t15), static_cast<char>(t14), static_cast<char>(t13), 423 static_cast<char>(t12), static_cast<char>(t11), static_cast<char>(t10), 424 static_cast<char>(t9), static_cast<char>(t8), static_cast<char>(t7), 425 static_cast<char>(t6), static_cast<char>(t5), static_cast<char>(t4), 426 static_cast<char>(t3), static_cast<char>(t2), static_cast<char>(t1), 427 static_cast<char>(t0), static_cast<char>(t15), static_cast<char>(t14), 428 static_cast<char>(t13), static_cast<char>(t12), static_cast<char>(t11), 429 static_cast<char>(t10), static_cast<char>(t9), static_cast<char>(t8), 430 static_cast<char>(t7), static_cast<char>(t6), static_cast<char>(t5), 431 static_cast<char>(t4), static_cast<char>(t3), static_cast<char>(t2), 432 static_cast<char>(t1), static_cast<char>(t0), static_cast<char>(t15), 433 static_cast<char>(t14), static_cast<char>(t13), static_cast<char>(t12), 434 static_cast<char>(t11), static_cast<char>(t10), static_cast<char>(t9), 435 static_cast<char>(t8), static_cast<char>(t7), static_cast<char>(t6), 436 static_cast<char>(t5), static_cast<char>(t4), static_cast<char>(t3), 437 static_cast<char>(t2), static_cast<char>(t1), static_cast<char>(t0), 438 static_cast<char>(t15), static_cast<char>(t14), static_cast<char>(t13), 439 static_cast<char>(t12), static_cast<char>(t11), static_cast<char>(t10), 440 static_cast<char>(t9), static_cast<char>(t8), static_cast<char>(t7), 441 static_cast<char>(t6), static_cast<char>(t5), static_cast<char>(t4), 442 static_cast<char>(t3), static_cast<char>(t2), static_cast<char>(t1), 443 static_cast<char>(t0))}; 444 #endif 445 } 446 447 template <class D, HWY_IF_UI16_D(D), HWY_IF_V_SIZE_D(D, 64)> 448 HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1, 449 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 450 TFromD<D> t5, TFromD<D> t6, 451 TFromD<D> t7) { 452 #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 453 // Missing set_epi8/16. 454 return BroadcastBlock<0>( 455 ResizeBitCast(d, Dup128VecFromValues(Full128<TFromD<D>>(), t0, t1, t2, t3, 456 t4, t5, t6, t7))); 457 #else 458 (void)d; 459 // Need to use _mm512_set_epi16 as there is no _mm512_setr_epi16 intrinsic 460 // available 461 return VFromD<D>{ 462 _mm512_set_epi16(static_cast<int16_t>(t7), static_cast<int16_t>(t6), 463 static_cast<int16_t>(t5), static_cast<int16_t>(t4), 464 static_cast<int16_t>(t3), static_cast<int16_t>(t2), 465 static_cast<int16_t>(t1), static_cast<int16_t>(t0), 466 static_cast<int16_t>(t7), static_cast<int16_t>(t6), 467 static_cast<int16_t>(t5), static_cast<int16_t>(t4), 468 static_cast<int16_t>(t3), static_cast<int16_t>(t2), 469 static_cast<int16_t>(t1), static_cast<int16_t>(t0), 470 static_cast<int16_t>(t7), static_cast<int16_t>(t6), 471 static_cast<int16_t>(t5), static_cast<int16_t>(t4), 472 static_cast<int16_t>(t3), static_cast<int16_t>(t2), 473 static_cast<int16_t>(t1), static_cast<int16_t>(t0), 474 static_cast<int16_t>(t7), static_cast<int16_t>(t6), 475 static_cast<int16_t>(t5), static_cast<int16_t>(t4), 476 static_cast<int16_t>(t3), static_cast<int16_t>(t2), 477 static_cast<int16_t>(t1), static_cast<int16_t>(t0))}; 478 #endif 479 } 480 481 #if HWY_HAVE_FLOAT16 482 template <class D, HWY_IF_F16_D(D), HWY_IF_V_SIZE_D(D, 64)> 483 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 484 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 485 TFromD<D> t5, TFromD<D> t6, 486 TFromD<D> t7) { 487 return VFromD<D>{_mm512_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, 488 t3, t4, t5, t6, t7, t0, t1, t2, t3, t4, t5, 489 t6, t7, t0, t1, t2, t3, t4, t5, t6, t7)}; 490 } 491 #endif 492 493 template <class D, HWY_IF_UI32_D(D), HWY_IF_V_SIZE_D(D, 64)> 494 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 495 TFromD<D> t2, TFromD<D> t3) { 496 return VFromD<D>{ 497 _mm512_setr_epi32(static_cast<int32_t>(t0), static_cast<int32_t>(t1), 498 static_cast<int32_t>(t2), static_cast<int32_t>(t3), 499 static_cast<int32_t>(t0), static_cast<int32_t>(t1), 500 static_cast<int32_t>(t2), static_cast<int32_t>(t3), 501 static_cast<int32_t>(t0), static_cast<int32_t>(t1), 502 static_cast<int32_t>(t2), static_cast<int32_t>(t3), 503 static_cast<int32_t>(t0), static_cast<int32_t>(t1), 504 static_cast<int32_t>(t2), static_cast<int32_t>(t3))}; 505 } 506 507 template <class D, HWY_IF_F32_D(D), HWY_IF_V_SIZE_D(D, 64)> 508 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 509 TFromD<D> t2, TFromD<D> t3) { 510 return VFromD<D>{_mm512_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3, t0, t1, t2, 511 t3, t0, t1, t2, t3)}; 512 } 513 514 template <class D, HWY_IF_UI64_D(D), HWY_IF_V_SIZE_D(D, 64)> 515 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { 516 return VFromD<D>{ 517 _mm512_setr_epi64(static_cast<int64_t>(t0), static_cast<int64_t>(t1), 518 static_cast<int64_t>(t0), static_cast<int64_t>(t1), 519 static_cast<int64_t>(t0), static_cast<int64_t>(t1), 520 static_cast<int64_t>(t0), static_cast<int64_t>(t1))}; 521 } 522 523 template <class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_D(D, 64)> 524 HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { 525 return VFromD<D>{_mm512_setr_pd(t0, t1, t0, t1, t0, t1, t0, t1)}; 526 } 527 528 // ----------------------------- Iota 529 530 namespace detail { 531 532 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 533 HWY_INLINE VFromD<D> Iota0(D d) { 534 #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 535 // Missing set_epi8/16. 536 alignas(64) static constexpr TFromD<D> kIota[64] = { 537 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 538 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 539 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 540 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}; 541 return Load(d, kIota); 542 #else 543 (void)d; 544 return VFromD<D>{_mm512_set_epi8( 545 static_cast<char>(63), static_cast<char>(62), static_cast<char>(61), 546 static_cast<char>(60), static_cast<char>(59), static_cast<char>(58), 547 static_cast<char>(57), static_cast<char>(56), static_cast<char>(55), 548 static_cast<char>(54), static_cast<char>(53), static_cast<char>(52), 549 static_cast<char>(51), static_cast<char>(50), static_cast<char>(49), 550 static_cast<char>(48), static_cast<char>(47), static_cast<char>(46), 551 static_cast<char>(45), static_cast<char>(44), static_cast<char>(43), 552 static_cast<char>(42), static_cast<char>(41), static_cast<char>(40), 553 static_cast<char>(39), static_cast<char>(38), static_cast<char>(37), 554 static_cast<char>(36), static_cast<char>(35), static_cast<char>(34), 555 static_cast<char>(33), static_cast<char>(32), static_cast<char>(31), 556 static_cast<char>(30), static_cast<char>(29), static_cast<char>(28), 557 static_cast<char>(27), static_cast<char>(26), static_cast<char>(25), 558 static_cast<char>(24), static_cast<char>(23), static_cast<char>(22), 559 static_cast<char>(21), static_cast<char>(20), static_cast<char>(19), 560 static_cast<char>(18), static_cast<char>(17), static_cast<char>(16), 561 static_cast<char>(15), static_cast<char>(14), static_cast<char>(13), 562 static_cast<char>(12), static_cast<char>(11), static_cast<char>(10), 563 static_cast<char>(9), static_cast<char>(8), static_cast<char>(7), 564 static_cast<char>(6), static_cast<char>(5), static_cast<char>(4), 565 static_cast<char>(3), static_cast<char>(2), static_cast<char>(1), 566 static_cast<char>(0))}; 567 #endif 568 } 569 570 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI16_D(D)> 571 HWY_INLINE VFromD<D> Iota0(D d) { 572 #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 573 // Missing set_epi8/16. 574 alignas(64) static constexpr TFromD<D> kIota[32] = { 575 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 576 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; 577 return Load(d, kIota); 578 #else 579 (void)d; 580 return VFromD<D>{_mm512_set_epi16( 581 int16_t{31}, int16_t{30}, int16_t{29}, int16_t{28}, int16_t{27}, 582 int16_t{26}, int16_t{25}, int16_t{24}, int16_t{23}, int16_t{22}, 583 int16_t{21}, int16_t{20}, int16_t{19}, int16_t{18}, int16_t{17}, 584 int16_t{16}, int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12}, 585 int16_t{11}, int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6}, 586 int16_t{5}, int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})}; 587 #endif 588 } 589 590 #if HWY_HAVE_FLOAT16 591 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 592 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 593 return VFromD<D>{_mm512_set_ph( 594 float16_t{31}, float16_t{30}, float16_t{29}, float16_t{28}, float16_t{27}, 595 float16_t{26}, float16_t{25}, float16_t{24}, float16_t{23}, float16_t{22}, 596 float16_t{21}, float16_t{20}, float16_t{19}, float16_t{18}, float16_t{17}, 597 float16_t{16}, float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12}, 598 float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, float16_t{7}, 599 float16_t{6}, float16_t{5}, float16_t{4}, float16_t{3}, float16_t{2}, 600 float16_t{1}, float16_t{0})}; 601 } 602 #endif // HWY_HAVE_FLOAT16 603 604 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 605 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 606 return VFromD<D>{_mm512_set_epi32( 607 int32_t{15}, int32_t{14}, int32_t{13}, int32_t{12}, int32_t{11}, 608 int32_t{10}, int32_t{9}, int32_t{8}, int32_t{7}, int32_t{6}, int32_t{5}, 609 int32_t{4}, int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})}; 610 } 611 612 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 613 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 614 return VFromD<D>{_mm512_set_epi64(int64_t{7}, int64_t{6}, int64_t{5}, 615 int64_t{4}, int64_t{3}, int64_t{2}, 616 int64_t{1}, int64_t{0})}; 617 } 618 619 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 620 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 621 return VFromD<D>{_mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 622 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 623 0.0f)}; 624 } 625 626 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 627 HWY_INLINE VFromD<D> Iota0(D /*d*/) { 628 return VFromD<D>{_mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)}; 629 } 630 631 } // namespace detail 632 633 template <class D, typename T2, HWY_IF_V_SIZE_D(D, 64)> 634 HWY_API VFromD<D> Iota(D d, const T2 first) { 635 return detail::Iota0(d) + Set(d, ConvertScalarTo<TFromD<D>>(first)); 636 } 637 638 // ================================================== LOGICAL 639 640 // ------------------------------ Not 641 642 template <typename T> 643 HWY_API Vec512<T> Not(const Vec512<T> v) { 644 const DFromV<decltype(v)> d; 645 const RebindToUnsigned<decltype(d)> du; 646 using VU = VFromD<decltype(du)>; 647 const __m512i vu = BitCast(du, v).raw; 648 return BitCast(d, VU{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); 649 } 650 651 // ------------------------------ And 652 653 template <typename T> 654 HWY_API Vec512<T> And(const Vec512<T> a, const Vec512<T> b) { 655 const DFromV<decltype(a)> d; // for float16_t 656 const RebindToUnsigned<decltype(d)> du; 657 return BitCast(d, VFromD<decltype(du)>{_mm512_and_si512(BitCast(du, a).raw, 658 BitCast(du, b).raw)}); 659 } 660 661 HWY_API Vec512<float> And(const Vec512<float> a, const Vec512<float> b) { 662 return Vec512<float>{_mm512_and_ps(a.raw, b.raw)}; 663 } 664 HWY_API Vec512<double> And(const Vec512<double> a, const Vec512<double> b) { 665 return Vec512<double>{_mm512_and_pd(a.raw, b.raw)}; 666 } 667 668 // ------------------------------ AndNot 669 670 // Returns ~not_mask & mask. 671 template <typename T> 672 HWY_API Vec512<T> AndNot(const Vec512<T> not_mask, const Vec512<T> mask) { 673 const DFromV<decltype(mask)> d; // for float16_t 674 const RebindToUnsigned<decltype(d)> du; 675 return BitCast(d, VFromD<decltype(du)>{_mm512_andnot_si512( 676 BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); 677 } 678 HWY_API Vec512<float> AndNot(const Vec512<float> not_mask, 679 const Vec512<float> mask) { 680 return Vec512<float>{_mm512_andnot_ps(not_mask.raw, mask.raw)}; 681 } 682 HWY_API Vec512<double> AndNot(const Vec512<double> not_mask, 683 const Vec512<double> mask) { 684 return Vec512<double>{_mm512_andnot_pd(not_mask.raw, mask.raw)}; 685 } 686 687 // ------------------------------ Or 688 689 template <typename T> 690 HWY_API Vec512<T> Or(const Vec512<T> a, const Vec512<T> b) { 691 const DFromV<decltype(a)> d; // for float16_t 692 const RebindToUnsigned<decltype(d)> du; 693 return BitCast(d, VFromD<decltype(du)>{_mm512_or_si512(BitCast(du, a).raw, 694 BitCast(du, b).raw)}); 695 } 696 697 HWY_API Vec512<float> Or(const Vec512<float> a, const Vec512<float> b) { 698 return Vec512<float>{_mm512_or_ps(a.raw, b.raw)}; 699 } 700 HWY_API Vec512<double> Or(const Vec512<double> a, const Vec512<double> b) { 701 return Vec512<double>{_mm512_or_pd(a.raw, b.raw)}; 702 } 703 704 // ------------------------------ Xor 705 706 template <typename T> 707 HWY_API Vec512<T> Xor(const Vec512<T> a, const Vec512<T> b) { 708 const DFromV<decltype(a)> d; // for float16_t 709 const RebindToUnsigned<decltype(d)> du; 710 return BitCast(d, VFromD<decltype(du)>{_mm512_xor_si512(BitCast(du, a).raw, 711 BitCast(du, b).raw)}); 712 } 713 714 HWY_API Vec512<float> Xor(const Vec512<float> a, const Vec512<float> b) { 715 return Vec512<float>{_mm512_xor_ps(a.raw, b.raw)}; 716 } 717 HWY_API Vec512<double> Xor(const Vec512<double> a, const Vec512<double> b) { 718 return Vec512<double>{_mm512_xor_pd(a.raw, b.raw)}; 719 } 720 721 // ------------------------------ Xor3 722 template <typename T> 723 HWY_API Vec512<T> Xor3(Vec512<T> x1, Vec512<T> x2, Vec512<T> x3) { 724 #if !HWY_IS_MSAN 725 const DFromV<decltype(x1)> d; 726 const RebindToUnsigned<decltype(d)> du; 727 using VU = VFromD<decltype(du)>; 728 const __m512i ret = _mm512_ternarylogic_epi64( 729 BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); 730 return BitCast(d, VU{ret}); 731 #else 732 return Xor(x1, Xor(x2, x3)); 733 #endif 734 } 735 736 // ------------------------------ Or3 737 template <typename T> 738 HWY_API Vec512<T> Or3(Vec512<T> o1, Vec512<T> o2, Vec512<T> o3) { 739 #if !HWY_IS_MSAN 740 const DFromV<decltype(o1)> d; 741 const RebindToUnsigned<decltype(d)> du; 742 using VU = VFromD<decltype(du)>; 743 const __m512i ret = _mm512_ternarylogic_epi64( 744 BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); 745 return BitCast(d, VU{ret}); 746 #else 747 return Or(o1, Or(o2, o3)); 748 #endif 749 } 750 751 // ------------------------------ OrAnd 752 template <typename T> 753 HWY_API Vec512<T> OrAnd(Vec512<T> o, Vec512<T> a1, Vec512<T> a2) { 754 #if !HWY_IS_MSAN 755 const DFromV<decltype(o)> d; 756 const RebindToUnsigned<decltype(d)> du; 757 using VU = VFromD<decltype(du)>; 758 const __m512i ret = _mm512_ternarylogic_epi64( 759 BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); 760 return BitCast(d, VU{ret}); 761 #else 762 return Or(o, And(a1, a2)); 763 #endif 764 } 765 766 // ------------------------------ IfVecThenElse 767 template <typename T> 768 HWY_API Vec512<T> IfVecThenElse(Vec512<T> mask, Vec512<T> yes, Vec512<T> no) { 769 #if !HWY_IS_MSAN 770 const DFromV<decltype(yes)> d; 771 const RebindToUnsigned<decltype(d)> du; 772 using VU = VFromD<decltype(du)>; 773 return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, 774 BitCast(du, yes).raw, 775 BitCast(du, no).raw, 0xCA)}); 776 #else 777 return IfThenElse(MaskFromVec(mask), yes, no); 778 #endif 779 } 780 781 // ------------------------------ Operator overloads (internal-only if float) 782 783 template <typename T> 784 HWY_API Vec512<T> operator&(const Vec512<T> a, const Vec512<T> b) { 785 return And(a, b); 786 } 787 788 template <typename T> 789 HWY_API Vec512<T> operator|(const Vec512<T> a, const Vec512<T> b) { 790 return Or(a, b); 791 } 792 793 template <typename T> 794 HWY_API Vec512<T> operator^(const Vec512<T> a, const Vec512<T> b) { 795 return Xor(a, b); 796 } 797 798 // ------------------------------ PopulationCount 799 800 // 8/16 require BITALG, 32/64 require VPOPCNTDQ. 801 #if HWY_TARGET <= HWY_AVX3_DL 802 803 #ifdef HWY_NATIVE_POPCNT 804 #undef HWY_NATIVE_POPCNT 805 #else 806 #define HWY_NATIVE_POPCNT 807 #endif 808 809 namespace detail { 810 811 template <typename T> 812 HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<1> /* tag */, Vec512<T> v) { 813 return Vec512<T>{_mm512_popcnt_epi8(v.raw)}; 814 } 815 template <typename T> 816 HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<2> /* tag */, Vec512<T> v) { 817 return Vec512<T>{_mm512_popcnt_epi16(v.raw)}; 818 } 819 template <typename T> 820 HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<4> /* tag */, Vec512<T> v) { 821 return Vec512<T>{_mm512_popcnt_epi32(v.raw)}; 822 } 823 template <typename T> 824 HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<8> /* tag */, Vec512<T> v) { 825 return Vec512<T>{_mm512_popcnt_epi64(v.raw)}; 826 } 827 828 } // namespace detail 829 830 template <typename T> 831 HWY_API Vec512<T> PopulationCount(Vec512<T> v) { 832 return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v); 833 } 834 835 #endif // HWY_TARGET <= HWY_AVX3_DL 836 837 // ================================================== MASK 838 839 // ------------------------------ FirstN 840 841 // Possibilities for constructing a bitmask of N ones: 842 // - kshift* only consider the lowest byte of the shift count, so they would 843 // not correctly handle large n. 844 // - Scalar shifts >= 64 are UB. 845 // - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, 846 // we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. 847 848 #if HWY_ARCH_X86_32 849 namespace detail { 850 851 // 32 bit mask is sufficient for lane size >= 2. 852 template <typename T, HWY_IF_NOT_T_SIZE(T, 1)> 853 HWY_INLINE Mask512<T> FirstN(size_t n) { 854 Mask512<T> m; 855 const uint32_t all = ~uint32_t{0}; 856 // BZHI only looks at the lower 8 bits of n, but it has been clamped to 857 // MaxLanes, which is at most 32. 858 m.raw = static_cast<decltype(m.raw)>(_bzhi_u32(all, n)); 859 return m; 860 } 861 862 #if HWY_COMPILER_MSVC >= 1920 || HWY_COMPILER_GCC_ACTUAL >= 900 || \ 863 HWY_COMPILER_CLANG || HWY_COMPILER_ICC 864 template <typename T, HWY_IF_T_SIZE(T, 1)> 865 HWY_INLINE Mask512<T> FirstN(size_t n) { 866 uint32_t lo_mask; 867 uint32_t hi_mask; 868 uint32_t hi_mask_len; 869 #if HWY_COMPILER_GCC 870 if (__builtin_constant_p(n >= 32) && n >= 32) { 871 if (__builtin_constant_p(n >= 64) && n >= 64) { 872 hi_mask_len = 32u; 873 } else { 874 hi_mask_len = static_cast<uint32_t>(n) - 32u; 875 } 876 lo_mask = hi_mask = 0xFFFFFFFFu; 877 } else // NOLINT(readability/braces) 878 #endif 879 { 880 const uint32_t lo_mask_len = static_cast<uint32_t>(n); 881 lo_mask = _bzhi_u32(0xFFFFFFFFu, lo_mask_len); 882 883 #if HWY_COMPILER_GCC 884 if (__builtin_constant_p(lo_mask_len <= 32) && lo_mask_len <= 32) { 885 return Mask512<T>{static_cast<__mmask64>(lo_mask)}; 886 } 887 #endif 888 889 _addcarry_u32(_subborrow_u32(0, lo_mask_len, 32u, &hi_mask_len), 890 0xFFFFFFFFu, 0u, &hi_mask); 891 } 892 hi_mask = _bzhi_u32(hi_mask, hi_mask_len); 893 #if HWY_COMPILER_GCC && !HWY_COMPILER_ICC 894 if (__builtin_constant_p((static_cast<uint64_t>(hi_mask) << 32) | lo_mask)) 895 #endif 896 return Mask512<T>{static_cast<__mmask64>( 897 (static_cast<uint64_t>(hi_mask) << 32) | lo_mask)}; 898 #if HWY_COMPILER_GCC && !HWY_COMPILER_ICC 899 else 900 return Mask512<T>{_mm512_kunpackd(static_cast<__mmask64>(hi_mask), 901 static_cast<__mmask64>(lo_mask))}; 902 #endif 903 } 904 #else // HWY_COMPILER.. 905 template <typename T, HWY_IF_T_SIZE(T, 1)> 906 HWY_INLINE Mask512<T> FirstN(size_t n) { 907 const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0}; 908 return Mask512<T>{static_cast<__mmask64>(bits)}; 909 } 910 #endif // HWY_COMPILER.. 911 } // namespace detail 912 #endif // HWY_ARCH_X86_32 913 914 template <class D, HWY_IF_V_SIZE_D(D, 64)> 915 HWY_API MFromD<D> FirstN(D d, size_t n) { 916 // This ensures `num` <= 255 as required by bzhi, which only looks 917 // at the lower 8 bits. 918 n = HWY_MIN(n, MaxLanes(d)); 919 920 #if HWY_ARCH_X86_64 921 MFromD<D> m; 922 const uint64_t all = ~uint64_t{0}; 923 m.raw = static_cast<decltype(m.raw)>(_bzhi_u64(all, n)); 924 return m; 925 #else 926 return detail::FirstN<TFromD<D>>(n); 927 #endif // HWY_ARCH_X86_64 928 } 929 930 // ------------------------------ IfThenElse 931 932 // Returns mask ? b : a. 933 934 namespace detail { 935 936 // Templates for signed/unsigned integer of a particular size. 937 template <typename T> 938 HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<1> /* tag */, 939 const Mask512<T> mask, const Vec512<T> yes, 940 const Vec512<T> no) { 941 return Vec512<T>{_mm512_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; 942 } 943 template <typename T> 944 HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<2> /* tag */, 945 const Mask512<T> mask, const Vec512<T> yes, 946 const Vec512<T> no) { 947 return Vec512<T>{_mm512_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; 948 } 949 template <typename T> 950 HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<4> /* tag */, 951 const Mask512<T> mask, const Vec512<T> yes, 952 const Vec512<T> no) { 953 return Vec512<T>{_mm512_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; 954 } 955 template <typename T> 956 HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<8> /* tag */, 957 const Mask512<T> mask, const Vec512<T> yes, 958 const Vec512<T> no) { 959 return Vec512<T>{_mm512_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; 960 } 961 962 } // namespace detail 963 964 template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> 965 HWY_API Vec512<T> IfThenElse(const Mask512<T> mask, const Vec512<T> yes, 966 const Vec512<T> no) { 967 return detail::IfThenElse(hwy::SizeTag<sizeof(T)>(), mask, yes, no); 968 } 969 #if HWY_HAVE_FLOAT16 970 HWY_API Vec512<float16_t> IfThenElse(Mask512<float16_t> mask, 971 Vec512<float16_t> yes, 972 Vec512<float16_t> no) { 973 return Vec512<float16_t>{_mm512_mask_blend_ph(mask.raw, no.raw, yes.raw)}; 974 } 975 #endif // HWY_HAVE_FLOAT16 976 HWY_API Vec512<float> IfThenElse(Mask512<float> mask, Vec512<float> yes, 977 Vec512<float> no) { 978 return Vec512<float>{_mm512_mask_blend_ps(mask.raw, no.raw, yes.raw)}; 979 } 980 HWY_API Vec512<double> IfThenElse(Mask512<double> mask, Vec512<double> yes, 981 Vec512<double> no) { 982 return Vec512<double>{_mm512_mask_blend_pd(mask.raw, no.raw, yes.raw)}; 983 } 984 985 namespace detail { 986 987 template <typename T> 988 HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<1> /* tag */, 989 const Mask512<T> mask, 990 const Vec512<T> yes) { 991 return Vec512<T>{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; 992 } 993 template <typename T> 994 HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<2> /* tag */, 995 const Mask512<T> mask, 996 const Vec512<T> yes) { 997 return Vec512<T>{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; 998 } 999 template <typename T> 1000 HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<4> /* tag */, 1001 const Mask512<T> mask, 1002 const Vec512<T> yes) { 1003 return Vec512<T>{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; 1004 } 1005 template <typename T> 1006 HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<8> /* tag */, 1007 const Mask512<T> mask, 1008 const Vec512<T> yes) { 1009 return Vec512<T>{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; 1010 } 1011 1012 } // namespace detail 1013 1014 template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> 1015 HWY_API Vec512<T> IfThenElseZero(const Mask512<T> mask, const Vec512<T> yes) { 1016 return detail::IfThenElseZero(hwy::SizeTag<sizeof(T)>(), mask, yes); 1017 } 1018 HWY_API Vec512<float> IfThenElseZero(Mask512<float> mask, Vec512<float> yes) { 1019 return Vec512<float>{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; 1020 } 1021 HWY_API Vec512<double> IfThenElseZero(Mask512<double> mask, 1022 Vec512<double> yes) { 1023 return Vec512<double>{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; 1024 } 1025 1026 namespace detail { 1027 1028 template <typename T> 1029 HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<1> /* tag */, 1030 const Mask512<T> mask, const Vec512<T> no) { 1031 // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. 1032 return Vec512<T>{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; 1033 } 1034 template <typename T> 1035 HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<2> /* tag */, 1036 const Mask512<T> mask, const Vec512<T> no) { 1037 return Vec512<T>{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; 1038 } 1039 template <typename T> 1040 HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<4> /* tag */, 1041 const Mask512<T> mask, const Vec512<T> no) { 1042 return Vec512<T>{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; 1043 } 1044 template <typename T> 1045 HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<8> /* tag */, 1046 const Mask512<T> mask, const Vec512<T> no) { 1047 return Vec512<T>{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; 1048 } 1049 1050 } // namespace detail 1051 1052 template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> 1053 HWY_API Vec512<T> IfThenZeroElse(const Mask512<T> mask, const Vec512<T> no) { 1054 return detail::IfThenZeroElse(hwy::SizeTag<sizeof(T)>(), mask, no); 1055 } 1056 HWY_API Vec512<float> IfThenZeroElse(Mask512<float> mask, Vec512<float> no) { 1057 return Vec512<float>{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; 1058 } 1059 HWY_API Vec512<double> IfThenZeroElse(Mask512<double> mask, Vec512<double> no) { 1060 return Vec512<double>{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; 1061 } 1062 1063 template <typename T> 1064 HWY_API Vec512<T> IfNegativeThenElse(Vec512<T> v, Vec512<T> yes, Vec512<T> no) { 1065 static_assert(IsSigned<T>(), "Only works for signed/float"); 1066 // AVX3 MaskFromVec only looks at the MSB 1067 return IfThenElse(MaskFromVec(v), yes, no); 1068 } 1069 1070 template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T), 1071 HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))> 1072 HWY_API Vec512<T> IfNegativeThenNegOrUndefIfZero(Vec512<T> mask, Vec512<T> v) { 1073 // AVX3 MaskFromVec only looks at the MSB 1074 const DFromV<decltype(v)> d; 1075 return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v); 1076 } 1077 1078 // ================================================== ARITHMETIC 1079 1080 // ------------------------------ Addition 1081 1082 // Unsigned 1083 HWY_API Vec512<uint8_t> operator+(Vec512<uint8_t> a, Vec512<uint8_t> b) { 1084 return Vec512<uint8_t>{_mm512_add_epi8(a.raw, b.raw)}; 1085 } 1086 HWY_API Vec512<uint16_t> operator+(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1087 return Vec512<uint16_t>{_mm512_add_epi16(a.raw, b.raw)}; 1088 } 1089 HWY_API Vec512<uint32_t> operator+(Vec512<uint32_t> a, Vec512<uint32_t> b) { 1090 return Vec512<uint32_t>{_mm512_add_epi32(a.raw, b.raw)}; 1091 } 1092 HWY_API Vec512<uint64_t> operator+(Vec512<uint64_t> a, Vec512<uint64_t> b) { 1093 return Vec512<uint64_t>{_mm512_add_epi64(a.raw, b.raw)}; 1094 } 1095 1096 // Signed 1097 HWY_API Vec512<int8_t> operator+(Vec512<int8_t> a, Vec512<int8_t> b) { 1098 return Vec512<int8_t>{_mm512_add_epi8(a.raw, b.raw)}; 1099 } 1100 HWY_API Vec512<int16_t> operator+(Vec512<int16_t> a, Vec512<int16_t> b) { 1101 return Vec512<int16_t>{_mm512_add_epi16(a.raw, b.raw)}; 1102 } 1103 HWY_API Vec512<int32_t> operator+(Vec512<int32_t> a, Vec512<int32_t> b) { 1104 return Vec512<int32_t>{_mm512_add_epi32(a.raw, b.raw)}; 1105 } 1106 HWY_API Vec512<int64_t> operator+(Vec512<int64_t> a, Vec512<int64_t> b) { 1107 return Vec512<int64_t>{_mm512_add_epi64(a.raw, b.raw)}; 1108 } 1109 1110 // Float 1111 #if HWY_HAVE_FLOAT16 1112 HWY_API Vec512<float16_t> operator+(Vec512<float16_t> a, Vec512<float16_t> b) { 1113 return Vec512<float16_t>{_mm512_add_ph(a.raw, b.raw)}; 1114 } 1115 #endif // HWY_HAVE_FLOAT16 1116 HWY_API Vec512<float> operator+(Vec512<float> a, Vec512<float> b) { 1117 return Vec512<float>{_mm512_add_ps(a.raw, b.raw)}; 1118 } 1119 HWY_API Vec512<double> operator+(Vec512<double> a, Vec512<double> b) { 1120 return Vec512<double>{_mm512_add_pd(a.raw, b.raw)}; 1121 } 1122 1123 // ------------------------------ Subtraction 1124 1125 // Unsigned 1126 HWY_API Vec512<uint8_t> operator-(Vec512<uint8_t> a, Vec512<uint8_t> b) { 1127 return Vec512<uint8_t>{_mm512_sub_epi8(a.raw, b.raw)}; 1128 } 1129 HWY_API Vec512<uint16_t> operator-(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1130 return Vec512<uint16_t>{_mm512_sub_epi16(a.raw, b.raw)}; 1131 } 1132 HWY_API Vec512<uint32_t> operator-(Vec512<uint32_t> a, Vec512<uint32_t> b) { 1133 return Vec512<uint32_t>{_mm512_sub_epi32(a.raw, b.raw)}; 1134 } 1135 HWY_API Vec512<uint64_t> operator-(Vec512<uint64_t> a, Vec512<uint64_t> b) { 1136 return Vec512<uint64_t>{_mm512_sub_epi64(a.raw, b.raw)}; 1137 } 1138 1139 // Signed 1140 HWY_API Vec512<int8_t> operator-(Vec512<int8_t> a, Vec512<int8_t> b) { 1141 return Vec512<int8_t>{_mm512_sub_epi8(a.raw, b.raw)}; 1142 } 1143 HWY_API Vec512<int16_t> operator-(Vec512<int16_t> a, Vec512<int16_t> b) { 1144 return Vec512<int16_t>{_mm512_sub_epi16(a.raw, b.raw)}; 1145 } 1146 HWY_API Vec512<int32_t> operator-(Vec512<int32_t> a, Vec512<int32_t> b) { 1147 return Vec512<int32_t>{_mm512_sub_epi32(a.raw, b.raw)}; 1148 } 1149 HWY_API Vec512<int64_t> operator-(Vec512<int64_t> a, Vec512<int64_t> b) { 1150 return Vec512<int64_t>{_mm512_sub_epi64(a.raw, b.raw)}; 1151 } 1152 1153 // Float 1154 #if HWY_HAVE_FLOAT16 1155 HWY_API Vec512<float16_t> operator-(Vec512<float16_t> a, Vec512<float16_t> b) { 1156 return Vec512<float16_t>{_mm512_sub_ph(a.raw, b.raw)}; 1157 } 1158 #endif // HWY_HAVE_FLOAT16 1159 HWY_API Vec512<float> operator-(Vec512<float> a, Vec512<float> b) { 1160 return Vec512<float>{_mm512_sub_ps(a.raw, b.raw)}; 1161 } 1162 HWY_API Vec512<double> operator-(Vec512<double> a, Vec512<double> b) { 1163 return Vec512<double>{_mm512_sub_pd(a.raw, b.raw)}; 1164 } 1165 1166 // ------------------------------ SumsOf8 1167 HWY_API Vec512<uint64_t> SumsOf8(const Vec512<uint8_t> v) { 1168 const Full512<uint8_t> d; 1169 return Vec512<uint64_t>{_mm512_sad_epu8(v.raw, Zero(d).raw)}; 1170 } 1171 1172 HWY_API Vec512<uint64_t> SumsOf8AbsDiff(Vec512<uint8_t> a, Vec512<uint8_t> b) { 1173 return Vec512<uint64_t>{_mm512_sad_epu8(a.raw, b.raw)}; 1174 } 1175 1176 // ------------------------------ SumsOf4 1177 namespace detail { 1178 1179 HWY_INLINE Vec512<uint32_t> SumsOf4(hwy::UnsignedTag /*type_tag*/, 1180 hwy::SizeTag<1> /*lane_size_tag*/, 1181 Vec512<uint8_t> v) { 1182 const DFromV<decltype(v)> d; 1183 1184 // _mm512_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be 1185 // zeroed out and the sums of the 4 consecutive lanes are already in the 1186 // even uint16_t lanes of the _mm512_maskz_dbsad_epu8 result. 1187 return Vec512<uint32_t>{_mm512_maskz_dbsad_epu8( 1188 static_cast<__mmask32>(0x55555555), v.raw, Zero(d).raw, 0)}; 1189 } 1190 1191 // I8->I32 SumsOf4 1192 // Generic for all vector lengths 1193 template <class V> 1194 HWY_INLINE VFromD<RepartitionToWideX2<DFromV<V>>> SumsOf4( 1195 hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { 1196 const DFromV<decltype(v)> d; 1197 const RebindToUnsigned<decltype(d)> du; 1198 const RepartitionToWideX2<decltype(d)> di32; 1199 1200 // Adjust the values of v to be in the 0..255 range by adding 128 to each lane 1201 // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then 1202 // bitcasting the Xor result to an u8 vector. 1203 const auto v_adj = BitCast(du, Xor(v, SignBit(d))); 1204 1205 // Need to add -512 to each i32 lane of the result of the 1206 // SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj) operation to account 1207 // for the adjustment made above. 1208 return BitCast(di32, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj)) + 1209 Set(di32, int32_t{-512}); 1210 } 1211 1212 } // namespace detail 1213 1214 // ------------------------------ SumsOfShuffledQuadAbsDiff 1215 1216 #if HWY_TARGET <= HWY_AVX3 1217 template <int kIdx3, int kIdx2, int kIdx1, int kIdx0> 1218 static Vec512<uint16_t> SumsOfShuffledQuadAbsDiff(Vec512<uint8_t> a, 1219 Vec512<uint8_t> b) { 1220 static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); 1221 static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); 1222 static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); 1223 static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); 1224 return Vec512<uint16_t>{ 1225 _mm512_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; 1226 } 1227 #endif 1228 1229 // ------------------------------ SaturatedAdd 1230 1231 // Returns a + b clamped to the destination range. 1232 1233 // Unsigned 1234 HWY_API Vec512<uint8_t> SaturatedAdd(Vec512<uint8_t> a, Vec512<uint8_t> b) { 1235 return Vec512<uint8_t>{_mm512_adds_epu8(a.raw, b.raw)}; 1236 } 1237 HWY_API Vec512<uint16_t> SaturatedAdd(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1238 return Vec512<uint16_t>{_mm512_adds_epu16(a.raw, b.raw)}; 1239 } 1240 1241 // Signed 1242 HWY_API Vec512<int8_t> SaturatedAdd(Vec512<int8_t> a, Vec512<int8_t> b) { 1243 return Vec512<int8_t>{_mm512_adds_epi8(a.raw, b.raw)}; 1244 } 1245 HWY_API Vec512<int16_t> SaturatedAdd(Vec512<int16_t> a, Vec512<int16_t> b) { 1246 return Vec512<int16_t>{_mm512_adds_epi16(a.raw, b.raw)}; 1247 } 1248 1249 // ------------------------------ SaturatedSub 1250 1251 // Returns a - b clamped to the destination range. 1252 1253 // Unsigned 1254 HWY_API Vec512<uint8_t> SaturatedSub(Vec512<uint8_t> a, Vec512<uint8_t> b) { 1255 return Vec512<uint8_t>{_mm512_subs_epu8(a.raw, b.raw)}; 1256 } 1257 HWY_API Vec512<uint16_t> SaturatedSub(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1258 return Vec512<uint16_t>{_mm512_subs_epu16(a.raw, b.raw)}; 1259 } 1260 1261 // Signed 1262 HWY_API Vec512<int8_t> SaturatedSub(Vec512<int8_t> a, Vec512<int8_t> b) { 1263 return Vec512<int8_t>{_mm512_subs_epi8(a.raw, b.raw)}; 1264 } 1265 HWY_API Vec512<int16_t> SaturatedSub(Vec512<int16_t> a, Vec512<int16_t> b) { 1266 return Vec512<int16_t>{_mm512_subs_epi16(a.raw, b.raw)}; 1267 } 1268 1269 // ------------------------------ Average 1270 1271 // Returns (a + b + 1) / 2 1272 1273 // Unsigned 1274 HWY_API Vec512<uint8_t> AverageRound(Vec512<uint8_t> a, Vec512<uint8_t> b) { 1275 return Vec512<uint8_t>{_mm512_avg_epu8(a.raw, b.raw)}; 1276 } 1277 HWY_API Vec512<uint16_t> AverageRound(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1278 return Vec512<uint16_t>{_mm512_avg_epu16(a.raw, b.raw)}; 1279 } 1280 1281 // ------------------------------ Abs (Sub) 1282 1283 // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. 1284 HWY_API Vec512<int8_t> Abs(const Vec512<int8_t> v) { 1285 #if HWY_COMPILER_MSVC 1286 // Workaround for incorrect codegen? (untested due to internal compiler error) 1287 const DFromV<decltype(v)> d; 1288 const auto zero = Zero(d); 1289 return Vec512<int8_t>{_mm512_max_epi8(v.raw, (zero - v).raw)}; 1290 #else 1291 return Vec512<int8_t>{_mm512_abs_epi8(v.raw)}; 1292 #endif 1293 } 1294 HWY_API Vec512<int16_t> Abs(const Vec512<int16_t> v) { 1295 return Vec512<int16_t>{_mm512_abs_epi16(v.raw)}; 1296 } 1297 HWY_API Vec512<int32_t> Abs(const Vec512<int32_t> v) { 1298 return Vec512<int32_t>{_mm512_abs_epi32(v.raw)}; 1299 } 1300 HWY_API Vec512<int64_t> Abs(const Vec512<int64_t> v) { 1301 return Vec512<int64_t>{_mm512_abs_epi64(v.raw)}; 1302 } 1303 1304 // ------------------------------ ShiftLeft 1305 1306 #if HWY_TARGET <= HWY_AVX3_DL 1307 namespace detail { 1308 template <typename T> 1309 HWY_API Vec512<T> GaloisAffine(Vec512<T> v, Vec512<uint64_t> matrix) { 1310 return Vec512<T>{_mm512_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; 1311 } 1312 } // namespace detail 1313 #endif // HWY_TARGET <= HWY_AVX3_DL 1314 1315 template <int kBits> 1316 HWY_API Vec512<uint16_t> ShiftLeft(const Vec512<uint16_t> v) { 1317 return Vec512<uint16_t>{_mm512_slli_epi16(v.raw, kBits)}; 1318 } 1319 1320 template <int kBits> 1321 HWY_API Vec512<uint32_t> ShiftLeft(const Vec512<uint32_t> v) { 1322 return Vec512<uint32_t>{_mm512_slli_epi32(v.raw, kBits)}; 1323 } 1324 1325 template <int kBits> 1326 HWY_API Vec512<uint64_t> ShiftLeft(const Vec512<uint64_t> v) { 1327 return Vec512<uint64_t>{_mm512_slli_epi64(v.raw, kBits)}; 1328 } 1329 1330 template <int kBits> 1331 HWY_API Vec512<int16_t> ShiftLeft(const Vec512<int16_t> v) { 1332 return Vec512<int16_t>{_mm512_slli_epi16(v.raw, kBits)}; 1333 } 1334 1335 template <int kBits> 1336 HWY_API Vec512<int32_t> ShiftLeft(const Vec512<int32_t> v) { 1337 return Vec512<int32_t>{_mm512_slli_epi32(v.raw, kBits)}; 1338 } 1339 1340 template <int kBits> 1341 HWY_API Vec512<int64_t> ShiftLeft(const Vec512<int64_t> v) { 1342 return Vec512<int64_t>{_mm512_slli_epi64(v.raw, kBits)}; 1343 } 1344 1345 #if HWY_TARGET > HWY_AVX3_DL 1346 1347 template <int kBits, typename T, HWY_IF_T_SIZE(T, 1)> 1348 HWY_API Vec512<T> ShiftLeft(const Vec512<T> v) { 1349 const DFromV<decltype(v)> d8; 1350 const RepartitionToWide<decltype(d8)> d16; 1351 const auto shifted = BitCast(d8, ShiftLeft<kBits>(BitCast(d16, v))); 1352 return kBits == 1 1353 ? (v + v) 1354 : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF))); 1355 } 1356 1357 #endif // HWY_TARGET > HWY_AVX3_DL 1358 1359 // ------------------------------ ShiftRight 1360 1361 template <int kBits> 1362 HWY_API Vec512<uint16_t> ShiftRight(const Vec512<uint16_t> v) { 1363 return Vec512<uint16_t>{_mm512_srli_epi16(v.raw, kBits)}; 1364 } 1365 1366 template <int kBits> 1367 HWY_API Vec512<uint32_t> ShiftRight(const Vec512<uint32_t> v) { 1368 return Vec512<uint32_t>{_mm512_srli_epi32(v.raw, kBits)}; 1369 } 1370 1371 template <int kBits> 1372 HWY_API Vec512<uint64_t> ShiftRight(const Vec512<uint64_t> v) { 1373 return Vec512<uint64_t>{_mm512_srli_epi64(v.raw, kBits)}; 1374 } 1375 1376 template <int kBits> 1377 HWY_API Vec512<int16_t> ShiftRight(const Vec512<int16_t> v) { 1378 return Vec512<int16_t>{_mm512_srai_epi16(v.raw, kBits)}; 1379 } 1380 1381 template <int kBits> 1382 HWY_API Vec512<int32_t> ShiftRight(const Vec512<int32_t> v) { 1383 return Vec512<int32_t>{_mm512_srai_epi32(v.raw, kBits)}; 1384 } 1385 1386 template <int kBits> 1387 HWY_API Vec512<int64_t> ShiftRight(const Vec512<int64_t> v) { 1388 return Vec512<int64_t>{_mm512_srai_epi64(v.raw, kBits)}; 1389 } 1390 1391 #if HWY_TARGET > HWY_AVX3_DL 1392 1393 template <int kBits> 1394 HWY_API Vec512<uint8_t> ShiftRight(const Vec512<uint8_t> v) { 1395 const DFromV<decltype(v)> d8; 1396 // Use raw instead of BitCast to support N=1. 1397 const Vec512<uint8_t> shifted{ShiftRight<kBits>(Vec512<uint16_t>{v.raw}).raw}; 1398 return shifted & Set(d8, 0xFF >> kBits); 1399 } 1400 1401 template <int kBits> 1402 HWY_API Vec512<int8_t> ShiftRight(const Vec512<int8_t> v) { 1403 const DFromV<decltype(v)> di; 1404 const RebindToUnsigned<decltype(di)> du; 1405 const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); 1406 const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); 1407 return (shifted ^ shifted_sign) - shifted_sign; 1408 } 1409 1410 #endif // HWY_TARGET > HWY_AVX3_DL 1411 1412 // ------------------------------ RotateRight 1413 1414 #if HWY_TARGET > HWY_AVX3_DL 1415 template <int kBits> 1416 HWY_API Vec512<uint8_t> RotateRight(const Vec512<uint8_t> v) { 1417 static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); 1418 if (kBits == 0) return v; 1419 // AVX3 does not support 8-bit. 1420 return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(7, 8 - kBits)>(v)); 1421 } 1422 #endif // HWY_TARGET > HWY_AVX3_DL 1423 1424 template <int kBits> 1425 HWY_API Vec512<uint16_t> RotateRight(const Vec512<uint16_t> v) { 1426 static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); 1427 if (kBits == 0) return v; 1428 #if HWY_TARGET <= HWY_AVX3_DL 1429 return Vec512<uint16_t>{_mm512_shrdi_epi16(v.raw, v.raw, kBits)}; 1430 #else 1431 // AVX3 does not support 16-bit. 1432 return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(15, 16 - kBits)>(v)); 1433 #endif 1434 } 1435 1436 template <int kBits> 1437 HWY_API Vec512<uint32_t> RotateRight(const Vec512<uint32_t> v) { 1438 static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); 1439 if (kBits == 0) return v; 1440 return Vec512<uint32_t>{_mm512_ror_epi32(v.raw, kBits)}; 1441 } 1442 1443 template <int kBits> 1444 HWY_API Vec512<uint64_t> RotateRight(const Vec512<uint64_t> v) { 1445 static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); 1446 if (kBits == 0) return v; 1447 return Vec512<uint64_t>{_mm512_ror_epi64(v.raw, kBits)}; 1448 } 1449 1450 // ------------------------------ Rol/Ror 1451 #if HWY_TARGET <= HWY_AVX3_DL 1452 template <class T, HWY_IF_UI16(T)> 1453 HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) { 1454 return Vec512<T>{_mm512_shrdv_epi16(a.raw, a.raw, b.raw)}; 1455 } 1456 #endif // HWY_TARGET <= HWY_AVX3_DL 1457 1458 template <class T, HWY_IF_UI32(T)> 1459 HWY_API Vec512<T> Rol(Vec512<T> a, Vec512<T> b) { 1460 return Vec512<T>{_mm512_rolv_epi32(a.raw, b.raw)}; 1461 } 1462 1463 template <class T, HWY_IF_UI32(T)> 1464 HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) { 1465 return Vec512<T>{_mm512_rorv_epi32(a.raw, b.raw)}; 1466 } 1467 1468 template <class T, HWY_IF_UI64(T)> 1469 HWY_API Vec512<T> Rol(Vec512<T> a, Vec512<T> b) { 1470 return Vec512<T>{_mm512_rolv_epi64(a.raw, b.raw)}; 1471 } 1472 1473 template <class T, HWY_IF_UI64(T)> 1474 HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) { 1475 return Vec512<T>{_mm512_rorv_epi64(a.raw, b.raw)}; 1476 } 1477 1478 // ------------------------------ ShiftLeftSame 1479 1480 // GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512 1481 // shift-with-immediate: the counts should all be unsigned int. Despite casting, 1482 // we still see warnings in GCC debug builds, hence disable. 1483 HWY_DIAGNOSTICS(push) 1484 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1485 1486 #if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100 1487 using Shift16Count = int; 1488 using Shift3264Count = int; 1489 #elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 1490 // GCC 11.0 requires these, prior versions used a macro+cast and don't care. 1491 using Shift16Count = int; 1492 using Shift3264Count = unsigned int; 1493 #else 1494 // Assume documented behavior. Clang 11, GCC 14 and MSVC 14.28.29910 match this. 1495 using Shift16Count = unsigned int; 1496 using Shift3264Count = unsigned int; 1497 #endif 1498 1499 HWY_API Vec512<uint16_t> ShiftLeftSame(const Vec512<uint16_t> v, 1500 const int bits) { 1501 #if HWY_COMPILER_GCC 1502 if (__builtin_constant_p(bits)) { 1503 return Vec512<uint16_t>{ 1504 _mm512_slli_epi16(v.raw, static_cast<Shift16Count>(bits))}; 1505 } 1506 #endif 1507 return Vec512<uint16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 1508 } 1509 HWY_API Vec512<uint32_t> ShiftLeftSame(const Vec512<uint32_t> v, 1510 const int bits) { 1511 #if HWY_COMPILER_GCC 1512 if (__builtin_constant_p(bits)) { 1513 return Vec512<uint32_t>{ 1514 _mm512_slli_epi32(v.raw, static_cast<Shift3264Count>(bits))}; 1515 } 1516 #endif 1517 return Vec512<uint32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 1518 } 1519 HWY_API Vec512<uint64_t> ShiftLeftSame(const Vec512<uint64_t> v, 1520 const int bits) { 1521 #if HWY_COMPILER_GCC 1522 if (__builtin_constant_p(bits)) { 1523 return Vec512<uint64_t>{ 1524 _mm512_slli_epi64(v.raw, static_cast<Shift3264Count>(bits))}; 1525 } 1526 #endif 1527 return Vec512<uint64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 1528 } 1529 1530 HWY_API Vec512<int16_t> ShiftLeftSame(const Vec512<int16_t> v, const int bits) { 1531 #if HWY_COMPILER_GCC 1532 if (__builtin_constant_p(bits)) { 1533 return Vec512<int16_t>{ 1534 _mm512_slli_epi16(v.raw, static_cast<Shift16Count>(bits))}; 1535 } 1536 #endif 1537 return Vec512<int16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 1538 } 1539 1540 HWY_API Vec512<int32_t> ShiftLeftSame(const Vec512<int32_t> v, const int bits) { 1541 #if HWY_COMPILER_GCC 1542 if (__builtin_constant_p(bits)) { 1543 return Vec512<int32_t>{ 1544 _mm512_slli_epi32(v.raw, static_cast<Shift3264Count>(bits))}; 1545 } 1546 #endif 1547 return Vec512<int32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 1548 } 1549 1550 HWY_API Vec512<int64_t> ShiftLeftSame(const Vec512<int64_t> v, const int bits) { 1551 #if HWY_COMPILER_GCC 1552 if (__builtin_constant_p(bits)) { 1553 return Vec512<int64_t>{ 1554 _mm512_slli_epi64(v.raw, static_cast<Shift3264Count>(bits))}; 1555 } 1556 #endif 1557 return Vec512<int64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 1558 } 1559 1560 template <typename T, HWY_IF_T_SIZE(T, 1)> 1561 HWY_API Vec512<T> ShiftLeftSame(const Vec512<T> v, const int bits) { 1562 const DFromV<decltype(v)> d8; 1563 const RepartitionToWide<decltype(d8)> d16; 1564 const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); 1565 return shifted & Set(d8, static_cast<T>((0xFF << bits) & 0xFF)); 1566 } 1567 1568 // ------------------------------ ShiftRightSame 1569 1570 HWY_API Vec512<uint16_t> ShiftRightSame(const Vec512<uint16_t> v, 1571 const int bits) { 1572 #if HWY_COMPILER_GCC 1573 if (__builtin_constant_p(bits)) { 1574 return Vec512<uint16_t>{ 1575 _mm512_srli_epi16(v.raw, static_cast<Shift16Count>(bits))}; 1576 } 1577 #endif 1578 return Vec512<uint16_t>{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 1579 } 1580 HWY_API Vec512<uint32_t> ShiftRightSame(const Vec512<uint32_t> v, 1581 const int bits) { 1582 #if HWY_COMPILER_GCC 1583 if (__builtin_constant_p(bits)) { 1584 return Vec512<uint32_t>{ 1585 _mm512_srli_epi32(v.raw, static_cast<Shift3264Count>(bits))}; 1586 } 1587 #endif 1588 return Vec512<uint32_t>{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 1589 } 1590 HWY_API Vec512<uint64_t> ShiftRightSame(const Vec512<uint64_t> v, 1591 const int bits) { 1592 #if HWY_COMPILER_GCC 1593 if (__builtin_constant_p(bits)) { 1594 return Vec512<uint64_t>{ 1595 _mm512_srli_epi64(v.raw, static_cast<Shift3264Count>(bits))}; 1596 } 1597 #endif 1598 return Vec512<uint64_t>{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 1599 } 1600 1601 HWY_API Vec512<uint8_t> ShiftRightSame(Vec512<uint8_t> v, const int bits) { 1602 const DFromV<decltype(v)> d8; 1603 const RepartitionToWide<decltype(d8)> d16; 1604 const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); 1605 return shifted & Set(d8, static_cast<uint8_t>(0xFF >> bits)); 1606 } 1607 1608 HWY_API Vec512<int16_t> ShiftRightSame(const Vec512<int16_t> v, 1609 const int bits) { 1610 #if HWY_COMPILER_GCC 1611 if (__builtin_constant_p(bits)) { 1612 return Vec512<int16_t>{ 1613 _mm512_srai_epi16(v.raw, static_cast<Shift16Count>(bits))}; 1614 } 1615 #endif 1616 return Vec512<int16_t>{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 1617 } 1618 1619 HWY_API Vec512<int32_t> ShiftRightSame(const Vec512<int32_t> v, 1620 const int bits) { 1621 #if HWY_COMPILER_GCC 1622 if (__builtin_constant_p(bits)) { 1623 return Vec512<int32_t>{ 1624 _mm512_srai_epi32(v.raw, static_cast<Shift3264Count>(bits))}; 1625 } 1626 #endif 1627 return Vec512<int32_t>{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 1628 } 1629 HWY_API Vec512<int64_t> ShiftRightSame(const Vec512<int64_t> v, 1630 const int bits) { 1631 #if HWY_COMPILER_GCC 1632 if (__builtin_constant_p(bits)) { 1633 return Vec512<int64_t>{ 1634 _mm512_srai_epi64(v.raw, static_cast<Shift3264Count>(bits))}; 1635 } 1636 #endif 1637 return Vec512<int64_t>{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 1638 } 1639 1640 HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v, const int bits) { 1641 const DFromV<decltype(v)> di; 1642 const RebindToUnsigned<decltype(di)> du; 1643 const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); 1644 const auto shifted_sign = 1645 BitCast(di, Set(du, static_cast<uint8_t>(0x80 >> bits))); 1646 return (shifted ^ shifted_sign) - shifted_sign; 1647 } 1648 1649 HWY_DIAGNOSTICS(pop) 1650 1651 // ------------------------------ Minimum 1652 1653 // Unsigned 1654 HWY_API Vec512<uint8_t> Min(Vec512<uint8_t> a, Vec512<uint8_t> b) { 1655 return Vec512<uint8_t>{_mm512_min_epu8(a.raw, b.raw)}; 1656 } 1657 HWY_API Vec512<uint16_t> Min(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1658 return Vec512<uint16_t>{_mm512_min_epu16(a.raw, b.raw)}; 1659 } 1660 HWY_API Vec512<uint32_t> Min(Vec512<uint32_t> a, Vec512<uint32_t> b) { 1661 return Vec512<uint32_t>{_mm512_min_epu32(a.raw, b.raw)}; 1662 } 1663 HWY_API Vec512<uint64_t> Min(Vec512<uint64_t> a, Vec512<uint64_t> b) { 1664 return Vec512<uint64_t>{_mm512_min_epu64(a.raw, b.raw)}; 1665 } 1666 1667 // Signed 1668 HWY_API Vec512<int8_t> Min(Vec512<int8_t> a, Vec512<int8_t> b) { 1669 return Vec512<int8_t>{_mm512_min_epi8(a.raw, b.raw)}; 1670 } 1671 HWY_API Vec512<int16_t> Min(Vec512<int16_t> a, Vec512<int16_t> b) { 1672 return Vec512<int16_t>{_mm512_min_epi16(a.raw, b.raw)}; 1673 } 1674 HWY_API Vec512<int32_t> Min(Vec512<int32_t> a, Vec512<int32_t> b) { 1675 return Vec512<int32_t>{_mm512_min_epi32(a.raw, b.raw)}; 1676 } 1677 HWY_API Vec512<int64_t> Min(Vec512<int64_t> a, Vec512<int64_t> b) { 1678 return Vec512<int64_t>{_mm512_min_epi64(a.raw, b.raw)}; 1679 } 1680 1681 // Float 1682 #if HWY_HAVE_FLOAT16 1683 HWY_API Vec512<float16_t> Min(Vec512<float16_t> a, Vec512<float16_t> b) { 1684 return Vec512<float16_t>{_mm512_min_ph(a.raw, b.raw)}; 1685 } 1686 #endif // HWY_HAVE_FLOAT16 1687 HWY_API Vec512<float> Min(Vec512<float> a, Vec512<float> b) { 1688 return Vec512<float>{_mm512_min_ps(a.raw, b.raw)}; 1689 } 1690 HWY_API Vec512<double> Min(Vec512<double> a, Vec512<double> b) { 1691 return Vec512<double>{_mm512_min_pd(a.raw, b.raw)}; 1692 } 1693 1694 // ------------------------------ Maximum 1695 1696 // Unsigned 1697 HWY_API Vec512<uint8_t> Max(Vec512<uint8_t> a, Vec512<uint8_t> b) { 1698 return Vec512<uint8_t>{_mm512_max_epu8(a.raw, b.raw)}; 1699 } 1700 HWY_API Vec512<uint16_t> Max(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1701 return Vec512<uint16_t>{_mm512_max_epu16(a.raw, b.raw)}; 1702 } 1703 HWY_API Vec512<uint32_t> Max(Vec512<uint32_t> a, Vec512<uint32_t> b) { 1704 return Vec512<uint32_t>{_mm512_max_epu32(a.raw, b.raw)}; 1705 } 1706 HWY_API Vec512<uint64_t> Max(Vec512<uint64_t> a, Vec512<uint64_t> b) { 1707 return Vec512<uint64_t>{_mm512_max_epu64(a.raw, b.raw)}; 1708 } 1709 1710 // Signed 1711 HWY_API Vec512<int8_t> Max(Vec512<int8_t> a, Vec512<int8_t> b) { 1712 return Vec512<int8_t>{_mm512_max_epi8(a.raw, b.raw)}; 1713 } 1714 HWY_API Vec512<int16_t> Max(Vec512<int16_t> a, Vec512<int16_t> b) { 1715 return Vec512<int16_t>{_mm512_max_epi16(a.raw, b.raw)}; 1716 } 1717 HWY_API Vec512<int32_t> Max(Vec512<int32_t> a, Vec512<int32_t> b) { 1718 return Vec512<int32_t>{_mm512_max_epi32(a.raw, b.raw)}; 1719 } 1720 HWY_API Vec512<int64_t> Max(Vec512<int64_t> a, Vec512<int64_t> b) { 1721 return Vec512<int64_t>{_mm512_max_epi64(a.raw, b.raw)}; 1722 } 1723 1724 // Float 1725 #if HWY_HAVE_FLOAT16 1726 HWY_API Vec512<float16_t> Max(Vec512<float16_t> a, Vec512<float16_t> b) { 1727 return Vec512<float16_t>{_mm512_max_ph(a.raw, b.raw)}; 1728 } 1729 #endif // HWY_HAVE_FLOAT16 1730 HWY_API Vec512<float> Max(Vec512<float> a, Vec512<float> b) { 1731 return Vec512<float>{_mm512_max_ps(a.raw, b.raw)}; 1732 } 1733 HWY_API Vec512<double> Max(Vec512<double> a, Vec512<double> b) { 1734 return Vec512<double>{_mm512_max_pd(a.raw, b.raw)}; 1735 } 1736 1737 // ------------------------------ MinNumber and MaxNumber 1738 1739 #if HWY_X86_HAVE_AVX10_2_OPS 1740 1741 #if HWY_HAVE_FLOAT16 1742 HWY_API Vec512<float16_t> MinNumber(Vec512<float16_t> a, Vec512<float16_t> b) { 1743 return Vec512<float16_t>{_mm512_minmax_ph(a.raw, b.raw, 0x14)}; 1744 } 1745 #endif 1746 HWY_API Vec512<float> MinNumber(Vec512<float> a, Vec512<float> b) { 1747 return Vec512<float>{_mm512_minmax_ps(a.raw, b.raw, 0x14)}; 1748 } 1749 HWY_API Vec512<double> MinNumber(Vec512<double> a, Vec512<double> b) { 1750 return Vec512<double>{_mm512_minmax_pd(a.raw, b.raw, 0x14)}; 1751 } 1752 1753 #if HWY_HAVE_FLOAT16 1754 HWY_API Vec512<float16_t> MaxNumber(Vec512<float16_t> a, Vec512<float16_t> b) { 1755 return Vec512<float16_t>{_mm512_minmax_ph(a.raw, b.raw, 0x15)}; 1756 } 1757 #endif 1758 HWY_API Vec512<float> MaxNumber(Vec512<float> a, Vec512<float> b) { 1759 return Vec512<float>{_mm512_minmax_ps(a.raw, b.raw, 0x15)}; 1760 } 1761 HWY_API Vec512<double> MaxNumber(Vec512<double> a, Vec512<double> b) { 1762 return Vec512<double>{_mm512_minmax_pd(a.raw, b.raw, 0x15)}; 1763 } 1764 1765 #endif 1766 1767 // ------------------------------ MinMagnitude and MaxMagnitude 1768 1769 #if HWY_X86_HAVE_AVX10_2_OPS 1770 1771 #if HWY_HAVE_FLOAT16 1772 HWY_API Vec512<float16_t> MinMagnitude(Vec512<float16_t> a, 1773 Vec512<float16_t> b) { 1774 return Vec512<float16_t>{_mm512_minmax_ph(a.raw, b.raw, 0x16)}; 1775 } 1776 #endif 1777 HWY_API Vec512<float> MinMagnitude(Vec512<float> a, Vec512<float> b) { 1778 return Vec512<float>{_mm512_minmax_ps(a.raw, b.raw, 0x16)}; 1779 } 1780 HWY_API Vec512<double> MinMagnitude(Vec512<double> a, Vec512<double> b) { 1781 return Vec512<double>{_mm512_minmax_pd(a.raw, b.raw, 0x16)}; 1782 } 1783 1784 #if HWY_HAVE_FLOAT16 1785 HWY_API Vec512<float16_t> MaxMagnitude(Vec512<float16_t> a, 1786 Vec512<float16_t> b) { 1787 return Vec512<float16_t>{_mm512_minmax_ph(a.raw, b.raw, 0x17)}; 1788 } 1789 #endif 1790 HWY_API Vec512<float> MaxMagnitude(Vec512<float> a, Vec512<float> b) { 1791 return Vec512<float>{_mm512_minmax_ps(a.raw, b.raw, 0x17)}; 1792 } 1793 HWY_API Vec512<double> MaxMagnitude(Vec512<double> a, Vec512<double> b) { 1794 return Vec512<double>{_mm512_minmax_pd(a.raw, b.raw, 0x17)}; 1795 } 1796 1797 #endif 1798 1799 // ------------------------------ Integer multiplication 1800 1801 // Unsigned 1802 HWY_API Vec512<uint16_t> operator*(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1803 return Vec512<uint16_t>{_mm512_mullo_epi16(a.raw, b.raw)}; 1804 } 1805 HWY_API Vec512<uint32_t> operator*(Vec512<uint32_t> a, Vec512<uint32_t> b) { 1806 return Vec512<uint32_t>{_mm512_mullo_epi32(a.raw, b.raw)}; 1807 } 1808 HWY_API Vec512<uint64_t> operator*(Vec512<uint64_t> a, Vec512<uint64_t> b) { 1809 return Vec512<uint64_t>{_mm512_mullo_epi64(a.raw, b.raw)}; 1810 } 1811 1812 // Signed 1813 HWY_API Vec512<int16_t> operator*(Vec512<int16_t> a, Vec512<int16_t> b) { 1814 return Vec512<int16_t>{_mm512_mullo_epi16(a.raw, b.raw)}; 1815 } 1816 HWY_API Vec512<int32_t> operator*(Vec512<int32_t> a, Vec512<int32_t> b) { 1817 return Vec512<int32_t>{_mm512_mullo_epi32(a.raw, b.raw)}; 1818 } 1819 HWY_API Vec512<int64_t> operator*(Vec512<int64_t> a, Vec512<int64_t> b) { 1820 return Vec512<int64_t>{_mm512_mullo_epi64(a.raw, b.raw)}; 1821 } 1822 1823 // Returns the upper 16 bits of a * b in each lane. 1824 HWY_API Vec512<uint16_t> MulHigh(Vec512<uint16_t> a, Vec512<uint16_t> b) { 1825 return Vec512<uint16_t>{_mm512_mulhi_epu16(a.raw, b.raw)}; 1826 } 1827 HWY_API Vec512<int16_t> MulHigh(Vec512<int16_t> a, Vec512<int16_t> b) { 1828 return Vec512<int16_t>{_mm512_mulhi_epi16(a.raw, b.raw)}; 1829 } 1830 1831 HWY_API Vec512<int16_t> MulFixedPoint15(Vec512<int16_t> a, Vec512<int16_t> b) { 1832 return Vec512<int16_t>{_mm512_mulhrs_epi16(a.raw, b.raw)}; 1833 } 1834 1835 // Multiplies even lanes (0, 2 ..) and places the double-wide result into 1836 // even and the upper half into its odd neighbor lane. 1837 HWY_API Vec512<int64_t> MulEven(Vec512<int32_t> a, Vec512<int32_t> b) { 1838 return Vec512<int64_t>{_mm512_mul_epi32(a.raw, b.raw)}; 1839 } 1840 HWY_API Vec512<uint64_t> MulEven(Vec512<uint32_t> a, Vec512<uint32_t> b) { 1841 return Vec512<uint64_t>{_mm512_mul_epu32(a.raw, b.raw)}; 1842 } 1843 1844 // ------------------------------ Neg (Sub) 1845 1846 template <typename T, HWY_IF_FLOAT_OR_SPECIAL(T)> 1847 HWY_API Vec512<T> Neg(const Vec512<T> v) { 1848 const DFromV<decltype(v)> d; 1849 return Xor(v, SignBit(d)); 1850 } 1851 1852 template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)> 1853 HWY_API Vec512<T> Neg(const Vec512<T> v) { 1854 const DFromV<decltype(v)> d; 1855 return Zero(d) - v; 1856 } 1857 1858 // ------------------------------ Floating-point mul / div 1859 1860 #if HWY_HAVE_FLOAT16 1861 HWY_API Vec512<float16_t> operator*(Vec512<float16_t> a, Vec512<float16_t> b) { 1862 return Vec512<float16_t>{_mm512_mul_ph(a.raw, b.raw)}; 1863 } 1864 #endif // HWY_HAVE_FLOAT16 1865 HWY_API Vec512<float> operator*(Vec512<float> a, Vec512<float> b) { 1866 return Vec512<float>{_mm512_mul_ps(a.raw, b.raw)}; 1867 } 1868 HWY_API Vec512<double> operator*(Vec512<double> a, Vec512<double> b) { 1869 return Vec512<double>{_mm512_mul_pd(a.raw, b.raw)}; 1870 } 1871 1872 #if HWY_HAVE_FLOAT16 1873 HWY_API Vec512<float16_t> MulByFloorPow2(Vec512<float16_t> a, 1874 Vec512<float16_t> b) { 1875 return Vec512<float16_t>{_mm512_scalef_ph(a.raw, b.raw)}; 1876 } 1877 #endif 1878 1879 HWY_API Vec512<float> MulByFloorPow2(Vec512<float> a, Vec512<float> b) { 1880 return Vec512<float>{_mm512_scalef_ps(a.raw, b.raw)}; 1881 } 1882 1883 HWY_API Vec512<double> MulByFloorPow2(Vec512<double> a, Vec512<double> b) { 1884 return Vec512<double>{_mm512_scalef_pd(a.raw, b.raw)}; 1885 } 1886 1887 #if HWY_HAVE_FLOAT16 1888 HWY_API Vec512<float16_t> operator/(Vec512<float16_t> a, Vec512<float16_t> b) { 1889 return Vec512<float16_t>{_mm512_div_ph(a.raw, b.raw)}; 1890 } 1891 #endif // HWY_HAVE_FLOAT16 1892 HWY_API Vec512<float> operator/(Vec512<float> a, Vec512<float> b) { 1893 return Vec512<float>{_mm512_div_ps(a.raw, b.raw)}; 1894 } 1895 HWY_API Vec512<double> operator/(Vec512<double> a, Vec512<double> b) { 1896 return Vec512<double>{_mm512_div_pd(a.raw, b.raw)}; 1897 } 1898 1899 // Approximate reciprocal 1900 #if HWY_HAVE_FLOAT16 1901 HWY_API Vec512<float16_t> ApproximateReciprocal(const Vec512<float16_t> v) { 1902 return Vec512<float16_t>{_mm512_rcp_ph(v.raw)}; 1903 } 1904 #endif // HWY_HAVE_FLOAT16 1905 HWY_API Vec512<float> ApproximateReciprocal(const Vec512<float> v) { 1906 return Vec512<float>{_mm512_rcp14_ps(v.raw)}; 1907 } 1908 1909 HWY_API Vec512<double> ApproximateReciprocal(Vec512<double> v) { 1910 return Vec512<double>{_mm512_rcp14_pd(v.raw)}; 1911 } 1912 1913 // ------------------------------ GetExponent 1914 1915 #if HWY_HAVE_FLOAT16 1916 template <class V, HWY_IF_F16(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> 1917 HWY_API V GetExponent(V v) { 1918 return V{_mm512_getexp_ph(v.raw)}; 1919 } 1920 #endif 1921 template <class V, HWY_IF_F32(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> 1922 HWY_API V GetExponent(V v) { 1923 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 1924 HWY_DIAGNOSTICS(push) 1925 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1926 return V{_mm512_getexp_ps(v.raw)}; 1927 HWY_DIAGNOSTICS(pop) 1928 } 1929 template <class V, HWY_IF_F64(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> 1930 HWY_API V GetExponent(V v) { 1931 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 1932 HWY_DIAGNOSTICS(push) 1933 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1934 return V{_mm512_getexp_pd(v.raw)}; 1935 HWY_DIAGNOSTICS(pop) 1936 } 1937 1938 // ------------------------------ MaskedMinOr 1939 1940 template <typename T, HWY_IF_U8(T)> 1941 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1942 Vec512<T> b) { 1943 return Vec512<T>{_mm512_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; 1944 } 1945 template <typename T, HWY_IF_I8(T)> 1946 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1947 Vec512<T> b) { 1948 return Vec512<T>{_mm512_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; 1949 } 1950 1951 template <typename T, HWY_IF_U16(T)> 1952 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1953 Vec512<T> b) { 1954 return Vec512<T>{_mm512_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; 1955 } 1956 template <typename T, HWY_IF_I16(T)> 1957 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1958 Vec512<T> b) { 1959 return Vec512<T>{_mm512_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; 1960 } 1961 1962 template <typename T, HWY_IF_U32(T)> 1963 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1964 Vec512<T> b) { 1965 return Vec512<T>{_mm512_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; 1966 } 1967 template <typename T, HWY_IF_I32(T)> 1968 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1969 Vec512<T> b) { 1970 return Vec512<T>{_mm512_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; 1971 } 1972 1973 template <typename T, HWY_IF_U64(T)> 1974 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1975 Vec512<T> b) { 1976 return Vec512<T>{_mm512_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; 1977 } 1978 template <typename T, HWY_IF_I64(T)> 1979 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1980 Vec512<T> b) { 1981 return Vec512<T>{_mm512_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; 1982 } 1983 1984 template <typename T, HWY_IF_F32(T)> 1985 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1986 Vec512<T> b) { 1987 return Vec512<T>{_mm512_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; 1988 } 1989 1990 template <typename T, HWY_IF_F64(T)> 1991 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1992 Vec512<T> b) { 1993 return Vec512<T>{_mm512_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; 1994 } 1995 1996 #if HWY_HAVE_FLOAT16 1997 template <typename T, HWY_IF_F16(T)> 1998 HWY_API Vec512<T> MaskedMinOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 1999 Vec512<T> b) { 2000 return Vec512<T>{_mm512_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; 2001 } 2002 #endif // HWY_HAVE_FLOAT16 2003 2004 // ------------------------------ MaskedMaxOr 2005 2006 template <typename T, HWY_IF_U8(T)> 2007 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2008 Vec512<T> b) { 2009 return Vec512<T>{_mm512_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; 2010 } 2011 template <typename T, HWY_IF_I8(T)> 2012 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2013 Vec512<T> b) { 2014 return Vec512<T>{_mm512_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; 2015 } 2016 2017 template <typename T, HWY_IF_U16(T)> 2018 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2019 Vec512<T> b) { 2020 return Vec512<T>{_mm512_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; 2021 } 2022 template <typename T, HWY_IF_I16(T)> 2023 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2024 Vec512<T> b) { 2025 return Vec512<T>{_mm512_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; 2026 } 2027 2028 template <typename T, HWY_IF_U32(T)> 2029 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2030 Vec512<T> b) { 2031 return Vec512<T>{_mm512_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; 2032 } 2033 template <typename T, HWY_IF_I32(T)> 2034 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2035 Vec512<T> b) { 2036 return Vec512<T>{_mm512_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; 2037 } 2038 2039 template <typename T, HWY_IF_U64(T)> 2040 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2041 Vec512<T> b) { 2042 return Vec512<T>{_mm512_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; 2043 } 2044 template <typename T, HWY_IF_I64(T)> 2045 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2046 Vec512<T> b) { 2047 return Vec512<T>{_mm512_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; 2048 } 2049 2050 template <typename T, HWY_IF_F32(T)> 2051 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2052 Vec512<T> b) { 2053 return Vec512<T>{_mm512_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; 2054 } 2055 2056 template <typename T, HWY_IF_F64(T)> 2057 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2058 Vec512<T> b) { 2059 return Vec512<T>{_mm512_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; 2060 } 2061 2062 #if HWY_HAVE_FLOAT16 2063 template <typename T, HWY_IF_F16(T)> 2064 HWY_API Vec512<T> MaskedMaxOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2065 Vec512<T> b) { 2066 return Vec512<T>{_mm512_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; 2067 } 2068 #endif // HWY_HAVE_FLOAT16 2069 2070 // ------------------------------ MaskedAddOr 2071 2072 template <typename T, HWY_IF_UI8(T)> 2073 HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2074 Vec512<T> b) { 2075 return Vec512<T>{_mm512_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; 2076 } 2077 2078 template <typename T, HWY_IF_UI16(T)> 2079 HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2080 Vec512<T> b) { 2081 return Vec512<T>{_mm512_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; 2082 } 2083 2084 template <typename T, HWY_IF_UI32(T)> 2085 HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2086 Vec512<T> b) { 2087 return Vec512<T>{_mm512_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; 2088 } 2089 2090 template <typename T, HWY_IF_UI64(T)> 2091 HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2092 Vec512<T> b) { 2093 return Vec512<T>{_mm512_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; 2094 } 2095 2096 template <typename T, HWY_IF_F32(T)> 2097 HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2098 Vec512<T> b) { 2099 return Vec512<T>{_mm512_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; 2100 } 2101 2102 template <typename T, HWY_IF_F64(T)> 2103 HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2104 Vec512<T> b) { 2105 return Vec512<T>{_mm512_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; 2106 } 2107 2108 #if HWY_HAVE_FLOAT16 2109 template <typename T, HWY_IF_F16(T)> 2110 HWY_API Vec512<T> MaskedAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2111 Vec512<T> b) { 2112 return Vec512<T>{_mm512_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; 2113 } 2114 #endif // HWY_HAVE_FLOAT16 2115 2116 // ------------------------------ MaskedSubOr 2117 2118 template <typename T, HWY_IF_UI8(T)> 2119 HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2120 Vec512<T> b) { 2121 return Vec512<T>{_mm512_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; 2122 } 2123 2124 template <typename T, HWY_IF_UI16(T)> 2125 HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2126 Vec512<T> b) { 2127 return Vec512<T>{_mm512_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; 2128 } 2129 2130 template <typename T, HWY_IF_UI32(T)> 2131 HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2132 Vec512<T> b) { 2133 return Vec512<T>{_mm512_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; 2134 } 2135 2136 template <typename T, HWY_IF_UI64(T)> 2137 HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2138 Vec512<T> b) { 2139 return Vec512<T>{_mm512_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; 2140 } 2141 2142 template <typename T, HWY_IF_F32(T)> 2143 HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2144 Vec512<T> b) { 2145 return Vec512<T>{_mm512_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; 2146 } 2147 2148 template <typename T, HWY_IF_F64(T)> 2149 HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2150 Vec512<T> b) { 2151 return Vec512<T>{_mm512_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; 2152 } 2153 2154 #if HWY_HAVE_FLOAT16 2155 template <typename T, HWY_IF_F16(T)> 2156 HWY_API Vec512<T> MaskedSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2157 Vec512<T> b) { 2158 return Vec512<T>{_mm512_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; 2159 } 2160 #endif // HWY_HAVE_FLOAT16 2161 2162 // ------------------------------ MaskedMulOr 2163 2164 HWY_API Vec512<float> MaskedMulOr(Vec512<float> no, Mask512<float> m, 2165 Vec512<float> a, Vec512<float> b) { 2166 return Vec512<float>{_mm512_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; 2167 } 2168 2169 HWY_API Vec512<double> MaskedMulOr(Vec512<double> no, Mask512<double> m, 2170 Vec512<double> a, Vec512<double> b) { 2171 return Vec512<double>{_mm512_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; 2172 } 2173 2174 #if HWY_HAVE_FLOAT16 2175 HWY_API Vec512<float16_t> MaskedMulOr(Vec512<float16_t> no, 2176 Mask512<float16_t> m, Vec512<float16_t> a, 2177 Vec512<float16_t> b) { 2178 return Vec512<float16_t>{_mm512_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; 2179 } 2180 #endif // HWY_HAVE_FLOAT16 2181 2182 // ------------------------------ MaskedDivOr 2183 2184 HWY_API Vec512<float> MaskedDivOr(Vec512<float> no, Mask512<float> m, 2185 Vec512<float> a, Vec512<float> b) { 2186 return Vec512<float>{_mm512_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; 2187 } 2188 2189 HWY_API Vec512<double> MaskedDivOr(Vec512<double> no, Mask512<double> m, 2190 Vec512<double> a, Vec512<double> b) { 2191 return Vec512<double>{_mm512_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; 2192 } 2193 2194 #if HWY_HAVE_FLOAT16 2195 HWY_API Vec512<float16_t> MaskedDivOr(Vec512<float16_t> no, 2196 Mask512<float16_t> m, Vec512<float16_t> a, 2197 Vec512<float16_t> b) { 2198 return Vec512<float16_t>{_mm512_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; 2199 } 2200 #endif // HWY_HAVE_FLOAT16 2201 2202 // ------------------------------ MaskedSatAddOr 2203 2204 template <typename T, HWY_IF_I8(T)> 2205 HWY_API Vec512<T> MaskedSatAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2206 Vec512<T> b) { 2207 return Vec512<T>{_mm512_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; 2208 } 2209 2210 template <typename T, HWY_IF_U8(T)> 2211 HWY_API Vec512<T> MaskedSatAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2212 Vec512<T> b) { 2213 return Vec512<T>{_mm512_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; 2214 } 2215 2216 template <typename T, HWY_IF_I16(T)> 2217 HWY_API Vec512<T> MaskedSatAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2218 Vec512<T> b) { 2219 return Vec512<T>{_mm512_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; 2220 } 2221 2222 template <typename T, HWY_IF_U16(T)> 2223 HWY_API Vec512<T> MaskedSatAddOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2224 Vec512<T> b) { 2225 return Vec512<T>{_mm512_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; 2226 } 2227 2228 // ------------------------------ MaskedSatSubOr 2229 2230 template <typename T, HWY_IF_I8(T)> 2231 HWY_API Vec512<T> MaskedSatSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2232 Vec512<T> b) { 2233 return Vec512<T>{_mm512_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; 2234 } 2235 2236 template <typename T, HWY_IF_U8(T)> 2237 HWY_API Vec512<T> MaskedSatSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2238 Vec512<T> b) { 2239 return Vec512<T>{_mm512_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; 2240 } 2241 2242 template <typename T, HWY_IF_I16(T)> 2243 HWY_API Vec512<T> MaskedSatSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2244 Vec512<T> b) { 2245 return Vec512<T>{_mm512_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; 2246 } 2247 2248 template <typename T, HWY_IF_U16(T)> 2249 HWY_API Vec512<T> MaskedSatSubOr(Vec512<T> no, Mask512<T> m, Vec512<T> a, 2250 Vec512<T> b) { 2251 return Vec512<T>{_mm512_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; 2252 } 2253 2254 // ------------------------------ Floating-point multiply-add variants 2255 2256 #if HWY_HAVE_FLOAT16 2257 2258 HWY_API Vec512<float16_t> MulAdd(Vec512<float16_t> mul, Vec512<float16_t> x, 2259 Vec512<float16_t> add) { 2260 return Vec512<float16_t>{_mm512_fmadd_ph(mul.raw, x.raw, add.raw)}; 2261 } 2262 2263 HWY_API Vec512<float16_t> NegMulAdd(Vec512<float16_t> mul, Vec512<float16_t> x, 2264 Vec512<float16_t> add) { 2265 return Vec512<float16_t>{_mm512_fnmadd_ph(mul.raw, x.raw, add.raw)}; 2266 } 2267 2268 HWY_API Vec512<float16_t> MulSub(Vec512<float16_t> mul, Vec512<float16_t> x, 2269 Vec512<float16_t> sub) { 2270 return Vec512<float16_t>{_mm512_fmsub_ph(mul.raw, x.raw, sub.raw)}; 2271 } 2272 2273 HWY_API Vec512<float16_t> NegMulSub(Vec512<float16_t> mul, Vec512<float16_t> x, 2274 Vec512<float16_t> sub) { 2275 return Vec512<float16_t>{_mm512_fnmsub_ph(mul.raw, x.raw, sub.raw)}; 2276 } 2277 2278 #endif // HWY_HAVE_FLOAT16 2279 2280 // Returns mul * x + add 2281 HWY_API Vec512<float> MulAdd(Vec512<float> mul, Vec512<float> x, 2282 Vec512<float> add) { 2283 return Vec512<float>{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; 2284 } 2285 HWY_API Vec512<double> MulAdd(Vec512<double> mul, Vec512<double> x, 2286 Vec512<double> add) { 2287 return Vec512<double>{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; 2288 } 2289 2290 // Returns add - mul * x 2291 HWY_API Vec512<float> NegMulAdd(Vec512<float> mul, Vec512<float> x, 2292 Vec512<float> add) { 2293 return Vec512<float>{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; 2294 } 2295 HWY_API Vec512<double> NegMulAdd(Vec512<double> mul, Vec512<double> x, 2296 Vec512<double> add) { 2297 return Vec512<double>{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; 2298 } 2299 2300 // Returns mul * x - sub 2301 HWY_API Vec512<float> MulSub(Vec512<float> mul, Vec512<float> x, 2302 Vec512<float> sub) { 2303 return Vec512<float>{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; 2304 } 2305 HWY_API Vec512<double> MulSub(Vec512<double> mul, Vec512<double> x, 2306 Vec512<double> sub) { 2307 return Vec512<double>{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; 2308 } 2309 2310 // Returns -mul * x - sub 2311 HWY_API Vec512<float> NegMulSub(Vec512<float> mul, Vec512<float> x, 2312 Vec512<float> sub) { 2313 return Vec512<float>{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; 2314 } 2315 HWY_API Vec512<double> NegMulSub(Vec512<double> mul, Vec512<double> x, 2316 Vec512<double> sub) { 2317 return Vec512<double>{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; 2318 } 2319 2320 #if HWY_HAVE_FLOAT16 2321 HWY_API Vec512<float16_t> MulAddSub(Vec512<float16_t> mul, Vec512<float16_t> x, 2322 Vec512<float16_t> sub_or_add) { 2323 return Vec512<float16_t>{_mm512_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; 2324 } 2325 #endif // HWY_HAVE_FLOAT16 2326 2327 HWY_API Vec512<float> MulAddSub(Vec512<float> mul, Vec512<float> x, 2328 Vec512<float> sub_or_add) { 2329 return Vec512<float>{_mm512_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; 2330 } 2331 2332 HWY_API Vec512<double> MulAddSub(Vec512<double> mul, Vec512<double> x, 2333 Vec512<double> sub_or_add) { 2334 return Vec512<double>{_mm512_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; 2335 } 2336 2337 // ------------------------------ Floating-point square root 2338 2339 // Full precision square root 2340 #if HWY_HAVE_FLOAT16 2341 HWY_API Vec512<float16_t> Sqrt(const Vec512<float16_t> v) { 2342 return Vec512<float16_t>{_mm512_sqrt_ph(v.raw)}; 2343 } 2344 #endif // HWY_HAVE_FLOAT16 2345 HWY_API Vec512<float> Sqrt(const Vec512<float> v) { 2346 return Vec512<float>{_mm512_sqrt_ps(v.raw)}; 2347 } 2348 HWY_API Vec512<double> Sqrt(const Vec512<double> v) { 2349 return Vec512<double>{_mm512_sqrt_pd(v.raw)}; 2350 } 2351 2352 // Approximate reciprocal square root 2353 #if HWY_HAVE_FLOAT16 2354 HWY_API Vec512<float16_t> ApproximateReciprocalSqrt(Vec512<float16_t> v) { 2355 return Vec512<float16_t>{_mm512_rsqrt_ph(v.raw)}; 2356 } 2357 #endif // HWY_HAVE_FLOAT16 2358 HWY_API Vec512<float> ApproximateReciprocalSqrt(Vec512<float> v) { 2359 return Vec512<float>{_mm512_rsqrt14_ps(v.raw)}; 2360 } 2361 2362 HWY_API Vec512<double> ApproximateReciprocalSqrt(Vec512<double> v) { 2363 return Vec512<double>{_mm512_rsqrt14_pd(v.raw)}; 2364 } 2365 2366 // ------------------------------ Floating-point rounding 2367 2368 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 2369 HWY_DIAGNOSTICS(push) 2370 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 2371 2372 // Toward nearest integer, tie to even 2373 #if HWY_HAVE_FLOAT16 2374 HWY_API Vec512<float16_t> Round(Vec512<float16_t> v) { 2375 return Vec512<float16_t>{_mm512_roundscale_ph( 2376 v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; 2377 } 2378 #endif // HWY_HAVE_FLOAT16 2379 HWY_API Vec512<float> Round(Vec512<float> v) { 2380 return Vec512<float>{_mm512_roundscale_ps( 2381 v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; 2382 } 2383 HWY_API Vec512<double> Round(Vec512<double> v) { 2384 return Vec512<double>{_mm512_roundscale_pd( 2385 v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; 2386 } 2387 2388 // Toward zero, aka truncate 2389 #if HWY_HAVE_FLOAT16 2390 HWY_API Vec512<float16_t> Trunc(Vec512<float16_t> v) { 2391 return Vec512<float16_t>{ 2392 _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; 2393 } 2394 #endif // HWY_HAVE_FLOAT16 2395 HWY_API Vec512<float> Trunc(Vec512<float> v) { 2396 return Vec512<float>{ 2397 _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; 2398 } 2399 HWY_API Vec512<double> Trunc(Vec512<double> v) { 2400 return Vec512<double>{ 2401 _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; 2402 } 2403 2404 // Toward +infinity, aka ceiling 2405 #if HWY_HAVE_FLOAT16 2406 HWY_API Vec512<float16_t> Ceil(Vec512<float16_t> v) { 2407 return Vec512<float16_t>{ 2408 _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; 2409 } 2410 #endif // HWY_HAVE_FLOAT16 2411 HWY_API Vec512<float> Ceil(Vec512<float> v) { 2412 return Vec512<float>{ 2413 _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; 2414 } 2415 HWY_API Vec512<double> Ceil(Vec512<double> v) { 2416 return Vec512<double>{ 2417 _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; 2418 } 2419 2420 // Toward -infinity, aka floor 2421 #if HWY_HAVE_FLOAT16 2422 HWY_API Vec512<float16_t> Floor(Vec512<float16_t> v) { 2423 return Vec512<float16_t>{ 2424 _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; 2425 } 2426 #endif // HWY_HAVE_FLOAT16 2427 HWY_API Vec512<float> Floor(Vec512<float> v) { 2428 return Vec512<float>{ 2429 _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; 2430 } 2431 HWY_API Vec512<double> Floor(Vec512<double> v) { 2432 return Vec512<double>{ 2433 _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; 2434 } 2435 2436 HWY_DIAGNOSTICS(pop) 2437 2438 // ================================================== COMPARE 2439 2440 // Comparisons set a mask bit to 1 if the condition is true, else 0. 2441 2442 template <class DTo, typename TFrom> 2443 HWY_API MFromD<DTo> RebindMask(DTo /*tag*/, Mask512<TFrom> m) { 2444 static_assert(sizeof(TFrom) == sizeof(TFromD<DTo>), "Must have same size"); 2445 return MFromD<DTo>{m.raw}; 2446 } 2447 2448 namespace detail { 2449 2450 template <typename T> 2451 HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<1> /*tag*/, Vec512<T> v, 2452 Vec512<T> bit) { 2453 return Mask512<T>{_mm512_test_epi8_mask(v.raw, bit.raw)}; 2454 } 2455 template <typename T> 2456 HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<2> /*tag*/, Vec512<T> v, 2457 Vec512<T> bit) { 2458 return Mask512<T>{_mm512_test_epi16_mask(v.raw, bit.raw)}; 2459 } 2460 template <typename T> 2461 HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<4> /*tag*/, Vec512<T> v, 2462 Vec512<T> bit) { 2463 return Mask512<T>{_mm512_test_epi32_mask(v.raw, bit.raw)}; 2464 } 2465 template <typename T> 2466 HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<8> /*tag*/, Vec512<T> v, 2467 Vec512<T> bit) { 2468 return Mask512<T>{_mm512_test_epi64_mask(v.raw, bit.raw)}; 2469 } 2470 2471 } // namespace detail 2472 2473 template <typename T> 2474 HWY_API Mask512<T> TestBit(const Vec512<T> v, const Vec512<T> bit) { 2475 static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); 2476 return detail::TestBit(hwy::SizeTag<sizeof(T)>(), v, bit); 2477 } 2478 2479 // ------------------------------ Equality 2480 2481 template <typename T, HWY_IF_T_SIZE(T, 1)> 2482 HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { 2483 return Mask512<T>{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; 2484 } 2485 template <typename T, HWY_IF_T_SIZE(T, 2)> 2486 HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { 2487 return Mask512<T>{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; 2488 } 2489 template <typename T, HWY_IF_UI32(T)> 2490 HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { 2491 return Mask512<T>{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; 2492 } 2493 template <typename T, HWY_IF_UI64(T)> 2494 HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { 2495 return Mask512<T>{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; 2496 } 2497 2498 #if HWY_HAVE_FLOAT16 2499 HWY_API Mask512<float16_t> operator==(Vec512<float16_t> a, 2500 Vec512<float16_t> b) { 2501 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 2502 HWY_DIAGNOSTICS(push) 2503 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 2504 return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; 2505 HWY_DIAGNOSTICS(pop) 2506 } 2507 #endif // HWY_HAVE_FLOAT16 2508 2509 HWY_API Mask512<float> operator==(Vec512<float> a, Vec512<float> b) { 2510 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; 2511 } 2512 2513 HWY_API Mask512<double> operator==(Vec512<double> a, Vec512<double> b) { 2514 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; 2515 } 2516 2517 // ------------------------------ Inequality 2518 2519 template <typename T, HWY_IF_T_SIZE(T, 1)> 2520 HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { 2521 return Mask512<T>{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; 2522 } 2523 template <typename T, HWY_IF_T_SIZE(T, 2)> 2524 HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { 2525 return Mask512<T>{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; 2526 } 2527 template <typename T, HWY_IF_UI32(T)> 2528 HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { 2529 return Mask512<T>{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; 2530 } 2531 template <typename T, HWY_IF_UI64(T)> 2532 HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { 2533 return Mask512<T>{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; 2534 } 2535 2536 #if HWY_HAVE_FLOAT16 2537 HWY_API Mask512<float16_t> operator!=(Vec512<float16_t> a, 2538 Vec512<float16_t> b) { 2539 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 2540 HWY_DIAGNOSTICS(push) 2541 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 2542 return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; 2543 HWY_DIAGNOSTICS(pop) 2544 } 2545 #endif // HWY_HAVE_FLOAT16 2546 2547 HWY_API Mask512<float> operator!=(Vec512<float> a, Vec512<float> b) { 2548 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; 2549 } 2550 2551 HWY_API Mask512<double> operator!=(Vec512<double> a, Vec512<double> b) { 2552 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; 2553 } 2554 2555 // ------------------------------ Strict inequality 2556 2557 HWY_API Mask512<uint8_t> operator>(Vec512<uint8_t> a, Vec512<uint8_t> b) { 2558 return Mask512<uint8_t>{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; 2559 } 2560 HWY_API Mask512<uint16_t> operator>(Vec512<uint16_t> a, Vec512<uint16_t> b) { 2561 return Mask512<uint16_t>{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; 2562 } 2563 HWY_API Mask512<uint32_t> operator>(Vec512<uint32_t> a, Vec512<uint32_t> b) { 2564 return Mask512<uint32_t>{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; 2565 } 2566 HWY_API Mask512<uint64_t> operator>(Vec512<uint64_t> a, Vec512<uint64_t> b) { 2567 return Mask512<uint64_t>{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; 2568 } 2569 2570 HWY_API Mask512<int8_t> operator>(Vec512<int8_t> a, Vec512<int8_t> b) { 2571 return Mask512<int8_t>{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; 2572 } 2573 HWY_API Mask512<int16_t> operator>(Vec512<int16_t> a, Vec512<int16_t> b) { 2574 return Mask512<int16_t>{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; 2575 } 2576 HWY_API Mask512<int32_t> operator>(Vec512<int32_t> a, Vec512<int32_t> b) { 2577 return Mask512<int32_t>{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; 2578 } 2579 HWY_API Mask512<int64_t> operator>(Vec512<int64_t> a, Vec512<int64_t> b) { 2580 return Mask512<int64_t>{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; 2581 } 2582 2583 #if HWY_HAVE_FLOAT16 2584 HWY_API Mask512<float16_t> operator>(Vec512<float16_t> a, Vec512<float16_t> b) { 2585 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 2586 HWY_DIAGNOSTICS(push) 2587 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 2588 return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; 2589 HWY_DIAGNOSTICS(pop) 2590 } 2591 #endif // HWY_HAVE_FLOAT16 2592 2593 HWY_API Mask512<float> operator>(Vec512<float> a, Vec512<float> b) { 2594 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; 2595 } 2596 HWY_API Mask512<double> operator>(Vec512<double> a, Vec512<double> b) { 2597 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; 2598 } 2599 2600 // ------------------------------ Weak inequality 2601 2602 #if HWY_HAVE_FLOAT16 2603 HWY_API Mask512<float16_t> operator>=(Vec512<float16_t> a, 2604 Vec512<float16_t> b) { 2605 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 2606 HWY_DIAGNOSTICS(push) 2607 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 2608 return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; 2609 HWY_DIAGNOSTICS(pop) 2610 } 2611 #endif // HWY_HAVE_FLOAT16 2612 2613 HWY_API Mask512<float> operator>=(Vec512<float> a, Vec512<float> b) { 2614 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; 2615 } 2616 HWY_API Mask512<double> operator>=(Vec512<double> a, Vec512<double> b) { 2617 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; 2618 } 2619 2620 HWY_API Mask512<uint8_t> operator>=(Vec512<uint8_t> a, Vec512<uint8_t> b) { 2621 return Mask512<uint8_t>{_mm512_cmpge_epu8_mask(a.raw, b.raw)}; 2622 } 2623 HWY_API Mask512<uint16_t> operator>=(Vec512<uint16_t> a, Vec512<uint16_t> b) { 2624 return Mask512<uint16_t>{_mm512_cmpge_epu16_mask(a.raw, b.raw)}; 2625 } 2626 HWY_API Mask512<uint32_t> operator>=(Vec512<uint32_t> a, Vec512<uint32_t> b) { 2627 return Mask512<uint32_t>{_mm512_cmpge_epu32_mask(a.raw, b.raw)}; 2628 } 2629 HWY_API Mask512<uint64_t> operator>=(Vec512<uint64_t> a, Vec512<uint64_t> b) { 2630 return Mask512<uint64_t>{_mm512_cmpge_epu64_mask(a.raw, b.raw)}; 2631 } 2632 2633 HWY_API Mask512<int8_t> operator>=(Vec512<int8_t> a, Vec512<int8_t> b) { 2634 return Mask512<int8_t>{_mm512_cmpge_epi8_mask(a.raw, b.raw)}; 2635 } 2636 HWY_API Mask512<int16_t> operator>=(Vec512<int16_t> a, Vec512<int16_t> b) { 2637 return Mask512<int16_t>{_mm512_cmpge_epi16_mask(a.raw, b.raw)}; 2638 } 2639 HWY_API Mask512<int32_t> operator>=(Vec512<int32_t> a, Vec512<int32_t> b) { 2640 return Mask512<int32_t>{_mm512_cmpge_epi32_mask(a.raw, b.raw)}; 2641 } 2642 HWY_API Mask512<int64_t> operator>=(Vec512<int64_t> a, Vec512<int64_t> b) { 2643 return Mask512<int64_t>{_mm512_cmpge_epi64_mask(a.raw, b.raw)}; 2644 } 2645 2646 // ------------------------------ Reversed comparisons 2647 2648 template <typename T> 2649 HWY_API Mask512<T> operator<(Vec512<T> a, Vec512<T> b) { 2650 return b > a; 2651 } 2652 2653 template <typename T> 2654 HWY_API Mask512<T> operator<=(Vec512<T> a, Vec512<T> b) { 2655 return b >= a; 2656 } 2657 2658 // ------------------------------ Mask 2659 2660 template <typename T, HWY_IF_UI8(T)> 2661 HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { 2662 return Mask512<T>{_mm512_movepi8_mask(v.raw)}; 2663 } 2664 template <typename T, HWY_IF_UI16(T)> 2665 HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { 2666 return Mask512<T>{_mm512_movepi16_mask(v.raw)}; 2667 } 2668 template <typename T, HWY_IF_UI32(T)> 2669 HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { 2670 return Mask512<T>{_mm512_movepi32_mask(v.raw)}; 2671 } 2672 template <typename T, HWY_IF_UI64(T)> 2673 HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { 2674 return Mask512<T>{_mm512_movepi64_mask(v.raw)}; 2675 } 2676 template <typename T, HWY_IF_FLOAT_OR_SPECIAL(T)> 2677 HWY_API Mask512<T> MaskFromVec(Vec512<T> v) { 2678 const RebindToSigned<DFromV<decltype(v)>> di; 2679 return Mask512<T>{MaskFromVec(BitCast(di, v)).raw}; 2680 } 2681 2682 template <typename T, HWY_IF_UI8(T)> 2683 HWY_API Vec512<T> VecFromMask(Mask512<T> m) { 2684 return Vec512<T>{_mm512_movm_epi8(m.raw)}; 2685 } 2686 template <typename T, HWY_IF_UI16(T)> 2687 HWY_API Vec512<T> VecFromMask(Mask512<T> m) { 2688 return Vec512<T>{_mm512_movm_epi16(m.raw)}; 2689 } 2690 #if HWY_HAVE_FLOAT16 2691 HWY_API Vec512<float16_t> VecFromMask(Mask512<float16_t> m) { 2692 return Vec512<float16_t>{_mm512_castsi512_ph(_mm512_movm_epi16(m.raw))}; 2693 } 2694 #endif // HWY_HAVE_FLOAT16 2695 template <typename T, HWY_IF_UI32(T)> 2696 HWY_API Vec512<T> VecFromMask(Mask512<T> m) { 2697 return Vec512<T>{_mm512_movm_epi32(m.raw)}; 2698 } 2699 template <typename T, HWY_IF_UI64(T)> 2700 HWY_API Vec512<T> VecFromMask(Mask512<T> m) { 2701 return Vec512<T>{_mm512_movm_epi64(m.raw)}; 2702 } 2703 template <typename T, HWY_IF_FLOAT_OR_SPECIAL(T)> 2704 HWY_API Vec512<T> VecFromMask(Mask512<T> m) { 2705 const Full512<T> d; 2706 const Full512<MakeSigned<T>> di; 2707 return BitCast(d, VecFromMask(RebindMask(di, m))); 2708 } 2709 2710 // ------------------------------ Mask logical 2711 2712 namespace detail { 2713 2714 template <typename T> 2715 HWY_INLINE Mask512<T> Not(hwy::SizeTag<1> /*tag*/, Mask512<T> m) { 2716 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2717 return Mask512<T>{_knot_mask64(m.raw)}; 2718 #else 2719 return Mask512<T>{~m.raw}; 2720 #endif 2721 } 2722 template <typename T> 2723 HWY_INLINE Mask512<T> Not(hwy::SizeTag<2> /*tag*/, Mask512<T> m) { 2724 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2725 return Mask512<T>{_knot_mask32(m.raw)}; 2726 #else 2727 return Mask512<T>{~m.raw}; 2728 #endif 2729 } 2730 template <typename T> 2731 HWY_INLINE Mask512<T> Not(hwy::SizeTag<4> /*tag*/, Mask512<T> m) { 2732 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2733 return Mask512<T>{_knot_mask16(m.raw)}; 2734 #else 2735 return Mask512<T>{static_cast<uint16_t>(~m.raw & 0xFFFF)}; 2736 #endif 2737 } 2738 template <typename T> 2739 HWY_INLINE Mask512<T> Not(hwy::SizeTag<8> /*tag*/, Mask512<T> m) { 2740 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2741 return Mask512<T>{_knot_mask8(m.raw)}; 2742 #else 2743 return Mask512<T>{static_cast<uint8_t>(~m.raw & 0xFF)}; 2744 #endif 2745 } 2746 2747 template <typename T> 2748 HWY_INLINE Mask512<T> And(hwy::SizeTag<1> /*tag*/, Mask512<T> a, Mask512<T> b) { 2749 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2750 return Mask512<T>{_kand_mask64(a.raw, b.raw)}; 2751 #else 2752 return Mask512<T>{a.raw & b.raw}; 2753 #endif 2754 } 2755 template <typename T> 2756 HWY_INLINE Mask512<T> And(hwy::SizeTag<2> /*tag*/, Mask512<T> a, Mask512<T> b) { 2757 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2758 return Mask512<T>{_kand_mask32(a.raw, b.raw)}; 2759 #else 2760 return Mask512<T>{a.raw & b.raw}; 2761 #endif 2762 } 2763 template <typename T> 2764 HWY_INLINE Mask512<T> And(hwy::SizeTag<4> /*tag*/, Mask512<T> a, Mask512<T> b) { 2765 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2766 return Mask512<T>{_kand_mask16(a.raw, b.raw)}; 2767 #else 2768 return Mask512<T>{static_cast<uint16_t>(a.raw & b.raw)}; 2769 #endif 2770 } 2771 template <typename T> 2772 HWY_INLINE Mask512<T> And(hwy::SizeTag<8> /*tag*/, Mask512<T> a, Mask512<T> b) { 2773 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2774 return Mask512<T>{_kand_mask8(a.raw, b.raw)}; 2775 #else 2776 return Mask512<T>{static_cast<uint8_t>(a.raw & b.raw)}; 2777 #endif 2778 } 2779 2780 template <typename T> 2781 HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<1> /*tag*/, Mask512<T> a, 2782 Mask512<T> b) { 2783 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2784 return Mask512<T>{_kandn_mask64(a.raw, b.raw)}; 2785 #else 2786 return Mask512<T>{~a.raw & b.raw}; 2787 #endif 2788 } 2789 template <typename T> 2790 HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<2> /*tag*/, Mask512<T> a, 2791 Mask512<T> b) { 2792 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2793 return Mask512<T>{_kandn_mask32(a.raw, b.raw)}; 2794 #else 2795 return Mask512<T>{~a.raw & b.raw}; 2796 #endif 2797 } 2798 template <typename T> 2799 HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<4> /*tag*/, Mask512<T> a, 2800 Mask512<T> b) { 2801 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2802 return Mask512<T>{_kandn_mask16(a.raw, b.raw)}; 2803 #else 2804 return Mask512<T>{static_cast<uint16_t>(~a.raw & b.raw)}; 2805 #endif 2806 } 2807 template <typename T> 2808 HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<8> /*tag*/, Mask512<T> a, 2809 Mask512<T> b) { 2810 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2811 return Mask512<T>{_kandn_mask8(a.raw, b.raw)}; 2812 #else 2813 return Mask512<T>{static_cast<uint8_t>(~a.raw & b.raw)}; 2814 #endif 2815 } 2816 2817 template <typename T> 2818 HWY_INLINE Mask512<T> Or(hwy::SizeTag<1> /*tag*/, Mask512<T> a, Mask512<T> b) { 2819 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2820 return Mask512<T>{_kor_mask64(a.raw, b.raw)}; 2821 #else 2822 return Mask512<T>{a.raw | b.raw}; 2823 #endif 2824 } 2825 template <typename T> 2826 HWY_INLINE Mask512<T> Or(hwy::SizeTag<2> /*tag*/, Mask512<T> a, Mask512<T> b) { 2827 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2828 return Mask512<T>{_kor_mask32(a.raw, b.raw)}; 2829 #else 2830 return Mask512<T>{a.raw | b.raw}; 2831 #endif 2832 } 2833 template <typename T> 2834 HWY_INLINE Mask512<T> Or(hwy::SizeTag<4> /*tag*/, Mask512<T> a, Mask512<T> b) { 2835 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2836 return Mask512<T>{_kor_mask16(a.raw, b.raw)}; 2837 #else 2838 return Mask512<T>{static_cast<uint16_t>(a.raw | b.raw)}; 2839 #endif 2840 } 2841 template <typename T> 2842 HWY_INLINE Mask512<T> Or(hwy::SizeTag<8> /*tag*/, Mask512<T> a, Mask512<T> b) { 2843 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2844 return Mask512<T>{_kor_mask8(a.raw, b.raw)}; 2845 #else 2846 return Mask512<T>{static_cast<uint8_t>(a.raw | b.raw)}; 2847 #endif 2848 } 2849 2850 template <typename T> 2851 HWY_INLINE Mask512<T> Xor(hwy::SizeTag<1> /*tag*/, Mask512<T> a, Mask512<T> b) { 2852 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2853 return Mask512<T>{_kxor_mask64(a.raw, b.raw)}; 2854 #else 2855 return Mask512<T>{a.raw ^ b.raw}; 2856 #endif 2857 } 2858 template <typename T> 2859 HWY_INLINE Mask512<T> Xor(hwy::SizeTag<2> /*tag*/, Mask512<T> a, Mask512<T> b) { 2860 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2861 return Mask512<T>{_kxor_mask32(a.raw, b.raw)}; 2862 #else 2863 return Mask512<T>{a.raw ^ b.raw}; 2864 #endif 2865 } 2866 template <typename T> 2867 HWY_INLINE Mask512<T> Xor(hwy::SizeTag<4> /*tag*/, Mask512<T> a, Mask512<T> b) { 2868 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2869 return Mask512<T>{_kxor_mask16(a.raw, b.raw)}; 2870 #else 2871 return Mask512<T>{static_cast<uint16_t>(a.raw ^ b.raw)}; 2872 #endif 2873 } 2874 template <typename T> 2875 HWY_INLINE Mask512<T> Xor(hwy::SizeTag<8> /*tag*/, Mask512<T> a, Mask512<T> b) { 2876 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2877 return Mask512<T>{_kxor_mask8(a.raw, b.raw)}; 2878 #else 2879 return Mask512<T>{static_cast<uint8_t>(a.raw ^ b.raw)}; 2880 #endif 2881 } 2882 2883 template <typename T> 2884 HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<1> /*tag*/, Mask512<T> a, 2885 Mask512<T> b) { 2886 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2887 return Mask512<T>{_kxnor_mask64(a.raw, b.raw)}; 2888 #else 2889 return Mask512<T>{~(a.raw ^ b.raw)}; 2890 #endif 2891 } 2892 template <typename T> 2893 HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<2> /*tag*/, Mask512<T> a, 2894 Mask512<T> b) { 2895 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2896 return Mask512<T>{_kxnor_mask32(a.raw, b.raw)}; 2897 #else 2898 return Mask512<T>{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; 2899 #endif 2900 } 2901 template <typename T> 2902 HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<4> /*tag*/, Mask512<T> a, 2903 Mask512<T> b) { 2904 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2905 return Mask512<T>{_kxnor_mask16(a.raw, b.raw)}; 2906 #else 2907 return Mask512<T>{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; 2908 #endif 2909 } 2910 template <typename T> 2911 HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<8> /*tag*/, Mask512<T> a, 2912 Mask512<T> b) { 2913 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2914 return Mask512<T>{_kxnor_mask8(a.raw, b.raw)}; 2915 #else 2916 return Mask512<T>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; 2917 #endif 2918 } 2919 2920 } // namespace detail 2921 2922 template <typename T> 2923 HWY_API Mask512<T> Not(Mask512<T> m) { 2924 return detail::Not(hwy::SizeTag<sizeof(T)>(), m); 2925 } 2926 2927 template <typename T> 2928 HWY_API Mask512<T> And(Mask512<T> a, Mask512<T> b) { 2929 return detail::And(hwy::SizeTag<sizeof(T)>(), a, b); 2930 } 2931 2932 template <typename T> 2933 HWY_API Mask512<T> AndNot(Mask512<T> a, Mask512<T> b) { 2934 return detail::AndNot(hwy::SizeTag<sizeof(T)>(), a, b); 2935 } 2936 2937 template <typename T> 2938 HWY_API Mask512<T> Or(Mask512<T> a, Mask512<T> b) { 2939 return detail::Or(hwy::SizeTag<sizeof(T)>(), a, b); 2940 } 2941 2942 template <typename T> 2943 HWY_API Mask512<T> Xor(Mask512<T> a, Mask512<T> b) { 2944 return detail::Xor(hwy::SizeTag<sizeof(T)>(), a, b); 2945 } 2946 2947 template <typename T> 2948 HWY_API Mask512<T> ExclusiveNeither(Mask512<T> a, Mask512<T> b) { 2949 return detail::ExclusiveNeither(hwy::SizeTag<sizeof(T)>(), a, b); 2950 } 2951 2952 template <class D, HWY_IF_LANES_D(D, 64)> 2953 HWY_API MFromD<D> CombineMasks(D /*d*/, MFromD<Half<D>> hi, 2954 MFromD<Half<D>> lo) { 2955 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2956 const __mmask64 combined_mask = _mm512_kunpackd( 2957 static_cast<__mmask64>(hi.raw), static_cast<__mmask64>(lo.raw)); 2958 #else 2959 const __mmask64 combined_mask = static_cast<__mmask64>( 2960 ((static_cast<uint64_t>(hi.raw) << 32) | (lo.raw & 0xFFFFFFFFULL))); 2961 #endif 2962 2963 return MFromD<D>{combined_mask}; 2964 } 2965 2966 template <class D, HWY_IF_LANES_D(D, 32)> 2967 HWY_API MFromD<D> UpperHalfOfMask(D /*d*/, MFromD<Twice<D>> m) { 2968 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2969 const auto shifted_mask = _kshiftri_mask64(static_cast<__mmask64>(m.raw), 32); 2970 #else 2971 const auto shifted_mask = static_cast<uint64_t>(m.raw) >> 32; 2972 #endif 2973 2974 return MFromD<D>{static_cast<decltype(MFromD<D>().raw)>(shifted_mask)}; 2975 } 2976 2977 template <class D, HWY_IF_LANES_D(D, 64)> 2978 HWY_API MFromD<D> SlideMask1Up(D /*d*/, MFromD<D> m) { 2979 using RawM = decltype(MFromD<D>().raw); 2980 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2981 return MFromD<D>{ 2982 static_cast<RawM>(_kshiftli_mask64(static_cast<__mmask64>(m.raw), 1))}; 2983 #else 2984 return MFromD<D>{static_cast<RawM>(static_cast<uint64_t>(m.raw) << 1)}; 2985 #endif 2986 } 2987 2988 template <class D, HWY_IF_LANES_D(D, 64)> 2989 HWY_API MFromD<D> SlideMask1Down(D /*d*/, MFromD<D> m) { 2990 using RawM = decltype(MFromD<D>().raw); 2991 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2992 return MFromD<D>{ 2993 static_cast<RawM>(_kshiftri_mask64(static_cast<__mmask64>(m.raw), 1))}; 2994 #else 2995 return MFromD<D>{static_cast<RawM>(static_cast<uint64_t>(m.raw) >> 1)}; 2996 #endif 2997 } 2998 2999 // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) 3000 3001 HWY_API Vec512<int8_t> BroadcastSignBit(Vec512<int8_t> v) { 3002 #if HWY_TARGET <= HWY_AVX3_DL 3003 const Repartition<uint64_t, DFromV<decltype(v)>> du64; 3004 return detail::GaloisAffine(v, Set(du64, 0x8080808080808080ull)); 3005 #else 3006 const DFromV<decltype(v)> d; 3007 return VecFromMask(v < Zero(d)); 3008 #endif 3009 } 3010 3011 HWY_API Vec512<int16_t> BroadcastSignBit(Vec512<int16_t> v) { 3012 return ShiftRight<15>(v); 3013 } 3014 3015 HWY_API Vec512<int32_t> BroadcastSignBit(Vec512<int32_t> v) { 3016 return ShiftRight<31>(v); 3017 } 3018 3019 HWY_API Vec512<int64_t> BroadcastSignBit(Vec512<int64_t> v) { 3020 return ShiftRight<63>(v); 3021 } 3022 3023 // ------------------------------ Floating-point classification (Not) 3024 3025 #if HWY_HAVE_FLOAT16 || HWY_IDE 3026 3027 namespace detail { 3028 3029 template <int kCategories> 3030 __mmask32 Fix_mm512_fpclass_ph_mask(__m512h v) { 3031 #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500 3032 // GCC's _mm512_cmp_ph_mask uses `__mmask8` instead of `__mmask32`, hence only 3033 // the first 8 lanes are set. 3034 return static_cast<__mmask32>(__builtin_ia32_fpclassph512_mask( 3035 static_cast<__v32hf>(v), kCategories, static_cast<__mmask32>(-1))); 3036 #else 3037 return _mm512_fpclass_ph_mask(v, kCategories); 3038 #endif 3039 } 3040 3041 } // namespace detail 3042 3043 HWY_API Mask512<float16_t> IsNaN(Vec512<float16_t> v) { 3044 constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN; 3045 return Mask512<float16_t>{ 3046 detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)}; 3047 } 3048 3049 HWY_API Mask512<float16_t> IsEitherNaN(Vec512<float16_t> a, 3050 Vec512<float16_t> b) { 3051 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 3052 HWY_DIAGNOSTICS(push) 3053 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 3054 return Mask512<float16_t>{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; 3055 HWY_DIAGNOSTICS(pop) 3056 } 3057 3058 HWY_API Mask512<float16_t> IsInf(Vec512<float16_t> v) { 3059 constexpr int kCategories = HWY_X86_FPCLASS_POS_INF | HWY_X86_FPCLASS_NEG_INF; 3060 return Mask512<float16_t>{ 3061 detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)}; 3062 } 3063 3064 // Returns whether normal/subnormal/zero. fpclass doesn't have a flag for 3065 // positive, so we have to check for inf/NaN and negate. 3066 HWY_API Mask512<float16_t> IsFinite(Vec512<float16_t> v) { 3067 constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | 3068 HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF; 3069 return Not(Mask512<float16_t>{ 3070 detail::Fix_mm512_fpclass_ph_mask<kCategories>(v.raw)}); 3071 } 3072 3073 #endif // HWY_HAVE_FLOAT16 3074 3075 HWY_API Mask512<float> IsNaN(Vec512<float> v) { 3076 return Mask512<float>{_mm512_fpclass_ps_mask( 3077 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; 3078 } 3079 HWY_API Mask512<double> IsNaN(Vec512<double> v) { 3080 return Mask512<double>{_mm512_fpclass_pd_mask( 3081 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; 3082 } 3083 3084 HWY_API Mask512<float> IsEitherNaN(Vec512<float> a, Vec512<float> b) { 3085 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; 3086 } 3087 3088 HWY_API Mask512<double> IsEitherNaN(Vec512<double> a, Vec512<double> b) { 3089 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; 3090 } 3091 3092 HWY_API Mask512<float> IsInf(Vec512<float> v) { 3093 return Mask512<float>{_mm512_fpclass_ps_mask( 3094 v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; 3095 } 3096 HWY_API Mask512<double> IsInf(Vec512<double> v) { 3097 return Mask512<double>{_mm512_fpclass_pd_mask( 3098 v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; 3099 } 3100 3101 // Returns whether normal/subnormal/zero. fpclass doesn't have a flag for 3102 // positive, so we have to check for inf/NaN and negate. 3103 HWY_API Mask512<float> IsFinite(Vec512<float> v) { 3104 return Not(Mask512<float>{_mm512_fpclass_ps_mask( 3105 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | 3106 HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); 3107 } 3108 HWY_API Mask512<double> IsFinite(Vec512<double> v) { 3109 return Not(Mask512<double>{_mm512_fpclass_pd_mask( 3110 v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | 3111 HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); 3112 } 3113 3114 // ================================================== MEMORY 3115 3116 // ------------------------------ Load 3117 3118 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3119 HWY_API VFromD<D> Load(D /* tag */, const TFromD<D>* HWY_RESTRICT aligned) { 3120 return VFromD<D>{_mm512_load_si512(aligned)}; 3121 } 3122 // bfloat16_t is handled by x86_128-inl.h. 3123 #if HWY_HAVE_FLOAT16 3124 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 3125 HWY_API Vec512<float16_t> Load(D /* tag */, 3126 const float16_t* HWY_RESTRICT aligned) { 3127 return Vec512<float16_t>{_mm512_load_ph(aligned)}; 3128 } 3129 #endif // HWY_HAVE_FLOAT16 3130 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3131 HWY_API Vec512<float> Load(D /* tag */, const float* HWY_RESTRICT aligned) { 3132 return Vec512<float>{_mm512_load_ps(aligned)}; 3133 } 3134 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3135 HWY_API VFromD<D> Load(D /* tag */, const double* HWY_RESTRICT aligned) { 3136 return VFromD<D>{_mm512_load_pd(aligned)}; 3137 } 3138 3139 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3140 HWY_API VFromD<D> LoadU(D /* tag */, const TFromD<D>* HWY_RESTRICT p) { 3141 return VFromD<D>{_mm512_loadu_si512(p)}; 3142 } 3143 3144 // bfloat16_t is handled by x86_128-inl.h. 3145 #if HWY_HAVE_FLOAT16 3146 template <class D, HWY_IF_V_SIZE_D(D, 64)> 3147 HWY_API Vec512<float16_t> LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { 3148 return Vec512<float16_t>{_mm512_loadu_ph(p)}; 3149 } 3150 #endif // HWY_HAVE_FLOAT16 3151 template <class D, HWY_IF_V_SIZE_D(D, 64)> 3152 HWY_API Vec512<float> LoadU(D /* tag */, const float* HWY_RESTRICT p) { 3153 return Vec512<float>{_mm512_loadu_ps(p)}; 3154 } 3155 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3156 HWY_API VFromD<D> LoadU(D /* tag */, const double* HWY_RESTRICT p) { 3157 return VFromD<D>{_mm512_loadu_pd(p)}; 3158 } 3159 3160 // ------------------------------ MaskedLoad 3161 3162 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 3163 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, 3164 const TFromD<D>* HWY_RESTRICT p) { 3165 return VFromD<D>{_mm512_maskz_loadu_epi8(m.raw, p)}; 3166 } 3167 3168 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 3169 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d, 3170 const TFromD<D>* HWY_RESTRICT p) { 3171 const RebindToUnsigned<D> du; // for float16_t 3172 return BitCast(d, VFromD<decltype(du)>{_mm512_maskz_loadu_epi16( 3173 m.raw, reinterpret_cast<const uint16_t*>(p))}); 3174 } 3175 3176 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 3177 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, 3178 const TFromD<D>* HWY_RESTRICT p) { 3179 return VFromD<D>{_mm512_maskz_loadu_epi32(m.raw, p)}; 3180 } 3181 3182 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 3183 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D /* tag */, 3184 const TFromD<D>* HWY_RESTRICT p) { 3185 return VFromD<D>{_mm512_maskz_loadu_epi64(m.raw, p)}; 3186 } 3187 3188 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3189 HWY_API Vec512<float> MaskedLoad(Mask512<float> m, D /* tag */, 3190 const float* HWY_RESTRICT p) { 3191 return Vec512<float>{_mm512_maskz_loadu_ps(m.raw, p)}; 3192 } 3193 3194 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3195 HWY_API Vec512<double> MaskedLoad(Mask512<double> m, D /* tag */, 3196 const double* HWY_RESTRICT p) { 3197 return Vec512<double>{_mm512_maskz_loadu_pd(m.raw, p)}; 3198 } 3199 3200 // ------------------------------ MaskedLoadOr 3201 3202 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 3203 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, 3204 const TFromD<D>* HWY_RESTRICT p) { 3205 return VFromD<D>{_mm512_mask_loadu_epi8(v.raw, m.raw, p)}; 3206 } 3207 3208 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 3209 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D d, 3210 const TFromD<D>* HWY_RESTRICT p) { 3211 const RebindToUnsigned<decltype(d)> du; // for float16_t 3212 return BitCast( 3213 d, VFromD<decltype(du)>{_mm512_mask_loadu_epi16( 3214 BitCast(du, v).raw, m.raw, reinterpret_cast<const uint16_t*>(p))}); 3215 } 3216 3217 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 3218 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, 3219 const TFromD<D>* HWY_RESTRICT p) { 3220 return VFromD<D>{_mm512_mask_loadu_epi32(v.raw, m.raw, p)}; 3221 } 3222 3223 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 3224 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D /* tag */, 3225 const TFromD<D>* HWY_RESTRICT p) { 3226 return VFromD<D>{_mm512_mask_loadu_epi64(v.raw, m.raw, p)}; 3227 } 3228 3229 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3230 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, Mask512<float> m, D /* tag */, 3231 const float* HWY_RESTRICT p) { 3232 return VFromD<D>{_mm512_mask_loadu_ps(v.raw, m.raw, p)}; 3233 } 3234 3235 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3236 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, Mask512<double> m, D /* tag */, 3237 const double* HWY_RESTRICT p) { 3238 return VFromD<D>{_mm512_mask_loadu_pd(v.raw, m.raw, p)}; 3239 } 3240 3241 // ------------------------------ LoadDup128 3242 3243 // Loads 128 bit and duplicates into both 128-bit halves. This avoids the 3244 // 3-cycle cost of moving data between 128-bit halves and avoids port 5. 3245 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> 3246 HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* const HWY_RESTRICT p) { 3247 const RebindToUnsigned<decltype(d)> du; 3248 const Full128<TFromD<D>> d128; 3249 const RebindToUnsigned<decltype(d128)> du128; 3250 return BitCast(d, VFromD<decltype(du)>{_mm512_broadcast_i32x4( 3251 BitCast(du128, LoadU(d128, p)).raw)}); 3252 } 3253 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3254 HWY_API VFromD<D> LoadDup128(D /* tag */, const float* HWY_RESTRICT p) { 3255 const __m128 x4 = _mm_loadu_ps(p); 3256 return VFromD<D>{_mm512_broadcast_f32x4(x4)}; 3257 } 3258 3259 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3260 HWY_API VFromD<D> LoadDup128(D /* tag */, const double* HWY_RESTRICT p) { 3261 const __m128d x2 = _mm_loadu_pd(p); 3262 return VFromD<D>{_mm512_broadcast_f64x2(x2)}; 3263 } 3264 3265 // ------------------------------ Store 3266 3267 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3268 HWY_API void Store(VFromD<D> v, D /* tag */, TFromD<D>* HWY_RESTRICT aligned) { 3269 _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); 3270 } 3271 // bfloat16_t is handled by x86_128-inl.h. 3272 #if HWY_HAVE_FLOAT16 3273 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 3274 HWY_API void Store(Vec512<float16_t> v, D /* tag */, 3275 float16_t* HWY_RESTRICT aligned) { 3276 _mm512_store_ph(aligned, v.raw); 3277 } 3278 #endif 3279 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3280 HWY_API void Store(Vec512<float> v, D /* tag */, float* HWY_RESTRICT aligned) { 3281 _mm512_store_ps(aligned, v.raw); 3282 } 3283 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3284 HWY_API void Store(VFromD<D> v, D /* tag */, double* HWY_RESTRICT aligned) { 3285 _mm512_store_pd(aligned, v.raw); 3286 } 3287 3288 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3289 HWY_API void StoreU(VFromD<D> v, D /* tag */, TFromD<D>* HWY_RESTRICT p) { 3290 _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); 3291 } 3292 // bfloat16_t is handled by x86_128-inl.h. 3293 #if HWY_HAVE_FLOAT16 3294 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 3295 HWY_API void StoreU(Vec512<float16_t> v, D /* tag */, 3296 float16_t* HWY_RESTRICT p) { 3297 _mm512_storeu_ph(p, v.raw); 3298 } 3299 #endif // HWY_HAVE_FLOAT16 3300 3301 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3302 HWY_API void StoreU(Vec512<float> v, D /* tag */, float* HWY_RESTRICT p) { 3303 _mm512_storeu_ps(p, v.raw); 3304 } 3305 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3306 HWY_API void StoreU(Vec512<double> v, D /* tag */, double* HWY_RESTRICT p) { 3307 _mm512_storeu_pd(p, v.raw); 3308 } 3309 3310 // ------------------------------ BlendedStore 3311 3312 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 3313 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, 3314 TFromD<D>* HWY_RESTRICT p) { 3315 _mm512_mask_storeu_epi8(p, m.raw, v.raw); 3316 } 3317 3318 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 3319 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d, 3320 TFromD<D>* HWY_RESTRICT p) { 3321 const RebindToUnsigned<decltype(d)> du; // for float16_t 3322 _mm512_mask_storeu_epi16(reinterpret_cast<uint16_t*>(p), m.raw, 3323 BitCast(du, v).raw); 3324 } 3325 3326 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 3327 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, 3328 TFromD<D>* HWY_RESTRICT p) { 3329 _mm512_mask_storeu_epi32(p, m.raw, v.raw); 3330 } 3331 3332 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 3333 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D /* tag */, 3334 TFromD<D>* HWY_RESTRICT p) { 3335 _mm512_mask_storeu_epi64(p, m.raw, v.raw); 3336 } 3337 3338 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3339 HWY_API void BlendedStore(Vec512<float> v, Mask512<float> m, D /* tag */, 3340 float* HWY_RESTRICT p) { 3341 _mm512_mask_storeu_ps(p, m.raw, v.raw); 3342 } 3343 3344 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3345 HWY_API void BlendedStore(Vec512<double> v, Mask512<double> m, D /* tag */, 3346 double* HWY_RESTRICT p) { 3347 _mm512_mask_storeu_pd(p, m.raw, v.raw); 3348 } 3349 3350 // ------------------------------ Non-temporal stores 3351 3352 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> 3353 HWY_API void Stream(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT aligned) { 3354 const RebindToUnsigned<decltype(d)> du; // for float16_t 3355 _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), BitCast(du, v).raw); 3356 } 3357 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3358 HWY_API void Stream(VFromD<D> v, D /* tag */, float* HWY_RESTRICT aligned) { 3359 _mm512_stream_ps(aligned, v.raw); 3360 } 3361 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3362 HWY_API void Stream(VFromD<D> v, D /* tag */, double* HWY_RESTRICT aligned) { 3363 _mm512_stream_pd(aligned, v.raw); 3364 } 3365 3366 // ------------------------------ ScatterOffset 3367 3368 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 3369 HWY_DIAGNOSTICS(push) 3370 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 3371 3372 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 3373 HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, 3374 TFromD<D>* HWY_RESTRICT base, 3375 VFromD<RebindToSigned<D>> offset) { 3376 _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); 3377 } 3378 3379 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 3380 HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, 3381 TFromD<D>* HWY_RESTRICT base, 3382 VFromD<RebindToSigned<D>> offset) { 3383 _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); 3384 } 3385 3386 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3387 HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, float* HWY_RESTRICT base, 3388 Vec512<int32_t> offset) { 3389 _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); 3390 } 3391 3392 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3393 HWY_API void ScatterOffset(VFromD<D> v, D /* tag */, double* HWY_RESTRICT base, 3394 Vec512<int64_t> offset) { 3395 _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); 3396 } 3397 3398 // ------------------------------ ScatterIndex 3399 3400 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 3401 HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, 3402 TFromD<D>* HWY_RESTRICT base, 3403 VFromD<RebindToSigned<D>> index) { 3404 _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); 3405 } 3406 3407 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 3408 HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, 3409 TFromD<D>* HWY_RESTRICT base, 3410 VFromD<RebindToSigned<D>> index) { 3411 _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); 3412 } 3413 3414 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3415 HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, float* HWY_RESTRICT base, 3416 Vec512<int32_t> index) { 3417 _mm512_i32scatter_ps(base, index.raw, v.raw, 4); 3418 } 3419 3420 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3421 HWY_API void ScatterIndex(VFromD<D> v, D /* tag */, double* HWY_RESTRICT base, 3422 Vec512<int64_t> index) { 3423 _mm512_i64scatter_pd(base, index.raw, v.raw, 8); 3424 } 3425 3426 // ------------------------------ MaskedScatterIndex 3427 3428 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 3429 HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, 3430 TFromD<D>* HWY_RESTRICT base, 3431 VFromD<RebindToSigned<D>> index) { 3432 _mm512_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, 4); 3433 } 3434 3435 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 3436 HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, 3437 TFromD<D>* HWY_RESTRICT base, 3438 VFromD<RebindToSigned<D>> index) { 3439 _mm512_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, 8); 3440 } 3441 3442 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3443 HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, 3444 float* HWY_RESTRICT base, 3445 Vec512<int32_t> index) { 3446 _mm512_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, 4); 3447 } 3448 3449 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3450 HWY_API void MaskedScatterIndex(VFromD<D> v, MFromD<D> m, D /* tag */, 3451 double* HWY_RESTRICT base, 3452 Vec512<int64_t> index) { 3453 _mm512_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, 8); 3454 } 3455 3456 // ------------------------------ Gather 3457 3458 namespace detail { 3459 3460 template <int kScale, typename T, HWY_IF_UI32(T)> 3461 HWY_INLINE Vec512<T> NativeGather512(const T* HWY_RESTRICT base, 3462 Vec512<int32_t> indices) { 3463 return Vec512<T>{_mm512_i32gather_epi32(indices.raw, base, kScale)}; 3464 } 3465 3466 template <int kScale, typename T, HWY_IF_UI64(T)> 3467 HWY_INLINE Vec512<T> NativeGather512(const T* HWY_RESTRICT base, 3468 Vec512<int64_t> indices) { 3469 return Vec512<T>{_mm512_i64gather_epi64(indices.raw, base, kScale)}; 3470 } 3471 3472 template <int kScale> 3473 HWY_INLINE Vec512<float> NativeGather512(const float* HWY_RESTRICT base, 3474 Vec512<int32_t> indices) { 3475 return Vec512<float>{_mm512_i32gather_ps(indices.raw, base, kScale)}; 3476 } 3477 3478 template <int kScale> 3479 HWY_INLINE Vec512<double> NativeGather512(const double* HWY_RESTRICT base, 3480 Vec512<int64_t> indices) { 3481 return Vec512<double>{_mm512_i64gather_pd(indices.raw, base, kScale)}; 3482 } 3483 3484 template <int kScale, typename T, HWY_IF_UI32(T)> 3485 HWY_INLINE Vec512<T> NativeMaskedGatherOr512(Vec512<T> no, Mask512<T> m, 3486 const T* HWY_RESTRICT base, 3487 Vec512<int32_t> indices) { 3488 return Vec512<T>{ 3489 _mm512_mask_i32gather_epi32(no.raw, m.raw, indices.raw, base, kScale)}; 3490 } 3491 3492 template <int kScale, typename T, HWY_IF_UI64(T)> 3493 HWY_INLINE Vec512<T> NativeMaskedGatherOr512(Vec512<T> no, Mask512<T> m, 3494 const T* HWY_RESTRICT base, 3495 Vec512<int64_t> indices) { 3496 return Vec512<T>{ 3497 _mm512_mask_i64gather_epi64(no.raw, m.raw, indices.raw, base, kScale)}; 3498 } 3499 3500 template <int kScale> 3501 HWY_INLINE Vec512<float> NativeMaskedGatherOr512(Vec512<float> no, 3502 Mask512<float> m, 3503 const float* HWY_RESTRICT base, 3504 Vec512<int32_t> indices) { 3505 return Vec512<float>{ 3506 _mm512_mask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; 3507 } 3508 3509 template <int kScale> 3510 HWY_INLINE Vec512<double> NativeMaskedGatherOr512( 3511 Vec512<double> no, Mask512<double> m, const double* HWY_RESTRICT base, 3512 Vec512<int64_t> indices) { 3513 return Vec512<double>{ 3514 _mm512_mask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; 3515 } 3516 } // namespace detail 3517 3518 template <class D, HWY_IF_V_SIZE_D(D, 64)> 3519 HWY_API VFromD<D> GatherOffset(D /*d*/, const TFromD<D>* HWY_RESTRICT base, 3520 VFromD<RebindToSigned<D>> offsets) { 3521 return detail::NativeGather512<1>(base, offsets); 3522 } 3523 3524 template <class D, HWY_IF_V_SIZE_D(D, 64)> 3525 HWY_API VFromD<D> GatherIndex(D /*d*/, const TFromD<D>* HWY_RESTRICT base, 3526 VFromD<RebindToSigned<D>> indices) { 3527 return detail::NativeGather512<sizeof(TFromD<D>)>(base, indices); 3528 } 3529 3530 template <class D, HWY_IF_V_SIZE_D(D, 64)> 3531 HWY_API VFromD<D> MaskedGatherIndexOr(VFromD<D> no, MFromD<D> m, D /*d*/, 3532 const TFromD<D>* HWY_RESTRICT base, 3533 VFromD<RebindToSigned<D>> indices) { 3534 return detail::NativeMaskedGatherOr512<sizeof(TFromD<D>)>(no, m, base, 3535 indices); 3536 } 3537 3538 HWY_DIAGNOSTICS(pop) 3539 3540 // ================================================== SWIZZLE 3541 3542 // ------------------------------ LowerHalf 3543 3544 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3545 HWY_API VFromD<D> LowerHalf(D /* tag */, VFromD<Twice<D>> v) { 3546 return VFromD<D>{_mm512_castsi512_si256(v.raw)}; 3547 } 3548 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)> 3549 HWY_API VFromD<D> LowerHalf(D /* tag */, Vec512<bfloat16_t> v) { 3550 return VFromD<D>{_mm512_castsi512_si256(v.raw)}; 3551 } 3552 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 3553 HWY_API VFromD<D> LowerHalf(D /* tag */, Vec512<float16_t> v) { 3554 #if HWY_HAVE_FLOAT16 3555 return VFromD<D>{_mm512_castph512_ph256(v.raw)}; 3556 #else 3557 return VFromD<D>{_mm512_castsi512_si256(v.raw)}; 3558 #endif // HWY_HAVE_FLOAT16 3559 } 3560 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3561 HWY_API VFromD<D> LowerHalf(D /* tag */, Vec512<float> v) { 3562 return VFromD<D>{_mm512_castps512_ps256(v.raw)}; 3563 } 3564 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3565 HWY_API VFromD<D> LowerHalf(D /* tag */, Vec512<double> v) { 3566 return VFromD<D>{_mm512_castpd512_pd256(v.raw)}; 3567 } 3568 3569 template <typename T> 3570 HWY_API Vec256<T> LowerHalf(Vec512<T> v) { 3571 const Half<DFromV<decltype(v)>> dh; 3572 return LowerHalf(dh, v); 3573 } 3574 3575 // ------------------------------ UpperHalf 3576 3577 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT3264_D(D)> 3578 HWY_API VFromD<D> UpperHalf(D d, VFromD<Twice<D>> v) { 3579 const RebindToUnsigned<decltype(d)> du; // for float16_t 3580 const Twice<decltype(du)> dut; 3581 return BitCast(d, VFromD<decltype(du)>{ 3582 _mm512_extracti32x8_epi32(BitCast(dut, v).raw, 1)}); 3583 } 3584 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 3585 HWY_API VFromD<D> UpperHalf(D /* tag */, VFromD<Twice<D>> v) { 3586 return VFromD<D>{_mm512_extractf32x8_ps(v.raw, 1)}; 3587 } 3588 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)> 3589 HWY_API VFromD<D> UpperHalf(D /* tag */, VFromD<Twice<D>> v) { 3590 return VFromD<D>{_mm512_extractf64x4_pd(v.raw, 1)}; 3591 } 3592 3593 // ------------------------------ ExtractLane (Store) 3594 template <typename T> 3595 HWY_API T ExtractLane(const Vec512<T> v, size_t i) { 3596 const DFromV<decltype(v)> d; 3597 HWY_DASSERT(i < Lanes(d)); 3598 3599 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 3600 constexpr size_t kLanesPerBlock = 16 / sizeof(T); 3601 if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) { 3602 return ExtractLane(ResizeBitCast(Full128<T>(), v), i); 3603 } 3604 #endif 3605 3606 alignas(64) T lanes[MaxLanes(d)]; 3607 Store(v, d, lanes); 3608 return lanes[i]; 3609 } 3610 3611 // ------------------------------ ExtractBlock 3612 template <int kBlockIdx, class T, hwy::EnableIf<(kBlockIdx <= 1)>* = nullptr> 3613 HWY_API Vec128<T> ExtractBlock(Vec512<T> v) { 3614 const DFromV<decltype(v)> d; 3615 const Half<decltype(d)> dh; 3616 return ExtractBlock<kBlockIdx>(LowerHalf(dh, v)); 3617 } 3618 3619 template <int kBlockIdx, class T, hwy::EnableIf<(kBlockIdx > 1)>* = nullptr> 3620 HWY_API Vec128<T> ExtractBlock(Vec512<T> v) { 3621 static_assert(kBlockIdx <= 3, "Invalid block index"); 3622 const DFromV<decltype(v)> d; 3623 const RebindToUnsigned<decltype(d)> du; // for float16_t 3624 return BitCast(Full128<T>(), 3625 Vec128<MakeUnsigned<T>>{ 3626 _mm512_extracti32x4_epi32(BitCast(du, v).raw, kBlockIdx)}); 3627 } 3628 3629 template <int kBlockIdx, hwy::EnableIf<(kBlockIdx > 1)>* = nullptr> 3630 HWY_API Vec128<float> ExtractBlock(Vec512<float> v) { 3631 static_assert(kBlockIdx <= 3, "Invalid block index"); 3632 return Vec128<float>{_mm512_extractf32x4_ps(v.raw, kBlockIdx)}; 3633 } 3634 3635 template <int kBlockIdx, hwy::EnableIf<(kBlockIdx > 1)>* = nullptr> 3636 HWY_API Vec128<double> ExtractBlock(Vec512<double> v) { 3637 static_assert(kBlockIdx <= 3, "Invalid block index"); 3638 return Vec128<double>{_mm512_extractf64x2_pd(v.raw, kBlockIdx)}; 3639 } 3640 3641 // ------------------------------ InsertLane (Store) 3642 template <typename T> 3643 HWY_API Vec512<T> InsertLane(const Vec512<T> v, size_t i, T t) { 3644 return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); 3645 } 3646 3647 // ------------------------------ InsertBlock 3648 namespace detail { 3649 3650 template <typename T> 3651 HWY_INLINE Vec512<T> InsertBlock(hwy::SizeTag<0> /* blk_idx_tag */, Vec512<T> v, 3652 Vec128<T> blk_to_insert) { 3653 const DFromV<decltype(v)> d; 3654 const auto insert_mask = FirstN(d, 16 / sizeof(T)); 3655 return IfThenElse(insert_mask, ResizeBitCast(d, blk_to_insert), v); 3656 } 3657 3658 template <size_t kBlockIdx, typename T> 3659 HWY_INLINE Vec512<T> InsertBlock(hwy::SizeTag<kBlockIdx> /* blk_idx_tag */, 3660 Vec512<T> v, Vec128<T> blk_to_insert) { 3661 const DFromV<decltype(v)> d; 3662 const RebindToUnsigned<decltype(d)> du; // for float16_t 3663 const Full128<MakeUnsigned<T>> du_blk_to_insert; 3664 return BitCast( 3665 d, VFromD<decltype(du)>{_mm512_inserti32x4( 3666 BitCast(du, v).raw, BitCast(du_blk_to_insert, blk_to_insert).raw, 3667 static_cast<int>(kBlockIdx & 3))}); 3668 } 3669 3670 template <size_t kBlockIdx, hwy::EnableIf<kBlockIdx != 0>* = nullptr> 3671 HWY_INLINE Vec512<float> InsertBlock(hwy::SizeTag<kBlockIdx> /* blk_idx_tag */, 3672 Vec512<float> v, 3673 Vec128<float> blk_to_insert) { 3674 return Vec512<float>{_mm512_insertf32x4(v.raw, blk_to_insert.raw, 3675 static_cast<int>(kBlockIdx & 3))}; 3676 } 3677 3678 template <size_t kBlockIdx, hwy::EnableIf<kBlockIdx != 0>* = nullptr> 3679 HWY_INLINE Vec512<double> InsertBlock(hwy::SizeTag<kBlockIdx> /* blk_idx_tag */, 3680 Vec512<double> v, 3681 Vec128<double> blk_to_insert) { 3682 return Vec512<double>{_mm512_insertf64x2(v.raw, blk_to_insert.raw, 3683 static_cast<int>(kBlockIdx & 3))}; 3684 } 3685 3686 } // namespace detail 3687 3688 template <int kBlockIdx, class T> 3689 HWY_API Vec512<T> InsertBlock(Vec512<T> v, Vec128<T> blk_to_insert) { 3690 static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); 3691 return detail::InsertBlock(hwy::SizeTag<static_cast<size_t>(kBlockIdx)>(), v, 3692 blk_to_insert); 3693 } 3694 3695 // ------------------------------ GetLane (LowerHalf) 3696 template <typename T> 3697 HWY_API T GetLane(const Vec512<T> v) { 3698 return GetLane(LowerHalf(v)); 3699 } 3700 3701 // ------------------------------ ZeroExtendVector 3702 3703 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)> 3704 HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) { 3705 #if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. 3706 (void)d; 3707 return VFromD<D>{_mm512_zextsi256_si512(lo.raw)}; 3708 #else 3709 return VFromD<D>{_mm512_inserti32x8(Zero(d).raw, lo.raw, 0)}; 3710 #endif 3711 } 3712 #if HWY_HAVE_FLOAT16 3713 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 3714 HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) { 3715 #if HWY_HAVE_ZEXT 3716 (void)d; 3717 return VFromD<D>{_mm512_zextph256_ph512(lo.raw)}; 3718 #else 3719 const RebindToUnsigned<D> du; 3720 return BitCast(d, ZeroExtendVector(du, BitCast(du, lo))); 3721 #endif 3722 } 3723 #endif // HWY_HAVE_FLOAT16 3724 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3725 HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) { 3726 #if HWY_HAVE_ZEXT 3727 (void)d; 3728 return VFromD<D>{_mm512_zextps256_ps512(lo.raw)}; 3729 #else 3730 return VFromD<D>{_mm512_insertf32x8(Zero(d).raw, lo.raw, 0)}; 3731 #endif 3732 } 3733 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3734 HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) { 3735 #if HWY_HAVE_ZEXT 3736 (void)d; 3737 return VFromD<D>{_mm512_zextpd256_pd512(lo.raw)}; 3738 #else 3739 return VFromD<D>{_mm512_insertf64x4(Zero(d).raw, lo.raw, 0)}; 3740 #endif 3741 } 3742 3743 // ------------------------------ ZeroExtendResizeBitCast 3744 3745 namespace detail { 3746 3747 template <class DTo, class DFrom, HWY_IF_NOT_FLOAT3264_D(DTo)> 3748 HWY_INLINE VFromD<DTo> ZeroExtendResizeBitCast( 3749 hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, 3750 DTo d_to, DFrom d_from, VFromD<DFrom> v) { 3751 const Repartition<uint8_t, decltype(d_from)> du8_from; 3752 const auto vu8 = BitCast(du8_from, v); 3753 const RebindToUnsigned<decltype(d_to)> du_to; 3754 #if HWY_HAVE_ZEXT 3755 return BitCast(d_to, 3756 VFromD<decltype(du_to)>{_mm512_zextsi128_si512(vu8.raw)}); 3757 #else 3758 return BitCast(d_to, VFromD<decltype(du_to)>{ 3759 _mm512_inserti32x4(Zero(du_to).raw, vu8.raw, 0)}); 3760 #endif 3761 } 3762 3763 template <class DTo, class DFrom, HWY_IF_F32_D(DTo)> 3764 HWY_INLINE VFromD<DTo> ZeroExtendResizeBitCast( 3765 hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, 3766 DTo d_to, DFrom d_from, VFromD<DFrom> v) { 3767 const Repartition<float, decltype(d_from)> df32_from; 3768 const auto vf32 = BitCast(df32_from, v); 3769 #if HWY_HAVE_ZEXT 3770 (void)d_to; 3771 return Vec512<float>{_mm512_zextps128_ps512(vf32.raw)}; 3772 #else 3773 return Vec512<float>{_mm512_insertf32x4(Zero(d_to).raw, vf32.raw, 0)}; 3774 #endif 3775 } 3776 3777 template <class DTo, class DFrom, HWY_IF_F64_D(DTo)> 3778 HWY_INLINE Vec512<double> ZeroExtendResizeBitCast( 3779 hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, 3780 DTo d_to, DFrom d_from, VFromD<DFrom> v) { 3781 const Repartition<double, decltype(d_from)> df64_from; 3782 const auto vf64 = BitCast(df64_from, v); 3783 #if HWY_HAVE_ZEXT 3784 (void)d_to; 3785 return Vec512<double>{_mm512_zextpd128_pd512(vf64.raw)}; 3786 #else 3787 return Vec512<double>{_mm512_insertf64x2(Zero(d_to).raw, vf64.raw, 0)}; 3788 #endif 3789 } 3790 3791 template <class DTo, class DFrom> 3792 HWY_INLINE VFromD<DTo> ZeroExtendResizeBitCast( 3793 hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, 3794 DTo d_to, DFrom d_from, VFromD<DFrom> v) { 3795 const Twice<decltype(d_from)> dt_from; 3796 return ZeroExtendResizeBitCast(hwy::SizeTag<16>(), hwy::SizeTag<64>(), d_to, 3797 dt_from, ZeroExtendVector(dt_from, v)); 3798 } 3799 3800 } // namespace detail 3801 3802 // ------------------------------ Combine 3803 3804 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> 3805 HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) { 3806 const RebindToUnsigned<decltype(d)> du; // for float16_t 3807 const Half<decltype(du)> duh; 3808 const __m512i lo512 = ZeroExtendVector(du, BitCast(duh, lo)).raw; 3809 return BitCast(d, VFromD<decltype(du)>{ 3810 _mm512_inserti32x8(lo512, BitCast(duh, hi).raw, 1)}); 3811 } 3812 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 3813 HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) { 3814 return VFromD<D>{_mm512_insertf32x8(ZeroExtendVector(d, lo).raw, hi.raw, 1)}; 3815 } 3816 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 3817 HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) { 3818 return VFromD<D>{_mm512_insertf64x4(ZeroExtendVector(d, lo).raw, hi.raw, 1)}; 3819 } 3820 3821 // ------------------------------ ShiftLeftBytes 3822 template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 64)> 3823 HWY_API VFromD<D> ShiftLeftBytes(D /* tag */, const VFromD<D> v) { 3824 static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); 3825 return VFromD<D>{_mm512_bslli_epi128(v.raw, kBytes)}; 3826 } 3827 3828 // ------------------------------ ShiftRightBytes 3829 template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 64)> 3830 HWY_API VFromD<D> ShiftRightBytes(D /* tag */, const VFromD<D> v) { 3831 static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); 3832 return VFromD<D>{_mm512_bsrli_epi128(v.raw, kBytes)}; 3833 } 3834 3835 // ------------------------------ CombineShiftRightBytes 3836 3837 template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 64)> 3838 HWY_API VFromD<D> CombineShiftRightBytes(D d, VFromD<D> hi, VFromD<D> lo) { 3839 const Repartition<uint8_t, decltype(d)> d8; 3840 return BitCast(d, Vec512<uint8_t>{_mm512_alignr_epi8( 3841 BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); 3842 } 3843 3844 // ------------------------------ Broadcast/splat any lane 3845 3846 template <int kLane, typename T, HWY_IF_T_SIZE(T, 2)> 3847 HWY_API Vec512<T> Broadcast(const Vec512<T> v) { 3848 const DFromV<decltype(v)> d; 3849 const RebindToUnsigned<decltype(d)> du; 3850 using VU = VFromD<decltype(du)>; 3851 const VU vu = BitCast(du, v); // for float16_t 3852 static_assert(0 <= kLane && kLane < 8, "Invalid lane"); 3853 if (kLane < 4) { 3854 const __m512i lo = _mm512_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); 3855 return BitCast(d, VU{_mm512_unpacklo_epi64(lo, lo)}); 3856 } else { 3857 const __m512i hi = 3858 _mm512_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); 3859 return BitCast(d, VU{_mm512_unpackhi_epi64(hi, hi)}); 3860 } 3861 } 3862 3863 template <int kLane, typename T, HWY_IF_UI32(T)> 3864 HWY_API Vec512<T> Broadcast(const Vec512<T> v) { 3865 static_assert(0 <= kLane && kLane < 4, "Invalid lane"); 3866 constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); 3867 return Vec512<T>{_mm512_shuffle_epi32(v.raw, perm)}; 3868 } 3869 3870 template <int kLane, typename T, HWY_IF_UI64(T)> 3871 HWY_API Vec512<T> Broadcast(const Vec512<T> v) { 3872 static_assert(0 <= kLane && kLane < 2, "Invalid lane"); 3873 constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; 3874 return Vec512<T>{_mm512_shuffle_epi32(v.raw, perm)}; 3875 } 3876 3877 template <int kLane> 3878 HWY_API Vec512<float> Broadcast(const Vec512<float> v) { 3879 static_assert(0 <= kLane && kLane < 4, "Invalid lane"); 3880 constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); 3881 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, perm)}; 3882 } 3883 3884 template <int kLane> 3885 HWY_API Vec512<double> Broadcast(const Vec512<double> v) { 3886 static_assert(0 <= kLane && kLane < 2, "Invalid lane"); 3887 constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); 3888 return Vec512<double>{_mm512_shuffle_pd(v.raw, v.raw, perm)}; 3889 } 3890 3891 // ------------------------------ BroadcastBlock 3892 template <int kBlockIdx, class T> 3893 HWY_API Vec512<T> BroadcastBlock(Vec512<T> v) { 3894 static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); 3895 const DFromV<decltype(v)> d; 3896 const RebindToUnsigned<decltype(d)> du; // for float16_t 3897 return BitCast( 3898 d, VFromD<decltype(du)>{_mm512_shuffle_i32x4( 3899 BitCast(du, v).raw, BitCast(du, v).raw, 0x55 * kBlockIdx)}); 3900 } 3901 3902 template <int kBlockIdx> 3903 HWY_API Vec512<float> BroadcastBlock(Vec512<float> v) { 3904 static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); 3905 return Vec512<float>{_mm512_shuffle_f32x4(v.raw, v.raw, 0x55 * kBlockIdx)}; 3906 } 3907 3908 template <int kBlockIdx> 3909 HWY_API Vec512<double> BroadcastBlock(Vec512<double> v) { 3910 static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); 3911 return Vec512<double>{_mm512_shuffle_f64x2(v.raw, v.raw, 0x55 * kBlockIdx)}; 3912 } 3913 3914 // ------------------------------ BroadcastLane 3915 3916 namespace detail { 3917 3918 template <class T, HWY_IF_T_SIZE(T, 1)> 3919 HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 3920 Vec512<T> v) { 3921 return Vec512<T>{_mm512_broadcastb_epi8(ResizeBitCast(Full128<T>(), v).raw)}; 3922 } 3923 3924 template <class T, HWY_IF_T_SIZE(T, 2)> 3925 HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 3926 Vec512<T> v) { 3927 const DFromV<decltype(v)> d; 3928 const RebindToUnsigned<decltype(d)> du; // for float16_t 3929 return BitCast(d, VFromD<decltype(du)>{_mm512_broadcastw_epi16( 3930 ResizeBitCast(Full128<uint16_t>(), v).raw)}); 3931 } 3932 3933 template <class T, HWY_IF_UI32(T)> 3934 HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 3935 Vec512<T> v) { 3936 return Vec512<T>{_mm512_broadcastd_epi32(ResizeBitCast(Full128<T>(), v).raw)}; 3937 } 3938 3939 template <class T, HWY_IF_UI64(T)> 3940 HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 3941 Vec512<T> v) { 3942 return Vec512<T>{_mm512_broadcastq_epi64(ResizeBitCast(Full128<T>(), v).raw)}; 3943 } 3944 3945 HWY_INLINE Vec512<float> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 3946 Vec512<float> v) { 3947 return Vec512<float>{ 3948 _mm512_broadcastss_ps(ResizeBitCast(Full128<float>(), v).raw)}; 3949 } 3950 3951 HWY_INLINE Vec512<double> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, 3952 Vec512<double> v) { 3953 return Vec512<double>{ 3954 _mm512_broadcastsd_pd(ResizeBitCast(Full128<double>(), v).raw)}; 3955 } 3956 3957 template <size_t kLaneIdx, class T, hwy::EnableIf<kLaneIdx != 0>* = nullptr> 3958 HWY_INLINE Vec512<T> BroadcastLane(hwy::SizeTag<kLaneIdx> /* lane_idx_tag */, 3959 Vec512<T> v) { 3960 constexpr size_t kLanesPerBlock = 16 / sizeof(T); 3961 constexpr int kBlockIdx = static_cast<int>(kLaneIdx / kLanesPerBlock); 3962 constexpr int kLaneInBlkIdx = 3963 static_cast<int>(kLaneIdx) & (kLanesPerBlock - 1); 3964 return Broadcast<kLaneInBlkIdx>(BroadcastBlock<kBlockIdx>(v)); 3965 } 3966 3967 } // namespace detail 3968 3969 template <int kLaneIdx, class T> 3970 HWY_API Vec512<T> BroadcastLane(Vec512<T> v) { 3971 static_assert(0 <= kLaneIdx, "Invalid lane"); 3972 return detail::BroadcastLane(hwy::SizeTag<static_cast<size_t>(kLaneIdx)>(), 3973 v); 3974 } 3975 3976 // ------------------------------ Hard-coded shuffles 3977 3978 // Notation: let Vec512<int32_t> have lanes 7,6,5,4,3,2,1,0 (0 is 3979 // least-significant). Shuffle0321 rotates four-lane blocks one lane to the 3980 // right (the previous least-significant lane is now most-significant => 3981 // 47650321). These could also be implemented via CombineShiftRightBytes but 3982 // the shuffle_abcd notation is more convenient. 3983 3984 // Swap 32-bit halves in 64-bit halves. 3985 template <typename T, HWY_IF_UI32(T)> 3986 HWY_API Vec512<T> Shuffle2301(const Vec512<T> v) { 3987 return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; 3988 } 3989 HWY_API Vec512<float> Shuffle2301(const Vec512<float> v) { 3990 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; 3991 } 3992 3993 namespace detail { 3994 3995 template <typename T, HWY_IF_T_SIZE(T, 4)> 3996 HWY_API Vec512<T> ShuffleTwo2301(const Vec512<T> a, const Vec512<T> b) { 3997 const DFromV<decltype(a)> d; 3998 const RebindToFloat<decltype(d)> df; 3999 return BitCast( 4000 d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, 4001 _MM_PERM_CDAB)}); 4002 } 4003 template <typename T, HWY_IF_T_SIZE(T, 4)> 4004 HWY_API Vec512<T> ShuffleTwo1230(const Vec512<T> a, const Vec512<T> b) { 4005 const DFromV<decltype(a)> d; 4006 const RebindToFloat<decltype(d)> df; 4007 return BitCast( 4008 d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, 4009 _MM_PERM_BCDA)}); 4010 } 4011 template <typename T, HWY_IF_T_SIZE(T, 4)> 4012 HWY_API Vec512<T> ShuffleTwo3012(const Vec512<T> a, const Vec512<T> b) { 4013 const DFromV<decltype(a)> d; 4014 const RebindToFloat<decltype(d)> df; 4015 return BitCast( 4016 d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, 4017 _MM_PERM_DABC)}); 4018 } 4019 4020 } // namespace detail 4021 4022 // Swap 64-bit halves 4023 HWY_API Vec512<uint32_t> Shuffle1032(const Vec512<uint32_t> v) { 4024 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; 4025 } 4026 HWY_API Vec512<int32_t> Shuffle1032(const Vec512<int32_t> v) { 4027 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; 4028 } 4029 HWY_API Vec512<float> Shuffle1032(const Vec512<float> v) { 4030 // Shorter encoding than _mm512_permute_ps. 4031 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; 4032 } 4033 HWY_API Vec512<uint64_t> Shuffle01(const Vec512<uint64_t> v) { 4034 return Vec512<uint64_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; 4035 } 4036 HWY_API Vec512<int64_t> Shuffle01(const Vec512<int64_t> v) { 4037 return Vec512<int64_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; 4038 } 4039 HWY_API Vec512<double> Shuffle01(const Vec512<double> v) { 4040 // Shorter encoding than _mm512_permute_pd. 4041 return Vec512<double>{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; 4042 } 4043 4044 // Rotate right 32 bits 4045 HWY_API Vec512<uint32_t> Shuffle0321(const Vec512<uint32_t> v) { 4046 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; 4047 } 4048 HWY_API Vec512<int32_t> Shuffle0321(const Vec512<int32_t> v) { 4049 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; 4050 } 4051 HWY_API Vec512<float> Shuffle0321(const Vec512<float> v) { 4052 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; 4053 } 4054 // Rotate left 32 bits 4055 HWY_API Vec512<uint32_t> Shuffle2103(const Vec512<uint32_t> v) { 4056 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; 4057 } 4058 HWY_API Vec512<int32_t> Shuffle2103(const Vec512<int32_t> v) { 4059 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; 4060 } 4061 HWY_API Vec512<float> Shuffle2103(const Vec512<float> v) { 4062 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; 4063 } 4064 4065 // Reverse 4066 HWY_API Vec512<uint32_t> Shuffle0123(const Vec512<uint32_t> v) { 4067 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; 4068 } 4069 HWY_API Vec512<int32_t> Shuffle0123(const Vec512<int32_t> v) { 4070 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; 4071 } 4072 HWY_API Vec512<float> Shuffle0123(const Vec512<float> v) { 4073 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; 4074 } 4075 4076 // ------------------------------ TableLookupLanes 4077 4078 // Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. 4079 template <typename T> 4080 struct Indices512 { 4081 __m512i raw; 4082 }; 4083 4084 template <class D, typename T = TFromD<D>, typename TI> 4085 HWY_API Indices512<T> IndicesFromVec(D /* tag */, Vec512<TI> vec) { 4086 static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); 4087 #if HWY_IS_DEBUG_BUILD 4088 const DFromV<decltype(vec)> di; 4089 const RebindToUnsigned<decltype(di)> du; 4090 using TU = MakeUnsigned<T>; 4091 const auto vec_u = BitCast(du, vec); 4092 HWY_DASSERT( 4093 AllTrue(du, Lt(vec_u, Set(du, static_cast<TU>(128 / sizeof(T)))))); 4094 #endif 4095 return Indices512<T>{vec.raw}; 4096 } 4097 4098 template <class D, HWY_IF_V_SIZE_D(D, 64), typename TI> 4099 HWY_API Indices512<TFromD<D>> SetTableIndices(D d, const TI* idx) { 4100 const Rebind<TI, decltype(d)> di; 4101 return IndicesFromVec(d, LoadU(di, idx)); 4102 } 4103 4104 template <typename T, HWY_IF_T_SIZE(T, 1)> 4105 HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { 4106 #if HWY_TARGET <= HWY_AVX3_DL 4107 return Vec512<T>{_mm512_permutexvar_epi8(idx.raw, v.raw)}; 4108 #else 4109 const DFromV<decltype(v)> d; 4110 const Repartition<uint16_t, decltype(d)> du16; 4111 const Vec512<T> idx_vec{idx.raw}; 4112 4113 const auto bd_sel_mask = 4114 MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); 4115 const auto cd_sel_mask = 4116 MaskFromVec(BitCast(d, ShiftLeft<2>(BitCast(du16, idx_vec)))); 4117 4118 const Vec512<T> v_a{_mm512_shuffle_i32x4(v.raw, v.raw, 0x00)}; 4119 const Vec512<T> v_b{_mm512_shuffle_i32x4(v.raw, v.raw, 0x55)}; 4120 const Vec512<T> v_c{_mm512_shuffle_i32x4(v.raw, v.raw, 0xAA)}; 4121 const Vec512<T> v_d{_mm512_shuffle_i32x4(v.raw, v.raw, 0xFF)}; 4122 4123 const auto shuf_a = TableLookupBytes(v_a, idx_vec); 4124 const auto shuf_c = TableLookupBytes(v_c, idx_vec); 4125 const Vec512<T> shuf_ab{_mm512_mask_shuffle_epi8(shuf_a.raw, bd_sel_mask.raw, 4126 v_b.raw, idx_vec.raw)}; 4127 const Vec512<T> shuf_cd{_mm512_mask_shuffle_epi8(shuf_c.raw, bd_sel_mask.raw, 4128 v_d.raw, idx_vec.raw)}; 4129 return IfThenElse(cd_sel_mask, shuf_cd, shuf_ab); 4130 #endif 4131 } 4132 4133 template <typename T, HWY_IF_T_SIZE(T, 2), HWY_IF_NOT_SPECIAL_FLOAT(T)> 4134 HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { 4135 return Vec512<T>{_mm512_permutexvar_epi16(idx.raw, v.raw)}; 4136 } 4137 #if HWY_HAVE_FLOAT16 4138 HWY_API Vec512<float16_t> TableLookupLanes(Vec512<float16_t> v, 4139 Indices512<float16_t> idx) { 4140 return Vec512<float16_t>{_mm512_permutexvar_ph(idx.raw, v.raw)}; 4141 } 4142 #endif // HWY_HAVE_FLOAT16 4143 template <typename T, HWY_IF_T_SIZE(T, 4)> 4144 HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { 4145 return Vec512<T>{_mm512_permutexvar_epi32(idx.raw, v.raw)}; 4146 } 4147 4148 template <typename T, HWY_IF_T_SIZE(T, 8)> 4149 HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { 4150 return Vec512<T>{_mm512_permutexvar_epi64(idx.raw, v.raw)}; 4151 } 4152 4153 HWY_API Vec512<float> TableLookupLanes(Vec512<float> v, Indices512<float> idx) { 4154 return Vec512<float>{_mm512_permutexvar_ps(idx.raw, v.raw)}; 4155 } 4156 4157 HWY_API Vec512<double> TableLookupLanes(Vec512<double> v, 4158 Indices512<double> idx) { 4159 return Vec512<double>{_mm512_permutexvar_pd(idx.raw, v.raw)}; 4160 } 4161 4162 template <typename T, HWY_IF_T_SIZE(T, 1)> 4163 HWY_API Vec512<T> TwoTablesLookupLanes(Vec512<T> a, Vec512<T> b, 4164 Indices512<T> idx) { 4165 #if HWY_TARGET <= HWY_AVX3_DL 4166 return Vec512<T>{_mm512_permutex2var_epi8(a.raw, idx.raw, b.raw)}; 4167 #else 4168 const DFromV<decltype(a)> d; 4169 const auto b_sel_mask = 4170 MaskFromVec(BitCast(d, ShiftLeft<1>(Vec512<uint16_t>{idx.raw}))); 4171 return IfThenElse(b_sel_mask, TableLookupLanes(b, idx), 4172 TableLookupLanes(a, idx)); 4173 #endif 4174 } 4175 4176 template <typename T, HWY_IF_T_SIZE(T, 2)> 4177 HWY_API Vec512<T> TwoTablesLookupLanes(Vec512<T> a, Vec512<T> b, 4178 Indices512<T> idx) { 4179 return Vec512<T>{_mm512_permutex2var_epi16(a.raw, idx.raw, b.raw)}; 4180 } 4181 4182 template <typename T, HWY_IF_UI32(T)> 4183 HWY_API Vec512<T> TwoTablesLookupLanes(Vec512<T> a, Vec512<T> b, 4184 Indices512<T> idx) { 4185 return Vec512<T>{_mm512_permutex2var_epi32(a.raw, idx.raw, b.raw)}; 4186 } 4187 4188 #if HWY_HAVE_FLOAT16 4189 HWY_API Vec512<float16_t> TwoTablesLookupLanes(Vec512<float16_t> a, 4190 Vec512<float16_t> b, 4191 Indices512<float16_t> idx) { 4192 return Vec512<float16_t>{_mm512_permutex2var_ph(a.raw, idx.raw, b.raw)}; 4193 } 4194 #endif // HWY_HAVE_FLOAT16 4195 HWY_API Vec512<float> TwoTablesLookupLanes(Vec512<float> a, Vec512<float> b, 4196 Indices512<float> idx) { 4197 return Vec512<float>{_mm512_permutex2var_ps(a.raw, idx.raw, b.raw)}; 4198 } 4199 4200 template <typename T, HWY_IF_UI64(T)> 4201 HWY_API Vec512<T> TwoTablesLookupLanes(Vec512<T> a, Vec512<T> b, 4202 Indices512<T> idx) { 4203 return Vec512<T>{_mm512_permutex2var_epi64(a.raw, idx.raw, b.raw)}; 4204 } 4205 4206 HWY_API Vec512<double> TwoTablesLookupLanes(Vec512<double> a, Vec512<double> b, 4207 Indices512<double> idx) { 4208 return Vec512<double>{_mm512_permutex2var_pd(a.raw, idx.raw, b.raw)}; 4209 } 4210 4211 // ------------------------------ Reverse 4212 4213 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 4214 HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { 4215 #if HWY_TARGET <= HWY_AVX3_DL 4216 const RebindToSigned<decltype(d)> di; 4217 alignas(64) static constexpr int8_t kReverse[64] = { 4218 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 4219 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 4220 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 4221 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; 4222 const Vec512<int8_t> idx = Load(di, kReverse); 4223 return BitCast( 4224 d, Vec512<int8_t>{_mm512_permutexvar_epi8(idx.raw, BitCast(di, v).raw)}); 4225 #else 4226 const RepartitionToWide<decltype(d)> d16; 4227 return BitCast(d, Reverse(d16, RotateRight<8>(BitCast(d16, v)))); 4228 #endif 4229 } 4230 4231 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 4232 HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { 4233 const RebindToSigned<decltype(d)> di; 4234 alignas(64) static constexpr int16_t kReverse[32] = { 4235 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 4236 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; 4237 const Vec512<int16_t> idx = Load(di, kReverse); 4238 return BitCast(d, Vec512<int16_t>{ 4239 _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); 4240 } 4241 4242 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> 4243 HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { 4244 alignas(64) static constexpr int32_t kReverse[16] = { 4245 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; 4246 return TableLookupLanes(v, SetTableIndices(d, kReverse)); 4247 } 4248 4249 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> 4250 HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) { 4251 alignas(64) static constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; 4252 return TableLookupLanes(v, SetTableIndices(d, kReverse)); 4253 } 4254 4255 // ------------------------------ Reverse2 (in x86_128) 4256 4257 // ------------------------------ Reverse4 4258 4259 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 4260 HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { 4261 const RebindToSigned<decltype(d)> di; 4262 alignas(64) static constexpr int16_t kReverse4[32] = { 4263 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 4264 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; 4265 const Vec512<int16_t> idx = Load(di, kReverse4); 4266 return BitCast(d, Vec512<int16_t>{ 4267 _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); 4268 } 4269 4270 // 32 bit Reverse4 defined in x86_128. 4271 4272 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 4273 HWY_API VFromD<D> Reverse4(D /* tag */, const VFromD<D> v) { 4274 return VFromD<D>{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; 4275 } 4276 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4277 HWY_API VFromD<D> Reverse4(D /* tag */, VFromD<D> v) { 4278 return VFromD<D>{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; 4279 } 4280 4281 // ------------------------------ Reverse8 4282 4283 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 4284 HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { 4285 const RebindToSigned<decltype(d)> di; 4286 alignas(64) static constexpr int16_t kReverse8[32] = { 4287 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 4288 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; 4289 const Vec512<int16_t> idx = Load(di, kReverse8); 4290 return BitCast(d, Vec512<int16_t>{ 4291 _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); 4292 } 4293 4294 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> 4295 HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { 4296 const RebindToSigned<decltype(d)> di; 4297 alignas(64) static constexpr int32_t kReverse8[16] = { 4298 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; 4299 const Vec512<int32_t> idx = Load(di, kReverse8); 4300 return BitCast(d, Vec512<int32_t>{ 4301 _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); 4302 } 4303 4304 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> 4305 HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { 4306 return Reverse(d, v); 4307 } 4308 4309 // ------------------------------ ReverseBits (GaloisAffine) 4310 4311 #if HWY_TARGET <= HWY_AVX3_DL 4312 4313 #ifdef HWY_NATIVE_REVERSE_BITS_UI8 4314 #undef HWY_NATIVE_REVERSE_BITS_UI8 4315 #else 4316 #define HWY_NATIVE_REVERSE_BITS_UI8 4317 #endif 4318 4319 // Generic for all vector lengths. Must be defined after all GaloisAffine. 4320 template <class V, HWY_IF_T_SIZE_V(V, 1)> 4321 HWY_API V ReverseBits(V v) { 4322 const Repartition<uint64_t, DFromV<V>> du64; 4323 return detail::GaloisAffine(v, Set(du64, 0x8040201008040201u)); 4324 } 4325 4326 #endif // HWY_TARGET <= HWY_AVX3_DL 4327 4328 // ------------------------------ InterleaveLower 4329 4330 template <typename T, HWY_IF_T_SIZE(T, 1)> 4331 HWY_API Vec512<T> InterleaveLower(Vec512<T> a, Vec512<T> b) { 4332 return Vec512<T>{_mm512_unpacklo_epi8(a.raw, b.raw)}; 4333 } 4334 template <typename T, HWY_IF_T_SIZE(T, 2)> 4335 HWY_API Vec512<T> InterleaveLower(Vec512<T> a, Vec512<T> b) { 4336 const DFromV<decltype(a)> d; 4337 const RebindToUnsigned<decltype(d)> du; 4338 using VU = VFromD<decltype(du)>; // for float16_t 4339 return BitCast( 4340 d, VU{_mm512_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); 4341 } 4342 template <typename T, HWY_IF_T_SIZE(T, 4)> 4343 HWY_API Vec512<T> InterleaveLower(Vec512<T> a, Vec512<T> b) { 4344 return Vec512<T>{_mm512_unpacklo_epi32(a.raw, b.raw)}; 4345 } 4346 template <typename T, HWY_IF_T_SIZE(T, 8)> 4347 HWY_API Vec512<T> InterleaveLower(Vec512<T> a, Vec512<T> b) { 4348 return Vec512<T>{_mm512_unpacklo_epi64(a.raw, b.raw)}; 4349 } 4350 HWY_API Vec512<float> InterleaveLower(Vec512<float> a, Vec512<float> b) { 4351 return Vec512<float>{_mm512_unpacklo_ps(a.raw, b.raw)}; 4352 } 4353 HWY_API Vec512<double> InterleaveLower(Vec512<double> a, Vec512<double> b) { 4354 return Vec512<double>{_mm512_unpacklo_pd(a.raw, b.raw)}; 4355 } 4356 4357 // ------------------------------ InterleaveUpper 4358 4359 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 4360 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 4361 return VFromD<D>{_mm512_unpackhi_epi8(a.raw, b.raw)}; 4362 } 4363 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 4364 HWY_API VFromD<D> InterleaveUpper(D d, VFromD<D> a, VFromD<D> b) { 4365 const RebindToUnsigned<decltype(d)> du; 4366 using VU = VFromD<decltype(du)>; // for float16_t 4367 return BitCast( 4368 d, VU{_mm512_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); 4369 } 4370 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 4371 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 4372 return VFromD<D>{_mm512_unpackhi_epi32(a.raw, b.raw)}; 4373 } 4374 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 4375 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 4376 return VFromD<D>{_mm512_unpackhi_epi64(a.raw, b.raw)}; 4377 } 4378 4379 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4380 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 4381 return VFromD<D>{_mm512_unpackhi_ps(a.raw, b.raw)}; 4382 } 4383 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4384 HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) { 4385 return VFromD<D>{_mm512_unpackhi_pd(a.raw, b.raw)}; 4386 } 4387 4388 // ------------------------------ Concat* halves 4389 4390 // hiH,hiL loH,loL |-> hiL,loL (= lower halves) 4391 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> 4392 HWY_API VFromD<D> ConcatLowerLower(D d, VFromD<D> hi, VFromD<D> lo) { 4393 const RebindToUnsigned<decltype(d)> du; // for float16_t 4394 return BitCast(d, 4395 VFromD<decltype(du)>{_mm512_shuffle_i32x4( 4396 BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BABA)}); 4397 } 4398 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4399 HWY_API VFromD<D> ConcatLowerLower(D /* tag */, VFromD<D> hi, VFromD<D> lo) { 4400 return VFromD<D>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; 4401 } 4402 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4403 HWY_API Vec512<double> ConcatLowerLower(D /* tag */, Vec512<double> hi, 4404 Vec512<double> lo) { 4405 return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; 4406 } 4407 4408 // hiH,hiL loH,loL |-> hiH,loH (= upper halves) 4409 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> 4410 HWY_API VFromD<D> ConcatUpperUpper(D d, VFromD<D> hi, VFromD<D> lo) { 4411 const RebindToUnsigned<decltype(d)> du; // for float16_t 4412 return BitCast(d, 4413 VFromD<decltype(du)>{_mm512_shuffle_i32x4( 4414 BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_DCDC)}); 4415 } 4416 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4417 HWY_API VFromD<D> ConcatUpperUpper(D /* tag */, VFromD<D> hi, VFromD<D> lo) { 4418 return VFromD<D>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; 4419 } 4420 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4421 HWY_API Vec512<double> ConcatUpperUpper(D /* tag */, Vec512<double> hi, 4422 Vec512<double> lo) { 4423 return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; 4424 } 4425 4426 // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) 4427 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> 4428 HWY_API VFromD<D> ConcatLowerUpper(D d, VFromD<D> hi, VFromD<D> lo) { 4429 const RebindToUnsigned<decltype(d)> du; // for float16_t 4430 return BitCast(d, 4431 VFromD<decltype(du)>{_mm512_shuffle_i32x4( 4432 BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BADC)}); 4433 } 4434 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4435 HWY_API VFromD<D> ConcatLowerUpper(D /* tag */, VFromD<D> hi, VFromD<D> lo) { 4436 return VFromD<D>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; 4437 } 4438 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4439 HWY_API Vec512<double> ConcatLowerUpper(D /* tag */, Vec512<double> hi, 4440 Vec512<double> lo) { 4441 return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; 4442 } 4443 4444 // hiH,hiL loH,loL |-> hiH,loL (= outer halves) 4445 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> 4446 HWY_API VFromD<D> ConcatUpperLower(D d, VFromD<D> hi, VFromD<D> lo) { 4447 // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks 4448 // are efficiently loaded from 32-bit regs. 4449 const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); 4450 const RebindToUnsigned<decltype(d)> du; // for float16_t 4451 return BitCast(d, VFromD<decltype(du)>{_mm512_mask_blend_epi16( 4452 mask, BitCast(du, hi).raw, BitCast(du, lo).raw)}); 4453 } 4454 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4455 HWY_API VFromD<D> ConcatUpperLower(D /* tag */, VFromD<D> hi, VFromD<D> lo) { 4456 const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); 4457 return VFromD<D>{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; 4458 } 4459 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4460 HWY_API Vec512<double> ConcatUpperLower(D /* tag */, Vec512<double> hi, 4461 Vec512<double> lo) { 4462 const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); 4463 return Vec512<double>{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; 4464 } 4465 4466 // ------------------------------ ConcatOdd 4467 4468 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 4469 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 4470 const RebindToUnsigned<decltype(d)> du; 4471 #if HWY_TARGET <= HWY_AVX3_DL 4472 alignas(64) static constexpr uint8_t kIdx[64] = { 4473 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 4474 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 4475 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 4476 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, 4477 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; 4478 return BitCast( 4479 d, Vec512<uint8_t>{_mm512_permutex2var_epi8( 4480 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 4481 #else 4482 const RepartitionToWide<decltype(du)> dw; 4483 // Right-shift 8 bits per u16 so we can pack. 4484 const Vec512<uint16_t> uH = ShiftRight<8>(BitCast(dw, hi)); 4485 const Vec512<uint16_t> uL = ShiftRight<8>(BitCast(dw, lo)); 4486 const Vec512<uint64_t> u8{_mm512_packus_epi16(uL.raw, uH.raw)}; 4487 // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. 4488 const Full512<uint64_t> du64; 4489 alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; 4490 return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); 4491 #endif 4492 } 4493 4494 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 4495 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 4496 const RebindToUnsigned<decltype(d)> du; 4497 alignas(64) static constexpr uint16_t kIdx[32] = { 4498 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 4499 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; 4500 return BitCast( 4501 d, Vec512<uint16_t>{_mm512_permutex2var_epi16( 4502 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 4503 } 4504 4505 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 4506 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 4507 const RebindToUnsigned<decltype(d)> du; 4508 alignas(64) static constexpr uint32_t kIdx[16] = { 4509 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; 4510 return BitCast( 4511 d, Vec512<uint32_t>{_mm512_permutex2var_epi32( 4512 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 4513 } 4514 4515 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4516 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 4517 const RebindToUnsigned<decltype(d)> du; 4518 alignas(64) static constexpr uint32_t kIdx[16] = { 4519 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; 4520 return VFromD<D>{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; 4521 } 4522 4523 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 4524 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 4525 const RebindToUnsigned<decltype(d)> du; 4526 alignas(64) static constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; 4527 return BitCast( 4528 d, Vec512<uint64_t>{_mm512_permutex2var_epi64( 4529 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 4530 } 4531 4532 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4533 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 4534 const RebindToUnsigned<decltype(d)> du; 4535 alignas(64) static constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; 4536 return VFromD<D>{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; 4537 } 4538 4539 // ------------------------------ ConcatEven 4540 4541 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 4542 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 4543 const RebindToUnsigned<decltype(d)> du; 4544 #if HWY_TARGET <= HWY_AVX3_DL 4545 alignas(64) static constexpr uint8_t kIdx[64] = { 4546 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 4547 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 4548 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 4549 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, 4550 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; 4551 return BitCast( 4552 d, Vec512<uint32_t>{_mm512_permutex2var_epi8( 4553 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 4554 #else 4555 const RepartitionToWide<decltype(du)> dw; 4556 // Isolate lower 8 bits per u16 so we can pack. 4557 const Vec512<uint16_t> mask = Set(dw, 0x00FF); 4558 const Vec512<uint16_t> uH = And(BitCast(dw, hi), mask); 4559 const Vec512<uint16_t> uL = And(BitCast(dw, lo), mask); 4560 const Vec512<uint64_t> u8{_mm512_packus_epi16(uL.raw, uH.raw)}; 4561 // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. 4562 const Full512<uint64_t> du64; 4563 alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; 4564 return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); 4565 #endif 4566 } 4567 4568 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 4569 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 4570 const RebindToUnsigned<decltype(d)> du; 4571 alignas(64) static constexpr uint16_t kIdx[32] = { 4572 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 4573 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; 4574 return BitCast( 4575 d, Vec512<uint32_t>{_mm512_permutex2var_epi16( 4576 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 4577 } 4578 4579 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 4580 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 4581 const RebindToUnsigned<decltype(d)> du; 4582 alignas(64) static constexpr uint32_t kIdx[16] = { 4583 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; 4584 return BitCast( 4585 d, Vec512<uint32_t>{_mm512_permutex2var_epi32( 4586 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 4587 } 4588 4589 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4590 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 4591 const RebindToUnsigned<decltype(d)> du; 4592 alignas(64) static constexpr uint32_t kIdx[16] = { 4593 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; 4594 return VFromD<D>{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; 4595 } 4596 4597 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 4598 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 4599 const RebindToUnsigned<decltype(d)> du; 4600 alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; 4601 return BitCast( 4602 d, Vec512<uint64_t>{_mm512_permutex2var_epi64( 4603 BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); 4604 } 4605 4606 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4607 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 4608 const RebindToUnsigned<decltype(d)> du; 4609 alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; 4610 return VFromD<D>{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; 4611 } 4612 4613 // ------------------------------ InterleaveWholeLower 4614 4615 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 4616 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 4617 #if HWY_TARGET <= HWY_AVX3_DL 4618 const RebindToUnsigned<decltype(d)> du; 4619 alignas(64) static constexpr uint8_t kIdx[64] = { 4620 0, 64, 1, 65, 2, 66, 3, 67, 4, 68, 5, 69, 6, 70, 7, 71, 4621 8, 72, 9, 73, 10, 74, 11, 75, 12, 76, 13, 77, 14, 78, 15, 79, 4622 16, 80, 17, 81, 18, 82, 19, 83, 20, 84, 21, 85, 22, 86, 23, 87, 4623 24, 88, 25, 89, 26, 90, 27, 91, 28, 92, 29, 93, 30, 94, 31, 95}; 4624 return VFromD<D>{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; 4625 #else 4626 alignas(64) static constexpr uint64_t kIdx2[8] = {0, 1, 8, 9, 2, 3, 10, 11}; 4627 const Repartition<uint64_t, decltype(d)> du64; 4628 return VFromD<D>{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, 4629 Load(du64, kIdx2).raw, 4630 InterleaveUpper(d, a, b).raw)}; 4631 #endif 4632 } 4633 4634 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 4635 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 4636 const RebindToUnsigned<decltype(d)> du; 4637 alignas(64) static constexpr uint16_t kIdx[32] = { 4638 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, 4639 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; 4640 return BitCast( 4641 d, VFromD<decltype(du)>{_mm512_permutex2var_epi16( 4642 BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); 4643 } 4644 4645 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 4646 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 4647 const RebindToUnsigned<decltype(d)> du; 4648 alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, 4649 4, 20, 5, 21, 6, 22, 7, 23}; 4650 return VFromD<D>{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; 4651 } 4652 4653 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4654 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 4655 const RebindToUnsigned<decltype(d)> du; 4656 alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, 4657 4, 20, 5, 21, 6, 22, 7, 23}; 4658 return VFromD<D>{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; 4659 } 4660 4661 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 4662 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 4663 const RebindToUnsigned<decltype(d)> du; 4664 alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; 4665 return VFromD<D>{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; 4666 } 4667 4668 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4669 HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) { 4670 const RebindToUnsigned<decltype(d)> du; 4671 alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; 4672 return VFromD<D>{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; 4673 } 4674 4675 // ------------------------------ InterleaveWholeUpper 4676 4677 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 4678 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 4679 #if HWY_TARGET <= HWY_AVX3_DL 4680 const RebindToUnsigned<decltype(d)> du; 4681 alignas(64) static constexpr uint8_t kIdx[64] = { 4682 32, 96, 33, 97, 34, 98, 35, 99, 36, 100, 37, 101, 38, 102, 39, 103, 4683 40, 104, 41, 105, 42, 106, 43, 107, 44, 108, 45, 109, 46, 110, 47, 111, 4684 48, 112, 49, 113, 50, 114, 51, 115, 52, 116, 53, 117, 54, 118, 55, 119, 4685 56, 120, 57, 121, 58, 122, 59, 123, 60, 124, 61, 125, 62, 126, 63, 127}; 4686 return VFromD<D>{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; 4687 #else 4688 alignas(64) static constexpr uint64_t kIdx2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; 4689 const Repartition<uint64_t, decltype(d)> du64; 4690 return VFromD<D>{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, 4691 Load(du64, kIdx2).raw, 4692 InterleaveUpper(d, a, b).raw)}; 4693 #endif 4694 } 4695 4696 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 4697 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 4698 const RebindToUnsigned<decltype(d)> du; 4699 alignas(64) static constexpr uint16_t kIdx[32] = { 4700 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, 4701 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; 4702 return BitCast( 4703 d, VFromD<decltype(du)>{_mm512_permutex2var_epi16( 4704 BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); 4705 } 4706 4707 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)> 4708 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 4709 const RebindToUnsigned<decltype(d)> du; 4710 alignas(64) static constexpr uint32_t kIdx[16] = { 4711 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; 4712 return VFromD<D>{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; 4713 } 4714 4715 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4716 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 4717 const RebindToUnsigned<decltype(d)> du; 4718 alignas(64) static constexpr uint32_t kIdx[16] = { 4719 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; 4720 return VFromD<D>{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; 4721 } 4722 4723 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)> 4724 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 4725 const RebindToUnsigned<decltype(d)> du; 4726 alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; 4727 return VFromD<D>{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; 4728 } 4729 4730 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4731 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 4732 const RebindToUnsigned<decltype(d)> du; 4733 alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; 4734 return VFromD<D>{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; 4735 } 4736 4737 // ------------------------------ DupEven (InterleaveLower) 4738 4739 template <typename T, HWY_IF_T_SIZE(T, 4)> 4740 HWY_API Vec512<T> DupEven(Vec512<T> v) { 4741 return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; 4742 } 4743 HWY_API Vec512<float> DupEven(Vec512<float> v) { 4744 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; 4745 } 4746 4747 template <typename T, HWY_IF_T_SIZE(T, 8)> 4748 HWY_API Vec512<T> DupEven(const Vec512<T> v) { 4749 const DFromV<decltype(v)> d; 4750 return InterleaveLower(d, v, v); 4751 } 4752 4753 // ------------------------------ DupOdd (InterleaveUpper) 4754 4755 template <typename T, HWY_IF_T_SIZE(T, 4)> 4756 HWY_API Vec512<T> DupOdd(Vec512<T> v) { 4757 return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; 4758 } 4759 HWY_API Vec512<float> DupOdd(Vec512<float> v) { 4760 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; 4761 } 4762 4763 template <typename T, HWY_IF_T_SIZE(T, 8)> 4764 HWY_API Vec512<T> DupOdd(const Vec512<T> v) { 4765 const DFromV<decltype(v)> d; 4766 return InterleaveUpper(d, v, v); 4767 } 4768 4769 // ------------------------------ OddEven (IfThenElse) 4770 4771 template <typename T> 4772 HWY_API Vec512<T> OddEven(const Vec512<T> a, const Vec512<T> b) { 4773 constexpr size_t s = sizeof(T); 4774 constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; 4775 return IfThenElse(Mask512<T>{0x5555555555555555ull >> shift}, b, a); 4776 } 4777 4778 // -------------------------- InterleaveEven 4779 4780 template <class D, HWY_IF_LANES_D(D, 16), HWY_IF_UI32_D(D)> 4781 HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { 4782 return VFromD<D>{_mm512_mask_shuffle_epi32( 4783 a.raw, static_cast<__mmask16>(0xAAAA), b.raw, 4784 static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; 4785 } 4786 template <class D, HWY_IF_LANES_D(D, 16), HWY_IF_F32_D(D)> 4787 HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { 4788 return VFromD<D>{_mm512_mask_shuffle_ps(a.raw, static_cast<__mmask16>(0xAAAA), 4789 b.raw, b.raw, 4790 _MM_SHUFFLE(2, 2, 0, 0))}; 4791 } 4792 // -------------------------- InterleaveOdd 4793 4794 template <class D, HWY_IF_LANES_D(D, 16), HWY_IF_UI32_D(D)> 4795 HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) { 4796 return VFromD<D>{_mm512_mask_shuffle_epi32( 4797 b.raw, static_cast<__mmask16>(0x5555), a.raw, 4798 static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; 4799 } 4800 template <class D, HWY_IF_LANES_D(D, 16), HWY_IF_F32_D(D)> 4801 HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) { 4802 return VFromD<D>{_mm512_mask_shuffle_ps(b.raw, static_cast<__mmask16>(0x5555), 4803 a.raw, a.raw, 4804 _MM_SHUFFLE(3, 3, 1, 1))}; 4805 } 4806 4807 // ------------------------------ OddEvenBlocks 4808 4809 template <typename T> 4810 HWY_API Vec512<T> OddEvenBlocks(Vec512<T> odd, Vec512<T> even) { 4811 const DFromV<decltype(odd)> d; 4812 const RebindToUnsigned<decltype(d)> du; // for float16_t 4813 return BitCast( 4814 d, VFromD<decltype(du)>{_mm512_mask_blend_epi64( 4815 __mmask8{0x33u}, BitCast(du, odd).raw, BitCast(du, even).raw)}); 4816 } 4817 4818 HWY_API Vec512<float> OddEvenBlocks(Vec512<float> odd, Vec512<float> even) { 4819 return Vec512<float>{ 4820 _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; 4821 } 4822 4823 HWY_API Vec512<double> OddEvenBlocks(Vec512<double> odd, Vec512<double> even) { 4824 return Vec512<double>{ 4825 _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; 4826 } 4827 4828 // ------------------------------ SwapAdjacentBlocks 4829 4830 template <typename T> 4831 HWY_API Vec512<T> SwapAdjacentBlocks(Vec512<T> v) { 4832 const DFromV<decltype(v)> d; 4833 const RebindToUnsigned<decltype(d)> du; // for float16_t 4834 return BitCast(d, 4835 VFromD<decltype(du)>{_mm512_shuffle_i32x4( 4836 BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_CDAB)}); 4837 } 4838 4839 HWY_API Vec512<float> SwapAdjacentBlocks(Vec512<float> v) { 4840 return Vec512<float>{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; 4841 } 4842 4843 HWY_API Vec512<double> SwapAdjacentBlocks(Vec512<double> v) { 4844 return Vec512<double>{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; 4845 } 4846 4847 // ------------------------------ InterleaveEvenBlocks 4848 template <typename T> 4849 HWY_API Vec512<T> InterleaveEvenBlocks(Full512<T> d, Vec512<T> a, Vec512<T> b) { 4850 return OddEvenBlocks(SlideUpBlocks<1>(d, b), a); 4851 } 4852 4853 // ------------------------------ InterleaveOddBlocks (ConcatUpperUpper) 4854 template <typename T> 4855 HWY_API Vec512<T> InterleaveOddBlocks(Full512<T> d, Vec512<T> a, Vec512<T> b) { 4856 return OddEvenBlocks(b, SlideDownBlocks<1>(d, a)); 4857 } 4858 4859 // ------------------------------ ReverseBlocks 4860 4861 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)> 4862 HWY_API VFromD<D> ReverseBlocks(D d, VFromD<D> v) { 4863 const RebindToUnsigned<decltype(d)> du; // for float16_t 4864 return BitCast(d, 4865 VFromD<decltype(du)>{_mm512_shuffle_i32x4( 4866 BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_ABCD)}); 4867 } 4868 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 4869 HWY_API VFromD<D> ReverseBlocks(D /* tag */, VFromD<D> v) { 4870 return VFromD<D>{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; 4871 } 4872 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 4873 HWY_API VFromD<D> ReverseBlocks(D /* tag */, VFromD<D> v) { 4874 return VFromD<D>{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; 4875 } 4876 4877 // ------------------------------ TableLookupBytes (ZeroExtendVector) 4878 4879 // Both full 4880 template <typename T, typename TI> 4881 HWY_API Vec512<TI> TableLookupBytes(Vec512<T> bytes, Vec512<TI> indices) { 4882 const DFromV<decltype(indices)> d; 4883 return BitCast(d, Vec512<uint8_t>{_mm512_shuffle_epi8( 4884 BitCast(Full512<uint8_t>(), bytes).raw, 4885 BitCast(Full512<uint8_t>(), indices).raw)}); 4886 } 4887 4888 // Partial index vector 4889 template <typename T, typename TI, size_t NI> 4890 HWY_API Vec128<TI, NI> TableLookupBytes(Vec512<T> bytes, Vec128<TI, NI> from) { 4891 const Full512<TI> d512; 4892 const Half<decltype(d512)> d256; 4893 const Half<decltype(d256)> d128; 4894 // First expand to full 128, then 256, then 512. 4895 const Vec128<TI> from_full{from.raw}; 4896 const auto from_512 = 4897 ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); 4898 const auto tbl_full = TableLookupBytes(bytes, from_512); 4899 // Shrink to 256, then 128, then partial. 4900 return Vec128<TI, NI>{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; 4901 } 4902 template <typename T, typename TI> 4903 HWY_API Vec256<TI> TableLookupBytes(Vec512<T> bytes, Vec256<TI> from) { 4904 const DFromV<decltype(from)> dih; 4905 const Twice<decltype(dih)> di; 4906 const auto from_512 = ZeroExtendVector(di, from); 4907 return LowerHalf(dih, TableLookupBytes(bytes, from_512)); 4908 } 4909 4910 // Partial table vector 4911 template <typename T, size_t N, typename TI> 4912 HWY_API Vec512<TI> TableLookupBytes(Vec128<T, N> bytes, Vec512<TI> from) { 4913 const DFromV<decltype(from)> d512; 4914 const Half<decltype(d512)> d256; 4915 const Half<decltype(d256)> d128; 4916 // First expand to full 128, then 256, then 512. 4917 const Vec128<T> bytes_full{bytes.raw}; 4918 const auto bytes_512 = 4919 ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); 4920 return TableLookupBytes(bytes_512, from); 4921 } 4922 template <typename T, typename TI> 4923 HWY_API Vec512<TI> TableLookupBytes(Vec256<T> bytes, Vec512<TI> from) { 4924 const Full512<T> d; 4925 return TableLookupBytes(ZeroExtendVector(d, bytes), from); 4926 } 4927 4928 // Partial both are handled by x86_128/256. 4929 4930 // ------------------------------ I8/U8 Broadcast (TableLookupBytes) 4931 4932 template <int kLane, class T, HWY_IF_T_SIZE(T, 1)> 4933 HWY_API Vec512<T> Broadcast(const Vec512<T> v) { 4934 static_assert(0 <= kLane && kLane < 16, "Invalid lane"); 4935 return TableLookupBytes(v, Set(Full512<T>(), static_cast<T>(kLane))); 4936 } 4937 4938 // ------------------------------ Per4LaneBlockShuffle 4939 4940 namespace detail { 4941 4942 template <class D, HWY_IF_V_SIZE_D(D, 64)> 4943 HWY_INLINE VFromD<D> Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, 4944 const uint32_t x2, 4945 const uint32_t x1, 4946 const uint32_t x0) { 4947 return BitCast(d, Vec512<uint32_t>{_mm512_set_epi32( 4948 static_cast<int32_t>(x3), static_cast<int32_t>(x2), 4949 static_cast<int32_t>(x1), static_cast<int32_t>(x0), 4950 static_cast<int32_t>(x3), static_cast<int32_t>(x2), 4951 static_cast<int32_t>(x1), static_cast<int32_t>(x0), 4952 static_cast<int32_t>(x3), static_cast<int32_t>(x2), 4953 static_cast<int32_t>(x1), static_cast<int32_t>(x0), 4954 static_cast<int32_t>(x3), static_cast<int32_t>(x2), 4955 static_cast<int32_t>(x1), static_cast<int32_t>(x0))}); 4956 } 4957 4958 template <size_t kIdx3210, class V, HWY_IF_NOT_FLOAT(TFromV<V>)> 4959 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, 4960 hwy::SizeTag<4> /*lane_size_tag*/, 4961 hwy::SizeTag<64> /*vect_size_tag*/, V v) { 4962 return V{ 4963 _mm512_shuffle_epi32(v.raw, static_cast<_MM_PERM_ENUM>(kIdx3210 & 0xFF))}; 4964 } 4965 4966 template <size_t kIdx3210, class V, HWY_IF_FLOAT(TFromV<V>)> 4967 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, 4968 hwy::SizeTag<4> /*lane_size_tag*/, 4969 hwy::SizeTag<64> /*vect_size_tag*/, V v) { 4970 return V{_mm512_shuffle_ps(v.raw, v.raw, static_cast<int>(kIdx3210 & 0xFF))}; 4971 } 4972 4973 template <size_t kIdx3210, class V, HWY_IF_NOT_FLOAT(TFromV<V>)> 4974 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, 4975 hwy::SizeTag<8> /*lane_size_tag*/, 4976 hwy::SizeTag<64> /*vect_size_tag*/, V v) { 4977 return V{_mm512_permutex_epi64(v.raw, static_cast<int>(kIdx3210 & 0xFF))}; 4978 } 4979 4980 template <size_t kIdx3210, class V, HWY_IF_FLOAT(TFromV<V>)> 4981 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/, 4982 hwy::SizeTag<8> /*lane_size_tag*/, 4983 hwy::SizeTag<64> /*vect_size_tag*/, V v) { 4984 return V{_mm512_permutex_pd(v.raw, static_cast<int>(kIdx3210 & 0xFF))}; 4985 } 4986 4987 } // namespace detail 4988 4989 // ------------------------------ SlideUpLanes 4990 4991 namespace detail { 4992 4993 template <int kI32Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> 4994 HWY_INLINE V CombineShiftRightI32Lanes(V hi, V lo) { 4995 const DFromV<decltype(hi)> d; 4996 const Repartition<uint32_t, decltype(d)> du32; 4997 return BitCast(d, 4998 Vec512<uint32_t>{_mm512_alignr_epi32( 4999 BitCast(du32, hi).raw, BitCast(du32, lo).raw, kI32Lanes)}); 5000 } 5001 5002 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> 5003 HWY_INLINE V CombineShiftRightI64Lanes(V hi, V lo) { 5004 const DFromV<decltype(hi)> d; 5005 const Repartition<uint64_t, decltype(d)> du64; 5006 return BitCast(d, 5007 Vec512<uint64_t>{_mm512_alignr_epi64( 5008 BitCast(du64, hi).raw, BitCast(du64, lo).raw, kI64Lanes)}); 5009 } 5010 5011 template <int kI32Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> 5012 HWY_INLINE V SlideUpI32Lanes(V v) { 5013 static_assert(0 <= kI32Lanes && kI32Lanes <= 15, 5014 "kI32Lanes must be between 0 and 15"); 5015 const DFromV<decltype(v)> d; 5016 return CombineShiftRightI32Lanes<16 - kI32Lanes>(v, Zero(d)); 5017 } 5018 5019 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> 5020 HWY_INLINE V SlideUpI64Lanes(V v) { 5021 static_assert(0 <= kI64Lanes && kI64Lanes <= 7, 5022 "kI64Lanes must be between 0 and 7"); 5023 const DFromV<decltype(v)> d; 5024 return CombineShiftRightI64Lanes<8 - kI64Lanes>(v, Zero(d)); 5025 } 5026 5027 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 5028 HWY_INLINE VFromD<D> TableLookupSlideUpLanes(D d, VFromD<D> v, size_t amt) { 5029 const Repartition<uint8_t, decltype(d)> du8; 5030 5031 #if HWY_TARGET <= HWY_AVX3_DL 5032 const auto byte_idx = Iota(du8, static_cast<uint8_t>(size_t{0} - amt)); 5033 return TwoTablesLookupLanes(v, Zero(d), Indices512<TFromD<D>>{byte_idx.raw}); 5034 #else 5035 const Repartition<uint16_t, decltype(d)> du16; 5036 const Repartition<uint64_t, decltype(d)> du64; 5037 const auto byte_idx = Iota(du8, static_cast<uint8_t>(size_t{0} - (amt & 15))); 5038 const auto blk_u64_idx = 5039 Iota(du64, static_cast<uint64_t>(uint64_t{0} - ((amt >> 4) << 1))); 5040 5041 const VFromD<D> even_blocks{ 5042 _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; 5043 const VFromD<D> odd_blocks{ 5044 _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(3, 1, 1, 3))}; 5045 const auto odd_sel_mask = 5046 MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, byte_idx)))); 5047 const auto even_blk_lookup_result = 5048 BitCast(d, TableLookupBytes(even_blocks, byte_idx)); 5049 const VFromD<D> blockwise_slide_up_result{ 5050 _mm512_mask_shuffle_epi8(even_blk_lookup_result.raw, odd_sel_mask.raw, 5051 odd_blocks.raw, byte_idx.raw)}; 5052 return BitCast(d, TwoTablesLookupLanes( 5053 BitCast(du64, blockwise_slide_up_result), Zero(du64), 5054 Indices512<uint64_t>{blk_u64_idx.raw})); 5055 #endif 5056 } 5057 5058 } // namespace detail 5059 5060 template <int kBlocks, class D, HWY_IF_V_SIZE_D(D, 64)> 5061 HWY_API VFromD<D> SlideUpBlocks(D d, VFromD<D> v) { 5062 static_assert(0 <= kBlocks && kBlocks <= 3, 5063 "kBlocks must be between 0 and 3"); 5064 switch (kBlocks) { 5065 case 0: 5066 return v; 5067 case 1: 5068 return detail::SlideUpI64Lanes<2>(v); 5069 case 2: 5070 return ConcatLowerLower(d, v, Zero(d)); 5071 case 3: 5072 return detail::SlideUpI64Lanes<6>(v); 5073 } 5074 5075 return v; 5076 } 5077 5078 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> 5079 HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { 5080 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5081 if (__builtin_constant_p(amt)) { 5082 switch (amt) { 5083 case 0: 5084 return v; 5085 case 1: 5086 return detail::SlideUpI32Lanes<1>(v); 5087 case 2: 5088 return detail::SlideUpI64Lanes<1>(v); 5089 case 3: 5090 return detail::SlideUpI32Lanes<3>(v); 5091 case 4: 5092 return detail::SlideUpI64Lanes<2>(v); 5093 case 5: 5094 return detail::SlideUpI32Lanes<5>(v); 5095 case 6: 5096 return detail::SlideUpI64Lanes<3>(v); 5097 case 7: 5098 return detail::SlideUpI32Lanes<7>(v); 5099 case 8: 5100 return ConcatLowerLower(d, v, Zero(d)); 5101 case 9: 5102 return detail::SlideUpI32Lanes<9>(v); 5103 case 10: 5104 return detail::SlideUpI64Lanes<5>(v); 5105 case 11: 5106 return detail::SlideUpI32Lanes<11>(v); 5107 case 12: 5108 return detail::SlideUpI64Lanes<6>(v); 5109 case 13: 5110 return detail::SlideUpI32Lanes<13>(v); 5111 case 14: 5112 return detail::SlideUpI64Lanes<7>(v); 5113 case 15: 5114 return detail::SlideUpI32Lanes<15>(v); 5115 } 5116 } 5117 #endif 5118 5119 return detail::TableLookupSlideUpLanes(d, v, amt); 5120 } 5121 5122 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> 5123 HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { 5124 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5125 if (__builtin_constant_p(amt)) { 5126 switch (amt) { 5127 case 0: 5128 return v; 5129 case 1: 5130 return detail::SlideUpI64Lanes<1>(v); 5131 case 2: 5132 return detail::SlideUpI64Lanes<2>(v); 5133 case 3: 5134 return detail::SlideUpI64Lanes<3>(v); 5135 case 4: 5136 return ConcatLowerLower(d, v, Zero(d)); 5137 case 5: 5138 return detail::SlideUpI64Lanes<5>(v); 5139 case 6: 5140 return detail::SlideUpI64Lanes<6>(v); 5141 case 7: 5142 return detail::SlideUpI64Lanes<7>(v); 5143 } 5144 } 5145 #endif 5146 5147 return detail::TableLookupSlideUpLanes(d, v, amt); 5148 } 5149 5150 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 5151 HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { 5152 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5153 if (__builtin_constant_p(amt)) { 5154 if ((amt & 3) == 0) { 5155 const Repartition<uint32_t, decltype(d)> du32; 5156 return BitCast(d, SlideUpLanes(du32, BitCast(du32, v), amt >> 2)); 5157 } else if ((amt & 1) == 0) { 5158 const Repartition<uint16_t, decltype(d)> du16; 5159 return BitCast( 5160 d, detail::TableLookupSlideUpLanes(du16, BitCast(du16, v), amt >> 1)); 5161 } 5162 #if HWY_TARGET > HWY_AVX3_DL 5163 else if (amt <= 63) { // NOLINT(readability/braces) 5164 const Repartition<uint64_t, decltype(d)> du64; 5165 const size_t blk_u64_slideup_amt = (amt >> 4) << 1; 5166 const auto vu64 = BitCast(du64, v); 5167 const auto v_hi = 5168 BitCast(d, SlideUpLanes(du64, vu64, blk_u64_slideup_amt)); 5169 const auto v_lo = 5170 (blk_u64_slideup_amt <= 4) 5171 ? BitCast(d, SlideUpLanes(du64, vu64, blk_u64_slideup_amt + 2)) 5172 : Zero(d); 5173 switch (amt & 15) { 5174 case 1: 5175 return CombineShiftRightBytes<15>(d, v_hi, v_lo); 5176 case 3: 5177 return CombineShiftRightBytes<13>(d, v_hi, v_lo); 5178 case 5: 5179 return CombineShiftRightBytes<11>(d, v_hi, v_lo); 5180 case 7: 5181 return CombineShiftRightBytes<9>(d, v_hi, v_lo); 5182 case 9: 5183 return CombineShiftRightBytes<7>(d, v_hi, v_lo); 5184 case 11: 5185 return CombineShiftRightBytes<5>(d, v_hi, v_lo); 5186 case 13: 5187 return CombineShiftRightBytes<3>(d, v_hi, v_lo); 5188 case 15: 5189 return CombineShiftRightBytes<1>(d, v_hi, v_lo); 5190 } 5191 } 5192 #endif // HWY_TARGET > HWY_AVX3_DL 5193 } 5194 #endif 5195 5196 return detail::TableLookupSlideUpLanes(d, v, amt); 5197 } 5198 5199 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 5200 HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { 5201 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5202 if (__builtin_constant_p(amt) && (amt & 1) == 0) { 5203 const Repartition<uint32_t, decltype(d)> du32; 5204 return BitCast(d, SlideUpLanes(du32, BitCast(du32, v), amt >> 1)); 5205 } 5206 #endif 5207 5208 return detail::TableLookupSlideUpLanes(d, v, amt); 5209 } 5210 5211 // ------------------------------ Slide1Up 5212 5213 template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 5214 HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { 5215 #if HWY_TARGET <= HWY_AVX3_DL 5216 return detail::TableLookupSlideUpLanes(d, v, 1); 5217 #else 5218 const auto v_lo = detail::SlideUpI64Lanes<2>(v); 5219 return CombineShiftRightBytes<15>(d, v, v_lo); 5220 #endif 5221 } 5222 5223 template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 5224 HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { 5225 return detail::TableLookupSlideUpLanes(d, v, 1); 5226 } 5227 5228 template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> 5229 HWY_API VFromD<D> Slide1Up(D /*d*/, VFromD<D> v) { 5230 return detail::SlideUpI32Lanes<1>(v); 5231 } 5232 5233 template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> 5234 HWY_API VFromD<D> Slide1Up(D /*d*/, VFromD<D> v) { 5235 return detail::SlideUpI64Lanes<1>(v); 5236 } 5237 5238 // ------------------------------ SlideDownLanes 5239 5240 namespace detail { 5241 5242 template <int kI32Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> 5243 HWY_INLINE V SlideDownI32Lanes(V v) { 5244 static_assert(0 <= kI32Lanes && kI32Lanes <= 15, 5245 "kI32Lanes must be between 0 and 15"); 5246 const DFromV<decltype(v)> d; 5247 return CombineShiftRightI32Lanes<kI32Lanes>(Zero(d), v); 5248 } 5249 5250 template <int kI64Lanes, class V, HWY_IF_V_SIZE_V(V, 64)> 5251 HWY_INLINE V SlideDownI64Lanes(V v) { 5252 static_assert(0 <= kI64Lanes && kI64Lanes <= 7, 5253 "kI64Lanes must be between 0 and 7"); 5254 const DFromV<decltype(v)> d; 5255 return CombineShiftRightI64Lanes<kI64Lanes>(Zero(d), v); 5256 } 5257 5258 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 5259 HWY_INLINE VFromD<D> TableLookupSlideDownLanes(D d, VFromD<D> v, size_t amt) { 5260 const Repartition<uint8_t, decltype(d)> du8; 5261 5262 #if HWY_TARGET <= HWY_AVX3_DL 5263 auto byte_idx = Iota(du8, static_cast<uint8_t>(amt)); 5264 return TwoTablesLookupLanes(v, Zero(d), Indices512<TFromD<D>>{byte_idx.raw}); 5265 #else 5266 const Repartition<uint16_t, decltype(d)> du16; 5267 const Repartition<uint64_t, decltype(d)> du64; 5268 const auto byte_idx = Iota(du8, static_cast<uint8_t>(amt & 15)); 5269 const auto blk_u64_idx = Iota(du64, static_cast<uint64_t>(((amt >> 4) << 1))); 5270 5271 const VFromD<D> even_blocks{ 5272 _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(0, 2, 2, 0))}; 5273 const VFromD<D> odd_blocks{ 5274 _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; 5275 const auto odd_sel_mask = 5276 MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, byte_idx)))); 5277 const VFromD<D> even_blk_lookup_result{ 5278 _mm512_maskz_shuffle_epi8(static_cast<__mmask64>(0x0000FFFFFFFFFFFFULL), 5279 even_blocks.raw, byte_idx.raw)}; 5280 const VFromD<D> blockwise_slide_up_result{ 5281 _mm512_mask_shuffle_epi8(even_blk_lookup_result.raw, odd_sel_mask.raw, 5282 odd_blocks.raw, byte_idx.raw)}; 5283 return BitCast(d, TwoTablesLookupLanes( 5284 BitCast(du64, blockwise_slide_up_result), Zero(du64), 5285 Indices512<uint64_t>{blk_u64_idx.raw})); 5286 #endif 5287 } 5288 5289 } // namespace detail 5290 5291 template <int kBlocks, class D, HWY_IF_V_SIZE_D(D, 64)> 5292 HWY_API VFromD<D> SlideDownBlocks(D d, VFromD<D> v) { 5293 static_assert(0 <= kBlocks && kBlocks <= 3, 5294 "kBlocks must be between 0 and 3"); 5295 const Half<decltype(d)> dh; 5296 switch (kBlocks) { 5297 case 0: 5298 return v; 5299 case 1: 5300 return detail::SlideDownI64Lanes<2>(v); 5301 case 2: 5302 return ZeroExtendVector(d, UpperHalf(dh, v)); 5303 case 3: 5304 return detail::SlideDownI64Lanes<6>(v); 5305 } 5306 5307 return v; 5308 } 5309 5310 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> 5311 HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { 5312 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5313 if (__builtin_constant_p(amt)) { 5314 const Half<decltype(d)> dh; 5315 switch (amt) { 5316 case 1: 5317 return detail::SlideDownI32Lanes<1>(v); 5318 case 2: 5319 return detail::SlideDownI64Lanes<1>(v); 5320 case 3: 5321 return detail::SlideDownI32Lanes<3>(v); 5322 case 4: 5323 return detail::SlideDownI64Lanes<2>(v); 5324 case 5: 5325 return detail::SlideDownI32Lanes<5>(v); 5326 case 6: 5327 return detail::SlideDownI64Lanes<3>(v); 5328 case 7: 5329 return detail::SlideDownI32Lanes<7>(v); 5330 case 8: 5331 return ZeroExtendVector(d, UpperHalf(dh, v)); 5332 case 9: 5333 return detail::SlideDownI32Lanes<9>(v); 5334 case 10: 5335 return detail::SlideDownI64Lanes<5>(v); 5336 case 11: 5337 return detail::SlideDownI32Lanes<11>(v); 5338 case 12: 5339 return detail::SlideDownI64Lanes<6>(v); 5340 case 13: 5341 return detail::SlideDownI32Lanes<13>(v); 5342 case 14: 5343 return detail::SlideDownI64Lanes<7>(v); 5344 case 15: 5345 return detail::SlideDownI32Lanes<15>(v); 5346 } 5347 } 5348 #endif 5349 5350 return detail::TableLookupSlideDownLanes(d, v, amt); 5351 } 5352 5353 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> 5354 HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { 5355 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5356 if (__builtin_constant_p(amt)) { 5357 const Half<decltype(d)> dh; 5358 switch (amt) { 5359 case 0: 5360 return v; 5361 case 1: 5362 return detail::SlideDownI64Lanes<1>(v); 5363 case 2: 5364 return detail::SlideDownI64Lanes<2>(v); 5365 case 3: 5366 return detail::SlideDownI64Lanes<3>(v); 5367 case 4: 5368 return ZeroExtendVector(d, UpperHalf(dh, v)); 5369 case 5: 5370 return detail::SlideDownI64Lanes<5>(v); 5371 case 6: 5372 return detail::SlideDownI64Lanes<6>(v); 5373 case 7: 5374 return detail::SlideDownI64Lanes<7>(v); 5375 } 5376 } 5377 #endif 5378 5379 return detail::TableLookupSlideDownLanes(d, v, amt); 5380 } 5381 5382 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 5383 HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { 5384 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5385 if (__builtin_constant_p(amt)) { 5386 if ((amt & 3) == 0) { 5387 const Repartition<uint32_t, decltype(d)> du32; 5388 return BitCast(d, SlideDownLanes(du32, BitCast(du32, v), amt >> 2)); 5389 } else if ((amt & 1) == 0) { 5390 const Repartition<uint16_t, decltype(d)> du16; 5391 return BitCast(d, detail::TableLookupSlideDownLanes( 5392 du16, BitCast(du16, v), amt >> 1)); 5393 } 5394 #if HWY_TARGET > HWY_AVX3_DL 5395 else if (amt <= 63) { // NOLINT(readability/braces) 5396 const Repartition<uint64_t, decltype(d)> du64; 5397 const size_t blk_u64_slidedown_amt = (amt >> 4) << 1; 5398 const auto vu64 = BitCast(du64, v); 5399 const auto v_lo = 5400 BitCast(d, SlideDownLanes(du64, vu64, blk_u64_slidedown_amt)); 5401 const auto v_hi = 5402 (blk_u64_slidedown_amt <= 4) 5403 ? BitCast(d, 5404 SlideDownLanes(du64, vu64, blk_u64_slidedown_amt + 2)) 5405 : Zero(d); 5406 switch (amt & 15) { 5407 case 1: 5408 return CombineShiftRightBytes<1>(d, v_hi, v_lo); 5409 case 3: 5410 return CombineShiftRightBytes<3>(d, v_hi, v_lo); 5411 case 5: 5412 return CombineShiftRightBytes<5>(d, v_hi, v_lo); 5413 case 7: 5414 return CombineShiftRightBytes<7>(d, v_hi, v_lo); 5415 case 9: 5416 return CombineShiftRightBytes<9>(d, v_hi, v_lo); 5417 case 11: 5418 return CombineShiftRightBytes<11>(d, v_hi, v_lo); 5419 case 13: 5420 return CombineShiftRightBytes<13>(d, v_hi, v_lo); 5421 case 15: 5422 return CombineShiftRightBytes<15>(d, v_hi, v_lo); 5423 } 5424 } 5425 #endif 5426 } 5427 #endif 5428 5429 return detail::TableLookupSlideDownLanes(d, v, amt); 5430 } 5431 5432 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 5433 HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { 5434 #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang 5435 if (__builtin_constant_p(amt) && (amt & 1) == 0) { 5436 const Repartition<uint32_t, decltype(d)> du32; 5437 return BitCast(d, SlideDownLanes(du32, BitCast(du32, v), amt >> 1)); 5438 } 5439 #endif 5440 5441 return detail::TableLookupSlideDownLanes(d, v, amt); 5442 } 5443 5444 // ------------------------------ Slide1Down 5445 5446 template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 5447 HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { 5448 #if HWY_TARGET <= HWY_AVX3_DL 5449 return detail::TableLookupSlideDownLanes(d, v, 1); 5450 #else 5451 const auto v_hi = detail::SlideDownI64Lanes<2>(v); 5452 return CombineShiftRightBytes<1>(d, v_hi, v); 5453 #endif 5454 } 5455 5456 template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 2)> 5457 HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { 5458 return detail::TableLookupSlideDownLanes(d, v, 1); 5459 } 5460 5461 template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 4)> 5462 HWY_API VFromD<D> Slide1Down(D /*d*/, VFromD<D> v) { 5463 return detail::SlideDownI32Lanes<1>(v); 5464 } 5465 5466 template <typename D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 8)> 5467 HWY_API VFromD<D> Slide1Down(D /*d*/, VFromD<D> v) { 5468 return detail::SlideDownI64Lanes<1>(v); 5469 } 5470 5471 // ================================================== CONVERT 5472 5473 // ------------------------------ Promotions (part w/ narrow lanes -> full) 5474 5475 // Unsigned: zero-extend. 5476 // Note: these have 3 cycle latency; if inputs are already split across the 5477 // 128 bit blocks (in their upper/lower halves), then Zip* would be faster. 5478 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> 5479 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint8_t> v) { 5480 return VFromD<D>{_mm512_cvtepu8_epi16(v.raw)}; 5481 } 5482 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U32_D(D)> 5483 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint8_t> v) { 5484 return VFromD<D>{_mm512_cvtepu8_epi32(v.raw)}; 5485 } 5486 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U32_D(D)> 5487 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint16_t> v) { 5488 return VFromD<D>{_mm512_cvtepu16_epi32(v.raw)}; 5489 } 5490 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> 5491 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint32_t> v) { 5492 return VFromD<D>{_mm512_cvtepu32_epi64(v.raw)}; 5493 } 5494 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> 5495 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint16_t> v) { 5496 return VFromD<D>{_mm512_cvtepu16_epi64(v.raw)}; 5497 } 5498 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> 5499 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec64<uint8_t> v) { 5500 return VFromD<D>{_mm512_cvtepu8_epi64(v.raw)}; 5501 } 5502 5503 // Signed: replicate sign bit. 5504 // Note: these have 3 cycle latency; if inputs are already split across the 5505 // 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by 5506 // signed shift would be faster. 5507 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I16_D(D)> 5508 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<int8_t> v) { 5509 return VFromD<D>{_mm512_cvtepi8_epi16(v.raw)}; 5510 } 5511 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> 5512 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int8_t> v) { 5513 return VFromD<D>{_mm512_cvtepi8_epi32(v.raw)}; 5514 } 5515 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> 5516 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<int16_t> v) { 5517 return VFromD<D>{_mm512_cvtepi16_epi32(v.raw)}; 5518 } 5519 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> 5520 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<int32_t> v) { 5521 return VFromD<D>{_mm512_cvtepi32_epi64(v.raw)}; 5522 } 5523 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> 5524 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int16_t> v) { 5525 return VFromD<D>{_mm512_cvtepi16_epi64(v.raw)}; 5526 } 5527 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> 5528 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec64<int8_t> v) { 5529 return VFromD<D>{_mm512_cvtepi8_epi64(v.raw)}; 5530 } 5531 5532 // Float 5533 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 5534 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<float16_t> v) { 5535 #if HWY_HAVE_FLOAT16 5536 const RebindToUnsigned<DFromV<decltype(v)>> du16; 5537 return VFromD<D>{_mm512_cvtph_ps(BitCast(du16, v).raw)}; 5538 #else 5539 return VFromD<D>{_mm512_cvtph_ps(v.raw)}; 5540 #endif // HWY_HAVE_FLOAT16 5541 } 5542 5543 #if HWY_HAVE_FLOAT16 5544 5545 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 5546 HWY_INLINE VFromD<D> PromoteTo(D /*tag*/, Vec128<float16_t> v) { 5547 return VFromD<D>{_mm512_cvtph_pd(v.raw)}; 5548 } 5549 5550 #endif // HWY_HAVE_FLOAT16 5551 5552 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 5553 HWY_API VFromD<D> PromoteTo(D df32, Vec256<bfloat16_t> v) { 5554 const Rebind<uint16_t, decltype(df32)> du16; 5555 const RebindToSigned<decltype(df32)> di32; 5556 return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); 5557 } 5558 5559 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 5560 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<float> v) { 5561 return VFromD<D>{_mm512_cvtps_pd(v.raw)}; 5562 } 5563 5564 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 5565 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<int32_t> v) { 5566 return VFromD<D>{_mm512_cvtepi32_pd(v.raw)}; 5567 } 5568 5569 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 5570 HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint32_t> v) { 5571 return VFromD<D>{_mm512_cvtepu32_pd(v.raw)}; 5572 } 5573 5574 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> 5575 HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) { 5576 #if HWY_X86_HAVE_AVX10_2_OPS 5577 return VFromD<D>{_mm512_cvtts_ps_epi64(v.raw)}; 5578 #elif HWY_COMPILER_GCC_ACTUAL 5579 // Workaround for undefined behavior with GCC if any values of v[i] are not 5580 // within the range of an int64_t 5581 5582 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 5583 if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { 5584 typedef float GccF32RawVectType __attribute__((__vector_size__(32))); 5585 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 5586 return VFromD<D>{_mm512_setr_epi64( 5587 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[0]), 5588 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[1]), 5589 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[2]), 5590 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[3]), 5591 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[4]), 5592 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[5]), 5593 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[6]), 5594 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[7]))}; 5595 } 5596 #endif 5597 5598 __m512i raw_result; 5599 __asm__("vcvttps2qq {%1, %0|%0, %1}" 5600 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 5601 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 5602 :); 5603 return VFromD<D>{raw_result}; 5604 #else 5605 return VFromD<D>{_mm512_cvttps_epi64(v.raw)}; 5606 #endif 5607 } 5608 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> 5609 HWY_API VFromD<D> PromoteInRangeTo(D /* tag */, VFromD<Rebind<float, D>> v) { 5610 #if HWY_X86_HAVE_AVX10_2_OPS 5611 return VFromD<D>{_mm512_cvtts_ps_epu64(v.raw)}; 5612 #elif HWY_COMPILER_GCC_ACTUAL 5613 // Workaround for undefined behavior with GCC if any values of v[i] are not 5614 // within the range of an uint64_t 5615 5616 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 5617 if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { 5618 typedef float GccF32RawVectType __attribute__((__vector_size__(32))); 5619 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 5620 return VFromD<D>{_mm512_setr_epi64( 5621 static_cast<int64_t>( 5622 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[0])), 5623 static_cast<int64_t>( 5624 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[1])), 5625 static_cast<int64_t>( 5626 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[2])), 5627 static_cast<int64_t>( 5628 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[3])), 5629 static_cast<int64_t>( 5630 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[4])), 5631 static_cast<int64_t>( 5632 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[5])), 5633 static_cast<int64_t>( 5634 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[6])), 5635 static_cast<int64_t>( 5636 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[7])))}; 5637 } 5638 #endif 5639 5640 __m512i raw_result; 5641 __asm__("vcvttps2uqq {%1, %0|%0, %1}" 5642 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 5643 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 5644 :); 5645 return VFromD<D>{raw_result}; 5646 #else 5647 return VFromD<D>{_mm512_cvttps_epu64(v.raw)}; 5648 #endif 5649 } 5650 5651 // ------------------------------ Demotions (full -> part w/ narrow lanes) 5652 5653 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> 5654 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int32_t> v) { 5655 const Full512<uint64_t> du64; 5656 const Vec512<uint16_t> u16{_mm512_packus_epi32(v.raw, v.raw)}; 5657 5658 // Compress even u64 lanes into 256 bit. 5659 alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 5660 const auto idx64 = Load(du64, kLanes); 5661 const Vec512<uint16_t> even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; 5662 return LowerHalf(even); 5663 } 5664 5665 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> 5666 HWY_API VFromD<D> DemoteTo(D dn, Vec512<uint32_t> v) { 5667 const DFromV<decltype(v)> d; 5668 const RebindToSigned<decltype(d)> di; 5669 return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); 5670 } 5671 5672 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I16_D(D)> 5673 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int32_t> v) { 5674 const Full512<uint64_t> du64; 5675 const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; 5676 5677 // Compress even u64 lanes into 256 bit. 5678 alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 5679 const auto idx64 = Load(du64, kLanes); 5680 const Vec512<int16_t> even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; 5681 return LowerHalf(even); 5682 } 5683 5684 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> 5685 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int32_t> v) { 5686 const Full512<uint32_t> du32; 5687 const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; 5688 const Vec512<uint8_t> u8{_mm512_packus_epi16(i16.raw, i16.raw)}; 5689 5690 const VFromD<decltype(du32)> idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); 5691 const Vec512<uint8_t> fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; 5692 return LowerHalf(LowerHalf(fixed)); 5693 } 5694 5695 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> 5696 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<uint32_t> v) { 5697 return VFromD<D>{_mm512_cvtusepi32_epi8(v.raw)}; 5698 } 5699 5700 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> 5701 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int16_t> v) { 5702 const Full512<uint64_t> du64; 5703 const Vec512<uint8_t> u8{_mm512_packus_epi16(v.raw, v.raw)}; 5704 5705 // Compress even u64 lanes into 256 bit. 5706 alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 5707 const auto idx64 = Load(du64, kLanes); 5708 const Vec512<uint8_t> even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; 5709 return LowerHalf(even); 5710 } 5711 5712 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> 5713 HWY_API VFromD<D> DemoteTo(D dn, Vec512<uint16_t> v) { 5714 const DFromV<decltype(v)> d; 5715 const RebindToSigned<decltype(d)> di; 5716 return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFu)))); 5717 } 5718 5719 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I8_D(D)> 5720 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int32_t> v) { 5721 const Full512<uint32_t> du32; 5722 const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; 5723 const Vec512<int8_t> i8{_mm512_packs_epi16(i16.raw, i16.raw)}; 5724 5725 const VFromD<decltype(du32)> idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); 5726 const Vec512<int8_t> fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; 5727 return LowerHalf(LowerHalf(fixed)); 5728 } 5729 5730 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I8_D(D)> 5731 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int16_t> v) { 5732 const Full512<uint64_t> du64; 5733 const Vec512<int8_t> u8{_mm512_packs_epi16(v.raw, v.raw)}; 5734 5735 // Compress even u64 lanes into 256 bit. 5736 alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 5737 const auto idx64 = Load(du64, kLanes); 5738 const Vec512<int8_t> even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; 5739 return LowerHalf(even); 5740 } 5741 5742 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> 5743 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { 5744 return VFromD<D>{_mm512_cvtsepi64_epi32(v.raw)}; 5745 } 5746 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I16_D(D)> 5747 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { 5748 return VFromD<D>{_mm512_cvtsepi64_epi16(v.raw)}; 5749 } 5750 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_I8_D(D)> 5751 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { 5752 return VFromD<D>{_mm512_cvtsepi64_epi8(v.raw)}; 5753 } 5754 5755 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> 5756 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { 5757 const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; 5758 return VFromD<D>{_mm512_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; 5759 } 5760 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> 5761 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { 5762 const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; 5763 return VFromD<D>{_mm512_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; 5764 } 5765 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> 5766 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<int64_t> v) { 5767 const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; 5768 return VFromD<D>{_mm512_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; 5769 } 5770 5771 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> 5772 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<uint64_t> v) { 5773 return VFromD<D>{_mm512_cvtusepi64_epi32(v.raw)}; 5774 } 5775 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> 5776 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<uint64_t> v) { 5777 return VFromD<D>{_mm512_cvtusepi64_epi16(v.raw)}; 5778 } 5779 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> 5780 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<uint64_t> v) { 5781 return VFromD<D>{_mm512_cvtusepi64_epi8(v.raw)}; 5782 } 5783 5784 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)> 5785 HWY_API VFromD<D> DemoteTo(D df16, Vec512<float> v) { 5786 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 5787 HWY_DIAGNOSTICS(push) 5788 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 5789 const RebindToUnsigned<decltype(df16)> du16; 5790 return BitCast( 5791 df16, VFromD<decltype(du16)>{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); 5792 HWY_DIAGNOSTICS(pop) 5793 } 5794 5795 #if HWY_HAVE_FLOAT16 5796 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F16_D(D)> 5797 HWY_API VFromD<D> DemoteTo(D /*df16*/, Vec512<double> v) { 5798 return VFromD<D>{_mm512_cvtpd_ph(v.raw)}; 5799 } 5800 #endif // HWY_HAVE_FLOAT16 5801 5802 #if HWY_AVX3_HAVE_F32_TO_BF16C 5803 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)> 5804 HWY_API VFromD<D> DemoteTo(D /*dbf16*/, Vec512<float> v) { 5805 #if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 5806 // Inline assembly workaround for LLVM codegen bug 5807 __m256i raw_result; 5808 __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); 5809 return VFromD<D>{raw_result}; 5810 #else 5811 // The _mm512_cvtneps_pbh intrinsic returns a __m256bh vector that needs to be 5812 // bit casted to a __m256i vector 5813 return VFromD<D>{detail::BitCastToInteger(_mm512_cvtneps_pbh(v.raw))}; 5814 #endif 5815 } 5816 5817 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)> 5818 HWY_API VFromD<D> ReorderDemote2To(D /*dbf16*/, Vec512<float> a, 5819 Vec512<float> b) { 5820 #if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 5821 // Inline assembly workaround for LLVM codegen bug 5822 __m512i raw_result; 5823 __asm__("vcvtne2ps2bf16 %2, %1, %0" 5824 : "=v"(raw_result) 5825 : "v"(b.raw), "v"(a.raw)); 5826 return VFromD<D>{raw_result}; 5827 #else 5828 // The _mm512_cvtne2ps_pbh intrinsic returns a __m512bh vector that needs to 5829 // be bit casted to a __m512i vector 5830 return VFromD<D>{detail::BitCastToInteger(_mm512_cvtne2ps_pbh(b.raw, a.raw))}; 5831 #endif 5832 } 5833 #endif // HWY_AVX3_HAVE_F32_TO_BF16C 5834 5835 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I16_D(D)> 5836 HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec512<int32_t> a, 5837 Vec512<int32_t> b) { 5838 return VFromD<D>{_mm512_packs_epi32(a.raw, b.raw)}; 5839 } 5840 5841 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> 5842 HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec512<int32_t> a, 5843 Vec512<int32_t> b) { 5844 return VFromD<D>{_mm512_packus_epi32(a.raw, b.raw)}; 5845 } 5846 5847 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> 5848 HWY_API VFromD<D> ReorderDemote2To(D dn, Vec512<uint32_t> a, 5849 Vec512<uint32_t> b) { 5850 const DFromV<decltype(a)> du32; 5851 const RebindToSigned<decltype(du32)> di32; 5852 const auto max_i32 = Set(du32, 0x7FFFFFFFu); 5853 5854 return ReorderDemote2To(dn, BitCast(di32, Min(a, max_i32)), 5855 BitCast(di32, Min(b, max_i32))); 5856 } 5857 5858 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I8_D(D)> 5859 HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec512<int16_t> a, 5860 Vec512<int16_t> b) { 5861 return VFromD<D>{_mm512_packs_epi16(a.raw, b.raw)}; 5862 } 5863 5864 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U8_D(D)> 5865 HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec512<int16_t> a, 5866 Vec512<int16_t> b) { 5867 return VFromD<D>{_mm512_packus_epi16(a.raw, b.raw)}; 5868 } 5869 5870 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U8_D(D)> 5871 HWY_API VFromD<D> ReorderDemote2To(D dn, Vec512<uint16_t> a, 5872 Vec512<uint16_t> b) { 5873 const DFromV<decltype(a)> du16; 5874 const RebindToSigned<decltype(du16)> di16; 5875 const auto max_i16 = Set(du16, 0x7FFFu); 5876 5877 return ReorderDemote2To(dn, BitCast(di16, Min(a, max_i16)), 5878 BitCast(di16, Min(b, max_i16))); 5879 } 5880 5881 template <class D, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>), 5882 HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), 5883 HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2), 5884 HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV<V>) * 2), 5885 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> 5886 HWY_API VFromD<D> OrderedDemote2To(D d, V a, V b) { 5887 const Full512<uint64_t> du64; 5888 alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; 5889 return BitCast(d, TableLookupLanes(BitCast(du64, ReorderDemote2To(d, a, b)), 5890 SetTableIndices(du64, kIdx))); 5891 } 5892 5893 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 5894 HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<double> v) { 5895 return VFromD<D>{_mm512_cvtpd_ps(v.raw)}; 5896 } 5897 5898 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)> 5899 HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) { 5900 #if HWY_X86_HAVE_AVX10_2_OPS 5901 return VFromD<D>{_mm512_cvtts_pd_epi32(v.raw)}; 5902 #elif HWY_COMPILER_GCC_ACTUAL 5903 // Workaround for undefined behavior in _mm512_cvttpd_epi32 with GCC if any 5904 // values of v[i] are not within the range of an int32_t 5905 5906 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 5907 if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { 5908 typedef double GccF64RawVectType __attribute__((__vector_size__(64))); 5909 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 5910 return VFromD<D>{ 5911 _mm256_setr_epi32( 5912 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[0]), 5913 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[1]), 5914 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[2]), 5915 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[3]), 5916 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[4]), 5917 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[5]), 5918 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[6]), 5919 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[7])) 5920 }; 5921 } 5922 #endif 5923 5924 __m256i raw_result; 5925 __asm__("vcvttpd2dq {%1, %0|%0, %1}" 5926 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 5927 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 5928 :); 5929 return VFromD<D>{raw_result}; 5930 #else 5931 return VFromD<D>{_mm512_cvttpd_epi32(v.raw)}; 5932 #endif 5933 } 5934 5935 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> 5936 HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) { 5937 #if HWY_X86_HAVE_AVX10_2_OPS 5938 return VFromD<D>{_mm512_cvtts_pd_epu32(v.raw)}; 5939 #elif HWY_COMPILER_GCC_ACTUAL 5940 // Workaround for undefined behavior in _mm512_cvttpd_epu32 with GCC if any 5941 // values of v[i] are not within the range of an uint32_t 5942 5943 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 5944 if (detail::IsConstantX86VecForF2IConv<uint32_t>(v)) { 5945 typedef double GccF64RawVectType __attribute__((__vector_size__(64))); 5946 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 5947 return VFromD<D>{_mm256_setr_epi32( 5948 static_cast<int32_t>( 5949 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[0])), 5950 static_cast<int32_t>( 5951 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[1])), 5952 static_cast<int32_t>( 5953 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[2])), 5954 static_cast<int32_t>( 5955 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[3])), 5956 static_cast<int32_t>( 5957 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[4])), 5958 static_cast<int32_t>( 5959 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[5])), 5960 static_cast<int32_t>( 5961 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[6])), 5962 static_cast<int32_t>( 5963 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[7])))}; 5964 } 5965 #endif 5966 5967 __m256i raw_result; 5968 __asm__("vcvttpd2udq {%1, %0|%0, %1}" 5969 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 5970 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 5971 :); 5972 return VFromD<D>{raw_result}; 5973 #else 5974 return VFromD<D>{_mm512_cvttpd_epu32(v.raw)}; 5975 #endif 5976 } 5977 5978 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 5979 HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int64_t, D>> v) { 5980 return VFromD<D>{_mm512_cvtepi64_ps(v.raw)}; 5981 } 5982 5983 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)> 5984 HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint64_t, D>> v) { 5985 return VFromD<D>{_mm512_cvtepu64_ps(v.raw)}; 5986 } 5987 5988 // For already range-limited input [0, 255]. 5989 HWY_API Vec128<uint8_t> U8FromU32(const Vec512<uint32_t> v) { 5990 const DFromV<decltype(v)> d32; 5991 // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the 5992 // lowest 4 bytes. 5993 const VFromD<decltype(d32)> v8From32 = 5994 Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); 5995 const auto quads = TableLookupBytes(v, v8From32); 5996 // Gather the lowest 4 bytes of 4 128-bit blocks. 5997 const VFromD<decltype(d32)> index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); 5998 const Vec512<uint8_t> bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; 5999 return LowerHalf(LowerHalf(bytes)); 6000 } 6001 6002 // ------------------------------ Truncations 6003 6004 template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)> 6005 HWY_API VFromD<D> TruncateTo(D d, const Vec512<uint64_t> v) { 6006 #if HWY_TARGET <= HWY_AVX3_DL 6007 (void)d; 6008 const Full512<uint8_t> d8; 6009 const VFromD<decltype(d8)> v8From64 = Dup128VecFromValues( 6010 d8, 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56); 6011 const Vec512<uint8_t> bytes{_mm512_permutexvar_epi8(v8From64.raw, v.raw)}; 6012 return LowerHalf(LowerHalf(LowerHalf(bytes))); 6013 #else 6014 const Full512<uint32_t> d32; 6015 alignas(64) static constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, 6016 0, 2, 4, 6, 8, 10, 12, 14}; 6017 const Vec512<uint32_t> even{ 6018 _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; 6019 return TruncateTo(d, LowerHalf(even)); 6020 #endif 6021 } 6022 6023 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)> 6024 HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint64_t> v) { 6025 const Full512<uint16_t> d16; 6026 alignas(16) static constexpr uint16_t k16From64[8] = {0, 4, 8, 12, 6027 16, 20, 24, 28}; 6028 const Vec512<uint16_t> bytes{ 6029 _mm512_permutexvar_epi16(LoadDup128(d16, k16From64).raw, v.raw)}; 6030 return LowerHalf(LowerHalf(bytes)); 6031 } 6032 6033 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)> 6034 HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint64_t> v) { 6035 const Full512<uint32_t> d32; 6036 alignas(64) static constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, 6037 0, 2, 4, 6, 8, 10, 12, 14}; 6038 const Vec512<uint32_t> even{ 6039 _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; 6040 return LowerHalf(even); 6041 } 6042 6043 template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)> 6044 HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint32_t> v) { 6045 #if HWY_TARGET <= HWY_AVX3_DL 6046 const Full512<uint8_t> d8; 6047 const VFromD<decltype(d8)> v8From32 = Dup128VecFromValues( 6048 d8, 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); 6049 const Vec512<uint8_t> bytes{_mm512_permutexvar_epi8(v8From32.raw, v.raw)}; 6050 #else 6051 const Full512<uint32_t> d32; 6052 // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the 6053 // lowest 4 bytes. 6054 const VFromD<decltype(d32)> v8From32 = 6055 Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); 6056 const auto quads = TableLookupBytes(v, v8From32); 6057 // Gather the lowest 4 bytes of 4 128-bit blocks. 6058 const VFromD<decltype(d32)> index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); 6059 const Vec512<uint8_t> bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; 6060 #endif 6061 return LowerHalf(LowerHalf(bytes)); 6062 } 6063 6064 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)> 6065 HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint32_t> v) { 6066 const Full512<uint16_t> d16; 6067 alignas(64) static constexpr uint16_t k16From32[32] = { 6068 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 6069 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; 6070 const Vec512<uint16_t> bytes{ 6071 _mm512_permutexvar_epi16(Load(d16, k16From32).raw, v.raw)}; 6072 return LowerHalf(bytes); 6073 } 6074 6075 template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)> 6076 HWY_API VFromD<D> TruncateTo(D /* tag */, const Vec512<uint16_t> v) { 6077 #if HWY_TARGET <= HWY_AVX3_DL 6078 const Full512<uint8_t> d8; 6079 alignas(64) static constexpr uint8_t k8From16[64] = { 6080 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 6081 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 6082 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 6083 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; 6084 const Vec512<uint8_t> bytes{ 6085 _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; 6086 #else 6087 const Full512<uint32_t> d32; 6088 const VFromD<decltype(d32)> v16From32 = Dup128VecFromValues( 6089 d32, 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u); 6090 const auto quads = TableLookupBytes(v, v16From32); 6091 alignas(64) static constexpr uint32_t kIndex32[16] = { 6092 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; 6093 const Vec512<uint8_t> bytes{ 6094 _mm512_permutexvar_epi32(Load(d32, kIndex32).raw, quads.raw)}; 6095 #endif 6096 return LowerHalf(bytes); 6097 } 6098 6099 // ------------------------------ Convert integer <=> floating point 6100 6101 #if HWY_HAVE_FLOAT16 6102 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 6103 HWY_API VFromD<D> ConvertTo(D /* tag */, Vec512<uint16_t> v) { 6104 return VFromD<D>{_mm512_cvtepu16_ph(v.raw)}; 6105 } 6106 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)> 6107 HWY_API VFromD<D> ConvertTo(D /* tag */, Vec512<int16_t> v) { 6108 return VFromD<D>{_mm512_cvtepi16_ph(v.raw)}; 6109 } 6110 #endif // HWY_HAVE_FLOAT16 6111 6112 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 6113 HWY_API VFromD<D> ConvertTo(D /* tag */, Vec512<int32_t> v) { 6114 return VFromD<D>{_mm512_cvtepi32_ps(v.raw)}; 6115 } 6116 6117 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 6118 HWY_API VFromD<D> ConvertTo(D /* tag */, Vec512<int64_t> v) { 6119 return VFromD<D>{_mm512_cvtepi64_pd(v.raw)}; 6120 } 6121 6122 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)> 6123 HWY_API VFromD<D> ConvertTo(D /* tag*/, Vec512<uint32_t> v) { 6124 return VFromD<D>{_mm512_cvtepu32_ps(v.raw)}; 6125 } 6126 6127 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)> 6128 HWY_API VFromD<D> ConvertTo(D /* tag*/, Vec512<uint64_t> v) { 6129 return VFromD<D>{_mm512_cvtepu64_pd(v.raw)}; 6130 } 6131 6132 // Truncates (rounds toward zero). 6133 #if HWY_HAVE_FLOAT16 6134 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I16_D(D)> 6135 HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec512<float16_t> v) { 6136 #if HWY_COMPILER_GCC_ACTUAL 6137 // Workaround for undefined behavior in _mm512_cvttph_epi16 with GCC if any 6138 // values of v[i] are not within the range of an int16_t 6139 6140 #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ 6141 HWY_HAVE_SCALAR_F16_TYPE 6142 if (detail::IsConstantX86VecForF2IConv<int16_t>(v)) { 6143 typedef hwy::float16_t::Native GccF16RawVectType 6144 __attribute__((__vector_size__(64))); 6145 const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); 6146 return VFromD<D>{ 6147 _mm512_set_epi16(detail::X86ConvertScalarFromFloat<int16_t>(raw_v[31]), 6148 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[30]), 6149 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[29]), 6150 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[28]), 6151 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[27]), 6152 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[26]), 6153 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[25]), 6154 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[24]), 6155 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[23]), 6156 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[22]), 6157 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[21]), 6158 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[20]), 6159 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[19]), 6160 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[18]), 6161 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[17]), 6162 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[16]), 6163 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[15]), 6164 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[14]), 6165 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[13]), 6166 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[12]), 6167 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[11]), 6168 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[10]), 6169 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[9]), 6170 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[8]), 6171 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[7]), 6172 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[6]), 6173 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[5]), 6174 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[4]), 6175 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[3]), 6176 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[2]), 6177 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[1]), 6178 detail::X86ConvertScalarFromFloat<int16_t>(raw_v[0]))}; 6179 } 6180 #endif 6181 6182 __m512i raw_result; 6183 __asm__("vcvttph2w {%1, %0|%0, %1}" 6184 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6185 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6186 :); 6187 return VFromD<D>{raw_result}; 6188 #else 6189 return VFromD<D>{_mm512_cvttph_epi16(v.raw)}; 6190 #endif 6191 } 6192 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> 6193 HWY_API VFromD<D> ConvertInRangeTo(D /* tag */, VFromD<RebindToFloat<D>> v) { 6194 #if HWY_COMPILER_GCC_ACTUAL 6195 // Workaround for undefined behavior in _mm512_cvttph_epu16 with GCC if any 6196 // values of v[i] are not within the range of an uint16_t 6197 6198 #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ 6199 HWY_HAVE_SCALAR_F16_TYPE 6200 if (detail::IsConstantX86VecForF2IConv<uint16_t>(v)) { 6201 typedef hwy::float16_t::Native GccF16RawVectType 6202 __attribute__((__vector_size__(64))); 6203 const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); 6204 return VFromD<D>{_mm512_set_epi16( 6205 static_cast<int16_t>( 6206 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[31])), 6207 static_cast<int16_t>( 6208 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[30])), 6209 static_cast<int16_t>( 6210 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[29])), 6211 static_cast<int16_t>( 6212 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[28])), 6213 static_cast<int16_t>( 6214 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[27])), 6215 static_cast<int16_t>( 6216 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[26])), 6217 static_cast<int16_t>( 6218 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[25])), 6219 static_cast<int16_t>( 6220 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[24])), 6221 static_cast<int16_t>( 6222 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[23])), 6223 static_cast<int16_t>( 6224 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[22])), 6225 static_cast<int16_t>( 6226 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[21])), 6227 static_cast<int16_t>( 6228 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[20])), 6229 static_cast<int16_t>( 6230 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[19])), 6231 static_cast<int16_t>( 6232 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[18])), 6233 static_cast<int16_t>( 6234 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[17])), 6235 static_cast<int16_t>( 6236 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[16])), 6237 static_cast<int16_t>( 6238 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[15])), 6239 static_cast<int16_t>( 6240 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[14])), 6241 static_cast<int16_t>( 6242 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[13])), 6243 static_cast<int16_t>( 6244 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[12])), 6245 static_cast<int16_t>( 6246 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[11])), 6247 static_cast<int16_t>( 6248 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[10])), 6249 static_cast<int16_t>( 6250 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[9])), 6251 static_cast<int16_t>( 6252 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[8])), 6253 static_cast<int16_t>( 6254 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[7])), 6255 static_cast<int16_t>( 6256 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[6])), 6257 static_cast<int16_t>( 6258 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[5])), 6259 static_cast<int16_t>( 6260 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[4])), 6261 static_cast<int16_t>( 6262 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[3])), 6263 static_cast<int16_t>( 6264 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[2])), 6265 static_cast<int16_t>( 6266 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[1])), 6267 static_cast<int16_t>( 6268 detail::X86ConvertScalarFromFloat<uint16_t>(raw_v[0])))}; 6269 } 6270 #endif 6271 6272 __m512i raw_result; 6273 __asm__("vcvttph2uw {%1, %0|%0, %1}" 6274 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6275 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6276 :); 6277 return VFromD<D>{raw_result}; 6278 #else 6279 return VFromD<D>{_mm512_cvttph_epu16(v.raw)}; 6280 #endif 6281 } 6282 #endif // HWY_HAVE_FLOAT16 6283 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> 6284 HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec512<float> v) { 6285 #if HWY_X86_HAVE_AVX10_2_OPS 6286 return VFromD<D>{_mm512_cvtts_ps_epi32(v.raw)}; 6287 #elif HWY_COMPILER_GCC_ACTUAL 6288 // Workaround for undefined behavior in _mm512_cvttps_epi32 with GCC if any 6289 // values of v[i] are not within the range of an int32_t 6290 6291 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6292 if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { 6293 typedef float GccF32RawVectType __attribute__((__vector_size__(64))); 6294 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 6295 return VFromD<D>{_mm512_setr_epi32( 6296 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[0]), 6297 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[1]), 6298 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[2]), 6299 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[3]), 6300 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[4]), 6301 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[5]), 6302 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[6]), 6303 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[7]), 6304 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[8]), 6305 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[9]), 6306 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[10]), 6307 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[11]), 6308 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[12]), 6309 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[13]), 6310 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[14]), 6311 detail::X86ConvertScalarFromFloat<int32_t>(raw_v[15]))}; 6312 } 6313 #endif 6314 6315 __m512i raw_result; 6316 __asm__("vcvttps2dq {%1, %0|%0, %1}" 6317 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6318 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6319 :); 6320 return VFromD<D>{raw_result}; 6321 #else 6322 return VFromD<D>{_mm512_cvttps_epi32(v.raw)}; 6323 #endif 6324 } 6325 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)> 6326 HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, Vec512<double> v) { 6327 #if HWY_X86_HAVE_AVX10_2_OPS 6328 return VFromD<D>{_mm512_cvtts_pd_epi64(v.raw)}; 6329 #elif HWY_COMPILER_GCC_ACTUAL 6330 // Workaround for undefined behavior in _mm512_cvttpd_epi64 with GCC if any 6331 // values of v[i] are not within the range of an int64_t 6332 6333 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6334 if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { 6335 typedef double GccF64RawVectType __attribute__((__vector_size__(64))); 6336 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 6337 return VFromD<D>{_mm512_setr_epi64( 6338 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[0]), 6339 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[1]), 6340 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[2]), 6341 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[3]), 6342 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[4]), 6343 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[5]), 6344 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[6]), 6345 detail::X86ConvertScalarFromFloat<int64_t>(raw_v[7]))}; 6346 } 6347 #endif 6348 6349 __m512i raw_result; 6350 __asm__("vcvttpd2qq {%1, %0|%0, %1}" 6351 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6352 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6353 :); 6354 return VFromD<D>{raw_result}; 6355 #else 6356 return VFromD<D>{_mm512_cvttpd_epi64(v.raw)}; 6357 #endif 6358 } 6359 template <class DU, HWY_IF_V_SIZE_D(DU, 64), HWY_IF_U32_D(DU)> 6360 HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) { 6361 #if HWY_X86_HAVE_AVX10_2_OPS 6362 return VFromD<DU>{_mm512_cvtts_ps_epu32(v.raw)}; 6363 #elif HWY_COMPILER_GCC_ACTUAL 6364 // Workaround for undefined behavior in _mm512_cvttps_epu32 with GCC if any 6365 // values of v[i] are not within the range of an uint32_t 6366 6367 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6368 if (detail::IsConstantX86VecForF2IConv<uint32_t>(v)) { 6369 typedef float GccF32RawVectType __attribute__((__vector_size__(64))); 6370 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 6371 return VFromD<DU>{_mm512_setr_epi32( 6372 static_cast<int32_t>( 6373 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[0])), 6374 static_cast<int32_t>( 6375 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[1])), 6376 static_cast<int32_t>( 6377 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[2])), 6378 static_cast<int32_t>( 6379 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[3])), 6380 static_cast<int32_t>( 6381 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[4])), 6382 static_cast<int32_t>( 6383 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[5])), 6384 static_cast<int32_t>( 6385 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[6])), 6386 static_cast<int32_t>( 6387 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[7])), 6388 static_cast<int32_t>( 6389 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[8])), 6390 static_cast<int32_t>( 6391 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[9])), 6392 static_cast<int32_t>( 6393 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[10])), 6394 static_cast<int32_t>( 6395 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[11])), 6396 static_cast<int32_t>( 6397 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[12])), 6398 static_cast<int32_t>( 6399 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[13])), 6400 static_cast<int32_t>( 6401 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[14])), 6402 static_cast<int32_t>( 6403 detail::X86ConvertScalarFromFloat<uint32_t>(raw_v[15])))}; 6404 } 6405 #endif 6406 6407 __m512i raw_result; 6408 __asm__("vcvttps2udq {%1, %0|%0, %1}" 6409 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6410 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6411 :); 6412 return VFromD<DU>{raw_result}; 6413 #else 6414 return VFromD<DU>{_mm512_cvttps_epu32(v.raw)}; 6415 #endif 6416 } 6417 template <class DU, HWY_IF_V_SIZE_D(DU, 64), HWY_IF_U64_D(DU)> 6418 HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) { 6419 #if HWY_X86_HAVE_AVX10_2_OPS 6420 return VFromD<DU>{_mm512_cvtts_pd_epu64(v.raw)}; 6421 #elif HWY_COMPILER_GCC_ACTUAL 6422 // Workaround for undefined behavior in _mm512_cvttpd_epu64 with GCC if any 6423 // values of v[i] are not within the range of an uint64_t 6424 6425 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6426 if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { 6427 typedef double GccF64RawVectType __attribute__((__vector_size__(64))); 6428 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 6429 return VFromD<DU>{_mm512_setr_epi64( 6430 static_cast<int64_t>( 6431 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[0])), 6432 static_cast<int64_t>( 6433 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[1])), 6434 static_cast<int64_t>( 6435 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[2])), 6436 static_cast<int64_t>( 6437 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[3])), 6438 static_cast<int64_t>( 6439 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[4])), 6440 static_cast<int64_t>( 6441 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[5])), 6442 static_cast<int64_t>( 6443 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[6])), 6444 static_cast<int64_t>( 6445 detail::X86ConvertScalarFromFloat<uint64_t>(raw_v[7])))}; 6446 } 6447 #endif 6448 6449 __m512i raw_result; 6450 __asm__("vcvttpd2uqq {%1, %0|%0, %1}" 6451 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6452 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6453 :); 6454 return VFromD<DU>{raw_result}; 6455 #else 6456 return VFromD<DU>{_mm512_cvttpd_epu64(v.raw)}; 6457 #endif 6458 } 6459 6460 template <class DI, HWY_IF_V_SIZE_D(DI, 64), HWY_IF_I32_D(DI)> 6461 static HWY_INLINE VFromD<DI> NearestIntInRange(DI, 6462 VFromD<RebindToFloat<DI>> v) { 6463 #if HWY_COMPILER_GCC_ACTUAL 6464 // Workaround for undefined behavior in _mm512_cvtps_epi32 with GCC if any 6465 // values of v[i] are not within the range of an int32_t 6466 6467 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6468 if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { 6469 typedef float GccF32RawVectType __attribute__((__vector_size__(64))); 6470 const auto raw_v = reinterpret_cast<GccF32RawVectType>(v.raw); 6471 return VFromD<DI>{ 6472 _mm512_setr_epi32(detail::X86ScalarNearestInt<int32_t>(raw_v[0]), 6473 detail::X86ScalarNearestInt<int32_t>(raw_v[1]), 6474 detail::X86ScalarNearestInt<int32_t>(raw_v[2]), 6475 detail::X86ScalarNearestInt<int32_t>(raw_v[3]), 6476 detail::X86ScalarNearestInt<int32_t>(raw_v[4]), 6477 detail::X86ScalarNearestInt<int32_t>(raw_v[5]), 6478 detail::X86ScalarNearestInt<int32_t>(raw_v[6]), 6479 detail::X86ScalarNearestInt<int32_t>(raw_v[7]), 6480 detail::X86ScalarNearestInt<int32_t>(raw_v[8]), 6481 detail::X86ScalarNearestInt<int32_t>(raw_v[9]), 6482 detail::X86ScalarNearestInt<int32_t>(raw_v[10]), 6483 detail::X86ScalarNearestInt<int32_t>(raw_v[11]), 6484 detail::X86ScalarNearestInt<int32_t>(raw_v[12]), 6485 detail::X86ScalarNearestInt<int32_t>(raw_v[13]), 6486 detail::X86ScalarNearestInt<int32_t>(raw_v[14]), 6487 detail::X86ScalarNearestInt<int32_t>(raw_v[15]))}; 6488 } 6489 #endif 6490 6491 __m512i raw_result; 6492 __asm__("vcvtps2dq {%1, %0|%0, %1}" 6493 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6494 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6495 :); 6496 return VFromD<DI>{raw_result}; 6497 #else 6498 return VFromD<DI>{_mm512_cvtps_epi32(v.raw)}; 6499 #endif 6500 } 6501 6502 #if HWY_HAVE_FLOAT16 6503 template <class DI, HWY_IF_V_SIZE_D(DI, 64), HWY_IF_I16_D(DI)> 6504 static HWY_INLINE VFromD<DI> NearestIntInRange(DI /*d*/, Vec512<float16_t> v) { 6505 #if HWY_COMPILER_GCC_ACTUAL 6506 // Workaround for undefined behavior in _mm512_cvtph_epi16 with GCC if any 6507 // values of v[i] are not within the range of an int16_t 6508 6509 #if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ 6510 HWY_HAVE_SCALAR_F16_TYPE 6511 if (detail::IsConstantX86VecForF2IConv<int16_t>(v)) { 6512 typedef hwy::float16_t::Native GccF16RawVectType 6513 __attribute__((__vector_size__(64))); 6514 const auto raw_v = reinterpret_cast<GccF16RawVectType>(v.raw); 6515 return VFromD<DI>{ 6516 _mm512_set_epi16(detail::X86ScalarNearestInt<int16_t>(raw_v[31]), 6517 detail::X86ScalarNearestInt<int16_t>(raw_v[30]), 6518 detail::X86ScalarNearestInt<int16_t>(raw_v[29]), 6519 detail::X86ScalarNearestInt<int16_t>(raw_v[28]), 6520 detail::X86ScalarNearestInt<int16_t>(raw_v[27]), 6521 detail::X86ScalarNearestInt<int16_t>(raw_v[26]), 6522 detail::X86ScalarNearestInt<int16_t>(raw_v[25]), 6523 detail::X86ScalarNearestInt<int16_t>(raw_v[24]), 6524 detail::X86ScalarNearestInt<int16_t>(raw_v[23]), 6525 detail::X86ScalarNearestInt<int16_t>(raw_v[22]), 6526 detail::X86ScalarNearestInt<int16_t>(raw_v[21]), 6527 detail::X86ScalarNearestInt<int16_t>(raw_v[20]), 6528 detail::X86ScalarNearestInt<int16_t>(raw_v[19]), 6529 detail::X86ScalarNearestInt<int16_t>(raw_v[18]), 6530 detail::X86ScalarNearestInt<int16_t>(raw_v[17]), 6531 detail::X86ScalarNearestInt<int16_t>(raw_v[16]), 6532 detail::X86ScalarNearestInt<int16_t>(raw_v[15]), 6533 detail::X86ScalarNearestInt<int16_t>(raw_v[14]), 6534 detail::X86ScalarNearestInt<int16_t>(raw_v[13]), 6535 detail::X86ScalarNearestInt<int16_t>(raw_v[12]), 6536 detail::X86ScalarNearestInt<int16_t>(raw_v[11]), 6537 detail::X86ScalarNearestInt<int16_t>(raw_v[10]), 6538 detail::X86ScalarNearestInt<int16_t>(raw_v[9]), 6539 detail::X86ScalarNearestInt<int16_t>(raw_v[8]), 6540 detail::X86ScalarNearestInt<int16_t>(raw_v[7]), 6541 detail::X86ScalarNearestInt<int16_t>(raw_v[6]), 6542 detail::X86ScalarNearestInt<int16_t>(raw_v[5]), 6543 detail::X86ScalarNearestInt<int16_t>(raw_v[4]), 6544 detail::X86ScalarNearestInt<int16_t>(raw_v[3]), 6545 detail::X86ScalarNearestInt<int16_t>(raw_v[2]), 6546 detail::X86ScalarNearestInt<int16_t>(raw_v[1]), 6547 detail::X86ScalarNearestInt<int16_t>(raw_v[0]))}; 6548 } 6549 #endif 6550 6551 __m512i raw_result; 6552 __asm__("vcvtph2w {%1, %0|%0, %1}" 6553 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6554 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6555 :); 6556 return VFromD<DI>{raw_result}; 6557 #else 6558 return VFromD<DI>{_mm512_cvtph_epi16(v.raw)}; 6559 #endif 6560 } 6561 #endif // HWY_HAVE_FLOAT16 6562 6563 template <class DI, HWY_IF_V_SIZE_D(DI, 64), HWY_IF_I64_D(DI)> 6564 static HWY_INLINE VFromD<DI> NearestIntInRange(DI /*di*/, Vec512<double> v) { 6565 #if HWY_COMPILER_GCC_ACTUAL 6566 // Workaround for undefined behavior in _mm512_cvtpd_epi64 with GCC if any 6567 // values of v[i] are not within the range of an int64_t 6568 6569 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6570 if (detail::IsConstantX86VecForF2IConv<int64_t>(v)) { 6571 typedef double GccF64RawVectType __attribute__((__vector_size__(64))); 6572 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 6573 return VFromD<DI>{ 6574 _mm512_setr_epi64(detail::X86ScalarNearestInt<int64_t>(raw_v[0]), 6575 detail::X86ScalarNearestInt<int64_t>(raw_v[1]), 6576 detail::X86ScalarNearestInt<int64_t>(raw_v[2]), 6577 detail::X86ScalarNearestInt<int64_t>(raw_v[3]), 6578 detail::X86ScalarNearestInt<int64_t>(raw_v[4]), 6579 detail::X86ScalarNearestInt<int64_t>(raw_v[5]), 6580 detail::X86ScalarNearestInt<int64_t>(raw_v[6]), 6581 detail::X86ScalarNearestInt<int64_t>(raw_v[7]))}; 6582 } 6583 #endif 6584 6585 __m512i raw_result; 6586 __asm__("vcvtpd2qq {%1, %0|%0, %1}" 6587 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6588 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6589 :); 6590 return VFromD<DI>{raw_result}; 6591 #else 6592 return VFromD<DI>{_mm512_cvtpd_epi64(v.raw)}; 6593 #endif 6594 } 6595 6596 template <class DI, HWY_IF_V_SIZE_D(DI, 32), HWY_IF_I32_D(DI)> 6597 static HWY_INLINE VFromD<DI> DemoteToNearestIntInRange(DI /* tag */, 6598 Vec512<double> v) { 6599 #if HWY_COMPILER_GCC_ACTUAL 6600 // Workaround for undefined behavior in _mm512_cvtpd_epi32 with GCC if any 6601 // values of v[i] are not within the range of an int32_t 6602 6603 #if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD 6604 if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) { 6605 typedef double GccF64RawVectType __attribute__((__vector_size__(64))); 6606 const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw); 6607 return VFromD<DI>{ 6608 _mm256_setr_epi32(detail::X86ScalarNearestInt<int32_t>(raw_v[0]), 6609 detail::X86ScalarNearestInt<int32_t>(raw_v[1]), 6610 detail::X86ScalarNearestInt<int32_t>(raw_v[2]), 6611 detail::X86ScalarNearestInt<int32_t>(raw_v[3]), 6612 detail::X86ScalarNearestInt<int32_t>(raw_v[4]), 6613 detail::X86ScalarNearestInt<int32_t>(raw_v[5]), 6614 detail::X86ScalarNearestInt<int32_t>(raw_v[6]), 6615 detail::X86ScalarNearestInt<int32_t>(raw_v[7]))}; 6616 } 6617 #endif 6618 6619 __m256i raw_result; 6620 __asm__("vcvtpd2dq {%1, %0|%0, %1}" 6621 : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) 6622 : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) 6623 :); 6624 return VFromD<DI>{raw_result}; 6625 #else 6626 return VFromD<DI>{_mm512_cvtpd_epi32(v.raw)}; 6627 #endif 6628 } 6629 6630 // ================================================== CRYPTO 6631 6632 #if !defined(HWY_DISABLE_PCLMUL_AES) 6633 6634 HWY_API Vec512<uint8_t> AESRound(Vec512<uint8_t> state, 6635 Vec512<uint8_t> round_key) { 6636 #if HWY_TARGET <= HWY_AVX3_DL 6637 return Vec512<uint8_t>{_mm512_aesenc_epi128(state.raw, round_key.raw)}; 6638 #else 6639 const DFromV<decltype(state)> d; 6640 const Half<decltype(d)> d2; 6641 return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), 6642 AESRound(LowerHalf(state), LowerHalf(round_key))); 6643 #endif 6644 } 6645 6646 HWY_API Vec512<uint8_t> AESLastRound(Vec512<uint8_t> state, 6647 Vec512<uint8_t> round_key) { 6648 #if HWY_TARGET <= HWY_AVX3_DL 6649 return Vec512<uint8_t>{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; 6650 #else 6651 const DFromV<decltype(state)> d; 6652 const Half<decltype(d)> d2; 6653 return Combine(d, 6654 AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), 6655 AESLastRound(LowerHalf(state), LowerHalf(round_key))); 6656 #endif 6657 } 6658 6659 HWY_API Vec512<uint8_t> AESRoundInv(Vec512<uint8_t> state, 6660 Vec512<uint8_t> round_key) { 6661 #if HWY_TARGET <= HWY_AVX3_DL 6662 return Vec512<uint8_t>{_mm512_aesdec_epi128(state.raw, round_key.raw)}; 6663 #else 6664 const Full512<uint8_t> d; 6665 const Half<decltype(d)> d2; 6666 return Combine(d, AESRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), 6667 AESRoundInv(LowerHalf(state), LowerHalf(round_key))); 6668 #endif 6669 } 6670 6671 HWY_API Vec512<uint8_t> AESLastRoundInv(Vec512<uint8_t> state, 6672 Vec512<uint8_t> round_key) { 6673 #if HWY_TARGET <= HWY_AVX3_DL 6674 return Vec512<uint8_t>{_mm512_aesdeclast_epi128(state.raw, round_key.raw)}; 6675 #else 6676 const Full512<uint8_t> d; 6677 const Half<decltype(d)> d2; 6678 return Combine( 6679 d, AESLastRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), 6680 AESLastRoundInv(LowerHalf(state), LowerHalf(round_key))); 6681 #endif 6682 } 6683 6684 template <uint8_t kRcon> 6685 HWY_API Vec512<uint8_t> AESKeyGenAssist(Vec512<uint8_t> v) { 6686 const Full512<uint8_t> d; 6687 #if HWY_TARGET <= HWY_AVX3_DL 6688 const VFromD<decltype(d)> rconXorMask = Dup128VecFromValues( 6689 d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); 6690 const VFromD<decltype(d)> rotWordShuffle = Dup128VecFromValues( 6691 d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); 6692 const Repartition<uint32_t, decltype(d)> du32; 6693 const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); 6694 const auto sub_word_result = AESLastRound(w13, rconXorMask); 6695 return TableLookupBytes(sub_word_result, rotWordShuffle); 6696 #else 6697 const Half<decltype(d)> d2; 6698 return Combine(d, AESKeyGenAssist<kRcon>(UpperHalf(d2, v)), 6699 AESKeyGenAssist<kRcon>(LowerHalf(v))); 6700 #endif 6701 } 6702 6703 HWY_API Vec512<uint64_t> CLMulLower(Vec512<uint64_t> va, Vec512<uint64_t> vb) { 6704 #if HWY_TARGET <= HWY_AVX3_DL 6705 return Vec512<uint64_t>{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; 6706 #else 6707 alignas(64) uint64_t a[8]; 6708 alignas(64) uint64_t b[8]; 6709 const DFromV<decltype(va)> d; 6710 const Half<Half<decltype(d)>> d128; 6711 Store(va, d, a); 6712 Store(vb, d, b); 6713 for (size_t i = 0; i < 8; i += 2) { 6714 const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); 6715 Store(mul, d128, a + i); 6716 } 6717 return Load(d, a); 6718 #endif 6719 } 6720 6721 HWY_API Vec512<uint64_t> CLMulUpper(Vec512<uint64_t> va, Vec512<uint64_t> vb) { 6722 #if HWY_TARGET <= HWY_AVX3_DL 6723 return Vec512<uint64_t>{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; 6724 #else 6725 alignas(64) uint64_t a[8]; 6726 alignas(64) uint64_t b[8]; 6727 const DFromV<decltype(va)> d; 6728 const Half<Half<decltype(d)>> d128; 6729 Store(va, d, a); 6730 Store(vb, d, b); 6731 for (size_t i = 0; i < 8; i += 2) { 6732 const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); 6733 Store(mul, d128, a + i); 6734 } 6735 return Load(d, a); 6736 #endif 6737 } 6738 6739 #endif // HWY_DISABLE_PCLMUL_AES 6740 6741 // ================================================== MISC 6742 6743 // ------------------------------ SumsOfAdjQuadAbsDiff (Broadcast, 6744 // SumsOfAdjShufQuadAbsDiff) 6745 6746 template <int kAOffset, int kBOffset> 6747 HWY_API Vec512<uint16_t> SumsOfAdjQuadAbsDiff(Vec512<uint8_t> a, 6748 Vec512<uint8_t> b) { 6749 static_assert(0 <= kAOffset && kAOffset <= 1, 6750 "kAOffset must be between 0 and 1"); 6751 static_assert(0 <= kBOffset && kBOffset <= 3, 6752 "kBOffset must be between 0 and 3"); 6753 6754 #if HWY_X86_HAVE_AVX10_2_OPS 6755 // AVX10.2 now has the _mm512_mpsadbw_epu8 intrinsic available 6756 return Vec512<uint16_t>{_mm512_mpsadbw_epu8( 6757 a.raw, b.raw, 6758 (kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)}; 6759 #else 6760 const DFromV<decltype(a)> d; 6761 const RepartitionToWideX2<decltype(d)> du32; 6762 6763 // The _mm512_mpsadbw_epu8 intrinsic is not available prior to AVX10.2. 6764 // The SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on 6765 // pre-AVX10.2 targets that support AVX3 using SumsOfShuffledQuadAbsDiff and 6766 // U32 Broadcast. 6767 return SumsOfShuffledQuadAbsDiff<kAOffset + 2, kAOffset + 1, kAOffset + 1, 6768 kAOffset>( 6769 a, BitCast(d, Broadcast<kBOffset>(BitCast(du32, b)))); 6770 #endif 6771 } 6772 6773 #if !HWY_IS_MSAN 6774 // ------------------------------ I32/I64 SaturatedAdd (MaskFromVec) 6775 6776 HWY_API Vec512<int32_t> SaturatedAdd(Vec512<int32_t> a, Vec512<int32_t> b) { 6777 const DFromV<decltype(a)> d; 6778 const auto sum = a + b; 6779 const auto overflow_mask = MaskFromVec( 6780 Vec512<int32_t>{_mm512_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)}); 6781 const auto i32_max = Set(d, LimitsMax<int32_t>()); 6782 const Vec512<int32_t> overflow_result{_mm512_mask_ternarylogic_epi32( 6783 i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; 6784 return IfThenElse(overflow_mask, overflow_result, sum); 6785 } 6786 6787 HWY_API Vec512<int64_t> SaturatedAdd(Vec512<int64_t> a, Vec512<int64_t> b) { 6788 const DFromV<decltype(a)> d; 6789 const auto sum = a + b; 6790 const auto overflow_mask = MaskFromVec( 6791 Vec512<int64_t>{_mm512_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)}); 6792 const auto i64_max = Set(d, LimitsMax<int64_t>()); 6793 const Vec512<int64_t> overflow_result{_mm512_mask_ternarylogic_epi64( 6794 i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; 6795 return IfThenElse(overflow_mask, overflow_result, sum); 6796 } 6797 6798 // ------------------------------ I32/I64 SaturatedSub (MaskFromVec) 6799 6800 HWY_API Vec512<int32_t> SaturatedSub(Vec512<int32_t> a, Vec512<int32_t> b) { 6801 const DFromV<decltype(a)> d; 6802 const auto diff = a - b; 6803 const auto overflow_mask = MaskFromVec( 6804 Vec512<int32_t>{_mm512_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)}); 6805 const auto i32_max = Set(d, LimitsMax<int32_t>()); 6806 const Vec512<int32_t> overflow_result{_mm512_mask_ternarylogic_epi32( 6807 i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; 6808 return IfThenElse(overflow_mask, overflow_result, diff); 6809 } 6810 6811 HWY_API Vec512<int64_t> SaturatedSub(Vec512<int64_t> a, Vec512<int64_t> b) { 6812 const DFromV<decltype(a)> d; 6813 const auto diff = a - b; 6814 const auto overflow_mask = MaskFromVec( 6815 Vec512<int64_t>{_mm512_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)}); 6816 const auto i64_max = Set(d, LimitsMax<int64_t>()); 6817 const Vec512<int64_t> overflow_result{_mm512_mask_ternarylogic_epi64( 6818 i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; 6819 return IfThenElse(overflow_mask, overflow_result, diff); 6820 } 6821 #endif // !HWY_IS_MSAN 6822 6823 // ------------------------------ Mask testing 6824 6825 // Beware: the suffix indicates the number of mask bits, not lane size! 6826 6827 namespace detail { 6828 6829 template <typename T> 6830 HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512<T> mask) { 6831 #if HWY_COMPILER_HAS_MASK_INTRINSICS 6832 return _kortestz_mask64_u8(mask.raw, mask.raw); 6833 #else 6834 return mask.raw == 0; 6835 #endif 6836 } 6837 template <typename T> 6838 HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512<T> mask) { 6839 #if HWY_COMPILER_HAS_MASK_INTRINSICS 6840 return _kortestz_mask32_u8(mask.raw, mask.raw); 6841 #else 6842 return mask.raw == 0; 6843 #endif 6844 } 6845 template <typename T> 6846 HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512<T> mask) { 6847 #if HWY_COMPILER_HAS_MASK_INTRINSICS 6848 return _kortestz_mask16_u8(mask.raw, mask.raw); 6849 #else 6850 return mask.raw == 0; 6851 #endif 6852 } 6853 template <typename T> 6854 HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512<T> mask) { 6855 #if HWY_COMPILER_HAS_MASK_INTRINSICS 6856 return _kortestz_mask8_u8(mask.raw, mask.raw); 6857 #else 6858 return mask.raw == 0; 6859 #endif 6860 } 6861 6862 } // namespace detail 6863 6864 template <class D, HWY_IF_V_SIZE_D(D, 64)> 6865 HWY_API bool AllFalse(D /* tag */, const MFromD<D> mask) { 6866 return detail::AllFalse(hwy::SizeTag<sizeof(TFromD<D>)>(), mask); 6867 } 6868 6869 namespace detail { 6870 6871 template <typename T> 6872 HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512<T> mask) { 6873 #if HWY_COMPILER_HAS_MASK_INTRINSICS 6874 return _kortestc_mask64_u8(mask.raw, mask.raw); 6875 #else 6876 return mask.raw == 0xFFFFFFFFFFFFFFFFull; 6877 #endif 6878 } 6879 template <typename T> 6880 HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512<T> mask) { 6881 #if HWY_COMPILER_HAS_MASK_INTRINSICS 6882 return _kortestc_mask32_u8(mask.raw, mask.raw); 6883 #else 6884 return mask.raw == 0xFFFFFFFFull; 6885 #endif 6886 } 6887 template <typename T> 6888 HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512<T> mask) { 6889 #if HWY_COMPILER_HAS_MASK_INTRINSICS 6890 return _kortestc_mask16_u8(mask.raw, mask.raw); 6891 #else 6892 return mask.raw == 0xFFFFull; 6893 #endif 6894 } 6895 template <typename T> 6896 HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512<T> mask) { 6897 #if HWY_COMPILER_HAS_MASK_INTRINSICS 6898 return _kortestc_mask8_u8(mask.raw, mask.raw); 6899 #else 6900 return mask.raw == 0xFFull; 6901 #endif 6902 } 6903 6904 } // namespace detail 6905 6906 template <class D, HWY_IF_V_SIZE_D(D, 64)> 6907 HWY_API bool AllTrue(D /* tag */, const MFromD<D> mask) { 6908 return detail::AllTrue(hwy::SizeTag<sizeof(TFromD<D>)>(), mask); 6909 } 6910 6911 // `p` points to at least 8 readable bytes, not all of which need be valid. 6912 template <class D, HWY_IF_V_SIZE_D(D, 64)> 6913 HWY_API MFromD<D> LoadMaskBits(D /* tag */, const uint8_t* HWY_RESTRICT bits) { 6914 MFromD<D> mask; 6915 CopyBytes<8 / sizeof(TFromD<D>)>(bits, &mask.raw); 6916 // N >= 8 (= 512 / 64), so no need to mask invalid bits. 6917 return mask; 6918 } 6919 6920 // `p` points to at least 8 writable bytes. 6921 template <class D, HWY_IF_V_SIZE_D(D, 64)> 6922 HWY_API size_t StoreMaskBits(D /* tag */, MFromD<D> mask, uint8_t* bits) { 6923 const size_t kNumBytes = 8 / sizeof(TFromD<D>); 6924 CopyBytes<kNumBytes>(&mask.raw, bits); 6925 // N >= 8 (= 512 / 64), so no need to mask invalid bits. 6926 return kNumBytes; 6927 } 6928 6929 template <class D, HWY_IF_V_SIZE_D(D, 64)> 6930 HWY_API size_t CountTrue(D /* tag */, const MFromD<D> mask) { 6931 return PopCount(static_cast<uint64_t>(mask.raw)); 6932 } 6933 6934 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_T_SIZE_D(D, 1)> 6935 HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD<D> mask) { 6936 return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); 6937 } 6938 6939 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 6940 HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD<D> mask) { 6941 return Num0BitsBelowLS1Bit_Nonzero64(mask.raw); 6942 } 6943 6944 template <class D, HWY_IF_V_SIZE_D(D, 64)> 6945 HWY_API intptr_t FindFirstTrue(D d, MFromD<D> mask) { 6946 return mask.raw ? static_cast<intptr_t>(FindKnownFirstTrue(d, mask)) 6947 : intptr_t{-1}; 6948 } 6949 6950 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_T_SIZE_D(D, 1)> 6951 HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD<D> mask) { 6952 return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask.raw); 6953 } 6954 6955 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)> 6956 HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD<D> mask) { 6957 return 63 - Num0BitsAboveMS1Bit_Nonzero64(mask.raw); 6958 } 6959 6960 template <class D, HWY_IF_V_SIZE_D(D, 64)> 6961 HWY_API intptr_t FindLastTrue(D d, MFromD<D> mask) { 6962 return mask.raw ? static_cast<intptr_t>(FindKnownLastTrue(d, mask)) 6963 : intptr_t{-1}; 6964 } 6965 6966 // ------------------------------ Compress 6967 6968 template <typename T, HWY_IF_T_SIZE(T, 8)> 6969 HWY_API Vec512<T> Compress(Vec512<T> v, Mask512<T> mask) { 6970 // See CompressIsPartition. u64 is faster than u32. 6971 alignas(16) static constexpr uint64_t packed_array[256] = { 6972 // From PrintCompress32x8Tables, without the FirstN extension (there is 6973 // no benefit to including them because 64-bit CompressStore is anyway 6974 // masked, but also no harm because TableLookupLanes ignores the MSB). 6975 0x76543210, 0x76543210, 0x76543201, 0x76543210, 0x76543102, 0x76543120, 6976 0x76543021, 0x76543210, 0x76542103, 0x76542130, 0x76542031, 0x76542310, 6977 0x76541032, 0x76541320, 0x76540321, 0x76543210, 0x76532104, 0x76532140, 6978 0x76532041, 0x76532410, 0x76531042, 0x76531420, 0x76530421, 0x76534210, 6979 0x76521043, 0x76521430, 0x76520431, 0x76524310, 0x76510432, 0x76514320, 6980 0x76504321, 0x76543210, 0x76432105, 0x76432150, 0x76432051, 0x76432510, 6981 0x76431052, 0x76431520, 0x76430521, 0x76435210, 0x76421053, 0x76421530, 6982 0x76420531, 0x76425310, 0x76410532, 0x76415320, 0x76405321, 0x76453210, 6983 0x76321054, 0x76321540, 0x76320541, 0x76325410, 0x76310542, 0x76315420, 6984 0x76305421, 0x76354210, 0x76210543, 0x76215430, 0x76205431, 0x76254310, 6985 0x76105432, 0x76154320, 0x76054321, 0x76543210, 0x75432106, 0x75432160, 6986 0x75432061, 0x75432610, 0x75431062, 0x75431620, 0x75430621, 0x75436210, 6987 0x75421063, 0x75421630, 0x75420631, 0x75426310, 0x75410632, 0x75416320, 6988 0x75406321, 0x75463210, 0x75321064, 0x75321640, 0x75320641, 0x75326410, 6989 0x75310642, 0x75316420, 0x75306421, 0x75364210, 0x75210643, 0x75216430, 6990 0x75206431, 0x75264310, 0x75106432, 0x75164320, 0x75064321, 0x75643210, 6991 0x74321065, 0x74321650, 0x74320651, 0x74326510, 0x74310652, 0x74316520, 6992 0x74306521, 0x74365210, 0x74210653, 0x74216530, 0x74206531, 0x74265310, 6993 0x74106532, 0x74165320, 0x74065321, 0x74653210, 0x73210654, 0x73216540, 6994 0x73206541, 0x73265410, 0x73106542, 0x73165420, 0x73065421, 0x73654210, 6995 0x72106543, 0x72165430, 0x72065431, 0x72654310, 0x71065432, 0x71654320, 6996 0x70654321, 0x76543210, 0x65432107, 0x65432170, 0x65432071, 0x65432710, 6997 0x65431072, 0x65431720, 0x65430721, 0x65437210, 0x65421073, 0x65421730, 6998 0x65420731, 0x65427310, 0x65410732, 0x65417320, 0x65407321, 0x65473210, 6999 0x65321074, 0x65321740, 0x65320741, 0x65327410, 0x65310742, 0x65317420, 7000 0x65307421, 0x65374210, 0x65210743, 0x65217430, 0x65207431, 0x65274310, 7001 0x65107432, 0x65174320, 0x65074321, 0x65743210, 0x64321075, 0x64321750, 7002 0x64320751, 0x64327510, 0x64310752, 0x64317520, 0x64307521, 0x64375210, 7003 0x64210753, 0x64217530, 0x64207531, 0x64275310, 0x64107532, 0x64175320, 7004 0x64075321, 0x64753210, 0x63210754, 0x63217540, 0x63207541, 0x63275410, 7005 0x63107542, 0x63175420, 0x63075421, 0x63754210, 0x62107543, 0x62175430, 7006 0x62075431, 0x62754310, 0x61075432, 0x61754320, 0x60754321, 0x67543210, 7007 0x54321076, 0x54321760, 0x54320761, 0x54327610, 0x54310762, 0x54317620, 7008 0x54307621, 0x54376210, 0x54210763, 0x54217630, 0x54207631, 0x54276310, 7009 0x54107632, 0x54176320, 0x54076321, 0x54763210, 0x53210764, 0x53217640, 7010 0x53207641, 0x53276410, 0x53107642, 0x53176420, 0x53076421, 0x53764210, 7011 0x52107643, 0x52176430, 0x52076431, 0x52764310, 0x51076432, 0x51764320, 7012 0x50764321, 0x57643210, 0x43210765, 0x43217650, 0x43207651, 0x43276510, 7013 0x43107652, 0x43176520, 0x43076521, 0x43765210, 0x42107653, 0x42176530, 7014 0x42076531, 0x42765310, 0x41076532, 0x41765320, 0x40765321, 0x47653210, 7015 0x32107654, 0x32176540, 0x32076541, 0x32765410, 0x31076542, 0x31765420, 7016 0x30765421, 0x37654210, 0x21076543, 0x21765430, 0x20765431, 0x27654310, 7017 0x10765432, 0x17654320, 0x07654321, 0x76543210}; 7018 7019 // For lane i, shift the i-th 4-bit index down to bits [0, 3) - 7020 // _mm512_permutexvar_epi64 will ignore the upper bits. 7021 const DFromV<decltype(v)> d; 7022 const RebindToUnsigned<decltype(d)> du64; 7023 const auto packed = Set(du64, packed_array[mask.raw]); 7024 alignas(64) static constexpr uint64_t shifts[8] = {0, 4, 8, 12, 7025 16, 20, 24, 28}; 7026 const auto indices = Indices512<T>{(packed >> Load(du64, shifts)).raw}; 7027 return TableLookupLanes(v, indices); 7028 } 7029 7030 // ------------------------------ Expand 7031 7032 namespace detail { 7033 7034 #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 7035 HWY_INLINE Vec512<uint8_t> NativeExpand(Vec512<uint8_t> v, 7036 Mask512<uint8_t> mask) { 7037 return Vec512<uint8_t>{_mm512_maskz_expand_epi8(mask.raw, v.raw)}; 7038 } 7039 7040 HWY_INLINE Vec512<uint16_t> NativeExpand(Vec512<uint16_t> v, 7041 Mask512<uint16_t> mask) { 7042 return Vec512<uint16_t>{_mm512_maskz_expand_epi16(mask.raw, v.raw)}; 7043 } 7044 7045 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U8_D(D)> 7046 HWY_INLINE VFromD<D> NativeLoadExpand(Mask512<uint8_t> mask, D /* d */, 7047 const uint8_t* HWY_RESTRICT unaligned) { 7048 return VFromD<D>{_mm512_maskz_expandloadu_epi8(mask.raw, unaligned)}; 7049 } 7050 7051 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U16_D(D)> 7052 HWY_INLINE VFromD<D> NativeLoadExpand(Mask512<uint16_t> mask, D /* d */, 7053 const uint16_t* HWY_RESTRICT unaligned) { 7054 return VFromD<D>{_mm512_maskz_expandloadu_epi16(mask.raw, unaligned)}; 7055 } 7056 #endif // HWY_TARGET <= HWY_AVX3_DL 7057 7058 HWY_INLINE Vec512<uint32_t> NativeExpand(Vec512<uint32_t> v, 7059 Mask512<uint32_t> mask) { 7060 return Vec512<uint32_t>{_mm512_maskz_expand_epi32(mask.raw, v.raw)}; 7061 } 7062 7063 HWY_INLINE Vec512<uint64_t> NativeExpand(Vec512<uint64_t> v, 7064 Mask512<uint64_t> mask) { 7065 return Vec512<uint64_t>{_mm512_maskz_expand_epi64(mask.raw, v.raw)}; 7066 } 7067 7068 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U32_D(D)> 7069 HWY_INLINE VFromD<D> NativeLoadExpand(Mask512<uint32_t> mask, D /* d */, 7070 const uint32_t* HWY_RESTRICT unaligned) { 7071 return VFromD<D>{_mm512_maskz_expandloadu_epi32(mask.raw, unaligned)}; 7072 } 7073 7074 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)> 7075 HWY_INLINE VFromD<D> NativeLoadExpand(Mask512<uint64_t> mask, D /* d */, 7076 const uint64_t* HWY_RESTRICT unaligned) { 7077 return VFromD<D>{_mm512_maskz_expandloadu_epi64(mask.raw, unaligned)}; 7078 } 7079 7080 } // namespace detail 7081 7082 template <typename T, HWY_IF_T_SIZE(T, 1)> 7083 HWY_API Vec512<T> Expand(Vec512<T> v, const Mask512<T> mask) { 7084 const Full512<T> d; 7085 #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 7086 const RebindToUnsigned<decltype(d)> du; 7087 const auto mu = RebindMask(du, mask); 7088 return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); 7089 #else 7090 // LUTs are infeasible for 2^64 possible masks, so splice together two 7091 // half-vector Expand. 7092 const Full256<T> dh; 7093 constexpr size_t N = MaxLanes(d); 7094 // We have to shift the input by a variable number of u8. Shuffling requires 7095 // VBMI2, in which case we would already have NativeExpand. We instead 7096 // load at an offset, which may incur a store to load forwarding stall. 7097 alignas(64) T lanes[N]; 7098 Store(v, d, lanes); 7099 using Bits = typename Mask256<T>::Raw; 7100 const Mask256<T> maskL{ 7101 static_cast<Bits>(mask.raw & Bits{(1ULL << (N / 2)) - 1})}; 7102 const Mask256<T> maskH{static_cast<Bits>(mask.raw >> (N / 2))}; 7103 const size_t countL = CountTrue(dh, maskL); 7104 const Vec256<T> expandL = Expand(LowerHalf(v), maskL); 7105 const Vec256<T> expandH = Expand(LoadU(dh, lanes + countL), maskH); 7106 return Combine(d, expandH, expandL); 7107 #endif 7108 } 7109 7110 template <typename T, HWY_IF_T_SIZE(T, 2)> 7111 HWY_API Vec512<T> Expand(Vec512<T> v, const Mask512<T> mask) { 7112 const Full512<T> d; 7113 const RebindToUnsigned<decltype(d)> du; 7114 const Vec512<uint16_t> vu = BitCast(du, v); 7115 #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 7116 return BitCast(d, detail::NativeExpand(vu, RebindMask(du, mask))); 7117 #else // AVX3 7118 // LUTs are infeasible for 2^32 possible masks, so splice together two 7119 // half-vector Expand. 7120 const Full256<T> dh; 7121 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 7122 using Bits = typename Mask256<T>::Raw; 7123 const Mask256<T> maskL{ 7124 static_cast<Bits>(mask.raw & static_cast<Bits>((1ULL << (N / 2)) - 1))}; 7125 const Mask256<T> maskH{static_cast<Bits>(mask.raw >> (N / 2))}; 7126 // In AVX3 we can permutevar, which avoids a potential store to load 7127 // forwarding stall vs. reloading the input. 7128 alignas(64) uint16_t iota[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 7129 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 7130 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; 7131 const Vec512<uint16_t> indices = LoadU(du, iota + CountTrue(dh, maskL)); 7132 const Vec512<uint16_t> shifted{_mm512_permutexvar_epi16(indices.raw, vu.raw)}; 7133 const Vec256<T> expandL = Expand(LowerHalf(v), maskL); 7134 const Vec256<T> expandH = Expand(LowerHalf(BitCast(d, shifted)), maskH); 7135 return Combine(d, expandH, expandL); 7136 #endif // AVX3 7137 } 7138 7139 template <class V, class M, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 4) | (1 << 8))> 7140 HWY_API V Expand(V v, const M mask) { 7141 const DFromV<decltype(v)> d; 7142 const RebindToUnsigned<decltype(d)> du; 7143 const auto mu = RebindMask(du, mask); 7144 return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); 7145 } 7146 7147 // For smaller vectors, it is likely more efficient to promote to 32-bit. 7148 // This works for u8x16, u16x8, u16x16 (can be promoted to u32x16), but is 7149 // unnecessary if HWY_AVX3_DL, which provides native instructions. 7150 #if HWY_TARGET > HWY_AVX3_DL // no VBMI2 7151 7152 template <class V, class M, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2)), 7153 HWY_IF_LANES_LE_D(DFromV<V>, 16)> 7154 HWY_API V Expand(V v, M mask) { 7155 const DFromV<V> d; 7156 const RebindToUnsigned<decltype(d)> du; 7157 const Rebind<uint32_t, decltype(d)> du32; 7158 const VFromD<decltype(du)> vu = BitCast(du, v); 7159 using M32 = MFromD<decltype(du32)>; 7160 const M32 m32{static_cast<typename M32::Raw>(mask.raw)}; 7161 return BitCast(d, TruncateTo(du, Expand(PromoteTo(du32, vu), m32))); 7162 } 7163 7164 #endif // HWY_TARGET > HWY_AVX3_DL 7165 7166 // ------------------------------ LoadExpand 7167 7168 template <class D, HWY_IF_V_SIZE_D(D, 64), 7169 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 1) | (1 << 2))> 7170 HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d, 7171 const TFromD<D>* HWY_RESTRICT unaligned) { 7172 #if HWY_TARGET <= HWY_AVX3_DL // VBMI2 7173 const RebindToUnsigned<decltype(d)> du; 7174 using TU = TFromD<decltype(du)>; 7175 const TU* HWY_RESTRICT pu = reinterpret_cast<const TU*>(unaligned); 7176 const MFromD<decltype(du)> mu = RebindMask(du, mask); 7177 return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); 7178 #else 7179 return Expand(LoadU(d, unaligned), mask); 7180 #endif 7181 } 7182 7183 template <class D, HWY_IF_V_SIZE_D(D, 64), 7184 HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))> 7185 HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d, 7186 const TFromD<D>* HWY_RESTRICT unaligned) { 7187 const RebindToUnsigned<decltype(d)> du; 7188 using TU = TFromD<decltype(du)>; 7189 const TU* HWY_RESTRICT pu = reinterpret_cast<const TU*>(unaligned); 7190 const MFromD<decltype(du)> mu = RebindMask(du, mask); 7191 return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); 7192 } 7193 7194 // ------------------------------ CompressNot 7195 7196 template <typename T, HWY_IF_T_SIZE(T, 8)> 7197 HWY_API Vec512<T> CompressNot(Vec512<T> v, Mask512<T> mask) { 7198 // See CompressIsPartition. u64 is faster than u32. 7199 alignas(16) static constexpr uint64_t packed_array[256] = { 7200 // From PrintCompressNot32x8Tables, without the FirstN extension (there is 7201 // no benefit to including them because 64-bit CompressStore is anyway 7202 // masked, but also no harm because TableLookupLanes ignores the MSB). 7203 0x76543210, 0x07654321, 0x17654320, 0x10765432, 0x27654310, 0x20765431, 7204 0x21765430, 0x21076543, 0x37654210, 0x30765421, 0x31765420, 0x31076542, 7205 0x32765410, 0x32076541, 0x32176540, 0x32107654, 0x47653210, 0x40765321, 7206 0x41765320, 0x41076532, 0x42765310, 0x42076531, 0x42176530, 0x42107653, 7207 0x43765210, 0x43076521, 0x43176520, 0x43107652, 0x43276510, 0x43207651, 7208 0x43217650, 0x43210765, 0x57643210, 0x50764321, 0x51764320, 0x51076432, 7209 0x52764310, 0x52076431, 0x52176430, 0x52107643, 0x53764210, 0x53076421, 7210 0x53176420, 0x53107642, 0x53276410, 0x53207641, 0x53217640, 0x53210764, 7211 0x54763210, 0x54076321, 0x54176320, 0x54107632, 0x54276310, 0x54207631, 7212 0x54217630, 0x54210763, 0x54376210, 0x54307621, 0x54317620, 0x54310762, 7213 0x54327610, 0x54320761, 0x54321760, 0x54321076, 0x67543210, 0x60754321, 7214 0x61754320, 0x61075432, 0x62754310, 0x62075431, 0x62175430, 0x62107543, 7215 0x63754210, 0x63075421, 0x63175420, 0x63107542, 0x63275410, 0x63207541, 7216 0x63217540, 0x63210754, 0x64753210, 0x64075321, 0x64175320, 0x64107532, 7217 0x64275310, 0x64207531, 0x64217530, 0x64210753, 0x64375210, 0x64307521, 7218 0x64317520, 0x64310752, 0x64327510, 0x64320751, 0x64321750, 0x64321075, 7219 0x65743210, 0x65074321, 0x65174320, 0x65107432, 0x65274310, 0x65207431, 7220 0x65217430, 0x65210743, 0x65374210, 0x65307421, 0x65317420, 0x65310742, 7221 0x65327410, 0x65320741, 0x65321740, 0x65321074, 0x65473210, 0x65407321, 7222 0x65417320, 0x65410732, 0x65427310, 0x65420731, 0x65421730, 0x65421073, 7223 0x65437210, 0x65430721, 0x65431720, 0x65431072, 0x65432710, 0x65432071, 7224 0x65432170, 0x65432107, 0x76543210, 0x70654321, 0x71654320, 0x71065432, 7225 0x72654310, 0x72065431, 0x72165430, 0x72106543, 0x73654210, 0x73065421, 7226 0x73165420, 0x73106542, 0x73265410, 0x73206541, 0x73216540, 0x73210654, 7227 0x74653210, 0x74065321, 0x74165320, 0x74106532, 0x74265310, 0x74206531, 7228 0x74216530, 0x74210653, 0x74365210, 0x74306521, 0x74316520, 0x74310652, 7229 0x74326510, 0x74320651, 0x74321650, 0x74321065, 0x75643210, 0x75064321, 7230 0x75164320, 0x75106432, 0x75264310, 0x75206431, 0x75216430, 0x75210643, 7231 0x75364210, 0x75306421, 0x75316420, 0x75310642, 0x75326410, 0x75320641, 7232 0x75321640, 0x75321064, 0x75463210, 0x75406321, 0x75416320, 0x75410632, 7233 0x75426310, 0x75420631, 0x75421630, 0x75421063, 0x75436210, 0x75430621, 7234 0x75431620, 0x75431062, 0x75432610, 0x75432061, 0x75432160, 0x75432106, 7235 0x76543210, 0x76054321, 0x76154320, 0x76105432, 0x76254310, 0x76205431, 7236 0x76215430, 0x76210543, 0x76354210, 0x76305421, 0x76315420, 0x76310542, 7237 0x76325410, 0x76320541, 0x76321540, 0x76321054, 0x76453210, 0x76405321, 7238 0x76415320, 0x76410532, 0x76425310, 0x76420531, 0x76421530, 0x76421053, 7239 0x76435210, 0x76430521, 0x76431520, 0x76431052, 0x76432510, 0x76432051, 7240 0x76432150, 0x76432105, 0x76543210, 0x76504321, 0x76514320, 0x76510432, 7241 0x76524310, 0x76520431, 0x76521430, 0x76521043, 0x76534210, 0x76530421, 7242 0x76531420, 0x76531042, 0x76532410, 0x76532041, 0x76532140, 0x76532104, 7243 0x76543210, 0x76540321, 0x76541320, 0x76541032, 0x76542310, 0x76542031, 7244 0x76542130, 0x76542103, 0x76543210, 0x76543021, 0x76543120, 0x76543102, 7245 0x76543210, 0x76543201, 0x76543210, 0x76543210}; 7246 7247 // For lane i, shift the i-th 4-bit index down to bits [0, 3) - 7248 // _mm512_permutexvar_epi64 will ignore the upper bits. 7249 const DFromV<decltype(v)> d; 7250 const RebindToUnsigned<decltype(d)> du64; 7251 const auto packed = Set(du64, packed_array[mask.raw]); 7252 alignas(64) static constexpr uint64_t shifts[8] = {0, 4, 8, 12, 7253 16, 20, 24, 28}; 7254 const auto indices = Indices512<T>{(packed >> Load(du64, shifts)).raw}; 7255 return TableLookupLanes(v, indices); 7256 } 7257 7258 // ------------------------------ LoadInterleaved4 7259 7260 // Actually implemented in generic_ops, we just overload LoadTransposedBlocks4. 7261 namespace detail { 7262 7263 // Type-safe wrapper. 7264 template <_MM_PERM_ENUM kPerm, typename T> 7265 Vec512<T> Shuffle128(const Vec512<T> lo, const Vec512<T> hi) { 7266 const DFromV<decltype(lo)> d; 7267 const RebindToUnsigned<decltype(d)> du; 7268 return BitCast(d, VFromD<decltype(du)>{_mm512_shuffle_i64x2( 7269 BitCast(du, lo).raw, BitCast(du, hi).raw, kPerm)}); 7270 } 7271 template <_MM_PERM_ENUM kPerm> 7272 Vec512<float> Shuffle128(const Vec512<float> lo, const Vec512<float> hi) { 7273 return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, kPerm)}; 7274 } 7275 template <_MM_PERM_ENUM kPerm> 7276 Vec512<double> Shuffle128(const Vec512<double> lo, const Vec512<double> hi) { 7277 return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, kPerm)}; 7278 } 7279 7280 // Input (128-bit blocks): 7281 // 3 2 1 0 (<- first block in unaligned) 7282 // 7 6 5 4 7283 // b a 9 8 7284 // Output: 7285 // 9 6 3 0 (LSB of A) 7286 // a 7 4 1 7287 // b 8 5 2 7288 template <class D, HWY_IF_V_SIZE_D(D, 64)> 7289 HWY_API void LoadTransposedBlocks3(D d, const TFromD<D>* HWY_RESTRICT unaligned, 7290 VFromD<D>& A, VFromD<D>& B, VFromD<D>& C) { 7291 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 7292 const VFromD<D> v3210 = LoadU(d, unaligned + 0 * N); 7293 const VFromD<D> v7654 = LoadU(d, unaligned + 1 * N); 7294 const VFromD<D> vba98 = LoadU(d, unaligned + 2 * N); 7295 7296 const VFromD<D> v5421 = detail::Shuffle128<_MM_PERM_BACB>(v3210, v7654); 7297 const VFromD<D> va976 = detail::Shuffle128<_MM_PERM_CBDC>(v7654, vba98); 7298 7299 A = detail::Shuffle128<_MM_PERM_CADA>(v3210, va976); 7300 B = detail::Shuffle128<_MM_PERM_DBCA>(v5421, va976); 7301 C = detail::Shuffle128<_MM_PERM_DADB>(v5421, vba98); 7302 } 7303 7304 // Input (128-bit blocks): 7305 // 3 2 1 0 (<- first block in unaligned) 7306 // 7 6 5 4 7307 // b a 9 8 7308 // f e d c 7309 // Output: 7310 // c 8 4 0 (LSB of A) 7311 // d 9 5 1 7312 // e a 6 2 7313 // f b 7 3 7314 template <class D, HWY_IF_V_SIZE_D(D, 64)> 7315 HWY_API void LoadTransposedBlocks4(D d, const TFromD<D>* HWY_RESTRICT unaligned, 7316 VFromD<D>& vA, VFromD<D>& vB, VFromD<D>& vC, 7317 VFromD<D>& vD) { 7318 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 7319 const VFromD<D> v3210 = LoadU(d, unaligned + 0 * N); 7320 const VFromD<D> v7654 = LoadU(d, unaligned + 1 * N); 7321 const VFromD<D> vba98 = LoadU(d, unaligned + 2 * N); 7322 const VFromD<D> vfedc = LoadU(d, unaligned + 3 * N); 7323 7324 const VFromD<D> v5410 = detail::Shuffle128<_MM_PERM_BABA>(v3210, v7654); 7325 const VFromD<D> vdc98 = detail::Shuffle128<_MM_PERM_BABA>(vba98, vfedc); 7326 const VFromD<D> v7632 = detail::Shuffle128<_MM_PERM_DCDC>(v3210, v7654); 7327 const VFromD<D> vfeba = detail::Shuffle128<_MM_PERM_DCDC>(vba98, vfedc); 7328 vA = detail::Shuffle128<_MM_PERM_CACA>(v5410, vdc98); 7329 vB = detail::Shuffle128<_MM_PERM_DBDB>(v5410, vdc98); 7330 vC = detail::Shuffle128<_MM_PERM_CACA>(v7632, vfeba); 7331 vD = detail::Shuffle128<_MM_PERM_DBDB>(v7632, vfeba); 7332 } 7333 7334 } // namespace detail 7335 7336 // ------------------------------ StoreInterleaved2 7337 7338 // Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. 7339 7340 namespace detail { 7341 7342 // Input (128-bit blocks): 7343 // 6 4 2 0 (LSB of i) 7344 // 7 5 3 1 7345 // Output: 7346 // 3 2 1 0 7347 // 7 6 5 4 7348 template <class D, HWY_IF_V_SIZE_D(D, 64)> 7349 HWY_API void StoreTransposedBlocks2(const VFromD<D> i, const VFromD<D> j, D d, 7350 TFromD<D>* HWY_RESTRICT unaligned) { 7351 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 7352 const auto j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); 7353 const auto j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); 7354 const auto j1_i1_j0_i0 = 7355 detail::Shuffle128<_MM_PERM_DBCA>(j1_j0_i1_i0, j1_j0_i1_i0); 7356 const auto j3_i3_j2_i2 = 7357 detail::Shuffle128<_MM_PERM_DBCA>(j3_j2_i3_i2, j3_j2_i3_i2); 7358 StoreU(j1_i1_j0_i0, d, unaligned + 0 * N); 7359 StoreU(j3_i3_j2_i2, d, unaligned + 1 * N); 7360 } 7361 7362 // Input (128-bit blocks): 7363 // 9 6 3 0 (LSB of i) 7364 // a 7 4 1 7365 // b 8 5 2 7366 // Output: 7367 // 3 2 1 0 7368 // 7 6 5 4 7369 // b a 9 8 7370 template <class D, HWY_IF_V_SIZE_D(D, 64)> 7371 HWY_API void StoreTransposedBlocks3(const VFromD<D> i, const VFromD<D> j, 7372 const VFromD<D> k, D d, 7373 TFromD<D>* HWY_RESTRICT unaligned) { 7374 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 7375 const VFromD<D> j2_j0_i2_i0 = detail::Shuffle128<_MM_PERM_CACA>(i, j); 7376 const VFromD<D> i3_i1_k2_k0 = detail::Shuffle128<_MM_PERM_DBCA>(k, i); 7377 const VFromD<D> j3_j1_k3_k1 = detail::Shuffle128<_MM_PERM_DBDB>(k, j); 7378 7379 const VFromD<D> out0 = // i1 k0 j0 i0 7380 detail::Shuffle128<_MM_PERM_CACA>(j2_j0_i2_i0, i3_i1_k2_k0); 7381 const VFromD<D> out1 = // j2 i2 k1 j1 7382 detail::Shuffle128<_MM_PERM_DBAC>(j3_j1_k3_k1, j2_j0_i2_i0); 7383 const VFromD<D> out2 = // k3 j3 i3 k2 7384 detail::Shuffle128<_MM_PERM_BDDB>(i3_i1_k2_k0, j3_j1_k3_k1); 7385 7386 StoreU(out0, d, unaligned + 0 * N); 7387 StoreU(out1, d, unaligned + 1 * N); 7388 StoreU(out2, d, unaligned + 2 * N); 7389 } 7390 7391 // Input (128-bit blocks): 7392 // c 8 4 0 (LSB of i) 7393 // d 9 5 1 7394 // e a 6 2 7395 // f b 7 3 7396 // Output: 7397 // 3 2 1 0 7398 // 7 6 5 4 7399 // b a 9 8 7400 // f e d c 7401 template <class D, HWY_IF_V_SIZE_D(D, 64)> 7402 HWY_API void StoreTransposedBlocks4(const VFromD<D> i, const VFromD<D> j, 7403 const VFromD<D> k, const VFromD<D> l, D d, 7404 TFromD<D>* HWY_RESTRICT unaligned) { 7405 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 7406 const VFromD<D> j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); 7407 const VFromD<D> l1_l0_k1_k0 = detail::Shuffle128<_MM_PERM_BABA>(k, l); 7408 const VFromD<D> j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); 7409 const VFromD<D> l3_l2_k3_k2 = detail::Shuffle128<_MM_PERM_DCDC>(k, l); 7410 const VFromD<D> out0 = 7411 detail::Shuffle128<_MM_PERM_CACA>(j1_j0_i1_i0, l1_l0_k1_k0); 7412 const VFromD<D> out1 = 7413 detail::Shuffle128<_MM_PERM_DBDB>(j1_j0_i1_i0, l1_l0_k1_k0); 7414 const VFromD<D> out2 = 7415 detail::Shuffle128<_MM_PERM_CACA>(j3_j2_i3_i2, l3_l2_k3_k2); 7416 const VFromD<D> out3 = 7417 detail::Shuffle128<_MM_PERM_DBDB>(j3_j2_i3_i2, l3_l2_k3_k2); 7418 StoreU(out0, d, unaligned + 0 * N); 7419 StoreU(out1, d, unaligned + 1 * N); 7420 StoreU(out2, d, unaligned + 2 * N); 7421 StoreU(out3, d, unaligned + 3 * N); 7422 } 7423 7424 } // namespace detail 7425 7426 // ------------------------------ Additional mask logical operations 7427 7428 template <class T> 7429 HWY_API Mask512<T> SetAtOrAfterFirst(Mask512<T> mask) { 7430 return Mask512<T>{ 7431 static_cast<typename Mask512<T>::Raw>(0u - detail::AVX3Blsi(mask.raw))}; 7432 } 7433 template <class T> 7434 HWY_API Mask512<T> SetBeforeFirst(Mask512<T> mask) { 7435 return Mask512<T>{ 7436 static_cast<typename Mask512<T>::Raw>(detail::AVX3Blsi(mask.raw) - 1u)}; 7437 } 7438 template <class T> 7439 HWY_API Mask512<T> SetAtOrBeforeFirst(Mask512<T> mask) { 7440 return Mask512<T>{ 7441 static_cast<typename Mask512<T>::Raw>(detail::AVX3Blsmsk(mask.raw))}; 7442 } 7443 template <class T> 7444 HWY_API Mask512<T> SetOnlyFirst(Mask512<T> mask) { 7445 return Mask512<T>{ 7446 static_cast<typename Mask512<T>::Raw>(detail::AVX3Blsi(mask.raw))}; 7447 } 7448 7449 // ------------------------------ Shl (Dup128VecFromValues) 7450 7451 HWY_API Vec512<uint16_t> operator<<(Vec512<uint16_t> v, Vec512<uint16_t> bits) { 7452 return Vec512<uint16_t>{_mm512_sllv_epi16(v.raw, bits.raw)}; 7453 } 7454 7455 // 8-bit: may use the << overload for uint16_t. 7456 HWY_API Vec512<uint8_t> operator<<(Vec512<uint8_t> v, Vec512<uint8_t> bits) { 7457 const DFromV<decltype(v)> d; 7458 #if HWY_TARGET <= HWY_AVX3_DL 7459 // kMask[i] = 0xFF >> i 7460 const VFromD<decltype(d)> masks = 7461 Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, 7462 0, 0, 0, 0, 0, 0, 0); 7463 // kShl[i] = 1 << i 7464 const VFromD<decltype(d)> shl = 7465 Dup128VecFromValues(d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0, 7466 0, 0, 0, 0, 0, 0, 0); 7467 v = And(v, TableLookupBytes(masks, bits)); 7468 const VFromD<decltype(d)> mul = TableLookupBytes(shl, bits); 7469 return VFromD<decltype(d)>{_mm512_gf2p8mul_epi8(v.raw, mul.raw)}; 7470 #else 7471 const Repartition<uint16_t, decltype(d)> dw; 7472 using VW = VFromD<decltype(dw)>; 7473 const VW even_mask = Set(dw, 0x00FF); 7474 const VW odd_mask = Set(dw, 0xFF00); 7475 const VW vw = BitCast(dw, v); 7476 const VW bits16 = BitCast(dw, bits); 7477 // Shift even lanes in-place 7478 const VW evens = vw << And(bits16, even_mask); 7479 const VW odds = And(vw, odd_mask) << ShiftRight<8>(bits16); 7480 return OddEven(BitCast(d, odds), BitCast(d, evens)); 7481 #endif 7482 } 7483 7484 HWY_API Vec512<uint32_t> operator<<(const Vec512<uint32_t> v, 7485 const Vec512<uint32_t> bits) { 7486 return Vec512<uint32_t>{_mm512_sllv_epi32(v.raw, bits.raw)}; 7487 } 7488 7489 HWY_API Vec512<uint64_t> operator<<(const Vec512<uint64_t> v, 7490 const Vec512<uint64_t> bits) { 7491 return Vec512<uint64_t>{_mm512_sllv_epi64(v.raw, bits.raw)}; 7492 } 7493 7494 // Signed left shift is the same as unsigned. 7495 template <typename T, HWY_IF_SIGNED(T)> 7496 HWY_API Vec512<T> operator<<(const Vec512<T> v, const Vec512<T> bits) { 7497 const DFromV<decltype(v)> di; 7498 const RebindToUnsigned<decltype(di)> du; 7499 return BitCast(di, BitCast(du, v) << BitCast(du, bits)); 7500 } 7501 7502 // ------------------------------ Shr (IfVecThenElse) 7503 7504 HWY_API Vec512<uint16_t> operator>>(const Vec512<uint16_t> v, 7505 const Vec512<uint16_t> bits) { 7506 return Vec512<uint16_t>{_mm512_srlv_epi16(v.raw, bits.raw)}; 7507 } 7508 7509 // 8-bit uses 16-bit shifts. 7510 HWY_API Vec512<uint8_t> operator>>(Vec512<uint8_t> v, Vec512<uint8_t> bits) { 7511 const DFromV<decltype(v)> d; 7512 const RepartitionToWide<decltype(d)> dw; 7513 using VW = VFromD<decltype(dw)>; 7514 const VW mask = Set(dw, 0x00FF); 7515 const VW vw = BitCast(dw, v); 7516 const VW bits16 = BitCast(dw, bits); 7517 const VW evens = And(vw, mask) >> And(bits16, mask); 7518 // Shift odd lanes in-place 7519 const VW odds = vw >> ShiftRight<8>(bits16); 7520 return OddEven(BitCast(d, odds), BitCast(d, evens)); 7521 } 7522 7523 HWY_API Vec512<uint32_t> operator>>(const Vec512<uint32_t> v, 7524 const Vec512<uint32_t> bits) { 7525 return Vec512<uint32_t>{_mm512_srlv_epi32(v.raw, bits.raw)}; 7526 } 7527 7528 HWY_API Vec512<uint64_t> operator>>(const Vec512<uint64_t> v, 7529 const Vec512<uint64_t> bits) { 7530 return Vec512<uint64_t>{_mm512_srlv_epi64(v.raw, bits.raw)}; 7531 } 7532 7533 HWY_API Vec512<int16_t> operator>>(const Vec512<int16_t> v, 7534 const Vec512<int16_t> bits) { 7535 return Vec512<int16_t>{_mm512_srav_epi16(v.raw, bits.raw)}; 7536 } 7537 7538 // 8-bit uses 16-bit shifts. 7539 HWY_API Vec512<int8_t> operator>>(Vec512<int8_t> v, Vec512<int8_t> bits) { 7540 const DFromV<decltype(v)> d; 7541 const RepartitionToWide<decltype(d)> dw; 7542 const RebindToUnsigned<decltype(dw)> dw_u; 7543 using VW = VFromD<decltype(dw)>; 7544 const VW mask = Set(dw, 0x00FF); 7545 const VW vw = BitCast(dw, v); 7546 const VW bits16 = BitCast(dw, bits); 7547 const VW evens = ShiftRight<8>(ShiftLeft<8>(vw)) >> And(bits16, mask); 7548 // Shift odd lanes in-place 7549 const VW odds = vw >> BitCast(dw, ShiftRight<8>(BitCast(dw_u, bits16))); 7550 return OddEven(BitCast(d, odds), BitCast(d, evens)); 7551 } 7552 7553 HWY_API Vec512<int32_t> operator>>(const Vec512<int32_t> v, 7554 const Vec512<int32_t> bits) { 7555 return Vec512<int32_t>{_mm512_srav_epi32(v.raw, bits.raw)}; 7556 } 7557 7558 HWY_API Vec512<int64_t> operator>>(const Vec512<int64_t> v, 7559 const Vec512<int64_t> bits) { 7560 return Vec512<int64_t>{_mm512_srav_epi64(v.raw, bits.raw)}; 7561 } 7562 7563 // ------------------------------ WidenMulPairwiseAdd 7564 7565 #if HWY_NATIVE_DOT_BF16 7566 template <class DF, HWY_IF_F32_D(DF), HWY_IF_V_SIZE_D(DF, 64), 7567 class VBF = VFromD<Repartition<bfloat16_t, DF>>> 7568 HWY_API VFromD<DF> WidenMulPairwiseAdd(DF df, VBF a, VBF b) { 7569 return VFromD<DF>{_mm512_dpbf16_ps(Zero(df).raw, 7570 reinterpret_cast<__m512bh>(a.raw), 7571 reinterpret_cast<__m512bh>(b.raw))}; 7572 } 7573 #endif // HWY_NATIVE_DOT_BF16 7574 7575 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> 7576 HWY_API VFromD<D> WidenMulPairwiseAdd(D /*d32*/, Vec512<int16_t> a, 7577 Vec512<int16_t> b) { 7578 return VFromD<D>{_mm512_madd_epi16(a.raw, b.raw)}; 7579 } 7580 7581 // ------------------------------ SatWidenMulPairwiseAdd 7582 template <class DI16, HWY_IF_V_SIZE_D(DI16, 64), HWY_IF_I16_D(DI16)> 7583 HWY_API VFromD<DI16> SatWidenMulPairwiseAdd( 7584 DI16 /* tag */, VFromD<Repartition<uint8_t, DI16>> a, 7585 VFromD<Repartition<int8_t, DI16>> b) { 7586 return VFromD<DI16>{_mm512_maddubs_epi16(a.raw, b.raw)}; 7587 } 7588 7589 // ------------------------------ SatWidenMulPairwiseAccumulate 7590 #if HWY_TARGET <= HWY_AVX3_DL 7591 template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 64)> 7592 HWY_API VFromD<DI32> SatWidenMulPairwiseAccumulate( 7593 DI32 /* tag */, VFromD<Repartition<int16_t, DI32>> a, 7594 VFromD<Repartition<int16_t, DI32>> b, VFromD<DI32> sum) { 7595 return VFromD<DI32>{_mm512_dpwssds_epi32(sum.raw, a.raw, b.raw)}; 7596 } 7597 #endif // HWY_TARGET <= HWY_AVX3_DL 7598 7599 // ------------------------------ ReorderWidenMulAccumulate 7600 7601 #if HWY_NATIVE_DOT_BF16 7602 template <class DF, HWY_IF_F32_D(DF), HWY_IF_V_SIZE_D(DF, 64), 7603 class VBF = VFromD<Repartition<bfloat16_t, DF>>> 7604 HWY_API VFromD<DF> ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, 7605 const VFromD<DF> sum0, 7606 VFromD<DF>& /*sum1*/) { 7607 return VFromD<DF>{_mm512_dpbf16_ps(sum0.raw, 7608 reinterpret_cast<__m512bh>(a.raw), 7609 reinterpret_cast<__m512bh>(b.raw))}; 7610 } 7611 #endif // HWY_NATIVE_DOT_BF16 7612 7613 template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)> 7614 HWY_API VFromD<D> ReorderWidenMulAccumulate(D d, Vec512<int16_t> a, 7615 Vec512<int16_t> b, 7616 const VFromD<D> sum0, 7617 VFromD<D>& /*sum1*/) { 7618 (void)d; 7619 #if HWY_TARGET <= HWY_AVX3_DL 7620 return VFromD<D>{_mm512_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; 7621 #else 7622 return sum0 + WidenMulPairwiseAdd(d, a, b); 7623 #endif 7624 } 7625 7626 HWY_API Vec512<int32_t> RearrangeToOddPlusEven(const Vec512<int32_t> sum0, 7627 Vec512<int32_t> /*sum1*/) { 7628 return sum0; // invariant already holds 7629 } 7630 7631 HWY_API Vec512<uint32_t> RearrangeToOddPlusEven(const Vec512<uint32_t> sum0, 7632 Vec512<uint32_t> /*sum1*/) { 7633 return sum0; // invariant already holds 7634 } 7635 7636 // ------------------------------ SumOfMulQuadAccumulate 7637 7638 #if HWY_TARGET <= HWY_AVX3_DL 7639 7640 template <class DI32, HWY_IF_V_SIZE_D(DI32, 64)> 7641 HWY_API VFromD<DI32> SumOfMulQuadAccumulate( 7642 DI32 /*di32*/, VFromD<Repartition<uint8_t, DI32>> a_u, 7643 VFromD<Repartition<int8_t, DI32>> b_i, VFromD<DI32> sum) { 7644 return VFromD<DI32>{_mm512_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; 7645 } 7646 7647 #if HWY_X86_HAVE_AVX10_2_OPS 7648 template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 64)> 7649 HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/, 7650 VFromD<Repartition<int8_t, DI32>> a, 7651 VFromD<Repartition<int8_t, DI32>> b, 7652 VFromD<DI32> sum) { 7653 return VFromD<DI32>{_mm512_dpbssd_epi32(sum.raw, a.raw, b.raw)}; 7654 } 7655 7656 template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_D(DU32, 64)> 7657 HWY_API VFromD<DU32> SumOfMulQuadAccumulate( 7658 DU32 /*du32*/, VFromD<Repartition<uint8_t, DU32>> a, 7659 VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) { 7660 return VFromD<DU32>{_mm512_dpbuud_epi32(sum.raw, a.raw, b.raw)}; 7661 } 7662 #endif // HWY_X86_HAVE_AVX10_2_OPS 7663 7664 #endif 7665 7666 // ------------------------------ Reductions 7667 7668 namespace detail { 7669 7670 // Used by generic_ops-inl 7671 template <class D, class Func, HWY_IF_V_SIZE_D(D, 64)> 7672 HWY_INLINE VFromD<D> ReduceAcrossBlocks(D d, Func f, VFromD<D> v) { 7673 v = f(v, SwapAdjacentBlocks(v)); 7674 return f(v, ReverseBlocks(d, v)); 7675 } 7676 7677 } // namespace detail 7678 7679 // ------------------------------ BitShuffle 7680 #if HWY_TARGET <= HWY_AVX3_DL 7681 template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>), 7682 HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, 64)> 7683 HWY_API V BitShuffle(V v, VI idx) { 7684 const DFromV<decltype(v)> d64; 7685 const RebindToUnsigned<decltype(d64)> du64; 7686 const Rebind<uint8_t, decltype(d64)> du8; 7687 7688 const __mmask64 mmask64_bit_shuf_result = 7689 _mm512_bitshuffle_epi64_mask(v.raw, idx.raw); 7690 7691 #if HWY_ARCH_X86_64 7692 const VFromD<decltype(du8)> vu8_bit_shuf_result{ 7693 _mm_cvtsi64_si128(static_cast<int64_t>(mmask64_bit_shuf_result))}; 7694 #else 7695 const int32_t i32_lo_bit_shuf_result = 7696 static_cast<int32_t>(mmask64_bit_shuf_result); 7697 const int32_t i32_hi_bit_shuf_result = 7698 static_cast<int32_t>(_kshiftri_mask64(mmask64_bit_shuf_result, 32)); 7699 7700 const VFromD<decltype(du8)> vu8_bit_shuf_result = ResizeBitCast( 7701 du8, InterleaveLower( 7702 Vec128<uint32_t>{_mm_cvtsi32_si128(i32_lo_bit_shuf_result)}, 7703 Vec128<uint32_t>{_mm_cvtsi32_si128(i32_hi_bit_shuf_result)})); 7704 #endif 7705 7706 return BitCast(d64, PromoteTo(du64, vu8_bit_shuf_result)); 7707 } 7708 #endif // HWY_TARGET <= HWY_AVX3_DL 7709 7710 // ------------------------------ MultiRotateRight 7711 7712 #if HWY_TARGET <= HWY_AVX3_DL 7713 7714 #ifdef HWY_NATIVE_MULTIROTATERIGHT 7715 #undef HWY_NATIVE_MULTIROTATERIGHT 7716 #else 7717 #define HWY_NATIVE_MULTIROTATERIGHT 7718 #endif 7719 7720 template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>), 7721 HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> 7722 HWY_API V MultiRotateRight(V v, VI idx) { 7723 return V{_mm512_multishift_epi64_epi8(idx.raw, v.raw)}; 7724 } 7725 7726 #endif 7727 7728 // -------------------- LeadingZeroCount 7729 7730 template <class V, HWY_IF_UI32(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> 7731 HWY_API V LeadingZeroCount(V v) { 7732 return V{_mm512_lzcnt_epi32(v.raw)}; 7733 } 7734 7735 template <class V, HWY_IF_UI64(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)> 7736 HWY_API V LeadingZeroCount(V v) { 7737 return V{_mm512_lzcnt_epi64(v.raw)}; 7738 } 7739 7740 // NOLINTNEXTLINE(google-readability-namespace-comments) 7741 } // namespace HWY_NAMESPACE 7742 } // namespace hwy 7743 HWY_AFTER_NAMESPACE(); 7744 7745 // Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - 7746 // the warning seems to be issued at the call site of intrinsics, i.e. our code. 7747 HWY_DIAGNOSTICS(pop)