tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

thread_pool_test.cc (16609B)


      1 // Copyright 2023 Google LLC
      2 // SPDX-License-Identifier: Apache-2.0
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 
     16 // Modified from BSD-licensed code
     17 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
     18 // See https://github.com/libjxl/libjxl/blob/main/LICENSE.
     19 
     20 #include "hwy/contrib/thread_pool/thread_pool.h"
     21 
     22 #include <math.h>  // sqrtf
     23 #include <stddef.h>
     24 #include <stdint.h>
     25 #include <stdio.h>
     26 
     27 #include <atomic>
     28 #include <thread>  // NOLINT
     29 #include <vector>
     30 
     31 #include "hwy/base.h"  // PopCount
     32 #include "hwy/contrib/thread_pool/spin.h"
     33 #include "hwy/contrib/thread_pool/topology.h"
     34 #include "hwy/tests/hwy_gtest.h"
     35 #include "hwy/tests/test_util-inl.h"  // AdjustedReps
     36 
     37 namespace hwy {
     38 namespace pool {
     39 namespace {
     40 
     41 TEST(ThreadPoolTest, TestCoprime) {
     42  // 1 is coprime with anything
     43  for (uint32_t i = 1; i < 500; ++i) {
     44    HWY_ASSERT(ShuffledIota::CoprimeNonzero(1, i));
     45    HWY_ASSERT(ShuffledIota::CoprimeNonzero(i, 1));
     46  }
     47 
     48  // Powers of two >= 2 are not coprime
     49  for (size_t i = 1; i < 20; ++i) {
     50    for (size_t j = 1; j < 20; ++j) {
     51      HWY_ASSERT(!ShuffledIota::CoprimeNonzero(1u << i, 1u << j));
     52    }
     53  }
     54 
     55  // 2^x and 2^x +/- 1 are coprime
     56  for (size_t i = 1; i < 30; ++i) {
     57    const uint32_t pow2 = 1u << i;
     58    HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2, pow2 + 1));
     59    HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2, pow2 - 1));
     60    HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2 + 1, pow2));
     61    HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2 - 1, pow2));
     62  }
     63 
     64  // Random number x * random y (both >= 2) is not co-prime with x nor y.
     65  RandomState rng;
     66  for (size_t i = 1; i < 5000; ++i) {
     67    const uint32_t x = (Random32(&rng) & 0xFFF7) + 2;
     68    const uint32_t y = (Random32(&rng) & 0xFFF7) + 2;
     69    HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x * y, x));
     70    HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x * y, y));
     71    HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x, x * y));
     72    HWY_ASSERT(!ShuffledIota::CoprimeNonzero(y, x * y));
     73  }
     74 
     75  // Primes are all coprime (list from https://oeis.org/A000040)
     76  static constexpr uint32_t primes[] = {
     77      2,   3,   5,   7,   11,  13,  17,  19,  23,  29,  31,  37,  41,  43,  47,
     78      53,  59,  61,  67,  71,  73,  79,  83,  89,  97,  101, 103, 107, 109, 113,
     79      127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197,
     80      199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271};
     81  for (size_t i = 0; i < sizeof(primes) / sizeof(primes[0]); ++i) {
     82    for (size_t j = i + 1; j < sizeof(primes) / sizeof(primes[0]); ++j) {
     83      HWY_ASSERT(ShuffledIota::CoprimeNonzero(primes[i], primes[j]));
     84      HWY_ASSERT(ShuffledIota::CoprimeNonzero(primes[j], primes[i]));
     85    }
     86  }
     87 }
     88 
     89 // Ensures `shuffled` visits [0, size) exactly once starting from `current`.
     90 void VerifyPermutation(uint32_t size, const Divisor64& divisor,
     91                       const ShuffledIota& shuffled, uint32_t current,
     92                       uint32_t* visited) {
     93  for (size_t i = 0; i < size; i++) {
     94    visited[i] = 0;
     95  }
     96 
     97  for (size_t i = 0; i < size; i++) {
     98    ++visited[current];
     99    current = shuffled.Next(current, divisor);
    100  }
    101 
    102  for (size_t i = 0; i < size; i++) {
    103    HWY_ASSERT(visited[i] == 1);
    104  }
    105 }
    106 
    107 // Verifies ShuffledIota generates a permutation of [0, size).
    108 TEST(ThreadPoolTest, TestRandomPermutation) {
    109  constexpr size_t kMaxSize = 40;
    110  uint32_t visited[kMaxSize];
    111 
    112  // Exhaustive enumeration of size and starting point.
    113  for (uint32_t size = 1; size < kMaxSize; ++size) {
    114    const Divisor64 divisor(size);
    115 
    116    const uint32_t coprime = ShuffledIota::FindAnotherCoprime(size, 1);
    117    const ShuffledIota shuffled(coprime);
    118 
    119    for (uint32_t start = 0; start < size; ++start) {
    120      VerifyPermutation(size, divisor, shuffled, start, visited);
    121    }
    122  }
    123 }
    124 
    125 // Verifies multiple ShuffledIota are relatively independent.
    126 TEST(ThreadPoolTest, TestMultiplePermutations) {
    127  constexpr size_t kMaxSize = 40;
    128  uint32_t coprimes[kMaxSize];
    129  // One per ShuffledIota; initially the starting value, then its Next().
    130  uint32_t current[kMaxSize];
    131 
    132  for (uint32_t size = 1; size < kMaxSize; ++size) {
    133    const Divisor64 divisor(size);
    134 
    135    // Create `size` ShuffledIota instances with unique coprimes.
    136    std::vector<ShuffledIota> shuffled;
    137    for (size_t i = 0; i < size; ++i) {
    138      coprimes[i] = ShuffledIota::FindAnotherCoprime(
    139          size, static_cast<uint32_t>((i + 1) * 257 + i * 13));
    140      shuffled.emplace_back(coprimes[i]);
    141    }
    142 
    143    // ShuffledIota[i] starts at i to match the worker thread use case.
    144    for (uint32_t i = 0; i < size; ++i) {
    145      current[i] = i;
    146    }
    147 
    148    size_t num_bad = 0;
    149    uint32_t all_visited[kMaxSize] = {0};
    150 
    151    // For each step, ensure there are few non-unique current[].
    152    for (size_t step = 0; step < size; ++step) {
    153      // How many times is each number visited?
    154      uint32_t visited[kMaxSize] = {0};
    155      for (size_t i = 0; i < size; ++i) {
    156        visited[current[i]] += 1;
    157        all_visited[current[i]] = 1;  // visited at all across all steps?
    158      }
    159 
    160      // How many numbers are visited multiple times?
    161      size_t num_contended = 0;
    162      uint32_t max_contention = 0;
    163      for (size_t i = 0; i < size; ++i) {
    164        num_contended += visited[i] > 1;
    165        max_contention = HWY_MAX(max_contention, visited[i]);
    166      }
    167 
    168      // Count/print if excessive collisions.
    169      const size_t expected =
    170          static_cast<size_t>(sqrtf(static_cast<float>(size)) * 2.0f);
    171      if ((num_contended > expected) && (max_contention > 3)) {
    172        ++num_bad;
    173        if (true) {
    174          fprintf(stderr, "size %u step %zu contended %zu max contention %u\n",
    175                  size, step, num_contended, max_contention);
    176          for (size_t i = 0; i < size; ++i) {
    177            fprintf(stderr, "  %u\n", current[i]);
    178          }
    179          fprintf(stderr, "coprimes\n");
    180          for (size_t i = 0; i < size; ++i) {
    181            fprintf(stderr, "  %u\n", coprimes[i]);
    182          }
    183        }
    184      }
    185 
    186      // Advance all ShuffledIota generators.
    187      for (size_t i = 0; i < size; ++i) {
    188        current[i] = shuffled[i].Next(current[i], divisor);
    189      }
    190    }  // step
    191 
    192    // Ensure each task was visited during at least one step.
    193    for (size_t i = 0; i < size; ++i) {
    194      HWY_ASSERT(all_visited[i] != 0);
    195    }
    196 
    197    if (num_bad != 0) {
    198      fprintf(stderr, "size %u total bad: %zu\n", size, num_bad);
    199    }
    200    HWY_ASSERT(num_bad < kMaxSize / 10);
    201  }  // size
    202 }
    203 
    204 class DoWait {
    205 public:
    206  explicit DoWait(Worker& worker) : worker_(worker) {}
    207 
    208  template <class Spin, class Wait>
    209  void operator()(const Spin& spin, const Wait& wait) const {
    210    wait.UntilWoken(worker_, spin);
    211  }
    212 
    213 private:
    214  Worker& worker_;
    215 };
    216 
    217 class DoWakeWorkers {
    218 public:
    219  explicit DoWakeWorkers(Worker* workers) : workers_(workers) {}
    220 
    221  template <class Spin, class Wait>
    222  void operator()(const Spin&, const Wait& wait) const {
    223    wait.WakeWorkers(workers_, workers_[0].WorkerEpoch());
    224  }
    225 
    226 private:
    227  Worker* const workers_;
    228 };
    229 
    230 // Verifies that waiter(s) can be woken by another thread.
    231 TEST(ThreadPoolTest, TestWaiter) {
    232  if (!hwy::HaveThreadingSupport()) return;
    233 
    234  // Not actual threads, but we allocate and loop over this many workers.
    235  for (size_t num_threads = 1; num_threads < 6; ++num_threads) {
    236    const size_t num_workers = 1 + num_threads;
    237    auto storage = hwy::AllocateAligned<uint8_t>(num_workers * sizeof(Worker));
    238    HWY_ASSERT(storage);
    239    const Divisor64 div_workers(num_workers);
    240    Shared& shared = Shared::Get();  // already calls ReserveWorker(0).
    241 
    242    for (WaitType wait_type :
    243         {WaitType::kBlock, WaitType::kSpin1, WaitType::kSpinSeparate}) {
    244      Worker* workers = pool::WorkerLifecycle::Init(
    245          storage.get(), num_threads, PoolWorkerMapping(), div_workers, shared);
    246 
    247      alignas(8) const Config config(SpinType::kPause, wait_type);
    248 
    249      // This thread acts as the "main thread", which will wake the actual main
    250      // and all its worker instances.
    251      std::thread thread(
    252          [&]() { CallWithConfig(config, DoWakeWorkers(workers)); });
    253 
    254      // main is 0
    255      for (size_t worker = 1; worker < num_workers; ++worker) {
    256        CallWithConfig(config, DoWait(workers[1]));
    257      }
    258      thread.join();
    259 
    260      pool::WorkerLifecycle::Destroy(workers, num_workers);
    261    }
    262  }
    263 }
    264 
    265 // Ensures all tasks are run. Similar to TestPool below but without threads.
    266 TEST(ThreadPoolTest, TestTasks) {
    267  for (size_t num_threads = 1; num_threads <= 8; ++num_threads) {
    268    const size_t num_workers = num_threads + 1;
    269    auto storage = hwy::AllocateAligned<uint8_t>(num_workers * sizeof(Worker));
    270    HWY_ASSERT(storage);
    271    const Divisor64 div_workers(num_workers);
    272    Shared& shared = Shared::Get();
    273    Stats stats;
    274    Worker* workers = WorkerLifecycle::Init(
    275        storage.get(), num_threads, PoolWorkerMapping(), div_workers, shared);
    276 
    277    constexpr uint64_t kMaxTasks = 20;
    278    uint64_t mementos[kMaxTasks];  // non-atomic, no threads involved.
    279    for (uint64_t num_tasks = 0; num_tasks < 20; ++num_tasks) {
    280      for (uint64_t begin = 0; begin < AdjustedReps(32); ++begin) {
    281        const uint64_t end = begin + num_tasks;
    282 
    283        ZeroBytes(mementos, sizeof(mementos));
    284        const auto func = [begin, end, &mementos](uint64_t task,
    285                                                  size_t /*worker*/) {
    286          HWY_ASSERT(begin <= task && task < end);
    287 
    288          // Store mementos ensure we visited each task.
    289          mementos[task - begin] = 1000 + task;
    290        };
    291        Tasks tasks;
    292        tasks.Set(begin, end, func);
    293 
    294        Tasks::DivideRangeAmongWorkers(begin, end, div_workers, workers);
    295        // The `tasks < workers` special case requires running by all workers.
    296        for (size_t worker = 0; worker < num_workers; ++worker) {
    297          tasks.WorkerRun(workers + worker, shared, stats);
    298        }
    299 
    300        // Ensure all tasks were run.
    301        for (uint64_t task = begin; task < end; ++task) {
    302          HWY_ASSERT_EQ(1000 + task, mementos[task - begin]);
    303        }
    304      }
    305    }
    306 
    307    WorkerLifecycle::Destroy(workers, num_workers);
    308  }
    309 }
    310 
    311 // Ensures task parameter is in bounds, every parameter is reached,
    312 // pool can be reused (multiple consecutive Run calls), pool can be destroyed
    313 // (joining with its threads), num_threads=0 works (runs on current thread).
    314 TEST(ThreadPoolTest, TestPool) {
    315  if (!hwy::HaveThreadingSupport()) return;
    316 
    317  constexpr uint64_t kMaxTasks = 20;
    318  static std::atomic<uint64_t> mementos[kMaxTasks];
    319  static std::atomic<uint64_t> a_begin;
    320  static std::atomic<uint64_t> a_end;
    321  static std::atomic<uint64_t> a_num_workers;
    322 
    323  // Called by pool; sets mementos and runs a nested but serial Run.
    324  const auto func = [](uint64_t task, size_t worker) {
    325    HWY_ASSERT(worker < a_num_workers.load());
    326    const uint64_t begin = a_begin.load(std::memory_order_acquire);
    327    const uint64_t end = a_end.load(std::memory_order_acquire);
    328 
    329    if (!(begin <= task && task < end)) {
    330      HWY_ABORT("Task %d not in [%d, %d]", static_cast<int>(task),
    331                static_cast<int>(begin), static_cast<int>(end));
    332    }
    333 
    334    // Store mementos ensure we visited each task.
    335    mementos[task - begin].store(1000 + task);
    336 
    337    // Re-entering Run is fine on a 0-worker pool. Note that this must be
    338    // per-thread so that it gets the `global_idx` it is running on.
    339    hwy::ThreadPool inner(0);
    340    inner.Run(begin, end,
    341              [begin, end](uint64_t inner_task, size_t inner_worker) {
    342                HWY_ASSERT(inner_worker == 0);
    343                HWY_ASSERT(begin <= inner_task && inner_task < end);
    344              });
    345  };
    346 
    347  for (size_t num_threads = 0; num_threads <= 6; num_threads += 3) {
    348    hwy::ThreadPool pool(HWY_MIN(ThreadPool::MaxThreads(), num_threads));
    349    a_num_workers.store(pool.NumWorkers());
    350    for (bool spin : {true, false}) {
    351      pool.SetWaitMode(spin ? PoolWaitMode::kSpin : PoolWaitMode::kBlock);
    352 
    353      for (uint64_t num_tasks = 0; num_tasks < kMaxTasks; ++num_tasks) {
    354        for (uint64_t all_begin = 0; all_begin < AdjustedReps(32);
    355             ++all_begin) {
    356          const uint64_t all_end = all_begin + num_tasks;
    357          a_begin.store(all_begin, std::memory_order_release);
    358          a_end.store(all_end, std::memory_order_release);
    359 
    360          for (size_t i = 0; i < kMaxTasks; ++i) {
    361            mementos[i].store(0);
    362          }
    363 
    364          pool.Run(all_begin, all_end, func);
    365 
    366          for (uint64_t task = all_begin; task < all_end; ++task) {
    367            const uint64_t expected = 1000 + task;
    368            const uint64_t actual = mementos[task - all_begin].load();
    369            if (expected != actual) {
    370              HWY_ABORT(
    371                  "threads %zu, tasks %d: task not run, expected %d, got %d\n",
    372                  num_threads, static_cast<int>(num_tasks),
    373                  static_cast<int>(expected), static_cast<int>(actual));
    374            }
    375          }
    376        }
    377      }
    378    }
    379  }
    380 }
    381 
    382 // Debug tsan builds seem to generate incorrect codegen for [&] of atomics, so
    383 // use a pointer to a state object instead.
    384 struct SmallAssignmentState {
    385  // (Avoid mutex because it may perturb the worker thread scheduling)
    386  std::atomic<uint64_t> num_tasks{0};
    387  std::atomic<uint64_t> num_workers{0};
    388  std::atomic<uint64_t> id_bits{0};
    389  std::atomic<uint64_t> num_calls{0};
    390 };
    391 
    392 // Verify "thread" parameter when processing few tasks.
    393 TEST(ThreadPoolTest, TestSmallAssignments) {
    394  if (!hwy::HaveThreadingSupport()) return;
    395 
    396  static SmallAssignmentState state;
    397 
    398  for (size_t num_threads :
    399       {size_t{0}, size_t{1}, size_t{3}, size_t{5}, size_t{8}}) {
    400    ThreadPool pool(HWY_MIN(ThreadPool::MaxThreads(), num_threads));
    401    state.num_workers.store(pool.NumWorkers());
    402 
    403    for (size_t mul = 1; mul <= 2; ++mul) {
    404      const size_t num_tasks = pool.NumWorkers() * mul;
    405      state.num_tasks.store(num_tasks);
    406      state.id_bits.store(0);
    407      state.num_calls.store(0);
    408 
    409      pool.Run(0, num_tasks, [](uint64_t task, size_t worker) {
    410        HWY_ASSERT(task < state.num_tasks.load());
    411        HWY_ASSERT(worker < state.num_workers.load());
    412 
    413        state.num_calls.fetch_add(1);
    414 
    415        uint64_t bits = state.id_bits.load();
    416        while (!state.id_bits.compare_exchange_weak(bits,
    417                                                    bits | (1ULL << worker))) {
    418        }
    419      });
    420 
    421      // Correct number of tasks.
    422      const uint64_t actual_calls = state.num_calls.load();
    423      HWY_ASSERT(num_tasks == actual_calls);
    424 
    425      const size_t num_participants = PopCount(state.id_bits.load());
    426      // <= because some workers may not manage to run any tasks.
    427      HWY_ASSERT(num_participants <= pool.NumWorkers());
    428    }
    429  }
    430 }
    431 
    432 struct Counter {
    433  Counter() {
    434    // Suppress "unused-field" warning.
    435    (void)padding;
    436  }
    437  void Assimilate(const Counter& victim) { counter += victim.counter; }
    438  std::atomic<uint64_t> counter{0};
    439  uint64_t padding[15];
    440 };
    441 
    442 // Can switch between any wait mode, and multiple times.
    443 TEST(ThreadPoolTest, TestWaitMode) {
    444  if (!hwy::HaveThreadingSupport()) return;
    445 
    446  ThreadPool pool(9);
    447  RandomState rng;
    448  for (size_t i = 0; i < 100; ++i) {
    449    pool.SetWaitMode((Random32(&rng) & 1u) ? PoolWaitMode::kSpin
    450                                    : PoolWaitMode::kBlock);
    451  }
    452 }
    453 
    454 TEST(ThreadPoolTest, TestCounter) {
    455  if (!hwy::HaveThreadingSupport()) return;
    456 
    457  const size_t kNumThreads = 12;
    458  ThreadPool pool(kNumThreads);
    459  for (PoolWaitMode mode : {PoolWaitMode::kSpin, PoolWaitMode::kBlock}) {
    460    pool.SetWaitMode(mode);
    461    alignas(128) Counter counters[1+kNumThreads];
    462 
    463    const uint64_t kNumTasks = kNumThreads * 19;
    464    pool.Run(0, kNumTasks,
    465             [&counters](const uint64_t task, const size_t worker) {
    466               counters[worker].counter.fetch_add(task);
    467             });
    468 
    469    uint64_t expected = 0;
    470    for (uint64_t i = 0; i < kNumTasks; ++i) {
    471      expected += i;
    472    }
    473 
    474    for (size_t i = 1; i < pool.NumWorkers(); ++i) {
    475      counters[0].Assimilate(counters[i]);
    476    }
    477    HWY_ASSERT_EQ(expected, counters[0].counter.load());
    478  }
    479 }
    480 
    481 }  // namespace
    482 }  // namespace pool
    483 }  // namespace hwy
    484 
    485 HWY_TEST_MAIN();