dot_test.cc (10545B)
1 // Copyright 2021 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #include <stdint.h> 17 #include <stdio.h> 18 #include <stdlib.h> 19 20 #include "hwy/aligned_allocator.h" 21 #include "hwy/base.h" 22 23 // clang-format off 24 #undef HWY_TARGET_INCLUDE 25 #define HWY_TARGET_INCLUDE "hwy/contrib/dot/dot_test.cc" 26 #include "hwy/foreach_target.h" // IWYU pragma: keep 27 #include "hwy/highway.h" 28 #include "hwy/contrib/dot/dot-inl.h" 29 #include "hwy/tests/test_util-inl.h" 30 // clang-format on 31 32 HWY_BEFORE_NAMESPACE(); 33 namespace hwy { 34 namespace HWY_NAMESPACE { 35 namespace { 36 37 template <typename T1, typename T2> 38 HWY_NOINLINE T1 SimpleDot(const T1* pa, const T2* pb, size_t num) { 39 float sum = 0.0f; 40 for (size_t i = 0; i < num; ++i) { 41 sum += ConvertScalarTo<float>(pa[i]) * ConvertScalarTo<float>(pb[i]); 42 } 43 return ConvertScalarTo<T1>(sum); 44 } 45 46 HWY_MAYBE_UNUSED HWY_NOINLINE float SimpleDot(const float* pa, 47 const hwy::bfloat16_t* pb, 48 size_t num) { 49 float sum = 0.0f; 50 for (size_t i = 0; i < num; ++i) { 51 sum += pa[i] * F32FromBF16(pb[i]); 52 } 53 return sum; 54 } 55 56 // Overload is required because the generic template hits an internal compiler 57 // error on aarch64 clang. 58 HWY_MAYBE_UNUSED HWY_NOINLINE float SimpleDot(const bfloat16_t* pa, 59 const bfloat16_t* pb, 60 size_t num) { 61 float sum = 0.0f; 62 for (size_t i = 0; i < num; ++i) { 63 sum += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); 64 } 65 return sum; 66 } 67 68 class TestDot { 69 // Computes/verifies one dot product. 70 template <int kAssumptions, class D> 71 void Test(D d, size_t num, size_t misalign_a, size_t misalign_b, 72 RandomState& rng) { 73 using T = TFromD<D>; 74 const size_t N = Lanes(d); 75 const auto random_t = [&rng]() { 76 const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023; 77 return static_cast<float>(bits - 512) * (1.0f / 64); 78 }; 79 80 const size_t padded = 81 (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num; 82 AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded); 83 AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(misalign_b + padded); 84 HWY_ASSERT(pa && pb); 85 T* a = pa.get() + misalign_a; 86 T* b = pb.get() + misalign_b; 87 size_t i = 0; 88 for (; i < num; ++i) { 89 a[i] = ConvertScalarTo<T>(random_t()); 90 b[i] = ConvertScalarTo<T>(random_t()); 91 } 92 // Fill padding - the values are not used, but avoids MSAN errors. 93 for (; i < padded; ++i) { 94 a[i] = ConvertScalarTo<T>(0); 95 b[i] = ConvertScalarTo<T>(0); 96 } 97 98 const double expected = SimpleDot(a, b, num); 99 const double magnitude = expected > 0.0 ? expected : -expected; 100 const double actual = 101 ConvertScalarTo<double>(Dot::Compute<kAssumptions>(d, a, b, num)); 102 const double max = static_cast<double>(8 * 8 * num); 103 HWY_ASSERT(-max <= actual && actual <= max); 104 // Integer math is exact, so no tolerance. 105 const double tolerance = 106 IsFloat<T>() ? 96.0 * ConvertScalarTo<double>(Epsilon<T>()) * 107 HWY_MAX(magnitude, 1.0) 108 : 0; 109 HWY_ASSERT(expected - tolerance <= actual && 110 actual <= expected + tolerance); 111 } 112 113 // Runs tests with various alignments. 114 template <int kAssumptions, class D> 115 void ForeachMisalign(D d, size_t num, RandomState& rng) { 116 const size_t N = Lanes(d); 117 const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; 118 for (size_t ma : misalignments) { 119 for (size_t mb : misalignments) { 120 Test<kAssumptions>(d, num, ma, mb, rng); 121 } 122 } 123 } 124 125 // Runs tests with various lengths compatible with the given assumptions. 126 template <int kAssumptions, class D> 127 void ForeachCount(D d, RandomState& rng) { 128 const size_t N = Lanes(d); 129 const size_t counts[] = {1, 130 3, 131 7, 132 16, 133 HWY_MAX(N / 2, 1), 134 HWY_MAX(2 * N / 3, 1), 135 N, 136 N + 1, 137 4 * N / 3, 138 3 * N, 139 8 * N, 140 8 * N + 2}; 141 for (size_t num : counts) { 142 if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue; 143 if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue; 144 ForeachMisalign<kAssumptions>(d, num, rng); 145 } 146 } 147 148 public: 149 // Must be inlined on aarch64 for bf16, else clang crashes. 150 template <class T, class D> 151 HWY_INLINE void operator()(T /*unused*/, D d) { 152 RandomState rng; 153 154 // All 8 combinations of the three length-related flags: 155 ForeachCount<0>(d, rng); 156 ForeachCount<Dot::kAtLeastOneVector>(d, rng); 157 ForeachCount<Dot::kMultipleOfVector>(d, rng); 158 ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng); 159 ForeachCount<Dot::kPaddedToVector>(d, rng); 160 ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng); 161 ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng); 162 ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector | 163 Dot::kAtLeastOneVector>(d, rng); 164 } 165 }; 166 167 class TestDotF32BF16 { 168 // Computes/verifies one dot product. 169 template <int kAssumptions, class D> 170 void Test(D d, size_t num, size_t misalign_a, size_t misalign_b, 171 RandomState& rng) { 172 using T = TFromD<D>; 173 using T2 = hwy::bfloat16_t; 174 const size_t N = Lanes(d); 175 const auto random_t = [&rng]() { 176 const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023; 177 return static_cast<float>(bits - 512) * (1.0f / 64); 178 }; 179 180 const size_t padded = 181 (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num; 182 AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded); 183 AlignedFreeUniquePtr<T2[]> pb = AllocateAligned<T2>(misalign_b + padded); 184 HWY_ASSERT(pa && pb); 185 T* a = pa.get() + misalign_a; 186 T2* b = pb.get() + misalign_b; 187 size_t i = 0; 188 for (; i < num; ++i) { 189 a[i] = ConvertScalarTo<T>(random_t()); 190 b[i] = ConvertScalarTo<T2>(random_t()); 191 } 192 // Fill padding with NaN - the values are not used, but avoids MSAN errors. 193 for (; i < padded; ++i) { 194 ScalableTag<float> df1; 195 a[i] = ConvertScalarTo<T>(GetLane(NaN(df1))); 196 b[i] = ConvertScalarTo<T2>(GetLane(NaN(df1))); 197 } 198 199 const double expected = SimpleDot(a, b, num); 200 const double magnitude = expected > 0.0 ? expected : -expected; 201 const double actual = 202 ConvertScalarTo<double>(Dot::Compute<kAssumptions>(d, a, b, num)); 203 const double max = static_cast<double>(8 * 8 * num); 204 HWY_ASSERT(-max <= actual && actual <= max); 205 const double tolerance = 206 64.0 * ConvertScalarTo<double>(Epsilon<T2>()) * HWY_MAX(magnitude, 1.0); 207 HWY_ASSERT(expected - tolerance <= actual && 208 actual <= expected + tolerance); 209 } 210 211 // Runs tests with various alignments. 212 template <int kAssumptions, class D> 213 void ForeachMisalign(D d, size_t num, RandomState& rng) { 214 const size_t N = Lanes(d); 215 const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; 216 for (size_t ma : misalignments) { 217 for (size_t mb : misalignments) { 218 Test<kAssumptions>(d, num, ma, mb, rng); 219 } 220 } 221 } 222 223 // Runs tests with various lengths compatible with the given assumptions. 224 template <int kAssumptions, class D> 225 void ForeachCount(D d, RandomState& rng) { 226 const size_t N = Lanes(d); 227 const size_t counts[] = {1, 228 3, 229 7, 230 16, 231 HWY_MAX(N / 2, 1), 232 HWY_MAX(2 * N / 3, 1), 233 N, 234 N + 1, 235 4 * N / 3, 236 3 * N, 237 8 * N, 238 8 * N + 2}; 239 for (size_t num : counts) { 240 if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue; 241 if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue; 242 ForeachMisalign<kAssumptions>(d, num, rng); 243 } 244 } 245 246 public: 247 // Must be inlined on aarch64 for bf16, else clang crashes. 248 template <class T, class D> 249 HWY_INLINE void operator()(T /*unused*/, D d) { 250 RandomState rng; 251 252 // All 8 combinations of the three length-related flags: 253 ForeachCount<0>(d, rng); 254 ForeachCount<Dot::kAtLeastOneVector>(d, rng); 255 ForeachCount<Dot::kMultipleOfVector>(d, rng); 256 ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng); 257 ForeachCount<Dot::kPaddedToVector>(d, rng); 258 ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng); 259 ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng); 260 ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector | 261 Dot::kAtLeastOneVector>(d, rng); 262 } 263 }; 264 265 // All floating-point types, both arguments same. 266 void TestAllDot() { ForFloatTypes(ForPartialVectors<TestDot>()); } 267 268 // Mixed f32 and bf16 inputs. 269 void TestAllDotF32BF16() { ForPartialVectors<TestDotF32BF16>()(float()); } 270 271 // Both inputs bf16. 272 void TestAllDotBF16() { ForShrinkableVectors<TestDot>()(bfloat16_t()); } 273 274 // Both inputs i16. 275 void TestAllDotI16() { ForShrinkableVectors<TestDot>()(int16_t()); } 276 277 } // namespace 278 // NOLINTNEXTLINE(google-readability-namespace-comments) 279 } // namespace HWY_NAMESPACE 280 } // namespace hwy 281 HWY_AFTER_NAMESPACE(); 282 283 #if HWY_ONCE 284 namespace hwy { 285 namespace { 286 HWY_BEFORE_TEST(DotTest); 287 HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot); 288 HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotF32BF16); 289 HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotBF16); 290 HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotI16); 291 HWY_AFTER_TEST(); 292 } // namespace 293 } // namespace hwy 294 HWY_TEST_MAIN(); 295 #endif // HWY_ONCE