unroller_test.cc (15178B)
1 // Copyright Google LLC 2021 2 // Matthew Kolbe 2023 3 // SPDX-License-Identifier: Apache-2.0 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 #include <cmath> // std::abs 18 #include <vector> 19 20 #include "hwy/base.h" 21 22 // clang-format off 23 #undef HWY_TARGET_INCLUDE 24 #define HWY_TARGET_INCLUDE "hwy/contrib/unroller/unroller_test.cc" //NOLINT 25 #include "hwy/foreach_target.h" // IWYU pragma: keep 26 #include "hwy/highway.h" 27 #include "hwy/contrib/unroller/unroller-inl.h" 28 #include "hwy/tests/test_util-inl.h" 29 // clang-format on 30 31 HWY_BEFORE_NAMESPACE(); 32 namespace hwy { 33 namespace HWY_NAMESPACE { 34 namespace { 35 36 template <typename T> 37 T DoubleDot(const T* pa, const T* pb, size_t num) { 38 double sum = 0.0; 39 for (size_t i = 0; i < num; ++i) { 40 // For reasons unknown, fp16 += does not compile on clang (Arm). 41 sum += ConvertScalarTo<double>(pa[i]) * ConvertScalarTo<double>(pb[i]); 42 } 43 return ConvertScalarTo<T>(sum); 44 } 45 46 template <typename T> 47 T DoubleSum(const T* pa, size_t num) { 48 double sum = 0.0; 49 for (size_t i = 0; i < num; ++i) { 50 sum += ConvertScalarTo<double>(pa[i]); 51 } 52 return ConvertScalarTo<T>(sum); 53 } 54 55 template <typename T> 56 T DoubleMin(const T* pa, size_t num) { 57 double min = HighestValue<T>(); 58 for (size_t i = 0; i < num; ++i) { 59 min = HWY_MIN(min, ConvertScalarTo<double>(pa[i])); 60 } 61 return ConvertScalarTo<T>(min); 62 } 63 64 template <typename T> 65 struct MultiplyUnit : UnrollerUnit2D<MultiplyUnit<T>, T, T, T> { 66 using TT = hn::ScalableTag<T>; 67 HWY_INLINE hn::Vec<TT> Func(ptrdiff_t idx, const hn::Vec<TT> x0, 68 const hn::Vec<TT> x1, const hn::Vec<TT> y) { 69 (void)idx; 70 (void)y; 71 return hn::Mul(x0, x1); 72 } 73 }; 74 75 template <typename FROM_T, typename TO_T> 76 struct ConvertUnit : UnrollerUnit<ConvertUnit<FROM_T, TO_T>, FROM_T, TO_T> { 77 using Base = UnrollerUnit<ConvertUnit<FROM_T, TO_T>, FROM_T, TO_T>; 78 using Base::MaxUnitLanes; 79 using typename Base::LargerD; 80 81 using TT_FROM = hn::Rebind<FROM_T, LargerD>; 82 using TT_TO = hn::Rebind<TO_T, LargerD>; 83 84 template < 85 class ToD, class FromV, 86 hwy::EnableIf<(sizeof(TFromV<FromV>) > sizeof(TFromD<ToD>))>* = nullptr> 87 static HWY_INLINE hn::Vec<ToD> DoConvertVector(ToD d, FromV v) { 88 return hn::DemoteTo(d, v); 89 } 90 template < 91 class ToD, class FromV, 92 hwy::EnableIf<(sizeof(TFromV<FromV>) == sizeof(TFromD<ToD>))>* = nullptr> 93 static HWY_INLINE hn::Vec<ToD> DoConvertVector(ToD d, FromV v) { 94 return hn::ConvertTo(d, v); 95 } 96 template < 97 class ToD, class FromV, 98 hwy::EnableIf<(sizeof(TFromV<FromV>) < sizeof(TFromD<ToD>))>* = nullptr> 99 static HWY_INLINE hn::Vec<ToD> DoConvertVector(ToD d, FromV v) { 100 return hn::PromoteTo(d, v); 101 } 102 103 hn::Vec<TT_TO> Func(ptrdiff_t idx, const hn::Vec<TT_FROM> x, 104 const hn::Vec<TT_TO> y) { 105 (void)idx; 106 (void)y; 107 TT_TO d; 108 return DoConvertVector(d, x); 109 } 110 }; 111 112 // Returns a value that does not compare equal to `value`. 113 template <class D, HWY_IF_FLOAT_D(D)> 114 HWY_INLINE Vec<D> OtherValue(D d, TFromD<D> /*value*/) { 115 return NaN(d); 116 } 117 template <class D, HWY_IF_NOT_FLOAT_D(D)> 118 HWY_INLINE Vec<D> OtherValue(D d, TFromD<D> value) { 119 return hn::Set(d, hwy::AddWithWraparound(value, 1)); 120 } 121 122 // Caveat: stores lane indices as MakeSigned<T>, which may overflow for 8-bit T 123 // on HWY_RVV. 124 template <typename T> 125 struct FindUnit : UnrollerUnit<FindUnit<T>, T, MakeSigned<T>> { 126 using TI = MakeSigned<T>; 127 using Base = UnrollerUnit<FindUnit<T>, T, TI>; 128 using Base::ActualLanes; 129 using Base::MaxUnitLanes; 130 131 using D = hn::CappedTag<T, MaxUnitLanes()>; 132 T to_find; 133 D d; 134 using DI = RebindToSigned<D>; 135 DI di; 136 137 FindUnit(T find) : to_find(find) {} 138 139 hn::Vec<DI> Func(ptrdiff_t idx, const hn::Vec<D> x, const hn::Vec<DI> y) { 140 const Mask<D> msk = hn::Eq(x, hn::Set(d, to_find)); 141 const TI first_idx = static_cast<TI>(hn::FindFirstTrue(d, msk)); 142 if (first_idx > -1) 143 return hn::Set(di, static_cast<TI>(static_cast<TI>(idx) + first_idx)); 144 else 145 return y; 146 } 147 148 hn::Vec<D> X0InitImpl() { return OtherValue(D(), to_find); } 149 150 hn::Vec<DI> YInitImpl() { return hn::Set(di, TI{-1}); } 151 152 hn::Vec<D> MaskLoadImpl(const ptrdiff_t idx, const T* from, 153 const ptrdiff_t places) { 154 auto mask = hn::FirstN(d, static_cast<size_t>(places)); 155 auto maskneg = hn::Not(hn::FirstN( 156 d, 157 static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes())))); 158 if (places < 0) mask = maskneg; 159 return hn::IfThenElse(mask, hn::MaskedLoad(mask, d, from + idx), 160 X0InitImpl()); 161 } 162 163 bool StoreAndShortCircuitImpl(const ptrdiff_t idx, TI* to, 164 const hn::Vec<DI> x) { 165 (void)idx; 166 167 TI a = hn::GetLane(x); 168 to[0] = a; 169 170 if (a == -1) return true; 171 172 return false; 173 } 174 175 ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, TI* to, const hn::Vec<DI> x, 176 const ptrdiff_t places) { 177 (void)idx; 178 (void)places; 179 TI a = hn::GetLane(x); 180 to[0] = a; 181 return 1; 182 } 183 }; 184 185 template <typename T> 186 struct AccumulateUnit : UnrollerUnit<AccumulateUnit<T>, T, T> { 187 using TT = hn::ScalableTag<T>; 188 hn::Vec<TT> Func(ptrdiff_t idx, const hn::Vec<TT> x, const hn::Vec<TT> y) { 189 (void)idx; 190 return hn::Add(x, y); 191 } 192 193 bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to, 194 const hn::Vec<TT> x) { 195 // no stores in a reducer 196 (void)idx; 197 (void)to; 198 (void)x; 199 return true; 200 } 201 202 ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec<TT> x, 203 const ptrdiff_t places) { 204 // no stores in a reducer 205 (void)idx; 206 (void)to; 207 (void)x; 208 (void)places; 209 return 0; 210 } 211 212 ptrdiff_t ReduceImpl(const hn::Vec<TT> x, T* to) { 213 const hn::ScalableTag<T> d; 214 (*to) = hn::ReduceSum(d, x); 215 return 1; 216 } 217 218 void ReduceImpl(const hn::Vec<TT> x0, const hn::Vec<TT> x1, 219 const hn::Vec<TT> x2, hn::Vec<TT>* y) { 220 (*y) = hn::Add(hn::Add(*y, x0), hn::Add(x1, x2)); 221 } 222 }; 223 224 template <typename T> 225 struct MinUnit : UnrollerUnit<MinUnit<T>, T, T> { 226 using Base = UnrollerUnit<MinUnit<T>, T, T>; 227 using Base::ActualLanes; 228 229 using TT = hn::ScalableTag<T>; 230 TT d; 231 232 hn::Vec<TT> Func(const ptrdiff_t idx, const hn::Vec<TT> x, 233 const hn::Vec<TT> y) { 234 (void)idx; 235 return hn::Min(y, x); 236 } 237 238 hn::Vec<TT> YInitImpl() { return hn::Set(d, HighestValue<T>()); } 239 240 hn::Vec<TT> MaskLoadImpl(const ptrdiff_t idx, const T* from, 241 const ptrdiff_t places) { 242 auto mask = hn::FirstN(d, static_cast<size_t>(places)); 243 auto maskneg = hn::Not(hn::FirstN( 244 d, 245 static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes())))); 246 if (places < 0) mask = maskneg; 247 248 auto def = YInitImpl(); 249 return hn::MaskedLoadOr(def, mask, d, from + idx); 250 } 251 252 bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to, 253 const hn::Vec<TT> x) { 254 // no stores in a reducer 255 (void)idx; 256 (void)to; 257 (void)x; 258 return true; 259 } 260 261 ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec<TT> x, 262 const ptrdiff_t places) { 263 // no stores in a reducer 264 (void)idx; 265 (void)to; 266 (void)x; 267 (void)places; 268 return 0; 269 } 270 271 ptrdiff_t ReduceImpl(const hn::Vec<TT> x, T* to) { 272 auto minvect = hn::MinOfLanes(d, x); 273 (*to) = hn::ExtractLane(minvect, 0); 274 return 1; 275 } 276 277 void ReduceImpl(const hn::Vec<TT> x0, const hn::Vec<TT> x1, 278 const hn::Vec<TT> x2, hn::Vec<TT>* y) { 279 auto a = hn::Min(x1, x0); 280 auto b = hn::Min(*y, x2); 281 (*y) = hn::Min(a, b); 282 } 283 }; 284 285 template <typename T> 286 struct DotUnit : UnrollerUnit2D<DotUnit<T>, T, T, T> { 287 using TT = hn::ScalableTag<T>; 288 289 hn::Vec<TT> Func(const ptrdiff_t idx, const hn::Vec<TT> x0, 290 const hn::Vec<TT> x1, const hn::Vec<TT> y) { 291 (void)idx; 292 return hn::MulAdd(x0, x1, y); 293 } 294 295 bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to, 296 const hn::Vec<TT> x) { 297 // no stores in a reducer 298 (void)idx; 299 (void)to; 300 (void)x; 301 return true; 302 } 303 304 ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec<TT> x, 305 const ptrdiff_t places) { 306 // no stores in a reducer 307 (void)idx; 308 (void)to; 309 (void)x; 310 (void)places; 311 return 0; 312 } 313 314 ptrdiff_t ReduceImpl(const hn::Vec<TT> x, T* to) { 315 const hn::ScalableTag<T> d; 316 (*to) = hn::ReduceSum(d, x); 317 return 1; 318 } 319 320 void ReduceImpl(const hn::Vec<TT> x0, const hn::Vec<TT> x1, 321 const hn::Vec<TT> x2, hn::Vec<TT>* y) { 322 (*y) = hn::Add(hn::Add(*y, x0), hn::Add(x1, x2)); 323 } 324 }; 325 326 template <class D> 327 std::vector<size_t> Counts(D d) { 328 const size_t N = Lanes(d); 329 return std::vector<size_t>{1, 330 3, 331 7, 332 16, 333 HWY_MAX(N / 2, 1), 334 HWY_MAX(2 * N / 3, 1), 335 N, 336 N + 1, 337 4 * N / 3, 338 3 * N, 339 8 * N, 340 8 * N + 2, 341 256 * N - 1, 342 256 * N}; 343 } 344 345 struct TestDot { 346 template <typename T, class D> 347 HWY_NOINLINE void operator()(T /*unused*/, D d) { 348 // TODO(janwas): avoid internal compiler error 349 #if HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || HWY_COMPILER_MSVC 350 (void)d; 351 #else 352 RandomState rng; 353 const auto random_t = [&rng]() { 354 const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023; 355 return static_cast<float>(bits - 512) * (1.0f / 64); 356 }; 357 358 for (size_t num : Counts(d)) { 359 AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(num); 360 AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(num); 361 AlignedFreeUniquePtr<T[]> py = AllocateAligned<T>(num); 362 363 HWY_ASSERT(pa && pb && py); 364 T* a = pa.get(); 365 T* b = pb.get(); 366 T* y = py.get(); 367 368 size_t i = 0; 369 for (; i < num; ++i) { 370 a[i] = ConvertScalarTo<T>(random_t()); 371 b[i] = ConvertScalarTo<T>(random_t()); 372 } 373 374 const T expected_dot = DoubleDot(a, b, num); 375 const double expected_dot_f64 = ConvertScalarTo<double>(expected_dot); 376 MultiplyUnit<T> multfn; 377 Unroller(multfn, a, b, y, static_cast<ptrdiff_t>(num)); 378 AccumulateUnit<T> accfn; 379 T dot_via_mul_acc; 380 Unroller(accfn, y, &dot_via_mul_acc, static_cast<ptrdiff_t>(num)); 381 const double tolerance = 120.0 * 382 ConvertScalarTo<double>(hwy::Epsilon<T>()) * 383 std::abs(expected_dot_f64); 384 HWY_ASSERT(std::abs(expected_dot_f64 - ConvertScalarTo<double>( 385 dot_via_mul_acc)) < tolerance); 386 387 DotUnit<T> dotfn; 388 T dotr; 389 Unroller(dotfn, a, b, &dotr, static_cast<ptrdiff_t>(num)); 390 const double dotr_f64 = ConvertScalarTo<double>(dotr); 391 HWY_ASSERT(std::abs(expected_dot_f64 - dotr_f64) < tolerance); 392 393 const T expected_min = DoubleMin(a, num); 394 MinUnit<T> minfn; 395 T minr; 396 Unroller(minfn, a, &minr, static_cast<ptrdiff_t>(num)); 397 398 const double l1 = std::abs(ConvertScalarTo<double>(expected_min) - 399 ConvertScalarTo<double>(minr)); 400 // Unlike above, tolerance is absolute, there should be no numerical 401 // differences between T and double because we just compute the min. 402 HWY_ASSERT(l1 < 1E-7); 403 } 404 #endif 405 } 406 }; 407 408 void TestAllDot() { ForFloatTypes(ForPartialVectors<TestDot>()); } 409 410 struct TestConvert { 411 template <typename T, class D> 412 HWY_NOINLINE void operator()(T /*unused*/, D d) { 413 // TODO(janwas): avoid internal compiler error 414 #if HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || HWY_COMPILER_MSVC 415 (void)d; 416 #else 417 for (size_t num : Counts(d)) { 418 AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(num); 419 AlignedFreeUniquePtr<int[]> pto = AllocateAligned<int>(num); 420 HWY_ASSERT(pa && pto); 421 T* HWY_RESTRICT a = pa.get(); 422 int* HWY_RESTRICT to = pto.get(); 423 424 for (size_t i = 0; i < num; ++i) { 425 a[i] = ConvertScalarTo<T>(static_cast<double>(i) * 0.25); 426 } 427 428 ConvertUnit<T, int> cvtfn; 429 Unroller(cvtfn, a, to, static_cast<ptrdiff_t>(num)); 430 for (size_t i = 0; i < num; ++i) { 431 // TODO(janwas): RVV QEMU fcvt_rtz appears to 'truncate' 4.75 to 5. 432 HWY_ASSERT( 433 static_cast<int>(a[i]) == to[i] || 434 (HWY_TARGET == HWY_RVV && static_cast<int>(a[i]) == to[i] - 1)); 435 } 436 437 ConvertUnit<int, T> cvtbackfn; 438 Unroller(cvtbackfn, to, a, static_cast<ptrdiff_t>(num)); 439 for (size_t i = 0; i < num; ++i) { 440 HWY_ASSERT_EQ(ConvertScalarTo<T>(to[i]), a[i]); 441 } 442 } 443 #endif 444 } 445 }; 446 447 void TestAllConvert() { ForFloat3264Types(ForPartialVectors<TestConvert>()); } 448 449 struct TestFind { 450 template <typename T, class D> 451 HWY_NOINLINE void operator()(T /*unused*/, D d) { 452 for (size_t num : Counts(d)) { 453 AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(num); 454 HWY_ASSERT(pa); 455 T* a = pa.get(); 456 457 for (size_t i = 0; i < num; ++i) a[i] = ConvertScalarTo<T>(i); 458 459 FindUnit<T> cvtfn(ConvertScalarTo<T>(num - 1)); 460 MakeSigned<T> idx = 0; 461 // Explicitly test input can be const 462 const T* const_a = a; 463 Unroller(cvtfn, const_a, &idx, static_cast<ptrdiff_t>(num)); 464 HWY_ASSERT(static_cast<MakeUnsigned<T>>(idx) < num); 465 HWY_ASSERT(a[idx] == ConvertScalarTo<T>(num - 1)); 466 467 FindUnit<T> cvtfnzero((T)(0)); 468 Unroller(cvtfnzero, a, &idx, static_cast<ptrdiff_t>(num)); 469 HWY_ASSERT(static_cast<MakeUnsigned<T>>(idx) < num); 470 HWY_ASSERT(a[idx] == (T)(0)); 471 472 // For f16, we cannot search for `num` because it may round to a value 473 // that is actually in the (large) array. 474 FindUnit<T> cvtfnnotin(HighestValue<T>()); 475 Unroller(cvtfnnotin, a, &idx, static_cast<ptrdiff_t>(num)); 476 HWY_ASSERT(idx == -1); 477 } 478 } 479 }; 480 481 void TestAllFind() { ForFloatTypes(ForPartialVectors<TestFind>()); } 482 483 } // namespace 484 } // namespace HWY_NAMESPACE 485 } // namespace hwy 486 HWY_AFTER_NAMESPACE(); 487 488 #if HWY_ONCE 489 namespace hwy { 490 namespace { 491 HWY_BEFORE_TEST(UnrollerTest); 492 HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllDot); 493 HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllConvert); 494 HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllFind); 495 HWY_AFTER_TEST(); 496 } // namespace 497 } // namespace hwy 498 HWY_TEST_MAIN(); 499 #endif // HWY_ONCE