tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

dot-inl.h (17983B)


      1 // Copyright 2021 Google LLC
      2 // SPDX-License-Identifier: Apache-2.0
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 
     16 // clang-format off
     17 #if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == defined(HWY_TARGET_TOGGLE)  // NOLINT
     18 // clang-format on
     19 #ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
     20 #undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
     21 #else
     22 #define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
     23 #endif
     24 
     25 #include <stddef.h>
     26 #include <stdint.h>
     27 
     28 #include "hwy/highway.h"
     29 
     30 HWY_BEFORE_NAMESPACE();
     31 namespace hwy {
     32 namespace HWY_NAMESPACE {
     33 
     34 // NOTE: the D argument describes the inputs, not the output, because both
     35 // f32/f32, bf16/bf16, and f32/bf16 inputs accumulate to f32.
     36 struct Dot {
     37  // Specify zero or more of these, ORed together, as the kAssumptions template
     38  // argument to Compute. Each one may improve performance or reduce code size,
     39  // at the cost of additional requirements on the arguments.
     40  enum Assumptions {
     41    // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T).
     42    kAtLeastOneVector = 1,
     43    // num_elements is divisible by N (a power of two, so this can be used if
     44    // the problem size is known to be a power of two >= HWY_MAX_BYTES /
     45    // sizeof(T)).
     46    kMultipleOfVector = 2,
     47    // RoundUpTo(num_elements, N) elements are accessible; their value does not
     48    // matter (will be treated as if they were zero).
     49    kPaddedToVector = 4,
     50  };
     51 
     52  // Returns sum{pa[i] * pb[i]} for floating-point inputs, including float16_t
     53  // and double if HWY_HAVE_FLOAT16/64. Aligning the
     54  // pointers to a multiple of N elements is helpful but not required.
     55  template <int kAssumptions, class D, typename T = TFromD<D>>
     56  static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa,
     57                              const T* const HWY_RESTRICT pb,
     58                              const size_t num_elements) {
     59    static_assert(IsFloat<T>(), "MulAdd requires float type");
     60    using V = decltype(Zero(d));
     61 
     62    HWY_LANES_CONSTEXPR size_t N = Lanes(d);
     63    size_t i = 0;
     64 
     65    constexpr bool kIsAtLeastOneVector =
     66        (kAssumptions & kAtLeastOneVector) != 0;
     67    constexpr bool kIsMultipleOfVector =
     68        (kAssumptions & kMultipleOfVector) != 0;
     69    constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
     70 
     71    // Won't be able to do a full vector load without padding => scalar loop.
     72    if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
     73        HWY_UNLIKELY(num_elements < N)) {
     74      // Only 2x unroll to avoid excessive code size.
     75      T sum0 = ConvertScalarTo<T>(0);
     76      T sum1 = ConvertScalarTo<T>(0);
     77      for (; i + 2 <= num_elements; i += 2) {
     78        // For reasons unknown, fp16 += does not compile on clang (Arm).
     79        sum0 = ConvertScalarTo<T>(sum0 + pa[i + 0] * pb[i + 0]);
     80        sum1 = ConvertScalarTo<T>(sum1 + pa[i + 1] * pb[i + 1]);
     81      }
     82      if (i < num_elements) {
     83        sum1 = ConvertScalarTo<T>(sum1 + pa[i] * pb[i]);
     84      }
     85      return ConvertScalarTo<T>(sum0 + sum1);
     86    }
     87 
     88    // Compiler doesn't make independent sum* accumulators, so unroll manually.
     89    // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
     90    // for unaligned inputs (each unaligned pointer halves the throughput
     91    // because it occupies both L1 load ports for a cycle). We cannot have
     92    // arrays of vectors on RVV/SVE, so always unroll 4x.
     93    V sum0 = Zero(d);
     94    V sum1 = Zero(d);
     95    V sum2 = Zero(d);
     96    V sum3 = Zero(d);
     97 
     98    // Main loop: unrolled
     99    for (; i + 4 * N <= num_elements; /* i += 4 * N */) {  // incr in loop
    100      const auto a0 = LoadU(d, pa + i);
    101      const auto b0 = LoadU(d, pb + i);
    102      i += N;
    103      sum0 = MulAdd(a0, b0, sum0);
    104      const auto a1 = LoadU(d, pa + i);
    105      const auto b1 = LoadU(d, pb + i);
    106      i += N;
    107      sum1 = MulAdd(a1, b1, sum1);
    108      const auto a2 = LoadU(d, pa + i);
    109      const auto b2 = LoadU(d, pb + i);
    110      i += N;
    111      sum2 = MulAdd(a2, b2, sum2);
    112      const auto a3 = LoadU(d, pa + i);
    113      const auto b3 = LoadU(d, pb + i);
    114      i += N;
    115      sum3 = MulAdd(a3, b3, sum3);
    116    }
    117 
    118    // Up to 3 iterations of whole vectors
    119    for (; i + N <= num_elements; i += N) {
    120      const auto a = LoadU(d, pa + i);
    121      const auto b = LoadU(d, pb + i);
    122      sum0 = MulAdd(a, b, sum0);
    123    }
    124 
    125    if (!kIsMultipleOfVector) {
    126      const size_t remaining = num_elements - i;
    127      if (remaining != 0) {
    128        if (kIsPaddedToVector) {
    129          const auto mask = FirstN(d, remaining);
    130          const auto a = LoadU(d, pa + i);
    131          const auto b = LoadU(d, pb + i);
    132          sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
    133        } else {
    134          // Unaligned load such that the last element is in the highest lane -
    135          // ensures we do not touch any elements outside the valid range.
    136          // If we get here, then num_elements >= N.
    137          HWY_DASSERT(i >= N);
    138          i += remaining - N;
    139          const auto skip = FirstN(d, N - remaining);
    140          const auto a = LoadU(d, pa + i);  // always unaligned
    141          const auto b = LoadU(d, pb + i);
    142          sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
    143        }
    144      }
    145    }  // kMultipleOfVector
    146 
    147    // Reduction tree: sum of all accumulators by pairs, then across lanes.
    148    sum0 = Add(sum0, sum1);
    149    sum2 = Add(sum2, sum3);
    150    sum0 = Add(sum0, sum2);
    151    return ReduceSum(d, sum0);
    152  }
    153 
    154  // f32 * bf16
    155  template <int kAssumptions, class DF, HWY_IF_F32_D(DF)>
    156  static HWY_INLINE float Compute(const DF df,
    157                                  const float* const HWY_RESTRICT pa,
    158                                  const hwy::bfloat16_t* const HWY_RESTRICT pb,
    159                                  const size_t num_elements) {
    160 #if HWY_TARGET == HWY_SCALAR
    161    const Rebind<hwy::bfloat16_t, DF> dbf;
    162 #else
    163    const Repartition<hwy::bfloat16_t, DF> dbf;
    164    using VBF = decltype(Zero(dbf));
    165 #endif
    166    const Half<decltype(dbf)> dbfh;
    167    using VF = decltype(Zero(df));
    168 
    169    HWY_LANES_CONSTEXPR size_t NF = Lanes(df);
    170 
    171    constexpr bool kIsAtLeastOneVector =
    172        (kAssumptions & kAtLeastOneVector) != 0;
    173    constexpr bool kIsMultipleOfVector =
    174        (kAssumptions & kMultipleOfVector) != 0;
    175    constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
    176 
    177    // Won't be able to do a full vector load without padding => scalar loop.
    178    if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
    179        HWY_UNLIKELY(num_elements < NF)) {
    180      // Only 2x unroll to avoid excessive code size.
    181      float sum0 = 0.0f;
    182      float sum1 = 0.0f;
    183      size_t i = 0;
    184      for (; i + 2 <= num_elements; i += 2) {
    185        sum0 += pa[i + 0] * ConvertScalarTo<float>(pb[i + 0]);
    186        sum1 += pa[i + 1] * ConvertScalarTo<float>(pb[i + 1]);
    187      }
    188      for (; i < num_elements; ++i) {
    189        sum1 += pa[i] * ConvertScalarTo<float>(pb[i]);
    190      }
    191      return sum0 + sum1;
    192    }
    193 
    194    // Compiler doesn't make independent sum* accumulators, so unroll manually.
    195    // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
    196    // for unaligned inputs (each unaligned pointer halves the throughput
    197    // because it occupies both L1 load ports for a cycle). We cannot have
    198    // arrays of vectors on RVV/SVE, so always unroll 4x.
    199    VF sum0 = Zero(df);
    200    VF sum1 = Zero(df);
    201    VF sum2 = Zero(df);
    202    VF sum3 = Zero(df);
    203 
    204    size_t i = 0;
    205 
    206 #if HWY_TARGET != HWY_SCALAR  // PromoteUpperTo supported
    207    // Main loop: unrolled
    208    for (; i + 4 * NF <= num_elements; /* i += 4 * N */) {  // incr in loop
    209      const VF a0 = LoadU(df, pa + i);
    210      const VBF b0 = LoadU(dbf, pb + i);
    211      i += NF;
    212      sum0 = MulAdd(a0, PromoteLowerTo(df, b0), sum0);
    213      const VF a1 = LoadU(df, pa + i);
    214      i += NF;
    215      sum1 = MulAdd(a1, PromoteUpperTo(df, b0), sum1);
    216      const VF a2 = LoadU(df, pa + i);
    217      const VBF b2 = LoadU(dbf, pb + i);
    218      i += NF;
    219      sum2 = MulAdd(a2, PromoteLowerTo(df, b2), sum2);
    220      const VF a3 = LoadU(df, pa + i);
    221      i += NF;
    222      sum3 = MulAdd(a3, PromoteUpperTo(df, b2), sum3);
    223    }
    224 #endif  // HWY_TARGET == HWY_SCALAR
    225 
    226    // Up to 3 iterations of whole vectors
    227    for (; i + NF <= num_elements; i += NF) {
    228      const VF a = LoadU(df, pa + i);
    229      const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
    230      sum0 = MulAdd(a, b, sum0);
    231    }
    232 
    233    if (!kIsMultipleOfVector) {
    234      const size_t remaining = num_elements - i;
    235      if (remaining != 0) {
    236        if (kIsPaddedToVector) {
    237          const auto mask = FirstN(df, remaining);
    238          const VF a = LoadU(df, pa + i);
    239          const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
    240          sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
    241        } else {
    242          // Unaligned load such that the last element is in the highest lane -
    243          // ensures we do not touch any elements outside the valid range.
    244          // If we get here, then num_elements >= N.
    245          HWY_DASSERT(i >= NF);
    246          i += remaining - NF;
    247          const auto skip = FirstN(df, NF - remaining);
    248          const VF a = LoadU(df, pa + i);  // always unaligned
    249          const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
    250          sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
    251        }
    252      }
    253    }  // kMultipleOfVector
    254 
    255    // Reduction tree: sum of all accumulators by pairs, then across lanes.
    256    sum0 = Add(sum0, sum1);
    257    sum2 = Add(sum2, sum3);
    258    sum0 = Add(sum0, sum2);
    259    return ReduceSum(df, sum0);
    260  }
    261 
    262  // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a
    263  // multiple of N elements is helpful but not required.
    264  template <int kAssumptions, class D, HWY_IF_BF16_D(D)>
    265  static HWY_INLINE float Compute(const D d,
    266                                  const bfloat16_t* const HWY_RESTRICT pa,
    267                                  const bfloat16_t* const HWY_RESTRICT pb,
    268                                  const size_t num_elements) {
    269    const RebindToUnsigned<D> du16;
    270    const Repartition<float, D> df32;
    271 
    272    using V = decltype(Zero(df32));
    273    HWY_LANES_CONSTEXPR size_t N = Lanes(d);
    274    size_t i = 0;
    275 
    276    constexpr bool kIsAtLeastOneVector =
    277        (kAssumptions & kAtLeastOneVector) != 0;
    278    constexpr bool kIsMultipleOfVector =
    279        (kAssumptions & kMultipleOfVector) != 0;
    280    constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
    281 
    282    // Won't be able to do a full vector load without padding => scalar loop.
    283    if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
    284        HWY_UNLIKELY(num_elements < N)) {
    285      float sum0 = 0.0f;  // Only 2x unroll to avoid excessive code size for..
    286      float sum1 = 0.0f;  // this unlikely(?) case.
    287      for (; i + 2 <= num_elements; i += 2) {
    288        sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
    289        sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
    290      }
    291      if (i < num_elements) {
    292        sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
    293      }
    294      return sum0 + sum1;
    295    }
    296 
    297    // See comment in the other Compute() overload. Unroll 2x, but we need
    298    // twice as many sums for ReorderWidenMulAccumulate.
    299    V sum0 = Zero(df32);
    300    V sum1 = Zero(df32);
    301    V sum2 = Zero(df32);
    302    V sum3 = Zero(df32);
    303 
    304    // Main loop: unrolled
    305    for (; i + 2 * N <= num_elements; /* i += 2 * N */) {  // incr in loop
    306      const auto a0 = LoadU(d, pa + i);
    307      const auto b0 = LoadU(d, pb + i);
    308      i += N;
    309      sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
    310      const auto a1 = LoadU(d, pa + i);
    311      const auto b1 = LoadU(d, pb + i);
    312      i += N;
    313      sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3);
    314    }
    315 
    316    // Possibly one more iteration of whole vectors
    317    if (i + N <= num_elements) {
    318      const auto a0 = LoadU(d, pa + i);
    319      const auto b0 = LoadU(d, pb + i);
    320      i += N;
    321      sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
    322    }
    323 
    324    if (!kIsMultipleOfVector) {
    325      const size_t remaining = num_elements - i;
    326      if (remaining != 0) {
    327        if (kIsPaddedToVector) {
    328          const auto mask = FirstN(du16, remaining);
    329          const auto va = LoadU(d, pa + i);
    330          const auto vb = LoadU(d, pb + i);
    331          const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
    332          const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
    333          sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
    334 
    335        } else {
    336          // Unaligned load such that the last element is in the highest lane -
    337          // ensures we do not touch any elements outside the valid range.
    338          // If we get here, then num_elements >= N.
    339          HWY_DASSERT(i >= N);
    340          i += remaining - N;
    341          const auto skip = FirstN(du16, N - remaining);
    342          const auto va = LoadU(d, pa + i);  // always unaligned
    343          const auto vb = LoadU(d, pb + i);
    344          const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
    345          const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
    346          sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
    347        }
    348      }
    349    }  // kMultipleOfVector
    350 
    351    // Reduction tree: sum of all accumulators by pairs, then across lanes.
    352    sum0 = Add(sum0, sum1);
    353    sum2 = Add(sum2, sum3);
    354    sum0 = Add(sum0, sum2);
    355    return ReduceSum(df32, sum0);
    356  }
    357 
    358  // Returns sum{i32(pa[i]) * i32(pb[i])} for i16 inputs. Aligning the pointers
    359  // to a multiple of N elements is helpful but not required.
    360  template <int kAssumptions, class D, HWY_IF_I16_D(D)>
    361  static HWY_INLINE int32_t Compute(const D d,
    362                                    const int16_t* const HWY_RESTRICT pa,
    363                                    const int16_t* const HWY_RESTRICT pb,
    364                                    const size_t num_elements) {
    365    const RebindToUnsigned<D> du16;
    366    const RepartitionToWide<D> di32;
    367 
    368    using VI32 = Vec<decltype(di32)>;
    369    HWY_LANES_CONSTEXPR size_t N = Lanes(d);
    370    size_t i = 0;
    371 
    372    constexpr bool kIsAtLeastOneVector =
    373        (kAssumptions & kAtLeastOneVector) != 0;
    374    constexpr bool kIsMultipleOfVector =
    375        (kAssumptions & kMultipleOfVector) != 0;
    376    constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
    377 
    378    // Won't be able to do a full vector load without padding => scalar loop.
    379    if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
    380        HWY_UNLIKELY(num_elements < N)) {
    381      int32_t sum0 = 0;  // Only 2x unroll to avoid excessive code size for..
    382      int32_t sum1 = 0;  // this unlikely(?) case.
    383      for (; i + 2 <= num_elements; i += 2) {
    384        sum0 += int32_t{pa[i + 0]} * int32_t{pb[i + 0]};
    385        sum1 += int32_t{pa[i + 1]} * int32_t{pb[i + 1]};
    386      }
    387      if (i < num_elements) {
    388        sum1 += int32_t{pa[i]} * int32_t{pb[i]};
    389      }
    390      return sum0 + sum1;
    391    }
    392 
    393    // See comment in the other Compute() overload. Unroll 2x, but we need
    394    // twice as many sums for ReorderWidenMulAccumulate.
    395    VI32 sum0 = Zero(di32);
    396    VI32 sum1 = Zero(di32);
    397    VI32 sum2 = Zero(di32);
    398    VI32 sum3 = Zero(di32);
    399 
    400    // Main loop: unrolled
    401    for (; i + 2 * N <= num_elements; /* i += 2 * N */) {  // incr in loop
    402      const auto a0 = LoadU(d, pa + i);
    403      const auto b0 = LoadU(d, pb + i);
    404      i += N;
    405      sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1);
    406      const auto a1 = LoadU(d, pa + i);
    407      const auto b1 = LoadU(d, pb + i);
    408      i += N;
    409      sum2 = ReorderWidenMulAccumulate(di32, a1, b1, sum2, sum3);
    410    }
    411 
    412    // Possibly one more iteration of whole vectors
    413    if (i + N <= num_elements) {
    414      const auto a0 = LoadU(d, pa + i);
    415      const auto b0 = LoadU(d, pb + i);
    416      i += N;
    417      sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1);
    418    }
    419 
    420    if (!kIsMultipleOfVector) {
    421      const size_t remaining = num_elements - i;
    422      if (remaining != 0) {
    423        if (kIsPaddedToVector) {
    424          const auto mask = FirstN(du16, remaining);
    425          const auto va = LoadU(d, pa + i);
    426          const auto vb = LoadU(d, pb + i);
    427          const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
    428          const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
    429          sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3);
    430 
    431        } else {
    432          // Unaligned load such that the last element is in the highest lane -
    433          // ensures we do not touch any elements outside the valid range.
    434          // If we get here, then num_elements >= N.
    435          HWY_DASSERT(i >= N);
    436          i += remaining - N;
    437          const auto skip = FirstN(du16, N - remaining);
    438          const auto va = LoadU(d, pa + i);  // always unaligned
    439          const auto vb = LoadU(d, pb + i);
    440          const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
    441          const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
    442          sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3);
    443        }
    444      }
    445    }  // kMultipleOfVector
    446 
    447    // Reduction tree: sum of all accumulators by pairs, then across lanes.
    448    sum0 = Add(sum0, sum1);
    449    sum2 = Add(sum2, sum3);
    450    sum0 = Add(sum0, sum2);
    451    return ReduceSum(di32, sum0);
    452  }
    453 };
    454 
    455 // NOLINTNEXTLINE(google-readability-namespace-comments)
    456 }  // namespace HWY_NAMESPACE
    457 }  // namespace hwy
    458 HWY_AFTER_NAMESPACE();
    459 
    460 #endif  // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_