dot-inl.h (17983B)
1 // Copyright 2021 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // clang-format off 17 #if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT 18 // clang-format on 19 #ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ 20 #undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ 21 #else 22 #define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ 23 #endif 24 25 #include <stddef.h> 26 #include <stdint.h> 27 28 #include "hwy/highway.h" 29 30 HWY_BEFORE_NAMESPACE(); 31 namespace hwy { 32 namespace HWY_NAMESPACE { 33 34 // NOTE: the D argument describes the inputs, not the output, because both 35 // f32/f32, bf16/bf16, and f32/bf16 inputs accumulate to f32. 36 struct Dot { 37 // Specify zero or more of these, ORed together, as the kAssumptions template 38 // argument to Compute. Each one may improve performance or reduce code size, 39 // at the cost of additional requirements on the arguments. 40 enum Assumptions { 41 // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T). 42 kAtLeastOneVector = 1, 43 // num_elements is divisible by N (a power of two, so this can be used if 44 // the problem size is known to be a power of two >= HWY_MAX_BYTES / 45 // sizeof(T)). 46 kMultipleOfVector = 2, 47 // RoundUpTo(num_elements, N) elements are accessible; their value does not 48 // matter (will be treated as if they were zero). 49 kPaddedToVector = 4, 50 }; 51 52 // Returns sum{pa[i] * pb[i]} for floating-point inputs, including float16_t 53 // and double if HWY_HAVE_FLOAT16/64. Aligning the 54 // pointers to a multiple of N elements is helpful but not required. 55 template <int kAssumptions, class D, typename T = TFromD<D>> 56 static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa, 57 const T* const HWY_RESTRICT pb, 58 const size_t num_elements) { 59 static_assert(IsFloat<T>(), "MulAdd requires float type"); 60 using V = decltype(Zero(d)); 61 62 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 63 size_t i = 0; 64 65 constexpr bool kIsAtLeastOneVector = 66 (kAssumptions & kAtLeastOneVector) != 0; 67 constexpr bool kIsMultipleOfVector = 68 (kAssumptions & kMultipleOfVector) != 0; 69 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; 70 71 // Won't be able to do a full vector load without padding => scalar loop. 72 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && 73 HWY_UNLIKELY(num_elements < N)) { 74 // Only 2x unroll to avoid excessive code size. 75 T sum0 = ConvertScalarTo<T>(0); 76 T sum1 = ConvertScalarTo<T>(0); 77 for (; i + 2 <= num_elements; i += 2) { 78 // For reasons unknown, fp16 += does not compile on clang (Arm). 79 sum0 = ConvertScalarTo<T>(sum0 + pa[i + 0] * pb[i + 0]); 80 sum1 = ConvertScalarTo<T>(sum1 + pa[i + 1] * pb[i + 1]); 81 } 82 if (i < num_elements) { 83 sum1 = ConvertScalarTo<T>(sum1 + pa[i] * pb[i]); 84 } 85 return ConvertScalarTo<T>(sum0 + sum1); 86 } 87 88 // Compiler doesn't make independent sum* accumulators, so unroll manually. 89 // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive 90 // for unaligned inputs (each unaligned pointer halves the throughput 91 // because it occupies both L1 load ports for a cycle). We cannot have 92 // arrays of vectors on RVV/SVE, so always unroll 4x. 93 V sum0 = Zero(d); 94 V sum1 = Zero(d); 95 V sum2 = Zero(d); 96 V sum3 = Zero(d); 97 98 // Main loop: unrolled 99 for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop 100 const auto a0 = LoadU(d, pa + i); 101 const auto b0 = LoadU(d, pb + i); 102 i += N; 103 sum0 = MulAdd(a0, b0, sum0); 104 const auto a1 = LoadU(d, pa + i); 105 const auto b1 = LoadU(d, pb + i); 106 i += N; 107 sum1 = MulAdd(a1, b1, sum1); 108 const auto a2 = LoadU(d, pa + i); 109 const auto b2 = LoadU(d, pb + i); 110 i += N; 111 sum2 = MulAdd(a2, b2, sum2); 112 const auto a3 = LoadU(d, pa + i); 113 const auto b3 = LoadU(d, pb + i); 114 i += N; 115 sum3 = MulAdd(a3, b3, sum3); 116 } 117 118 // Up to 3 iterations of whole vectors 119 for (; i + N <= num_elements; i += N) { 120 const auto a = LoadU(d, pa + i); 121 const auto b = LoadU(d, pb + i); 122 sum0 = MulAdd(a, b, sum0); 123 } 124 125 if (!kIsMultipleOfVector) { 126 const size_t remaining = num_elements - i; 127 if (remaining != 0) { 128 if (kIsPaddedToVector) { 129 const auto mask = FirstN(d, remaining); 130 const auto a = LoadU(d, pa + i); 131 const auto b = LoadU(d, pb + i); 132 sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); 133 } else { 134 // Unaligned load such that the last element is in the highest lane - 135 // ensures we do not touch any elements outside the valid range. 136 // If we get here, then num_elements >= N. 137 HWY_DASSERT(i >= N); 138 i += remaining - N; 139 const auto skip = FirstN(d, N - remaining); 140 const auto a = LoadU(d, pa + i); // always unaligned 141 const auto b = LoadU(d, pb + i); 142 sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); 143 } 144 } 145 } // kMultipleOfVector 146 147 // Reduction tree: sum of all accumulators by pairs, then across lanes. 148 sum0 = Add(sum0, sum1); 149 sum2 = Add(sum2, sum3); 150 sum0 = Add(sum0, sum2); 151 return ReduceSum(d, sum0); 152 } 153 154 // f32 * bf16 155 template <int kAssumptions, class DF, HWY_IF_F32_D(DF)> 156 static HWY_INLINE float Compute(const DF df, 157 const float* const HWY_RESTRICT pa, 158 const hwy::bfloat16_t* const HWY_RESTRICT pb, 159 const size_t num_elements) { 160 #if HWY_TARGET == HWY_SCALAR 161 const Rebind<hwy::bfloat16_t, DF> dbf; 162 #else 163 const Repartition<hwy::bfloat16_t, DF> dbf; 164 using VBF = decltype(Zero(dbf)); 165 #endif 166 const Half<decltype(dbf)> dbfh; 167 using VF = decltype(Zero(df)); 168 169 HWY_LANES_CONSTEXPR size_t NF = Lanes(df); 170 171 constexpr bool kIsAtLeastOneVector = 172 (kAssumptions & kAtLeastOneVector) != 0; 173 constexpr bool kIsMultipleOfVector = 174 (kAssumptions & kMultipleOfVector) != 0; 175 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; 176 177 // Won't be able to do a full vector load without padding => scalar loop. 178 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && 179 HWY_UNLIKELY(num_elements < NF)) { 180 // Only 2x unroll to avoid excessive code size. 181 float sum0 = 0.0f; 182 float sum1 = 0.0f; 183 size_t i = 0; 184 for (; i + 2 <= num_elements; i += 2) { 185 sum0 += pa[i + 0] * ConvertScalarTo<float>(pb[i + 0]); 186 sum1 += pa[i + 1] * ConvertScalarTo<float>(pb[i + 1]); 187 } 188 for (; i < num_elements; ++i) { 189 sum1 += pa[i] * ConvertScalarTo<float>(pb[i]); 190 } 191 return sum0 + sum1; 192 } 193 194 // Compiler doesn't make independent sum* accumulators, so unroll manually. 195 // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive 196 // for unaligned inputs (each unaligned pointer halves the throughput 197 // because it occupies both L1 load ports for a cycle). We cannot have 198 // arrays of vectors on RVV/SVE, so always unroll 4x. 199 VF sum0 = Zero(df); 200 VF sum1 = Zero(df); 201 VF sum2 = Zero(df); 202 VF sum3 = Zero(df); 203 204 size_t i = 0; 205 206 #if HWY_TARGET != HWY_SCALAR // PromoteUpperTo supported 207 // Main loop: unrolled 208 for (; i + 4 * NF <= num_elements; /* i += 4 * N */) { // incr in loop 209 const VF a0 = LoadU(df, pa + i); 210 const VBF b0 = LoadU(dbf, pb + i); 211 i += NF; 212 sum0 = MulAdd(a0, PromoteLowerTo(df, b0), sum0); 213 const VF a1 = LoadU(df, pa + i); 214 i += NF; 215 sum1 = MulAdd(a1, PromoteUpperTo(df, b0), sum1); 216 const VF a2 = LoadU(df, pa + i); 217 const VBF b2 = LoadU(dbf, pb + i); 218 i += NF; 219 sum2 = MulAdd(a2, PromoteLowerTo(df, b2), sum2); 220 const VF a3 = LoadU(df, pa + i); 221 i += NF; 222 sum3 = MulAdd(a3, PromoteUpperTo(df, b2), sum3); 223 } 224 #endif // HWY_TARGET == HWY_SCALAR 225 226 // Up to 3 iterations of whole vectors 227 for (; i + NF <= num_elements; i += NF) { 228 const VF a = LoadU(df, pa + i); 229 const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); 230 sum0 = MulAdd(a, b, sum0); 231 } 232 233 if (!kIsMultipleOfVector) { 234 const size_t remaining = num_elements - i; 235 if (remaining != 0) { 236 if (kIsPaddedToVector) { 237 const auto mask = FirstN(df, remaining); 238 const VF a = LoadU(df, pa + i); 239 const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); 240 sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); 241 } else { 242 // Unaligned load such that the last element is in the highest lane - 243 // ensures we do not touch any elements outside the valid range. 244 // If we get here, then num_elements >= N. 245 HWY_DASSERT(i >= NF); 246 i += remaining - NF; 247 const auto skip = FirstN(df, NF - remaining); 248 const VF a = LoadU(df, pa + i); // always unaligned 249 const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); 250 sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); 251 } 252 } 253 } // kMultipleOfVector 254 255 // Reduction tree: sum of all accumulators by pairs, then across lanes. 256 sum0 = Add(sum0, sum1); 257 sum2 = Add(sum2, sum3); 258 sum0 = Add(sum0, sum2); 259 return ReduceSum(df, sum0); 260 } 261 262 // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a 263 // multiple of N elements is helpful but not required. 264 template <int kAssumptions, class D, HWY_IF_BF16_D(D)> 265 static HWY_INLINE float Compute(const D d, 266 const bfloat16_t* const HWY_RESTRICT pa, 267 const bfloat16_t* const HWY_RESTRICT pb, 268 const size_t num_elements) { 269 const RebindToUnsigned<D> du16; 270 const Repartition<float, D> df32; 271 272 using V = decltype(Zero(df32)); 273 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 274 size_t i = 0; 275 276 constexpr bool kIsAtLeastOneVector = 277 (kAssumptions & kAtLeastOneVector) != 0; 278 constexpr bool kIsMultipleOfVector = 279 (kAssumptions & kMultipleOfVector) != 0; 280 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; 281 282 // Won't be able to do a full vector load without padding => scalar loop. 283 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && 284 HWY_UNLIKELY(num_elements < N)) { 285 float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for.. 286 float sum1 = 0.0f; // this unlikely(?) case. 287 for (; i + 2 <= num_elements; i += 2) { 288 sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]); 289 sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]); 290 } 291 if (i < num_elements) { 292 sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); 293 } 294 return sum0 + sum1; 295 } 296 297 // See comment in the other Compute() overload. Unroll 2x, but we need 298 // twice as many sums for ReorderWidenMulAccumulate. 299 V sum0 = Zero(df32); 300 V sum1 = Zero(df32); 301 V sum2 = Zero(df32); 302 V sum3 = Zero(df32); 303 304 // Main loop: unrolled 305 for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop 306 const auto a0 = LoadU(d, pa + i); 307 const auto b0 = LoadU(d, pb + i); 308 i += N; 309 sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); 310 const auto a1 = LoadU(d, pa + i); 311 const auto b1 = LoadU(d, pb + i); 312 i += N; 313 sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3); 314 } 315 316 // Possibly one more iteration of whole vectors 317 if (i + N <= num_elements) { 318 const auto a0 = LoadU(d, pa + i); 319 const auto b0 = LoadU(d, pb + i); 320 i += N; 321 sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); 322 } 323 324 if (!kIsMultipleOfVector) { 325 const size_t remaining = num_elements - i; 326 if (remaining != 0) { 327 if (kIsPaddedToVector) { 328 const auto mask = FirstN(du16, remaining); 329 const auto va = LoadU(d, pa + i); 330 const auto vb = LoadU(d, pb + i); 331 const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); 332 const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); 333 sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); 334 335 } else { 336 // Unaligned load such that the last element is in the highest lane - 337 // ensures we do not touch any elements outside the valid range. 338 // If we get here, then num_elements >= N. 339 HWY_DASSERT(i >= N); 340 i += remaining - N; 341 const auto skip = FirstN(du16, N - remaining); 342 const auto va = LoadU(d, pa + i); // always unaligned 343 const auto vb = LoadU(d, pb + i); 344 const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); 345 const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); 346 sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); 347 } 348 } 349 } // kMultipleOfVector 350 351 // Reduction tree: sum of all accumulators by pairs, then across lanes. 352 sum0 = Add(sum0, sum1); 353 sum2 = Add(sum2, sum3); 354 sum0 = Add(sum0, sum2); 355 return ReduceSum(df32, sum0); 356 } 357 358 // Returns sum{i32(pa[i]) * i32(pb[i])} for i16 inputs. Aligning the pointers 359 // to a multiple of N elements is helpful but not required. 360 template <int kAssumptions, class D, HWY_IF_I16_D(D)> 361 static HWY_INLINE int32_t Compute(const D d, 362 const int16_t* const HWY_RESTRICT pa, 363 const int16_t* const HWY_RESTRICT pb, 364 const size_t num_elements) { 365 const RebindToUnsigned<D> du16; 366 const RepartitionToWide<D> di32; 367 368 using VI32 = Vec<decltype(di32)>; 369 HWY_LANES_CONSTEXPR size_t N = Lanes(d); 370 size_t i = 0; 371 372 constexpr bool kIsAtLeastOneVector = 373 (kAssumptions & kAtLeastOneVector) != 0; 374 constexpr bool kIsMultipleOfVector = 375 (kAssumptions & kMultipleOfVector) != 0; 376 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; 377 378 // Won't be able to do a full vector load without padding => scalar loop. 379 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && 380 HWY_UNLIKELY(num_elements < N)) { 381 int32_t sum0 = 0; // Only 2x unroll to avoid excessive code size for.. 382 int32_t sum1 = 0; // this unlikely(?) case. 383 for (; i + 2 <= num_elements; i += 2) { 384 sum0 += int32_t{pa[i + 0]} * int32_t{pb[i + 0]}; 385 sum1 += int32_t{pa[i + 1]} * int32_t{pb[i + 1]}; 386 } 387 if (i < num_elements) { 388 sum1 += int32_t{pa[i]} * int32_t{pb[i]}; 389 } 390 return sum0 + sum1; 391 } 392 393 // See comment in the other Compute() overload. Unroll 2x, but we need 394 // twice as many sums for ReorderWidenMulAccumulate. 395 VI32 sum0 = Zero(di32); 396 VI32 sum1 = Zero(di32); 397 VI32 sum2 = Zero(di32); 398 VI32 sum3 = Zero(di32); 399 400 // Main loop: unrolled 401 for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop 402 const auto a0 = LoadU(d, pa + i); 403 const auto b0 = LoadU(d, pb + i); 404 i += N; 405 sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1); 406 const auto a1 = LoadU(d, pa + i); 407 const auto b1 = LoadU(d, pb + i); 408 i += N; 409 sum2 = ReorderWidenMulAccumulate(di32, a1, b1, sum2, sum3); 410 } 411 412 // Possibly one more iteration of whole vectors 413 if (i + N <= num_elements) { 414 const auto a0 = LoadU(d, pa + i); 415 const auto b0 = LoadU(d, pb + i); 416 i += N; 417 sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1); 418 } 419 420 if (!kIsMultipleOfVector) { 421 const size_t remaining = num_elements - i; 422 if (remaining != 0) { 423 if (kIsPaddedToVector) { 424 const auto mask = FirstN(du16, remaining); 425 const auto va = LoadU(d, pa + i); 426 const auto vb = LoadU(d, pb + i); 427 const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); 428 const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); 429 sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3); 430 431 } else { 432 // Unaligned load such that the last element is in the highest lane - 433 // ensures we do not touch any elements outside the valid range. 434 // If we get here, then num_elements >= N. 435 HWY_DASSERT(i >= N); 436 i += remaining - N; 437 const auto skip = FirstN(du16, N - remaining); 438 const auto va = LoadU(d, pa + i); // always unaligned 439 const auto vb = LoadU(d, pb + i); 440 const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); 441 const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); 442 sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3); 443 } 444 } 445 } // kMultipleOfVector 446 447 // Reduction tree: sum of all accumulators by pairs, then across lanes. 448 sum0 = Add(sum0, sum1); 449 sum2 = Add(sum2, sum3); 450 sum0 = Add(sum0, sum2); 451 return ReduceSum(di32, sum0); 452 } 453 }; 454 455 // NOLINTNEXTLINE(google-readability-namespace-comments) 456 } // namespace HWY_NAMESPACE 457 } // namespace hwy 458 HWY_AFTER_NAMESPACE(); 459 460 #endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_