arm_sve-inl.h (264305B)
1 // Copyright 2021 Google LLC 2 // Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> 3 // SPDX-License-Identifier: Apache-2.0 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 // Arm SVE[2] vectors (length not known at compile time). 18 // External include guard in highway.h - see comment there. 19 20 #include <arm_sve.h> 21 22 #include "hwy/ops/shared-inl.h" 23 24 // Arm C215 declares that SVE vector lengths will always be a power of two. 25 // We default to relying on this, which makes some operations more efficient. 26 // You can still opt into fixups by setting this to 0 (unsupported). 27 #ifndef HWY_SVE_IS_POW2 28 #define HWY_SVE_IS_POW2 1 29 #endif 30 31 #if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 32 #define HWY_SVE_HAVE_2 1 33 #else 34 #define HWY_SVE_HAVE_2 0 35 #endif 36 37 // If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available: 38 // create/get/set/dup, ld/st, sel, rev, trn, uzp, zip. 39 #if HWY_ARM_HAVE_SCALAR_BF16_TYPE && defined(__ARM_FEATURE_SVE_BF16) 40 #define HWY_SVE_HAVE_BF16_FEATURE 1 41 #else 42 #define HWY_SVE_HAVE_BF16_FEATURE 0 43 #endif 44 45 // HWY_SVE_HAVE_BF16_VEC is defined to 1 if the SVE svbfloat16_t vector type 46 // is supported, even if HWY_SVE_HAVE_BF16_FEATURE (= intrinsics) is 0. 47 #if HWY_SVE_HAVE_BF16_FEATURE || \ 48 (HWY_COMPILER_CLANG >= 1200 && defined(__ARM_FEATURE_SVE_BF16)) || \ 49 HWY_COMPILER_GCC_ACTUAL >= 1000 50 #define HWY_SVE_HAVE_BF16_VEC 1 51 #else 52 #define HWY_SVE_HAVE_BF16_VEC 0 53 #endif 54 55 // HWY_SVE_HAVE_F32_TO_BF16C is defined to 1 if the SVE svcvt_bf16_f32_x 56 // and svcvtnt_bf16_f32_x intrinsics are available, even if the __bf16 type 57 // is disabled 58 #if HWY_SVE_HAVE_BF16_VEC && defined(__ARM_FEATURE_SVE_BF16) 59 #define HWY_SVE_HAVE_F32_TO_BF16C 1 60 #else 61 #define HWY_SVE_HAVE_F32_TO_BF16C 0 62 #endif 63 64 HWY_BEFORE_NAMESPACE(); 65 namespace hwy { 66 namespace HWY_NAMESPACE { 67 68 template <class V> 69 struct DFromV_t {}; // specialized in macros 70 template <class V> 71 using DFromV = typename DFromV_t<RemoveConst<V>>::type; 72 73 template <class V> 74 using TFromV = TFromD<DFromV<V>>; 75 76 // ================================================== MACROS 77 78 // Generate specializations and function definitions using X macros. Although 79 // harder to read and debug, writing everything manually is too bulky. 80 81 namespace detail { // for code folding 82 83 // Args: BASE, CHAR, BITS, HALF, NAME, OP 84 85 // Unsigned: 86 #define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP) 87 #define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP) 88 #define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ 89 X_MACRO(uint, u, 32, 16, NAME, OP) 90 #define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ 91 X_MACRO(uint, u, 64, 32, NAME, OP) 92 93 // Signed: 94 #define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP) 95 #define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP) 96 #define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP) 97 #define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP) 98 99 // Float: 100 #define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ 101 X_MACRO(float, f, 16, 16, NAME, OP) 102 #define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ 103 X_MACRO(float, f, 32, 16, NAME, OP) 104 #define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ 105 X_MACRO(float, f, 64, 32, NAME, OP) 106 107 #define HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) \ 108 X_MACRO(bfloat, bf, 16, 16, NAME, OP) 109 110 #if HWY_SVE_HAVE_BF16_FEATURE 111 #define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) \ 112 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) 113 // We have both f16 and bf16, so nothing is emulated. 114 115 // NOTE: hwy::EnableIf<!hwy::IsSame<D, D>()>* = nullptr is used instead of 116 // hwy::EnableIf<false>* = nullptr to avoid compiler errors since 117 // !hwy::IsSame<D, D>() is always false and as !hwy::IsSame<D, D>() will cause 118 // SFINAE to occur instead of a hard error due to a dependency on the D template 119 // argument 120 #define HWY_SVE_IF_EMULATED_D(D) hwy::EnableIf<!hwy::IsSame<D, D>()>* = nullptr 121 #define HWY_GENERIC_IF_EMULATED_D(D) \ 122 hwy::EnableIf<!hwy::IsSame<D, D>()>* = nullptr 123 #define HWY_SVE_IF_NOT_EMULATED_D(D) hwy::EnableIf<true>* = nullptr 124 #else 125 #define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) 126 #define HWY_SVE_IF_EMULATED_D(D) HWY_IF_BF16_D(D) 127 #define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) 128 #define HWY_SVE_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) 129 #endif // HWY_SVE_HAVE_BF16_FEATURE 130 131 // For all element sizes: 132 #define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ 133 HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ 134 HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ 135 HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ 136 HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) 137 138 #define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ 139 HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \ 140 HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \ 141 HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \ 142 HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) 143 144 #define HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) \ 145 HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ 146 HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) 147 148 // HWY_SVE_FOREACH_F does not include HWY_SVE_FOREACH_BF16 because SVE lacks 149 // bf16 overloads for some intrinsics (especially less-common arithmetic). 150 // However, this does include f16 because SVE supports it unconditionally. 151 #define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \ 152 HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ 153 HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) 154 155 // Commonly used type categories for a given element size: 156 #define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \ 157 HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ 158 HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) 159 160 #define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \ 161 HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ 162 HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) 163 164 #define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ 165 HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ 166 HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) 167 168 #define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ 169 HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ 170 HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) 171 172 #define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \ 173 HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ 174 HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ 175 HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) 176 177 // Commonly used type categories: 178 #define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \ 179 HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ 180 HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) 181 182 #define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \ 183 HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ 184 HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) 185 186 #define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \ 187 HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ 188 HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ 189 HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) 190 191 // Assemble types for use in x-macros 192 #define HWY_SVE_T(BASE, BITS) BASE##BITS##_t 193 #define HWY_SVE_D(BASE, BITS, N, POW2) Simd<HWY_SVE_T(BASE, BITS), N, POW2> 194 #define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t 195 #define HWY_SVE_TUPLE(BASE, BITS, MUL) sv##BASE##BITS##x##MUL##_t 196 197 } // namespace detail 198 199 #define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \ 200 template <> \ 201 struct DFromV_t<HWY_SVE_V(BASE, BITS)> { \ 202 using type = ScalableTag<HWY_SVE_T(BASE, BITS)>; \ 203 }; 204 205 HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) 206 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 207 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) 208 #endif 209 #undef HWY_SPECIALIZE 210 211 // Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX 212 // instructions, and we anyway only use it when the predicate is ptrue. 213 214 // vector = f(vector), e.g. Not 215 #define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \ 216 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 217 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ 218 } 219 #define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \ 220 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 221 return sv##OP##_##CHAR##BITS(v); \ 222 } 223 #define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ 224 HWY_API HWY_SVE_V(BASE, BITS) \ 225 NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ 226 return sv##OP##_##CHAR##BITS##_m(no, m, a); \ 227 } 228 #define HWY_SVE_RETV_ARGMV(BASE, CHAR, BITS, HALF, NAME, OP) \ 229 HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ 230 return sv##OP##_##CHAR##BITS##_x(m, v); \ 231 } 232 #define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ 233 HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ 234 return sv##OP##_##CHAR##BITS##_z(m, a); \ 235 } 236 237 // vector = f(vector, scalar), e.g. detail::AddN 238 #define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ 239 HWY_API HWY_SVE_V(BASE, BITS) \ 240 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ 241 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ 242 } 243 #define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \ 244 HWY_API HWY_SVE_V(BASE, BITS) \ 245 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ 246 return sv##OP##_##CHAR##BITS(a, b); \ 247 } 248 249 // vector = f(vector, vector), e.g. Add 250 #define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ 251 HWY_API HWY_SVE_V(BASE, BITS) \ 252 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 253 return sv##OP##_##CHAR##BITS(a, b); \ 254 } 255 // All-true mask 256 #define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ 257 HWY_API HWY_SVE_V(BASE, BITS) \ 258 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 259 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ 260 } 261 // User-specified mask. Mask=false value is undefined and must be set by caller 262 // because SVE instructions take it from one of the two inputs, whereas 263 // AVX-512, RVV and Highway allow a third argument. 264 #define HWY_SVE_RETV_ARGMVV(BASE, CHAR, BITS, HALF, NAME, OP) \ 265 HWY_API HWY_SVE_V(BASE, BITS) \ 266 NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 267 return sv##OP##_##CHAR##BITS##_x(m, a, b); \ 268 } 269 // User-specified mask. Mask=false value is zero. 270 #define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ 271 HWY_API HWY_SVE_V(BASE, BITS) \ 272 NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 273 return sv##OP##_##CHAR##BITS##_z(m, a, b); \ 274 } 275 276 #define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ 277 HWY_API HWY_SVE_V(BASE, BITS) \ 278 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ 279 HWY_SVE_V(BASE, BITS) c) { \ 280 return sv##OP##_##CHAR##BITS(a, b, c); \ 281 } 282 #define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ 283 HWY_API HWY_SVE_V(BASE, BITS) \ 284 NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ 285 HWY_SVE_V(BASE, BITS) c) { \ 286 return sv##OP##_##CHAR##BITS##_x(m, a, b, c); \ 287 } 288 #define HWY_SVE_RETV_ARGMVVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ 289 HWY_API HWY_SVE_V(BASE, BITS) \ 290 NAME(svbool_t m, HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ 291 HWY_SVE_V(BASE, BITS) add) { \ 292 return sv##OP##_##CHAR##BITS##_z(m, x, mul, add); \ 293 } 294 295 // ------------------------------ Lanes 296 297 namespace detail { 298 299 // Returns actual lanes of a hardware vector without rounding to a power of two. 300 template <typename T, HWY_IF_T_SIZE(T, 1)> 301 HWY_INLINE size_t AllHardwareLanes() { 302 return svcntb_pat(SV_ALL); 303 } 304 template <typename T, HWY_IF_T_SIZE(T, 2)> 305 HWY_INLINE size_t AllHardwareLanes() { 306 return svcnth_pat(SV_ALL); 307 } 308 template <typename T, HWY_IF_T_SIZE(T, 4)> 309 HWY_INLINE size_t AllHardwareLanes() { 310 return svcntw_pat(SV_ALL); 311 } 312 template <typename T, HWY_IF_T_SIZE(T, 8)> 313 HWY_INLINE size_t AllHardwareLanes() { 314 return svcntd_pat(SV_ALL); 315 } 316 317 // All-true mask from a macro 318 319 #if HWY_SVE_IS_POW2 320 #define HWY_SVE_ALL_PTRUE(BITS) svptrue_b##BITS() 321 #define HWY_SVE_PTRUE(BITS) svptrue_b##BITS() 322 #else 323 #define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b##BITS(SV_ALL) 324 #define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2) 325 #endif // HWY_SVE_IS_POW2 326 327 } // namespace detail 328 329 #if HWY_HAVE_SCALABLE 330 331 // Returns actual number of lanes after capping by N and shifting. May return 0 332 // (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8). 333 template <typename T, size_t N, int kPow2> 334 HWY_API size_t Lanes(Simd<T, N, kPow2> d) { 335 const size_t actual = detail::AllHardwareLanes<T>(); 336 constexpr size_t kMaxLanes = MaxLanes(d); 337 constexpr int kClampedPow2 = HWY_MIN(kPow2, 0); 338 // Common case of full vectors: avoid any extra instructions. 339 if (detail::IsFull(d)) return actual; 340 return HWY_MIN(detail::ScaleByPower(actual, kClampedPow2), kMaxLanes); 341 } 342 343 #endif // HWY_HAVE_SCALABLE 344 345 // ================================================== MASK INIT 346 347 // One mask bit per byte; only the one belonging to the lowest byte is valid. 348 349 // ------------------------------ FirstN 350 #define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \ 351 template <size_t N, int kPow2> \ 352 HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \ 353 const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \ 354 return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast<uint32_t>(limit)); \ 355 } 356 HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) 357 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 358 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_FIRSTN, FirstN, whilelt) 359 #endif 360 361 template <class D, HWY_SVE_IF_EMULATED_D(D)> 362 svbool_t FirstN(D /* tag */, size_t count) { 363 return FirstN(RebindToUnsigned<D>(), count); 364 } 365 366 #undef HWY_SVE_FIRSTN 367 368 template <class D> 369 using MFromD = svbool_t; 370 371 namespace detail { 372 373 #define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ 374 template <size_t N, int kPow2> \ 375 HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ 376 return HWY_SVE_PTRUE(BITS); \ 377 } \ 378 template <size_t N, int kPow2> \ 379 HWY_API svbool_t All##NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ 380 return HWY_SVE_ALL_PTRUE(BITS); \ 381 } 382 383 HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true 384 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) 385 #undef HWY_SVE_WRAP_PTRUE 386 387 HWY_API svbool_t PFalse() { return svpfalse_b(); } 388 389 // Returns all-true if d is HWY_FULL or FirstN(N) after capping N. 390 // 391 // This is used in functions that load/store memory; other functions (e.g. 392 // arithmetic) can ignore d and use PTrue instead. 393 // 394 // Always use FirstN(N) for HWY_TARGET == HWY_SVE2_128 to avoid vector length 395 // information loss when using PTrue(d) predicates in memory intrinsics. 396 // 397 // SVE2_256 is untested due to unavailable hardware and cannot assume 398 // equal minimum and maximum vector lengths as SVE2_128 can. 399 template <class D> 400 svbool_t MakeMask(D d) { 401 #if HWY_TARGET != HWY_SVE2_128 402 HWY_IF_CONSTEXPR(IsFull(d)) { return PTrue(d); } 403 #endif 404 return FirstN(d, Lanes(d)); 405 } 406 407 } // namespace detail 408 409 #ifdef HWY_NATIVE_MASK_FALSE 410 #undef HWY_NATIVE_MASK_FALSE 411 #else 412 #define HWY_NATIVE_MASK_FALSE 413 #endif 414 415 template <class D> 416 HWY_API svbool_t MaskFalse(const D /*d*/) { 417 return detail::PFalse(); 418 } 419 420 #ifdef HWY_NATIVE_SET_MASK 421 #undef HWY_NATIVE_SET_MASK 422 #else 423 #define HWY_NATIVE_SET_MASK 424 #endif 425 426 template <class D> 427 HWY_API svbool_t SetMask(D d, bool val) { 428 // The SVE svdup_n_b* intrinsics are equivalent to the FirstN op below if 429 // detail::IsFull(d) is true since svdup_n_b* is simply a wrapper around the 430 // SVE whilelo instruction. 431 return FirstN(d, size_t{0} - static_cast<size_t>(val)); 432 } 433 434 // ================================================== INIT 435 436 // ------------------------------ Set 437 // vector = f(d, scalar), e.g. Set 438 #define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ 439 template <size_t N, int kPow2> \ 440 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 441 HWY_SVE_T(BASE, BITS) arg) { \ 442 return sv##OP##_##CHAR##BITS(arg); \ 443 } 444 445 HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) 446 #if HWY_SVE_HAVE_BF16_FEATURE // for if-elif chain 447 HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, dup_n) 448 #elif HWY_SVE_HAVE_BF16_VEC 449 // Required for Zero and VFromD 450 template <class D, HWY_IF_BF16_D(D)> 451 HWY_API svbfloat16_t Set(D d, bfloat16_t arg) { 452 return svreinterpret_bf16_u16( 453 Set(RebindToUnsigned<decltype(d)>(), BitCastScalar<uint16_t>(arg))); 454 } 455 #else // neither bf16 feature nor vector: emulate with u16 456 // Required for Zero and VFromD 457 template <class D, HWY_IF_BF16_D(D)> 458 HWY_API svuint16_t Set(D d, bfloat16_t arg) { 459 const RebindToUnsigned<decltype(d)> du; 460 return Set(du, BitCastScalar<uint16_t>(arg)); 461 } 462 #endif // HWY_SVE_HAVE_BF16_FEATURE 463 #undef HWY_SVE_SET 464 465 template <class D> 466 using VFromD = decltype(Set(D(), TFromD<D>())); 467 468 using VBF16 = VFromD<ScalableTag<bfloat16_t>>; 469 470 // ------------------------------ MaskedSetOr/MaskedSet 471 472 #define HWY_SVE_MASKED_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ 473 HWY_API HWY_SVE_V(BASE, BITS) \ 474 NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ 475 return sv##OP##_##CHAR##BITS##_m(no, m, op); \ 476 } 477 478 HWY_SVE_FOREACH(HWY_SVE_MASKED_SET_OR, MaskedSetOr, dup_n) 479 #undef HWY_SVE_MASKED_SET_OR 480 481 #define HWY_SVE_MASKED_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ 482 template <size_t N, int kPow2> \ 483 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 484 svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ 485 return sv##OP##_##CHAR##BITS##_z(m, op); \ 486 } 487 488 HWY_SVE_FOREACH(HWY_SVE_MASKED_SET, MaskedSet, dup_n) 489 #undef HWY_SVE_MASKED_SET 490 491 // ------------------------------ Zero 492 493 template <class D> 494 VFromD<D> Zero(D d) { 495 // Cast to support bfloat16_t. 496 const RebindToUnsigned<decltype(d)> du; 497 return BitCast(d, Set(du, 0)); 498 } 499 500 // ------------------------------ BitCast 501 502 namespace detail { 503 504 // u8: no change 505 #define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \ 506 HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ 507 return v; \ 508 } \ 509 template <size_t N, int kPow2> \ 510 HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ 511 HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ 512 return v; \ 513 } 514 515 // All other types 516 #define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \ 517 HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ 518 return sv##OP##_u8_##CHAR##BITS(v); \ 519 } \ 520 template <size_t N, int kPow2> \ 521 HWY_INLINE HWY_SVE_V(BASE, BITS) \ 522 BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \ 523 return sv##OP##_##CHAR##BITS##_u8(v); \ 524 } 525 526 // U08 is special-cased, hence do not use FOREACH. 527 HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) 528 HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret) 529 HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret) 530 HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret) 531 HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret) 532 HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) 533 534 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 535 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CAST, _, reinterpret) 536 #else // !(HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC) 537 template <class V, HWY_SVE_IF_EMULATED_D(DFromV<V>)> 538 HWY_INLINE svuint8_t BitCastToByte(V v) { 539 const RebindToUnsigned<DFromV<V>> du; 540 return BitCastToByte(BitCast(du, v)); 541 } 542 543 template <class D, HWY_SVE_IF_EMULATED_D(D)> 544 HWY_INLINE VFromD<D> BitCastFromByte(D d, svuint8_t v) { 545 const RebindToUnsigned<decltype(d)> du; 546 return BitCastFromByte(du, v); 547 } 548 #endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 549 550 #undef HWY_SVE_CAST_NOP 551 #undef HWY_SVE_CAST 552 553 } // namespace detail 554 555 template <class D, class FromV> 556 HWY_API VFromD<D> BitCast(D d, FromV v) { 557 return detail::BitCastFromByte(d, detail::BitCastToByte(v)); 558 } 559 560 // ------------------------------ Undefined 561 562 #define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ 563 template <size_t N, int kPow2> \ 564 HWY_API HWY_SVE_V(BASE, BITS) \ 565 NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ 566 return sv##OP##_##CHAR##BITS(); \ 567 } 568 569 HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) 570 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 571 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_UNDEFINED, Undefined, undef) 572 #endif 573 574 template <class D, HWY_SVE_IF_EMULATED_D(D)> 575 VFromD<D> Undefined(D d) { 576 const RebindToUnsigned<D> du; 577 return BitCast(d, Undefined(du)); 578 } 579 580 // ------------------------------ Tuple 581 582 // tuples = f(d, v..), e.g. Create2 583 #define HWY_SVE_CREATE(BASE, CHAR, BITS, HALF, NAME, OP) \ 584 template <size_t N, int kPow2> \ 585 HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \ 586 NAME##2(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 587 HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1) { \ 588 return sv##OP##2_##CHAR##BITS(v0, v1); \ 589 } \ 590 template <size_t N, int kPow2> \ 591 HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) NAME##3( \ 592 HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v0, \ 593 HWY_SVE_V(BASE, BITS) v1, HWY_SVE_V(BASE, BITS) v2) { \ 594 return sv##OP##3_##CHAR##BITS(v0, v1, v2); \ 595 } \ 596 template <size_t N, int kPow2> \ 597 HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \ 598 NAME##4(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 599 HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ 600 HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3) { \ 601 return sv##OP##4_##CHAR##BITS(v0, v1, v2, v3); \ 602 } 603 604 HWY_SVE_FOREACH(HWY_SVE_CREATE, Create, create) 605 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 606 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CREATE, Create, create) 607 #endif 608 #undef HWY_SVE_CREATE 609 610 template <class D> 611 using Vec2 = decltype(Create2(D(), Zero(D()), Zero(D()))); 612 template <class D> 613 using Vec3 = decltype(Create3(D(), Zero(D()), Zero(D()), Zero(D()))); 614 template <class D> 615 using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D()))); 616 617 #define HWY_SVE_GET(BASE, CHAR, BITS, HALF, NAME, OP) \ 618 template <size_t kIndex> \ 619 HWY_API HWY_SVE_V(BASE, BITS) NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple) { \ 620 return sv##OP##2_##CHAR##BITS(tuple, kIndex); \ 621 } \ 622 template <size_t kIndex> \ 623 HWY_API HWY_SVE_V(BASE, BITS) NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple) { \ 624 return sv##OP##3_##CHAR##BITS(tuple, kIndex); \ 625 } \ 626 template <size_t kIndex> \ 627 HWY_API HWY_SVE_V(BASE, BITS) NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple) { \ 628 return sv##OP##4_##CHAR##BITS(tuple, kIndex); \ 629 } 630 631 HWY_SVE_FOREACH(HWY_SVE_GET, Get, get) 632 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 633 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_GET, Get, get) 634 #endif 635 #undef HWY_SVE_GET 636 637 #define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ 638 template <size_t kIndex> \ 639 HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \ 640 NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(BASE, BITS) vec) { \ 641 return sv##OP##2_##CHAR##BITS(tuple, kIndex, vec); \ 642 } \ 643 template <size_t kIndex> \ 644 HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) \ 645 NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple, HWY_SVE_V(BASE, BITS) vec) { \ 646 return sv##OP##3_##CHAR##BITS(tuple, kIndex, vec); \ 647 } \ 648 template <size_t kIndex> \ 649 HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \ 650 NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple, HWY_SVE_V(BASE, BITS) vec) { \ 651 return sv##OP##4_##CHAR##BITS(tuple, kIndex, vec); \ 652 } 653 654 HWY_SVE_FOREACH(HWY_SVE_SET, Set, set) 655 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 656 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_SET, Set, set) 657 #endif 658 #undef HWY_SVE_SET 659 660 // ------------------------------ ResizeBitCast 661 662 // Same as BitCast on SVE 663 template <class D, class FromV> 664 HWY_API VFromD<D> ResizeBitCast(D d, FromV v) { 665 return BitCast(d, v); 666 } 667 668 // ------------------------------ Dup128VecFromValues 669 670 template <class D, HWY_IF_I8_D(D)> 671 HWY_API svint8_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 672 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 673 TFromD<D> t5, TFromD<D> t6, TFromD<D> t7, 674 TFromD<D> t8, TFromD<D> t9, TFromD<D> t10, 675 TFromD<D> t11, TFromD<D> t12, 676 TFromD<D> t13, TFromD<D> t14, 677 TFromD<D> t15) { 678 return svdupq_n_s8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, 679 t14, t15); 680 } 681 682 template <class D, HWY_IF_U8_D(D)> 683 HWY_API svuint8_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 684 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 685 TFromD<D> t5, TFromD<D> t6, TFromD<D> t7, 686 TFromD<D> t8, TFromD<D> t9, TFromD<D> t10, 687 TFromD<D> t11, TFromD<D> t12, 688 TFromD<D> t13, TFromD<D> t14, 689 TFromD<D> t15) { 690 return svdupq_n_u8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, 691 t14, t15); 692 } 693 694 template <class D, HWY_IF_I16_D(D)> 695 HWY_API svint16_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 696 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 697 TFromD<D> t5, TFromD<D> t6, 698 TFromD<D> t7) { 699 return svdupq_n_s16(t0, t1, t2, t3, t4, t5, t6, t7); 700 } 701 702 template <class D, HWY_IF_U16_D(D)> 703 HWY_API svuint16_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 704 TFromD<D> t2, TFromD<D> t3, TFromD<D> t4, 705 TFromD<D> t5, TFromD<D> t6, 706 TFromD<D> t7) { 707 return svdupq_n_u16(t0, t1, t2, t3, t4, t5, t6, t7); 708 } 709 710 template <class D, HWY_IF_F16_D(D)> 711 HWY_API svfloat16_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 712 TFromD<D> t2, TFromD<D> t3, 713 TFromD<D> t4, TFromD<D> t5, 714 TFromD<D> t6, TFromD<D> t7) { 715 return svdupq_n_f16(t0, t1, t2, t3, t4, t5, t6, t7); 716 } 717 718 template <class D, HWY_IF_BF16_D(D)> 719 HWY_API VBF16 Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1, TFromD<D> t2, 720 TFromD<D> t3, TFromD<D> t4, TFromD<D> t5, 721 TFromD<D> t6, TFromD<D> t7) { 722 #if HWY_SVE_HAVE_BF16_FEATURE 723 (void)d; 724 return svdupq_n_bf16(t0, t1, t2, t3, t4, t5, t6, t7); 725 #else 726 const RebindToUnsigned<decltype(d)> du; 727 return BitCast( 728 d, Dup128VecFromValues( 729 du, BitCastScalar<uint16_t>(t0), BitCastScalar<uint16_t>(t1), 730 BitCastScalar<uint16_t>(t2), BitCastScalar<uint16_t>(t3), 731 BitCastScalar<uint16_t>(t4), BitCastScalar<uint16_t>(t5), 732 BitCastScalar<uint16_t>(t6), BitCastScalar<uint16_t>(t7))); 733 #endif 734 } 735 736 template <class D, HWY_IF_I32_D(D)> 737 HWY_API svint32_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 738 TFromD<D> t2, TFromD<D> t3) { 739 return svdupq_n_s32(t0, t1, t2, t3); 740 } 741 742 template <class D, HWY_IF_U32_D(D)> 743 HWY_API svuint32_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 744 TFromD<D> t2, TFromD<D> t3) { 745 return svdupq_n_u32(t0, t1, t2, t3); 746 } 747 748 template <class D, HWY_IF_F32_D(D)> 749 HWY_API svfloat32_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1, 750 TFromD<D> t2, TFromD<D> t3) { 751 return svdupq_n_f32(t0, t1, t2, t3); 752 } 753 754 template <class D, HWY_IF_I64_D(D)> 755 HWY_API svint64_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { 756 return svdupq_n_s64(t0, t1); 757 } 758 759 template <class D, HWY_IF_U64_D(D)> 760 HWY_API svuint64_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { 761 return svdupq_n_u64(t0, t1); 762 } 763 764 template <class D, HWY_IF_F64_D(D)> 765 HWY_API svfloat64_t Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) { 766 return svdupq_n_f64(t0, t1); 767 } 768 769 // ------------------------------ GetLane 770 771 namespace detail { 772 #define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \ 773 HWY_INLINE HWY_SVE_T(BASE, BITS) \ 774 NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ 775 return sv##OP##_##CHAR##BITS(mask, v); \ 776 } 777 778 HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLaneM, lasta) 779 HWY_SVE_FOREACH(HWY_SVE_GET_LANE, ExtractLastMatchingLaneM, lastb) 780 #undef HWY_SVE_GET_LANE 781 } // namespace detail 782 783 template <class V> 784 HWY_API TFromV<V> GetLane(V v) { 785 return detail::GetLaneM(v, detail::PFalse()); 786 } 787 788 // ================================================== LOGICAL 789 790 // detail::*N() functions accept a scalar argument to avoid extra Set(). 791 792 // ------------------------------ Not 793 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not ) // NOLINT 794 795 // ------------------------------ And 796 797 namespace detail { 798 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n) 799 } // namespace detail 800 801 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and) 802 803 template <class V, HWY_IF_FLOAT_V(V)> 804 HWY_API V And(const V a, const V b) { 805 const DFromV<V> df; 806 const RebindToUnsigned<decltype(df)> du; 807 return BitCast(df, And(BitCast(du, a), BitCast(du, b))); 808 } 809 810 // ------------------------------ Or 811 812 namespace detail { 813 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, OrN, orr_n) 814 } // namespace detail 815 816 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr) 817 818 template <class V, HWY_IF_FLOAT_V(V)> 819 HWY_API V Or(const V a, const V b) { 820 const DFromV<V> df; 821 const RebindToUnsigned<decltype(df)> du; 822 return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); 823 } 824 825 // ------------------------------ MaskedOr 826 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedOr, orr) 827 828 // ------------------------------ Xor 829 830 namespace detail { 831 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n) 832 } // namespace detail 833 834 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor) 835 836 template <class V, HWY_IF_FLOAT_V(V)> 837 HWY_API V Xor(const V a, const V b) { 838 const DFromV<V> df; 839 const RebindToUnsigned<decltype(df)> du; 840 return BitCast(df, Xor(BitCast(du, a), BitCast(du, b))); 841 } 842 843 // ------------------------------ AndNot 844 845 namespace detail { 846 #define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ 847 HWY_API HWY_SVE_V(BASE, BITS) \ 848 NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 849 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ 850 } 851 852 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n) 853 #undef HWY_SVE_RETV_ARGPVN_SWAP 854 } // namespace detail 855 856 #define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ 857 HWY_API HWY_SVE_V(BASE, BITS) \ 858 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 859 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ 860 } 861 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic) 862 #undef HWY_SVE_RETV_ARGPVV_SWAP 863 864 template <class V, HWY_IF_FLOAT_V(V)> 865 HWY_API V AndNot(const V a, const V b) { 866 const DFromV<V> df; 867 const RebindToUnsigned<decltype(df)> du; 868 return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b))); 869 } 870 871 // ------------------------------ Xor3 872 873 #if HWY_SVE_HAVE_2 874 875 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3) 876 877 template <class V, HWY_IF_FLOAT_V(V)> 878 HWY_API V Xor3(const V x1, const V x2, const V x3) { 879 const DFromV<V> df; 880 const RebindToUnsigned<decltype(df)> du; 881 return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); 882 } 883 884 #else 885 template <class V> 886 HWY_API V Xor3(V x1, V x2, V x3) { 887 return Xor(x1, Xor(x2, x3)); 888 } 889 #endif 890 891 // ------------------------------ Or3 892 template <class V> 893 HWY_API V Or3(V o1, V o2, V o3) { 894 return Or(o1, Or(o2, o3)); 895 } 896 897 // ------------------------------ OrAnd 898 template <class V> 899 HWY_API V OrAnd(const V o, const V a1, const V a2) { 900 return Or(o, And(a1, a2)); 901 } 902 903 // ------------------------------ PopulationCount 904 905 #ifdef HWY_NATIVE_POPCNT 906 #undef HWY_NATIVE_POPCNT 907 #else 908 #define HWY_NATIVE_POPCNT 909 #endif 910 911 // Need to return original type instead of unsigned. 912 #define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \ 913 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 914 return BitCast(DFromV<decltype(v)>(), \ 915 sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \ 916 } 917 HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt) 918 #undef HWY_SVE_POPCNT 919 920 // ================================================== SIGN 921 922 // ------------------------------ Neg 923 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg) 924 925 HWY_API VBF16 Neg(VBF16 v) { 926 const DFromV<decltype(v)> d; 927 const RebindToUnsigned<decltype(d)> du; 928 using TU = TFromD<decltype(du)>; 929 return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask<TU>()))); 930 } 931 932 // ------------------------------ SaturatedNeg 933 #if HWY_SVE_HAVE_2 934 #ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 935 #undef HWY_NATIVE_SATURATED_NEG_8_16_32 936 #else 937 #define HWY_NATIVE_SATURATED_NEG_8_16_32 938 #endif 939 940 #ifdef HWY_NATIVE_SATURATED_NEG_64 941 #undef HWY_NATIVE_SATURATED_NEG_64 942 #else 943 #define HWY_NATIVE_SATURATED_NEG_64 944 #endif 945 946 HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedNeg, qneg) 947 #endif // HWY_SVE_HAVE_2 948 949 // ================================================== ARITHMETIC 950 951 // Per-target flags to prevent generic_ops-inl.h defining Add etc. 952 #ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS 953 #undef HWY_NATIVE_OPERATOR_REPLACEMENTS 954 #else 955 #define HWY_NATIVE_OPERATOR_REPLACEMENTS 956 #endif 957 958 // ------------------------------ Add 959 960 namespace detail { 961 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n) 962 } // namespace detail 963 964 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add) 965 966 // ------------------------------ Sub 967 968 namespace detail { 969 // Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg. 970 #define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \ 971 HWY_API HWY_SVE_V(BASE, BITS) \ 972 NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ 973 return sv##OP##_##CHAR##BITS##_z(pg, a, b); \ 974 } 975 976 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n) 977 #undef HWY_SVE_RETV_ARGPVN_MASK 978 } // namespace detail 979 980 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub) 981 982 // ------------------------------ SumsOf8 983 HWY_API svuint64_t SumsOf8(const svuint8_t v) { 984 const ScalableTag<uint32_t> du32; 985 const ScalableTag<uint64_t> du64; 986 const svbool_t pg = detail::PTrue(du64); 987 988 const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); 989 // Compute pairwise sum of u32 and extend to u64. 990 991 #if HWY_SVE_HAVE_2 992 return svadalp_u64_x(pg, Zero(du64), sums_of_4); 993 #else 994 const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); 995 // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) 996 const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); 997 return Add(hi, lo); 998 #endif 999 } 1000 1001 HWY_API svint64_t SumsOf8(const svint8_t v) { 1002 const ScalableTag<int32_t> di32; 1003 const ScalableTag<int64_t> di64; 1004 const svbool_t pg = detail::PTrue(di64); 1005 1006 const svint32_t sums_of_4 = svdot_n_s32(Zero(di32), v, 1); 1007 #if HWY_SVE_HAVE_2 1008 return svadalp_s64_x(pg, Zero(di64), sums_of_4); 1009 #else 1010 const svint64_t hi = svasr_n_s64_x(pg, BitCast(di64, sums_of_4), 32); 1011 // Isolate the lower 32 bits (to be added to the upper 32 and sign-extended) 1012 const svint64_t lo = svextw_s64_x(pg, BitCast(di64, sums_of_4)); 1013 return Add(hi, lo); 1014 #endif 1015 } 1016 1017 // ------------------------------ SumsOf2 1018 #if HWY_SVE_HAVE_2 1019 namespace detail { 1020 1021 HWY_INLINE svint16_t SumsOf2(hwy::SignedTag /*type_tag*/, 1022 hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { 1023 const ScalableTag<int16_t> di16; 1024 const svbool_t pg = detail::PTrue(di16); 1025 return svadalp_s16_x(pg, Zero(di16), v); 1026 } 1027 1028 HWY_INLINE svuint16_t SumsOf2(hwy::UnsignedTag /*type_tag*/, 1029 hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { 1030 const ScalableTag<uint16_t> du16; 1031 const svbool_t pg = detail::PTrue(du16); 1032 return svadalp_u16_x(pg, Zero(du16), v); 1033 } 1034 1035 HWY_INLINE svint32_t SumsOf2(hwy::SignedTag /*type_tag*/, 1036 hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { 1037 const ScalableTag<int32_t> di32; 1038 const svbool_t pg = detail::PTrue(di32); 1039 return svadalp_s32_x(pg, Zero(di32), v); 1040 } 1041 1042 HWY_INLINE svuint32_t SumsOf2(hwy::UnsignedTag /*type_tag*/, 1043 hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { 1044 const ScalableTag<uint32_t> du32; 1045 const svbool_t pg = detail::PTrue(du32); 1046 return svadalp_u32_x(pg, Zero(du32), v); 1047 } 1048 1049 HWY_INLINE svint64_t SumsOf2(hwy::SignedTag /*type_tag*/, 1050 hwy::SizeTag<4> /*lane_size_tag*/, svint32_t v) { 1051 const ScalableTag<int64_t> di64; 1052 const svbool_t pg = detail::PTrue(di64); 1053 return svadalp_s64_x(pg, Zero(di64), v); 1054 } 1055 1056 HWY_INLINE svuint64_t SumsOf2(hwy::UnsignedTag /*type_tag*/, 1057 hwy::SizeTag<4> /*lane_size_tag*/, svuint32_t v) { 1058 const ScalableTag<uint64_t> du64; 1059 const svbool_t pg = detail::PTrue(du64); 1060 return svadalp_u64_x(pg, Zero(du64), v); 1061 } 1062 1063 } // namespace detail 1064 #endif // HWY_SVE_HAVE_2 1065 1066 // ------------------------------ SumsOf4 1067 namespace detail { 1068 1069 HWY_INLINE svint32_t SumsOf4(hwy::SignedTag /*type_tag*/, 1070 hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { 1071 return svdot_n_s32(Zero(ScalableTag<int32_t>()), v, 1); 1072 } 1073 1074 HWY_INLINE svuint32_t SumsOf4(hwy::UnsignedTag /*type_tag*/, 1075 hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { 1076 return svdot_n_u32(Zero(ScalableTag<uint32_t>()), v, 1); 1077 } 1078 1079 HWY_INLINE svint64_t SumsOf4(hwy::SignedTag /*type_tag*/, 1080 hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { 1081 return svdot_n_s64(Zero(ScalableTag<int64_t>()), v, 1); 1082 } 1083 1084 HWY_INLINE svuint64_t SumsOf4(hwy::UnsignedTag /*type_tag*/, 1085 hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { 1086 return svdot_n_u64(Zero(ScalableTag<uint64_t>()), v, 1); 1087 } 1088 1089 } // namespace detail 1090 1091 // ------------------------------ SaturatedAdd 1092 1093 #ifdef HWY_NATIVE_I32_SATURATED_ADDSUB 1094 #undef HWY_NATIVE_I32_SATURATED_ADDSUB 1095 #else 1096 #define HWY_NATIVE_I32_SATURATED_ADDSUB 1097 #endif 1098 1099 #ifdef HWY_NATIVE_U32_SATURATED_ADDSUB 1100 #undef HWY_NATIVE_U32_SATURATED_ADDSUB 1101 #else 1102 #define HWY_NATIVE_U32_SATURATED_ADDSUB 1103 #endif 1104 1105 #ifdef HWY_NATIVE_I64_SATURATED_ADDSUB 1106 #undef HWY_NATIVE_I64_SATURATED_ADDSUB 1107 #else 1108 #define HWY_NATIVE_I64_SATURATED_ADDSUB 1109 #endif 1110 1111 #ifdef HWY_NATIVE_U64_SATURATED_ADDSUB 1112 #undef HWY_NATIVE_U64_SATURATED_ADDSUB 1113 #else 1114 #define HWY_NATIVE_U64_SATURATED_ADDSUB 1115 #endif 1116 1117 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) 1118 1119 // ------------------------------ SaturatedSub 1120 1121 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) 1122 1123 // ------------------------------ AbsDiff 1124 #ifdef HWY_NATIVE_INTEGER_ABS_DIFF 1125 #undef HWY_NATIVE_INTEGER_ABS_DIFF 1126 #else 1127 #define HWY_NATIVE_INTEGER_ABS_DIFF 1128 #endif 1129 1130 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) 1131 1132 // ------------------------------ ShiftLeft[Same] 1133 1134 #define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ 1135 template <int kBits> \ 1136 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 1137 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ 1138 } \ 1139 HWY_API HWY_SVE_V(BASE, BITS) \ 1140 NAME##Same(HWY_SVE_V(BASE, BITS) v, int bits) { \ 1141 return sv##OP##_##CHAR##BITS##_x( \ 1142 HWY_SVE_PTRUE(BITS), v, static_cast<HWY_SVE_T(uint, BITS)>(bits)); \ 1143 } 1144 1145 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n) 1146 1147 // ------------------------------ ShiftRight[Same] 1148 1149 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n) 1150 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) 1151 1152 #undef HWY_SVE_SHIFT_N 1153 1154 // ------------------------------ MaskedShift[Left/Right] 1155 1156 #define HWY_SVE_SHIFT_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ 1157 template <int kBits> \ 1158 HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ 1159 auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \ 1160 return sv##OP##_##CHAR##BITS##_z(m, v, shifts); \ 1161 } 1162 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeft, lsl_n) 1163 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRight, asr_n) 1164 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRight, lsr_n) 1165 1166 #undef HWY_SVE_SHIFT_Z 1167 1168 // ------------------------------ MaskedShiftRightOr 1169 1170 #define HWY_SVE_SHIFT_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ 1171 template <int kBits> \ 1172 HWY_API HWY_SVE_V(BASE, BITS) \ 1173 NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ 1174 auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \ 1175 return svsel##_##CHAR##BITS(m, sv##OP##_##CHAR##BITS##_z(m, v, shifts), \ 1176 no); \ 1177 } 1178 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, asr_n) 1179 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, lsr_n) 1180 1181 #undef HWY_SVE_SHIFT_OR 1182 1183 // ------------------------------ RotateRight 1184 1185 #if HWY_SVE_HAVE_2 1186 1187 #define HWY_SVE_ROTATE_RIGHT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ 1188 template <int kBits> \ 1189 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 1190 if (kBits == 0) return v; \ 1191 return sv##OP##_##CHAR##BITS(v, Zero(DFromV<decltype(v)>()), \ 1192 HWY_MAX(kBits, 1)); \ 1193 } 1194 1195 HWY_SVE_FOREACH_U(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) 1196 HWY_SVE_FOREACH_I(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) 1197 1198 #undef HWY_SVE_ROTATE_RIGHT_N 1199 1200 #else // !HWY_SVE_HAVE_2 1201 template <int kBits, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 1202 HWY_API V RotateRight(const V v) { 1203 const DFromV<decltype(v)> d; 1204 const RebindToUnsigned<decltype(d)> du; 1205 1206 constexpr size_t kSizeInBits = sizeof(TFromV<V>) * 8; 1207 static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); 1208 if (kBits == 0) return v; 1209 1210 return Or(BitCast(d, ShiftRight<kBits>(BitCast(du, v))), 1211 ShiftLeft<HWY_MIN(kSizeInBits - 1, kSizeInBits - kBits)>(v)); 1212 } 1213 #endif 1214 1215 // ------------------------------ Shl, Shr 1216 1217 #define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \ 1218 HWY_API HWY_SVE_V(BASE, BITS) \ 1219 NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ 1220 const RebindToUnsigned<DFromV<decltype(v)>> du; \ 1221 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ 1222 BitCast(du, bits)); \ 1223 } 1224 1225 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl) 1226 1227 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr) 1228 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr) 1229 1230 #undef HWY_SVE_SHIFT 1231 1232 // ------------------------------ RoundingShiftLeft[Same]/RoundingShr 1233 1234 #if HWY_SVE_HAVE_2 1235 1236 #ifdef HWY_NATIVE_ROUNDING_SHR 1237 #undef HWY_NATIVE_ROUNDING_SHR 1238 #else 1239 #define HWY_NATIVE_ROUNDING_SHR 1240 #endif 1241 1242 #define HWY_SVE_ROUNDING_SHR_N(BASE, CHAR, BITS, HALF, NAME, OP) \ 1243 template <int kBits> \ 1244 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 1245 HWY_IF_CONSTEXPR(kBits == 0) { return v; } \ 1246 \ 1247 return sv##OP##_##CHAR##BITS##_x( \ 1248 HWY_SVE_PTRUE(BITS), v, static_cast<uint64_t>(HWY_MAX(kBits, 1))); \ 1249 } 1250 1251 HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR_N, RoundingShiftRight, rshr_n) 1252 1253 #undef HWY_SVE_ROUNDING_SHR_N 1254 1255 #define HWY_SVE_ROUNDING_SHR(BASE, CHAR, BITS, HALF, NAME, OP) \ 1256 HWY_API HWY_SVE_V(BASE, BITS) \ 1257 NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ 1258 const RebindToSigned<DFromV<decltype(v)>> di; \ 1259 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ 1260 Neg(BitCast(di, bits))); \ 1261 } 1262 1263 HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR, RoundingShr, rshl) 1264 1265 #undef HWY_SVE_ROUNDING_SHR 1266 1267 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 1268 HWY_API V RoundingShiftRightSame(V v, int bits) { 1269 const DFromV<V> d; 1270 using T = TFromD<decltype(d)>; 1271 return RoundingShr(v, Set(d, static_cast<T>(bits))); 1272 } 1273 1274 #endif // HWY_SVE_HAVE_2 1275 1276 // ------------------------------ BroadcastSignBit (ShiftRight) 1277 template <class V> 1278 HWY_API V BroadcastSignBit(const V v) { 1279 return ShiftRight<sizeof(TFromV<V>) * 8 - 1>(v); 1280 } 1281 1282 // ------------------------------ Abs (ShiftRight, Add, Xor, AndN) 1283 1284 // Workaround for incorrect results with `svabs`. 1285 #if HWY_COMPILER_CLANG 1286 template <class V, HWY_IF_SIGNED_V(V)> 1287 HWY_API V Abs(V v) { 1288 const V sign = BroadcastSignBit(v); 1289 return Xor(Add(v, sign), sign); 1290 } 1291 1292 template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)> 1293 HWY_NOINLINE V Abs(V v) { 1294 const DFromV<V> d; 1295 const RebindToUnsigned<decltype(d)> du; 1296 using TU = MakeUnsigned<TFromD<decltype(d)>>; 1297 return BitCast( 1298 d, detail::AndN(BitCast(du, v), static_cast<TU>(~SignMask<TU>()))); 1299 } 1300 1301 #else 1302 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) 1303 #endif 1304 1305 // ------------------------------ SaturatedAbs 1306 #if HWY_SVE_HAVE_2 1307 #ifdef HWY_NATIVE_SATURATED_ABS 1308 #undef HWY_NATIVE_SATURATED_ABS 1309 #else 1310 #define HWY_NATIVE_SATURATED_ABS 1311 #endif 1312 1313 HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs) 1314 #endif // HWY_SVE_HAVE_2 1315 1316 // ------------------------------ MaskedAbsOr 1317 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs) 1318 1319 // ------------------------------ MaskedAbs 1320 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbs, abs) 1321 1322 // ------------------------------ Mul 1323 1324 // Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. 1325 #ifdef HWY_NATIVE_MUL_8 1326 #undef HWY_NATIVE_MUL_8 1327 #else 1328 #define HWY_NATIVE_MUL_8 1329 #endif 1330 #ifdef HWY_NATIVE_MUL_64 1331 #undef HWY_NATIVE_MUL_64 1332 #else 1333 #define HWY_NATIVE_MUL_64 1334 #endif 1335 1336 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Mul, mul) 1337 1338 // ------------------------------ MulHigh 1339 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) 1340 1341 // ------------------------------ MulFixedPoint15 1342 HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { 1343 #if HWY_SVE_HAVE_2 1344 return svqrdmulh_s16(a, b); 1345 #else 1346 const DFromV<decltype(a)> d; 1347 const RebindToUnsigned<decltype(d)> du; 1348 1349 const svuint16_t lo = BitCast(du, Mul(a, b)); 1350 const svint16_t hi = MulHigh(a, b); 1351 // We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must 1352 // carry that into the result. Instead isolate the top two bits because only 1353 // they can influence the result. 1354 const svuint16_t lo_top2 = ShiftRight<14>(lo); 1355 // Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0. 1356 const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1)); 1357 return Add(Add(hi, hi), BitCast(d, rounding)); 1358 #endif 1359 } 1360 1361 // ------------------------------ Div 1362 #ifdef HWY_NATIVE_INT_DIV 1363 #undef HWY_NATIVE_INT_DIV 1364 #else 1365 #define HWY_NATIVE_INT_DIV 1366 #endif 1367 1368 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, Div, div) 1369 HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPVV, Div, div) 1370 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div) 1371 1372 // ------------------------------ ApproximateReciprocal 1373 #ifdef HWY_NATIVE_F64_APPROX_RECIP 1374 #undef HWY_NATIVE_F64_APPROX_RECIP 1375 #else 1376 #define HWY_NATIVE_F64_APPROX_RECIP 1377 #endif 1378 1379 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe) 1380 1381 // ------------------------------ Sqrt 1382 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) 1383 1384 // ------------------------------ MaskedSqrt 1385 #ifdef HWY_NATIVE_MASKED_SQRT 1386 #undef HWY_NATIVE_MASKED_SQRT 1387 #else 1388 #define HWY_NATIVE_MASKED_SQRT 1389 #endif 1390 1391 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrt, sqrt) 1392 1393 // ------------------------------ ApproximateReciprocalSqrt 1394 #ifdef HWY_NATIVE_F64_APPROX_RSQRT 1395 #undef HWY_NATIVE_F64_APPROX_RSQRT 1396 #else 1397 #define HWY_NATIVE_F64_APPROX_RSQRT 1398 #endif 1399 1400 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte) 1401 1402 // ------------------------------ MulAdd 1403 1404 // Per-target flag to prevent generic_ops-inl.h from defining int MulAdd. 1405 #ifdef HWY_NATIVE_INT_FMA 1406 #undef HWY_NATIVE_INT_FMA 1407 #else 1408 #define HWY_NATIVE_INT_FMA 1409 #endif 1410 1411 #define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ 1412 HWY_API HWY_SVE_V(BASE, BITS) \ 1413 NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ 1414 HWY_SVE_V(BASE, BITS) add) { \ 1415 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \ 1416 } 1417 1418 HWY_SVE_FOREACH(HWY_SVE_FMA, MulAdd, mad) 1419 1420 // ------------------------------ NegMulAdd 1421 HWY_SVE_FOREACH(HWY_SVE_FMA, NegMulAdd, msb) 1422 1423 // ------------------------------ MulSub 1424 HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb) 1425 1426 // ------------------------------ NegMulSub 1427 HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) 1428 1429 #undef HWY_SVE_FMA 1430 1431 // ------------------------------ Round etc. 1432 1433 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) 1434 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm) 1435 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp) 1436 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz) 1437 1438 // ================================================== MASK 1439 1440 // ------------------------------ RebindMask 1441 template <class D, typename MFrom> 1442 HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) { 1443 return mask; 1444 } 1445 1446 // ------------------------------ Mask logical 1447 1448 HWY_API svbool_t Not(svbool_t m) { 1449 // We don't know the lane type, so assume 8-bit. For larger types, this will 1450 // de-canonicalize the predicate, i.e. set bits to 1 even though they do not 1451 // correspond to the lowest byte in the lane. Arm says such bits are ignored. 1452 return svnot_b_z(HWY_SVE_PTRUE(8), m); 1453 } 1454 HWY_API svbool_t And(svbool_t a, svbool_t b) { 1455 return svand_b_z(b, b, a); // same order as AndNot for consistency 1456 } 1457 HWY_API svbool_t AndNot(svbool_t a, svbool_t b) { 1458 return svbic_b_z(b, b, a); // reversed order like NEON 1459 } 1460 HWY_API svbool_t Or(svbool_t a, svbool_t b) { 1461 return svsel_b(a, a, b); // a ? true : b 1462 } 1463 HWY_API svbool_t Xor(svbool_t a, svbool_t b) { 1464 return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b. 1465 } 1466 1467 HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) { 1468 return svnor_b_z(HWY_SVE_PTRUE(8), a, b); // !a && !b, undefined if a && b. 1469 } 1470 1471 // ------------------------------ CountTrue 1472 1473 #define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ 1474 template <size_t N, int kPow2> \ 1475 HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \ 1476 return sv##OP##_b##BITS(detail::MakeMask(d), m); \ 1477 } 1478 1479 HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) 1480 #undef HWY_SVE_COUNT_TRUE 1481 1482 // For 16-bit Compress: full vector, not limited to SV_POW2. 1483 namespace detail { 1484 1485 #define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \ 1486 template <size_t N, int kPow2> \ 1487 HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \ 1488 return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ 1489 } 1490 1491 HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) 1492 #undef HWY_SVE_COUNT_TRUE_FULL 1493 1494 } // namespace detail 1495 1496 // ------------------------------ AllFalse 1497 template <class D> 1498 HWY_API bool AllFalse(D d, svbool_t m) { 1499 return !svptest_any(detail::MakeMask(d), m); 1500 } 1501 1502 // ------------------------------ AllTrue 1503 template <class D> 1504 HWY_API bool AllTrue(D d, svbool_t m) { 1505 return CountTrue(d, m) == Lanes(d); 1506 } 1507 1508 // ------------------------------ FindFirstTrue 1509 template <class D> 1510 HWY_API intptr_t FindFirstTrue(D d, svbool_t m) { 1511 return AllFalse(d, m) ? intptr_t{-1} 1512 : static_cast<intptr_t>( 1513 CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m))); 1514 } 1515 1516 // ------------------------------ FindKnownFirstTrue 1517 template <class D> 1518 HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) { 1519 return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)); 1520 } 1521 1522 // ------------------------------ IfThenElse 1523 #define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \ 1524 HWY_API HWY_SVE_V(BASE, BITS) \ 1525 NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \ 1526 return sv##OP##_##CHAR##BITS(m, yes, no); \ 1527 } 1528 1529 HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) 1530 HWY_SVE_FOREACH_BF16(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) 1531 #undef HWY_SVE_IF_THEN_ELSE 1532 1533 template <class V, class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)> 1534 HWY_API V IfThenElse(const svbool_t mask, V yes, V no) { 1535 const RebindToUnsigned<D> du; 1536 return BitCast( 1537 D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); 1538 } 1539 1540 // ------------------------------ IfThenElseZero 1541 1542 template <class V, class D = DFromV<V>, HWY_SVE_IF_NOT_EMULATED_D(D)> 1543 HWY_API V IfThenElseZero(const svbool_t mask, const V yes) { 1544 return IfThenElse(mask, yes, Zero(D())); 1545 } 1546 1547 template <class V, class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)> 1548 HWY_API V IfThenElseZero(const svbool_t mask, V yes) { 1549 const RebindToUnsigned<D> du; 1550 return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); 1551 } 1552 1553 // ------------------------------ IfThenZeroElse 1554 1555 template <class V, class D = DFromV<V>, HWY_SVE_IF_NOT_EMULATED_D(D)> 1556 HWY_API V IfThenZeroElse(const svbool_t mask, const V no) { 1557 return IfThenElse(mask, Zero(D()), no); 1558 } 1559 1560 template <class V, class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)> 1561 HWY_API V IfThenZeroElse(const svbool_t mask, V no) { 1562 const RebindToUnsigned<D> du; 1563 return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); 1564 } 1565 1566 // ------------------------------ Additional mask logical operations 1567 HWY_API svbool_t SetBeforeFirst(svbool_t m) { 1568 // We don't know the lane type, so assume 8-bit. For larger types, this will 1569 // de-canonicalize the predicate, i.e. set bits to 1 even though they do not 1570 // correspond to the lowest byte in the lane. Arm says such bits are ignored. 1571 return svbrkb_b_z(HWY_SVE_PTRUE(8), m); 1572 } 1573 1574 HWY_API svbool_t SetAtOrBeforeFirst(svbool_t m) { 1575 // We don't know the lane type, so assume 8-bit. For larger types, this will 1576 // de-canonicalize the predicate, i.e. set bits to 1 even though they do not 1577 // correspond to the lowest byte in the lane. Arm says such bits are ignored. 1578 return svbrka_b_z(HWY_SVE_PTRUE(8), m); 1579 } 1580 1581 HWY_API svbool_t SetOnlyFirst(svbool_t m) { return svbrka_b_z(m, m); } 1582 1583 HWY_API svbool_t SetAtOrAfterFirst(svbool_t m) { 1584 return Not(SetBeforeFirst(m)); 1585 } 1586 1587 // ------------------------------ PromoteMaskTo 1588 1589 #ifdef HWY_NATIVE_PROMOTE_MASK_TO 1590 #undef HWY_NATIVE_PROMOTE_MASK_TO 1591 #else 1592 #define HWY_NATIVE_PROMOTE_MASK_TO 1593 #endif 1594 1595 template <class DTo, class DFrom, 1596 HWY_IF_T_SIZE_D(DTo, sizeof(TFromD<DFrom>) * 2)> 1597 HWY_API svbool_t PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { 1598 return svunpklo_b(m); 1599 } 1600 1601 template <class DTo, class DFrom, 1602 HWY_IF_T_SIZE_GT_D(DTo, sizeof(TFromD<DFrom>) * 2)> 1603 HWY_API svbool_t PromoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { 1604 using TFrom = TFromD<DFrom>; 1605 using TWFrom = MakeWide<MakeUnsigned<TFrom>>; 1606 static_assert(sizeof(TWFrom) > sizeof(TFrom), 1607 "sizeof(TWFrom) > sizeof(TFrom) must be true"); 1608 1609 const Rebind<TWFrom, decltype(d_from)> dw_from; 1610 return PromoteMaskTo(d_to, dw_from, PromoteMaskTo(dw_from, d_from, m)); 1611 } 1612 1613 // ------------------------------ DemoteMaskTo 1614 1615 #ifdef HWY_NATIVE_DEMOTE_MASK_TO 1616 #undef HWY_NATIVE_DEMOTE_MASK_TO 1617 #else 1618 #define HWY_NATIVE_DEMOTE_MASK_TO 1619 #endif 1620 1621 template <class DTo, class DFrom, HWY_IF_T_SIZE_D(DTo, 1), 1622 HWY_IF_T_SIZE_D(DFrom, 2)> 1623 HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { 1624 return svuzp1_b8(m, m); 1625 } 1626 1627 template <class DTo, class DFrom, HWY_IF_T_SIZE_D(DTo, 2), 1628 HWY_IF_T_SIZE_D(DFrom, 4)> 1629 HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { 1630 return svuzp1_b16(m, m); 1631 } 1632 1633 template <class DTo, class DFrom, HWY_IF_T_SIZE_D(DTo, 4), 1634 HWY_IF_T_SIZE_D(DFrom, 8)> 1635 HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { 1636 return svuzp1_b32(m, m); 1637 } 1638 1639 template <class DTo, class DFrom, 1640 HWY_IF_T_SIZE_LE_D(DTo, sizeof(TFromD<DFrom>) / 4)> 1641 HWY_API svbool_t DemoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { 1642 using TFrom = TFromD<DFrom>; 1643 using TNFrom = MakeNarrow<MakeUnsigned<TFrom>>; 1644 static_assert(sizeof(TNFrom) < sizeof(TFrom), 1645 "sizeof(TNFrom) < sizeof(TFrom) must be true"); 1646 1647 const Rebind<TNFrom, decltype(d_from)> dn_from; 1648 return DemoteMaskTo(d_to, dn_from, DemoteMaskTo(dn_from, d_from, m)); 1649 } 1650 1651 // ------------------------------ LowerHalfOfMask 1652 #ifdef HWY_NATIVE_LOWER_HALF_OF_MASK 1653 #undef HWY_NATIVE_LOWER_HALF_OF_MASK 1654 #else 1655 #define HWY_NATIVE_LOWER_HALF_OF_MASK 1656 #endif 1657 1658 template <class D> 1659 HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) { 1660 return m; 1661 } 1662 1663 // ------------------------------ MaskedAddOr etc. (IfThenElse) 1664 1665 #ifdef HWY_NATIVE_MASKED_ARITH 1666 #undef HWY_NATIVE_MASKED_ARITH 1667 #else 1668 #define HWY_NATIVE_MASKED_ARITH 1669 #endif 1670 1671 namespace detail { 1672 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedMin, minnm) 1673 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedMax, maxnm) 1674 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedMin, min) 1675 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedMax, max) 1676 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add) 1677 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub) 1678 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul) 1679 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) 1680 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) 1681 HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) 1682 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV, MaskedSqrt, sqrt) 1683 #if HWY_SVE_HAVE_2 1684 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd) 1685 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub) 1686 #endif 1687 } // namespace detail 1688 1689 template <class V, class M> 1690 HWY_API V MaskedMinOr(V no, M m, V a, V b) { 1691 return IfThenElse(m, detail::MaskedMin(m, a, b), no); 1692 } 1693 1694 template <class V, class M> 1695 HWY_API V MaskedMaxOr(V no, M m, V a, V b) { 1696 return IfThenElse(m, detail::MaskedMax(m, a, b), no); 1697 } 1698 1699 template <class V, class M> 1700 HWY_API V MaskedAddOr(V no, M m, V a, V b) { 1701 return IfThenElse(m, detail::MaskedAdd(m, a, b), no); 1702 } 1703 1704 template <class V, class M> 1705 HWY_API V MaskedSubOr(V no, M m, V a, V b) { 1706 return IfThenElse(m, detail::MaskedSub(m, a, b), no); 1707 } 1708 1709 template <class V, class M> 1710 HWY_API V MaskedMulOr(V no, M m, V a, V b) { 1711 return IfThenElse(m, detail::MaskedMul(m, a, b), no); 1712 } 1713 1714 template <class V, class M, 1715 HWY_IF_T_SIZE_ONE_OF_V( 1716 V, (hwy::IsSame<TFromV<V>, hwy::float16_t>() ? (1 << 2) : 0) | 1717 (1 << 4) | (1 << 8))> 1718 HWY_API V MaskedDivOr(V no, M m, V a, V b) { 1719 return IfThenElse(m, detail::MaskedDiv(m, a, b), no); 1720 } 1721 1722 // I8/U8/I16/U16 MaskedDivOr is implemented after I8/U8/I16/U16 Div 1723 1724 #if HWY_SVE_HAVE_2 1725 template <class V, class M> 1726 HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { 1727 return IfThenElse(m, detail::MaskedSatAdd(m, a, b), no); 1728 } 1729 1730 template <class V, class M> 1731 HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { 1732 return IfThenElse(m, detail::MaskedSatSub(m, a, b), no); 1733 } 1734 #else 1735 template <class V, class M> 1736 HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { 1737 return IfThenElse(m, SaturatedAdd(a, b), no); 1738 } 1739 1740 template <class V, class M> 1741 HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { 1742 return IfThenElse(m, SaturatedSub(a, b), no); 1743 } 1744 #endif 1745 1746 // ------------------------------ MaskedMulAddOr 1747 namespace detail { 1748 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad) 1749 } 1750 1751 // Per-target flag to prevent generic_ops-inl.h from defining int 1752 // MaskedMulAddOr. 1753 #ifdef HWY_NATIVE_MASKED_INT_FMA 1754 #undef HWY_NATIVE_MASKED_INT_FMA 1755 #else 1756 #define HWY_NATIVE_MASKED_INT_FMA 1757 #endif 1758 1759 template <class V, class M> 1760 HWY_API V MaskedMulAddOr(V no, M m, V mul, V x, V add) { 1761 return IfThenElse(m, detail::MaskedMulAdd(m, mul, x, add), no); 1762 } 1763 1764 template <class V, HWY_IF_FLOAT_V(V), class M> 1765 HWY_API V MaskedSqrtOr(V no, M m, V v) { 1766 return IfThenElse(m, detail::MaskedSqrt(m, v), no); 1767 } 1768 1769 // ================================================== REDUCE 1770 1771 #ifdef HWY_NATIVE_REDUCE_SCALAR 1772 #undef HWY_NATIVE_REDUCE_SCALAR 1773 #else 1774 #define HWY_NATIVE_REDUCE_SCALAR 1775 #endif 1776 1777 // These return T, suitable for ReduceSum. 1778 namespace detail { 1779 #define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ 1780 HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ 1781 /* The intrinsic returns [u]int64_t; truncate to T so we can broadcast. */ \ 1782 using T = HWY_SVE_T(BASE, BITS); \ 1783 using TU = MakeUnsigned<T>; \ 1784 constexpr uint64_t kMask = LimitsMax<TU>(); \ 1785 return static_cast<T>(static_cast<TU>( \ 1786 static_cast<uint64_t>(sv##OP##_##CHAR##BITS(pg, v)) & kMask)); \ 1787 } 1788 1789 #define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \ 1790 HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ 1791 return sv##OP##_##CHAR##BITS(pg, v); \ 1792 } 1793 1794 // TODO: Remove SumOfLanesM in favor of using MaskedReduceSum 1795 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv) 1796 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv) 1797 1798 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanesM, minv) 1799 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanesM, maxv) 1800 // NaN if all are 1801 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanesM, minnmv) 1802 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv) 1803 1804 #undef HWY_SVE_REDUCE 1805 #undef HWY_SVE_REDUCE_ADD 1806 } // namespace detail 1807 1808 // detail::SumOfLanesM, detail::MinOfLanesM, and detail::MaxOfLanesM is more 1809 // efficient for N=4 I8/U8 reductions on SVE than the default implementations 1810 // of the N=4 I8/U8 ReduceSum/ReduceMin/ReduceMax operations in 1811 // generic_ops-inl.h 1812 #undef HWY_IF_REDUCE_D 1813 #define HWY_IF_REDUCE_D(D) hwy::EnableIf<HWY_MAX_LANES_D(D) != 1>* = nullptr 1814 1815 #ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 1816 #undef HWY_NATIVE_REDUCE_SUM_4_UI8 1817 #else 1818 #define HWY_NATIVE_REDUCE_SUM_4_UI8 1819 #endif 1820 1821 #ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 1822 #undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 1823 #else 1824 #define HWY_NATIVE_REDUCE_MINMAX_4_UI8 1825 #endif 1826 1827 template <class D, HWY_IF_REDUCE_D(D)> 1828 HWY_API TFromD<D> ReduceSum(D d, VFromD<D> v) { 1829 return detail::SumOfLanesM(detail::MakeMask(d), v); 1830 } 1831 1832 template <class D, HWY_IF_REDUCE_D(D)> 1833 HWY_API TFromD<D> ReduceMin(D d, VFromD<D> v) { 1834 return detail::MinOfLanesM(detail::MakeMask(d), v); 1835 } 1836 1837 template <class D, HWY_IF_REDUCE_D(D)> 1838 HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) { 1839 return detail::MaxOfLanesM(detail::MakeMask(d), v); 1840 } 1841 1842 #ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR 1843 #undef HWY_NATIVE_MASKED_REDUCE_SCALAR 1844 #else 1845 #define HWY_NATIVE_MASKED_REDUCE_SCALAR 1846 #endif 1847 1848 template <class D, class M> 1849 HWY_API TFromD<D> MaskedReduceSum(D /*d*/, M m, VFromD<D> v) { 1850 return detail::SumOfLanesM(m, v); 1851 } 1852 template <class D, class M> 1853 HWY_API TFromD<D> MaskedReduceMin(D /*d*/, M m, VFromD<D> v) { 1854 return detail::MinOfLanesM(m, v); 1855 } 1856 template <class D, class M> 1857 HWY_API TFromD<D> MaskedReduceMax(D /*d*/, M m, VFromD<D> v) { 1858 return detail::MaxOfLanesM(m, v); 1859 } 1860 1861 // ------------------------------ SumOfLanes 1862 1863 template <class D, HWY_IF_LANES_GT_D(D, 1)> 1864 HWY_API VFromD<D> SumOfLanes(D d, VFromD<D> v) { 1865 return Set(d, ReduceSum(d, v)); 1866 } 1867 template <class D, HWY_IF_LANES_GT_D(D, 1)> 1868 HWY_API VFromD<D> MinOfLanes(D d, VFromD<D> v) { 1869 return Set(d, ReduceMin(d, v)); 1870 } 1871 template <class D, HWY_IF_LANES_GT_D(D, 1)> 1872 HWY_API VFromD<D> MaxOfLanes(D d, VFromD<D> v) { 1873 return Set(d, ReduceMax(d, v)); 1874 } 1875 1876 // ------------------------------ MaskedAdd etc. (IfThenElse) 1877 1878 #ifdef HWY_NATIVE_ZERO_MASKED_ARITH 1879 #undef HWY_NATIVE_ZERO_MASKED_ARITH 1880 #else 1881 #define HWY_NATIVE_ZERO_MASKED_ARITH 1882 #endif 1883 1884 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedMax, max) 1885 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedAdd, add) 1886 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedSub, sub) 1887 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedMul, mul) 1888 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) 1889 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) 1890 HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) 1891 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV_Z, MaskedMulAdd, mad) 1892 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV_Z, MaskedNegMulAdd, msb) 1893 1894 // I8/U8/I16/U16 MaskedDiv is implemented after I8/U8/I16/U16 Div 1895 1896 #if HWY_SVE_HAVE_2 1897 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedSaturatedAdd, qadd) 1898 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedSaturatedSub, qsub) 1899 #else 1900 template <class V, class M> 1901 HWY_API V MaskedSaturatedAdd(M m, V a, V b) { 1902 return IfThenElseZero(m, SaturatedAdd(a, b)); 1903 } 1904 1905 template <class V, class M> 1906 HWY_API V MaskedSaturatedSub(M m, V a, V b) { 1907 return IfThenElseZero(m, SaturatedSub(a, b)); 1908 } 1909 #endif 1910 1911 template <class V, class M, typename D = DFromV<V>, HWY_IF_I16_D(D)> 1912 HWY_API V MaskedMulFixedPoint15(M m, V a, V b) { 1913 return IfThenElseZero(m, MulFixedPoint15(a, b)); 1914 } 1915 1916 template <class D, class M, HWY_IF_UI32_D(D), 1917 class V16 = VFromD<RepartitionToNarrow<D>>> 1918 HWY_API VFromD<D> MaskedWidenMulPairwiseAdd(D d32, M m, V16 a, V16 b) { 1919 return IfThenElseZero(m, WidenMulPairwiseAdd(d32, a, b)); 1920 } 1921 1922 template <class DF, class M, HWY_IF_F32_D(DF), class VBF> 1923 HWY_API VFromD<DF> MaskedWidenMulPairwiseAdd(DF df, M m, VBF a, VBF b) { 1924 return IfThenElseZero(m, WidenMulPairwiseAdd(df, a, b)); 1925 } 1926 1927 // ================================================== COMPARE 1928 1929 // mask = f(vector, vector) 1930 #define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \ 1931 HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 1932 return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ 1933 } 1934 #define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \ 1935 HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ 1936 return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ 1937 } 1938 1939 // ------------------------------ Eq 1940 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq) 1941 namespace detail { 1942 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n) 1943 } // namespace detail 1944 1945 // ------------------------------ Ne 1946 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne) 1947 namespace detail { 1948 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n) 1949 } // namespace detail 1950 1951 // ------------------------------ Lt 1952 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt) 1953 namespace detail { 1954 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n) 1955 } // namespace detail 1956 1957 // ------------------------------ Le 1958 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Le, cmple) 1959 namespace detail { 1960 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LeN, cmple_n) 1961 } // namespace detail 1962 1963 // ------------------------------ Gt/Ge (swapped order) 1964 template <class V> 1965 HWY_API svbool_t Gt(const V a, const V b) { 1966 return Lt(b, a); 1967 } 1968 template <class V> 1969 HWY_API svbool_t Ge(const V a, const V b) { 1970 return Le(b, a); 1971 } 1972 namespace detail { 1973 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GeN, cmpge_n) 1974 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GtN, cmpgt_n) 1975 } // namespace detail 1976 1977 #undef HWY_SVE_COMPARE 1978 #undef HWY_SVE_COMPARE_N 1979 1980 // ------------------------------ TestBit 1981 template <class V> 1982 HWY_API svbool_t TestBit(const V a, const V bit) { 1983 return detail::NeN(And(a, bit), 0); 1984 } 1985 1986 // ------------------------------ Min/Max (Lt, IfThenElse) 1987 1988 HWY_SVE_FOREACH_U(HWY_SVE_RETV_ARGPVV, Min, min) 1989 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max) 1990 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm) 1991 1992 // Workaround for incorrect results with `svmin`. 1993 #if HWY_COMPILER_CLANG 1994 template <class V, HWY_IF_SIGNED_V(V)> 1995 HWY_API V Min(V a, V b) { 1996 return IfThenElse(Lt(a, b), a, b); 1997 } 1998 template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)> 1999 HWY_API V Min(V a, V b) { 2000 return IfThenElse(Or(Lt(a, b), Ne(b, b)), a, b); 2001 } 2002 #else 2003 HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPVV, Min, min) 2004 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm) 2005 #endif 2006 2007 namespace detail { 2008 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n) 2009 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n) 2010 } // namespace detail 2011 2012 // ================================================== SWIZZLE 2013 2014 // ------------------------------ ConcatEven/ConcatOdd 2015 2016 // WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the 2017 // full vector length, not rounded down to a power of two as we require). 2018 namespace detail { 2019 2020 #define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \ 2021 HWY_INLINE HWY_SVE_V(BASE, BITS) \ 2022 NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ 2023 return sv##OP##_##CHAR##BITS(lo, hi); \ 2024 } 2025 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1) 2026 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2) 2027 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 2028 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, 2029 uzp1) 2030 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, 2031 uzp2) 2032 #endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 2033 #if defined(__ARM_FEATURE_SVE_MATMUL_FP64) 2034 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q) 2035 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) 2036 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 2037 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, 2038 ConcatEvenBlocks, uzp1q) 2039 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, 2040 uzp2q) 2041 #endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 2042 #endif // defined(__ARM_FEATURE_SVE_MATMUL_FP64) 2043 #undef HWY_SVE_CONCAT_EVERY_SECOND 2044 2045 // Used to slide up / shift whole register left; mask indicates which range 2046 // to take from lo, and the rest is filled from hi starting at its lowest. 2047 #define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \ 2048 HWY_API HWY_SVE_V(BASE, BITS) NAME( \ 2049 HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \ 2050 return sv##OP##_##CHAR##BITS(mask, lo, hi); \ 2051 } 2052 HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice) 2053 #if HWY_SVE_HAVE_BF16_FEATURE 2054 HWY_SVE_FOREACH_BF16(HWY_SVE_SPLICE, Splice, splice) 2055 #else 2056 template <class V, HWY_IF_BF16_D(DFromV<V>)> 2057 HWY_INLINE V Splice(V hi, V lo, svbool_t mask) { 2058 const DFromV<V> d; 2059 const RebindToUnsigned<decltype(d)> du; 2060 return BitCast(d, Splice(BitCast(du, hi), BitCast(du, lo), mask)); 2061 } 2062 #endif // HWY_SVE_HAVE_BF16_FEATURE 2063 #undef HWY_SVE_SPLICE 2064 2065 } // namespace detail 2066 2067 template <class D> 2068 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { 2069 #if HWY_SVE_IS_POW2 2070 if (detail::IsFull(d)) return detail::ConcatOddFull(hi, lo); 2071 #endif 2072 const VFromD<D> hi_odd = detail::ConcatOddFull(hi, hi); 2073 const VFromD<D> lo_odd = detail::ConcatOddFull(lo, lo); 2074 return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); 2075 } 2076 2077 template <class D> 2078 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { 2079 #if HWY_SVE_IS_POW2 2080 if (detail::IsFull(d)) return detail::ConcatEvenFull(hi, lo); 2081 #endif 2082 const VFromD<D> hi_odd = detail::ConcatEvenFull(hi, hi); 2083 const VFromD<D> lo_odd = detail::ConcatEvenFull(lo, lo); 2084 return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); 2085 } 2086 2087 HWY_API svuint8_t U8FromU32(const svuint32_t v) { 2088 const DFromV<svuint32_t> du32; 2089 const RepartitionToNarrow<decltype(du32)> du16; 2090 const RepartitionToNarrow<decltype(du16)> du8; 2091 2092 const svuint16_t cast16 = BitCast(du16, v); 2093 const svuint16_t x2 = svuzp1_u16(cast16, cast16); 2094 const svuint8_t cast8 = BitCast(du8, x2); 2095 return svuzp1_u8(cast8, cast8); 2096 } 2097 2098 // ================================================== MASK 2099 2100 // ------------------------------ MaskFromVec (Ne) 2101 template <class V> 2102 HWY_API svbool_t MaskFromVec(const V v) { 2103 using T = TFromV<V>; 2104 return detail::NeN(v, ConvertScalarTo<T>(0)); 2105 } 2106 2107 // ------------------------------ VecFromMask 2108 template <class D> 2109 HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) { 2110 const RebindToSigned<D> di; 2111 // This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which 2112 // requires an extra instruction plus M0 pipeline. 2113 return BitCast(d, IfThenElseZero(mask, Set(di, -1))); 2114 } 2115 2116 // ------------------------------ BitsFromMask (AndN, Shl, ReduceSum, GetLane 2117 // ConcatEvenFull, U8FromU32) 2118 2119 namespace detail { 2120 2121 // For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes. 2122 template <class D, HWY_IF_T_SIZE_D(D, 1)> 2123 HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { 2124 return svdup_n_u8_z(m, 1); 2125 } 2126 template <class D, HWY_IF_T_SIZE_D(D, 2)> 2127 HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { 2128 const ScalableTag<uint8_t> d8; 2129 const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1)); 2130 return detail::ConcatEvenFull(b16, b16); // lower half 2131 } 2132 template <class D, HWY_IF_T_SIZE_D(D, 4)> 2133 HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { 2134 return U8FromU32(svdup_n_u32_z(m, 1)); 2135 } 2136 template <class D, HWY_IF_T_SIZE_D(D, 8)> 2137 HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { 2138 const ScalableTag<uint32_t> d32; 2139 const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1)); 2140 return U8FromU32(detail::ConcatEvenFull(b64, b64)); // lower half 2141 } 2142 2143 // Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane. 2144 HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) { 2145 const ScalableTag<uint8_t> d8; 2146 const ScalableTag<uint16_t> d16; 2147 const ScalableTag<uint32_t> d32; 2148 const ScalableTag<uint64_t> d64; 2149 // TODO(janwas): could use SVE2 BDEP, but it's optional. 2150 x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); 2151 x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); 2152 x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); 2153 return BitCast(d64, x); 2154 } 2155 2156 } // namespace detail 2157 2158 // BitsFromMask is required if `HWY_MAX_BYTES <= 64`, which is true for the 2159 // fixed-size SVE targets. 2160 #if HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE_256 2161 template <class D> 2162 HWY_API uint64_t BitsFromMask(D d, svbool_t mask) { 2163 const Repartition<uint64_t, D> du64; 2164 svuint64_t bits_in_u64 = detail::BitsFromBool(detail::BoolFromMask<D>(mask)); 2165 2166 constexpr size_t N = MaxLanes(d); 2167 static_assert(N < 64, "SVE2_128 and SVE_256 are only 128 or 256 bits"); 2168 const uint64_t valid = (1ull << N) - 1; 2169 HWY_IF_CONSTEXPR(N <= 8) { 2170 // Upper bits are undefined even if N == 8, hence mask. 2171 return GetLane(bits_in_u64) & valid; 2172 } 2173 2174 // Up to 8 of the least-significant bits of each u64 lane are valid. 2175 bits_in_u64 = detail::AndN(bits_in_u64, 0xFF); 2176 2177 // 128-bit vector: only two u64, so avoid ReduceSum. 2178 HWY_IF_CONSTEXPR(HWY_TARGET == HWY_SVE2_128) { 2179 alignas(16) uint64_t lanes[2]; 2180 Store(bits_in_u64, du64, lanes); 2181 // lanes[0] is always valid because we know N > 8, but lanes[1] might 2182 // not be - we may mask it out below. 2183 const uint64_t result = lanes[0] + (lanes[1] << 8); 2184 // 8-bit lanes, no further masking 2185 HWY_IF_CONSTEXPR(N == 16) return result; 2186 return result & valid; 2187 } 2188 2189 // Shift the 8-bit groups into place in each u64 lane. 2190 alignas(32) uint64_t kShifts[4] = {0 * 8, 1 * 8, 2 * 8, 3 * 8}; 2191 bits_in_u64 = Shl(bits_in_u64, Load(du64, kShifts)); 2192 return ReduceSum(du64, bits_in_u64) & valid; 2193 } 2194 2195 #endif // HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE_256 2196 2197 // ------------------------------ IsNegative (Lt) 2198 #ifdef HWY_NATIVE_IS_NEGATIVE 2199 #undef HWY_NATIVE_IS_NEGATIVE 2200 #else 2201 #define HWY_NATIVE_IS_NEGATIVE 2202 #endif 2203 2204 template <class V, HWY_IF_NOT_UNSIGNED_V(V)> 2205 HWY_API svbool_t IsNegative(V v) { 2206 const DFromV<decltype(v)> d; 2207 const RebindToSigned<decltype(d)> di; 2208 using TI = TFromD<decltype(di)>; 2209 2210 return detail::LtN(BitCast(di, v), static_cast<TI>(0)); 2211 } 2212 2213 // ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) 2214 2215 #if HWY_SVE_HAVE_2 2216 2217 #define HWY_SVE_IF_VEC(BASE, CHAR, BITS, HALF, NAME, OP) \ 2218 HWY_API HWY_SVE_V(BASE, BITS) \ 2219 NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \ 2220 HWY_SVE_V(BASE, BITS) no) { \ 2221 return sv##OP##_##CHAR##BITS(yes, no, mask); \ 2222 } 2223 2224 HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl) 2225 #undef HWY_SVE_IF_VEC 2226 2227 template <class V, HWY_IF_FLOAT_V(V)> 2228 HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { 2229 const DFromV<V> d; 2230 const RebindToUnsigned<decltype(d)> du; 2231 return BitCast( 2232 d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no))); 2233 } 2234 2235 #else 2236 2237 template <class V> 2238 HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { 2239 return Or(And(mask, yes), AndNot(mask, no)); 2240 } 2241 2242 #endif // HWY_SVE_HAVE_2 2243 2244 // ------------------------------ BitwiseIfThenElse 2245 2246 #ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE 2247 #undef HWY_NATIVE_BITWISE_IF_THEN_ELSE 2248 #else 2249 #define HWY_NATIVE_BITWISE_IF_THEN_ELSE 2250 #endif 2251 2252 template <class V> 2253 HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { 2254 return IfVecThenElse(mask, yes, no); 2255 } 2256 2257 // ------------------------------ CopySign (BitwiseIfThenElse) 2258 template <class V> 2259 HWY_API V CopySign(const V magn, const V sign) { 2260 const DFromV<decltype(magn)> d; 2261 return BitwiseIfThenElse(SignBit(d), sign, magn); 2262 } 2263 2264 // ------------------------------ CopySignToAbs 2265 template <class V> 2266 HWY_API V CopySignToAbs(const V abs, const V sign) { 2267 #if HWY_SVE_HAVE_2 // CopySign is more efficient than OrAnd 2268 return CopySign(abs, sign); 2269 #else 2270 const DFromV<V> d; 2271 return OrAnd(abs, SignBit(d), sign); 2272 #endif 2273 } 2274 2275 // ------------------------------ Floating-point classification (Ne) 2276 2277 template <class V> 2278 HWY_API svbool_t IsNaN(const V v) { 2279 return Ne(v, v); // could also use cmpuo 2280 } 2281 2282 // Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. 2283 // We use a fused Set/comparison for IsFinite. 2284 #ifdef HWY_NATIVE_ISINF 2285 #undef HWY_NATIVE_ISINF 2286 #else 2287 #define HWY_NATIVE_ISINF 2288 #endif 2289 2290 template <class V> 2291 HWY_API svbool_t IsInf(const V v) { 2292 using T = TFromV<V>; 2293 const DFromV<decltype(v)> d; 2294 const RebindToUnsigned<decltype(d)> du; 2295 const RebindToSigned<decltype(d)> di; 2296 2297 // 'Shift left' to clear the sign bit 2298 const VFromD<decltype(du)> vu = BitCast(du, v); 2299 const VFromD<decltype(du)> v2 = Add(vu, vu); 2300 // Check for exponent=max and mantissa=0. 2301 const VFromD<decltype(di)> max2 = Set(di, hwy::MaxExponentTimes2<T>()); 2302 return RebindMask(d, Eq(v2, BitCast(du, max2))); 2303 } 2304 2305 // Returns whether normal/subnormal/zero. 2306 template <class V> 2307 HWY_API svbool_t IsFinite(const V v) { 2308 using T = TFromV<V>; 2309 const DFromV<decltype(v)> d; 2310 const RebindToUnsigned<decltype(d)> du; 2311 const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison 2312 const VFromD<decltype(du)> vu = BitCast(du, v); 2313 // 'Shift left' to clear the sign bit, then right so we can compare with the 2314 // max exponent (cannot compare with MaxExponentTimes2 directly because it is 2315 // negative and non-negative floats would be greater). 2316 const VFromD<decltype(di)> exp = 2317 BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu))); 2318 return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField<T>())); 2319 } 2320 2321 // ------------------------------ MulByPow2/MulByFloorPow2 2322 2323 #define HWY_SVE_MUL_BY_POW2(BASE, CHAR, BITS, HALF, NAME, OP) \ 2324 HWY_API HWY_SVE_V(BASE, BITS) \ 2325 NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(int, BITS) exp) { \ 2326 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, exp); \ 2327 } 2328 2329 HWY_SVE_FOREACH_F(HWY_SVE_MUL_BY_POW2, MulByPow2, scale) 2330 2331 #undef HWY_SVE_MUL_BY_POW2 2332 2333 // ------------------------------ MaskedEq etc. 2334 #ifdef HWY_NATIVE_MASKED_COMP 2335 #undef HWY_NATIVE_MASKED_COMP 2336 #else 2337 #define HWY_NATIVE_MASKED_COMP 2338 #endif 2339 2340 // mask = f(mask, vector, vector) 2341 #define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ 2342 HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ 2343 HWY_SVE_V(BASE, BITS) b) { \ 2344 return sv##OP##_##CHAR##BITS(m, a, b); \ 2345 } 2346 2347 HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq) 2348 HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne) 2349 HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt) 2350 HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple) 2351 2352 #undef HWY_SVE_COMPARE_Z 2353 2354 template <class V, class M, class D = DFromV<V>> 2355 HWY_API MFromD<D> MaskedGt(M m, V a, V b) { 2356 // Swap args to reverse comparison 2357 return MaskedLt(m, b, a); 2358 } 2359 2360 template <class V, class M, class D = DFromV<V>> 2361 HWY_API MFromD<D> MaskedGe(M m, V a, V b) { 2362 // Swap args to reverse comparison 2363 return MaskedLe(m, b, a); 2364 } 2365 2366 template <class V, class M, class D = DFromV<V>> 2367 HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) { 2368 return MaskedNe(m, v, v); 2369 } 2370 2371 // ================================================== MEMORY 2372 2373 // ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream 2374 2375 #define HWY_SVE_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \ 2376 template <size_t N, int kPow2> \ 2377 HWY_API HWY_SVE_V(BASE, BITS) \ 2378 LoadU(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2379 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ 2380 return svld1_##CHAR##BITS(detail::MakeMask(d), \ 2381 detail::NativeLanePointer(p)); \ 2382 } \ 2383 template <size_t N, int kPow2> \ 2384 HWY_API HWY_SVE_V(BASE, BITS) \ 2385 MaskedLoad(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 2386 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ 2387 return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \ 2388 } \ 2389 template <size_t N, int kPow2> \ 2390 HWY_API void StoreU(HWY_SVE_V(BASE, BITS) v, \ 2391 HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2392 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ 2393 svst1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), v); \ 2394 } \ 2395 template <size_t N, int kPow2> \ 2396 HWY_API void Stream(HWY_SVE_V(BASE, BITS) v, \ 2397 HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2398 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ 2399 svstnt1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \ 2400 v); \ 2401 } \ 2402 template <size_t N, int kPow2> \ 2403 HWY_API void BlendedStore(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ 2404 HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 2405 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ 2406 svst1_##CHAR##BITS(m, detail::NativeLanePointer(p), v); \ 2407 } 2408 2409 HWY_SVE_FOREACH(HWY_SVE_MEM, _, _) 2410 HWY_SVE_FOREACH_BF16(HWY_SVE_MEM, _, _) 2411 2412 template <class D, HWY_SVE_IF_EMULATED_D(D)> 2413 HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) { 2414 const RebindToUnsigned<decltype(d)> du; 2415 return BitCast(d, LoadU(du, detail::U16LanePointer(p))); 2416 } 2417 2418 template <class D, HWY_SVE_IF_EMULATED_D(D)> 2419 HWY_API void StoreU(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT p) { 2420 const RebindToUnsigned<decltype(d)> du; 2421 StoreU(BitCast(du, v), du, detail::U16LanePointer(p)); 2422 } 2423 2424 template <class D, HWY_SVE_IF_EMULATED_D(D)> 2425 HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d, 2426 const TFromD<D>* HWY_RESTRICT p) { 2427 const RebindToUnsigned<decltype(d)> du; 2428 return BitCast(d, 2429 MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p))); 2430 } 2431 2432 // MaskedLoadOr is generic and does not require emulation. 2433 2434 template <class D, HWY_SVE_IF_EMULATED_D(D)> 2435 HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d, 2436 TFromD<D>* HWY_RESTRICT p) { 2437 const RebindToUnsigned<decltype(d)> du; 2438 BlendedStore(BitCast(du, v), RebindMask(du, m), du, 2439 detail::U16LanePointer(p)); 2440 } 2441 2442 #undef HWY_SVE_MEM 2443 2444 #if HWY_TARGET != HWY_SVE2_128 2445 namespace detail { 2446 #define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ 2447 template <size_t N, int kPow2> \ 2448 HWY_API HWY_SVE_V(BASE, BITS) \ 2449 NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 2450 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ 2451 /* All-true predicate to load all 128 bits. */ \ 2452 return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), \ 2453 detail::NativeLanePointer(p)); \ 2454 } 2455 2456 HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) 2457 HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) 2458 2459 template <class D, HWY_SVE_IF_EMULATED_D(D)> 2460 HWY_API VFromD<D> LoadDupFull128(D d, const TFromD<D>* HWY_RESTRICT p) { 2461 const RebindToUnsigned<decltype(d)> du; 2462 return BitCast(d, LoadDupFull128(du, detail::U16LanePointer(p))); 2463 } 2464 2465 } // namespace detail 2466 #endif // HWY_TARGET != HWY_SVE2_128 2467 2468 #if HWY_TARGET == HWY_SVE2_128 2469 // On the HWY_SVE2_128 target, LoadDup128 is the same as LoadU since vectors 2470 // cannot exceed 16 bytes on the HWY_SVE2_128 target. 2471 template <class D> 2472 HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) { 2473 return LoadU(d, p); 2474 } 2475 #else // HWY_TARGET != HWY_SVE2_128 2476 // If D().MaxBytes() <= 16 is true, simply do a LoadU operation. 2477 template <class D, HWY_IF_V_SIZE_LE_D(D, 16)> 2478 HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) { 2479 return LoadU(d, p); 2480 } 2481 2482 // If D().MaxBytes() > 16 is true, need to load the vector using ld1rq 2483 template <class D, HWY_IF_V_SIZE_GT_D(D, 16)> 2484 HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) { 2485 return detail::LoadDupFull128(d, p); 2486 } 2487 2488 #endif // HWY_TARGET != HWY_SVE2_128 2489 2490 // Truncate to smaller size and store 2491 #ifdef HWY_NATIVE_STORE_TRUNCATED 2492 #undef HWY_NATIVE_STORE_TRUNCATED 2493 #else 2494 #define HWY_NATIVE_STORE_TRUNCATED 2495 #endif 2496 2497 #define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \ 2498 template <size_t N, int kPow2> \ 2499 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ 2500 const HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2501 HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \ 2502 sv##OP##_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \ 2503 v); \ 2504 } 2505 2506 #define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \ 2507 HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8) 2508 #define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \ 2509 HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16) 2510 #define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \ 2511 HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32) 2512 2513 HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) 2514 HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) 2515 HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) 2516 HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h) 2517 HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h) 2518 HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, TruncateStore, st1w) 2519 2520 #undef HWY_SVE_STORE_TRUNCATED 2521 2522 // ------------------------------ Load/Store 2523 2524 // SVE only requires lane alignment, not natural alignment of the entire 2525 // vector, so Load/Store are the same as LoadU/StoreU. 2526 template <class D> 2527 HWY_API VFromD<D> Load(D d, const TFromD<D>* HWY_RESTRICT p) { 2528 return LoadU(d, p); 2529 } 2530 2531 template <class V, class D> 2532 HWY_API void Store(const V v, D d, TFromD<D>* HWY_RESTRICT p) { 2533 StoreU(v, d, p); 2534 } 2535 2536 // ------------------------------ MaskedLoadOr 2537 2538 // SVE MaskedLoad hard-codes zero, so this requires an extra blend. 2539 template <class D> 2540 HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D d, 2541 const TFromD<D>* HWY_RESTRICT p) { 2542 return IfThenElse(m, MaskedLoad(m, d, p), v); 2543 } 2544 2545 // ------------------------------ ScatterOffset/Index 2546 2547 #ifdef HWY_NATIVE_SCATTER 2548 #undef HWY_NATIVE_SCATTER 2549 #else 2550 #define HWY_NATIVE_SCATTER 2551 #endif 2552 2553 #define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ 2554 template <size_t N, int kPow2> \ 2555 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ 2556 HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2557 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ 2558 HWY_SVE_V(int, BITS) offset) { \ 2559 sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \ 2560 v); \ 2561 } 2562 2563 #define HWY_SVE_MASKED_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ 2564 template <size_t N, int kPow2> \ 2565 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ 2566 HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, \ 2567 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ 2568 HWY_SVE_V(int, BITS) indices) { \ 2569 sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices, v); \ 2570 } 2571 2572 HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) 2573 HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_SCATTER_INDEX, MaskedScatterIndex, 2574 st1_scatter) 2575 #undef HWY_SVE_SCATTER_OFFSET 2576 #undef HWY_SVE_MASKED_SCATTER_INDEX 2577 2578 template <class D> 2579 HWY_API void ScatterIndex(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT p, 2580 VFromD<RebindToSigned<D>> indices) { 2581 MaskedScatterIndex(v, detail::MakeMask(d), d, p, indices); 2582 } 2583 2584 // ------------------------------ GatherOffset/Index 2585 2586 #ifdef HWY_NATIVE_GATHER 2587 #undef HWY_NATIVE_GATHER 2588 #else 2589 #define HWY_NATIVE_GATHER 2590 #endif 2591 2592 #define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ 2593 template <size_t N, int kPow2> \ 2594 HWY_API HWY_SVE_V(BASE, BITS) \ 2595 NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2596 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ 2597 HWY_SVE_V(int, BITS) offset) { \ 2598 return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \ 2599 offset); \ 2600 } 2601 #define HWY_SVE_MASKED_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ 2602 template <size_t N, int kPow2> \ 2603 HWY_API HWY_SVE_V(BASE, BITS) \ 2604 NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2605 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ 2606 HWY_SVE_V(int, BITS) indices) { \ 2607 const RebindToSigned<decltype(d)> di; \ 2608 (void)di; /* for HWY_DASSERT */ \ 2609 HWY_DASSERT(AllFalse(di, Lt(indices, Zero(di)))); \ 2610 return sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices); \ 2611 } 2612 2613 HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) 2614 HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_GATHER_INDEX, MaskedGatherIndex, 2615 ld1_gather) 2616 #undef HWY_SVE_GATHER_OFFSET 2617 #undef HWY_SVE_MASKED_GATHER_INDEX 2618 2619 template <class D> 2620 HWY_API VFromD<D> MaskedGatherIndexOr(VFromD<D> no, svbool_t m, D d, 2621 const TFromD<D>* HWY_RESTRICT p, 2622 VFromD<RebindToSigned<D>> indices) { 2623 return IfThenElse(m, MaskedGatherIndex(m, d, p, indices), no); 2624 } 2625 2626 template <class D> 2627 HWY_API VFromD<D> GatherIndex(D d, const TFromD<D>* HWY_RESTRICT p, 2628 VFromD<RebindToSigned<D>> indices) { 2629 return MaskedGatherIndex(detail::MakeMask(d), d, p, indices); 2630 } 2631 2632 // ------------------------------ LoadInterleaved2 2633 2634 // Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. 2635 #ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED 2636 #undef HWY_NATIVE_LOAD_STORE_INTERLEAVED 2637 #else 2638 #define HWY_NATIVE_LOAD_STORE_INTERLEAVED 2639 #endif 2640 2641 #define HWY_SVE_LOAD2(BASE, CHAR, BITS, HALF, NAME, OP) \ 2642 template <size_t N, int kPow2> \ 2643 HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2644 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ 2645 HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \ 2646 const HWY_SVE_TUPLE(BASE, BITS, 2) tuple = sv##OP##_##CHAR##BITS( \ 2647 detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ 2648 v0 = svget2(tuple, 0); \ 2649 v1 = svget2(tuple, 1); \ 2650 } 2651 HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) 2652 HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD2, LoadInterleaved2, ld2) 2653 2654 #undef HWY_SVE_LOAD2 2655 2656 // ------------------------------ LoadInterleaved3 2657 2658 #define HWY_SVE_LOAD3(BASE, CHAR, BITS, HALF, NAME, OP) \ 2659 template <size_t N, int kPow2> \ 2660 HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2661 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ 2662 HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ 2663 HWY_SVE_V(BASE, BITS) & v2) { \ 2664 const HWY_SVE_TUPLE(BASE, BITS, 3) tuple = sv##OP##_##CHAR##BITS( \ 2665 detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ 2666 v0 = svget3(tuple, 0); \ 2667 v1 = svget3(tuple, 1); \ 2668 v2 = svget3(tuple, 2); \ 2669 } 2670 HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) 2671 HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD3, LoadInterleaved3, ld3) 2672 2673 #undef HWY_SVE_LOAD3 2674 2675 // ------------------------------ LoadInterleaved4 2676 2677 #define HWY_SVE_LOAD4(BASE, CHAR, BITS, HALF, NAME, OP) \ 2678 template <size_t N, int kPow2> \ 2679 HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2680 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ 2681 HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ 2682 HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \ 2683 const HWY_SVE_TUPLE(BASE, BITS, 4) tuple = sv##OP##_##CHAR##BITS( \ 2684 detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ 2685 v0 = svget4(tuple, 0); \ 2686 v1 = svget4(tuple, 1); \ 2687 v2 = svget4(tuple, 2); \ 2688 v3 = svget4(tuple, 3); \ 2689 } 2690 HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4) 2691 HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD4, LoadInterleaved4, ld4) 2692 2693 #undef HWY_SVE_LOAD4 2694 2695 // ------------------------------ StoreInterleaved2 2696 2697 #define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ 2698 template <size_t N, int kPow2> \ 2699 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ 2700 HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2701 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ 2702 sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ 2703 detail::NativeLanePointer(unaligned), \ 2704 Create2(d, v0, v1)); \ 2705 } 2706 HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) 2707 HWY_SVE_FOREACH_BF16(HWY_SVE_STORE2, StoreInterleaved2, st2) 2708 2709 #undef HWY_SVE_STORE2 2710 2711 // ------------------------------ StoreInterleaved3 2712 2713 #define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \ 2714 template <size_t N, int kPow2> \ 2715 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ 2716 HWY_SVE_V(BASE, BITS) v2, \ 2717 HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2718 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ 2719 sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ 2720 detail::NativeLanePointer(unaligned), \ 2721 Create3(d, v0, v1, v2)); \ 2722 } 2723 HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) 2724 HWY_SVE_FOREACH_BF16(HWY_SVE_STORE3, StoreInterleaved3, st3) 2725 2726 #undef HWY_SVE_STORE3 2727 2728 // ------------------------------ StoreInterleaved4 2729 2730 #define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \ 2731 template <size_t N, int kPow2> \ 2732 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ 2733 HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ 2734 HWY_SVE_D(BASE, BITS, N, kPow2) d, \ 2735 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ 2736 sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ 2737 detail::NativeLanePointer(unaligned), \ 2738 Create4(d, v0, v1, v2, v3)); \ 2739 } 2740 HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4) 2741 HWY_SVE_FOREACH_BF16(HWY_SVE_STORE4, StoreInterleaved4, st4) 2742 2743 #undef HWY_SVE_STORE4 2744 2745 // Fall back on generic Load/StoreInterleaved[234] for any emulated types. 2746 // Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_SVE_IF_EMULATED_D. 2747 2748 // ================================================== CONVERT 2749 2750 // ------------------------------ PromoteTo 2751 2752 // Same sign 2753 #define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \ 2754 template <size_t N, int kPow2> \ 2755 HWY_API HWY_SVE_V(BASE, BITS) NAME( \ 2756 HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \ 2757 return sv##OP##_##CHAR##BITS(v); \ 2758 } 2759 2760 HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) 2761 HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) 2762 HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) 2763 2764 // 2x 2765 template <size_t N, int kPow2> 2766 HWY_API svuint32_t PromoteTo(Simd<uint32_t, N, kPow2> dto, svuint8_t vfrom) { 2767 const RepartitionToWide<DFromV<decltype(vfrom)>> d2; 2768 return PromoteTo(dto, PromoteTo(d2, vfrom)); 2769 } 2770 template <size_t N, int kPow2> 2771 HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svint8_t vfrom) { 2772 const RepartitionToWide<DFromV<decltype(vfrom)>> d2; 2773 return PromoteTo(dto, PromoteTo(d2, vfrom)); 2774 } 2775 template <size_t N, int kPow2> 2776 HWY_API svuint64_t PromoteTo(Simd<uint64_t, N, kPow2> dto, svuint16_t vfrom) { 2777 const RepartitionToWide<DFromV<decltype(vfrom)>> d2; 2778 return PromoteTo(dto, PromoteTo(d2, vfrom)); 2779 } 2780 template <size_t N, int kPow2> 2781 HWY_API svint64_t PromoteTo(Simd<int64_t, N, kPow2> dto, svint16_t vfrom) { 2782 const RepartitionToWide<DFromV<decltype(vfrom)>> d2; 2783 return PromoteTo(dto, PromoteTo(d2, vfrom)); 2784 } 2785 2786 // 3x 2787 template <size_t N, int kPow2> 2788 HWY_API svuint64_t PromoteTo(Simd<uint64_t, N, kPow2> dto, svuint8_t vfrom) { 2789 const RepartitionToNarrow<decltype(dto)> d4; 2790 const RepartitionToNarrow<decltype(d4)> d2; 2791 return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom))); 2792 } 2793 template <size_t N, int kPow2> 2794 HWY_API svint64_t PromoteTo(Simd<int64_t, N, kPow2> dto, svint8_t vfrom) { 2795 const RepartitionToNarrow<decltype(dto)> d4; 2796 const RepartitionToNarrow<decltype(d4)> d2; 2797 return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom))); 2798 } 2799 2800 // Sign change 2801 template <class D, class V, HWY_IF_SIGNED_D(D), HWY_IF_UNSIGNED_V(V), 2802 HWY_IF_LANES_GT(sizeof(TFromD<D>), sizeof(TFromV<V>))> 2803 HWY_API VFromD<D> PromoteTo(D di, V v) { 2804 const RebindToUnsigned<decltype(di)> du; 2805 return BitCast(di, PromoteTo(du, v)); 2806 } 2807 2808 // ------------------------------ PromoteTo F 2809 2810 // Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. 2811 #ifdef HWY_NATIVE_F16C 2812 #undef HWY_NATIVE_F16C 2813 #else 2814 #define HWY_NATIVE_F16C 2815 #endif 2816 2817 // Unlike Highway's ZipLower, this returns the same type. 2818 namespace detail { 2819 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLowerSame, zip1) 2820 } // namespace detail 2821 2822 template <size_t N, int kPow2> 2823 HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> /* d */, 2824 const svfloat16_t v) { 2825 // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so 2826 // first replicate each lane once. 2827 const svfloat16_t vv = detail::ZipLowerSame(v, v); 2828 return svcvt_f32_f16_x(detail::PTrue(Simd<float16_t, N, kPow2>()), vv); 2829 } 2830 2831 #ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 2832 #undef HWY_NATIVE_PROMOTE_F16_TO_F64 2833 #else 2834 #define HWY_NATIVE_PROMOTE_F16_TO_F64 2835 #endif 2836 2837 template <size_t N, int kPow2> 2838 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, 2839 const svfloat16_t v) { 2840 // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so 2841 // first replicate each lane once. 2842 const svfloat16_t vv = detail::ZipLowerSame(v, v); 2843 return svcvt_f64_f16_x(detail::PTrue(Simd<float16_t, N, kPow2>()), 2844 detail::ZipLowerSame(vv, vv)); 2845 } 2846 2847 template <size_t N, int kPow2> 2848 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, 2849 const svfloat32_t v) { 2850 const svfloat32_t vv = detail::ZipLowerSame(v, v); 2851 return svcvt_f64_f32_x(detail::PTrue(Simd<float32_t, N, kPow2>()), vv); 2852 } 2853 2854 template <size_t N, int kPow2> 2855 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, 2856 const svint32_t v) { 2857 const svint32_t vv = detail::ZipLowerSame(v, v); 2858 return svcvt_f64_s32_x(detail::PTrue(Simd<int32_t, N, kPow2>()), vv); 2859 } 2860 2861 template <size_t N, int kPow2> 2862 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, 2863 const svuint32_t v) { 2864 const svuint32_t vv = detail::ZipLowerSame(v, v); 2865 return svcvt_f64_u32_x(detail::PTrue(Simd<uint32_t, N, kPow2>()), vv); 2866 } 2867 2868 template <size_t N, int kPow2> 2869 HWY_API svint64_t PromoteTo(Simd<int64_t, N, kPow2> /* d */, 2870 const svfloat32_t v) { 2871 const svfloat32_t vv = detail::ZipLowerSame(v, v); 2872 return svcvt_s64_f32_x(detail::PTrue(Simd<float, N, kPow2>()), vv); 2873 } 2874 2875 template <size_t N, int kPow2> 2876 HWY_API svuint64_t PromoteTo(Simd<uint64_t, N, kPow2> /* d */, 2877 const svfloat32_t v) { 2878 const svfloat32_t vv = detail::ZipLowerSame(v, v); 2879 return svcvt_u64_f32_x(detail::PTrue(Simd<float, N, kPow2>()), vv); 2880 } 2881 2882 // ------------------------------ PromoteUpperTo 2883 2884 namespace detail { 2885 HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) 2886 HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) 2887 HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) 2888 #undef HWY_SVE_PROMOTE_TO 2889 } // namespace detail 2890 2891 #ifdef HWY_NATIVE_PROMOTE_UPPER_TO 2892 #undef HWY_NATIVE_PROMOTE_UPPER_TO 2893 #else 2894 #define HWY_NATIVE_PROMOTE_UPPER_TO 2895 #endif 2896 2897 // Unsigned->Unsigned or Signed->Signed 2898 template <class D, class V, typename TD = TFromD<D>, typename TV = TFromV<V>, 2899 hwy::EnableIf<IsInteger<TD>() && IsInteger<TV>() && 2900 (IsSigned<TD>() == IsSigned<TV>())>* = nullptr> 2901 HWY_API VFromD<D> PromoteUpperTo(D d, V v) { 2902 if (detail::IsFull(d)) { 2903 return detail::PromoteUpperTo(d, v); 2904 } 2905 const Rebind<TFromV<V>, decltype(d)> dh; 2906 return PromoteTo(d, UpperHalf(dh, v)); 2907 } 2908 2909 // Differing signs or either is float 2910 template <class D, class V, typename TD = TFromD<D>, typename TV = TFromV<V>, 2911 hwy::EnableIf<!IsInteger<TD>() || !IsInteger<TV>() || 2912 (IsSigned<TD>() != IsSigned<TV>())>* = nullptr> 2913 HWY_API VFromD<D> PromoteUpperTo(D d, V v) { 2914 // Lanes(d) may differ from Lanes(DFromV<V>()). Use the lane type from V 2915 // because it cannot be deduced from D (could be either bf16 or f16). 2916 const Rebind<TFromV<V>, decltype(d)> dh; 2917 return PromoteTo(d, UpperHalf(dh, v)); 2918 } 2919 2920 // ------------------------------ DemoteTo U 2921 2922 namespace detail { 2923 2924 // Saturates unsigned vectors to half/quarter-width TN. 2925 template <typename TN, class VU> 2926 VU SaturateU(VU v) { 2927 return detail::MinN(v, static_cast<TFromV<VU>>(LimitsMax<TN>())); 2928 } 2929 2930 // Saturates unsigned vectors to half/quarter-width TN. 2931 template <typename TN, class VI> 2932 VI SaturateI(VI v) { 2933 return detail::MinN(detail::MaxN(v, LimitsMin<TN>()), LimitsMax<TN>()); 2934 } 2935 2936 } // namespace detail 2937 2938 template <size_t N, int kPow2> 2939 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint16_t v) { 2940 #if HWY_SVE_HAVE_2 2941 const svuint8_t vn = BitCast(dn, svqxtunb_s16(v)); 2942 #else 2943 const DFromV<decltype(v)> di; 2944 const RebindToUnsigned<decltype(di)> du; 2945 using TN = TFromD<decltype(dn)>; 2946 // First clamp negative numbers to zero and cast to unsigned. 2947 const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0)); 2948 // Saturate to unsigned-max and halve the width. 2949 const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(clamped)); 2950 #endif 2951 return svuzp1_u8(vn, vn); 2952 } 2953 2954 template <size_t N, int kPow2> 2955 HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svint32_t v) { 2956 #if HWY_SVE_HAVE_2 2957 const svuint16_t vn = BitCast(dn, svqxtunb_s32(v)); 2958 #else 2959 const DFromV<decltype(v)> di; 2960 const RebindToUnsigned<decltype(di)> du; 2961 using TN = TFromD<decltype(dn)>; 2962 // First clamp negative numbers to zero and cast to unsigned. 2963 const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); 2964 // Saturate to unsigned-max and halve the width. 2965 const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(clamped)); 2966 #endif 2967 return svuzp1_u16(vn, vn); 2968 } 2969 2970 template <size_t N, int kPow2> 2971 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint32_t v) { 2972 const DFromV<decltype(v)> di; 2973 const RebindToUnsigned<decltype(di)> du; 2974 const RepartitionToNarrow<decltype(du)> d2; 2975 #if HWY_SVE_HAVE_2 2976 const svuint16_t cast16 = BitCast(d2, svqxtnb_u16(svqxtunb_s32(v))); 2977 #else 2978 using TN = TFromD<decltype(dn)>; 2979 // First clamp negative numbers to zero and cast to unsigned. 2980 const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); 2981 // Saturate to unsigned-max and quarter the width. 2982 const svuint16_t cast16 = BitCast(d2, detail::SaturateU<TN>(clamped)); 2983 #endif 2984 const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16)); 2985 return svuzp1_u8(x2, x2); 2986 } 2987 2988 template <size_t N, int kPow2> 2989 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint16_t v) { 2990 #if HWY_SVE_HAVE_2 2991 const svuint8_t vn = BitCast(dn, svqxtnb_u16(v)); 2992 #else 2993 using TN = TFromD<decltype(dn)>; 2994 const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(v)); 2995 #endif 2996 return svuzp1_u8(vn, vn); 2997 } 2998 2999 template <size_t N, int kPow2> 3000 HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svuint32_t v) { 3001 #if HWY_SVE_HAVE_2 3002 const svuint16_t vn = BitCast(dn, svqxtnb_u32(v)); 3003 #else 3004 using TN = TFromD<decltype(dn)>; 3005 const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(v)); 3006 #endif 3007 return svuzp1_u16(vn, vn); 3008 } 3009 3010 template <size_t N, int kPow2> 3011 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint32_t v) { 3012 using TN = TFromD<decltype(dn)>; 3013 return U8FromU32(detail::SaturateU<TN>(v)); 3014 } 3015 3016 // ------------------------------ Truncations 3017 3018 template <size_t N, int kPow2> 3019 HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, 3020 const svuint64_t v) { 3021 const DFromV<svuint8_t> d; 3022 const svuint8_t v1 = BitCast(d, v); 3023 const svuint8_t v2 = svuzp1_u8(v1, v1); 3024 const svuint8_t v3 = svuzp1_u8(v2, v2); 3025 return svuzp1_u8(v3, v3); 3026 } 3027 3028 template <size_t N, int kPow2> 3029 HWY_API svuint16_t TruncateTo(Simd<uint16_t, N, kPow2> /* tag */, 3030 const svuint64_t v) { 3031 const DFromV<svuint16_t> d; 3032 const svuint16_t v1 = BitCast(d, v); 3033 const svuint16_t v2 = svuzp1_u16(v1, v1); 3034 return svuzp1_u16(v2, v2); 3035 } 3036 3037 template <size_t N, int kPow2> 3038 HWY_API svuint32_t TruncateTo(Simd<uint32_t, N, kPow2> /* tag */, 3039 const svuint64_t v) { 3040 const DFromV<svuint32_t> d; 3041 const svuint32_t v1 = BitCast(d, v); 3042 return svuzp1_u32(v1, v1); 3043 } 3044 3045 template <size_t N, int kPow2> 3046 HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, 3047 const svuint32_t v) { 3048 const DFromV<svuint8_t> d; 3049 const svuint8_t v1 = BitCast(d, v); 3050 const svuint8_t v2 = svuzp1_u8(v1, v1); 3051 return svuzp1_u8(v2, v2); 3052 } 3053 3054 template <size_t N, int kPow2> 3055 HWY_API svuint16_t TruncateTo(Simd<uint16_t, N, kPow2> /* tag */, 3056 const svuint32_t v) { 3057 const DFromV<svuint16_t> d; 3058 const svuint16_t v1 = BitCast(d, v); 3059 return svuzp1_u16(v1, v1); 3060 } 3061 3062 template <size_t N, int kPow2> 3063 HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, 3064 const svuint16_t v) { 3065 const DFromV<svuint8_t> d; 3066 const svuint8_t v1 = BitCast(d, v); 3067 return svuzp1_u8(v1, v1); 3068 } 3069 3070 // ------------------------------ DemoteTo I 3071 3072 template <size_t N, int kPow2> 3073 HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint16_t v) { 3074 #if HWY_SVE_HAVE_2 3075 const svint8_t vn = BitCast(dn, svqxtnb_s16(v)); 3076 #else 3077 using TN = TFromD<decltype(dn)>; 3078 const svint8_t vn = BitCast(dn, detail::SaturateI<TN>(v)); 3079 #endif 3080 return svuzp1_s8(vn, vn); 3081 } 3082 3083 template <size_t N, int kPow2> 3084 HWY_API svint16_t DemoteTo(Simd<int16_t, N, kPow2> dn, const svint32_t v) { 3085 #if HWY_SVE_HAVE_2 3086 const svint16_t vn = BitCast(dn, svqxtnb_s32(v)); 3087 #else 3088 using TN = TFromD<decltype(dn)>; 3089 const svint16_t vn = BitCast(dn, detail::SaturateI<TN>(v)); 3090 #endif 3091 return svuzp1_s16(vn, vn); 3092 } 3093 3094 template <size_t N, int kPow2> 3095 HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint32_t v) { 3096 const RepartitionToWide<decltype(dn)> d2; 3097 #if HWY_SVE_HAVE_2 3098 const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v))); 3099 #else 3100 using TN = TFromD<decltype(dn)>; 3101 const svint16_t cast16 = BitCast(d2, detail::SaturateI<TN>(v)); 3102 #endif 3103 const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16)); 3104 return BitCast(dn, svuzp1_s8(v2, v2)); 3105 } 3106 3107 // ------------------------------ I64/U64 DemoteTo 3108 3109 template <size_t N, int kPow2> 3110 HWY_API svint32_t DemoteTo(Simd<int32_t, N, kPow2> dn, const svint64_t v) { 3111 const Rebind<uint64_t, decltype(dn)> du64; 3112 const RebindToUnsigned<decltype(dn)> dn_u; 3113 #if HWY_SVE_HAVE_2 3114 const svuint64_t vn = BitCast(du64, svqxtnb_s64(v)); 3115 #else 3116 using TN = TFromD<decltype(dn)>; 3117 const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v)); 3118 #endif 3119 return BitCast(dn, TruncateTo(dn_u, vn)); 3120 } 3121 3122 template <size_t N, int kPow2> 3123 HWY_API svint16_t DemoteTo(Simd<int16_t, N, kPow2> dn, const svint64_t v) { 3124 const Rebind<uint64_t, decltype(dn)> du64; 3125 const RebindToUnsigned<decltype(dn)> dn_u; 3126 #if HWY_SVE_HAVE_2 3127 const svuint64_t vn = BitCast(du64, svqxtnb_s32(svqxtnb_s64(v))); 3128 #else 3129 using TN = TFromD<decltype(dn)>; 3130 const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v)); 3131 #endif 3132 return BitCast(dn, TruncateTo(dn_u, vn)); 3133 } 3134 3135 template <size_t N, int kPow2> 3136 HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint64_t v) { 3137 const Rebind<uint64_t, decltype(dn)> du64; 3138 const RebindToUnsigned<decltype(dn)> dn_u; 3139 using TN = TFromD<decltype(dn)>; 3140 const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v)); 3141 return BitCast(dn, TruncateTo(dn_u, vn)); 3142 } 3143 3144 template <size_t N, int kPow2> 3145 HWY_API svuint32_t DemoteTo(Simd<uint32_t, N, kPow2> dn, const svint64_t v) { 3146 const Rebind<uint64_t, decltype(dn)> du64; 3147 #if HWY_SVE_HAVE_2 3148 const svuint64_t vn = BitCast(du64, svqxtunb_s64(v)); 3149 #else 3150 using TN = TFromD<decltype(dn)>; 3151 // First clamp negative numbers to zero and cast to unsigned. 3152 const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); 3153 // Saturate to unsigned-max 3154 const svuint64_t vn = detail::SaturateU<TN>(clamped); 3155 #endif 3156 return TruncateTo(dn, vn); 3157 } 3158 3159 template <size_t N, int kPow2> 3160 HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svint64_t v) { 3161 const Rebind<uint64_t, decltype(dn)> du64; 3162 #if HWY_SVE_HAVE_2 3163 const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtunb_s64(v))); 3164 #else 3165 using TN = TFromD<decltype(dn)>; 3166 // First clamp negative numbers to zero and cast to unsigned. 3167 const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); 3168 // Saturate to unsigned-max 3169 const svuint64_t vn = detail::SaturateU<TN>(clamped); 3170 #endif 3171 return TruncateTo(dn, vn); 3172 } 3173 3174 template <size_t N, int kPow2> 3175 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint64_t v) { 3176 const Rebind<uint64_t, decltype(dn)> du64; 3177 using TN = TFromD<decltype(dn)>; 3178 // First clamp negative numbers to zero and cast to unsigned. 3179 const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); 3180 // Saturate to unsigned-max 3181 const svuint64_t vn = detail::SaturateU<TN>(clamped); 3182 return TruncateTo(dn, vn); 3183 } 3184 3185 template <size_t N, int kPow2> 3186 HWY_API svuint32_t DemoteTo(Simd<uint32_t, N, kPow2> dn, const svuint64_t v) { 3187 const Rebind<uint64_t, decltype(dn)> du64; 3188 #if HWY_SVE_HAVE_2 3189 const svuint64_t vn = BitCast(du64, svqxtnb_u64(v)); 3190 #else 3191 using TN = TFromD<decltype(dn)>; 3192 const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v)); 3193 #endif 3194 return TruncateTo(dn, vn); 3195 } 3196 3197 template <size_t N, int kPow2> 3198 HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svuint64_t v) { 3199 const Rebind<uint64_t, decltype(dn)> du64; 3200 #if HWY_SVE_HAVE_2 3201 const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtnb_u64(v))); 3202 #else 3203 using TN = TFromD<decltype(dn)>; 3204 const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v)); 3205 #endif 3206 return TruncateTo(dn, vn); 3207 } 3208 3209 template <size_t N, int kPow2> 3210 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint64_t v) { 3211 const Rebind<uint64_t, decltype(dn)> du64; 3212 using TN = TFromD<decltype(dn)>; 3213 const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v)); 3214 return TruncateTo(dn, vn); 3215 } 3216 3217 // ------------------------------ Unsigned to signed demotions 3218 3219 // Disable the default unsigned to signed DemoteTo/ReorderDemote2To 3220 // implementations in generic_ops-inl.h on SVE/SVE2 as the SVE/SVE2 targets have 3221 // target-specific implementations of the unsigned to signed DemoteTo and 3222 // ReorderDemote2To ops 3223 3224 // NOTE: hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr is used instead of 3225 // hwy::EnableIf<false>* = nullptr to avoid compiler errors since 3226 // !hwy::IsSame<V, V>() is always false and as !hwy::IsSame<V, V>() will cause 3227 // SFINAE to occur instead of a hard error due to a dependency on the V template 3228 // argument 3229 #undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V 3230 #define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ 3231 hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr 3232 3233 template <class D, class V, HWY_IF_SIGNED_D(D), HWY_IF_UNSIGNED_V(V), 3234 HWY_IF_T_SIZE_LE_D(D, sizeof(TFromV<V>) - 1)> 3235 HWY_API VFromD<D> DemoteTo(D dn, V v) { 3236 const RebindToUnsigned<D> dn_u; 3237 return BitCast(dn, TruncateTo(dn_u, detail::SaturateU<TFromD<D>>(v))); 3238 } 3239 3240 // ------------------------------ PromoteEvenTo/PromoteOddTo 3241 3242 // Signed to signed PromoteEvenTo: 1 instruction instead of 2 in generic-inl.h. 3243 // Might as well also enable unsigned to unsigned, though it is just an And. 3244 namespace detail { 3245 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extb) 3246 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, exth) 3247 HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extw) 3248 } // namespace detail 3249 3250 #include "hwy/ops/inside-inl.h" 3251 3252 // ------------------------------ DemoteTo F 3253 3254 // We already toggled HWY_NATIVE_F16C above. 3255 3256 template <size_t N, int kPow2> 3257 HWY_API svfloat16_t DemoteTo(Simd<float16_t, N, kPow2> d, const svfloat32_t v) { 3258 const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); 3259 return detail::ConcatEvenFull(in_even, 3260 in_even); // lower half 3261 } 3262 3263 #ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 3264 #undef HWY_NATIVE_DEMOTE_F64_TO_F16 3265 #else 3266 #define HWY_NATIVE_DEMOTE_F64_TO_F16 3267 #endif 3268 3269 template <size_t N, int kPow2> 3270 HWY_API svfloat16_t DemoteTo(Simd<float16_t, N, kPow2> d, const svfloat64_t v) { 3271 const svfloat16_t in_lo16 = svcvt_f16_f64_x(detail::PTrue(d), v); 3272 const svfloat16_t in_even = detail::ConcatEvenFull(in_lo16, in_lo16); 3273 return detail::ConcatEvenFull(in_even, 3274 in_even); // lower half 3275 } 3276 3277 #ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 3278 #undef HWY_NATIVE_DEMOTE_F32_TO_BF16 3279 #else 3280 #define HWY_NATIVE_DEMOTE_F32_TO_BF16 3281 #endif 3282 3283 #if !HWY_SVE_HAVE_F32_TO_BF16C 3284 namespace detail { 3285 3286 // Round a F32 value to the nearest BF16 value, with the result returned as the 3287 // rounded F32 value bitcasted to an U32 3288 3289 // RoundF32ForDemoteToBF16 also converts NaN values to QNaN values to prevent 3290 // NaN F32 values from being converted to an infinity 3291 HWY_INLINE svuint32_t RoundF32ForDemoteToBF16(svfloat32_t v) { 3292 const DFromV<decltype(v)> df32; 3293 const RebindToUnsigned<decltype(df32)> du32; 3294 3295 const auto is_non_nan = Eq(v, v); 3296 const auto bits32 = BitCast(du32, v); 3297 3298 const auto round_incr = 3299 detail::AddN(detail::AndN(ShiftRight<16>(bits32), 1u), 0x7FFFu); 3300 return MaskedAddOr(detail::OrN(bits32, 0x00400000u), is_non_nan, bits32, 3301 round_incr); 3302 } 3303 3304 } // namespace detail 3305 #endif // !HWY_SVE_HAVE_F32_TO_BF16C 3306 3307 template <size_t N, int kPow2> 3308 HWY_API VBF16 DemoteTo(Simd<bfloat16_t, N, kPow2> dbf16, svfloat32_t v) { 3309 #if HWY_SVE_HAVE_F32_TO_BF16C 3310 const VBF16 in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), v); 3311 return detail::ConcatEvenFull(in_even, in_even); 3312 #else 3313 const svuint16_t in_odd = 3314 BitCast(ScalableTag<uint16_t>(), detail::RoundF32ForDemoteToBF16(v)); 3315 return BitCast(dbf16, detail::ConcatOddFull(in_odd, in_odd)); // lower half 3316 #endif 3317 } 3318 3319 template <size_t N, int kPow2> 3320 HWY_API svfloat32_t DemoteTo(Simd<float32_t, N, kPow2> d, const svfloat64_t v) { 3321 const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v); 3322 return detail::ConcatEvenFull(in_even, 3323 in_even); // lower half 3324 } 3325 3326 template <size_t N, int kPow2> 3327 HWY_API svint32_t DemoteTo(Simd<int32_t, N, kPow2> d, const svfloat64_t v) { 3328 const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v); 3329 return detail::ConcatEvenFull(in_even, 3330 in_even); // lower half 3331 } 3332 3333 template <size_t N, int kPow2> 3334 HWY_API svuint32_t DemoteTo(Simd<uint32_t, N, kPow2> d, const svfloat64_t v) { 3335 const svuint32_t in_even = svcvt_u32_f64_x(detail::PTrue(d), v); 3336 return detail::ConcatEvenFull(in_even, 3337 in_even); // lower half 3338 } 3339 3340 template <size_t N, int kPow2> 3341 HWY_API svfloat32_t DemoteTo(Simd<float, N, kPow2> d, const svint64_t v) { 3342 const svfloat32_t in_even = svcvt_f32_s64_x(detail::PTrue(d), v); 3343 return detail::ConcatEvenFull(in_even, 3344 in_even); // lower half 3345 } 3346 3347 template <size_t N, int kPow2> 3348 HWY_API svfloat32_t DemoteTo(Simd<float, N, kPow2> d, const svuint64_t v) { 3349 const svfloat32_t in_even = svcvt_f32_u64_x(detail::PTrue(d), v); 3350 return detail::ConcatEvenFull(in_even, 3351 in_even); // lower half 3352 } 3353 3354 // ------------------------------ ConvertTo F 3355 3356 #define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ 3357 /* Float from signed */ \ 3358 template <size_t N, int kPow2> \ 3359 HWY_API HWY_SVE_V(BASE, BITS) \ 3360 NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ 3361 return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ 3362 } \ 3363 /* Float from unsigned */ \ 3364 template <size_t N, int kPow2> \ 3365 HWY_API HWY_SVE_V(BASE, BITS) \ 3366 NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \ 3367 return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ 3368 } \ 3369 /* Signed from float, rounding toward zero */ \ 3370 template <size_t N, int kPow2> \ 3371 HWY_API HWY_SVE_V(int, BITS) \ 3372 NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ 3373 return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ 3374 } \ 3375 /* Unsigned from float, rounding toward zero */ \ 3376 template <size_t N, int kPow2> \ 3377 HWY_API HWY_SVE_V(uint, BITS) \ 3378 NAME(HWY_SVE_D(uint, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ 3379 return sv##OP##_u##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ 3380 } 3381 3382 HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) 3383 #undef HWY_SVE_CONVERT 3384 3385 // ------------------------------ MaskedConvertTo F 3386 3387 #define HWY_SVE_MASKED_CONVERT_TO_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \ 3388 /* Float from signed */ \ 3389 template <size_t N, int kPow2> \ 3390 HWY_API HWY_SVE_V(BASE, BITS) \ 3391 NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 3392 HWY_SVE_V(int, BITS) v) { \ 3393 return sv##OP##_##CHAR##BITS##_s##BITS##_z(m, v); \ 3394 } \ 3395 /* Float from unsigned */ \ 3396 template <size_t N, int kPow2> \ 3397 HWY_API HWY_SVE_V(BASE, BITS) \ 3398 NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ 3399 HWY_SVE_V(uint, BITS) v) { \ 3400 return sv##OP##_##CHAR##BITS##_u##BITS##_z(m, v); \ 3401 } \ 3402 /* Signed from float, rounding toward zero */ \ 3403 template <size_t N, int kPow2> \ 3404 HWY_API HWY_SVE_V(int, BITS) \ 3405 NAME(svbool_t m, HWY_SVE_D(int, BITS, N, kPow2) /* d */, \ 3406 HWY_SVE_V(BASE, BITS) v) { \ 3407 return sv##OP##_s##BITS##_##CHAR##BITS##_z(m, v); \ 3408 } \ 3409 /* Unsigned from float, rounding toward zero */ \ 3410 template <size_t N, int kPow2> \ 3411 HWY_API HWY_SVE_V(uint, BITS) \ 3412 NAME(svbool_t m, HWY_SVE_D(uint, BITS, N, kPow2) /* d */, \ 3413 HWY_SVE_V(BASE, BITS) v) { \ 3414 return sv##OP##_u##BITS##_##CHAR##BITS##_z(m, v); \ 3415 } 3416 3417 HWY_SVE_FOREACH_F(HWY_SVE_MASKED_CONVERT_TO_OR_ZERO, MaskedConvertTo, cvt) 3418 #undef HWY_SVE_MASKED_CONVERT_TO_OR_ZERO 3419 3420 // ------------------------------ NearestInt (Round, ConvertTo) 3421 template <class VF, class DI = RebindToSigned<DFromV<VF>>> 3422 HWY_API VFromD<DI> NearestInt(VF v) { 3423 // No single instruction, round then truncate. 3424 return ConvertTo(DI(), Round(v)); 3425 } 3426 3427 template <class DI32, HWY_IF_I32_D(DI32)> 3428 HWY_API VFromD<DI32> DemoteToNearestInt(DI32 di32, 3429 VFromD<Rebind<double, DI32>> v) { 3430 // No single instruction, round then demote. 3431 return DemoteTo(di32, Round(v)); 3432 } 3433 3434 // ------------------------------ Iota (AddN, ConvertTo) 3435 3436 #define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ 3437 template <size_t N, int kPow2, typename T2> \ 3438 HWY_API HWY_SVE_V(BASE, BITS) \ 3439 NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, T2 first) { \ 3440 return sv##OP##_##CHAR##BITS( \ 3441 ConvertScalarTo<HWY_SVE_T(BASE, BITS)>(first), 1); \ 3442 } 3443 3444 HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) 3445 #undef HWY_SVE_IOTA 3446 3447 template <class D, typename T = TFromD<D>, typename T2, HWY_IF_FLOAT(T)> 3448 HWY_API VFromD<D> Iota(const D d, T2 first) { 3449 const RebindToSigned<D> di; 3450 const T first_f = ConvertScalarTo<T>(first); 3451 const VFromD<D> iota_f = ConvertTo(d, Iota(di, 0)); 3452 return detail::AddN(iota_f, first_f); 3453 } 3454 3455 // ================================================== LANE ACCESS 3456 3457 // ------------------------------ ExtractLane (GetLaneM, FirstN) 3458 template <class V> 3459 HWY_API TFromV<V> ExtractLane(V v, size_t i) { 3460 return detail::GetLaneM(v, FirstN(DFromV<V>(), i)); 3461 } 3462 3463 // ------------------------------ InsertLane (IfThenElse, EqN) 3464 template <class V, typename T> 3465 HWY_API V InsertLane(const V v, size_t i, T t) { 3466 static_assert(sizeof(TFromV<V>) == sizeof(T), "Lane size mismatch"); 3467 const DFromV<V> d; 3468 const RebindToSigned<decltype(d)> di; 3469 using TI = TFromD<decltype(di)>; 3470 const svbool_t is_i = detail::EqN(Iota(di, 0), static_cast<TI>(i)); 3471 // The actual type may be int16_t for special floats; copy, not cast. 3472 TFromV<V> t_bits; 3473 hwy::CopySameSize(&t, &t_bits); 3474 return IfThenElse(RebindMask(d, is_i), Set(d, t_bits), v); 3475 } 3476 3477 // ------------------------------ GetExponent 3478 3479 #if HWY_SVE_HAVE_2 || HWY_IDE 3480 #ifdef HWY_NATIVE_GET_EXPONENT 3481 #undef HWY_NATIVE_GET_EXPONENT 3482 #else 3483 #define HWY_NATIVE_GET_EXPONENT 3484 #endif 3485 3486 namespace detail { 3487 #define HWY_SVE_GET_EXP(BASE, CHAR, BITS, HALF, NAME, OP) \ 3488 HWY_API HWY_SVE_V(int, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 3489 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ 3490 } 3491 HWY_SVE_FOREACH_F(HWY_SVE_GET_EXP, GetExponent, logb) 3492 #undef HWY_SVE_GET_EXP 3493 } // namespace detail 3494 3495 template <class V, HWY_IF_FLOAT_V(V)> 3496 HWY_API V GetExponent(V v) { 3497 const DFromV<V> d; 3498 const RebindToSigned<decltype(d)> di; 3499 const VFromD<decltype(di)> exponent_int = detail::GetExponent(v); 3500 // convert integer to original type 3501 return ConvertTo(d, exponent_int); 3502 } 3503 #endif // HWY_SVE_HAVE_2 3504 3505 // ------------------------------ InterleaveLower 3506 3507 template <class D, class V> 3508 HWY_API V InterleaveLower(D d, const V a, const V b) { 3509 static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch"); 3510 #if HWY_TARGET == HWY_SVE2_128 3511 (void)d; 3512 return detail::ZipLowerSame(a, b); 3513 #else 3514 // Move lower halves of blocks to lower half of vector. 3515 const Repartition<uint64_t, decltype(d)> d64; 3516 const auto a64 = BitCast(d64, a); 3517 const auto b64 = BitCast(d64, b); 3518 const auto a_blocks = detail::ConcatEvenFull(a64, a64); // lower half 3519 const auto b_blocks = detail::ConcatEvenFull(b64, b64); 3520 return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); 3521 #endif 3522 } 3523 3524 template <class V> 3525 HWY_API V InterleaveLower(const V a, const V b) { 3526 return InterleaveLower(DFromV<V>(), a, b); 3527 } 3528 3529 // ------------------------------ InterleaveUpper 3530 3531 // Only use zip2 if vector are a powers of two, otherwise getting the actual 3532 // "upper half" requires MaskUpperHalf. 3533 namespace detail { 3534 // Unlike Highway's ZipUpper, this returns the same type. 3535 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2) 3536 } // namespace detail 3537 3538 // Full vector: guaranteed to have at least one block 3539 template <class D, class V = VFromD<D>, 3540 hwy::EnableIf<detail::IsFull(D())>* = nullptr> 3541 HWY_API V InterleaveUpper(D d, const V a, const V b) { 3542 #if HWY_TARGET == HWY_SVE2_128 3543 (void)d; 3544 return detail::ZipUpperSame(a, b); 3545 #else 3546 // Move upper halves of blocks to lower half of vector. 3547 const Repartition<uint64_t, decltype(d)> d64; 3548 const auto a64 = BitCast(d64, a); 3549 const auto b64 = BitCast(d64, b); 3550 const auto a_blocks = detail::ConcatOddFull(a64, a64); // lower half 3551 const auto b_blocks = detail::ConcatOddFull(b64, b64); 3552 return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); 3553 #endif 3554 } 3555 3556 // Capped/fraction: need runtime check 3557 template <class D, class V = VFromD<D>, 3558 hwy::EnableIf<!detail::IsFull(D())>* = nullptr> 3559 HWY_API V InterleaveUpper(D d, const V a, const V b) { 3560 // Less than one block: treat as capped 3561 if (Lanes(d) * sizeof(TFromD<D>) < 16) { 3562 const Half<decltype(d)> d2; 3563 return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); 3564 } 3565 return InterleaveUpper(DFromV<V>(), a, b); 3566 } 3567 3568 // ------------------------------ InterleaveWholeLower 3569 #ifdef HWY_NATIVE_INTERLEAVE_WHOLE 3570 #undef HWY_NATIVE_INTERLEAVE_WHOLE 3571 #else 3572 #define HWY_NATIVE_INTERLEAVE_WHOLE 3573 #endif 3574 3575 template <class D> 3576 HWY_API VFromD<D> InterleaveWholeLower(D /*d*/, VFromD<D> a, VFromD<D> b) { 3577 return detail::ZipLowerSame(a, b); 3578 } 3579 3580 // ------------------------------ InterleaveWholeUpper 3581 3582 template <class D> 3583 HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) { 3584 if (HWY_SVE_IS_POW2 && detail::IsFull(d)) { 3585 return detail::ZipUpperSame(a, b); 3586 } 3587 3588 const Half<decltype(d)> d2; 3589 return InterleaveWholeLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); 3590 } 3591 3592 // ------------------------------ Per4LaneBlockShuffle 3593 3594 namespace detail { 3595 3596 template <size_t kLaneSize, size_t kVectSize, class V, 3597 HWY_IF_NOT_T_SIZE_V(V, 8)> 3598 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/, 3599 hwy::SizeTag<kLaneSize> /*lane_size_tag*/, 3600 hwy::SizeTag<kVectSize> /*vect_size_tag*/, 3601 V v) { 3602 const DFromV<decltype(v)> d; 3603 const RebindToUnsigned<decltype(d)> du; 3604 const RepartitionToWide<decltype(du)> dw; 3605 3606 const auto evens = BitCast(dw, ConcatEvenFull(v, v)); 3607 return BitCast(d, ZipLowerSame(evens, evens)); 3608 } 3609 3610 template <size_t kLaneSize, size_t kVectSize, class V, 3611 HWY_IF_NOT_T_SIZE_V(V, 8)> 3612 HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/, 3613 hwy::SizeTag<kLaneSize> /*lane_size_tag*/, 3614 hwy::SizeTag<kVectSize> /*vect_size_tag*/, 3615 V v) { 3616 const DFromV<decltype(v)> d; 3617 const RebindToUnsigned<decltype(d)> du; 3618 const RepartitionToWide<decltype(du)> dw; 3619 3620 const auto odds = BitCast(dw, ConcatOddFull(v, v)); 3621 return BitCast(d, ZipLowerSame(odds, odds)); 3622 } 3623 3624 } // namespace detail 3625 3626 // ================================================== COMBINE 3627 3628 namespace detail { 3629 3630 #if (HWY_TARGET == HWY_SVE_256 && HWY_HAVE_CONSTEXPR_LANES) || HWY_IDE 3631 template <class D, HWY_IF_T_SIZE_D(D, 1)> 3632 svbool_t MaskLowerHalf(D d) { 3633 switch (MaxLanes(d)) { 3634 case 32: 3635 return svptrue_pat_b8(SV_VL16); 3636 case 16: 3637 return svptrue_pat_b8(SV_VL8); 3638 case 8: 3639 return svptrue_pat_b8(SV_VL4); 3640 case 4: 3641 return svptrue_pat_b8(SV_VL2); 3642 default: 3643 return svptrue_pat_b8(SV_VL1); 3644 } 3645 } 3646 template <class D, HWY_IF_T_SIZE_D(D, 2)> 3647 svbool_t MaskLowerHalf(D d) { 3648 switch (MaxLanes(d)) { 3649 case 16: 3650 return svptrue_pat_b16(SV_VL8); 3651 case 8: 3652 return svptrue_pat_b16(SV_VL4); 3653 case 4: 3654 return svptrue_pat_b16(SV_VL2); 3655 default: 3656 return svptrue_pat_b16(SV_VL1); 3657 } 3658 } 3659 template <class D, HWY_IF_T_SIZE_D(D, 4)> 3660 svbool_t MaskLowerHalf(D d) { 3661 switch (MaxLanes(d)) { 3662 case 8: 3663 return svptrue_pat_b32(SV_VL4); 3664 case 4: 3665 return svptrue_pat_b32(SV_VL2); 3666 default: 3667 return svptrue_pat_b32(SV_VL1); 3668 } 3669 } 3670 template <class D, HWY_IF_T_SIZE_D(D, 8)> 3671 svbool_t MaskLowerHalf(D d) { 3672 switch (MaxLanes(d)) { 3673 case 4: 3674 return svptrue_pat_b64(SV_VL2); 3675 default: 3676 return svptrue_pat_b64(SV_VL1); 3677 } 3678 } 3679 #endif 3680 #if (HWY_TARGET == HWY_SVE2_128 && HWY_HAVE_CONSTEXPR_LANES) || HWY_IDE 3681 template <class D, HWY_IF_T_SIZE_D(D, 1)> 3682 svbool_t MaskLowerHalf(D d) { 3683 switch (MaxLanes(d)) { 3684 case 16: 3685 return svptrue_pat_b8(SV_VL8); 3686 case 8: 3687 return svptrue_pat_b8(SV_VL4); 3688 case 4: 3689 return svptrue_pat_b8(SV_VL2); 3690 case 2: 3691 case 1: 3692 default: 3693 return svptrue_pat_b8(SV_VL1); 3694 } 3695 } 3696 template <class D, HWY_IF_T_SIZE_D(D, 2)> 3697 svbool_t MaskLowerHalf(D d) { 3698 switch (MaxLanes(d)) { 3699 case 8: 3700 return svptrue_pat_b16(SV_VL4); 3701 case 4: 3702 return svptrue_pat_b16(SV_VL2); 3703 case 2: 3704 case 1: 3705 default: 3706 return svptrue_pat_b16(SV_VL1); 3707 } 3708 } 3709 template <class D, HWY_IF_T_SIZE_D(D, 4)> 3710 svbool_t MaskLowerHalf(D d) { 3711 return svptrue_pat_b32(MaxLanes(d) == 4 ? SV_VL2 : SV_VL1); 3712 } 3713 template <class D, HWY_IF_T_SIZE_D(D, 8)> 3714 svbool_t MaskLowerHalf(D /*d*/) { 3715 return svptrue_pat_b64(SV_VL1); 3716 } 3717 #endif // HWY_TARGET == HWY_SVE2_128 3718 #if (HWY_TARGET != HWY_SVE_256 && HWY_TARGET != HWY_SVE2_128) || \ 3719 !HWY_HAVE_CONSTEXPR_LANES 3720 template <class D> 3721 svbool_t MaskLowerHalf(D d) { 3722 return FirstN(d, Lanes(d) / 2); 3723 } 3724 #endif 3725 3726 template <class D> 3727 svbool_t MaskUpperHalf(D d) { 3728 // TODO(janwas): WHILEGE on SVE2 3729 if (HWY_SVE_IS_POW2 && IsFull(d)) { 3730 return Not(MaskLowerHalf(d)); 3731 } 3732 3733 // For Splice to work as intended, make sure bits above Lanes(d) are zero. 3734 return AndNot(MaskLowerHalf(d), detail::MakeMask(d)); 3735 } 3736 3737 // Right-shift vector pair by constexpr; can be used to slide down (=N) or up 3738 // (=Lanes()-N). 3739 #define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \ 3740 template <size_t kIndex> \ 3741 HWY_API HWY_SVE_V(BASE, BITS) \ 3742 NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ 3743 return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \ 3744 } 3745 HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext) 3746 #undef HWY_SVE_EXT 3747 3748 } // namespace detail 3749 3750 // ------------------------------ ConcatUpperLower 3751 template <class D, class V> 3752 HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) { 3753 return IfThenElse(detail::MaskLowerHalf(d), lo, hi); 3754 } 3755 3756 // ------------------------------ ConcatLowerLower 3757 template <class D, class V> 3758 HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) { 3759 if (detail::IsFull(d)) { 3760 #if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 3761 return detail::ConcatEvenBlocks(hi, lo); 3762 #endif 3763 #if HWY_TARGET == HWY_SVE2_128 3764 const Repartition<uint64_t, D> du64; 3765 const auto lo64 = BitCast(du64, lo); 3766 return BitCast(d, InterleaveLower(du64, lo64, BitCast(du64, hi))); 3767 #endif 3768 } 3769 return detail::Splice(hi, lo, detail::MaskLowerHalf(d)); 3770 } 3771 3772 // ------------------------------ ConcatLowerUpper 3773 template <class D, class V> 3774 HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) { 3775 #if HWY_HAVE_CONSTEXPR_LANES 3776 if (detail::IsFull(d)) { 3777 return detail::Ext<Lanes(d) / 2>(hi, lo); 3778 } 3779 #endif 3780 return detail::Splice(hi, lo, detail::MaskUpperHalf(d)); 3781 } 3782 3783 // ------------------------------ ConcatUpperUpper 3784 template <class D, class V> 3785 HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) { 3786 if (detail::IsFull(d)) { 3787 #if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 3788 return detail::ConcatOddBlocks(hi, lo); 3789 #endif 3790 #if HWY_TARGET == HWY_SVE2_128 3791 const Repartition<uint64_t, D> du64; 3792 const auto lo64 = BitCast(du64, lo); 3793 return BitCast(d, InterleaveUpper(du64, lo64, BitCast(du64, hi))); 3794 #endif 3795 } 3796 const svbool_t mask_upper = detail::MaskUpperHalf(d); 3797 const V lo_upper = detail::Splice(lo, lo, mask_upper); 3798 return IfThenElse(mask_upper, hi, lo_upper); 3799 } 3800 3801 // ------------------------------ Combine 3802 template <class D, class V2> 3803 HWY_API VFromD<D> Combine(const D d, const V2 hi, const V2 lo) { 3804 return ConcatLowerLower(d, hi, lo); 3805 } 3806 3807 // ------------------------------ ZeroExtendVector 3808 template <class D, class V> 3809 HWY_API V ZeroExtendVector(const D d, const V lo) { 3810 return Combine(d, Zero(Half<D>()), lo); 3811 } 3812 3813 // ------------------------------ Lower/UpperHalf 3814 3815 template <class D2, class V> 3816 HWY_API V LowerHalf(D2 /* tag */, const V v) { 3817 return v; 3818 } 3819 3820 template <class V> 3821 HWY_API V LowerHalf(const V v) { 3822 return v; 3823 } 3824 3825 template <class DH, class V> 3826 HWY_API V UpperHalf(const DH dh, const V v) { 3827 const Twice<decltype(dh)> d; 3828 // Cast so that we support bfloat16_t. 3829 const RebindToUnsigned<decltype(d)> du; 3830 const VFromD<decltype(du)> vu = BitCast(du, v); 3831 #if HWY_HAVE_CONSTEXPR_LANES 3832 return BitCast(d, detail::Ext<Lanes(dh)>(vu, vu)); 3833 #else 3834 const MFromD<decltype(du)> mask = detail::MaskUpperHalf(du); 3835 return BitCast(d, detail::Splice(vu, vu, mask)); 3836 #endif 3837 } 3838 3839 // ================================================== SWIZZLE 3840 3841 // ------------------------------ DupEven 3842 3843 namespace detail { 3844 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) 3845 } // namespace detail 3846 3847 template <class V> 3848 HWY_API V DupEven(const V v) { 3849 return detail::InterleaveEven(v, v); 3850 } 3851 3852 // ------------------------------ DupOdd 3853 3854 namespace detail { 3855 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) 3856 } // namespace detail 3857 3858 template <class V> 3859 HWY_API V DupOdd(const V v) { 3860 return detail::InterleaveOdd(v, v); 3861 } 3862 3863 // ------------------------------ OddEven 3864 3865 #if HWY_SVE_HAVE_2 3866 3867 #define HWY_SVE_ODD_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ 3868 HWY_API HWY_SVE_V(BASE, BITS) \ 3869 NAME(HWY_SVE_V(BASE, BITS) odd, HWY_SVE_V(BASE, BITS) even) { \ 3870 return sv##OP##_##CHAR##BITS(even, odd, /*xor=*/0); \ 3871 } 3872 3873 HWY_SVE_FOREACH_UI(HWY_SVE_ODD_EVEN, OddEven, eortb_n) 3874 #undef HWY_SVE_ODD_EVEN 3875 3876 template <class V, HWY_IF_FLOAT_V(V)> 3877 HWY_API V OddEven(const V odd, const V even) { 3878 const DFromV<V> d; 3879 const RebindToUnsigned<decltype(d)> du; 3880 return BitCast(d, OddEven(BitCast(du, odd), BitCast(du, even))); 3881 } 3882 3883 #else 3884 3885 template <class V> 3886 HWY_API V OddEven(const V odd, const V even) { 3887 const auto odd_in_even = detail::Ext<1>(odd, odd); 3888 return detail::InterleaveEven(even, odd_in_even); 3889 } 3890 3891 #endif // HWY_TARGET 3892 3893 // ------------------------------ InterleaveEven 3894 template <class D> 3895 HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) { 3896 return detail::InterleaveEven(a, b); 3897 } 3898 3899 // ------------------------------ InterleaveOdd 3900 template <class D> 3901 HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) { 3902 return detail::InterleaveOdd(a, b); 3903 } 3904 3905 // ------------------------------ OddEvenBlocks 3906 template <class V> 3907 HWY_API V OddEvenBlocks(const V odd, const V even) { 3908 const DFromV<V> d; 3909 #if HWY_TARGET == HWY_SVE_256 3910 return ConcatUpperLower(d, odd, even); 3911 #elif HWY_TARGET == HWY_SVE2_128 3912 (void)odd; 3913 (void)d; 3914 return even; 3915 #else 3916 const RebindToUnsigned<decltype(d)> du; 3917 using TU = TFromD<decltype(du)>; 3918 constexpr size_t kShift = CeilLog2(16 / sizeof(TU)); 3919 const auto idx_block = ShiftRight<kShift>(Iota(du, 0)); 3920 const auto lsb = detail::AndN(idx_block, static_cast<TU>(1)); 3921 const svbool_t is_even = detail::EqN(lsb, static_cast<TU>(0)); 3922 return IfThenElse(is_even, even, odd); 3923 #endif 3924 } 3925 3926 // ------------------------------ TableLookupLanes 3927 3928 template <class D, class VI> 3929 HWY_API VFromD<RebindToUnsigned<D>> IndicesFromVec(D d, VI vec) { 3930 using TI = TFromV<VI>; 3931 static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index/lane size mismatch"); 3932 const RebindToUnsigned<D> du; 3933 const auto indices = BitCast(du, vec); 3934 #if HWY_IS_DEBUG_BUILD 3935 using TU = MakeUnsigned<TI>; 3936 const size_t twice_max_lanes = Lanes(d) * 2; 3937 HWY_DASSERT(AllTrue( 3938 du, Eq(indices, 3939 detail::AndN(indices, static_cast<TU>(twice_max_lanes - 1))))); 3940 #else 3941 (void)d; 3942 #endif 3943 return indices; 3944 } 3945 3946 template <class D, typename TI> 3947 HWY_API VFromD<RebindToUnsigned<D>> SetTableIndices(D d, const TI* idx) { 3948 static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane"); 3949 return IndicesFromVec(d, LoadU(Rebind<TI, D>(), idx)); 3950 } 3951 3952 #define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \ 3953 HWY_API HWY_SVE_V(BASE, BITS) \ 3954 NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \ 3955 return sv##OP##_##CHAR##BITS(v, idx); \ 3956 } 3957 3958 HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) 3959 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 3960 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE, TableLookupLanes, tbl) 3961 #endif 3962 #undef HWY_SVE_TABLE 3963 3964 #if HWY_SVE_HAVE_2 3965 namespace detail { 3966 #define HWY_SVE_TABLE2(BASE, CHAR, BITS, HALF, NAME, OP) \ 3967 HWY_API HWY_SVE_V(BASE, BITS) \ 3968 NAME(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(uint, BITS) idx) { \ 3969 return sv##OP##_##CHAR##BITS(tuple, idx); \ 3970 } 3971 3972 HWY_SVE_FOREACH(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, tbl2) 3973 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 3974 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, 3975 tbl2) 3976 #endif 3977 #undef HWY_SVE_TABLE 3978 } // namespace detail 3979 #endif // HWY_SVE_HAVE_2 3980 3981 template <class D> 3982 HWY_API VFromD<D> TwoTablesLookupLanes(D d, VFromD<D> a, VFromD<D> b, 3983 VFromD<RebindToUnsigned<D>> idx) { 3984 // SVE2 has an instruction for this, but it only works for full 2^n vectors. 3985 #if HWY_SVE_HAVE_2 && HWY_SVE_IS_POW2 3986 if (detail::IsFull(d)) { 3987 return detail::NativeTwoTableLookupLanes(Create2(d, a, b), idx); 3988 } 3989 #endif 3990 const RebindToUnsigned<decltype(d)> du; 3991 using TU = TFromD<decltype(du)>; 3992 3993 const size_t num_of_lanes = Lanes(d); 3994 const auto idx_mod = detail::AndN(idx, static_cast<TU>(num_of_lanes - 1)); 3995 const auto sel_a_mask = Eq(idx, idx_mod); 3996 3997 const auto a_lookup_result = TableLookupLanes(a, idx_mod); 3998 const auto b_lookup_result = TableLookupLanes(b, idx_mod); 3999 return IfThenElse(sel_a_mask, a_lookup_result, b_lookup_result); 4000 } 4001 4002 template <class V> 4003 HWY_API V TwoTablesLookupLanes(V a, V b, 4004 VFromD<RebindToUnsigned<DFromV<V>>> idx) { 4005 const DFromV<decltype(a)> d; 4006 return TwoTablesLookupLanes(d, a, b, idx); 4007 } 4008 4009 // ------------------------------ SlideUpLanes (FirstN) 4010 template <class D> 4011 HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) { 4012 return detail::Splice(v, Zero(d), FirstN(d, amt)); 4013 } 4014 4015 // ------------------------------ Slide1Up 4016 4017 #ifdef HWY_NATIVE_SLIDE1_UP_DOWN 4018 #undef HWY_NATIVE_SLIDE1_UP_DOWN 4019 #else 4020 #define HWY_NATIVE_SLIDE1_UP_DOWN 4021 #endif 4022 4023 template <class D> 4024 HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) { 4025 return SlideUpLanes(d, v, 1); 4026 } 4027 4028 // ------------------------------ SlideDownLanes (TableLookupLanes) 4029 template <class D> 4030 HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) { 4031 const RebindToUnsigned<decltype(d)> du; 4032 using TU = TFromD<decltype(du)>; 4033 const auto idx = Iota(du, static_cast<TU>(amt)); 4034 return IfThenElseZero(FirstN(d, Lanes(d) - amt), TableLookupLanes(v, idx)); 4035 } 4036 4037 // ------------------------------ Slide1Down 4038 template <class D> 4039 HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) { 4040 return SlideDownLanes(d, v, 1); 4041 } 4042 4043 // ------------------------------ SwapAdjacentBlocks (TableLookupLanes) 4044 4045 namespace detail { 4046 4047 template <typename T, size_t N, int kPow2> 4048 constexpr size_t LanesPerBlock(Simd<T, N, kPow2> d) { 4049 // We might have a capped vector smaller than a block, so honor that. 4050 return HWY_MIN(16 / sizeof(T), MaxLanes(d)); 4051 } 4052 4053 } // namespace detail 4054 4055 template <class V> 4056 HWY_API V SwapAdjacentBlocks(const V v) { 4057 const DFromV<V> d; 4058 #if HWY_TARGET == HWY_SVE_256 4059 return ConcatLowerUpper(d, v, v); 4060 #elif HWY_TARGET == HWY_SVE2_128 4061 (void)d; 4062 return v; 4063 #else 4064 const RebindToUnsigned<decltype(d)> du; 4065 constexpr auto kLanesPerBlock = 4066 static_cast<TFromD<decltype(du)>>(detail::LanesPerBlock(d)); 4067 const VFromD<decltype(du)> idx = detail::XorN(Iota(du, 0), kLanesPerBlock); 4068 return TableLookupLanes(v, idx); 4069 #endif 4070 } 4071 4072 // ------------------------------ InterleaveEvenBlocks 4073 // (ConcatLowerLower, SlideUpLanes, OddEvenBlocks) 4074 4075 template <class D, class V = VFromD<D>> 4076 HWY_API V InterleaveEvenBlocks(D d, V a, V b) { 4077 #if HWY_TARGET == HWY_SVE_256 4078 return ConcatLowerLower(d, b, a); 4079 #elif HWY_TARGET == HWY_SVE2_128 4080 (void)d; 4081 (void)b; 4082 return a; 4083 #else 4084 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); 4085 return OddEvenBlocks(SlideUpLanes(d, b, kLanesPerBlock), a); 4086 #endif 4087 } 4088 4089 // ------------------------------ InterleaveOddBlocks 4090 // (ConcatUpperUpper, SlideDownLanes, OddEvenBlocks) 4091 4092 template <class D, class V = VFromD<D>> 4093 HWY_API V InterleaveOddBlocks(D d, V a, V b) { 4094 #if HWY_TARGET == HWY_SVE_256 4095 return ConcatUpperUpper(d, b, a); 4096 #elif HWY_TARGET == HWY_SVE2_128 4097 (void)d; 4098 (void)b; 4099 return a; 4100 #else 4101 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); 4102 return OddEvenBlocks(b, SlideDownLanes(d, a, kLanesPerBlock)); 4103 #endif 4104 } 4105 4106 // ------------------------------ Reverse 4107 4108 namespace detail { 4109 4110 #define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \ 4111 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 4112 return sv##OP##_##CHAR##BITS(v); \ 4113 } 4114 4115 HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev) 4116 #if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC 4117 HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_REVERSE, ReverseFull, rev) 4118 #endif 4119 #undef HWY_SVE_REVERSE 4120 4121 } // namespace detail 4122 4123 template <class D, class V> 4124 HWY_API V Reverse(D d, V v) { 4125 using T = TFromD<D>; 4126 const auto reversed = detail::ReverseFull(v); 4127 if (HWY_SVE_IS_POW2 && detail::IsFull(d)) return reversed; 4128 // Shift right to remove extra (non-pow2 and remainder) lanes. 4129 // TODO(janwas): on SVE2, use WHILEGE. 4130 // Avoids FirstN truncating to the return vector size. Must also avoid Not 4131 // because that is limited to SV_POW2. 4132 const ScalableTag<T> dfull; 4133 const svbool_t all_true = detail::AllPTrue(dfull); 4134 const size_t all_lanes = detail::AllHardwareLanes<T>(); 4135 const size_t want_lanes = Lanes(d); 4136 HWY_DASSERT(want_lanes <= all_lanes); 4137 const svbool_t mask = 4138 svnot_b_z(all_true, FirstN(dfull, all_lanes - want_lanes)); 4139 return detail::Splice(reversed, reversed, mask); 4140 } 4141 4142 // ------------------------------ Reverse2 4143 4144 // Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. 4145 #ifdef HWY_NATIVE_REVERSE2_8 4146 #undef HWY_NATIVE_REVERSE2_8 4147 #else 4148 #define HWY_NATIVE_REVERSE2_8 4149 #endif 4150 4151 template <class D, HWY_IF_T_SIZE_D(D, 1)> 4152 HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { 4153 const RebindToUnsigned<decltype(d)> du; 4154 const RepartitionToWide<decltype(du)> dw; 4155 return BitCast(d, svrevb_u16_x(detail::PTrue(d), BitCast(dw, v))); 4156 } 4157 4158 template <class D, HWY_IF_T_SIZE_D(D, 2)> 4159 HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { 4160 const RebindToUnsigned<decltype(d)> du; 4161 const RepartitionToWide<decltype(du)> dw; 4162 return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v))); 4163 } 4164 4165 template <class D, HWY_IF_T_SIZE_D(D, 4)> 4166 HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { 4167 const RebindToUnsigned<decltype(d)> du; 4168 const RepartitionToWide<decltype(du)> dw; 4169 return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v))); 4170 } 4171 4172 template <class D, HWY_IF_T_SIZE_D(D, 8)> 4173 HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { // 3210 4174 #if HWY_TARGET == HWY_SVE2_128 4175 if (detail::IsFull(d)) { 4176 return detail::Ext<1>(v, v); 4177 } 4178 #endif 4179 (void)d; 4180 const auto odd_in_even = detail::Ext<1>(v, v); // x321 4181 return detail::InterleaveEven(odd_in_even, v); // 2301 4182 } 4183 4184 // ------------------------------ Reverse4 (TableLookupLanes) 4185 4186 template <class D, HWY_IF_T_SIZE_D(D, 1)> 4187 HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { 4188 const RebindToUnsigned<decltype(d)> du; 4189 const RepartitionToWideX2<decltype(du)> du32; 4190 return BitCast(d, svrevb_u32_x(detail::PTrue(d), BitCast(du32, v))); 4191 } 4192 4193 template <class D, HWY_IF_T_SIZE_D(D, 2)> 4194 HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { 4195 const RebindToUnsigned<decltype(d)> du; 4196 const RepartitionToWideX2<decltype(du)> du64; 4197 return BitCast(d, svrevh_u64_x(detail::PTrue(d), BitCast(du64, v))); 4198 } 4199 4200 template <class D, HWY_IF_T_SIZE_D(D, 4)> 4201 HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { 4202 if (HWY_TARGET == HWY_SVE2_128 && detail::IsFull(d)) { 4203 return detail::ReverseFull(v); 4204 } 4205 // TODO(janwas): is this approach faster than Shuffle0123? 4206 const RebindToUnsigned<decltype(d)> du; 4207 const auto idx = detail::XorN(Iota(du, 0), 3); 4208 return TableLookupLanes(v, idx); 4209 } 4210 4211 template <class D, HWY_IF_T_SIZE_D(D, 8)> 4212 HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { 4213 if (HWY_TARGET == HWY_SVE_256 && detail::IsFull(d)) { 4214 return detail::ReverseFull(v); 4215 } 4216 // TODO(janwas): is this approach faster than Shuffle0123? 4217 const RebindToUnsigned<decltype(d)> du; 4218 const auto idx = detail::XorN(Iota(du, 0), 3); 4219 return TableLookupLanes(v, idx); 4220 } 4221 4222 // ------------------------------ Reverse8 (TableLookupLanes) 4223 4224 template <class D, HWY_IF_T_SIZE_D(D, 1)> 4225 HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { 4226 const Repartition<uint64_t, decltype(d)> du64; 4227 return BitCast(d, svrevb_u64_x(detail::PTrue(d), BitCast(du64, v))); 4228 } 4229 4230 template <class D, HWY_IF_NOT_T_SIZE_D(D, 1)> 4231 HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { 4232 const RebindToUnsigned<decltype(d)> du; 4233 const auto idx = detail::XorN(Iota(du, 0), 7); 4234 return TableLookupLanes(v, idx); 4235 } 4236 4237 // ------------------------------- ReverseBits 4238 4239 #ifdef HWY_NATIVE_REVERSE_BITS_UI8 4240 #undef HWY_NATIVE_REVERSE_BITS_UI8 4241 #else 4242 #define HWY_NATIVE_REVERSE_BITS_UI8 4243 #endif 4244 4245 #ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64 4246 #undef HWY_NATIVE_REVERSE_BITS_UI16_32_64 4247 #else 4248 #define HWY_NATIVE_REVERSE_BITS_UI16_32_64 4249 #endif 4250 4251 #define HWY_SVE_REVERSE_BITS(BASE, CHAR, BITS, HALF, NAME, OP) \ 4252 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 4253 const DFromV<decltype(v)> d; \ 4254 return sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v); \ 4255 } 4256 4257 HWY_SVE_FOREACH_UI(HWY_SVE_REVERSE_BITS, ReverseBits, rbit) 4258 #undef HWY_SVE_REVERSE_BITS 4259 4260 // ------------------------------ Block insert/extract/broadcast ops 4261 #if HWY_TARGET != HWY_SVE2_128 4262 4263 #ifdef HWY_NATIVE_BLK_INSERT_EXTRACT 4264 #undef HWY_NATIVE_BLK_INSERT_EXTRACT 4265 #else 4266 #define HWY_NATIVE_BLK_INSERT_EXTRACT 4267 #endif 4268 4269 template <int kBlockIdx, class V> 4270 HWY_API V InsertBlock(V v, V blk_to_insert) { 4271 const DFromV<decltype(v)> d; 4272 static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), 4273 "Invalid block index"); 4274 4275 #if HWY_TARGET == HWY_SVE_256 4276 return (kBlockIdx == 0) ? ConcatUpperLower(d, v, blk_to_insert) 4277 : ConcatLowerLower(d, blk_to_insert, v); 4278 #else 4279 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); 4280 4281 constexpr size_t kBlockOffset = 4282 static_cast<size_t>(kBlockIdx) * kLanesPerBlock; 4283 const auto splice_mask = FirstN(d, kBlockOffset); 4284 const auto sel_lo_mask = FirstN(d, kBlockOffset + kLanesPerBlock); 4285 4286 const auto splice_result = detail::Splice(blk_to_insert, v, splice_mask); 4287 return IfThenElse(sel_lo_mask, splice_result, v); 4288 #endif 4289 } 4290 4291 template <int kBlockIdx, class V> 4292 HWY_API V ExtractBlock(V v) { 4293 const DFromV<decltype(v)> d; 4294 static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), 4295 "Invalid block index"); 4296 4297 if (kBlockIdx == 0) return v; 4298 4299 #if HWY_TARGET == HWY_SVE_256 4300 return UpperHalf(Half<decltype(d)>(), v); 4301 #else 4302 const RebindToUnsigned<decltype(d)> du; 4303 using TU = TFromD<decltype(du)>; 4304 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); 4305 constexpr size_t kBlockOffset = 4306 static_cast<size_t>(kBlockIdx) * kLanesPerBlock; 4307 const auto splice_mask = 4308 RebindMask(d, detail::LtN(Iota(du, static_cast<TU>(0u - kBlockOffset)), 4309 static_cast<TU>(kLanesPerBlock))); 4310 return detail::Splice(v, v, splice_mask); 4311 #endif 4312 } 4313 4314 template <int kBlockIdx, class V> 4315 HWY_API V BroadcastBlock(V v) { 4316 const DFromV<decltype(v)> d; 4317 static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), 4318 "Invalid block index"); 4319 4320 const RebindToUnsigned<decltype(d)> du; // for bfloat16_t 4321 using VU = VFromD<decltype(du)>; 4322 const VU vu = BitCast(du, v); 4323 4324 #if HWY_TARGET == HWY_SVE_256 4325 return BitCast(d, (kBlockIdx == 0) ? ConcatLowerLower(du, vu, vu) 4326 : ConcatUpperUpper(du, vu, vu)); 4327 #else 4328 using TU = TFromD<decltype(du)>; 4329 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); 4330 constexpr size_t kBlockOffset = 4331 static_cast<size_t>(kBlockIdx) * kLanesPerBlock; 4332 4333 const VU idx = detail::AddN( 4334 detail::AndN(Iota(du, TU{0}), static_cast<TU>(kLanesPerBlock - 1)), 4335 static_cast<TU>(kBlockOffset)); 4336 return BitCast(d, TableLookupLanes(vu, idx)); 4337 #endif 4338 } 4339 4340 #endif // HWY_TARGET != HWY_SVE2_128 4341 4342 // ------------------------------ Compress (PromoteTo) 4343 4344 template <typename T> 4345 struct CompressIsPartition { 4346 #if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 4347 // Optimization for 64-bit lanes (could also be applied to 32-bit, but that 4348 // requires a larger table). 4349 enum { value = (sizeof(T) == 8) }; 4350 #else 4351 enum { value = 0 }; 4352 #endif // HWY_TARGET == HWY_SVE_256 4353 }; 4354 4355 #define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \ 4356 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ 4357 return sv##OP##_##CHAR##BITS(mask, v); \ 4358 } 4359 4360 #if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 4361 HWY_SVE_FOREACH_UI32(HWY_SVE_COMPRESS, Compress, compact) 4362 HWY_SVE_FOREACH_F32(HWY_SVE_COMPRESS, Compress, compact) 4363 #else 4364 HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact) 4365 #endif 4366 #undef HWY_SVE_COMPRESS 4367 4368 #if HWY_TARGET == HWY_SVE_256 || HWY_IDE 4369 template <class V, HWY_IF_T_SIZE_V(V, 8)> 4370 HWY_API V Compress(V v, svbool_t mask) { 4371 const DFromV<V> d; 4372 const RebindToUnsigned<decltype(d)> du64; 4373 4374 // Convert mask into bitfield via horizontal sum (faster than ORV) of masked 4375 // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for 4376 // SetTableIndices. 4377 const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); 4378 const size_t offset = detail::SumOfLanesM(mask, bits); 4379 4380 // See CompressIsPartition. 4381 alignas(16) static constexpr uint64_t table[4 * 16] = { 4382 // PrintCompress64x4Tables 4383 0, 1, 2, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 2, 0, 1, 3, 0, 2, 4384 1, 3, 1, 2, 0, 3, 0, 1, 2, 3, 3, 0, 1, 2, 0, 3, 1, 2, 1, 3, 0, 2, 4385 0, 1, 3, 2, 2, 3, 0, 1, 0, 2, 3, 1, 1, 2, 3, 0, 0, 1, 2, 3}; 4386 return TableLookupLanes(v, SetTableIndices(d, table + offset)); 4387 } 4388 4389 #endif // HWY_TARGET == HWY_SVE_256 4390 #if HWY_TARGET == HWY_SVE2_128 || HWY_IDE 4391 template <class V, HWY_IF_T_SIZE_V(V, 8)> 4392 HWY_API V Compress(V v, svbool_t mask) { 4393 // If mask == 10: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 4394 // swaps upper/lower (the lower half is set to the upper half, and the 4395 // remaining upper half is filled from the lower half of the second v), and 4396 // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot keep 10 4397 // unchanged and map everything else to 00. 4398 const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane 4399 return detail::Splice(v, v, AndNot(maskLL, mask)); 4400 } 4401 4402 #endif // HWY_TARGET == HWY_SVE2_128 4403 4404 template <class V, HWY_IF_T_SIZE_V(V, 2)> 4405 HWY_API V Compress(V v, svbool_t mask16) { 4406 static_assert(!IsSame<V, svfloat16_t>(), "Must use overload"); 4407 const DFromV<V> d16; 4408 4409 // Promote vector and mask to 32-bit 4410 const RepartitionToWide<decltype(d16)> dw; 4411 const auto v32L = PromoteTo(dw, v); 4412 const auto v32H = detail::PromoteUpperTo(dw, v); 4413 const svbool_t mask32L = svunpklo_b(mask16); 4414 const svbool_t mask32H = svunpkhi_b(mask16); 4415 4416 const auto compressedL = Compress(v32L, mask32L); 4417 const auto compressedH = Compress(v32H, mask32H); 4418 4419 // Demote to 16-bit (already in range) - separately so we can splice 4420 const V evenL = BitCast(d16, compressedL); 4421 const V evenH = BitCast(d16, compressedH); 4422 const V v16L = detail::ConcatEvenFull(evenL, evenL); // lower half 4423 const V v16H = detail::ConcatEvenFull(evenH, evenH); 4424 4425 // We need to combine two vectors of non-constexpr length, so the only option 4426 // is Splice, which requires us to synthesize a mask. NOTE: this function uses 4427 // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt. 4428 const size_t countL = detail::CountTrueFull(dw, mask32L); 4429 const auto compressed_maskL = FirstN(d16, countL); 4430 return detail::Splice(v16H, v16L, compressed_maskL); 4431 } 4432 4433 // Must treat float16_t as integers so we can ConcatEven. 4434 HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) { 4435 const DFromV<decltype(v)> df; 4436 const RebindToSigned<decltype(df)> di; 4437 return BitCast(df, Compress(BitCast(di, v), mask16)); 4438 } 4439 4440 // ------------------------------ CompressNot 4441 4442 // 2 or 4 bytes 4443 template <class V, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 2) | (1 << 4))> 4444 HWY_API V CompressNot(V v, const svbool_t mask) { 4445 return Compress(v, Not(mask)); 4446 } 4447 4448 template <class V, HWY_IF_T_SIZE_V(V, 8)> 4449 HWY_API V CompressNot(V v, svbool_t mask) { 4450 #if HWY_TARGET == HWY_SVE2_128 || HWY_IDE 4451 // If mask == 01: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 4452 // swaps upper/lower (the lower half is set to the upper half, and the 4453 // remaining upper half is filled from the lower half of the second v), and 4454 // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot map 4455 // 01 to 10, and everything else to 00. 4456 const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane 4457 return detail::Splice(v, v, AndNot(mask, maskLL)); 4458 #endif 4459 #if HWY_TARGET == HWY_SVE_256 || HWY_IDE 4460 const DFromV<V> d; 4461 const RebindToUnsigned<decltype(d)> du64; 4462 4463 // Convert mask into bitfield via horizontal sum (faster than ORV) of masked 4464 // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for 4465 // SetTableIndices. 4466 const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); 4467 const size_t offset = detail::SumOfLanesM(mask, bits); 4468 4469 // See CompressIsPartition. 4470 alignas(16) static constexpr uint64_t table[4 * 16] = { 4471 // PrintCompressNot64x4Tables 4472 0, 1, 2, 3, 1, 2, 3, 0, 0, 2, 3, 1, 2, 3, 0, 1, 0, 1, 3, 2, 1, 3, 4473 0, 2, 0, 3, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 0, 3, 0, 2, 1, 3, 4474 2, 0, 1, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; 4475 return TableLookupLanes(v, SetTableIndices(d, table + offset)); 4476 #endif // HWY_TARGET == HWY_SVE_256 4477 4478 return Compress(v, Not(mask)); 4479 } 4480 4481 // ------------------------------ CompressBlocksNot 4482 HWY_API svuint64_t CompressBlocksNot(svuint64_t v, svbool_t mask) { 4483 #if HWY_TARGET == HWY_SVE2_128 4484 (void)mask; 4485 return v; 4486 #endif 4487 #if HWY_TARGET == HWY_SVE_256 || HWY_IDE 4488 uint64_t bits = 0; // predicate reg is 32-bit 4489 CopyBytes<4>(&mask, &bits); // not same size - 64-bit more efficient 4490 // Concatenate LSB for upper and lower blocks, pre-scale by 4 for table idx. 4491 const size_t offset = ((bits & 1) ? 4u : 0u) + ((bits & 0x10000) ? 8u : 0u); 4492 // See CompressIsPartition. Manually generated; flip halves if mask = [0, 1]. 4493 alignas(16) static constexpr uint64_t table[4 * 4] = {0, 1, 2, 3, 2, 3, 0, 1, 4494 0, 1, 2, 3, 0, 1, 2, 3}; 4495 const ScalableTag<uint64_t> d; 4496 return TableLookupLanes(v, SetTableIndices(d, table + offset)); 4497 #endif 4498 4499 return CompressNot(v, mask); 4500 } 4501 4502 // ------------------------------ CompressStore 4503 template <class V, class D, HWY_IF_NOT_T_SIZE_D(D, 1)> 4504 HWY_API size_t CompressStore(const V v, const svbool_t mask, const D d, 4505 TFromD<D>* HWY_RESTRICT unaligned) { 4506 StoreU(Compress(v, mask), d, unaligned); 4507 return CountTrue(d, mask); 4508 } 4509 4510 // ------------------------------ CompressBlendedStore 4511 template <class V, class D, HWY_IF_NOT_T_SIZE_D(D, 1)> 4512 HWY_API size_t CompressBlendedStore(const V v, const svbool_t mask, const D d, 4513 TFromD<D>* HWY_RESTRICT unaligned) { 4514 const size_t count = CountTrue(d, mask); 4515 const svbool_t store_mask = FirstN(d, count); 4516 BlendedStore(Compress(v, mask), store_mask, d, unaligned); 4517 return count; 4518 } 4519 4520 // ================================================== MASK (2) 4521 4522 // ------------------------------ FindKnownLastTrue 4523 template <class D> 4524 HWY_API size_t FindKnownLastTrue(D d, svbool_t m) { 4525 const RebindToUnsigned<decltype(d)> du; 4526 return static_cast<size_t>(detail::ExtractLastMatchingLaneM( 4527 Iota(du, 0), And(m, detail::MakeMask(d)))); 4528 } 4529 4530 // ------------------------------ FindLastTrue 4531 template <class D> 4532 HWY_API intptr_t FindLastTrue(D d, svbool_t m) { 4533 return AllFalse(d, m) ? intptr_t{-1} 4534 : static_cast<intptr_t>(FindKnownLastTrue(d, m)); 4535 } 4536 4537 // ================================================== BLOCKWISE 4538 4539 // ------------------------------ CombineShiftRightBytes 4540 4541 // Prevent accidentally using these for 128-bit vectors - should not be 4542 // necessary. 4543 #if HWY_TARGET != HWY_SVE2_128 4544 namespace detail { 4545 4546 // For x86-compatible behaviour mandated by Highway API: TableLookupBytes 4547 // offsets are implicitly relative to the start of their 128-bit block. 4548 template <class D, class V> 4549 HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { 4550 using T = MakeUnsigned<TFromD<D>>; 4551 return detail::AndNotN(static_cast<T>(LanesPerBlock(d) - 1), iota0); 4552 } 4553 4554 template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 1)> 4555 svbool_t FirstNPerBlock(D d) { 4556 const RebindToUnsigned<decltype(d)> du; 4557 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); 4558 const svuint8_t idx_mod = 4559 svdupq_n_u8(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, 4560 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, 4561 6 % kLanesPerBlock, 7 % kLanesPerBlock, 8 % kLanesPerBlock, 4562 9 % kLanesPerBlock, 10 % kLanesPerBlock, 11 % kLanesPerBlock, 4563 12 % kLanesPerBlock, 13 % kLanesPerBlock, 14 % kLanesPerBlock, 4564 15 % kLanesPerBlock); 4565 return detail::LtN(BitCast(du, idx_mod), kLanes); 4566 } 4567 template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 2)> 4568 svbool_t FirstNPerBlock(D d) { 4569 const RebindToUnsigned<decltype(d)> du; 4570 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); 4571 const svuint16_t idx_mod = 4572 svdupq_n_u16(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, 4573 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, 4574 6 % kLanesPerBlock, 7 % kLanesPerBlock); 4575 return detail::LtN(BitCast(du, idx_mod), kLanes); 4576 } 4577 template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 4)> 4578 svbool_t FirstNPerBlock(D d) { 4579 const RebindToUnsigned<decltype(d)> du; 4580 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); 4581 const svuint32_t idx_mod = 4582 svdupq_n_u32(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, 4583 3 % kLanesPerBlock); 4584 return detail::LtN(BitCast(du, idx_mod), kLanes); 4585 } 4586 template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 8)> 4587 svbool_t FirstNPerBlock(D d) { 4588 const RebindToUnsigned<decltype(d)> du; 4589 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); 4590 const svuint64_t idx_mod = 4591 svdupq_n_u64(0 % kLanesPerBlock, 1 % kLanesPerBlock); 4592 return detail::LtN(BitCast(du, idx_mod), kLanes); 4593 } 4594 4595 } // namespace detail 4596 #endif // HWY_TARGET != HWY_SVE2_128 4597 4598 template <size_t kBytes, class D, class V = VFromD<D>> 4599 HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) { 4600 const Repartition<uint8_t, decltype(d)> d8; 4601 const auto hi8 = BitCast(d8, hi); 4602 const auto lo8 = BitCast(d8, lo); 4603 #if HWY_TARGET == HWY_SVE2_128 4604 return BitCast(d, detail::Ext<kBytes>(hi8, lo8)); 4605 #else 4606 const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes)); 4607 const auto lo_down = detail::Ext<kBytes>(lo8, lo8); 4608 const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); 4609 return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); 4610 #endif 4611 } 4612 4613 // ------------------------------ Shuffle2301 4614 template <class V> 4615 HWY_API V Shuffle2301(const V v) { 4616 const DFromV<V> d; 4617 static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); 4618 return Reverse2(d, v); 4619 } 4620 4621 // ------------------------------ Shuffle2103 4622 template <class V> 4623 HWY_API V Shuffle2103(const V v) { 4624 const DFromV<V> d; 4625 const Repartition<uint8_t, decltype(d)> d8; 4626 static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); 4627 const svuint8_t v8 = BitCast(d8, v); 4628 return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8)); 4629 } 4630 4631 // ------------------------------ Shuffle0321 4632 template <class V> 4633 HWY_API V Shuffle0321(const V v) { 4634 const DFromV<V> d; 4635 const Repartition<uint8_t, decltype(d)> d8; 4636 static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); 4637 const svuint8_t v8 = BitCast(d8, v); 4638 return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8)); 4639 } 4640 4641 // ------------------------------ Shuffle1032 4642 template <class V> 4643 HWY_API V Shuffle1032(const V v) { 4644 const DFromV<V> d; 4645 const Repartition<uint8_t, decltype(d)> d8; 4646 static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); 4647 const svuint8_t v8 = BitCast(d8, v); 4648 return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); 4649 } 4650 4651 // ------------------------------ Shuffle01 4652 template <class V> 4653 HWY_API V Shuffle01(const V v) { 4654 const DFromV<V> d; 4655 const Repartition<uint8_t, decltype(d)> d8; 4656 static_assert(sizeof(TFromD<decltype(d)>) == 8, "Defined for 64-bit types"); 4657 const svuint8_t v8 = BitCast(d8, v); 4658 return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); 4659 } 4660 4661 // ------------------------------ Shuffle0123 4662 template <class V> 4663 HWY_API V Shuffle0123(const V v) { 4664 return Shuffle2301(Shuffle1032(v)); 4665 } 4666 4667 // ------------------------------ ReverseBlocks (Reverse, Shuffle01) 4668 template <class D, class V = VFromD<D>> 4669 HWY_API V ReverseBlocks(D d, V v) { 4670 #if HWY_TARGET == HWY_SVE_256 4671 if (detail::IsFull(d)) { 4672 return SwapAdjacentBlocks(v); 4673 } else if (detail::IsFull(Twice<D>())) { 4674 return v; 4675 } 4676 #elif HWY_TARGET == HWY_SVE2_128 4677 (void)d; 4678 return v; 4679 #endif 4680 const Repartition<uint64_t, D> du64; 4681 return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v)))); 4682 } 4683 4684 // ------------------------------ TableLookupBytes 4685 4686 template <class V, class VI> 4687 HWY_API VI TableLookupBytes(const V v, const VI idx) { 4688 const DFromV<VI> d; 4689 const Repartition<uint8_t, decltype(d)> du8; 4690 #if HWY_TARGET == HWY_SVE2_128 4691 return BitCast(d, TableLookupLanes(BitCast(du8, v), BitCast(du8, idx))); 4692 #else 4693 const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0)); 4694 const auto idx8 = Add(BitCast(du8, idx), offsets128); 4695 return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8)); 4696 #endif 4697 } 4698 4699 template <class V, class VI> 4700 HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { 4701 const DFromV<VI> d; 4702 // Mask size must match vector type, so cast everything to this type. 4703 const Repartition<int8_t, decltype(d)> di8; 4704 4705 auto idx8 = BitCast(di8, idx); 4706 const auto msb = detail::LtN(idx8, 0); 4707 4708 const auto lookup = TableLookupBytes(BitCast(di8, v), idx8); 4709 return BitCast(d, IfThenZeroElse(msb, lookup)); 4710 } 4711 4712 // ------------------------------ Broadcast 4713 4714 #ifdef HWY_NATIVE_BROADCASTLANE 4715 #undef HWY_NATIVE_BROADCASTLANE 4716 #else 4717 #define HWY_NATIVE_BROADCASTLANE 4718 #endif 4719 4720 namespace detail { 4721 #define HWY_SVE_BROADCAST(BASE, CHAR, BITS, HALF, NAME, OP) \ 4722 template <int kLane> \ 4723 HWY_INLINE HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 4724 return sv##OP##_##CHAR##BITS(v, kLane); \ 4725 } 4726 4727 HWY_SVE_FOREACH(HWY_SVE_BROADCAST, BroadcastLane, dup_lane) 4728 #undef HWY_SVE_BROADCAST 4729 } // namespace detail 4730 4731 template <int kLane, class V> 4732 HWY_API V Broadcast(const V v) { 4733 const DFromV<V> d; 4734 const RebindToUnsigned<decltype(d)> du; 4735 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); 4736 static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane"); 4737 #if HWY_TARGET == HWY_SVE2_128 4738 return detail::BroadcastLane<kLane>(v); 4739 #else 4740 auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0)); 4741 if (kLane != 0) { 4742 idx = detail::AddN(idx, kLane); 4743 } 4744 return TableLookupLanes(v, idx); 4745 #endif 4746 } 4747 4748 template <int kLane, class V> 4749 HWY_API V BroadcastLane(const V v) { 4750 static_assert(0 <= kLane && kLane < HWY_MAX_LANES_V(V), "Invalid lane"); 4751 return detail::BroadcastLane<kLane>(v); 4752 } 4753 4754 // ------------------------------ ShiftLeftLanes 4755 4756 template <size_t kLanes, class D, class V = VFromD<D>> 4757 HWY_API V ShiftLeftLanes(D d, const V v) { 4758 const auto zero = Zero(d); 4759 const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes)); 4760 #if HWY_TARGET == HWY_SVE2_128 4761 return shifted; 4762 #else 4763 // Match x86 semantics by zeroing lower lanes in 128-bit blocks 4764 return IfThenElse(detail::FirstNPerBlock<kLanes>(d), zero, shifted); 4765 #endif 4766 } 4767 4768 template <size_t kLanes, class V> 4769 HWY_API V ShiftLeftLanes(const V v) { 4770 return ShiftLeftLanes<kLanes>(DFromV<V>(), v); 4771 } 4772 4773 // ------------------------------ ShiftRightLanes 4774 template <size_t kLanes, class D, class V = VFromD<D>> 4775 HWY_API V ShiftRightLanes(D d, V v) { 4776 // For capped/fractional vectors, clear upper lanes so we shift in zeros. 4777 if (!detail::IsFull(d)) { 4778 v = IfThenElseZero(detail::MakeMask(d), v); 4779 } 4780 4781 #if HWY_TARGET == HWY_SVE2_128 4782 return detail::Ext<kLanes>(Zero(d), v); 4783 #else 4784 const auto shifted = detail::Ext<kLanes>(v, v); 4785 // Match x86 semantics by zeroing upper lanes in 128-bit blocks 4786 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); 4787 const svbool_t mask = detail::FirstNPerBlock<kLanesPerBlock - kLanes>(d); 4788 return IfThenElseZero(mask, shifted); 4789 #endif 4790 } 4791 4792 // ------------------------------ ShiftLeftBytes 4793 4794 template <int kBytes, class D, class V = VFromD<D>> 4795 HWY_API V ShiftLeftBytes(const D d, const V v) { 4796 const Repartition<uint8_t, decltype(d)> d8; 4797 return BitCast(d, ShiftLeftLanes<kBytes>(BitCast(d8, v))); 4798 } 4799 4800 template <int kBytes, class V> 4801 HWY_API V ShiftLeftBytes(const V v) { 4802 return ShiftLeftBytes<kBytes>(DFromV<V>(), v); 4803 } 4804 4805 // ------------------------------ ShiftRightBytes 4806 template <int kBytes, class D, class V = VFromD<D>> 4807 HWY_API V ShiftRightBytes(const D d, const V v) { 4808 const Repartition<uint8_t, decltype(d)> d8; 4809 return BitCast(d, ShiftRightLanes<kBytes>(d8, BitCast(d8, v))); 4810 } 4811 4812 // ------------------------------ ZipLower 4813 4814 template <class V, class DW = RepartitionToWide<DFromV<V>>> 4815 HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) { 4816 const RepartitionToNarrow<DW> dn; 4817 static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch"); 4818 return BitCast(dw, InterleaveLower(dn, a, b)); 4819 } 4820 template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> 4821 HWY_API VFromD<DW> ZipLower(const V a, const V b) { 4822 return BitCast(DW(), InterleaveLower(D(), a, b)); 4823 } 4824 4825 // ------------------------------ ZipUpper 4826 template <class V, class DW = RepartitionToWide<DFromV<V>>> 4827 HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) { 4828 const RepartitionToNarrow<DW> dn; 4829 static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch"); 4830 return BitCast(dw, InterleaveUpper(dn, a, b)); 4831 } 4832 4833 // ================================================== Ops with dependencies 4834 4835 // ------------------------------ AddSub (Reverse2) 4836 4837 // NOTE: svcadd_f*_x(HWY_SVE_PTRUE(BITS), a, b, 90) computes a[i] - b[i + 1] in 4838 // the even lanes and a[i] + b[i - 1] in the odd lanes. 4839 4840 #define HWY_SVE_ADDSUB_F(BASE, CHAR, BITS, HALF, NAME, OP) \ 4841 HWY_API HWY_SVE_V(BASE, BITS) \ 4842 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 4843 const DFromV<decltype(b)> d; \ 4844 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, Reverse2(d, b), \ 4845 90); \ 4846 } 4847 4848 HWY_SVE_FOREACH_F(HWY_SVE_ADDSUB_F, AddSub, cadd) 4849 4850 #undef HWY_SVE_ADDSUB_F 4851 4852 // NOTE: svcadd_s*(a, b, 90) and svcadd_u*(a, b, 90) compute a[i] - b[i + 1] in 4853 // the even lanes and a[i] + b[i - 1] in the odd lanes. 4854 4855 #if HWY_SVE_HAVE_2 4856 #define HWY_SVE_ADDSUB_UI(BASE, CHAR, BITS, HALF, NAME, OP) \ 4857 HWY_API HWY_SVE_V(BASE, BITS) \ 4858 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ 4859 const DFromV<decltype(b)> d; \ 4860 return sv##OP##_##CHAR##BITS(a, Reverse2(d, b), 90); \ 4861 } 4862 4863 HWY_SVE_FOREACH_UI(HWY_SVE_ADDSUB_UI, AddSub, cadd) 4864 4865 #undef HWY_SVE_ADDSUB_UI 4866 4867 // Disable the default implementation of AddSub in generic_ops-inl.h on SVE2 4868 #undef HWY_IF_ADDSUB_V 4869 #define HWY_IF_ADDSUB_V(V) \ 4870 HWY_IF_LANES_GT_D(DFromV<V>, 1), \ 4871 hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr 4872 4873 #else // !HWY_SVE_HAVE_2 4874 4875 // Disable the default implementation of AddSub in generic_ops-inl.h for 4876 // floating-point vectors on SVE, but enable the default implementation of 4877 // AddSub in generic_ops-inl.h for integer vectors on SVE that do not support 4878 // SVE2 4879 #undef HWY_IF_ADDSUB_V 4880 #define HWY_IF_ADDSUB_V(V) \ 4881 HWY_IF_LANES_GT_D(DFromV<V>, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) 4882 4883 #endif // HWY_SVE_HAVE_2 4884 4885 // ------------------------------ MulAddSub (AddSub) 4886 4887 template <class V, HWY_IF_LANES_GT_D(DFromV<V>, 1), HWY_IF_FLOAT_V(V)> 4888 HWY_API V MulAddSub(V mul, V x, V sub_or_add) { 4889 using T = TFromV<V>; 4890 4891 const DFromV<V> d; 4892 const T neg_zero = ConvertScalarTo<T>(-0.0f); 4893 4894 return MulAdd(mul, x, AddSub(Set(d, neg_zero), sub_or_add)); 4895 } 4896 4897 #if HWY_SVE_HAVE_2 4898 4899 // Disable the default implementation of MulAddSub in generic_ops-inl.h on SVE2 4900 #undef HWY_IF_MULADDSUB_V 4901 #define HWY_IF_MULADDSUB_V(V) \ 4902 HWY_IF_LANES_GT_D(DFromV<V>, 1), \ 4903 hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr 4904 4905 template <class V, HWY_IF_LANES_GT_D(DFromV<V>, 1), 4906 HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 4907 HWY_API V MulAddSub(V mul, V x, V sub_or_add) { 4908 const DFromV<V> d; 4909 return MulAdd(mul, x, AddSub(Zero(d), sub_or_add)); 4910 } 4911 4912 #else // !HWY_SVE_HAVE_2 4913 4914 // Disable the default implementation of MulAddSub in generic_ops-inl.h for 4915 // floating-point vectors on SVE, but enable the default implementation of 4916 // AddSub in generic_ops-inl.h for integer vectors on SVE targets that do not 4917 // support SVE2 4918 #undef HWY_IF_MULADDSUB_V 4919 #define HWY_IF_MULADDSUB_V(V) \ 4920 HWY_IF_LANES_GT_D(DFromV<V>, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) 4921 4922 #endif // HWY_SVE_HAVE_2 4923 4924 // ------------------------------ PromoteTo bfloat16 (ZipLower) 4925 template <size_t N, int kPow2> 4926 HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> df32, VBF16 v) { 4927 const ScalableTag<uint16_t> du16; 4928 return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), BitCast(du16, v))); 4929 } 4930 4931 // ------------------------------ PromoteEvenTo/PromoteOddTo (ConcatOddFull) 4932 4933 namespace detail { 4934 4935 // Signed to signed PromoteEvenTo 4936 template <class D> 4937 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, 4938 hwy::SizeTag<2> /*to_lane_size_tag*/, 4939 hwy::SignedTag /*from_type_tag*/, D d_to, 4940 svint8_t v) { 4941 return svextb_s16_x(detail::PTrue(d_to), BitCast(d_to, v)); 4942 } 4943 4944 template <class D> 4945 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, 4946 hwy::SizeTag<4> /*to_lane_size_tag*/, 4947 hwy::SignedTag /*from_type_tag*/, D d_to, 4948 svint16_t v) { 4949 return svexth_s32_x(detail::PTrue(d_to), BitCast(d_to, v)); 4950 } 4951 4952 template <class D> 4953 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, 4954 hwy::SizeTag<8> /*to_lane_size_tag*/, 4955 hwy::SignedTag /*from_type_tag*/, D d_to, 4956 svint32_t v) { 4957 return svextw_s64_x(detail::PTrue(d_to), BitCast(d_to, v)); 4958 } 4959 4960 // F16->F32 PromoteEvenTo 4961 template <class D> 4962 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, 4963 hwy::SizeTag<4> /*to_lane_size_tag*/, 4964 hwy::FloatTag /*from_type_tag*/, D d_to, 4965 svfloat16_t v) { 4966 const Repartition<float, decltype(d_to)> d_from; 4967 return svcvt_f32_f16_x(detail::PTrue(d_from), v); 4968 } 4969 4970 // F32->F64 PromoteEvenTo 4971 template <class D> 4972 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, 4973 hwy::SizeTag<8> /*to_lane_size_tag*/, 4974 hwy::FloatTag /*from_type_tag*/, D d_to, 4975 svfloat32_t v) { 4976 const Repartition<float, decltype(d_to)> d_from; 4977 return svcvt_f64_f32_x(detail::PTrue(d_from), v); 4978 } 4979 4980 // I32->F64 PromoteEvenTo 4981 template <class D> 4982 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, 4983 hwy::SizeTag<8> /*to_lane_size_tag*/, 4984 hwy::SignedTag /*from_type_tag*/, D d_to, 4985 svint32_t v) { 4986 const Repartition<float, decltype(d_to)> d_from; 4987 return svcvt_f64_s32_x(detail::PTrue(d_from), v); 4988 } 4989 4990 // U32->F64 PromoteEvenTo 4991 template <class D> 4992 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, 4993 hwy::SizeTag<8> /*to_lane_size_tag*/, 4994 hwy::UnsignedTag /*from_type_tag*/, D d_to, 4995 svuint32_t v) { 4996 const Repartition<float, decltype(d_to)> d_from; 4997 return svcvt_f64_u32_x(detail::PTrue(d_from), v); 4998 } 4999 5000 // F32->I64 PromoteEvenTo 5001 template <class D> 5002 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, 5003 hwy::SizeTag<8> /*to_lane_size_tag*/, 5004 hwy::FloatTag /*from_type_tag*/, D d_to, 5005 svfloat32_t v) { 5006 const Repartition<float, decltype(d_to)> d_from; 5007 return svcvt_s64_f32_x(detail::PTrue(d_from), v); 5008 } 5009 5010 // F32->U64 PromoteEvenTo 5011 template <class D> 5012 HWY_INLINE VFromD<D> PromoteEvenTo(hwy::UnsignedTag /*to_type_tag*/, 5013 hwy::SizeTag<8> /*to_lane_size_tag*/, 5014 hwy::FloatTag /*from_type_tag*/, D d_to, 5015 svfloat32_t v) { 5016 const Repartition<float, decltype(d_to)> d_from; 5017 return svcvt_u64_f32_x(detail::PTrue(d_from), v); 5018 } 5019 5020 // F16->F32 PromoteOddTo 5021 template <class D> 5022 HWY_INLINE VFromD<D> PromoteOddTo(hwy::FloatTag to_type_tag, 5023 hwy::SizeTag<4> to_lane_size_tag, 5024 hwy::FloatTag from_type_tag, D d_to, 5025 svfloat16_t v) { 5026 return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, 5027 DupOdd(v)); 5028 } 5029 5030 // I32/U32/F32->F64 PromoteOddTo 5031 template <class FromTypeTag, class D, class V> 5032 HWY_INLINE VFromD<D> PromoteOddTo(hwy::FloatTag to_type_tag, 5033 hwy::SizeTag<8> to_lane_size_tag, 5034 FromTypeTag from_type_tag, D d_to, V v) { 5035 return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, 5036 DupOdd(v)); 5037 } 5038 5039 // F32->I64/U64 PromoteOddTo 5040 template <class ToTypeTag, class D, HWY_IF_UI64_D(D)> 5041 HWY_INLINE VFromD<D> PromoteOddTo(ToTypeTag to_type_tag, 5042 hwy::SizeTag<8> to_lane_size_tag, 5043 hwy::FloatTag from_type_tag, D d_to, 5044 svfloat32_t v) { 5045 return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, 5046 DupOdd(v)); 5047 } 5048 5049 } // namespace detail 5050 5051 // ------------------------------ ReorderDemote2To (OddEven) 5052 5053 template <size_t N, int kPow2> 5054 HWY_API VBF16 ReorderDemote2To(Simd<bfloat16_t, N, kPow2> dbf16, svfloat32_t a, 5055 svfloat32_t b) { 5056 #if HWY_SVE_HAVE_F32_TO_BF16C 5057 const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); 5058 return svcvtnt_bf16_f32_x(b_in_even, detail::PTrue(dbf16), a); 5059 #else 5060 (void)dbf16; 5061 const auto a_in_odd = 5062 BitCast(ScalableTag<uint16_t>(), detail::RoundF32ForDemoteToBF16(a)); 5063 const auto b_in_odd = 5064 BitCast(ScalableTag<uint16_t>(), detail::RoundF32ForDemoteToBF16(b)); 5065 return BitCast(dbf16, detail::InterleaveOdd(b_in_odd, a_in_odd)); 5066 #endif 5067 } 5068 5069 template <size_t N, int kPow2> 5070 HWY_API svint16_t ReorderDemote2To(Simd<int16_t, N, kPow2> d16, svint32_t a, 5071 svint32_t b) { 5072 #if HWY_SVE_HAVE_2 5073 (void)d16; 5074 const svint16_t a_in_even = svqxtnb_s32(a); 5075 return svqxtnt_s32(a_in_even, b); 5076 #else 5077 const svint16_t a16 = BitCast(d16, detail::SaturateI<int16_t>(a)); 5078 const svint16_t b16 = BitCast(d16, detail::SaturateI<int16_t>(b)); 5079 return detail::InterleaveEven(a16, b16); 5080 #endif 5081 } 5082 5083 template <size_t N, int kPow2> 5084 HWY_API svuint16_t ReorderDemote2To(Simd<uint16_t, N, kPow2> d16, svint32_t a, 5085 svint32_t b) { 5086 #if HWY_SVE_HAVE_2 5087 (void)d16; 5088 const svuint16_t a_in_even = svqxtunb_s32(a); 5089 return svqxtunt_s32(a_in_even, b); 5090 #else 5091 const Repartition<uint32_t, decltype(d16)> du32; 5092 const svuint32_t clamped_a = BitCast(du32, detail::MaxN(a, 0)); 5093 const svuint32_t clamped_b = BitCast(du32, detail::MaxN(b, 0)); 5094 const svuint16_t a16 = BitCast(d16, detail::SaturateU<uint16_t>(clamped_a)); 5095 const svuint16_t b16 = BitCast(d16, detail::SaturateU<uint16_t>(clamped_b)); 5096 return detail::InterleaveEven(a16, b16); 5097 #endif 5098 } 5099 5100 template <size_t N, int kPow2> 5101 HWY_API svuint16_t ReorderDemote2To(Simd<uint16_t, N, kPow2> d16, svuint32_t a, 5102 svuint32_t b) { 5103 #if HWY_SVE_HAVE_2 5104 (void)d16; 5105 const svuint16_t a_in_even = svqxtnb_u32(a); 5106 return svqxtnt_u32(a_in_even, b); 5107 #else 5108 const svuint16_t a16 = BitCast(d16, detail::SaturateU<uint16_t>(a)); 5109 const svuint16_t b16 = BitCast(d16, detail::SaturateU<uint16_t>(b)); 5110 return detail::InterleaveEven(a16, b16); 5111 #endif 5112 } 5113 5114 template <size_t N, int kPow2> 5115 HWY_API svint8_t ReorderDemote2To(Simd<int8_t, N, kPow2> d8, svint16_t a, 5116 svint16_t b) { 5117 #if HWY_SVE_HAVE_2 5118 (void)d8; 5119 const svint8_t a_in_even = svqxtnb_s16(a); 5120 return svqxtnt_s16(a_in_even, b); 5121 #else 5122 const svint8_t a8 = BitCast(d8, detail::SaturateI<int8_t>(a)); 5123 const svint8_t b8 = BitCast(d8, detail::SaturateI<int8_t>(b)); 5124 return detail::InterleaveEven(a8, b8); 5125 #endif 5126 } 5127 5128 template <size_t N, int kPow2> 5129 HWY_API svuint8_t ReorderDemote2To(Simd<uint8_t, N, kPow2> d8, svint16_t a, 5130 svint16_t b) { 5131 #if HWY_SVE_HAVE_2 5132 (void)d8; 5133 const svuint8_t a_in_even = svqxtunb_s16(a); 5134 return svqxtunt_s16(a_in_even, b); 5135 #else 5136 const Repartition<uint16_t, decltype(d8)> du16; 5137 const svuint16_t clamped_a = BitCast(du16, detail::MaxN(a, 0)); 5138 const svuint16_t clamped_b = BitCast(du16, detail::MaxN(b, 0)); 5139 const svuint8_t a8 = BitCast(d8, detail::SaturateU<uint8_t>(clamped_a)); 5140 const svuint8_t b8 = BitCast(d8, detail::SaturateU<uint8_t>(clamped_b)); 5141 return detail::InterleaveEven(a8, b8); 5142 #endif 5143 } 5144 5145 template <size_t N, int kPow2> 5146 HWY_API svuint8_t ReorderDemote2To(Simd<uint8_t, N, kPow2> d8, svuint16_t a, 5147 svuint16_t b) { 5148 #if HWY_SVE_HAVE_2 5149 (void)d8; 5150 const svuint8_t a_in_even = svqxtnb_u16(a); 5151 return svqxtnt_u16(a_in_even, b); 5152 #else 5153 const svuint8_t a8 = BitCast(d8, detail::SaturateU<uint8_t>(a)); 5154 const svuint8_t b8 = BitCast(d8, detail::SaturateU<uint8_t>(b)); 5155 return detail::InterleaveEven(a8, b8); 5156 #endif 5157 } 5158 5159 template <size_t N, int kPow2> 5160 HWY_API svint32_t ReorderDemote2To(Simd<int32_t, N, kPow2> d32, svint64_t a, 5161 svint64_t b) { 5162 #if HWY_SVE_HAVE_2 5163 (void)d32; 5164 const svint32_t a_in_even = svqxtnb_s64(a); 5165 return svqxtnt_s64(a_in_even, b); 5166 #else 5167 const svint32_t a32 = BitCast(d32, detail::SaturateI<int32_t>(a)); 5168 const svint32_t b32 = BitCast(d32, detail::SaturateI<int32_t>(b)); 5169 return detail::InterleaveEven(a32, b32); 5170 #endif 5171 } 5172 5173 template <size_t N, int kPow2> 5174 HWY_API svuint32_t ReorderDemote2To(Simd<uint32_t, N, kPow2> d32, svint64_t a, 5175 svint64_t b) { 5176 #if HWY_SVE_HAVE_2 5177 (void)d32; 5178 const svuint32_t a_in_even = svqxtunb_s64(a); 5179 return svqxtunt_s64(a_in_even, b); 5180 #else 5181 const Repartition<uint64_t, decltype(d32)> du64; 5182 const svuint64_t clamped_a = BitCast(du64, detail::MaxN(a, 0)); 5183 const svuint64_t clamped_b = BitCast(du64, detail::MaxN(b, 0)); 5184 const svuint32_t a32 = BitCast(d32, detail::SaturateU<uint32_t>(clamped_a)); 5185 const svuint32_t b32 = BitCast(d32, detail::SaturateU<uint32_t>(clamped_b)); 5186 return detail::InterleaveEven(a32, b32); 5187 #endif 5188 } 5189 5190 template <size_t N, int kPow2> 5191 HWY_API svuint32_t ReorderDemote2To(Simd<uint32_t, N, kPow2> d32, svuint64_t a, 5192 svuint64_t b) { 5193 #if HWY_SVE_HAVE_2 5194 (void)d32; 5195 const svuint32_t a_in_even = svqxtnb_u64(a); 5196 return svqxtnt_u64(a_in_even, b); 5197 #else 5198 const svuint32_t a32 = BitCast(d32, detail::SaturateU<uint32_t>(a)); 5199 const svuint32_t b32 = BitCast(d32, detail::SaturateU<uint32_t>(b)); 5200 return detail::InterleaveEven(a32, b32); 5201 #endif 5202 } 5203 5204 template <class D, class V, HWY_IF_SIGNED_D(D), HWY_IF_UNSIGNED_V(V), 5205 HWY_IF_T_SIZE_D(D, sizeof(TFromV<V>) / 2)> 5206 HWY_API VFromD<D> ReorderDemote2To(D dn, V a, V b) { 5207 const auto clamped_a = BitCast(dn, detail::SaturateU<TFromD<D>>(a)); 5208 const auto clamped_b = BitCast(dn, detail::SaturateU<TFromD<D>>(b)); 5209 return detail::InterleaveEven(clamped_a, clamped_b); 5210 } 5211 5212 template <class D, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>), 5213 HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), 5214 HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2)> 5215 HWY_API VFromD<D> OrderedDemote2To(D dn, V a, V b) { 5216 const Half<decltype(dn)> dnh; 5217 const auto demoted_a = DemoteTo(dnh, a); 5218 const auto demoted_b = DemoteTo(dnh, b); 5219 return Combine(dn, demoted_b, demoted_a); 5220 } 5221 5222 template <size_t N, int kPow2> 5223 HWY_API VBF16 OrderedDemote2To(Simd<bfloat16_t, N, kPow2> dbf16, svfloat32_t a, 5224 svfloat32_t b) { 5225 #if HWY_SVE_HAVE_F32_TO_BF16C 5226 (void)dbf16; 5227 const VBF16 a_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), a); 5228 const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); 5229 return ConcatEven(dbf16, b_in_even, a_in_even); 5230 #else 5231 const RebindToUnsigned<decltype(dbf16)> du16; 5232 const svuint16_t a_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(a)); 5233 const svuint16_t b_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(b)); 5234 return BitCast(dbf16, ConcatOdd(du16, b_in_odd, a_in_odd)); // lower half 5235 #endif 5236 } 5237 5238 // ------------------------------ I8/U8/I16/U16 Div 5239 5240 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), 5241 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2))> 5242 HWY_API V Div(V a, V b) { 5243 const DFromV<decltype(a)> d; 5244 const Half<decltype(d)> dh; 5245 const RepartitionToWide<decltype(d)> dw; 5246 5247 const auto q_lo = 5248 Div(PromoteTo(dw, LowerHalf(dh, a)), PromoteTo(dw, LowerHalf(dh, b))); 5249 const auto q_hi = Div(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)); 5250 5251 return OrderedDemote2To(d, q_lo, q_hi); 5252 } 5253 5254 // ------------------------------ I8/U8/I16/U16 MaskedDivOr 5255 template <class V, class M, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2)), 5256 HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 5257 HWY_API V MaskedDivOr(V no, M m, V a, V b) { 5258 return IfThenElse(m, Div(a, b), no); 5259 } 5260 5261 template <class V, class M, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2)), 5262 HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 5263 HWY_API V MaskedDiv(M m, V a, V b) { 5264 return IfThenElseZero(m, Div(a, b)); 5265 } 5266 5267 // ------------------------------ Mod (Div, NegMulAdd) 5268 template <class V> 5269 HWY_API V Mod(V a, V b) { 5270 return NegMulAdd(Div(a, b), b, a); 5271 } 5272 5273 // ------------------------------ MaskedModOr (Mod) 5274 template <class V, class M> 5275 HWY_API V MaskedModOr(V no, M m, V a, V b) { 5276 return IfThenElse(m, Mod(a, b), no); 5277 } 5278 5279 // ------------------------------ IfNegativeThenElse (BroadcastSignBit) 5280 template <class V> 5281 HWY_API V IfNegativeThenElse(V v, V yes, V no) { 5282 static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float"); 5283 return IfThenElse(IsNegative(v), yes, no); 5284 } 5285 // ------------------------------ IfNegativeThenNegOrUndefIfZero 5286 5287 #ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG 5288 #undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG 5289 #else 5290 #define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG 5291 #endif 5292 5293 #define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \ 5294 HWY_API HWY_SVE_V(BASE, BITS) \ 5295 NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \ 5296 return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \ 5297 } 5298 5299 HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg) 5300 5301 #undef HWY_SVE_NEG_IF 5302 5303 // ------------------------------ AverageRound (ShiftRight) 5304 5305 #ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 5306 #undef HWY_NATIVE_AVERAGE_ROUND_UI32 5307 #else 5308 #define HWY_NATIVE_AVERAGE_ROUND_UI32 5309 #endif 5310 5311 #ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 5312 #undef HWY_NATIVE_AVERAGE_ROUND_UI64 5313 #else 5314 #define HWY_NATIVE_AVERAGE_ROUND_UI64 5315 #endif 5316 5317 #if HWY_SVE_HAVE_2 5318 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) 5319 #else 5320 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 5321 HWY_API V AverageRound(const V a, const V b) { 5322 return Sub(Or(a, b), ShiftRight<1>(Xor(a, b))); 5323 } 5324 #endif // HWY_SVE_HAVE_2 5325 5326 // ------------------------------ LoadMaskBits (TestBit) 5327 5328 // `p` points to at least 8 readable bytes, not all of which need be valid. 5329 template <class D, HWY_IF_T_SIZE_D(D, 1)> 5330 HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { 5331 #if HWY_COMPILER_CLANG >= 2200 || HWY_COMPILER_GCC_ACTUAL >= 1200 5332 typedef svbool_t UnalignedSveMaskT 5333 __attribute__((__aligned__(1), __may_alias__)); 5334 (void)d; 5335 return *reinterpret_cast<const UnalignedSveMaskT*>(bits); 5336 #else 5337 // TODO(janwas): with SVE2.1, load to vector, then PMOV 5338 const RebindToUnsigned<D> du; 5339 const svuint8_t iota = Iota(du, 0); 5340 5341 // Load correct number of bytes (bits/8) with 7 zeros after each. 5342 const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits)); 5343 // Replicate bytes 8x such that each byte contains the bit that governs it. 5344 const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota)); 5345 5346 const svuint8_t bit = 5347 svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); 5348 return TestBit(rep8, bit); 5349 #endif 5350 } 5351 5352 template <class D, HWY_IF_T_SIZE_D(D, 2)> 5353 HWY_INLINE svbool_t LoadMaskBits(D /* tag */, 5354 const uint8_t* HWY_RESTRICT bits) { 5355 const RebindToUnsigned<D> du; 5356 const Repartition<uint8_t, D> du8; 5357 5358 // There may be up to 128 bits; avoid reading past the end. 5359 const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits); 5360 5361 // Replicate bytes 16x such that each lane contains the bit that governs it. 5362 const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0))); 5363 5364 const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); 5365 return TestBit(BitCast(du, rep16), bit); 5366 } 5367 5368 template <class D, HWY_IF_T_SIZE_D(D, 4)> 5369 HWY_INLINE svbool_t LoadMaskBits(D /* tag */, 5370 const uint8_t* HWY_RESTRICT bits) { 5371 const RebindToUnsigned<D> du; 5372 const Repartition<uint8_t, D> du8; 5373 5374 // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable, 5375 // so we can skip computing the actual length (Lanes(du)+7)/8. 5376 const svuint8_t bytes = svld1(FirstN(du8, 8), bits); 5377 5378 // Replicate bytes 32x such that each lane contains the bit that governs it. 5379 const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0))); 5380 5381 // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 .. 5382 const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7)); 5383 5384 return TestBit(BitCast(du, rep32), bit); 5385 } 5386 5387 template <class D, HWY_IF_T_SIZE_D(D, 8)> 5388 HWY_INLINE svbool_t LoadMaskBits(D /* tag */, 5389 const uint8_t* HWY_RESTRICT bits) { 5390 const RebindToUnsigned<D> du; 5391 5392 // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane. 5393 // The "at least 8 byte" guarantee in quick_reference ensures this is safe. 5394 uint32_t mask_bits; 5395 CopyBytes<4>(bits, &mask_bits); // copy from bytes 5396 const auto vbits = Set(du, mask_bits); 5397 5398 // 2 ^ {0,1, .., 31}, will not have more lanes than that. 5399 const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0)); 5400 5401 return TestBit(vbits, bit); 5402 } 5403 5404 // ------------------------------ Dup128MaskFromMaskBits 5405 5406 template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_LE_D(D, 8)> 5407 HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { 5408 const RebindToUnsigned<decltype(d)> du; 5409 5410 constexpr size_t kN = MaxLanes(d); 5411 if (kN < 8) mask_bits &= (1u << kN) - 1; 5412 5413 // Replicate the lower 8 bits of mask_bits to each u8 lane 5414 const svuint8_t bytes = BitCast(du, Set(du, static_cast<uint8_t>(mask_bits))); 5415 5416 const svuint8_t bit = 5417 svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); 5418 return TestBit(bytes, bit); 5419 } 5420 5421 template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_GT_D(D, 8)> 5422 HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { 5423 const RebindToUnsigned<decltype(d)> du; 5424 const Repartition<uint16_t, decltype(du)> du16; 5425 5426 // Replicate the lower 16 bits of mask_bits to each u16 lane of a u16 vector, 5427 // and then bitcast the replicated mask_bits to a u8 vector 5428 const svuint8_t bytes = 5429 BitCast(du, Set(du16, static_cast<uint16_t>(mask_bits))); 5430 // Replicate bytes 8x such that each byte contains the bit that governs it. 5431 const svuint8_t rep8 = svtbl_u8(bytes, ShiftRight<3>(Iota(du, 0))); 5432 5433 const svuint8_t bit = 5434 svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); 5435 return TestBit(rep8, bit); 5436 } 5437 5438 template <class D, HWY_IF_T_SIZE_D(D, 2)> 5439 HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { 5440 const RebindToUnsigned<decltype(d)> du; 5441 const Repartition<uint8_t, decltype(d)> du8; 5442 5443 constexpr size_t kN = MaxLanes(d); 5444 if (kN < 8) mask_bits &= (1u << kN) - 1; 5445 5446 // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits 5447 const svuint8_t bytes = Set(du8, static_cast<uint8_t>(mask_bits)); 5448 5449 const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); 5450 return TestBit(BitCast(du, bytes), bit); 5451 } 5452 5453 template <class D, HWY_IF_T_SIZE_D(D, 4)> 5454 HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { 5455 const RebindToUnsigned<decltype(d)> du; 5456 const Repartition<uint8_t, decltype(d)> du8; 5457 5458 constexpr size_t kN = MaxLanes(d); 5459 if (kN < 4) mask_bits &= (1u << kN) - 1; 5460 5461 // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits 5462 const svuint8_t bytes = Set(du8, static_cast<uint8_t>(mask_bits)); 5463 5464 const svuint32_t bit = svdupq_n_u32(1, 2, 4, 8); 5465 return TestBit(BitCast(du, bytes), bit); 5466 } 5467 5468 template <class D, HWY_IF_T_SIZE_D(D, 8)> 5469 HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) { 5470 const RebindToUnsigned<decltype(d)> du; 5471 const Repartition<uint8_t, decltype(d)> du8; 5472 5473 if (MaxLanes(d) < 2) mask_bits &= 1u; 5474 5475 // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits 5476 const svuint8_t bytes = Set(du8, static_cast<uint8_t>(mask_bits)); 5477 5478 const svuint64_t bit = svdupq_n_u64(1, 2); 5479 return TestBit(BitCast(du, bytes), bit); 5480 } 5481 5482 // ------------------------------ StoreMaskBits (BitsFromMask) 5483 5484 // `p` points to at least 8 writable bytes. 5485 // TODO(janwas): with SVE2.1, use PMOV to store to vector, then StoreU 5486 template <class D> 5487 HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) { 5488 #if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 5489 constexpr size_t N = MaxLanes(d); 5490 const uint64_t bits64 = BitsFromMask(d, m); 5491 HWY_IF_CONSTEXPR(N < 8) { 5492 // BitsFromMask guarantees upper bits are zero, hence no masking. 5493 bits[0] = static_cast<uint8_t>(bits64); 5494 } 5495 else { 5496 static_assert(N % 8 == 0, "N is pow2 >= 8, hence divisible"); 5497 static_assert(HWY_IS_LITTLE_ENDIAN, ""); 5498 hwy::CopyBytes<N / 8>(&bits64, bits); 5499 } 5500 constexpr size_t num_bytes = hwy::DivCeil(N, size_t{8}); 5501 return num_bytes; 5502 #else 5503 svuint64_t bits_in_u64 = detail::BitsFromBool(detail::BoolFromMask<D>(m)); 5504 5505 const size_t num_bits = Lanes(d); 5506 const size_t num_bytes = hwy::DivCeil(num_bits, size_t{8}); 5507 5508 // Truncate each u64 to 8 bits and store to u8. 5509 svst1b_u64(FirstN(ScalableTag<uint64_t>(), num_bytes), bits, bits_in_u64); 5510 5511 // Non-full byte, need to clear the undefined upper bits. Can happen for 5512 // capped/fractional vectors or large T and small hardware vectors. 5513 if (num_bits < 8) { 5514 const int mask = static_cast<int>((1ull << num_bits) - 1); 5515 bits[0] = static_cast<uint8_t>(bits[0] & mask); 5516 } 5517 // Else: we wrote full bytes because num_bits is a power of two >= 8. 5518 5519 return num_bytes; 5520 #endif // HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 5521 } 5522 5523 // ------------------------------ CompressBits (LoadMaskBits) 5524 template <class V, HWY_IF_NOT_T_SIZE_V(V, 1)> 5525 HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { 5526 return Compress(v, LoadMaskBits(DFromV<V>(), bits)); 5527 } 5528 5529 // ------------------------------ CompressBitsStore (LoadMaskBits) 5530 template <class D, HWY_IF_NOT_T_SIZE_D(D, 1)> 5531 HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits, 5532 D d, TFromD<D>* HWY_RESTRICT unaligned) { 5533 return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); 5534 } 5535 5536 // ------------------------------ Expand (StoreMaskBits) 5537 5538 #ifdef HWY_NATIVE_EXPAND 5539 #undef HWY_NATIVE_EXPAND 5540 #else 5541 #define HWY_NATIVE_EXPAND 5542 #endif 5543 5544 namespace detail { 5545 5546 HWY_INLINE svuint8_t IndicesForExpandFromBits(uint64_t mask_bits) { 5547 const CappedTag<uint8_t, 8> du8; 5548 alignas(16) static constexpr uint8_t table[8 * 256] = { 5549 // PrintExpand8x8Tables 5550 128, 128, 128, 128, 128, 128, 128, 128, // 5551 0, 128, 128, 128, 128, 128, 128, 128, // 5552 128, 0, 128, 128, 128, 128, 128, 128, // 5553 0, 1, 128, 128, 128, 128, 128, 128, // 5554 128, 128, 0, 128, 128, 128, 128, 128, // 5555 0, 128, 1, 128, 128, 128, 128, 128, // 5556 128, 0, 1, 128, 128, 128, 128, 128, // 5557 0, 1, 2, 128, 128, 128, 128, 128, // 5558 128, 128, 128, 0, 128, 128, 128, 128, // 5559 0, 128, 128, 1, 128, 128, 128, 128, // 5560 128, 0, 128, 1, 128, 128, 128, 128, // 5561 0, 1, 128, 2, 128, 128, 128, 128, // 5562 128, 128, 0, 1, 128, 128, 128, 128, // 5563 0, 128, 1, 2, 128, 128, 128, 128, // 5564 128, 0, 1, 2, 128, 128, 128, 128, // 5565 0, 1, 2, 3, 128, 128, 128, 128, // 5566 128, 128, 128, 128, 0, 128, 128, 128, // 5567 0, 128, 128, 128, 1, 128, 128, 128, // 5568 128, 0, 128, 128, 1, 128, 128, 128, // 5569 0, 1, 128, 128, 2, 128, 128, 128, // 5570 128, 128, 0, 128, 1, 128, 128, 128, // 5571 0, 128, 1, 128, 2, 128, 128, 128, // 5572 128, 0, 1, 128, 2, 128, 128, 128, // 5573 0, 1, 2, 128, 3, 128, 128, 128, // 5574 128, 128, 128, 0, 1, 128, 128, 128, // 5575 0, 128, 128, 1, 2, 128, 128, 128, // 5576 128, 0, 128, 1, 2, 128, 128, 128, // 5577 0, 1, 128, 2, 3, 128, 128, 128, // 5578 128, 128, 0, 1, 2, 128, 128, 128, // 5579 0, 128, 1, 2, 3, 128, 128, 128, // 5580 128, 0, 1, 2, 3, 128, 128, 128, // 5581 0, 1, 2, 3, 4, 128, 128, 128, // 5582 128, 128, 128, 128, 128, 0, 128, 128, // 5583 0, 128, 128, 128, 128, 1, 128, 128, // 5584 128, 0, 128, 128, 128, 1, 128, 128, // 5585 0, 1, 128, 128, 128, 2, 128, 128, // 5586 128, 128, 0, 128, 128, 1, 128, 128, // 5587 0, 128, 1, 128, 128, 2, 128, 128, // 5588 128, 0, 1, 128, 128, 2, 128, 128, // 5589 0, 1, 2, 128, 128, 3, 128, 128, // 5590 128, 128, 128, 0, 128, 1, 128, 128, // 5591 0, 128, 128, 1, 128, 2, 128, 128, // 5592 128, 0, 128, 1, 128, 2, 128, 128, // 5593 0, 1, 128, 2, 128, 3, 128, 128, // 5594 128, 128, 0, 1, 128, 2, 128, 128, // 5595 0, 128, 1, 2, 128, 3, 128, 128, // 5596 128, 0, 1, 2, 128, 3, 128, 128, // 5597 0, 1, 2, 3, 128, 4, 128, 128, // 5598 128, 128, 128, 128, 0, 1, 128, 128, // 5599 0, 128, 128, 128, 1, 2, 128, 128, // 5600 128, 0, 128, 128, 1, 2, 128, 128, // 5601 0, 1, 128, 128, 2, 3, 128, 128, // 5602 128, 128, 0, 128, 1, 2, 128, 128, // 5603 0, 128, 1, 128, 2, 3, 128, 128, // 5604 128, 0, 1, 128, 2, 3, 128, 128, // 5605 0, 1, 2, 128, 3, 4, 128, 128, // 5606 128, 128, 128, 0, 1, 2, 128, 128, // 5607 0, 128, 128, 1, 2, 3, 128, 128, // 5608 128, 0, 128, 1, 2, 3, 128, 128, // 5609 0, 1, 128, 2, 3, 4, 128, 128, // 5610 128, 128, 0, 1, 2, 3, 128, 128, // 5611 0, 128, 1, 2, 3, 4, 128, 128, // 5612 128, 0, 1, 2, 3, 4, 128, 128, // 5613 0, 1, 2, 3, 4, 5, 128, 128, // 5614 128, 128, 128, 128, 128, 128, 0, 128, // 5615 0, 128, 128, 128, 128, 128, 1, 128, // 5616 128, 0, 128, 128, 128, 128, 1, 128, // 5617 0, 1, 128, 128, 128, 128, 2, 128, // 5618 128, 128, 0, 128, 128, 128, 1, 128, // 5619 0, 128, 1, 128, 128, 128, 2, 128, // 5620 128, 0, 1, 128, 128, 128, 2, 128, // 5621 0, 1, 2, 128, 128, 128, 3, 128, // 5622 128, 128, 128, 0, 128, 128, 1, 128, // 5623 0, 128, 128, 1, 128, 128, 2, 128, // 5624 128, 0, 128, 1, 128, 128, 2, 128, // 5625 0, 1, 128, 2, 128, 128, 3, 128, // 5626 128, 128, 0, 1, 128, 128, 2, 128, // 5627 0, 128, 1, 2, 128, 128, 3, 128, // 5628 128, 0, 1, 2, 128, 128, 3, 128, // 5629 0, 1, 2, 3, 128, 128, 4, 128, // 5630 128, 128, 128, 128, 0, 128, 1, 128, // 5631 0, 128, 128, 128, 1, 128, 2, 128, // 5632 128, 0, 128, 128, 1, 128, 2, 128, // 5633 0, 1, 128, 128, 2, 128, 3, 128, // 5634 128, 128, 0, 128, 1, 128, 2, 128, // 5635 0, 128, 1, 128, 2, 128, 3, 128, // 5636 128, 0, 1, 128, 2, 128, 3, 128, // 5637 0, 1, 2, 128, 3, 128, 4, 128, // 5638 128, 128, 128, 0, 1, 128, 2, 128, // 5639 0, 128, 128, 1, 2, 128, 3, 128, // 5640 128, 0, 128, 1, 2, 128, 3, 128, // 5641 0, 1, 128, 2, 3, 128, 4, 128, // 5642 128, 128, 0, 1, 2, 128, 3, 128, // 5643 0, 128, 1, 2, 3, 128, 4, 128, // 5644 128, 0, 1, 2, 3, 128, 4, 128, // 5645 0, 1, 2, 3, 4, 128, 5, 128, // 5646 128, 128, 128, 128, 128, 0, 1, 128, // 5647 0, 128, 128, 128, 128, 1, 2, 128, // 5648 128, 0, 128, 128, 128, 1, 2, 128, // 5649 0, 1, 128, 128, 128, 2, 3, 128, // 5650 128, 128, 0, 128, 128, 1, 2, 128, // 5651 0, 128, 1, 128, 128, 2, 3, 128, // 5652 128, 0, 1, 128, 128, 2, 3, 128, // 5653 0, 1, 2, 128, 128, 3, 4, 128, // 5654 128, 128, 128, 0, 128, 1, 2, 128, // 5655 0, 128, 128, 1, 128, 2, 3, 128, // 5656 128, 0, 128, 1, 128, 2, 3, 128, // 5657 0, 1, 128, 2, 128, 3, 4, 128, // 5658 128, 128, 0, 1, 128, 2, 3, 128, // 5659 0, 128, 1, 2, 128, 3, 4, 128, // 5660 128, 0, 1, 2, 128, 3, 4, 128, // 5661 0, 1, 2, 3, 128, 4, 5, 128, // 5662 128, 128, 128, 128, 0, 1, 2, 128, // 5663 0, 128, 128, 128, 1, 2, 3, 128, // 5664 128, 0, 128, 128, 1, 2, 3, 128, // 5665 0, 1, 128, 128, 2, 3, 4, 128, // 5666 128, 128, 0, 128, 1, 2, 3, 128, // 5667 0, 128, 1, 128, 2, 3, 4, 128, // 5668 128, 0, 1, 128, 2, 3, 4, 128, // 5669 0, 1, 2, 128, 3, 4, 5, 128, // 5670 128, 128, 128, 0, 1, 2, 3, 128, // 5671 0, 128, 128, 1, 2, 3, 4, 128, // 5672 128, 0, 128, 1, 2, 3, 4, 128, // 5673 0, 1, 128, 2, 3, 4, 5, 128, // 5674 128, 128, 0, 1, 2, 3, 4, 128, // 5675 0, 128, 1, 2, 3, 4, 5, 128, // 5676 128, 0, 1, 2, 3, 4, 5, 128, // 5677 0, 1, 2, 3, 4, 5, 6, 128, // 5678 128, 128, 128, 128, 128, 128, 128, 0, // 5679 0, 128, 128, 128, 128, 128, 128, 1, // 5680 128, 0, 128, 128, 128, 128, 128, 1, // 5681 0, 1, 128, 128, 128, 128, 128, 2, // 5682 128, 128, 0, 128, 128, 128, 128, 1, // 5683 0, 128, 1, 128, 128, 128, 128, 2, // 5684 128, 0, 1, 128, 128, 128, 128, 2, // 5685 0, 1, 2, 128, 128, 128, 128, 3, // 5686 128, 128, 128, 0, 128, 128, 128, 1, // 5687 0, 128, 128, 1, 128, 128, 128, 2, // 5688 128, 0, 128, 1, 128, 128, 128, 2, // 5689 0, 1, 128, 2, 128, 128, 128, 3, // 5690 128, 128, 0, 1, 128, 128, 128, 2, // 5691 0, 128, 1, 2, 128, 128, 128, 3, // 5692 128, 0, 1, 2, 128, 128, 128, 3, // 5693 0, 1, 2, 3, 128, 128, 128, 4, // 5694 128, 128, 128, 128, 0, 128, 128, 1, // 5695 0, 128, 128, 128, 1, 128, 128, 2, // 5696 128, 0, 128, 128, 1, 128, 128, 2, // 5697 0, 1, 128, 128, 2, 128, 128, 3, // 5698 128, 128, 0, 128, 1, 128, 128, 2, // 5699 0, 128, 1, 128, 2, 128, 128, 3, // 5700 128, 0, 1, 128, 2, 128, 128, 3, // 5701 0, 1, 2, 128, 3, 128, 128, 4, // 5702 128, 128, 128, 0, 1, 128, 128, 2, // 5703 0, 128, 128, 1, 2, 128, 128, 3, // 5704 128, 0, 128, 1, 2, 128, 128, 3, // 5705 0, 1, 128, 2, 3, 128, 128, 4, // 5706 128, 128, 0, 1, 2, 128, 128, 3, // 5707 0, 128, 1, 2, 3, 128, 128, 4, // 5708 128, 0, 1, 2, 3, 128, 128, 4, // 5709 0, 1, 2, 3, 4, 128, 128, 5, // 5710 128, 128, 128, 128, 128, 0, 128, 1, // 5711 0, 128, 128, 128, 128, 1, 128, 2, // 5712 128, 0, 128, 128, 128, 1, 128, 2, // 5713 0, 1, 128, 128, 128, 2, 128, 3, // 5714 128, 128, 0, 128, 128, 1, 128, 2, // 5715 0, 128, 1, 128, 128, 2, 128, 3, // 5716 128, 0, 1, 128, 128, 2, 128, 3, // 5717 0, 1, 2, 128, 128, 3, 128, 4, // 5718 128, 128, 128, 0, 128, 1, 128, 2, // 5719 0, 128, 128, 1, 128, 2, 128, 3, // 5720 128, 0, 128, 1, 128, 2, 128, 3, // 5721 0, 1, 128, 2, 128, 3, 128, 4, // 5722 128, 128, 0, 1, 128, 2, 128, 3, // 5723 0, 128, 1, 2, 128, 3, 128, 4, // 5724 128, 0, 1, 2, 128, 3, 128, 4, // 5725 0, 1, 2, 3, 128, 4, 128, 5, // 5726 128, 128, 128, 128, 0, 1, 128, 2, // 5727 0, 128, 128, 128, 1, 2, 128, 3, // 5728 128, 0, 128, 128, 1, 2, 128, 3, // 5729 0, 1, 128, 128, 2, 3, 128, 4, // 5730 128, 128, 0, 128, 1, 2, 128, 3, // 5731 0, 128, 1, 128, 2, 3, 128, 4, // 5732 128, 0, 1, 128, 2, 3, 128, 4, // 5733 0, 1, 2, 128, 3, 4, 128, 5, // 5734 128, 128, 128, 0, 1, 2, 128, 3, // 5735 0, 128, 128, 1, 2, 3, 128, 4, // 5736 128, 0, 128, 1, 2, 3, 128, 4, // 5737 0, 1, 128, 2, 3, 4, 128, 5, // 5738 128, 128, 0, 1, 2, 3, 128, 4, // 5739 0, 128, 1, 2, 3, 4, 128, 5, // 5740 128, 0, 1, 2, 3, 4, 128, 5, // 5741 0, 1, 2, 3, 4, 5, 128, 6, // 5742 128, 128, 128, 128, 128, 128, 0, 1, // 5743 0, 128, 128, 128, 128, 128, 1, 2, // 5744 128, 0, 128, 128, 128, 128, 1, 2, // 5745 0, 1, 128, 128, 128, 128, 2, 3, // 5746 128, 128, 0, 128, 128, 128, 1, 2, // 5747 0, 128, 1, 128, 128, 128, 2, 3, // 5748 128, 0, 1, 128, 128, 128, 2, 3, // 5749 0, 1, 2, 128, 128, 128, 3, 4, // 5750 128, 128, 128, 0, 128, 128, 1, 2, // 5751 0, 128, 128, 1, 128, 128, 2, 3, // 5752 128, 0, 128, 1, 128, 128, 2, 3, // 5753 0, 1, 128, 2, 128, 128, 3, 4, // 5754 128, 128, 0, 1, 128, 128, 2, 3, // 5755 0, 128, 1, 2, 128, 128, 3, 4, // 5756 128, 0, 1, 2, 128, 128, 3, 4, // 5757 0, 1, 2, 3, 128, 128, 4, 5, // 5758 128, 128, 128, 128, 0, 128, 1, 2, // 5759 0, 128, 128, 128, 1, 128, 2, 3, // 5760 128, 0, 128, 128, 1, 128, 2, 3, // 5761 0, 1, 128, 128, 2, 128, 3, 4, // 5762 128, 128, 0, 128, 1, 128, 2, 3, // 5763 0, 128, 1, 128, 2, 128, 3, 4, // 5764 128, 0, 1, 128, 2, 128, 3, 4, // 5765 0, 1, 2, 128, 3, 128, 4, 5, // 5766 128, 128, 128, 0, 1, 128, 2, 3, // 5767 0, 128, 128, 1, 2, 128, 3, 4, // 5768 128, 0, 128, 1, 2, 128, 3, 4, // 5769 0, 1, 128, 2, 3, 128, 4, 5, // 5770 128, 128, 0, 1, 2, 128, 3, 4, // 5771 0, 128, 1, 2, 3, 128, 4, 5, // 5772 128, 0, 1, 2, 3, 128, 4, 5, // 5773 0, 1, 2, 3, 4, 128, 5, 6, // 5774 128, 128, 128, 128, 128, 0, 1, 2, // 5775 0, 128, 128, 128, 128, 1, 2, 3, // 5776 128, 0, 128, 128, 128, 1, 2, 3, // 5777 0, 1, 128, 128, 128, 2, 3, 4, // 5778 128, 128, 0, 128, 128, 1, 2, 3, // 5779 0, 128, 1, 128, 128, 2, 3, 4, // 5780 128, 0, 1, 128, 128, 2, 3, 4, // 5781 0, 1, 2, 128, 128, 3, 4, 5, // 5782 128, 128, 128, 0, 128, 1, 2, 3, // 5783 0, 128, 128, 1, 128, 2, 3, 4, // 5784 128, 0, 128, 1, 128, 2, 3, 4, // 5785 0, 1, 128, 2, 128, 3, 4, 5, // 5786 128, 128, 0, 1, 128, 2, 3, 4, // 5787 0, 128, 1, 2, 128, 3, 4, 5, // 5788 128, 0, 1, 2, 128, 3, 4, 5, // 5789 0, 1, 2, 3, 128, 4, 5, 6, // 5790 128, 128, 128, 128, 0, 1, 2, 3, // 5791 0, 128, 128, 128, 1, 2, 3, 4, // 5792 128, 0, 128, 128, 1, 2, 3, 4, // 5793 0, 1, 128, 128, 2, 3, 4, 5, // 5794 128, 128, 0, 128, 1, 2, 3, 4, // 5795 0, 128, 1, 128, 2, 3, 4, 5, // 5796 128, 0, 1, 128, 2, 3, 4, 5, // 5797 0, 1, 2, 128, 3, 4, 5, 6, // 5798 128, 128, 128, 0, 1, 2, 3, 4, // 5799 0, 128, 128, 1, 2, 3, 4, 5, // 5800 128, 0, 128, 1, 2, 3, 4, 5, // 5801 0, 1, 128, 2, 3, 4, 5, 6, // 5802 128, 128, 0, 1, 2, 3, 4, 5, // 5803 0, 128, 1, 2, 3, 4, 5, 6, // 5804 128, 0, 1, 2, 3, 4, 5, 6, // 5805 0, 1, 2, 3, 4, 5, 6, 7}; 5806 return Load(du8, table + mask_bits * 8); 5807 } 5808 5809 template <class D, HWY_IF_T_SIZE_D(D, 1)> 5810 HWY_INLINE svuint8_t LaneIndicesFromByteIndices(D, svuint8_t idx) { 5811 return idx; 5812 } 5813 template <class D, class DU = RebindToUnsigned<D>, HWY_IF_NOT_T_SIZE_D(D, 1)> 5814 HWY_INLINE VFromD<DU> LaneIndicesFromByteIndices(D, svuint8_t idx) { 5815 return PromoteTo(DU(), idx); 5816 } 5817 5818 // General case when we don't know the vector size, 8 elements at a time. 5819 template <class V> 5820 HWY_INLINE V ExpandLoop(V v, svbool_t mask) { 5821 const DFromV<V> d; 5822 using T = TFromV<V>; 5823 uint8_t mask_bytes[256 / 8]; 5824 StoreMaskBits(d, mask, mask_bytes); 5825 5826 // ShiftLeftLanes is expensive, so we're probably better off storing to memory 5827 // and loading the final result. 5828 alignas(16) T out[2 * MaxLanes(d)]; 5829 5830 svbool_t next = svpfalse_b(); 5831 size_t input_consumed = 0; 5832 const V iota = Iota(d, 0); 5833 for (size_t i = 0; i < Lanes(d); i += 8) { 5834 uint64_t mask_bits = mask_bytes[i / 8]; 5835 5836 // We want to skip past the v lanes already consumed. There is no 5837 // instruction for variable-shift-reg, but we can splice. 5838 const V vH = detail::Splice(v, v, next); 5839 input_consumed += PopCount(mask_bits); 5840 next = detail::GeN(iota, ConvertScalarTo<T>(input_consumed)); 5841 5842 const auto idx = detail::LaneIndicesFromByteIndices( 5843 d, detail::IndicesForExpandFromBits(mask_bits)); 5844 const V expand = TableLookupLanes(vH, idx); 5845 StoreU(expand, d, out + i); 5846 } 5847 return LoadU(d, out); 5848 } 5849 5850 } // namespace detail 5851 5852 template <class V, HWY_IF_T_SIZE_V(V, 1)> 5853 HWY_API V Expand(V v, svbool_t mask) { 5854 #if HWY_TARGET == HWY_SVE2_128 || HWY_IDE 5855 const DFromV<V> d; 5856 uint8_t mask_bytes[256 / 8]; 5857 StoreMaskBits(d, mask, mask_bytes); 5858 const uint64_t maskL = mask_bytes[0]; 5859 const uint64_t maskH = mask_bytes[1]; 5860 5861 // We want to skip past the v bytes already consumed by expandL. There is no 5862 // instruction for shift-reg by variable bytes, but we can splice. Instead of 5863 // GeN, Not(FirstN()) would also work. 5864 using T = TFromV<V>; 5865 const T countL = static_cast<T>(PopCount(maskL)); 5866 const V vH = detail::Splice(v, v, detail::GeN(Iota(d, 0), countL)); 5867 5868 const svuint8_t idxL = detail::IndicesForExpandFromBits(maskL); 5869 const svuint8_t idxH = detail::IndicesForExpandFromBits(maskH); 5870 return Combine(d, TableLookupLanes(vH, idxH), TableLookupLanes(v, idxL)); 5871 #else 5872 return detail::ExpandLoop(v, mask); 5873 #endif 5874 } 5875 5876 template <class V, HWY_IF_T_SIZE_V(V, 2)> 5877 HWY_API V Expand(V v, svbool_t mask) { 5878 #if HWY_TARGET == HWY_SVE2_128 || HWY_IDE // 16x8 5879 const DFromV<V> d; 5880 const RebindToUnsigned<decltype(d)> du16; 5881 const Rebind<uint8_t, decltype(d)> du8; 5882 // Convert mask into bitfield via horizontal sum (faster than ORV) of 8 bits. 5883 // Pre-multiply by N so we can use it as an offset for Load. 5884 const svuint16_t bits = Shl(Set(du16, 1), Iota(du16, 3)); 5885 const size_t offset = detail::SumOfLanesM(mask, bits); 5886 5887 // Storing as 8-bit reduces table size from 4 KiB to 2 KiB. We cannot apply 5888 // the nibble trick used below because not all indices fit within one lane. 5889 alignas(16) static constexpr uint8_t table[8 * 256] = { 5890 // PrintExpand16x8LaneTables 5891 255, 255, 255, 255, 255, 255, 255, 255, // 5892 0, 255, 255, 255, 255, 255, 255, 255, // 5893 255, 0, 255, 255, 255, 255, 255, 255, // 5894 0, 1, 255, 255, 255, 255, 255, 255, // 5895 255, 255, 0, 255, 255, 255, 255, 255, // 5896 0, 255, 1, 255, 255, 255, 255, 255, // 5897 255, 0, 1, 255, 255, 255, 255, 255, // 5898 0, 1, 2, 255, 255, 255, 255, 255, // 5899 255, 255, 255, 0, 255, 255, 255, 255, // 5900 0, 255, 255, 1, 255, 255, 255, 255, // 5901 255, 0, 255, 1, 255, 255, 255, 255, // 5902 0, 1, 255, 2, 255, 255, 255, 255, // 5903 255, 255, 0, 1, 255, 255, 255, 255, // 5904 0, 255, 1, 2, 255, 255, 255, 255, // 5905 255, 0, 1, 2, 255, 255, 255, 255, // 5906 0, 1, 2, 3, 255, 255, 255, 255, // 5907 255, 255, 255, 255, 0, 255, 255, 255, // 5908 0, 255, 255, 255, 1, 255, 255, 255, // 5909 255, 0, 255, 255, 1, 255, 255, 255, // 5910 0, 1, 255, 255, 2, 255, 255, 255, // 5911 255, 255, 0, 255, 1, 255, 255, 255, // 5912 0, 255, 1, 255, 2, 255, 255, 255, // 5913 255, 0, 1, 255, 2, 255, 255, 255, // 5914 0, 1, 2, 255, 3, 255, 255, 255, // 5915 255, 255, 255, 0, 1, 255, 255, 255, // 5916 0, 255, 255, 1, 2, 255, 255, 255, // 5917 255, 0, 255, 1, 2, 255, 255, 255, // 5918 0, 1, 255, 2, 3, 255, 255, 255, // 5919 255, 255, 0, 1, 2, 255, 255, 255, // 5920 0, 255, 1, 2, 3, 255, 255, 255, // 5921 255, 0, 1, 2, 3, 255, 255, 255, // 5922 0, 1, 2, 3, 4, 255, 255, 255, // 5923 255, 255, 255, 255, 255, 0, 255, 255, // 5924 0, 255, 255, 255, 255, 1, 255, 255, // 5925 255, 0, 255, 255, 255, 1, 255, 255, // 5926 0, 1, 255, 255, 255, 2, 255, 255, // 5927 255, 255, 0, 255, 255, 1, 255, 255, // 5928 0, 255, 1, 255, 255, 2, 255, 255, // 5929 255, 0, 1, 255, 255, 2, 255, 255, // 5930 0, 1, 2, 255, 255, 3, 255, 255, // 5931 255, 255, 255, 0, 255, 1, 255, 255, // 5932 0, 255, 255, 1, 255, 2, 255, 255, // 5933 255, 0, 255, 1, 255, 2, 255, 255, // 5934 0, 1, 255, 2, 255, 3, 255, 255, // 5935 255, 255, 0, 1, 255, 2, 255, 255, // 5936 0, 255, 1, 2, 255, 3, 255, 255, // 5937 255, 0, 1, 2, 255, 3, 255, 255, // 5938 0, 1, 2, 3, 255, 4, 255, 255, // 5939 255, 255, 255, 255, 0, 1, 255, 255, // 5940 0, 255, 255, 255, 1, 2, 255, 255, // 5941 255, 0, 255, 255, 1, 2, 255, 255, // 5942 0, 1, 255, 255, 2, 3, 255, 255, // 5943 255, 255, 0, 255, 1, 2, 255, 255, // 5944 0, 255, 1, 255, 2, 3, 255, 255, // 5945 255, 0, 1, 255, 2, 3, 255, 255, // 5946 0, 1, 2, 255, 3, 4, 255, 255, // 5947 255, 255, 255, 0, 1, 2, 255, 255, // 5948 0, 255, 255, 1, 2, 3, 255, 255, // 5949 255, 0, 255, 1, 2, 3, 255, 255, // 5950 0, 1, 255, 2, 3, 4, 255, 255, // 5951 255, 255, 0, 1, 2, 3, 255, 255, // 5952 0, 255, 1, 2, 3, 4, 255, 255, // 5953 255, 0, 1, 2, 3, 4, 255, 255, // 5954 0, 1, 2, 3, 4, 5, 255, 255, // 5955 255, 255, 255, 255, 255, 255, 0, 255, // 5956 0, 255, 255, 255, 255, 255, 1, 255, // 5957 255, 0, 255, 255, 255, 255, 1, 255, // 5958 0, 1, 255, 255, 255, 255, 2, 255, // 5959 255, 255, 0, 255, 255, 255, 1, 255, // 5960 0, 255, 1, 255, 255, 255, 2, 255, // 5961 255, 0, 1, 255, 255, 255, 2, 255, // 5962 0, 1, 2, 255, 255, 255, 3, 255, // 5963 255, 255, 255, 0, 255, 255, 1, 255, // 5964 0, 255, 255, 1, 255, 255, 2, 255, // 5965 255, 0, 255, 1, 255, 255, 2, 255, // 5966 0, 1, 255, 2, 255, 255, 3, 255, // 5967 255, 255, 0, 1, 255, 255, 2, 255, // 5968 0, 255, 1, 2, 255, 255, 3, 255, // 5969 255, 0, 1, 2, 255, 255, 3, 255, // 5970 0, 1, 2, 3, 255, 255, 4, 255, // 5971 255, 255, 255, 255, 0, 255, 1, 255, // 5972 0, 255, 255, 255, 1, 255, 2, 255, // 5973 255, 0, 255, 255, 1, 255, 2, 255, // 5974 0, 1, 255, 255, 2, 255, 3, 255, // 5975 255, 255, 0, 255, 1, 255, 2, 255, // 5976 0, 255, 1, 255, 2, 255, 3, 255, // 5977 255, 0, 1, 255, 2, 255, 3, 255, // 5978 0, 1, 2, 255, 3, 255, 4, 255, // 5979 255, 255, 255, 0, 1, 255, 2, 255, // 5980 0, 255, 255, 1, 2, 255, 3, 255, // 5981 255, 0, 255, 1, 2, 255, 3, 255, // 5982 0, 1, 255, 2, 3, 255, 4, 255, // 5983 255, 255, 0, 1, 2, 255, 3, 255, // 5984 0, 255, 1, 2, 3, 255, 4, 255, // 5985 255, 0, 1, 2, 3, 255, 4, 255, // 5986 0, 1, 2, 3, 4, 255, 5, 255, // 5987 255, 255, 255, 255, 255, 0, 1, 255, // 5988 0, 255, 255, 255, 255, 1, 2, 255, // 5989 255, 0, 255, 255, 255, 1, 2, 255, // 5990 0, 1, 255, 255, 255, 2, 3, 255, // 5991 255, 255, 0, 255, 255, 1, 2, 255, // 5992 0, 255, 1, 255, 255, 2, 3, 255, // 5993 255, 0, 1, 255, 255, 2, 3, 255, // 5994 0, 1, 2, 255, 255, 3, 4, 255, // 5995 255, 255, 255, 0, 255, 1, 2, 255, // 5996 0, 255, 255, 1, 255, 2, 3, 255, // 5997 255, 0, 255, 1, 255, 2, 3, 255, // 5998 0, 1, 255, 2, 255, 3, 4, 255, // 5999 255, 255, 0, 1, 255, 2, 3, 255, // 6000 0, 255, 1, 2, 255, 3, 4, 255, // 6001 255, 0, 1, 2, 255, 3, 4, 255, // 6002 0, 1, 2, 3, 255, 4, 5, 255, // 6003 255, 255, 255, 255, 0, 1, 2, 255, // 6004 0, 255, 255, 255, 1, 2, 3, 255, // 6005 255, 0, 255, 255, 1, 2, 3, 255, // 6006 0, 1, 255, 255, 2, 3, 4, 255, // 6007 255, 255, 0, 255, 1, 2, 3, 255, // 6008 0, 255, 1, 255, 2, 3, 4, 255, // 6009 255, 0, 1, 255, 2, 3, 4, 255, // 6010 0, 1, 2, 255, 3, 4, 5, 255, // 6011 255, 255, 255, 0, 1, 2, 3, 255, // 6012 0, 255, 255, 1, 2, 3, 4, 255, // 6013 255, 0, 255, 1, 2, 3, 4, 255, // 6014 0, 1, 255, 2, 3, 4, 5, 255, // 6015 255, 255, 0, 1, 2, 3, 4, 255, // 6016 0, 255, 1, 2, 3, 4, 5, 255, // 6017 255, 0, 1, 2, 3, 4, 5, 255, // 6018 0, 1, 2, 3, 4, 5, 6, 255, // 6019 255, 255, 255, 255, 255, 255, 255, 0, // 6020 0, 255, 255, 255, 255, 255, 255, 1, // 6021 255, 0, 255, 255, 255, 255, 255, 1, // 6022 0, 1, 255, 255, 255, 255, 255, 2, // 6023 255, 255, 0, 255, 255, 255, 255, 1, // 6024 0, 255, 1, 255, 255, 255, 255, 2, // 6025 255, 0, 1, 255, 255, 255, 255, 2, // 6026 0, 1, 2, 255, 255, 255, 255, 3, // 6027 255, 255, 255, 0, 255, 255, 255, 1, // 6028 0, 255, 255, 1, 255, 255, 255, 2, // 6029 255, 0, 255, 1, 255, 255, 255, 2, // 6030 0, 1, 255, 2, 255, 255, 255, 3, // 6031 255, 255, 0, 1, 255, 255, 255, 2, // 6032 0, 255, 1, 2, 255, 255, 255, 3, // 6033 255, 0, 1, 2, 255, 255, 255, 3, // 6034 0, 1, 2, 3, 255, 255, 255, 4, // 6035 255, 255, 255, 255, 0, 255, 255, 1, // 6036 0, 255, 255, 255, 1, 255, 255, 2, // 6037 255, 0, 255, 255, 1, 255, 255, 2, // 6038 0, 1, 255, 255, 2, 255, 255, 3, // 6039 255, 255, 0, 255, 1, 255, 255, 2, // 6040 0, 255, 1, 255, 2, 255, 255, 3, // 6041 255, 0, 1, 255, 2, 255, 255, 3, // 6042 0, 1, 2, 255, 3, 255, 255, 4, // 6043 255, 255, 255, 0, 1, 255, 255, 2, // 6044 0, 255, 255, 1, 2, 255, 255, 3, // 6045 255, 0, 255, 1, 2, 255, 255, 3, // 6046 0, 1, 255, 2, 3, 255, 255, 4, // 6047 255, 255, 0, 1, 2, 255, 255, 3, // 6048 0, 255, 1, 2, 3, 255, 255, 4, // 6049 255, 0, 1, 2, 3, 255, 255, 4, // 6050 0, 1, 2, 3, 4, 255, 255, 5, // 6051 255, 255, 255, 255, 255, 0, 255, 1, // 6052 0, 255, 255, 255, 255, 1, 255, 2, // 6053 255, 0, 255, 255, 255, 1, 255, 2, // 6054 0, 1, 255, 255, 255, 2, 255, 3, // 6055 255, 255, 0, 255, 255, 1, 255, 2, // 6056 0, 255, 1, 255, 255, 2, 255, 3, // 6057 255, 0, 1, 255, 255, 2, 255, 3, // 6058 0, 1, 2, 255, 255, 3, 255, 4, // 6059 255, 255, 255, 0, 255, 1, 255, 2, // 6060 0, 255, 255, 1, 255, 2, 255, 3, // 6061 255, 0, 255, 1, 255, 2, 255, 3, // 6062 0, 1, 255, 2, 255, 3, 255, 4, // 6063 255, 255, 0, 1, 255, 2, 255, 3, // 6064 0, 255, 1, 2, 255, 3, 255, 4, // 6065 255, 0, 1, 2, 255, 3, 255, 4, // 6066 0, 1, 2, 3, 255, 4, 255, 5, // 6067 255, 255, 255, 255, 0, 1, 255, 2, // 6068 0, 255, 255, 255, 1, 2, 255, 3, // 6069 255, 0, 255, 255, 1, 2, 255, 3, // 6070 0, 1, 255, 255, 2, 3, 255, 4, // 6071 255, 255, 0, 255, 1, 2, 255, 3, // 6072 0, 255, 1, 255, 2, 3, 255, 4, // 6073 255, 0, 1, 255, 2, 3, 255, 4, // 6074 0, 1, 2, 255, 3, 4, 255, 5, // 6075 255, 255, 255, 0, 1, 2, 255, 3, // 6076 0, 255, 255, 1, 2, 3, 255, 4, // 6077 255, 0, 255, 1, 2, 3, 255, 4, // 6078 0, 1, 255, 2, 3, 4, 255, 5, // 6079 255, 255, 0, 1, 2, 3, 255, 4, // 6080 0, 255, 1, 2, 3, 4, 255, 5, // 6081 255, 0, 1, 2, 3, 4, 255, 5, // 6082 0, 1, 2, 3, 4, 5, 255, 6, // 6083 255, 255, 255, 255, 255, 255, 0, 1, // 6084 0, 255, 255, 255, 255, 255, 1, 2, // 6085 255, 0, 255, 255, 255, 255, 1, 2, // 6086 0, 1, 255, 255, 255, 255, 2, 3, // 6087 255, 255, 0, 255, 255, 255, 1, 2, // 6088 0, 255, 1, 255, 255, 255, 2, 3, // 6089 255, 0, 1, 255, 255, 255, 2, 3, // 6090 0, 1, 2, 255, 255, 255, 3, 4, // 6091 255, 255, 255, 0, 255, 255, 1, 2, // 6092 0, 255, 255, 1, 255, 255, 2, 3, // 6093 255, 0, 255, 1, 255, 255, 2, 3, // 6094 0, 1, 255, 2, 255, 255, 3, 4, // 6095 255, 255, 0, 1, 255, 255, 2, 3, // 6096 0, 255, 1, 2, 255, 255, 3, 4, // 6097 255, 0, 1, 2, 255, 255, 3, 4, // 6098 0, 1, 2, 3, 255, 255, 4, 5, // 6099 255, 255, 255, 255, 0, 255, 1, 2, // 6100 0, 255, 255, 255, 1, 255, 2, 3, // 6101 255, 0, 255, 255, 1, 255, 2, 3, // 6102 0, 1, 255, 255, 2, 255, 3, 4, // 6103 255, 255, 0, 255, 1, 255, 2, 3, // 6104 0, 255, 1, 255, 2, 255, 3, 4, // 6105 255, 0, 1, 255, 2, 255, 3, 4, // 6106 0, 1, 2, 255, 3, 255, 4, 5, // 6107 255, 255, 255, 0, 1, 255, 2, 3, // 6108 0, 255, 255, 1, 2, 255, 3, 4, // 6109 255, 0, 255, 1, 2, 255, 3, 4, // 6110 0, 1, 255, 2, 3, 255, 4, 5, // 6111 255, 255, 0, 1, 2, 255, 3, 4, // 6112 0, 255, 1, 2, 3, 255, 4, 5, // 6113 255, 0, 1, 2, 3, 255, 4, 5, // 6114 0, 1, 2, 3, 4, 255, 5, 6, // 6115 255, 255, 255, 255, 255, 0, 1, 2, // 6116 0, 255, 255, 255, 255, 1, 2, 3, // 6117 255, 0, 255, 255, 255, 1, 2, 3, // 6118 0, 1, 255, 255, 255, 2, 3, 4, // 6119 255, 255, 0, 255, 255, 1, 2, 3, // 6120 0, 255, 1, 255, 255, 2, 3, 4, // 6121 255, 0, 1, 255, 255, 2, 3, 4, // 6122 0, 1, 2, 255, 255, 3, 4, 5, // 6123 255, 255, 255, 0, 255, 1, 2, 3, // 6124 0, 255, 255, 1, 255, 2, 3, 4, // 6125 255, 0, 255, 1, 255, 2, 3, 4, // 6126 0, 1, 255, 2, 255, 3, 4, 5, // 6127 255, 255, 0, 1, 255, 2, 3, 4, // 6128 0, 255, 1, 2, 255, 3, 4, 5, // 6129 255, 0, 1, 2, 255, 3, 4, 5, // 6130 0, 1, 2, 3, 255, 4, 5, 6, // 6131 255, 255, 255, 255, 0, 1, 2, 3, // 6132 0, 255, 255, 255, 1, 2, 3, 4, // 6133 255, 0, 255, 255, 1, 2, 3, 4, // 6134 0, 1, 255, 255, 2, 3, 4, 5, // 6135 255, 255, 0, 255, 1, 2, 3, 4, // 6136 0, 255, 1, 255, 2, 3, 4, 5, // 6137 255, 0, 1, 255, 2, 3, 4, 5, // 6138 0, 1, 2, 255, 3, 4, 5, 6, // 6139 255, 255, 255, 0, 1, 2, 3, 4, // 6140 0, 255, 255, 1, 2, 3, 4, 5, // 6141 255, 0, 255, 1, 2, 3, 4, 5, // 6142 0, 1, 255, 2, 3, 4, 5, 6, // 6143 255, 255, 0, 1, 2, 3, 4, 5, // 6144 0, 255, 1, 2, 3, 4, 5, 6, // 6145 255, 0, 1, 2, 3, 4, 5, 6, // 6146 0, 1, 2, 3, 4, 5, 6, 7}; 6147 const svuint16_t indices = PromoteTo(du16, Load(du8, table + offset)); 6148 return TableLookupLanes(v, indices); // already zeros mask=false lanes 6149 #else 6150 return detail::ExpandLoop(v, mask); 6151 #endif 6152 } 6153 6154 template <class V, HWY_IF_T_SIZE_V(V, 4)> 6155 HWY_API V Expand(V v, svbool_t mask) { 6156 #if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 32x8 6157 const DFromV<V> d; 6158 const RebindToUnsigned<decltype(d)> du32; 6159 // Convert mask into bitfield via horizontal sum (faster than ORV). 6160 const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0)); 6161 const size_t code = detail::SumOfLanesM(mask, bits); 6162 6163 alignas(16) constexpr uint32_t packed_array[256] = { 6164 // PrintExpand32x8. 6165 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0, 6166 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10, 6167 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0, 6168 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210, 6169 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0, 6170 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10, 6171 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0, 6172 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210, 6173 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0, 6174 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10, 6175 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0, 6176 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210, 6177 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0, 6178 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10, 6179 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0, 6180 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210, 6181 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0, 6182 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10, 6183 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0, 6184 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210, 6185 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0, 6186 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10, 6187 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0, 6188 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210, 6189 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0, 6190 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10, 6191 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0, 6192 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210, 6193 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0, 6194 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10, 6195 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0, 6196 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210, 6197 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0, 6198 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10, 6199 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0, 6200 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210, 6201 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0, 6202 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10, 6203 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0, 6204 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210, 6205 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0, 6206 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10, 6207 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210}; 6208 6209 // For lane i, shift the i-th 4-bit index down and mask with 0xF because 6210 // svtbl zeros outputs if the index is out of bounds. 6211 const svuint32_t packed = Set(du32, packed_array[code]); 6212 const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF); 6213 return TableLookupLanes(v, indices); // already zeros mask=false lanes 6214 #elif HWY_TARGET == HWY_SVE2_128 // 32x4 6215 const DFromV<V> d; 6216 const RebindToUnsigned<decltype(d)> du32; 6217 // Convert mask into bitfield via horizontal sum (faster than ORV). 6218 const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0)); 6219 const size_t offset = detail::SumOfLanesM(mask, bits); 6220 6221 alignas(16) constexpr uint32_t packed_array[16] = { 6222 // PrintExpand64x4Nibble - same for 32x4. 6223 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, 6224 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, 6225 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; 6226 6227 // For lane i, shift the i-th 4-bit index down and mask with 0xF because 6228 // svtbl zeros outputs if the index is out of bounds. 6229 const svuint32_t packed = Set(du32, packed_array[offset]); 6230 const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF); 6231 return TableLookupLanes(v, indices); // already zeros mask=false lanes 6232 #else 6233 return detail::ExpandLoop(v, mask); 6234 #endif 6235 } 6236 6237 template <class V, HWY_IF_T_SIZE_V(V, 8)> 6238 HWY_API V Expand(V v, svbool_t mask) { 6239 #if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 64x4 6240 const DFromV<V> d; 6241 const RebindToUnsigned<decltype(d)> du64; 6242 6243 // Convert mask into bitfield via horizontal sum (faster than ORV) of masked 6244 // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for 6245 // SetTableIndices. 6246 const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); 6247 const size_t offset = detail::SumOfLanesM(mask, bits); 6248 6249 alignas(16) static constexpr uint64_t table[4 * 16] = { 6250 // PrintExpand64x4Tables - small enough to store uncompressed. 6251 255, 255, 255, 255, 0, 255, 255, 255, 255, 0, 255, 255, 0, 1, 255, 255, 6252 255, 255, 0, 255, 0, 255, 1, 255, 255, 0, 1, 255, 0, 1, 2, 255, 6253 255, 255, 255, 0, 0, 255, 255, 1, 255, 0, 255, 1, 0, 1, 255, 2, 6254 255, 255, 0, 1, 0, 255, 1, 2, 255, 0, 1, 2, 0, 1, 2, 3}; 6255 // This already zeros mask=false lanes. 6256 return TableLookupLanes(v, SetTableIndices(d, table + offset)); 6257 #elif HWY_TARGET == HWY_SVE2_128 // 64x2 6258 // Same as Compress, just zero out the mask=false lanes. 6259 return IfThenElseZero(mask, Compress(v, mask)); 6260 #else 6261 return detail::ExpandLoop(v, mask); 6262 #endif 6263 } 6264 6265 // ------------------------------ LoadExpand 6266 6267 template <class D> 6268 HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d, 6269 const TFromD<D>* HWY_RESTRICT unaligned) { 6270 return Expand(LoadU(d, unaligned), mask); 6271 } 6272 6273 // ------------------------------ MulEven (InterleaveEven) 6274 6275 #if HWY_SVE_HAVE_2 6276 namespace detail { 6277 #define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ 6278 HWY_API HWY_SVE_V(BASE, BITS) \ 6279 NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \ 6280 return sv##OP##_##CHAR##BITS(a, b); \ 6281 } 6282 6283 HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) 6284 HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) 6285 HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) 6286 HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulOddNative, mullt) 6287 HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulOddNative, mullt) 6288 HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulOddNative, mullt) 6289 #undef HWY_SVE_MUL_EVEN 6290 } // namespace detail 6291 #endif 6292 6293 template <class V, class DW = RepartitionToWide<DFromV<V>>, 6294 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> 6295 HWY_API VFromD<DW> MulEven(const V a, const V b) { 6296 #if HWY_SVE_HAVE_2 6297 return BitCast(DW(), detail::MulEvenNative(a, b)); 6298 #else 6299 const auto lo = Mul(a, b); 6300 const auto hi = MulHigh(a, b); 6301 return BitCast(DW(), detail::InterleaveEven(lo, hi)); 6302 #endif 6303 } 6304 6305 template <class V, class DW = RepartitionToWide<DFromV<V>>, 6306 HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> 6307 HWY_API VFromD<DW> MulOdd(const V a, const V b) { 6308 #if HWY_SVE_HAVE_2 6309 return BitCast(DW(), detail::MulOddNative(a, b)); 6310 #else 6311 const auto lo = Mul(a, b); 6312 const auto hi = MulHigh(a, b); 6313 return BitCast(DW(), detail::InterleaveOdd(lo, hi)); 6314 #endif 6315 } 6316 6317 HWY_API svint64_t MulEven(const svint64_t a, const svint64_t b) { 6318 const auto lo = Mul(a, b); 6319 const auto hi = MulHigh(a, b); 6320 return detail::InterleaveEven(lo, hi); 6321 } 6322 6323 HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) { 6324 const auto lo = Mul(a, b); 6325 const auto hi = MulHigh(a, b); 6326 return detail::InterleaveEven(lo, hi); 6327 } 6328 6329 HWY_API svint64_t MulOdd(const svint64_t a, const svint64_t b) { 6330 const auto lo = Mul(a, b); 6331 const auto hi = MulHigh(a, b); 6332 return detail::InterleaveOdd(lo, hi); 6333 } 6334 6335 HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { 6336 const auto lo = Mul(a, b); 6337 const auto hi = MulHigh(a, b); 6338 return detail::InterleaveOdd(lo, hi); 6339 } 6340 6341 // ------------------------------ PairwiseAdd/PairwiseSub 6342 #if HWY_TARGET != HWY_SCALAR 6343 #if HWY_SVE_HAVE_2 || HWY_IDE 6344 6345 #ifdef HWY_NATIVE_PAIRWISE_ADD 6346 #undef HWY_NATIVE_PAIRWISE_ADD 6347 #else 6348 #define HWY_NATIVE_PAIRWISE_ADD 6349 #endif 6350 6351 namespace detail { 6352 #define HWY_SVE_SV_PAIRWISE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ 6353 template <size_t N, int kPow2> \ 6354 HWY_API HWY_SVE_V(BASE, BITS) \ 6355 NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, HWY_SVE_V(BASE, BITS) a, \ 6356 HWY_SVE_V(BASE, BITS) b) { \ 6357 return sv##OP##_##CHAR##BITS##_m(HWY_SVE_PTRUE(BITS), a, b); \ 6358 } 6359 6360 HWY_SVE_FOREACH(HWY_SVE_SV_PAIRWISE_ADD, PairwiseAdd, addp) 6361 #undef HWY_SVE_SV_PAIRWISE_ADD 6362 } // namespace detail 6363 6364 // Pairwise add returning interleaved output of a and b 6365 template <class D, class V, HWY_IF_LANES_GT_D(D, 1)> 6366 HWY_API V PairwiseAdd(D d, V a, V b) { 6367 return detail::PairwiseAdd(d, a, b); 6368 } 6369 6370 #endif // HWY_SVE_HAVE_2 6371 #endif // HWY_TARGET != HWY_SCALAR 6372 6373 // ------------------------------ WidenMulPairwiseAdd 6374 6375 template <size_t N, int kPow2> 6376 HWY_API svfloat32_t WidenMulPairwiseAdd(Simd<float, N, kPow2> df, VBF16 a, 6377 VBF16 b) { 6378 #if HWY_SVE_HAVE_F32_TO_BF16C 6379 const svfloat32_t even = svbfmlalb_f32(Zero(df), a, b); 6380 return svbfmlalt_f32(even, a, b); 6381 #else 6382 return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), 6383 Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); 6384 #endif // HWY_SVE_HAVE_BF16_FEATURE 6385 } 6386 6387 template <size_t N, int kPow2> 6388 HWY_API svint32_t WidenMulPairwiseAdd(Simd<int32_t, N, kPow2> d32, svint16_t a, 6389 svint16_t b) { 6390 #if HWY_SVE_HAVE_2 6391 (void)d32; 6392 return svmlalt_s32(svmullb_s32(a, b), a, b); 6393 #else 6394 return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), 6395 Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); 6396 #endif 6397 } 6398 6399 template <size_t N, int kPow2> 6400 HWY_API svuint32_t WidenMulPairwiseAdd(Simd<uint32_t, N, kPow2> d32, 6401 svuint16_t a, svuint16_t b) { 6402 #if HWY_SVE_HAVE_2 6403 (void)d32; 6404 return svmlalt_u32(svmullb_u32(a, b), a, b); 6405 #else 6406 return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), 6407 Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); 6408 #endif 6409 } 6410 6411 // ------------------------------ SatWidenMulPairwiseAccumulate 6412 #if HWY_SVE_HAVE_2 6413 #define HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2(BASE, CHAR, BITS, HALF, NAME, OP) \ 6414 template <size_t N, int kPow2> \ 6415 HWY_API HWY_SVE_V(BASE, BITS) \ 6416 NAME(HWY_SVE_D(BASE, BITS, N, kPow2) dw, HWY_SVE_V(BASE, HALF) a, \ 6417 HWY_SVE_V(BASE, HALF) b, HWY_SVE_V(BASE, BITS) sum) { \ 6418 auto product = svmlalt_##CHAR##BITS(svmullb_##CHAR##BITS(a, b), a, b); \ 6419 const auto mul_overflow = IfThenElseZero( \ 6420 Eq(product, Set(dw, LimitsMin<int##BITS##_t>())), Set(dw, -1)); \ 6421 return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), \ 6422 Add(product, mul_overflow)); \ 6423 } 6424 HWY_SVE_FOREACH_UI16(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, 6425 SatWidenMulPairwiseAccumulate, _) 6426 HWY_SVE_FOREACH_UI32(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, 6427 SatWidenMulPairwiseAccumulate, _) 6428 HWY_SVE_FOREACH_UI64(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, 6429 SatWidenMulPairwiseAccumulate, _) 6430 6431 #undef HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2 6432 #endif 6433 6434 // ------------------------------ SatWidenMulAccumFixedPoint 6435 6436 #if HWY_SVE_HAVE_2 6437 6438 #ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT 6439 #undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT 6440 #else 6441 #define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT 6442 #endif 6443 6444 template <class DI32, HWY_IF_I32_D(DI32)> 6445 HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 /*di32*/, 6446 VFromD<Rebind<int16_t, DI32>> a, 6447 VFromD<Rebind<int16_t, DI32>> b, 6448 VFromD<DI32> sum) { 6449 return svqdmlalb_s32(sum, detail::ZipLowerSame(a, a), 6450 detail::ZipLowerSame(b, b)); 6451 } 6452 6453 #endif // HWY_SVE_HAVE_2 6454 6455 // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) 6456 6457 #if HWY_SVE_HAVE_BF16_FEATURE 6458 6459 // NOTE: we currently do not use SVE BFDOT for bf16 ReorderWidenMulAccumulate 6460 // because, apparently unlike NEON, it uses round to odd unless the additional 6461 // FEAT_EBF16 feature is available and enabled. 6462 #ifdef HWY_NATIVE_MUL_EVEN_BF16 6463 #undef HWY_NATIVE_MUL_EVEN_BF16 6464 #else 6465 #define HWY_NATIVE_MUL_EVEN_BF16 6466 #endif 6467 6468 template <size_t N, int kPow2> 6469 HWY_API svfloat32_t MulEvenAdd(Simd<float, N, kPow2> /* d */, VBF16 a, VBF16 b, 6470 const svfloat32_t c) { 6471 return svbfmlalb_f32(c, a, b); 6472 } 6473 6474 template <size_t N, int kPow2> 6475 HWY_API svfloat32_t MulOddAdd(Simd<float, N, kPow2> /* d */, VBF16 a, VBF16 b, 6476 const svfloat32_t c) { 6477 return svbfmlalt_f32(c, a, b); 6478 } 6479 6480 #endif // HWY_SVE_HAVE_BF16_FEATURE 6481 6482 template <size_t N, int kPow2> 6483 HWY_API svint32_t ReorderWidenMulAccumulate(Simd<int32_t, N, kPow2> d32, 6484 svint16_t a, svint16_t b, 6485 const svint32_t sum0, 6486 svint32_t& sum1) { 6487 #if HWY_SVE_HAVE_2 6488 (void)d32; 6489 sum1 = svmlalt_s32(sum1, a, b); 6490 return svmlalb_s32(sum0, a, b); 6491 #else 6492 // Lane order within sum0/1 is undefined, hence we can avoid the 6493 // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. 6494 sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); 6495 return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); 6496 #endif 6497 } 6498 6499 template <size_t N, int kPow2> 6500 HWY_API svuint32_t ReorderWidenMulAccumulate(Simd<uint32_t, N, kPow2> d32, 6501 svuint16_t a, svuint16_t b, 6502 const svuint32_t sum0, 6503 svuint32_t& sum1) { 6504 #if HWY_SVE_HAVE_2 6505 (void)d32; 6506 sum1 = svmlalt_u32(sum1, a, b); 6507 return svmlalb_u32(sum0, a, b); 6508 #else 6509 // Lane order within sum0/1 is undefined, hence we can avoid the 6510 // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. 6511 sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); 6512 return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); 6513 #endif 6514 } 6515 6516 // ------------------------------ RearrangeToOddPlusEven 6517 template <class VW> 6518 HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { 6519 // sum0 is the sum of bottom/even lanes and sum1 of top/odd lanes. 6520 return Add(sum0, sum1); 6521 } 6522 6523 // ------------------------------ SumOfMulQuadAccumulate 6524 6525 #ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE 6526 #undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE 6527 #else 6528 #define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE 6529 #endif 6530 6531 template <class DI32, HWY_IF_I32_D(DI32)> 6532 HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/, svint8_t a, 6533 svint8_t b, svint32_t sum) { 6534 return svdot_s32(sum, a, b); 6535 } 6536 6537 #ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE 6538 #undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE 6539 #else 6540 #define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE 6541 #endif 6542 6543 template <class DU32, HWY_IF_U32_D(DU32)> 6544 HWY_API VFromD<DU32> SumOfMulQuadAccumulate(DU32 /*du32*/, svuint8_t a, 6545 svuint8_t b, svuint32_t sum) { 6546 return svdot_u32(sum, a, b); 6547 } 6548 6549 #ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE 6550 #undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE 6551 #else 6552 #define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE 6553 #endif 6554 6555 template <class DI32, HWY_IF_I32_D(DI32)> 6556 HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u, 6557 svint8_t b_i, svint32_t sum) { 6558 #if HWY_SVE_HAVE_2 6559 (void)di32; 6560 return svusdot_s32(sum, a_u, b_i); 6561 #else 6562 const RebindToUnsigned<decltype(di32)> du32; 6563 const Repartition<uint8_t, decltype(di32)> du8; 6564 6565 const auto b_u = BitCast(du8, b_i); 6566 const auto result_sum0 = svdot_u32(BitCast(du32, sum), a_u, b_u); 6567 const auto result_sum1 = 6568 ShiftLeft<8>(svdot_u32(Zero(du32), a_u, ShiftRight<7>(b_u))); 6569 6570 return BitCast(di32, Sub(result_sum0, result_sum1)); 6571 #endif 6572 } 6573 6574 #ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE 6575 #undef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE 6576 #else 6577 #define HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE 6578 #endif 6579 6580 template <class DI64, HWY_IF_I64_D(DI64)> 6581 HWY_API VFromD<DI64> SumOfMulQuadAccumulate(DI64 /*di64*/, svint16_t a, 6582 svint16_t b, svint64_t sum) { 6583 return svdot_s64(sum, a, b); 6584 } 6585 6586 #ifdef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE 6587 #undef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE 6588 #else 6589 #define HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE 6590 #endif 6591 6592 template <class DU64, HWY_IF_U64_D(DU64)> 6593 HWY_API VFromD<DU64> SumOfMulQuadAccumulate(DU64 /*du64*/, svuint16_t a, 6594 svuint16_t b, svuint64_t sum) { 6595 return svdot_u64(sum, a, b); 6596 } 6597 6598 // ------------------------------ MulComplex* / MaskedMulComplex* 6599 6600 // Per-target flag to prevent generic_ops-inl.h from defining MulComplex*. 6601 #ifdef HWY_NATIVE_CPLX 6602 #undef HWY_NATIVE_CPLX 6603 #else 6604 #define HWY_NATIVE_CPLX 6605 #endif 6606 6607 template <class V, HWY_IF_NOT_UNSIGNED(TFromV<V>)> 6608 HWY_API V ComplexConj(V a) { 6609 return OddEven(Neg(a), a); 6610 } 6611 6612 namespace detail { 6613 #define HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, ROT) \ 6614 HWY_API HWY_SVE_V(BASE, BITS) \ 6615 NAME##ROT(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ 6616 HWY_SVE_V(BASE, BITS) c) { \ 6617 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b, c, ROT); \ 6618 } \ 6619 HWY_API HWY_SVE_V(BASE, BITS) \ 6620 NAME##Z##ROT(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ 6621 HWY_SVE_V(BASE, BITS) b, HWY_SVE_V(BASE, BITS) c) { \ 6622 return sv##OP##_##CHAR##BITS##_z(m, a, b, c, ROT); \ 6623 } 6624 6625 #define HWY_SVE_CPLX_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ 6626 HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 0) \ 6627 HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 90) \ 6628 HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 180) \ 6629 HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 270) 6630 6631 // Only SVE2 has complex multiply add for integer types 6632 // and these do not include masked variants 6633 HWY_SVE_FOREACH_F(HWY_SVE_CPLX_FMA, ComplexMulAdd, cmla) 6634 #undef HWY_SVE_CPLX_FMA 6635 #undef HWY_SVE_CPLX_FMA_ROT 6636 } // namespace detail 6637 6638 template <class V, class M, HWY_IF_FLOAT_V(V)> 6639 HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { 6640 const V t = detail::ComplexMulAddZ0(mask, c, b, a); 6641 return detail::ComplexMulAddZ270(mask, t, b, a); 6642 } 6643 6644 template <class V, class M, HWY_IF_FLOAT_V(V)> 6645 HWY_API V MaskedMulComplexConj(M mask, V a, V b) { 6646 return MaskedMulComplexConjAdd(mask, a, b, Zero(DFromV<V>())); 6647 } 6648 6649 template <class V, HWY_IF_FLOAT_V(V)> 6650 HWY_API V MulComplexAdd(V a, V b, V c) { 6651 return detail::ComplexMulAdd90(detail::ComplexMulAdd0(c, a, b), a, b); 6652 } 6653 6654 template <class V, HWY_IF_FLOAT_V(V)> 6655 HWY_API V MulComplex(V a, V b) { 6656 return MulComplexAdd(a, b, Zero(DFromV<V>())); 6657 } 6658 6659 template <class V, class M, HWY_IF_FLOAT_V(V)> 6660 HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { 6661 return IfThenElse(mask, MulComplex(a, b), no); 6662 } 6663 6664 template <class V, HWY_IF_FLOAT_V(V)> 6665 HWY_API V MulComplexConjAdd(V a, V b, V c) { 6666 return detail::ComplexMulAdd270(detail::ComplexMulAdd0(c, b, a), b, a); 6667 } 6668 6669 template <class V, HWY_IF_FLOAT_V(V)> 6670 HWY_API V MulComplexConj(V a, V b) { 6671 return MulComplexConjAdd(a, b, Zero(DFromV<V>())); 6672 } 6673 6674 // TODO SVE2 does have intrinsics for integers but not masked variants 6675 template <class V, HWY_IF_NOT_FLOAT_V(V)> 6676 HWY_API V MulComplex(V a, V b) { 6677 // a = u + iv, b = x + iy 6678 const auto u = DupEven(a); 6679 const auto v = DupOdd(a); 6680 const auto x = DupEven(b); 6681 const auto y = DupOdd(b); 6682 6683 return OddEven(MulAdd(u, y, Mul(v, x)), Sub(Mul(u, x), Mul(v, y))); 6684 } 6685 6686 template <class V, HWY_IF_NOT_FLOAT_V(V)> 6687 HWY_API V MulComplexConj(V a, V b) { 6688 // a = u + iv, b = x + iy 6689 const auto u = DupEven(a); 6690 const auto v = DupOdd(a); 6691 const auto x = DupEven(b); 6692 const auto y = DupOdd(b); 6693 6694 return OddEven(Sub(Mul(v, x), Mul(u, y)), MulAdd(u, x, Mul(v, y))); 6695 } 6696 6697 template <class V, HWY_IF_NOT_FLOAT_V(V)> 6698 HWY_API V MulComplexAdd(V a, V b, V c) { 6699 return Add(MulComplex(a, b), c); 6700 } 6701 6702 template <class V, HWY_IF_NOT_FLOAT_V(V)> 6703 HWY_API V MulComplexConjAdd(V a, V b, V c) { 6704 return Add(MulComplexConj(a, b), c); 6705 } 6706 6707 template <class V, class M, HWY_IF_NOT_FLOAT_V(V)> 6708 HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { 6709 return IfThenElseZero(mask, MulComplexConjAdd(a, b, c)); 6710 } 6711 6712 template <class V, class M, HWY_IF_NOT_FLOAT_V(V)> 6713 HWY_API V MaskedMulComplexConj(M mask, V a, V b) { 6714 return IfThenElseZero(mask, MulComplexConj(a, b)); 6715 } 6716 6717 template <class V, class M, HWY_IF_NOT_FLOAT_V(V)> 6718 HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { 6719 return IfThenElse(mask, MulComplex(a, b), no); 6720 } 6721 6722 // ------------------------------ AESRound / CLMul 6723 6724 // Static dispatch with -march=armv8-a+sve2+aes, or dynamic dispatch WITHOUT a 6725 // baseline, in which case we check for AES support at runtime. 6726 #if defined(__ARM_FEATURE_SVE2_AES) || \ 6727 (HWY_SVE_HAVE_2 && HWY_HAVE_RUNTIME_DISPATCH && HWY_BASELINE_SVE2 == 0) 6728 6729 // Per-target flag to prevent generic_ops-inl.h from defining AESRound. 6730 #ifdef HWY_NATIVE_AES 6731 #undef HWY_NATIVE_AES 6732 #else 6733 #define HWY_NATIVE_AES 6734 #endif 6735 6736 HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) { 6737 // It is not clear whether E and MC fuse like they did on NEON. 6738 return Xor(svaesmc_u8(svaese_u8(state, svdup_n_u8(0))), round_key); 6739 } 6740 6741 HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) { 6742 return Xor(svaese_u8(state, svdup_n_u8(0)), round_key); 6743 } 6744 6745 HWY_API svuint8_t AESInvMixColumns(svuint8_t state) { 6746 return svaesimc_u8(state); 6747 } 6748 6749 HWY_API svuint8_t AESRoundInv(svuint8_t state, svuint8_t round_key) { 6750 return Xor(svaesimc_u8(svaesd_u8(state, svdup_n_u8(0))), round_key); 6751 } 6752 6753 HWY_API svuint8_t AESLastRoundInv(svuint8_t state, svuint8_t round_key) { 6754 return Xor(svaesd_u8(state, svdup_n_u8(0)), round_key); 6755 } 6756 6757 template <uint8_t kRcon> 6758 HWY_API svuint8_t AESKeyGenAssist(svuint8_t v) { 6759 alignas(16) static constexpr uint8_t kRconXorMask[16] = { 6760 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0}; 6761 alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { 6762 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12}; 6763 const DFromV<decltype(v)> d; 6764 const Repartition<uint32_t, decltype(d)> du32; 6765 const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); 6766 const auto sub_word_result = AESLastRound(w13, LoadDup128(d, kRconXorMask)); 6767 return TableLookupBytes(sub_word_result, LoadDup128(d, kRotWordShuffle)); 6768 } 6769 6770 HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) { 6771 return svpmullb_pair(a, b); 6772 } 6773 6774 HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) { 6775 return svpmullt_pair(a, b); 6776 } 6777 6778 #endif // __ARM_FEATURE_SVE2_AES 6779 6780 // ------------------------------ Lt128 6781 6782 namespace detail { 6783 #define HWY_SVE_DUP(BASE, CHAR, BITS, HALF, NAME, OP) \ 6784 template <size_t N, int kPow2> \ 6785 HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, svbool_t m) { \ 6786 return sv##OP##_b##BITS(m, m); \ 6787 } 6788 6789 HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupEvenB, trn1) // actually for bool 6790 HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupOddB, trn2) // actually for bool 6791 #undef HWY_SVE_DUP 6792 6793 #if HWY_TARGET == HWY_SVE_256 || HWY_IDE 6794 template <class D> 6795 HWY_INLINE svuint64_t Lt128Vec(D d, const svuint64_t a, const svuint64_t b) { 6796 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6797 const svbool_t eqHx = Eq(a, b); // only odd lanes used 6798 // Convert to vector: more pipelines can execute vector TRN* instructions 6799 // than the predicate version. 6800 const svuint64_t ltHL = VecFromMask(d, Lt(a, b)); 6801 // Move into upper lane: ltL if the upper half is equal, otherwise ltH. 6802 // Requires an extra IfThenElse because INSR, EXT, TRN2 are unpredicated. 6803 const svuint64_t ltHx = IfThenElse(eqHx, DupEven(ltHL), ltHL); 6804 // Duplicate upper lane into lower. 6805 return DupOdd(ltHx); 6806 } 6807 #endif 6808 } // namespace detail 6809 6810 template <class D> 6811 HWY_INLINE svbool_t Lt128(D d, const svuint64_t a, const svuint64_t b) { 6812 #if HWY_TARGET == HWY_SVE_256 6813 return MaskFromVec(detail::Lt128Vec(d, a, b)); 6814 #else 6815 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6816 const svbool_t eqHx = Eq(a, b); // only odd lanes used 6817 const svbool_t ltHL = Lt(a, b); 6818 // Move into upper lane: ltL if the upper half is equal, otherwise ltH. 6819 const svbool_t ltHx = svsel_b(eqHx, detail::DupEvenB(d, ltHL), ltHL); 6820 // Duplicate upper lane into lower. 6821 return detail::DupOddB(d, ltHx); 6822 #endif // HWY_TARGET != HWY_SVE_256 6823 } 6824 6825 // ------------------------------ Lt128Upper 6826 6827 template <class D> 6828 HWY_INLINE svbool_t Lt128Upper(D d, svuint64_t a, svuint64_t b) { 6829 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6830 const svbool_t ltHL = Lt(a, b); 6831 return detail::DupOddB(d, ltHL); 6832 } 6833 6834 // ------------------------------ Eq128, Ne128 6835 6836 #if HWY_TARGET == HWY_SVE_256 || HWY_IDE 6837 namespace detail { 6838 6839 template <class D> 6840 HWY_INLINE svuint64_t Eq128Vec(D d, const svuint64_t a, const svuint64_t b) { 6841 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6842 // Convert to vector: more pipelines can execute vector TRN* instructions 6843 // than the predicate version. 6844 const svuint64_t eqHL = VecFromMask(d, Eq(a, b)); 6845 // Duplicate upper and lower. 6846 const svuint64_t eqHH = DupOdd(eqHL); 6847 const svuint64_t eqLL = DupEven(eqHL); 6848 return And(eqLL, eqHH); 6849 } 6850 6851 template <class D> 6852 HWY_INLINE svuint64_t Ne128Vec(D d, const svuint64_t a, const svuint64_t b) { 6853 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6854 // Convert to vector: more pipelines can execute vector TRN* instructions 6855 // than the predicate version. 6856 const svuint64_t neHL = VecFromMask(d, Ne(a, b)); 6857 // Duplicate upper and lower. 6858 const svuint64_t neHH = DupOdd(neHL); 6859 const svuint64_t neLL = DupEven(neHL); 6860 return Or(neLL, neHH); 6861 } 6862 6863 } // namespace detail 6864 #endif 6865 6866 template <class D> 6867 HWY_INLINE svbool_t Eq128(D d, const svuint64_t a, const svuint64_t b) { 6868 #if HWY_TARGET == HWY_SVE_256 6869 return MaskFromVec(detail::Eq128Vec(d, a, b)); 6870 #else 6871 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6872 const svbool_t eqHL = Eq(a, b); 6873 const svbool_t eqHH = detail::DupOddB(d, eqHL); 6874 const svbool_t eqLL = detail::DupEvenB(d, eqHL); 6875 return And(eqLL, eqHH); 6876 #endif // HWY_TARGET != HWY_SVE_256 6877 } 6878 6879 template <class D> 6880 HWY_INLINE svbool_t Ne128(D d, const svuint64_t a, const svuint64_t b) { 6881 #if HWY_TARGET == HWY_SVE_256 6882 return MaskFromVec(detail::Ne128Vec(d, a, b)); 6883 #else 6884 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6885 const svbool_t neHL = Ne(a, b); 6886 const svbool_t neHH = detail::DupOddB(d, neHL); 6887 const svbool_t neLL = detail::DupEvenB(d, neHL); 6888 return Or(neLL, neHH); 6889 #endif // HWY_TARGET != HWY_SVE_256 6890 } 6891 6892 // ------------------------------ Eq128Upper, Ne128Upper 6893 6894 template <class D> 6895 HWY_INLINE svbool_t Eq128Upper(D d, svuint64_t a, svuint64_t b) { 6896 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6897 const svbool_t eqHL = Eq(a, b); 6898 return detail::DupOddB(d, eqHL); 6899 } 6900 6901 template <class D> 6902 HWY_INLINE svbool_t Ne128Upper(D d, svuint64_t a, svuint64_t b) { 6903 static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64"); 6904 const svbool_t neHL = Ne(a, b); 6905 return detail::DupOddB(d, neHL); 6906 } 6907 6908 // ------------------------------ Min128, Max128 (Lt128) 6909 6910 template <class D> 6911 HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) { 6912 #if HWY_TARGET == HWY_SVE_256 6913 return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); 6914 #else 6915 return IfThenElse(Lt128(d, a, b), a, b); 6916 #endif 6917 } 6918 6919 template <class D> 6920 HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) { 6921 #if HWY_TARGET == HWY_SVE_256 6922 return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); 6923 #else 6924 return IfThenElse(Lt128(d, b, a), a, b); 6925 #endif 6926 } 6927 6928 template <class D> 6929 HWY_INLINE svuint64_t Min128Upper(D d, const svuint64_t a, const svuint64_t b) { 6930 return IfThenElse(Lt128Upper(d, a, b), a, b); 6931 } 6932 6933 template <class D> 6934 HWY_INLINE svuint64_t Max128Upper(D d, const svuint64_t a, const svuint64_t b) { 6935 return IfThenElse(Lt128Upper(d, b, a), a, b); 6936 } 6937 6938 // -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex 6939 6940 #ifdef HWY_NATIVE_LEADING_ZERO_COUNT 6941 #undef HWY_NATIVE_LEADING_ZERO_COUNT 6942 #else 6943 #define HWY_NATIVE_LEADING_ZERO_COUNT 6944 #endif 6945 6946 #define HWY_SVE_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ 6947 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ 6948 const DFromV<decltype(v)> d; \ 6949 return BitCast(d, sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v)); \ 6950 } 6951 6952 HWY_SVE_FOREACH_UI(HWY_SVE_LEADING_ZERO_COUNT, LeadingZeroCount, clz) 6953 #undef HWY_SVE_LEADING_ZERO_COUNT 6954 6955 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 6956 HWY_API V TrailingZeroCount(V v) { 6957 return LeadingZeroCount(ReverseBits(v)); 6958 } 6959 6960 template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> 6961 HWY_API V HighestSetBitIndex(V v) { 6962 const DFromV<decltype(v)> d; 6963 using T = TFromD<decltype(d)>; 6964 return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v))); 6965 } 6966 6967 #ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT 6968 #undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT 6969 #else 6970 #define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT 6971 #endif 6972 6973 #define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ 6974 HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ 6975 const DFromV<decltype(v)> d; \ 6976 return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \ 6977 } 6978 6979 HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount, 6980 clz) 6981 #undef HWY_SVE_LEADING_ZERO_COUNT 6982 6983 // ================================================== END MACROS 6984 #undef HWY_SVE_ALL_PTRUE 6985 #undef HWY_SVE_D 6986 #undef HWY_SVE_FOREACH 6987 #undef HWY_SVE_FOREACH_BF16 6988 #undef HWY_SVE_FOREACH_BF16_UNCONDITIONAL 6989 #undef HWY_SVE_FOREACH_F 6990 #undef HWY_SVE_FOREACH_F16 6991 #undef HWY_SVE_FOREACH_F32 6992 #undef HWY_SVE_FOREACH_F3264 6993 #undef HWY_SVE_FOREACH_F64 6994 #undef HWY_SVE_FOREACH_I 6995 #undef HWY_SVE_FOREACH_I08 6996 #undef HWY_SVE_FOREACH_I16 6997 #undef HWY_SVE_FOREACH_I32 6998 #undef HWY_SVE_FOREACH_I64 6999 #undef HWY_SVE_FOREACH_IF 7000 #undef HWY_SVE_FOREACH_U 7001 #undef HWY_SVE_FOREACH_U08 7002 #undef HWY_SVE_FOREACH_U16 7003 #undef HWY_SVE_FOREACH_U32 7004 #undef HWY_SVE_FOREACH_U64 7005 #undef HWY_SVE_FOREACH_UI 7006 #undef HWY_SVE_FOREACH_UI08 7007 #undef HWY_SVE_FOREACH_UI16 7008 #undef HWY_SVE_FOREACH_UI32 7009 #undef HWY_SVE_FOREACH_UI64 7010 #undef HWY_SVE_FOREACH_UIF3264 7011 #undef HWY_SVE_HAVE_2 7012 #undef HWY_SVE_IF_EMULATED_D 7013 #undef HWY_SVE_IF_NOT_EMULATED_D 7014 #undef HWY_SVE_PTRUE 7015 #undef HWY_SVE_RETV_ARGMVV 7016 #undef HWY_SVE_RETV_ARGMVV_Z 7017 #undef HWY_SVE_RETV_ARGMV_Z 7018 #undef HWY_SVE_RETV_ARGMV 7019 #undef HWY_SVE_RETV_ARGMVV_Z 7020 #undef HWY_SVE_RETV_ARGPV 7021 #undef HWY_SVE_RETV_ARGPVN 7022 #undef HWY_SVE_RETV_ARGPVV 7023 #undef HWY_SVE_RETV_ARGV 7024 #undef HWY_SVE_RETV_ARGVN 7025 #undef HWY_SVE_RETV_ARGMV_M 7026 #undef HWY_SVE_RETV_ARGVV 7027 #undef HWY_SVE_RETV_ARGVVV 7028 #undef HWY_SVE_RETV_ARGMVVV_Z 7029 #undef HWY_SVE_RETV_ARGMVVV 7030 #undef HWY_SVE_T 7031 #undef HWY_SVE_UNDEFINED 7032 #undef HWY_SVE_V 7033 7034 // NOLINTNEXTLINE(google-readability-namespace-comments) 7035 } // namespace HWY_NAMESPACE 7036 } // namespace hwy 7037 HWY_AFTER_NAMESPACE();