tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

dot_test.cc (10545B)


      1 // Copyright 2021 Google LLC
      2 // SPDX-License-Identifier: Apache-2.0
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 
     16 #include <stdint.h>
     17 #include <stdio.h>
     18 #include <stdlib.h>
     19 
     20 #include "hwy/aligned_allocator.h"
     21 #include "hwy/base.h"
     22 
     23 // clang-format off
     24 #undef HWY_TARGET_INCLUDE
     25 #define HWY_TARGET_INCLUDE "hwy/contrib/dot/dot_test.cc"
     26 #include "hwy/foreach_target.h"  // IWYU pragma: keep
     27 #include "hwy/highway.h"
     28 #include "hwy/contrib/dot/dot-inl.h"
     29 #include "hwy/tests/test_util-inl.h"
     30 // clang-format on
     31 
     32 HWY_BEFORE_NAMESPACE();
     33 namespace hwy {
     34 namespace HWY_NAMESPACE {
     35 namespace {
     36 
     37 template <typename T1, typename T2>
     38 HWY_NOINLINE T1 SimpleDot(const T1* pa, const T2* pb, size_t num) {
     39  float sum = 0.0f;
     40  for (size_t i = 0; i < num; ++i) {
     41    sum += ConvertScalarTo<float>(pa[i]) * ConvertScalarTo<float>(pb[i]);
     42  }
     43  return ConvertScalarTo<T1>(sum);
     44 }
     45 
     46 HWY_MAYBE_UNUSED HWY_NOINLINE float SimpleDot(const float* pa,
     47                                              const hwy::bfloat16_t* pb,
     48                                              size_t num) {
     49  float sum = 0.0f;
     50  for (size_t i = 0; i < num; ++i) {
     51    sum += pa[i] * F32FromBF16(pb[i]);
     52  }
     53  return sum;
     54 }
     55 
     56 // Overload is required because the generic template hits an internal compiler
     57 // error on aarch64 clang.
     58 HWY_MAYBE_UNUSED HWY_NOINLINE float SimpleDot(const bfloat16_t* pa,
     59                                              const bfloat16_t* pb,
     60                                              size_t num) {
     61  float sum = 0.0f;
     62  for (size_t i = 0; i < num; ++i) {
     63    sum += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
     64  }
     65  return sum;
     66 }
     67 
     68 class TestDot {
     69  // Computes/verifies one dot product.
     70  template <int kAssumptions, class D>
     71  void Test(D d, size_t num, size_t misalign_a, size_t misalign_b,
     72            RandomState& rng) {
     73    using T = TFromD<D>;
     74    const size_t N = Lanes(d);
     75    const auto random_t = [&rng]() {
     76      const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
     77      return static_cast<float>(bits - 512) * (1.0f / 64);
     78    };
     79 
     80    const size_t padded =
     81        (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num;
     82    AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded);
     83    AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(misalign_b + padded);
     84    HWY_ASSERT(pa && pb);
     85    T* a = pa.get() + misalign_a;
     86    T* b = pb.get() + misalign_b;
     87    size_t i = 0;
     88    for (; i < num; ++i) {
     89      a[i] = ConvertScalarTo<T>(random_t());
     90      b[i] = ConvertScalarTo<T>(random_t());
     91    }
     92    // Fill padding - the values are not used, but avoids MSAN errors.
     93    for (; i < padded; ++i) {
     94      a[i] = ConvertScalarTo<T>(0);
     95      b[i] = ConvertScalarTo<T>(0);
     96    }
     97 
     98    const double expected = SimpleDot(a, b, num);
     99    const double magnitude = expected > 0.0 ? expected : -expected;
    100    const double actual =
    101        ConvertScalarTo<double>(Dot::Compute<kAssumptions>(d, a, b, num));
    102    const double max = static_cast<double>(8 * 8 * num);
    103    HWY_ASSERT(-max <= actual && actual <= max);
    104    // Integer math is exact, so no tolerance.
    105    const double tolerance =
    106        IsFloat<T>() ? 96.0 * ConvertScalarTo<double>(Epsilon<T>()) *
    107                           HWY_MAX(magnitude, 1.0)
    108                     : 0;
    109    HWY_ASSERT(expected - tolerance <= actual &&
    110               actual <= expected + tolerance);
    111  }
    112 
    113  // Runs tests with various alignments.
    114  template <int kAssumptions, class D>
    115  void ForeachMisalign(D d, size_t num, RandomState& rng) {
    116    const size_t N = Lanes(d);
    117    const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
    118    for (size_t ma : misalignments) {
    119      for (size_t mb : misalignments) {
    120        Test<kAssumptions>(d, num, ma, mb, rng);
    121      }
    122    }
    123  }
    124 
    125  // Runs tests with various lengths compatible with the given assumptions.
    126  template <int kAssumptions, class D>
    127  void ForeachCount(D d, RandomState& rng) {
    128    const size_t N = Lanes(d);
    129    const size_t counts[] = {1,
    130                             3,
    131                             7,
    132                             16,
    133                             HWY_MAX(N / 2, 1),
    134                             HWY_MAX(2 * N / 3, 1),
    135                             N,
    136                             N + 1,
    137                             4 * N / 3,
    138                             3 * N,
    139                             8 * N,
    140                             8 * N + 2};
    141    for (size_t num : counts) {
    142      if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue;
    143      if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue;
    144      ForeachMisalign<kAssumptions>(d, num, rng);
    145    }
    146  }
    147 
    148 public:
    149  // Must be inlined on aarch64 for bf16, else clang crashes.
    150  template <class T, class D>
    151  HWY_INLINE void operator()(T /*unused*/, D d) {
    152    RandomState rng;
    153 
    154    // All 8 combinations of the three length-related flags:
    155    ForeachCount<0>(d, rng);
    156    ForeachCount<Dot::kAtLeastOneVector>(d, rng);
    157    ForeachCount<Dot::kMultipleOfVector>(d, rng);
    158    ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng);
    159    ForeachCount<Dot::kPaddedToVector>(d, rng);
    160    ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng);
    161    ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng);
    162    ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector |
    163                 Dot::kAtLeastOneVector>(d, rng);
    164  }
    165 };
    166 
    167 class TestDotF32BF16 {
    168  // Computes/verifies one dot product.
    169  template <int kAssumptions, class D>
    170  void Test(D d, size_t num, size_t misalign_a, size_t misalign_b,
    171            RandomState& rng) {
    172    using T = TFromD<D>;
    173    using T2 = hwy::bfloat16_t;
    174    const size_t N = Lanes(d);
    175    const auto random_t = [&rng]() {
    176      const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
    177      return static_cast<float>(bits - 512) * (1.0f / 64);
    178    };
    179 
    180    const size_t padded =
    181        (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num;
    182    AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded);
    183    AlignedFreeUniquePtr<T2[]> pb = AllocateAligned<T2>(misalign_b + padded);
    184    HWY_ASSERT(pa && pb);
    185    T* a = pa.get() + misalign_a;
    186    T2* b = pb.get() + misalign_b;
    187    size_t i = 0;
    188    for (; i < num; ++i) {
    189      a[i] = ConvertScalarTo<T>(random_t());
    190      b[i] = ConvertScalarTo<T2>(random_t());
    191    }
    192    // Fill padding with NaN - the values are not used, but avoids MSAN errors.
    193    for (; i < padded; ++i) {
    194      ScalableTag<float> df1;
    195      a[i] = ConvertScalarTo<T>(GetLane(NaN(df1)));
    196      b[i] = ConvertScalarTo<T2>(GetLane(NaN(df1)));
    197    }
    198 
    199    const double expected = SimpleDot(a, b, num);
    200    const double magnitude = expected > 0.0 ? expected : -expected;
    201    const double actual =
    202        ConvertScalarTo<double>(Dot::Compute<kAssumptions>(d, a, b, num));
    203    const double max = static_cast<double>(8 * 8 * num);
    204    HWY_ASSERT(-max <= actual && actual <= max);
    205    const double tolerance =
    206        64.0 * ConvertScalarTo<double>(Epsilon<T2>()) * HWY_MAX(magnitude, 1.0);
    207    HWY_ASSERT(expected - tolerance <= actual &&
    208               actual <= expected + tolerance);
    209  }
    210 
    211  // Runs tests with various alignments.
    212  template <int kAssumptions, class D>
    213  void ForeachMisalign(D d, size_t num, RandomState& rng) {
    214    const size_t N = Lanes(d);
    215    const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
    216    for (size_t ma : misalignments) {
    217      for (size_t mb : misalignments) {
    218        Test<kAssumptions>(d, num, ma, mb, rng);
    219      }
    220    }
    221  }
    222 
    223  // Runs tests with various lengths compatible with the given assumptions.
    224  template <int kAssumptions, class D>
    225  void ForeachCount(D d, RandomState& rng) {
    226    const size_t N = Lanes(d);
    227    const size_t counts[] = {1,
    228                             3,
    229                             7,
    230                             16,
    231                             HWY_MAX(N / 2, 1),
    232                             HWY_MAX(2 * N / 3, 1),
    233                             N,
    234                             N + 1,
    235                             4 * N / 3,
    236                             3 * N,
    237                             8 * N,
    238                             8 * N + 2};
    239    for (size_t num : counts) {
    240      if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue;
    241      if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue;
    242      ForeachMisalign<kAssumptions>(d, num, rng);
    243    }
    244  }
    245 
    246 public:
    247  // Must be inlined on aarch64 for bf16, else clang crashes.
    248  template <class T, class D>
    249  HWY_INLINE void operator()(T /*unused*/, D d) {
    250    RandomState rng;
    251 
    252    // All 8 combinations of the three length-related flags:
    253    ForeachCount<0>(d, rng);
    254    ForeachCount<Dot::kAtLeastOneVector>(d, rng);
    255    ForeachCount<Dot::kMultipleOfVector>(d, rng);
    256    ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng);
    257    ForeachCount<Dot::kPaddedToVector>(d, rng);
    258    ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng);
    259    ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng);
    260    ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector |
    261                 Dot::kAtLeastOneVector>(d, rng);
    262  }
    263 };
    264 
    265 // All floating-point types, both arguments same.
    266 void TestAllDot() { ForFloatTypes(ForPartialVectors<TestDot>()); }
    267 
    268 // Mixed f32 and bf16 inputs.
    269 void TestAllDotF32BF16() { ForPartialVectors<TestDotF32BF16>()(float()); }
    270 
    271 // Both inputs bf16.
    272 void TestAllDotBF16() { ForShrinkableVectors<TestDot>()(bfloat16_t()); }
    273 
    274 // Both inputs i16.
    275 void TestAllDotI16() { ForShrinkableVectors<TestDot>()(int16_t()); }
    276 
    277 }  // namespace
    278 // NOLINTNEXTLINE(google-readability-namespace-comments)
    279 }  // namespace HWY_NAMESPACE
    280 }  // namespace hwy
    281 HWY_AFTER_NAMESPACE();
    282 
    283 #if HWY_ONCE
    284 namespace hwy {
    285 namespace {
    286 HWY_BEFORE_TEST(DotTest);
    287 HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot);
    288 HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotF32BF16);
    289 HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotBF16);
    290 HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotI16);
    291 HWY_AFTER_TEST();
    292 }  // namespace
    293 }  // namespace hwy
    294 HWY_TEST_MAIN();
    295 #endif  // HWY_ONCE