tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

matvec_test.cc (11071B)


      1 // Copyright 2023 Google LLC
      2 // SPDX-License-Identifier: Apache-2.0
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 
     16 #include "hwy/base.h"
     17 
     18 // Reduce targets to avoid timeout under emulation.
     19 #ifndef HWY_DISABLED_TARGETS
     20 #define HWY_DISABLED_TARGETS (HWY_SVE2_128 | HWY_SVE2 | HWY_SVE_256 | HWY_NEON)
     21 #endif
     22 
     23 #include <stddef.h>
     24 #include <stdint.h>
     25 
     26 #include <cmath>  // std::abs
     27 
     28 #include "hwy/aligned_allocator.h"
     29 
     30 // clang-format off
     31 #undef HWY_TARGET_INCLUDE
     32 #define HWY_TARGET_INCLUDE "hwy/contrib/matvec/matvec_test.cc"  // NOLINT
     33 #include "hwy/foreach_target.h"  // IWYU pragma: keep
     34 // Must come after foreach_target.h
     35 #include "hwy/contrib/algo/transform-inl.h"
     36 #include "hwy/contrib/matvec/matvec-inl.h"
     37 #include "hwy/highway.h"
     38 #include "hwy/contrib/thread_pool/thread_pool.h"
     39 #include "hwy/contrib/thread_pool/topology.h"
     40 #include "hwy/tests/test_util-inl.h"
     41 // clang-format on
     42 
     43 HWY_BEFORE_NAMESPACE();
     44 namespace hwy {
     45 namespace HWY_NAMESPACE {
     46 namespace {
     47 
     48 template <typename MatT, typename T>
     49 HWY_NOINLINE void SimpleMatVecAdd(const MatT* HWY_RESTRICT mat,
     50                                  const T* HWY_RESTRICT vec,
     51                                  const T* HWY_RESTRICT add, size_t rows,
     52                                  size_t cols, T* HWY_RESTRICT out,
     53                                  ThreadPool& pool) {
     54  if (add) {
     55    pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
     56      double dot = 0.0;
     57      for (size_t c = 0; c < cols; c++) {
     58        // For reasons unknown, fp16 += does not compile on clang (Arm).
     59        dot += ConvertScalarTo<double>(mat[r * cols + c]) *
     60               ConvertScalarTo<double>(vec[c]);
     61      }
     62      out[r] = ConvertScalarTo<T>(dot + ConvertScalarTo<double>(add[r]));
     63    });
     64  } else {
     65    pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
     66      double dot = 0.0;
     67      for (size_t c = 0; c < cols; c++) {
     68        // For reasons unknown, fp16 += does not compile on clang (Arm).
     69        dot += ConvertScalarTo<double>(mat[r * cols + c]) *
     70               ConvertScalarTo<double>(vec[c]);
     71      }
     72      out[r] = ConvertScalarTo<T>(dot);
     73    });
     74  }
     75 }
     76 
     77 HWY_MAYBE_UNUSED HWY_NOINLINE void SimpleMatVecAdd(
     78    const hwy::bfloat16_t* HWY_RESTRICT mat, const float* HWY_RESTRICT vec,
     79    const float* add, size_t rows, size_t cols, float* HWY_RESTRICT out,
     80    ThreadPool& pool) {
     81  if (add) {
     82    pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
     83      float dot = 0.0f;
     84      for (size_t c = 0; c < cols; c++) {
     85        dot += F32FromBF16(mat[r * cols + c]) * vec[c];
     86      }
     87      out[r] = dot + add[r];
     88    });
     89  } else {
     90    pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
     91      float dot = 0.0f;
     92      for (size_t c = 0; c < cols; c++) {
     93        dot += F32FromBF16(mat[r * cols + c]) * vec[c];
     94      }
     95      out[r] = dot;
     96    });
     97  }
     98 }
     99 
    100 HWY_MAYBE_UNUSED HWY_NOINLINE void SimpleMatVecAdd(
    101    const hwy::bfloat16_t* HWY_RESTRICT mat,
    102    const hwy::bfloat16_t* HWY_RESTRICT vec,
    103    const hwy::bfloat16_t* HWY_RESTRICT add, size_t rows, size_t cols,
    104    float* HWY_RESTRICT out, ThreadPool& pool) {
    105  if (add) {
    106    pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
    107      float dot = 0.0f;
    108      for (size_t c = 0; c < cols; c++) {
    109        dot += F32FromBF16(mat[r * cols + c]) * F32FromBF16(vec[c]);
    110      }
    111      out[r] = dot + F32FromBF16(add[r]);
    112    });
    113  } else {
    114    pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
    115      float dot = 0.0f;
    116      for (size_t c = 0; c < cols; c++) {
    117        dot += F32FromBF16(mat[r * cols + c]) * F32FromBF16(vec[c]);
    118      }
    119      out[r] = dot;
    120    });
    121  }
    122 }
    123 
    124 // Workaround for incorrect codegen on Arm, which results in values of `av`
    125 // >= 1E10. Can also be prevented by calling `Print(du, indices)`.
    126 #if HWY_ARCH_ARM && HWY_COMPILER_CLANG
    127 #define GENERATE_INLINE HWY_NOINLINE
    128 #else
    129 #define GENERATE_INLINE HWY_INLINE
    130 #endif
    131 
    132 struct GenerateMod {
    133  template <class D, HWY_IF_NOT_BF16_D(D), HWY_IF_LANES_GT_D(D, 1)>
    134  GENERATE_INLINE Vec<D> operator()(D d,
    135                                    Vec<RebindToUnsigned<D>> indices) const {
    136    const RebindToUnsigned<D> du;
    137    return Reverse2(d, ConvertTo(d, And(indices, Set(du, 0xF))));
    138  }
    139 
    140  template <class D, HWY_IF_NOT_BF16_D(D), HWY_IF_LANES_LE_D(D, 1)>
    141  GENERATE_INLINE Vec<D> operator()(D d,
    142                                    Vec<RebindToUnsigned<D>> indices) const {
    143    const RebindToUnsigned<D> du;
    144    return ConvertTo(d, And(indices, Set(du, 0xF)));
    145  }
    146 
    147  // Requires >= 4 bf16 lanes for float32 Reverse2.
    148  template <class D, HWY_IF_BF16_D(D), HWY_IF_LANES_GT_D(D, 2)>
    149  GENERATE_INLINE Vec<D> operator()(D d,
    150                                    Vec<RebindToUnsigned<D>> indices) const {
    151    const RebindToUnsigned<D> du;
    152    const RebindToSigned<D> di;
    153    const RepartitionToWide<decltype(di)> dw;
    154    const RebindToFloat<decltype(dw)> df;
    155    indices = And(indices, Set(du, 0xF));
    156    const Vec<decltype(df)> i0 = ConvertTo(df, PromoteLowerTo(dw, indices));
    157    const Vec<decltype(df)> i1 = ConvertTo(df, PromoteUpperTo(dw, indices));
    158    return OrderedDemote2To(d, Reverse2(df, i0), Reverse2(df, i1));
    159  }
    160 
    161  // For one or two lanes, we don't have OrderedDemote2To nor Reverse2.
    162  template <class D, HWY_IF_BF16_D(D), HWY_IF_LANES_LE_D(D, 2)>
    163  GENERATE_INLINE Vec<D> operator()(D d,
    164                                    Vec<RebindToUnsigned<D>> indices) const {
    165    const Rebind<float, D> df;
    166    return DemoteTo(d, Set(df, static_cast<float>(GetLane(indices))));
    167  }
    168 };
    169 
    170 // MatT is usually the same as T, but can also be bfloat16_t when T = float.
    171 template <typename MatT, typename VecT>
    172 class TestMatVecAdd {
    173  template <size_t kRows, size_t kCols, class D, typename T = TFromD<D>>
    174  HWY_NOINLINE void Test(D d, ThreadPool& pool) {
    175 // This target lacks too many ops required in our implementation, use
    176 // HWY_EMU128 instead.
    177 #if HWY_TARGET != HWY_SCALAR
    178    const Repartition<MatT, D> dm;
    179    const Repartition<VecT, D> dv;
    180    const size_t misalign = 3 * Lanes(d) / 5;
    181    // Fill matrix and vector with small integer values
    182    const size_t area = kRows * kCols;
    183    AlignedFreeUniquePtr<MatT[]> storage_m =
    184        AllocateAligned<MatT>(misalign + area);
    185    AlignedFreeUniquePtr<VecT[]> storage_v =
    186        AllocateAligned<VecT>(misalign + kCols);
    187    AlignedFreeUniquePtr<VecT[]> storage_a =
    188        AllocateAligned<VecT>(misalign + kRows);
    189    HWY_ASSERT(storage_m && storage_v && storage_a);
    190    MatT* pm = storage_m.get() + misalign;
    191    VecT* pv = storage_v.get() + misalign;
    192    VecT* av = storage_a.get() + misalign;
    193    Generate(dm, pm, area, GenerateMod());
    194    Generate(dv, pv, kCols, GenerateMod());
    195    Generate(dv, av, kRows, GenerateMod());
    196 
    197    AlignedFreeUniquePtr<T[]> expected_without_add = AllocateAligned<T>(kRows);
    198    HWY_ASSERT(expected_without_add);
    199    SimpleMatVecAdd(pm, pv, static_cast<VecT*>(nullptr), kRows, kCols,
    200                    expected_without_add.get(), pool);
    201 
    202    AlignedFreeUniquePtr<T[]> actual_without_add = AllocateAligned<T>(kRows);
    203    HWY_ASSERT(actual_without_add);
    204    MatVec<kRows, kCols>(pm, pv, actual_without_add.get(), pool);
    205 
    206    const auto assert_close = [&](const AlignedFreeUniquePtr<T[]>& expected,
    207                                  const AlignedFreeUniquePtr<T[]>& actual,
    208                                  bool with_add) {
    209      for (size_t i = 0; i < kRows; ++i) {
    210        const double exp = ConvertScalarTo<double>(expected[i]);
    211        const double act = ConvertScalarTo<double>(actual[i]);
    212        const double epsilon =
    213            1.0 / (1ULL << HWY_MIN(MantissaBits<MatT>(), MantissaBits<VecT>()));
    214        const double tolerance = exp * 20.0 / epsilon;
    215        const double l1 = std::abs(exp - act);
    216        const double rel = exp == 0.0 ? 0.0 : l1 / exp;
    217 
    218        if (l1 > tolerance && rel > epsilon) {
    219          fprintf(stderr,
    220                  "%s/%s %zu x %zu, %s: mismatch at %zu: %E != %E; "
    221                  "tol %f l1 %f rel %E\n",
    222                  TypeName(MatT(), 1).c_str(), TypeName(VecT(), 1).c_str(),
    223                  kRows, kCols, (with_add ? "with add" : "without add"), i, exp,
    224                  act, tolerance, l1, rel);
    225          HWY_ASSERT(0);
    226        }
    227      }
    228    };
    229 
    230    assert_close(expected_without_add, actual_without_add, /*with_add=*/false);
    231 
    232    AlignedFreeUniquePtr<T[]> expected_with_add = AllocateAligned<T>(kRows);
    233    SimpleMatVecAdd(pm, pv, av, kRows, kCols, expected_with_add.get(), pool);
    234 
    235    AlignedFreeUniquePtr<T[]> actual_with_add = AllocateAligned<T>(kRows);
    236    MatVecAdd<kRows, kCols>(pm, pv, av, actual_with_add.get(), pool);
    237 
    238    assert_close(expected_with_add, actual_with_add, /*with_add=*/true);
    239 
    240 #else
    241    (void)d;
    242    (void)pool;
    243 #endif  // HWY_TARGET != HWY_SCALAR
    244  }
    245 
    246  template <class D>
    247  HWY_NOINLINE void CreatePoolAndTest(D d, size_t num_threads) {
    248    // Threads might not work on WASM; run only on main thread.
    249    if (HaveThreadingSupport()) num_threads = 0;
    250 
    251    ThreadPool pool(HWY_MIN(num_threads, ThreadPool::MaxThreads()));
    252 
    253    Test<AdjustedReps(192), AdjustedReps(256)>(d, pool);
    254 // Fewer tests due to compiler OOM
    255 #if !HWY_ARCH_RISCV
    256    Test<40, AdjustedReps(512)>(d, pool);
    257    Test<AdjustedReps(1024), 50>(d, pool);
    258 
    259    // Too large for low-precision vectors/accumulators.
    260    if (sizeof(TFromD<D>) != 2 && sizeof(VecT) != 2) {
    261      Test<AdjustedReps(1536), AdjustedReps(1536)>(d, pool);
    262    }
    263 #endif  // !HWY_ARCH_RISCV
    264  }
    265 
    266 public:
    267  template <class T, class D>
    268  HWY_NOINLINE void operator()(T /*unused*/, D d) {
    269    CreatePoolAndTest(d, 13);
    270 // Fewer tests due to compiler OOM
    271 #if !HWY_ARCH_RISCV
    272    CreatePoolAndTest(d, 16);
    273 #endif
    274  }
    275 };
    276 
    277 void TestAllMatVecAdd() {
    278 #if HWY_HAVE_FLOAT16
    279  ForPartialVectors<TestMatVecAdd<float16_t, float16_t>>()(float16_t());
    280 #endif
    281  ForPartialVectors<TestMatVecAdd<float, float>>()(float());
    282 #if HWY_HAVE_FLOAT64
    283  ForPartialVectors<TestMatVecAdd<double, double>>()(double());
    284 #endif
    285 }
    286 
    287 void TestAllMatVecBF16() {
    288  ForGEVectors<32, TestMatVecAdd<bfloat16_t, float>>()(float());
    289 }
    290 
    291 void TestAllMatVecBF16Both() {
    292  ForGEVectors<32, TestMatVecAdd<bfloat16_t, bfloat16_t>>()(float());
    293 }
    294 
    295 }  // namespace
    296 // NOLINTNEXTLINE(google-readability-namespace-comments)
    297 }  // namespace HWY_NAMESPACE
    298 }  // namespace hwy
    299 HWY_AFTER_NAMESPACE();
    300 
    301 #if HWY_ONCE
    302 namespace hwy {
    303 namespace {
    304 HWY_BEFORE_TEST(MatVecTest);
    305 HWY_EXPORT_AND_TEST_P(MatVecTest, TestAllMatVecAdd);
    306 HWY_EXPORT_AND_TEST_P(MatVecTest, TestAllMatVecBF16);
    307 HWY_EXPORT_AND_TEST_P(MatVecTest, TestAllMatVecBF16Both);
    308 HWY_AFTER_TEST();
    309 }  // namespace
    310 }  // namespace hwy
    311 HWY_TEST_MAIN();
    312 #endif  // HWY_ONCE