unroller-inl.h (14545B)
1 // Copyright 2023 Matthew Kolbe 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #if defined(HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_) == \ 17 defined(HWY_TARGET_TOGGLE) 18 #ifdef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ 19 #undef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ 20 #else 21 #define HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ 22 #endif 23 24 #include <cstdlib> // std::abs 25 26 #include "hwy/highway.h" 27 28 HWY_BEFORE_NAMESPACE(); 29 namespace hwy { 30 namespace HWY_NAMESPACE { 31 32 namespace hn = hwy::HWY_NAMESPACE; 33 34 template <class DERIVED, typename IN_T, typename OUT_T> 35 struct UnrollerUnit { 36 static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN_T), sizeof(OUT_T)); 37 using LargerT = SignedFromSize<kMaxTSize>; // only the size matters. 38 39 DERIVED* me() { return static_cast<DERIVED*>(this); } 40 41 static constexpr size_t MaxUnitLanes() { 42 return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>); 43 } 44 static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); } 45 46 using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>; 47 using IT = hn::Rebind<IN_T, LargerD>; 48 using OT = hn::Rebind<OUT_T, LargerD>; 49 IT d_in; 50 OT d_out; 51 using Y_VEC = hn::Vec<OT>; 52 using X_VEC = hn::Vec<IT>; 53 54 Y_VEC Func(const ptrdiff_t idx, const X_VEC x, const Y_VEC y) { 55 return me()->Func(idx, x, y); 56 } 57 58 X_VEC X0Init() { return me()->X0InitImpl(); } 59 60 X_VEC X0InitImpl() { return hn::Zero(d_in); } 61 62 Y_VEC YInit() { return me()->YInitImpl(); } 63 64 Y_VEC YInitImpl() { return hn::Zero(d_out); } 65 66 X_VEC Load(const ptrdiff_t idx, const IN_T* from) { 67 return me()->LoadImpl(idx, from); 68 } 69 70 X_VEC LoadImpl(const ptrdiff_t idx, const IN_T* from) { 71 return hn::LoadU(d_in, from + idx); 72 } 73 74 // MaskLoad can take in either a positive or negative number for `places`. if 75 // the number is positive, then it loads the top `places` values, and if it's 76 // negative, it loads the bottom |places| values. example: places = 3 77 // | o | o | o | x | x | x | x | x | 78 // example places = -3 79 // | x | x | x | x | x | o | o | o | 80 X_VEC MaskLoad(const ptrdiff_t idx, const IN_T* from, 81 const ptrdiff_t places) { 82 return me()->MaskLoadImpl(idx, from, places); 83 } 84 85 X_VEC MaskLoadImpl(const ptrdiff_t idx, const IN_T* from, 86 const ptrdiff_t places) { 87 auto mask = hn::FirstN(d_in, static_cast<size_t>(places)); 88 auto maskneg = hn::Not(hn::FirstN( 89 d_in, 90 static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes())))); 91 if (places < 0) mask = maskneg; 92 93 return hn::MaskedLoad(mask, d_in, from + idx); 94 } 95 96 bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { 97 return me()->StoreAndShortCircuitImpl(idx, to, x); 98 } 99 100 bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { 101 hn::StoreU(x, d_out, to + idx); 102 return true; 103 } 104 105 ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, 106 ptrdiff_t const places) { 107 return me()->MaskStoreImpl(idx, to, x, places); 108 } 109 110 ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, 111 const ptrdiff_t places) { 112 auto mask = hn::FirstN(d_out, static_cast<size_t>(places)); 113 auto maskneg = hn::Not(hn::FirstN( 114 d_out, 115 static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes())))); 116 if (places < 0) mask = maskneg; 117 118 hn::BlendedStore(x, mask, d_out, to + idx); 119 return std::abs(places); 120 } 121 122 ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); } 123 124 ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) { 125 // default does nothing 126 (void)x; 127 (void)to; 128 return 0; 129 } 130 131 void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { 132 me()->ReduceImpl(x0, x1, x2, y); 133 } 134 135 void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { 136 // default does nothing 137 (void)x0; 138 (void)x1; 139 (void)x2; 140 (void)y; 141 } 142 }; 143 144 template <class DERIVED, typename IN0_T, typename IN1_T, typename OUT_T> 145 struct UnrollerUnit2D { 146 DERIVED* me() { return static_cast<DERIVED*>(this); } 147 148 static constexpr size_t kMaxTSize = 149 HWY_MAX(sizeof(IN0_T), HWY_MAX(sizeof(IN1_T), sizeof(OUT_T))); 150 using LargerT = SignedFromSize<kMaxTSize>; // only the size matters. 151 152 static constexpr size_t MaxUnitLanes() { 153 return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>); 154 } 155 static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); } 156 157 using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>; 158 159 using I0T = hn::Rebind<IN0_T, LargerD>; 160 using I1T = hn::Rebind<IN1_T, LargerD>; 161 using OT = hn::Rebind<OUT_T, LargerD>; 162 I0T d_in0; 163 I1T d_in1; 164 OT d_out; 165 using Y_VEC = hn::Vec<OT>; 166 using X0_VEC = hn::Vec<I0T>; 167 using X1_VEC = hn::Vec<I1T>; 168 169 hn::Vec<OT> Func(const ptrdiff_t idx, const hn::Vec<I0T> x0, 170 const hn::Vec<I1T> x1, const Y_VEC y) { 171 return me()->Func(idx, x0, x1, y); 172 } 173 174 X0_VEC X0Init() { return me()->X0InitImpl(); } 175 176 X0_VEC X0InitImpl() { return hn::Zero(d_in0); } 177 178 X1_VEC X1Init() { return me()->X1InitImpl(); } 179 180 X1_VEC X1InitImpl() { return hn::Zero(d_in1); } 181 182 Y_VEC YInit() { return me()->YInitImpl(); } 183 184 Y_VEC YInitImpl() { return hn::Zero(d_out); } 185 186 X0_VEC Load0(const ptrdiff_t idx, const IN0_T* from) { 187 return me()->Load0Impl(idx, from); 188 } 189 190 X0_VEC Load0Impl(const ptrdiff_t idx, const IN0_T* from) { 191 return hn::LoadU(d_in0, from + idx); 192 } 193 194 X1_VEC Load1(const ptrdiff_t idx, const IN1_T* from) { 195 return me()->Load1Impl(idx, from); 196 } 197 198 X1_VEC Load1Impl(const ptrdiff_t idx, const IN1_T* from) { 199 return hn::LoadU(d_in1, from + idx); 200 } 201 202 // maskload can take in either a positive or negative number for `places`. if 203 // the number is positive, then it loads the top `places` values, and if it's 204 // negative, it loads the bottom |places| values. example: places = 3 205 // | o | o | o | x | x | x | x | x | 206 // example places = -3 207 // | x | x | x | x | x | o | o | o | 208 X0_VEC MaskLoad0(const ptrdiff_t idx, const IN0_T* from, 209 const ptrdiff_t places) { 210 return me()->MaskLoad0Impl(idx, from, places); 211 } 212 213 X0_VEC MaskLoad0Impl(const ptrdiff_t idx, const IN0_T* from, 214 const ptrdiff_t places) { 215 auto mask = hn::FirstN(d_in0, static_cast<size_t>(places)); 216 auto maskneg = hn::Not(hn::FirstN( 217 d_in0, 218 static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes())))); 219 if (places < 0) mask = maskneg; 220 221 return hn::MaskedLoad(mask, d_in0, from + idx); 222 } 223 224 hn::Vec<I1T> MaskLoad1(const ptrdiff_t idx, const IN1_T* from, 225 const ptrdiff_t places) { 226 return me()->MaskLoad1Impl(idx, from, places); 227 } 228 229 hn::Vec<I1T> MaskLoad1Impl(const ptrdiff_t idx, const IN1_T* from, 230 const ptrdiff_t places) { 231 auto mask = hn::FirstN(d_in1, static_cast<size_t>(places)); 232 auto maskneg = hn::Not(hn::FirstN( 233 d_in1, 234 static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes())))); 235 if (places < 0) mask = maskneg; 236 237 return hn::MaskedLoad(mask, d_in1, from + idx); 238 } 239 240 // store returns a bool that is `false` when 241 bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { 242 return me()->StoreAndShortCircuitImpl(idx, to, x); 243 } 244 245 bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { 246 hn::StoreU(x, d_out, to + idx); 247 return true; 248 } 249 250 ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, 251 const ptrdiff_t places) { 252 return me()->MaskStoreImpl(idx, to, x, places); 253 } 254 255 ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, 256 const ptrdiff_t places) { 257 auto mask = hn::FirstN(d_out, static_cast<size_t>(places)); 258 auto maskneg = hn::Not(hn::FirstN( 259 d_out, 260 static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes())))); 261 if (places < 0) mask = maskneg; 262 263 hn::BlendedStore(x, mask, d_out, to + idx); 264 return std::abs(places); 265 } 266 267 ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); } 268 269 ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) { 270 // default does nothing 271 (void)x; 272 (void)to; 273 return 0; 274 } 275 276 void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { 277 me()->ReduceImpl(x0, x1, x2, y); 278 } 279 280 void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { 281 // default does nothing 282 (void)x0; 283 (void)x1; 284 (void)x2; 285 (void)y; 286 } 287 }; 288 289 template <class FUNC, typename IN_T, typename OUT_T> 290 inline void Unroller(FUNC& f, const IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y, 291 const ptrdiff_t n) { 292 auto xx = f.X0Init(); 293 auto yy = f.YInit(); 294 ptrdiff_t i = 0; 295 296 #if HWY_MEM_OPS_MIGHT_FAULT 297 constexpr auto lane_sz = 298 static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes()); 299 if (n < lane_sz) { 300 const DFromV<decltype(yy)> d; 301 // this may not fit on the stack for HWY_RVV, but we do not reach this code 302 // there 303 HWY_ALIGN IN_T xtmp[static_cast<size_t>(lane_sz)]; 304 HWY_ALIGN OUT_T ytmp[static_cast<size_t>(lane_sz)]; 305 306 CopyBytes(x, xtmp, static_cast<size_t>(n) * sizeof(IN_T)); 307 xx = f.MaskLoad(0, xtmp, n); 308 yy = f.Func(0, xx, yy); 309 Store(Zero(d), d, ytmp); 310 i += f.MaskStore(0, ytmp, yy, n); 311 i += f.Reduce(yy, ytmp); 312 CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T)); 313 return; 314 } 315 #endif 316 317 const ptrdiff_t actual_lanes = 318 static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes()); 319 if (n > 4 * actual_lanes) { 320 auto xx1 = f.X0Init(); 321 auto yy1 = f.YInit(); 322 auto xx2 = f.X0Init(); 323 auto yy2 = f.YInit(); 324 auto xx3 = f.X0Init(); 325 auto yy3 = f.YInit(); 326 327 while (i + 4 * actual_lanes - 1 < n) { 328 xx = f.Load(i, x); 329 i += actual_lanes; 330 xx1 = f.Load(i, x); 331 i += actual_lanes; 332 xx2 = f.Load(i, x); 333 i += actual_lanes; 334 xx3 = f.Load(i, x); 335 i -= 3 * actual_lanes; 336 337 yy = f.Func(i, xx, yy); 338 yy1 = f.Func(i + actual_lanes, xx1, yy1); 339 yy2 = f.Func(i + 2 * actual_lanes, xx2, yy2); 340 yy3 = f.Func(i + 3 * actual_lanes, xx3, yy3); 341 342 if (!f.StoreAndShortCircuit(i, y, yy)) return; 343 i += actual_lanes; 344 if (!f.StoreAndShortCircuit(i, y, yy1)) return; 345 i += actual_lanes; 346 if (!f.StoreAndShortCircuit(i, y, yy2)) return; 347 i += actual_lanes; 348 if (!f.StoreAndShortCircuit(i, y, yy3)) return; 349 i += actual_lanes; 350 } 351 352 f.Reduce(yy3, yy2, yy1, &yy); 353 } 354 355 while (i + actual_lanes - 1 < n) { 356 xx = f.Load(i, x); 357 yy = f.Func(i, xx, yy); 358 if (!f.StoreAndShortCircuit(i, y, yy)) return; 359 i += actual_lanes; 360 } 361 362 if (i != n) { 363 xx = f.MaskLoad(n - actual_lanes, x, i - n); 364 yy = f.Func(n - actual_lanes, xx, yy); 365 f.MaskStore(n - actual_lanes, y, yy, i - n); 366 } 367 368 f.Reduce(yy, y); 369 } 370 371 template <class FUNC, typename IN0_T, typename IN1_T, typename OUT_T> 372 inline void Unroller(FUNC& HWY_RESTRICT f, IN0_T* HWY_RESTRICT x0, 373 IN1_T* HWY_RESTRICT x1, OUT_T* HWY_RESTRICT y, 374 const ptrdiff_t n) { 375 const ptrdiff_t lane_sz = 376 static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes()); 377 378 auto xx00 = f.X0Init(); 379 auto xx10 = f.X1Init(); 380 auto yy = f.YInit(); 381 382 ptrdiff_t i = 0; 383 384 #if HWY_MEM_OPS_MIGHT_FAULT 385 if (n < lane_sz) { 386 const DFromV<decltype(yy)> d; 387 // this may not fit on the stack for HWY_RVV, but we do not reach this code 388 // there 389 constexpr auto max_lane_sz = 390 static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes()); 391 HWY_ALIGN IN0_T xtmp0[static_cast<size_t>(max_lane_sz)]; 392 HWY_ALIGN IN1_T xtmp1[static_cast<size_t>(max_lane_sz)]; 393 HWY_ALIGN OUT_T ytmp[static_cast<size_t>(max_lane_sz)]; 394 395 CopyBytes(x0, xtmp0, static_cast<size_t>(n) * sizeof(IN0_T)); 396 CopyBytes(x1, xtmp1, static_cast<size_t>(n) * sizeof(IN1_T)); 397 xx00 = f.MaskLoad0(0, xtmp0, n); 398 xx10 = f.MaskLoad1(0, xtmp1, n); 399 yy = f.Func(0, xx00, xx10, yy); 400 Store(Zero(d), d, ytmp); 401 i += f.MaskStore(0, ytmp, yy, n); 402 i += f.Reduce(yy, ytmp); 403 CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T)); 404 return; 405 } 406 #endif 407 408 if (n > 4 * lane_sz) { 409 auto xx01 = f.X0Init(); 410 auto xx11 = f.X1Init(); 411 auto yy1 = f.YInit(); 412 auto xx02 = f.X0Init(); 413 auto xx12 = f.X1Init(); 414 auto yy2 = f.YInit(); 415 auto xx03 = f.X0Init(); 416 auto xx13 = f.X1Init(); 417 auto yy3 = f.YInit(); 418 419 while (i + 4 * lane_sz - 1 < n) { 420 xx00 = f.Load0(i, x0); 421 xx10 = f.Load1(i, x1); 422 i += lane_sz; 423 xx01 = f.Load0(i, x0); 424 xx11 = f.Load1(i, x1); 425 i += lane_sz; 426 xx02 = f.Load0(i, x0); 427 xx12 = f.Load1(i, x1); 428 i += lane_sz; 429 xx03 = f.Load0(i, x0); 430 xx13 = f.Load1(i, x1); 431 i -= 3 * lane_sz; 432 433 yy = f.Func(i, xx00, xx10, yy); 434 yy1 = f.Func(i + lane_sz, xx01, xx11, yy1); 435 yy2 = f.Func(i + 2 * lane_sz, xx02, xx12, yy2); 436 yy3 = f.Func(i + 3 * lane_sz, xx03, xx13, yy3); 437 438 if (!f.StoreAndShortCircuit(i, y, yy)) return; 439 i += lane_sz; 440 if (!f.StoreAndShortCircuit(i, y, yy1)) return; 441 i += lane_sz; 442 if (!f.StoreAndShortCircuit(i, y, yy2)) return; 443 i += lane_sz; 444 if (!f.StoreAndShortCircuit(i, y, yy3)) return; 445 i += lane_sz; 446 } 447 448 f.Reduce(yy3, yy2, yy1, &yy); 449 } 450 451 while (i + lane_sz - 1 < n) { 452 xx00 = f.Load0(i, x0); 453 xx10 = f.Load1(i, x1); 454 yy = f.Func(i, xx00, xx10, yy); 455 if (!f.StoreAndShortCircuit(i, y, yy)) return; 456 i += lane_sz; 457 } 458 459 if (i != n) { 460 xx00 = f.MaskLoad0(n - lane_sz, x0, i - n); 461 xx10 = f.MaskLoad1(n - lane_sz, x1, i - n); 462 yy = f.Func(n - lane_sz, xx00, xx10, yy); 463 f.MaskStore(n - lane_sz, y, yy, i - n); 464 } 465 466 f.Reduce(yy, y); 467 } 468 469 } // namespace HWY_NAMESPACE 470 } // namespace hwy 471 HWY_AFTER_NAMESPACE(); 472 473 #endif // HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_