matvec-inl.h (19278B)
1 // Copyright 2023 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // Include guard (still compiled once per target) 17 #if defined(HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_) == \ 18 defined(HWY_TARGET_TOGGLE) 19 #ifdef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ 20 #undef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ 21 #else 22 #define HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ 23 #endif 24 25 #include <stddef.h> 26 #include <stdint.h> 27 28 #include "hwy/cache_control.h" 29 #include "hwy/contrib/thread_pool/thread_pool.h" 30 #include "hwy/highway.h" 31 32 HWY_BEFORE_NAMESPACE(); 33 namespace hwy { 34 namespace HWY_NAMESPACE { 35 36 template <typename TA, typename TB> 37 TA AddScalar(TA a, TB b) { 38 return ConvertScalarTo<TA>(ConvertScalarTo<float>(a) + 39 ConvertScalarTo<float>(b)); 40 } 41 42 template <size_t kOuter, size_t kInner, typename T, bool kAdd> 43 HWY_NOINLINE void MatVecAddImpl(const T* HWY_RESTRICT mat, 44 const T* HWY_RESTRICT vec, 45 const T* HWY_RESTRICT add, T* HWY_RESTRICT out, 46 hwy::ThreadPool& pool) { 47 (void)add; 48 49 // Process multiple rows at a time so that we write multiples of a cache line 50 // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little 51 // parallelization potential. 52 constexpr size_t kChunkSize2 = 64 / sizeof(T); 53 const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2); 54 55 const ScalableTag<T> d; 56 const size_t N = Lanes(d); 57 // Required for Stream loop, otherwise we might have partial vectors. 58 HWY_DASSERT(kChunkSize2 >= N); 59 pool.Run(0, num_chunks, 60 [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { 61 // MSVC workaround: duplicate to ensure constexpr. 62 constexpr size_t kChunkSize = 64 / sizeof(T); 63 // Software write-combining to avoid cache pollution from out. 64 // Although `out` may be used later, keeping it out of the cache 65 // now and avoiding RFOs is a consistent 5% overall win. 66 HWY_ALIGN T buf[kChunkSize]; 67 68 // Only handle entire chunks here because the Stream is not masked. 69 // Remaining rows are handled after the pool.Run. 70 const size_t begin = static_cast<size_t>(chunk * kChunkSize); 71 for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { 72 auto sum0 = Zero(d); 73 auto sum1 = Zero(d); 74 // 4x unrolling barely helps SKX but likely helps Arm V2. 75 auto sum2 = Zero(d); 76 auto sum3 = Zero(d); 77 78 const T* HWY_RESTRICT row = &mat[(begin + idx_row) * kInner]; 79 size_t i = 0; 80 // No clear win from prefetching from the next 1..3 rows. 81 // clflush &row[i] is slow, clflushopt less so but not helping. 82 HWY_UNROLL(1) 83 for (; i + 4 * N <= kInner; i += 4 * N) { 84 const auto a0 = LoadU(d, row + i + 0 * N); 85 const auto v0 = LoadU(d, vec + i + 0 * N); 86 sum0 = MulAdd(a0, v0, sum0); 87 88 const auto a1 = LoadU(d, row + i + 1 * N); 89 const auto v1 = LoadU(d, vec + i + 1 * N); 90 sum1 = MulAdd(a1, v1, sum1); 91 92 const auto a2 = LoadU(d, row + i + 2 * N); 93 const auto v2 = LoadU(d, vec + i + 2 * N); 94 sum2 = MulAdd(a2, v2, sum2); 95 96 const auto a3 = LoadU(d, row + i + 3 * N); 97 const auto v3 = LoadU(d, vec + i + 3 * N); 98 sum3 = MulAdd(a3, v3, sum3); 99 } 100 // Last entire vectors 101 for (; i + N <= kInner; i += N) { 102 const auto a0 = LoadU(d, row + i); 103 const auto v0 = LoadU(d, vec + i); 104 sum0 = MulAdd(a0, v0, sum0); 105 } 106 const size_t remainder = kInner - i; 107 if (remainder != 0) { 108 const auto a0 = LoadN(d, row + i, remainder); 109 const auto v0 = LoadN(d, vec + i, remainder); 110 sum1 = MulAdd(a0, v0, sum1); 111 } 112 // Reduction tree: sum of all accumulators, then their lanes 113 sum2 = Add(sum2, sum3); 114 sum0 = Add(sum0, sum1); 115 sum0 = Add(sum0, sum2); 116 buf[idx_row] = ReduceSum(d, sum0); 117 HWY_IF_CONSTEXPR(kAdd) { 118 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); 119 } 120 } // idx_row 121 HWY_UNROLL(4) // 1..4 iterations 122 for (size_t i = 0; i != kChunkSize; i += N) { 123 Stream(Load(d, buf + i), d, out + begin + i); 124 } 125 }); 126 hwy::FlushStream(); 127 128 // Handle remainder rows which are not a multiple of the chunk size. 129 for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) { 130 auto sum0 = Zero(d); 131 132 const T* HWY_RESTRICT row = &mat[r * kInner]; 133 size_t i = 0; 134 HWY_UNROLL(1) 135 for (; i + N <= kInner; i += N) { 136 const auto a0 = LoadU(d, row + i); 137 const auto v0 = LoadU(d, vec + i); 138 sum0 = MulAdd(a0, v0, sum0); 139 } 140 const size_t remainder = kInner - i; 141 if (remainder != 0) { 142 const auto a0 = LoadN(d, row + i, remainder); 143 const auto v0 = LoadN(d, vec + i, remainder); 144 sum0 = MulAdd(a0, v0, sum0); 145 } 146 out[r] = ReduceSum(d, sum0); 147 HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } 148 } // r 149 } 150 151 // Multiplies mat with vec, adds add and puts the result in out. 152 // 153 // mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at 154 // index i * kInner + j. 155 // 156 // vec is a (kInner,)-shaped array. 157 // 158 // add is a (kOuter,)-shaped array. 159 // 160 // out is a (kOuter,)-shaped array that will set to mat @ vec + add. 161 template <size_t kOuter, size_t kInner, typename T> 162 HWY_NOINLINE void MatVecAdd(const T* HWY_RESTRICT mat, 163 const T* HWY_RESTRICT vec, 164 const T* HWY_RESTRICT add, T* HWY_RESTRICT out, 165 hwy::ThreadPool& pool) { 166 MatVecAddImpl<kOuter, kInner, T, true>(mat, vec, add, out, pool); 167 } 168 169 // Multiplies mat with vec and puts the result in out. 170 // 171 // mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at 172 // index i * kInner + j. 173 // 174 // vec is a (kInner,)-shaped array. 175 // 176 // out is a (kOuter,)-shaped array that will set to mat @ vec. 177 template <size_t kOuter, size_t kInner, typename T> 178 HWY_NOINLINE void MatVec(const T* HWY_RESTRICT mat, const T* HWY_RESTRICT vec, 179 T* HWY_RESTRICT out, hwy::ThreadPool& pool) { 180 MatVecAddImpl<kOuter, kInner, T, false>(mat, vec, /*add=*/nullptr, out, pool); 181 } 182 183 // This target lacks too many ops required in our implementation, use 184 // HWY_EMU128 instead. 185 #if HWY_TARGET != HWY_SCALAR 186 187 // Specialization for bf16 matrix, which halves memory bandwidth requirements. 188 template <size_t kOuter, size_t kInner, bool kAdd> 189 HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat, 190 const float* HWY_RESTRICT vec, 191 const float* HWY_RESTRICT add, 192 float* HWY_RESTRICT out, 193 hwy::ThreadPool& pool) { 194 // Process multiple rows at a time so that we write multiples of a cache line 195 // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little 196 // parallelization potential. 197 constexpr size_t kChunkSize2 = 64 / sizeof(float); 198 const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2); 199 200 const ScalableTag<float> d; 201 const Repartition<hwy::bfloat16_t, decltype(d)> d16; 202 // In the remainder loop, we only process a single f32 vector, so load half 203 // vectors of bf16 to avoid overrun. 204 const Half<decltype(d16)> d16h; 205 using V = Vec<decltype(d)>; 206 using V16 = Vec<decltype(d16)>; 207 using V16H = Vec<decltype(d16h)>; 208 const size_t N = Lanes(d); 209 // Required for Stream loop, otherwise we might have partial vectors. 210 HWY_DASSERT(kChunkSize2 >= N); 211 pool.Run(0, num_chunks, 212 [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { 213 // MSVC workaround: duplicate to ensure constexpr. 214 constexpr size_t kChunkSize = 64 / sizeof(float); 215 // Software write-combining to avoid cache pollution from out. 216 // Although `out` may be used later, keeping it out of the cache 217 // now and avoiding RFOs is a consistent 5% overall win. 218 HWY_ALIGN float buf[kChunkSize]; 219 220 // Only handle entire chunks here because the Stream is not masked. 221 // Remaining rows are handled after the pool.Run. 222 const size_t begin = static_cast<size_t>(chunk * kChunkSize); 223 for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { 224 auto sum0 = Zero(d); 225 auto sum1 = Zero(d); 226 // 4x unrolling barely helps SKX but likely helps Arm V2. 227 auto sum2 = Zero(d); 228 auto sum3 = Zero(d); 229 230 const hwy::bfloat16_t* HWY_RESTRICT row = 231 &mat[(begin + idx_row) * kInner]; 232 size_t i = 0; 233 // No clear win from prefetching from the next 1..3 rows. 234 // clflush &row[i] is slow, clflushopt less so but not helping. 235 HWY_UNROLL(1) 236 for (; i + 4 * N <= kInner; i += 4 * N) { 237 const V16 b0 = LoadU(d16, row + i + 0 * N); 238 const V a0 = PromoteLowerTo(d, b0); 239 const V a1 = PromoteUpperTo(d, b0); 240 241 const V16 b1 = LoadU(d16, row + i + 2 * N); 242 const V a2 = PromoteLowerTo(d, b1); 243 const V a3 = PromoteUpperTo(d, b1); 244 245 const V v0 = LoadU(d, vec + i + 0 * N); 246 sum0 = MulAdd(a0, v0, sum0); 247 248 const V v1 = LoadU(d, vec + i + 1 * N); 249 sum1 = MulAdd(a1, v1, sum1); 250 251 const V v2 = LoadU(d, vec + i + 2 * N); 252 sum2 = MulAdd(a2, v2, sum2); 253 254 const V v3 = LoadU(d, vec + i + 3 * N); 255 sum3 = MulAdd(a3, v3, sum3); 256 } 257 // Last entire vectors 258 for (; i + N <= kInner; i += N) { 259 const V16H b0 = LoadU(d16h, row + i); 260 const V a0 = PromoteTo(d, b0); 261 const V v0 = LoadU(d, vec + i); 262 sum0 = MulAdd(a0, v0, sum0); 263 } 264 const size_t remainder = kInner - i; 265 if (remainder != 0) { 266 const V16H b0 = LoadN(d16h, row + i, remainder); 267 const V a0 = PromoteTo(d, b0); 268 const V v0 = LoadN(d, vec + i, remainder); 269 sum1 = MulAdd(a0, v0, sum1); 270 } 271 // Reduction tree: sum of all accumulators, then their lanes 272 sum2 = Add(sum2, sum3); 273 sum0 = Add(sum0, sum1); 274 sum0 = Add(sum0, sum2); 275 buf[idx_row] = ReduceSum(d, sum0); 276 HWY_IF_CONSTEXPR(kAdd) { 277 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); 278 } 279 } // idx_row 280 HWY_UNROLL(4) // 1..4 iterations 281 for (size_t i = 0; i != kChunkSize; i += N) { 282 Stream(Load(d, buf + i), d, out + begin + i); 283 } 284 }); 285 hwy::FlushStream(); 286 287 // Handle remainder rows which are not a multiple of the chunk size. 288 for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) { 289 auto sum0 = Zero(d); 290 291 const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner]; 292 size_t i = 0; 293 HWY_UNROLL(1) 294 for (; i + N <= kInner; i += N) { 295 const V16H b0 = LoadU(d16h, row + i); 296 const V a0 = PromoteTo(d, b0); 297 const V v0 = LoadU(d, vec + i); 298 sum0 = MulAdd(a0, v0, sum0); 299 } 300 const size_t remainder = kInner - i; 301 if (remainder != 0) { 302 const V16H b0 = LoadN(d16h, row + i, remainder); 303 const V a0 = PromoteTo(d, b0); 304 const V v0 = LoadN(d, vec + i, remainder); 305 sum0 = MulAdd(a0, v0, sum0); 306 } 307 out[r] = ReduceSum(d, sum0); 308 HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } 309 } // r 310 } 311 312 template <size_t kOuter, size_t kInner> 313 HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat, 314 const float* HWY_RESTRICT vec, 315 const float* HWY_RESTRICT add, 316 float* HWY_RESTRICT out, hwy::ThreadPool& pool) { 317 MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool); 318 } 319 320 template <size_t kOuter, size_t kInner> 321 HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat, 322 const float* HWY_RESTRICT vec, float* HWY_RESTRICT out, 323 hwy::ThreadPool& pool) { 324 MatVecAddImpl<kOuter, kInner, false>(mat, vec, /*add=*/nullptr, out, pool); 325 } 326 327 // Both mat and vec are bf16. 328 template <size_t kOuter, size_t kInner, bool kAdd> 329 HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat, 330 const hwy::bfloat16_t* HWY_RESTRICT vec, 331 const hwy::bfloat16_t* HWY_RESTRICT add, 332 float* HWY_RESTRICT out, 333 hwy::ThreadPool& pool) { 334 // Process multiple rows at a time so that we write multiples of a cache line 335 // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little 336 // parallelization potential. 337 constexpr size_t kChunkSize2 = 64 / sizeof(bfloat16_t); 338 const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2); 339 340 const ScalableTag<float> df; 341 const Repartition<hwy::bfloat16_t, decltype(df)> d16; 342 using V16 = Vec<decltype(d16)>; 343 const size_t N = Lanes(d16); 344 // Required for Stream loop, otherwise we might have partial vectors. 345 HWY_DASSERT(kChunkSize2 >= N); 346 pool.Run(0, num_chunks, 347 [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { 348 // MSVC workaround: duplicate to ensure constexpr. 349 constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t); 350 // Software write-combining to avoid cache pollution from out. 351 // Although `out` may be used later, keeping it out of the cache 352 // now and avoiding RFOs is a consistent 5% overall win. 353 HWY_ALIGN float buf[kChunkSize]; 354 355 // Only handle entire chunks here because the Stream is not masked. 356 // Remaining rows are handled after the pool.Run. 357 const size_t begin = static_cast<size_t>(chunk * kChunkSize); 358 for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { 359 auto sum0 = Zero(df); 360 auto sum1 = Zero(df); 361 auto sum2 = Zero(df); 362 auto sum3 = Zero(df); 363 364 const hwy::bfloat16_t* HWY_RESTRICT row = 365 &mat[(begin + idx_row) * kInner]; 366 size_t i = 0; 367 // No clear win from prefetching from the next 1..3 rows. 368 // clflush &row[i] is slow, clflushopt less so but not helping. 369 HWY_UNROLL(1) 370 for (; i + 2 * N <= kInner; i += 2 * N) { 371 const V16 b0 = LoadU(d16, row + i + 0 * N); 372 const V16 b1 = LoadU(d16, row + i + 1 * N); 373 const V16 v0 = LoadU(d16, vec + i + 0 * N); 374 const V16 v1 = LoadU(d16, vec + i + 1 * N); 375 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); 376 sum2 = ReorderWidenMulAccumulate(df, b1, v1, sum2, sum3); 377 } 378 // Last entire vector 379 for (; i + N <= kInner; i += N) { 380 const V16 b0 = LoadU(d16, row + i); 381 const V16 v0 = LoadU(d16, vec + i); 382 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); 383 } 384 const size_t remainder = kInner - i; 385 if (remainder != 0) { 386 const V16 b0 = LoadN(d16, row + i, remainder); 387 const V16 v0 = LoadN(d16, vec + i, remainder); 388 sum2 = ReorderWidenMulAccumulate(df, b0, v0, sum2, sum3); 389 } 390 // Reduction tree: sum of all accumulators, then their lanes 391 sum0 = Add(sum0, sum1); 392 sum2 = Add(sum2, sum3); 393 sum0 = Add(sum0, sum2); 394 buf[idx_row] = ReduceSum(df, sum0); 395 HWY_IF_CONSTEXPR(kAdd) { 396 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); 397 } 398 } // idx_row 399 HWY_UNROLL(4) // 1..4 iterations 400 for (size_t i = 0; i != kChunkSize; i += N / 2) { 401 Stream(Load(df, buf + i), df, out + begin + i); 402 } 403 }); 404 hwy::FlushStream(); 405 406 // Handle remainder rows which are not a multiple of the chunk size. 407 for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) { 408 auto sum0 = Zero(df); 409 auto sum1 = Zero(df); 410 411 const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner]; 412 size_t i = 0; 413 HWY_UNROLL(1) 414 for (; i + N <= kInner; i += N) { 415 const V16 b0 = LoadU(d16, row + i); 416 const V16 v0 = LoadU(d16, vec + i); 417 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); 418 } 419 const size_t remainder = kInner - i; 420 if (remainder != 0) { 421 const V16 b0 = LoadN(d16, row + i, remainder); 422 const V16 v0 = LoadN(d16, vec + i, remainder); 423 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); 424 } 425 out[r] = ReduceSum(df, Add(sum0, sum1)); 426 HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } 427 } // r 428 } 429 430 template <size_t kOuter, size_t kInner> 431 HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat, 432 const hwy::bfloat16_t* HWY_RESTRICT vec, 433 const hwy::bfloat16_t* HWY_RESTRICT add, 434 float* HWY_RESTRICT out, hwy::ThreadPool& pool) { 435 MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool); 436 } 437 438 template <size_t kOuter, size_t kInner> 439 HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat, 440 const hwy::bfloat16_t* HWY_RESTRICT vec, 441 float* HWY_RESTRICT out, hwy::ThreadPool& pool) { 442 MatVecAddImpl<kOuter, kInner, false>(mat, vec, /*add=*/nullptr, out, pool); 443 } 444 445 #endif // HWY_TARGET != HWY_SCALAR 446 447 // NOLINTNEXTLINE(google-readability-namespace-comments) 448 } // namespace HWY_NAMESPACE 449 } // namespace hwy 450 HWY_AFTER_NAMESPACE(); 451 452 #endif // HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_