thread_pool_test.cc (16609B)
1 // Copyright 2023 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // Modified from BSD-licensed code 17 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 18 // See https://github.com/libjxl/libjxl/blob/main/LICENSE. 19 20 #include "hwy/contrib/thread_pool/thread_pool.h" 21 22 #include <math.h> // sqrtf 23 #include <stddef.h> 24 #include <stdint.h> 25 #include <stdio.h> 26 27 #include <atomic> 28 #include <thread> // NOLINT 29 #include <vector> 30 31 #include "hwy/base.h" // PopCount 32 #include "hwy/contrib/thread_pool/spin.h" 33 #include "hwy/contrib/thread_pool/topology.h" 34 #include "hwy/tests/hwy_gtest.h" 35 #include "hwy/tests/test_util-inl.h" // AdjustedReps 36 37 namespace hwy { 38 namespace pool { 39 namespace { 40 41 TEST(ThreadPoolTest, TestCoprime) { 42 // 1 is coprime with anything 43 for (uint32_t i = 1; i < 500; ++i) { 44 HWY_ASSERT(ShuffledIota::CoprimeNonzero(1, i)); 45 HWY_ASSERT(ShuffledIota::CoprimeNonzero(i, 1)); 46 } 47 48 // Powers of two >= 2 are not coprime 49 for (size_t i = 1; i < 20; ++i) { 50 for (size_t j = 1; j < 20; ++j) { 51 HWY_ASSERT(!ShuffledIota::CoprimeNonzero(1u << i, 1u << j)); 52 } 53 } 54 55 // 2^x and 2^x +/- 1 are coprime 56 for (size_t i = 1; i < 30; ++i) { 57 const uint32_t pow2 = 1u << i; 58 HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2, pow2 + 1)); 59 HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2, pow2 - 1)); 60 HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2 + 1, pow2)); 61 HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2 - 1, pow2)); 62 } 63 64 // Random number x * random y (both >= 2) is not co-prime with x nor y. 65 RandomState rng; 66 for (size_t i = 1; i < 5000; ++i) { 67 const uint32_t x = (Random32(&rng) & 0xFFF7) + 2; 68 const uint32_t y = (Random32(&rng) & 0xFFF7) + 2; 69 HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x * y, x)); 70 HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x * y, y)); 71 HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x, x * y)); 72 HWY_ASSERT(!ShuffledIota::CoprimeNonzero(y, x * y)); 73 } 74 75 // Primes are all coprime (list from https://oeis.org/A000040) 76 static constexpr uint32_t primes[] = { 77 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 78 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 79 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 80 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271}; 81 for (size_t i = 0; i < sizeof(primes) / sizeof(primes[0]); ++i) { 82 for (size_t j = i + 1; j < sizeof(primes) / sizeof(primes[0]); ++j) { 83 HWY_ASSERT(ShuffledIota::CoprimeNonzero(primes[i], primes[j])); 84 HWY_ASSERT(ShuffledIota::CoprimeNonzero(primes[j], primes[i])); 85 } 86 } 87 } 88 89 // Ensures `shuffled` visits [0, size) exactly once starting from `current`. 90 void VerifyPermutation(uint32_t size, const Divisor64& divisor, 91 const ShuffledIota& shuffled, uint32_t current, 92 uint32_t* visited) { 93 for (size_t i = 0; i < size; i++) { 94 visited[i] = 0; 95 } 96 97 for (size_t i = 0; i < size; i++) { 98 ++visited[current]; 99 current = shuffled.Next(current, divisor); 100 } 101 102 for (size_t i = 0; i < size; i++) { 103 HWY_ASSERT(visited[i] == 1); 104 } 105 } 106 107 // Verifies ShuffledIota generates a permutation of [0, size). 108 TEST(ThreadPoolTest, TestRandomPermutation) { 109 constexpr size_t kMaxSize = 40; 110 uint32_t visited[kMaxSize]; 111 112 // Exhaustive enumeration of size and starting point. 113 for (uint32_t size = 1; size < kMaxSize; ++size) { 114 const Divisor64 divisor(size); 115 116 const uint32_t coprime = ShuffledIota::FindAnotherCoprime(size, 1); 117 const ShuffledIota shuffled(coprime); 118 119 for (uint32_t start = 0; start < size; ++start) { 120 VerifyPermutation(size, divisor, shuffled, start, visited); 121 } 122 } 123 } 124 125 // Verifies multiple ShuffledIota are relatively independent. 126 TEST(ThreadPoolTest, TestMultiplePermutations) { 127 constexpr size_t kMaxSize = 40; 128 uint32_t coprimes[kMaxSize]; 129 // One per ShuffledIota; initially the starting value, then its Next(). 130 uint32_t current[kMaxSize]; 131 132 for (uint32_t size = 1; size < kMaxSize; ++size) { 133 const Divisor64 divisor(size); 134 135 // Create `size` ShuffledIota instances with unique coprimes. 136 std::vector<ShuffledIota> shuffled; 137 for (size_t i = 0; i < size; ++i) { 138 coprimes[i] = ShuffledIota::FindAnotherCoprime( 139 size, static_cast<uint32_t>((i + 1) * 257 + i * 13)); 140 shuffled.emplace_back(coprimes[i]); 141 } 142 143 // ShuffledIota[i] starts at i to match the worker thread use case. 144 for (uint32_t i = 0; i < size; ++i) { 145 current[i] = i; 146 } 147 148 size_t num_bad = 0; 149 uint32_t all_visited[kMaxSize] = {0}; 150 151 // For each step, ensure there are few non-unique current[]. 152 for (size_t step = 0; step < size; ++step) { 153 // How many times is each number visited? 154 uint32_t visited[kMaxSize] = {0}; 155 for (size_t i = 0; i < size; ++i) { 156 visited[current[i]] += 1; 157 all_visited[current[i]] = 1; // visited at all across all steps? 158 } 159 160 // How many numbers are visited multiple times? 161 size_t num_contended = 0; 162 uint32_t max_contention = 0; 163 for (size_t i = 0; i < size; ++i) { 164 num_contended += visited[i] > 1; 165 max_contention = HWY_MAX(max_contention, visited[i]); 166 } 167 168 // Count/print if excessive collisions. 169 const size_t expected = 170 static_cast<size_t>(sqrtf(static_cast<float>(size)) * 2.0f); 171 if ((num_contended > expected) && (max_contention > 3)) { 172 ++num_bad; 173 if (true) { 174 fprintf(stderr, "size %u step %zu contended %zu max contention %u\n", 175 size, step, num_contended, max_contention); 176 for (size_t i = 0; i < size; ++i) { 177 fprintf(stderr, " %u\n", current[i]); 178 } 179 fprintf(stderr, "coprimes\n"); 180 for (size_t i = 0; i < size; ++i) { 181 fprintf(stderr, " %u\n", coprimes[i]); 182 } 183 } 184 } 185 186 // Advance all ShuffledIota generators. 187 for (size_t i = 0; i < size; ++i) { 188 current[i] = shuffled[i].Next(current[i], divisor); 189 } 190 } // step 191 192 // Ensure each task was visited during at least one step. 193 for (size_t i = 0; i < size; ++i) { 194 HWY_ASSERT(all_visited[i] != 0); 195 } 196 197 if (num_bad != 0) { 198 fprintf(stderr, "size %u total bad: %zu\n", size, num_bad); 199 } 200 HWY_ASSERT(num_bad < kMaxSize / 10); 201 } // size 202 } 203 204 class DoWait { 205 public: 206 explicit DoWait(Worker& worker) : worker_(worker) {} 207 208 template <class Spin, class Wait> 209 void operator()(const Spin& spin, const Wait& wait) const { 210 wait.UntilWoken(worker_, spin); 211 } 212 213 private: 214 Worker& worker_; 215 }; 216 217 class DoWakeWorkers { 218 public: 219 explicit DoWakeWorkers(Worker* workers) : workers_(workers) {} 220 221 template <class Spin, class Wait> 222 void operator()(const Spin&, const Wait& wait) const { 223 wait.WakeWorkers(workers_, workers_[0].WorkerEpoch()); 224 } 225 226 private: 227 Worker* const workers_; 228 }; 229 230 // Verifies that waiter(s) can be woken by another thread. 231 TEST(ThreadPoolTest, TestWaiter) { 232 if (!hwy::HaveThreadingSupport()) return; 233 234 // Not actual threads, but we allocate and loop over this many workers. 235 for (size_t num_threads = 1; num_threads < 6; ++num_threads) { 236 const size_t num_workers = 1 + num_threads; 237 auto storage = hwy::AllocateAligned<uint8_t>(num_workers * sizeof(Worker)); 238 HWY_ASSERT(storage); 239 const Divisor64 div_workers(num_workers); 240 Shared& shared = Shared::Get(); // already calls ReserveWorker(0). 241 242 for (WaitType wait_type : 243 {WaitType::kBlock, WaitType::kSpin1, WaitType::kSpinSeparate}) { 244 Worker* workers = pool::WorkerLifecycle::Init( 245 storage.get(), num_threads, PoolWorkerMapping(), div_workers, shared); 246 247 alignas(8) const Config config(SpinType::kPause, wait_type); 248 249 // This thread acts as the "main thread", which will wake the actual main 250 // and all its worker instances. 251 std::thread thread( 252 [&]() { CallWithConfig(config, DoWakeWorkers(workers)); }); 253 254 // main is 0 255 for (size_t worker = 1; worker < num_workers; ++worker) { 256 CallWithConfig(config, DoWait(workers[1])); 257 } 258 thread.join(); 259 260 pool::WorkerLifecycle::Destroy(workers, num_workers); 261 } 262 } 263 } 264 265 // Ensures all tasks are run. Similar to TestPool below but without threads. 266 TEST(ThreadPoolTest, TestTasks) { 267 for (size_t num_threads = 1; num_threads <= 8; ++num_threads) { 268 const size_t num_workers = num_threads + 1; 269 auto storage = hwy::AllocateAligned<uint8_t>(num_workers * sizeof(Worker)); 270 HWY_ASSERT(storage); 271 const Divisor64 div_workers(num_workers); 272 Shared& shared = Shared::Get(); 273 Stats stats; 274 Worker* workers = WorkerLifecycle::Init( 275 storage.get(), num_threads, PoolWorkerMapping(), div_workers, shared); 276 277 constexpr uint64_t kMaxTasks = 20; 278 uint64_t mementos[kMaxTasks]; // non-atomic, no threads involved. 279 for (uint64_t num_tasks = 0; num_tasks < 20; ++num_tasks) { 280 for (uint64_t begin = 0; begin < AdjustedReps(32); ++begin) { 281 const uint64_t end = begin + num_tasks; 282 283 ZeroBytes(mementos, sizeof(mementos)); 284 const auto func = [begin, end, &mementos](uint64_t task, 285 size_t /*worker*/) { 286 HWY_ASSERT(begin <= task && task < end); 287 288 // Store mementos ensure we visited each task. 289 mementos[task - begin] = 1000 + task; 290 }; 291 Tasks tasks; 292 tasks.Set(begin, end, func); 293 294 Tasks::DivideRangeAmongWorkers(begin, end, div_workers, workers); 295 // The `tasks < workers` special case requires running by all workers. 296 for (size_t worker = 0; worker < num_workers; ++worker) { 297 tasks.WorkerRun(workers + worker, shared, stats); 298 } 299 300 // Ensure all tasks were run. 301 for (uint64_t task = begin; task < end; ++task) { 302 HWY_ASSERT_EQ(1000 + task, mementos[task - begin]); 303 } 304 } 305 } 306 307 WorkerLifecycle::Destroy(workers, num_workers); 308 } 309 } 310 311 // Ensures task parameter is in bounds, every parameter is reached, 312 // pool can be reused (multiple consecutive Run calls), pool can be destroyed 313 // (joining with its threads), num_threads=0 works (runs on current thread). 314 TEST(ThreadPoolTest, TestPool) { 315 if (!hwy::HaveThreadingSupport()) return; 316 317 constexpr uint64_t kMaxTasks = 20; 318 static std::atomic<uint64_t> mementos[kMaxTasks]; 319 static std::atomic<uint64_t> a_begin; 320 static std::atomic<uint64_t> a_end; 321 static std::atomic<uint64_t> a_num_workers; 322 323 // Called by pool; sets mementos and runs a nested but serial Run. 324 const auto func = [](uint64_t task, size_t worker) { 325 HWY_ASSERT(worker < a_num_workers.load()); 326 const uint64_t begin = a_begin.load(std::memory_order_acquire); 327 const uint64_t end = a_end.load(std::memory_order_acquire); 328 329 if (!(begin <= task && task < end)) { 330 HWY_ABORT("Task %d not in [%d, %d]", static_cast<int>(task), 331 static_cast<int>(begin), static_cast<int>(end)); 332 } 333 334 // Store mementos ensure we visited each task. 335 mementos[task - begin].store(1000 + task); 336 337 // Re-entering Run is fine on a 0-worker pool. Note that this must be 338 // per-thread so that it gets the `global_idx` it is running on. 339 hwy::ThreadPool inner(0); 340 inner.Run(begin, end, 341 [begin, end](uint64_t inner_task, size_t inner_worker) { 342 HWY_ASSERT(inner_worker == 0); 343 HWY_ASSERT(begin <= inner_task && inner_task < end); 344 }); 345 }; 346 347 for (size_t num_threads = 0; num_threads <= 6; num_threads += 3) { 348 hwy::ThreadPool pool(HWY_MIN(ThreadPool::MaxThreads(), num_threads)); 349 a_num_workers.store(pool.NumWorkers()); 350 for (bool spin : {true, false}) { 351 pool.SetWaitMode(spin ? PoolWaitMode::kSpin : PoolWaitMode::kBlock); 352 353 for (uint64_t num_tasks = 0; num_tasks < kMaxTasks; ++num_tasks) { 354 for (uint64_t all_begin = 0; all_begin < AdjustedReps(32); 355 ++all_begin) { 356 const uint64_t all_end = all_begin + num_tasks; 357 a_begin.store(all_begin, std::memory_order_release); 358 a_end.store(all_end, std::memory_order_release); 359 360 for (size_t i = 0; i < kMaxTasks; ++i) { 361 mementos[i].store(0); 362 } 363 364 pool.Run(all_begin, all_end, func); 365 366 for (uint64_t task = all_begin; task < all_end; ++task) { 367 const uint64_t expected = 1000 + task; 368 const uint64_t actual = mementos[task - all_begin].load(); 369 if (expected != actual) { 370 HWY_ABORT( 371 "threads %zu, tasks %d: task not run, expected %d, got %d\n", 372 num_threads, static_cast<int>(num_tasks), 373 static_cast<int>(expected), static_cast<int>(actual)); 374 } 375 } 376 } 377 } 378 } 379 } 380 } 381 382 // Debug tsan builds seem to generate incorrect codegen for [&] of atomics, so 383 // use a pointer to a state object instead. 384 struct SmallAssignmentState { 385 // (Avoid mutex because it may perturb the worker thread scheduling) 386 std::atomic<uint64_t> num_tasks{0}; 387 std::atomic<uint64_t> num_workers{0}; 388 std::atomic<uint64_t> id_bits{0}; 389 std::atomic<uint64_t> num_calls{0}; 390 }; 391 392 // Verify "thread" parameter when processing few tasks. 393 TEST(ThreadPoolTest, TestSmallAssignments) { 394 if (!hwy::HaveThreadingSupport()) return; 395 396 static SmallAssignmentState state; 397 398 for (size_t num_threads : 399 {size_t{0}, size_t{1}, size_t{3}, size_t{5}, size_t{8}}) { 400 ThreadPool pool(HWY_MIN(ThreadPool::MaxThreads(), num_threads)); 401 state.num_workers.store(pool.NumWorkers()); 402 403 for (size_t mul = 1; mul <= 2; ++mul) { 404 const size_t num_tasks = pool.NumWorkers() * mul; 405 state.num_tasks.store(num_tasks); 406 state.id_bits.store(0); 407 state.num_calls.store(0); 408 409 pool.Run(0, num_tasks, [](uint64_t task, size_t worker) { 410 HWY_ASSERT(task < state.num_tasks.load()); 411 HWY_ASSERT(worker < state.num_workers.load()); 412 413 state.num_calls.fetch_add(1); 414 415 uint64_t bits = state.id_bits.load(); 416 while (!state.id_bits.compare_exchange_weak(bits, 417 bits | (1ULL << worker))) { 418 } 419 }); 420 421 // Correct number of tasks. 422 const uint64_t actual_calls = state.num_calls.load(); 423 HWY_ASSERT(num_tasks == actual_calls); 424 425 const size_t num_participants = PopCount(state.id_bits.load()); 426 // <= because some workers may not manage to run any tasks. 427 HWY_ASSERT(num_participants <= pool.NumWorkers()); 428 } 429 } 430 } 431 432 struct Counter { 433 Counter() { 434 // Suppress "unused-field" warning. 435 (void)padding; 436 } 437 void Assimilate(const Counter& victim) { counter += victim.counter; } 438 std::atomic<uint64_t> counter{0}; 439 uint64_t padding[15]; 440 }; 441 442 // Can switch between any wait mode, and multiple times. 443 TEST(ThreadPoolTest, TestWaitMode) { 444 if (!hwy::HaveThreadingSupport()) return; 445 446 ThreadPool pool(9); 447 RandomState rng; 448 for (size_t i = 0; i < 100; ++i) { 449 pool.SetWaitMode((Random32(&rng) & 1u) ? PoolWaitMode::kSpin 450 : PoolWaitMode::kBlock); 451 } 452 } 453 454 TEST(ThreadPoolTest, TestCounter) { 455 if (!hwy::HaveThreadingSupport()) return; 456 457 const size_t kNumThreads = 12; 458 ThreadPool pool(kNumThreads); 459 for (PoolWaitMode mode : {PoolWaitMode::kSpin, PoolWaitMode::kBlock}) { 460 pool.SetWaitMode(mode); 461 alignas(128) Counter counters[1+kNumThreads]; 462 463 const uint64_t kNumTasks = kNumThreads * 19; 464 pool.Run(0, kNumTasks, 465 [&counters](const uint64_t task, const size_t worker) { 466 counters[worker].counter.fetch_add(task); 467 }); 468 469 uint64_t expected = 0; 470 for (uint64_t i = 0; i < kNumTasks; ++i) { 471 expected += i; 472 } 473 474 for (size_t i = 1; i < pool.NumWorkers(); ++i) { 475 counters[0].Assimilate(counters[i]); 476 } 477 HWY_ASSERT_EQ(expected, counters[0].counter.load()); 478 } 479 } 480 481 } // namespace 482 } // namespace pool 483 } // namespace hwy 484 485 HWY_TEST_MAIN();