tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

matvec-inl.h (19278B)


      1 // Copyright 2023 Google LLC
      2 // SPDX-License-Identifier: Apache-2.0
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 
     16 // Include guard (still compiled once per target)
     17 #if defined(HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_) == \
     18    defined(HWY_TARGET_TOGGLE)
     19 #ifdef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
     20 #undef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
     21 #else
     22 #define HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
     23 #endif
     24 
     25 #include <stddef.h>
     26 #include <stdint.h>
     27 
     28 #include "hwy/cache_control.h"
     29 #include "hwy/contrib/thread_pool/thread_pool.h"
     30 #include "hwy/highway.h"
     31 
     32 HWY_BEFORE_NAMESPACE();
     33 namespace hwy {
     34 namespace HWY_NAMESPACE {
     35 
     36 template <typename TA, typename TB>
     37 TA AddScalar(TA a, TB b) {
     38  return ConvertScalarTo<TA>(ConvertScalarTo<float>(a) +
     39                             ConvertScalarTo<float>(b));
     40 }
     41 
     42 template <size_t kOuter, size_t kInner, typename T, bool kAdd>
     43 HWY_NOINLINE void MatVecAddImpl(const T* HWY_RESTRICT mat,
     44                                const T* HWY_RESTRICT vec,
     45                                const T* HWY_RESTRICT add, T* HWY_RESTRICT out,
     46                                hwy::ThreadPool& pool) {
     47  (void)add;
     48 
     49  // Process multiple rows at a time so that we write multiples of a cache line
     50  // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
     51  // parallelization potential.
     52  constexpr size_t kChunkSize2 = 64 / sizeof(T);
     53  const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2);
     54 
     55  const ScalableTag<T> d;
     56  const size_t N = Lanes(d);
     57  // Required for Stream loop, otherwise we might have partial vectors.
     58  HWY_DASSERT(kChunkSize2 >= N);
     59  pool.Run(0, num_chunks,
     60           [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
     61             // MSVC workaround: duplicate to ensure constexpr.
     62             constexpr size_t kChunkSize = 64 / sizeof(T);
     63             // Software write-combining to avoid cache pollution from out.
     64             // Although `out` may be used later, keeping it out of the cache
     65             // now and avoiding RFOs is a consistent 5% overall win.
     66             HWY_ALIGN T buf[kChunkSize];
     67 
     68             // Only handle entire chunks here because the Stream is not masked.
     69             // Remaining rows are handled after the pool.Run.
     70             const size_t begin = static_cast<size_t>(chunk * kChunkSize);
     71             for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
     72               auto sum0 = Zero(d);
     73               auto sum1 = Zero(d);
     74               // 4x unrolling barely helps SKX but likely helps Arm V2.
     75               auto sum2 = Zero(d);
     76               auto sum3 = Zero(d);
     77 
     78               const T* HWY_RESTRICT row = &mat[(begin + idx_row) * kInner];
     79               size_t i = 0;
     80               // No clear win from prefetching from the next 1..3 rows.
     81               // clflush &row[i] is slow, clflushopt less so but not helping.
     82               HWY_UNROLL(1)
     83               for (; i + 4 * N <= kInner; i += 4 * N) {
     84                 const auto a0 = LoadU(d, row + i + 0 * N);
     85                 const auto v0 = LoadU(d, vec + i + 0 * N);
     86                 sum0 = MulAdd(a0, v0, sum0);
     87 
     88                 const auto a1 = LoadU(d, row + i + 1 * N);
     89                 const auto v1 = LoadU(d, vec + i + 1 * N);
     90                 sum1 = MulAdd(a1, v1, sum1);
     91 
     92                 const auto a2 = LoadU(d, row + i + 2 * N);
     93                 const auto v2 = LoadU(d, vec + i + 2 * N);
     94                 sum2 = MulAdd(a2, v2, sum2);
     95 
     96                 const auto a3 = LoadU(d, row + i + 3 * N);
     97                 const auto v3 = LoadU(d, vec + i + 3 * N);
     98                 sum3 = MulAdd(a3, v3, sum3);
     99               }
    100               // Last entire vectors
    101               for (; i + N <= kInner; i += N) {
    102                 const auto a0 = LoadU(d, row + i);
    103                 const auto v0 = LoadU(d, vec + i);
    104                 sum0 = MulAdd(a0, v0, sum0);
    105               }
    106               const size_t remainder = kInner - i;
    107               if (remainder != 0) {
    108                 const auto a0 = LoadN(d, row + i, remainder);
    109                 const auto v0 = LoadN(d, vec + i, remainder);
    110                 sum1 = MulAdd(a0, v0, sum1);
    111               }
    112               // Reduction tree: sum of all accumulators, then their lanes
    113               sum2 = Add(sum2, sum3);
    114               sum0 = Add(sum0, sum1);
    115               sum0 = Add(sum0, sum2);
    116               buf[idx_row] = ReduceSum(d, sum0);
    117               HWY_IF_CONSTEXPR(kAdd) {
    118                 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
    119               }
    120             }  // idx_row
    121             HWY_UNROLL(4)  // 1..4 iterations
    122             for (size_t i = 0; i != kChunkSize; i += N) {
    123               Stream(Load(d, buf + i), d, out + begin + i);
    124             }
    125           });
    126  hwy::FlushStream();
    127 
    128  // Handle remainder rows which are not a multiple of the chunk size.
    129  for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) {
    130    auto sum0 = Zero(d);
    131 
    132    const T* HWY_RESTRICT row = &mat[r * kInner];
    133    size_t i = 0;
    134    HWY_UNROLL(1)
    135    for (; i + N <= kInner; i += N) {
    136      const auto a0 = LoadU(d, row + i);
    137      const auto v0 = LoadU(d, vec + i);
    138      sum0 = MulAdd(a0, v0, sum0);
    139    }
    140    const size_t remainder = kInner - i;
    141    if (remainder != 0) {
    142      const auto a0 = LoadN(d, row + i, remainder);
    143      const auto v0 = LoadN(d, vec + i, remainder);
    144      sum0 = MulAdd(a0, v0, sum0);
    145    }
    146    out[r] = ReduceSum(d, sum0);
    147    HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
    148  }  // r
    149 }
    150 
    151 // Multiplies mat with vec, adds add and puts the result in out.
    152 //
    153 // mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at
    154 // index i * kInner + j.
    155 //
    156 // vec is a (kInner,)-shaped array.
    157 //
    158 // add is a (kOuter,)-shaped array.
    159 //
    160 // out is a (kOuter,)-shaped array that will set to mat @ vec + add.
    161 template <size_t kOuter, size_t kInner, typename T>
    162 HWY_NOINLINE void MatVecAdd(const T* HWY_RESTRICT mat,
    163                            const T* HWY_RESTRICT vec,
    164                            const T* HWY_RESTRICT add, T* HWY_RESTRICT out,
    165                            hwy::ThreadPool& pool) {
    166  MatVecAddImpl<kOuter, kInner, T, true>(mat, vec, add, out, pool);
    167 }
    168 
    169 // Multiplies mat with vec and puts the result in out.
    170 //
    171 // mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at
    172 // index i * kInner + j.
    173 //
    174 // vec is a (kInner,)-shaped array.
    175 //
    176 // out is a (kOuter,)-shaped array that will set to mat @ vec.
    177 template <size_t kOuter, size_t kInner, typename T>
    178 HWY_NOINLINE void MatVec(const T* HWY_RESTRICT mat, const T* HWY_RESTRICT vec,
    179                         T* HWY_RESTRICT out, hwy::ThreadPool& pool) {
    180  MatVecAddImpl<kOuter, kInner, T, false>(mat, vec, /*add=*/nullptr, out, pool);
    181 }
    182 
    183 // This target lacks too many ops required in our implementation, use
    184 // HWY_EMU128 instead.
    185 #if HWY_TARGET != HWY_SCALAR
    186 
    187 // Specialization for bf16 matrix, which halves memory bandwidth requirements.
    188 template <size_t kOuter, size_t kInner, bool kAdd>
    189 HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
    190                                const float* HWY_RESTRICT vec,
    191                                const float* HWY_RESTRICT add,
    192                                float* HWY_RESTRICT out,
    193                                hwy::ThreadPool& pool) {
    194  // Process multiple rows at a time so that we write multiples of a cache line
    195  // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
    196  // parallelization potential.
    197  constexpr size_t kChunkSize2 = 64 / sizeof(float);
    198  const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2);
    199 
    200  const ScalableTag<float> d;
    201  const Repartition<hwy::bfloat16_t, decltype(d)> d16;
    202  // In the remainder loop, we only process a single f32 vector, so load half
    203  // vectors of bf16 to avoid overrun.
    204  const Half<decltype(d16)> d16h;
    205  using V = Vec<decltype(d)>;
    206  using V16 = Vec<decltype(d16)>;
    207  using V16H = Vec<decltype(d16h)>;
    208  const size_t N = Lanes(d);
    209  // Required for Stream loop, otherwise we might have partial vectors.
    210  HWY_DASSERT(kChunkSize2 >= N);
    211  pool.Run(0, num_chunks,
    212           [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
    213             // MSVC workaround: duplicate to ensure constexpr.
    214             constexpr size_t kChunkSize = 64 / sizeof(float);
    215             // Software write-combining to avoid cache pollution from out.
    216             // Although `out` may be used later, keeping it out of the cache
    217             // now and avoiding RFOs is a consistent 5% overall win.
    218             HWY_ALIGN float buf[kChunkSize];
    219 
    220             // Only handle entire chunks here because the Stream is not masked.
    221             // Remaining rows are handled after the pool.Run.
    222             const size_t begin = static_cast<size_t>(chunk * kChunkSize);
    223             for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
    224               auto sum0 = Zero(d);
    225               auto sum1 = Zero(d);
    226               // 4x unrolling barely helps SKX but likely helps Arm V2.
    227               auto sum2 = Zero(d);
    228               auto sum3 = Zero(d);
    229 
    230               const hwy::bfloat16_t* HWY_RESTRICT row =
    231                   &mat[(begin + idx_row) * kInner];
    232               size_t i = 0;
    233               // No clear win from prefetching from the next 1..3 rows.
    234               // clflush &row[i] is slow, clflushopt less so but not helping.
    235               HWY_UNROLL(1)
    236               for (; i + 4 * N <= kInner; i += 4 * N) {
    237                 const V16 b0 = LoadU(d16, row + i + 0 * N);
    238                 const V a0 = PromoteLowerTo(d, b0);
    239                 const V a1 = PromoteUpperTo(d, b0);
    240 
    241                 const V16 b1 = LoadU(d16, row + i + 2 * N);
    242                 const V a2 = PromoteLowerTo(d, b1);
    243                 const V a3 = PromoteUpperTo(d, b1);
    244 
    245                 const V v0 = LoadU(d, vec + i + 0 * N);
    246                 sum0 = MulAdd(a0, v0, sum0);
    247 
    248                 const V v1 = LoadU(d, vec + i + 1 * N);
    249                 sum1 = MulAdd(a1, v1, sum1);
    250 
    251                 const V v2 = LoadU(d, vec + i + 2 * N);
    252                 sum2 = MulAdd(a2, v2, sum2);
    253 
    254                 const V v3 = LoadU(d, vec + i + 3 * N);
    255                 sum3 = MulAdd(a3, v3, sum3);
    256               }
    257               // Last entire vectors
    258               for (; i + N <= kInner; i += N) {
    259                 const V16H b0 = LoadU(d16h, row + i);
    260                 const V a0 = PromoteTo(d, b0);
    261                 const V v0 = LoadU(d, vec + i);
    262                 sum0 = MulAdd(a0, v0, sum0);
    263               }
    264               const size_t remainder = kInner - i;
    265               if (remainder != 0) {
    266                 const V16H b0 = LoadN(d16h, row + i, remainder);
    267                 const V a0 = PromoteTo(d, b0);
    268                 const V v0 = LoadN(d, vec + i, remainder);
    269                 sum1 = MulAdd(a0, v0, sum1);
    270               }
    271               // Reduction tree: sum of all accumulators, then their lanes
    272               sum2 = Add(sum2, sum3);
    273               sum0 = Add(sum0, sum1);
    274               sum0 = Add(sum0, sum2);
    275               buf[idx_row] = ReduceSum(d, sum0);
    276               HWY_IF_CONSTEXPR(kAdd) {
    277                 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
    278               }
    279             }  // idx_row
    280             HWY_UNROLL(4)  // 1..4 iterations
    281             for (size_t i = 0; i != kChunkSize; i += N) {
    282               Stream(Load(d, buf + i), d, out + begin + i);
    283             }
    284           });
    285  hwy::FlushStream();
    286 
    287  // Handle remainder rows which are not a multiple of the chunk size.
    288  for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) {
    289    auto sum0 = Zero(d);
    290 
    291    const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner];
    292    size_t i = 0;
    293    HWY_UNROLL(1)
    294    for (; i + N <= kInner; i += N) {
    295      const V16H b0 = LoadU(d16h, row + i);
    296      const V a0 = PromoteTo(d, b0);
    297      const V v0 = LoadU(d, vec + i);
    298      sum0 = MulAdd(a0, v0, sum0);
    299    }
    300    const size_t remainder = kInner - i;
    301    if (remainder != 0) {
    302      const V16H b0 = LoadN(d16h, row + i, remainder);
    303      const V a0 = PromoteTo(d, b0);
    304      const V v0 = LoadN(d, vec + i, remainder);
    305      sum0 = MulAdd(a0, v0, sum0);
    306    }
    307    out[r] = ReduceSum(d, sum0);
    308    HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
    309  }  // r
    310 }
    311 
    312 template <size_t kOuter, size_t kInner>
    313 HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat,
    314                            const float* HWY_RESTRICT vec,
    315                            const float* HWY_RESTRICT add,
    316                            float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
    317  MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool);
    318 }
    319 
    320 template <size_t kOuter, size_t kInner>
    321 HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat,
    322                         const float* HWY_RESTRICT vec, float* HWY_RESTRICT out,
    323                         hwy::ThreadPool& pool) {
    324  MatVecAddImpl<kOuter, kInner, false>(mat, vec, /*add=*/nullptr, out, pool);
    325 }
    326 
    327 // Both mat and vec are bf16.
    328 template <size_t kOuter, size_t kInner, bool kAdd>
    329 HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
    330                                const hwy::bfloat16_t* HWY_RESTRICT vec,
    331                                const hwy::bfloat16_t* HWY_RESTRICT add,
    332                                float* HWY_RESTRICT out,
    333                                hwy::ThreadPool& pool) {
    334  // Process multiple rows at a time so that we write multiples of a cache line
    335  // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
    336  // parallelization potential.
    337  constexpr size_t kChunkSize2 = 64 / sizeof(bfloat16_t);
    338  const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2);
    339 
    340  const ScalableTag<float> df;
    341  const Repartition<hwy::bfloat16_t, decltype(df)> d16;
    342  using V16 = Vec<decltype(d16)>;
    343  const size_t N = Lanes(d16);
    344  // Required for Stream loop, otherwise we might have partial vectors.
    345  HWY_DASSERT(kChunkSize2 >= N);
    346  pool.Run(0, num_chunks,
    347           [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
    348             // MSVC workaround: duplicate to ensure constexpr.
    349             constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t);
    350             // Software write-combining to avoid cache pollution from out.
    351             // Although `out` may be used later, keeping it out of the cache
    352             // now and avoiding RFOs is a consistent 5% overall win.
    353             HWY_ALIGN float buf[kChunkSize];
    354 
    355             // Only handle entire chunks here because the Stream is not masked.
    356             // Remaining rows are handled after the pool.Run.
    357             const size_t begin = static_cast<size_t>(chunk * kChunkSize);
    358             for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
    359               auto sum0 = Zero(df);
    360               auto sum1 = Zero(df);
    361               auto sum2 = Zero(df);
    362               auto sum3 = Zero(df);
    363 
    364               const hwy::bfloat16_t* HWY_RESTRICT row =
    365                   &mat[(begin + idx_row) * kInner];
    366               size_t i = 0;
    367               // No clear win from prefetching from the next 1..3 rows.
    368               // clflush &row[i] is slow, clflushopt less so but not helping.
    369               HWY_UNROLL(1)
    370               for (; i + 2 * N <= kInner; i += 2 * N) {
    371                 const V16 b0 = LoadU(d16, row + i + 0 * N);
    372                 const V16 b1 = LoadU(d16, row + i + 1 * N);
    373                 const V16 v0 = LoadU(d16, vec + i + 0 * N);
    374                 const V16 v1 = LoadU(d16, vec + i + 1 * N);
    375                 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
    376                 sum2 = ReorderWidenMulAccumulate(df, b1, v1, sum2, sum3);
    377               }
    378               // Last entire vector
    379               for (; i + N <= kInner; i += N) {
    380                 const V16 b0 = LoadU(d16, row + i);
    381                 const V16 v0 = LoadU(d16, vec + i);
    382                 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
    383               }
    384               const size_t remainder = kInner - i;
    385               if (remainder != 0) {
    386                 const V16 b0 = LoadN(d16, row + i, remainder);
    387                 const V16 v0 = LoadN(d16, vec + i, remainder);
    388                 sum2 = ReorderWidenMulAccumulate(df, b0, v0, sum2, sum3);
    389               }
    390               // Reduction tree: sum of all accumulators, then their lanes
    391               sum0 = Add(sum0, sum1);
    392               sum2 = Add(sum2, sum3);
    393               sum0 = Add(sum0, sum2);
    394               buf[idx_row] = ReduceSum(df, sum0);
    395               HWY_IF_CONSTEXPR(kAdd) {
    396                 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
    397               }
    398             }  // idx_row
    399             HWY_UNROLL(4)  // 1..4 iterations
    400             for (size_t i = 0; i != kChunkSize; i += N / 2) {
    401               Stream(Load(df, buf + i), df, out + begin + i);
    402             }
    403           });
    404  hwy::FlushStream();
    405 
    406  // Handle remainder rows which are not a multiple of the chunk size.
    407  for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) {
    408    auto sum0 = Zero(df);
    409    auto sum1 = Zero(df);
    410 
    411    const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner];
    412    size_t i = 0;
    413    HWY_UNROLL(1)
    414    for (; i + N <= kInner; i += N) {
    415      const V16 b0 = LoadU(d16, row + i);
    416      const V16 v0 = LoadU(d16, vec + i);
    417      sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
    418    }
    419    const size_t remainder = kInner - i;
    420    if (remainder != 0) {
    421      const V16 b0 = LoadN(d16, row + i, remainder);
    422      const V16 v0 = LoadN(d16, vec + i, remainder);
    423      sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
    424    }
    425    out[r] = ReduceSum(df, Add(sum0, sum1));
    426    HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
    427  }  // r
    428 }
    429 
    430 template <size_t kOuter, size_t kInner>
    431 HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat,
    432                            const hwy::bfloat16_t* HWY_RESTRICT vec,
    433                            const hwy::bfloat16_t* HWY_RESTRICT add,
    434                            float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
    435  MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool);
    436 }
    437 
    438 template <size_t kOuter, size_t kInner>
    439 HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat,
    440                         const hwy::bfloat16_t* HWY_RESTRICT vec,
    441                         float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
    442  MatVecAddImpl<kOuter, kInner, false>(mat, vec, /*add=*/nullptr, out, pool);
    443 }
    444 
    445 #endif  // HWY_TARGET != HWY_SCALAR
    446 
    447 // NOLINTNEXTLINE(google-readability-namespace-comments)
    448 }  // namespace HWY_NAMESPACE
    449 }  // namespace hwy
    450 HWY_AFTER_NAMESPACE();
    451 
    452 #endif  // HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_