auto_tune.h (19788B)
1 // Copyright 2025 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #ifndef HIGHWAY_HWY_AUTO_TUNE_H_ 17 #define HIGHWAY_HWY_AUTO_TUNE_H_ 18 19 #include <stddef.h> 20 #include <stdint.h> 21 #include <string.h> // memmove 22 23 #include <cmath> 24 #include <vector> 25 26 #include "hwy/aligned_allocator.h" // Span 27 #include "hwy/base.h" // HWY_MIN 28 29 // configuration to allow auto_tune to use std::sort instead of VQSort 30 // (also enabled in header only mode). 31 #if defined(HWY_HEADER_ONLY) 32 #define HWY_AUTOTUNE_STDSORT 33 #endif 34 35 #ifdef HWY_AUTOTUNE_STDSORT 36 #include <algorithm> // std::sort 37 #else 38 #include "hwy/contrib/sort/vqsort.h" // VQSort 39 #endif 40 41 // Infrastructure for auto-tuning (choosing optimal parameters at runtime). 42 43 namespace hwy { 44 45 // O(1) storage to estimate the central tendency of hundreds of independent 46 // distributions (one per configuration). The number of samples per distribution 47 // (`kMinSamples`) varies from few to dozens. We support both by first storing 48 // values in a buffer, and when full, switching to online variance estimation. 49 // Modified from `hwy/stats.h`. 50 class CostDistribution { 51 public: 52 static constexpr size_t kMaxValues = 14; // for total size of 128 bytes 53 54 void Notify(const double x) { 55 if (HWY_UNLIKELY(x < 0.0)) { 56 HWY_WARN("Ignoring negative cost %f.", x); 57 return; 58 } 59 60 // Online phase after filling and warm-up. 61 if (HWY_LIKELY(IsOnline())) return OnlineNotify(x); 62 63 // Fill phase: store up to `kMaxValues` values. 64 values_[num_values_++] = x; 65 HWY_DASSERT(num_values_ <= kMaxValues); 66 if (HWY_UNLIKELY(num_values_ == kMaxValues)) { 67 WarmUpOnline(); 68 HWY_DASSERT(IsOnline()); 69 } 70 } 71 72 // Returns an estimate of the true cost, mitigating the impact of noise. 73 // 74 // Background and observations from time measurements in `thread_pool.h`: 75 // - We aim for O(1) storage because there may be hundreds of instances. 76 // - The mean is biased upwards by mostly additive noise: particularly 77 // interruptions such as context switches, but also contention. 78 // - The minimum is not a robust estimator because there are also "lucky 79 // shots" (1.2-1.6x lower values) where interruptions or contention happen 80 // to be low. 81 // - We want to preserve information about contention and a configuration's 82 // sensitivity to it. Otherwise, we are optimizing for the best-case, not 83 // the common case. 84 // - It is still important to minimize the influence of outliers, such as page 85 // faults, which can cause multiple times larger measurements. 86 // - Detecting outliers based only on the initial variance is too brittle. If 87 // the sample is narrow, measurements will fluctuate across runs because 88 // too many measurements are considered outliers. This would cause the 89 // 'best' configuration to vary. 90 // 91 // Approach: 92 // - Use Winsorization to reduce the impact of outliers, while preserving 93 // information on the central tendency. 94 // - Continually update the thresholds based on the online variance, with 95 // exponential smoothing for stability. 96 // - Trim the initial sample via MAD or skewness for a robust estimate of the 97 // variance. 98 double EstimateCost() { 99 if (!IsOnline()) { 100 WarmUpOnline(); 101 HWY_DASSERT(IsOnline()); 102 } 103 return Mean(); 104 } 105 106 // Multiplex online state into values_ to allow higher `kMaxValues`. 107 // Public for inspection in tests. Do not use directly. 108 double& M1() { return values_[0]; } // Moments for variance. 109 double& M2() { return values_[1]; } 110 double& Mean() { return values_[2]; } // Exponential smoothing. 111 double& Stddev() { return values_[3]; } 112 double& Lower() { return values_[4]; } 113 double& Upper() { return values_[5]; } 114 115 private: 116 static double Median(double* to_sort, size_t n) { 117 HWY_DASSERT(n >= 2); 118 119 #ifdef HWY_AUTOTUNE_STDSORT 120 std::sort(to_sort, to_sort + n); 121 #else 122 // F64 is supported everywhere except Armv7. 123 #if !HWY_ARCH_ARM_V7 124 VQSort(to_sort, n, SortAscending()); 125 #else 126 // Values are known to be finite and non-negative, hence sorting as U64 is 127 // equivalent. 128 VQSort(reinterpret_cast<uint64_t*>(to_sort), n, SortAscending()); 129 #endif 130 #endif 131 132 if (n & 1) return to_sort[n / 2]; 133 // Even length: average of two middle elements. 134 return (to_sort[n / 2] + to_sort[n / 2 - 1]) * 0.5; 135 } 136 137 static double MAD(const double* values, size_t n, const double median) { 138 double abs_dev[kMaxValues]; 139 for (size_t i = 0; i < n; ++i) { 140 abs_dev[i] = ScalarAbs(values[i] - median); 141 } 142 return Median(abs_dev, n); 143 } 144 145 // If `num_values_` is large enough, sorts and discards outliers: either via 146 // MAD, or if too many values are equal, by trimming according to skewness. 147 void RemoveOutliers() { 148 if (num_values_ < 3) return; // Not enough to discard two. 149 HWY_DASSERT(num_values_ <= kMaxValues); 150 151 // Given the noise level in `auto_tune_test`, it can happen that 1/4 of the 152 // sample is an outlier *in either direction*. Use median absolute 153 // deviation, which is robust to almost half of the sample being outliers. 154 const double median = Median(values_, num_values_); // sorts in-place. 155 const double mad = MAD(values_, num_values_, median); 156 // At least half the sample is equal. 157 if (mad == 0.0) { 158 // Estimate skewness to decide which side to trim more. 159 const double skewness = 160 (values_[num_values_ - 1] - median) - (median - values_[0]); 161 162 const size_t trim = HWY_MAX(num_values_ / 2, size_t{2}); 163 const size_t left = 164 HWY_MAX(skewness < 0.0 ? trim * 3 / 4 : trim / 4, size_t{1}); 165 num_values_ -= trim; 166 HWY_DASSERT(num_values_ >= 1); 167 memmove(values_, values_ + left, num_values_ * sizeof(values_[0])); 168 return; 169 } 170 171 const double upper = median + 5.0 * mad; 172 const double lower = median - 5.0 * mad; 173 size_t right = num_values_ - 1; 174 while (values_[right] > upper) --right; 175 // Nonzero MAD implies no more than half are equal, so we did not advance 176 // beyond the median. 177 HWY_DASSERT(right >= num_values_ / 2); 178 179 size_t left = 0; 180 while (left < right && values_[left] < lower) ++left; 181 HWY_DASSERT(left <= num_values_ / 2); 182 num_values_ = right - left + 1; 183 memmove(values_, values_ + left, num_values_ * sizeof(values_[0])); 184 } 185 186 double SampleMean() const { 187 // Only called in non-online phase, but buffer might not be full. 188 HWY_DASSERT(!IsOnline() && 0 != num_values_ && num_values_ <= kMaxValues); 189 double sum = 0.0; 190 for (size_t i = 0; i < num_values_; ++i) { 191 sum += values_[i]; 192 } 193 return sum / static_cast<double>(num_values_); 194 } 195 196 // Unbiased estimator for population variance even for small `num_values_`. 197 double SampleVariance(double sample_mean) const { 198 HWY_DASSERT(sample_mean >= 0.0); // we checked costs are non-negative. 199 // Only called in non-online phase, but buffer might not be full. 200 HWY_DASSERT(!IsOnline() && 0 != num_values_ && num_values_ <= kMaxValues); 201 if (HWY_UNLIKELY(num_values_ == 1)) return 0.0; // prevent divide-by-zero. 202 double sum2 = 0.0; 203 for (size_t i = 0; i < num_values_; ++i) { 204 const double d = values_[i] - sample_mean; 205 sum2 += d * d; 206 } 207 return sum2 / static_cast<double>(num_values_ - 1); 208 } 209 210 bool IsOnline() const { return online_n_ > 0.0; } 211 212 void OnlineNotify(double x) { 213 // Winsorize. 214 x = HWY_MIN(HWY_MAX(Lower(), x), Upper()); 215 216 // Welford's online variance estimator. 217 // https://media.thinkbrg.com/wp-content/uploads/2020/06/19094655/720_720_McCrary_ImplementingAlgorithms_Whitepaper_20151119_WEB.pdf#page=7.09 218 const double n_minus_1 = online_n_; 219 online_n_ += 1.0; 220 const double d = x - M1(); 221 const double d_div_n = d / online_n_; 222 M1() += d_div_n; 223 HWY_DASSERT(M1() >= Lower()); 224 M2() += d * n_minus_1 * d_div_n; // d^2 * (N-1)/N 225 // HWY_MAX avoids divide-by-zero. 226 const double stddev = std::sqrt(M2() / HWY_MAX(1.0, n_minus_1)); 227 228 // Exponential smoothing. 229 constexpr double kNew = 0.2; // relatively fast update 230 constexpr double kOld = 1.0 - kNew; 231 Mean() = M1() * kNew + Mean() * kOld; 232 Stddev() = stddev * kNew + Stddev() * kOld; 233 234 // Update thresholds from smoothed mean and stddev to enable recovering from 235 // a too narrow initial range due to excessive trimming. 236 Lower() = Mean() - 3.5 * Stddev(); 237 Upper() = Mean() + 3.5 * Stddev(); 238 } 239 240 void WarmUpOnline() { 241 RemoveOutliers(); 242 243 // Compute and copy before writing to `M1`, which overwrites `values_`! 244 const double sample_mean = SampleMean(); 245 const double sample_variance = SampleVariance(sample_mean); 246 double copy[kMaxValues]; 247 hwy::CopyBytes(values_, copy, num_values_ * sizeof(values_[0])); 248 249 M1() = M2() = 0.0; 250 Mean() = sample_mean; 251 Stddev() = std::sqrt(sample_variance); 252 // For single-value or all-equal sample, widen the range, else we will only 253 // accept the same value. 254 if (Stddev() == 0.0) Stddev() = Mean() / 2; 255 256 // High tolerance because the distribution is not actually Gaussian, and 257 // we trimmed up to *half*, and do not want to reject too many values in 258 // the online phase. 259 Lower() = Mean() - 4.0 * Stddev(); 260 Upper() = Mean() + 4.0 * Stddev(); 261 // Feed copied values into online estimator. 262 for (size_t i = 0; i < num_values_; ++i) { 263 OnlineNotify(copy[i]); 264 } 265 HWY_DASSERT(IsOnline()); 266 } 267 268 size_t num_values_ = 0; // size of `values_` <= `kMaxValues` 269 #if SIZE_MAX == 0xFFFFFFFFu 270 HWY_MAYBE_UNUSED uint32_t padding_ = 0; 271 #endif 272 273 double online_n_ = 0.0; // number of calls to `OnlineNotify`. 274 275 double values_[kMaxValues]; 276 }; 277 static_assert(sizeof(CostDistribution) == 128, ""); 278 279 // Implements a counter with wrap-around, plus the ability to skip values. 280 // O(1) time, O(N) space via doubly-linked list of indices. 281 class NextWithSkip { 282 public: 283 NextWithSkip() {} 284 explicit NextWithSkip(size_t num) { 285 links_.reserve(num); 286 for (size_t i = 0; i < num; ++i) { 287 links_.emplace_back(i, num); 288 } 289 } 290 291 size_t Next(size_t pos) { 292 HWY_DASSERT(pos < links_.size()); 293 HWY_DASSERT(!links_[pos].IsRemoved()); 294 return links_[pos].Next(); 295 } 296 297 // Must not be called for an already skipped position. Ignores an attempt to 298 // skip the last remaining position. 299 void Skip(size_t pos) { 300 HWY_DASSERT(!links_[pos].IsRemoved()); // not already skipped. 301 const size_t prev = links_[pos].Prev(); 302 const size_t next = links_[pos].Next(); 303 if (prev == pos || next == pos) return; // last remaining position. 304 links_[next].SetPrev(prev); 305 links_[prev].SetNext(next); 306 links_[pos].Remove(); 307 } 308 309 private: 310 // Combine prev/next into one array to improve locality/reduce allocations. 311 class Link { 312 // Bit-shifts avoid potentially expensive 16-bit loads. Store `next` at the 313 // top and `prev` at the bottom for extraction with a single shift/AND. 314 // There may be hundreds of configurations, so 8 bits are not enough. 315 static constexpr size_t kBits = 14; 316 static constexpr size_t kShift = 32 - kBits; 317 static constexpr uint32_t kMaxNum = 1u << kBits; 318 319 public: 320 Link(size_t pos, size_t num) { 321 HWY_DASSERT(num < kMaxNum); 322 const size_t prev = pos == 0 ? num - 1 : pos - 1; 323 const size_t next = pos == num - 1 ? 0 : pos + 1; 324 bits_ = 325 (static_cast<uint32_t>(next) << kShift) | static_cast<uint32_t>(prev); 326 HWY_DASSERT(Next() == next && Prev() == prev); 327 HWY_DASSERT(!IsRemoved()); 328 } 329 330 bool IsRemoved() const { return (bits_ & kMaxNum) != 0; } 331 void Remove() { bits_ |= kMaxNum; } 332 333 size_t Next() const { return bits_ >> kShift; } 334 size_t Prev() const { return bits_ & (kMaxNum - 1); } 335 336 void SetNext(size_t next) { 337 HWY_DASSERT(next < kMaxNum); 338 bits_ &= (~0u >> kBits); // clear old next 339 bits_ |= static_cast<uint32_t>(next) << kShift; 340 HWY_DASSERT(Next() == next); 341 HWY_DASSERT(!IsRemoved()); 342 } 343 void SetPrev(size_t prev) { 344 HWY_DASSERT(prev < kMaxNum); 345 bits_ &= ~(kMaxNum - 1); // clear old prev 346 bits_ |= static_cast<uint32_t>(prev); 347 HWY_DASSERT(Prev() == prev); 348 HWY_DASSERT(!IsRemoved()); 349 } 350 351 private: 352 uint32_t bits_; 353 }; 354 std::vector<Link> links_; 355 }; 356 357 // State machine for choosing at runtime the lowest-cost `Config`, which is 358 // typically a struct containing multiple parameters. For an introduction, see 359 // "Auto-Tuning and Performance Portability on Heterogeneous Hardware". 360 // 361 // **Which parameters** 362 // Note that simple parameters such as the L2 cache size can be directly queried 363 // via `hwy/contrib/thread_pool/topology.h`. Difficult to predict parameters 364 // such as task granularity are more appropriate for auto-tuning. We also 365 // suggest that at least some parameters should also be 'algorithm variants' 366 // such as parallel vs. serial, or 2D tiling vs. 1D striping. 367 // 368 // **Search strategy** 369 // To guarantee the optimal result, we use exhaustive search, which is suitable 370 // for around 10 parameters and a few hundred combinations of 'candidate' 371 // configurations. 372 // 373 // **How to generate candidates** 374 // To keep this framework simple and generic, applications enumerate the search 375 // space and pass the list of all feasible candidates to `SetCandidates` before 376 // the first call to `NextConfig`. Applications should prune the space as much 377 // as possible, e.g. by upper-bounding parameters based on the known cache 378 // sizes, and applying constraints such as one being a multiple of another. 379 // 380 // **Usage** 381 // Applications typically conditionally branch to the code implementing the 382 // configuration returned by `NextConfig`. They measure the cost of running it 383 // and pass that to `NotifyCost`. Branching avoids the complexity and 384 // opaqueness of a JIT. The number of branches can be reduced (at the cost of 385 // code size) by inlining low-level decisions into larger code regions, e.g. by 386 // hoisting them outside hot loops. 387 // 388 // **What is cost** 389 // Cost is an arbitrary `uint64_t`, with lower values being better. Most 390 // applications will use the elapsed time. If the tasks being tuned are short, 391 // it is important to use a high-resolution timer such as `hwy/timer.h`. Energy 392 // may also be useful [https://www.osti.gov/servlets/purl/1361296]. 393 // 394 // **Online vs. offline** 395 // Although applications can auto-tune once, offline, it may be difficult to 396 // ensure the stored configuration still applies to the current circumstances. 397 // Thus we recommend online auto-tuning, re-discovering the configuration on 398 // each run. We assume the overhead of bookkeeping and measuring cost is 399 // negligible relative to the actual work. The cost of auto-tuning is then that 400 // of running sub-optimal configurations. Assuming the best configuration is 401 // better than baseline, and the work is performed many thousands of times, the 402 // cost is outweighed by the benefits. 403 // 404 // **kMinSamples** 405 // To further reduce overhead, after `kMinSamples` rounds (= measurements of 406 // each configuration) we start excluding configurations from further 407 // measurements if they are sufficiently worse than the current best. 408 // `kMinSamples` can be several dozen when the tasks being tuned take a few 409 // microseconds. Even for longer tasks, it should be at least 2 for some noise 410 // tolerance. After this, there are another `kMinSamples / 2 + 1` rounds before 411 // declaring the winner. 412 template <typename Config, size_t kMinSamples = 2> 413 class AutoTune { 414 public: 415 // Returns non-null best configuration if auto-tuning has already finished. 416 // Otherwise, callers continue calling `NextConfig` and `NotifyCost`. 417 // Points into `Candidates()`. 418 const Config* Best() const { return best_; } 419 420 // If false, caller must call `SetCandidates` before `NextConfig`. 421 // NOTE: also called after Best() is non-null. 422 bool HasCandidates() const { return !candidates_.empty(); } 423 424 // WARNING: invalidates `Best()`, do not call if that is non-null. 425 void SetCandidates(std::vector<Config> candidates) { 426 HWY_DASSERT(!Best() && !HasCandidates()); 427 candidates_.swap(candidates); 428 HWY_DASSERT(HasCandidates()); 429 costs_.resize(candidates_.size()); 430 list_ = NextWithSkip(candidates_.size()); 431 } 432 433 // Typically called after Best() is non-null to compare all candidates' costs. 434 Span<const Config> Candidates() const { 435 HWY_DASSERT(HasCandidates()); 436 return Span<const Config>(candidates_.data(), candidates_.size()); 437 } 438 Span<CostDistribution> Costs() { 439 return Span<CostDistribution>(costs_.data(), costs_.size()); 440 } 441 442 // Returns the current `Config` to measure. 443 const Config& NextConfig() const { 444 HWY_DASSERT(HasCandidates()); 445 return candidates_[config_idx_]; 446 } 447 448 // O(1) except at the end of each round, which is O(N). 449 void NotifyCost(uint64_t cost) { 450 HWY_DASSERT(!Best() && HasCandidates()); 451 452 costs_[config_idx_].Notify(static_cast<double>(cost)); 453 // Save now before we update `config_idx_`. 454 const size_t my_idx = config_idx_; 455 // Only retrieve once we have enough samples, otherwise, we switch to 456 // online variance before the buffer is populated. 457 const double my_cost = rounds_complete_ >= kMinSamples 458 ? costs_[config_idx_].EstimateCost() 459 : 0.0; 460 461 // Advance to next non-skipped config with wrap-around. This decorrelates 462 // measurements by not immediately re-measuring the same config. 463 config_idx_ = list_.Next(config_idx_); 464 // Might still equal `my_idx` if this is the only non-skipped config. 465 466 // Disqualify from future `NextConfig` if cost was too far beyond the 467 // current best. This reduces the number of measurements, while tolerating 468 // noise in the first few measurements. Must happen after advancing. 469 if (my_cost > skip_if_above_) { 470 list_.Skip(my_idx); 471 } 472 473 // Wrap-around indicates the round is complete. 474 if (HWY_UNLIKELY(config_idx_ <= my_idx)) { 475 ++rounds_complete_; 476 477 // Enough samples for stable estimates: update the thresholds. 478 if (rounds_complete_ >= kMinSamples) { 479 double best_cost = HighestValue<double>(); 480 size_t idx_min = 0; 481 for (size_t i = 0; i < candidates_.size(); ++i) { 482 const double estimate = costs_[i].EstimateCost(); 483 if (estimate < best_cost) { 484 best_cost = estimate; 485 idx_min = i; 486 } 487 } 488 skip_if_above_ = best_cost * 1.25; 489 490 // After sufficient rounds, declare the winner. 491 if (HWY_UNLIKELY(rounds_complete_ == 3 * kMinSamples / 2 + 1)) { 492 best_ = &candidates_[idx_min]; 493 HWY_DASSERT(Best()); 494 } 495 } 496 } 497 } 498 499 // Avoid printing during the first few rounds, because those might be noisy 500 // and not yet skipped. 501 bool ShouldPrint() { return rounds_complete_ > kMinSamples; } 502 503 private: 504 const Config* best_ = nullptr; 505 std::vector<Config> candidates_; 506 std::vector<CostDistribution> costs_; // one per candidate 507 size_t config_idx_ = 0; // [0, candidates_.size()) 508 NextWithSkip list_; 509 size_t rounds_complete_ = 0; 510 511 double skip_if_above_ = 0.0; 512 }; 513 514 } // namespace hwy 515 516 #endif // HIGHWAY_HWY_AUTO_TUNE_H_