bit_set_test.cc (7967B)
1 // Copyright 2024 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #include "hwy/bit_set.h" 17 18 #include <stddef.h> 19 #include <stdint.h> 20 #include <stdio.h> 21 22 #include <algorithm> // std::find 23 #include <map> 24 #include <utility> // std::make_pair 25 #include <vector> 26 27 #include "hwy/base.h" 28 #include "hwy/tests/hwy_gtest.h" 29 #include "hwy/tests/test_util-inl.h" 30 #include "hwy/tests/test_util.h" 31 32 namespace hwy { 33 namespace { 34 35 template <class Set> 36 void SmokeTest() { 37 constexpr size_t kMax = Set().MaxSize() - 1; 38 39 Set set; 40 // Defaults to empty. 41 HWY_ASSERT(!set.Any()); 42 HWY_ASSERT(!set.All()); 43 HWY_ASSERT(!set.Get(0)); 44 HWY_ASSERT(!set.Get(kMax)); 45 HWY_ASSERT(set.First0() == 0); 46 set.Foreach( 47 [](size_t i) { HWY_ABORT("Set should be empty but got %zu\n", i); }); 48 HWY_ASSERT(set.Count() == 0); 49 50 // After setting, we can retrieve it. 51 set.Set(kMax); 52 HWY_ASSERT(set.Get(kMax)); 53 HWY_ASSERT(set.Any()); 54 HWY_ASSERT(!set.All()); 55 HWY_ASSERT(set.First() == kMax); 56 HWY_ASSERT(set.First0() == 0); 57 set.Foreach([](size_t i) { HWY_ASSERT(i == kMax); }); 58 HWY_ASSERT(set.Count() == 1); 59 60 // After clearing, it is empty again. 61 set.Clear(kMax); 62 set.Clear(0); // was not set 63 HWY_ASSERT(!set.Get(0)); 64 HWY_ASSERT(!set.Get(kMax)); 65 HWY_ASSERT(!set.Any()); 66 HWY_ASSERT(!set.All()); 67 HWY_ASSERT(set.First0() == 0); 68 set.Foreach( 69 [](size_t i) { HWY_ABORT("Set should be empty but got %zu\n", i); }); 70 HWY_ASSERT(set.Count() == 0); 71 } 72 73 TEST(BitSetTest, SmokeTestSet64) { SmokeTest<BitSet64>(); } 74 TEST(BitSetTest, SmokeTestSet) { SmokeTest<BitSet<320>>(); } 75 TEST(BitSetTest, SmokeTestAtomicSet) { SmokeTest<AtomicBitSet<400>>(); } 76 TEST(BitSetTest, SmokeTestSet4096) { SmokeTest<BitSet4096<>>(); } 77 78 template <class Set> 79 void TestSetNonzeroBitsFrom64() { 80 constexpr size_t kMin = 0; 81 Set set; 82 set.SetNonzeroBitsFrom64(1ull << kMin); 83 HWY_ASSERT(set.Any()); 84 HWY_ASSERT(!set.All()); 85 HWY_ASSERT(set.Get(kMin)); 86 HWY_ASSERT(set.First() == kMin); 87 HWY_ASSERT(set.First0() == kMin + 1); 88 set.Foreach([](size_t i) { HWY_ASSERT(i == kMin); }); 89 HWY_ASSERT(set.Count() == 1); 90 91 set.SetNonzeroBitsFrom64(0x70ULL); 92 HWY_ASSERT(set.Get(kMin) && set.Get(4) && set.Get(5) && set.Get(6)); 93 HWY_ASSERT(set.Any()); 94 HWY_ASSERT(!set.All()); 95 HWY_ASSERT(set.First() == kMin); // does not clear existing bits 96 HWY_ASSERT(set.First0() == kMin + 1); 97 set.Foreach([](size_t i) { HWY_ASSERT(i == kMin || (4 <= i && i <= 6)); }); 98 HWY_ASSERT(set.Count() == 4); 99 } 100 101 TEST(BitSetTest, TestSetNonzeroBits64) { TestSetNonzeroBitsFrom64<BitSet64>(); } 102 TEST(BitSetTest, TestSetNonzeroBits4096) { 103 TestSetNonzeroBitsFrom64<BitSet4096<>>(); 104 } 105 106 // Reference implementation using map (for sparse `BitSet4096`) and vector for 107 // random choice of elements. 108 class SlowSet { 109 public: 110 // Inserting multiple times is a no-op. 111 void Set(size_t i) { 112 const auto ib = idx_for_i_.insert(std::make_pair(i, vec_.size())); 113 if (ib.second) { // inserted 114 vec_.push_back(i); 115 HWY_ASSERT(idx_for_i_.size() == vec_.size()); 116 } else { 117 // Already have `i` and it can be found at the stored index. 118 HWY_ASSERT(ib.first->first == i); 119 const size_t idx = ib.first->second; 120 HWY_ASSERT(vec_[idx] == i); 121 } 122 HWY_ASSERT(Get(i)); 123 } 124 125 bool Get(size_t i) const { 126 const auto it = idx_for_i_.find(i); 127 if (it == idx_for_i_.end()) { 128 HWY_ASSERT(std::find(vec_.begin(), vec_.end(), i) == vec_.end()); 129 return false; 130 } 131 HWY_ASSERT(vec_[it->second] == i); 132 return true; 133 } 134 135 void Clear(size_t i) { 136 if (!Get(i)) return; 137 const size_t idx = idx_for_i_[i]; 138 idx_for_i_.erase(i); 139 // Move last into gap, unless it was equal to `i`. 140 const size_t last = vec_.back(); 141 vec_.pop_back(); 142 if (last == i) { 143 HWY_ASSERT(idx == vec_.size()); // was the last item 144 } else { 145 HWY_ASSERT(vec_[idx] == i); 146 vec_[idx] = last; 147 idx_for_i_[last] = idx; 148 HWY_ASSERT(Get(last)); // can still find `last` 149 } 150 HWY_ASSERT(!Get(i)); 151 } 152 153 size_t Count() const { 154 HWY_ASSERT(idx_for_i_.size() == vec_.size()); 155 return vec_.size(); 156 } 157 158 // Must not call if Count() == 0. 159 size_t RandomChoice(RandomState& rng) const { 160 HWY_ASSERT(Count() != 0); 161 const size_t idx = static_cast<size_t>(hwy::Random32(&rng)) % vec_.size(); 162 return vec_[idx]; 163 } 164 165 template <class Set> 166 void CheckSame(const Set& set) { 167 HWY_ASSERT(set.Any() == (set.Count() != 0)); 168 HWY_ASSERT(set.All() == (set.Count() == set.MaxSize())); 169 HWY_ASSERT(Count() == set.Count()); 170 // Everything set has, we also have. 171 set.Foreach([this](size_t i) { HWY_ASSERT(Get(i)); }); 172 // Everything we have, set also has. 173 std::for_each(vec_.begin(), vec_.end(), 174 [&set](size_t i) { HWY_ASSERT(set.Get(i)); }); 175 // First matches first in the map 176 if (set.Any()) { 177 HWY_ASSERT(set.First() == idx_for_i_.begin()->first); 178 } 179 if (!set.All()) { 180 const size_t idx0 = set.First0(); 181 HWY_ASSERT(idx0 < set.MaxSize()); 182 HWY_ASSERT(!set.Get(idx0)); 183 HWY_ASSERT(!Get(idx0)); 184 } 185 } 186 187 private: 188 std::vector<size_t> vec_; 189 std::map<size_t, size_t> idx_for_i_; 190 }; 191 192 template <class Set> 193 void TestSetWithGrowProb(uint64_t grow_prob) { 194 constexpr uint32_t max_size = static_cast<uint32_t>(Set().MaxSize()); 195 RandomState rng; 196 197 // Multiple independent random tests: 198 for (size_t rep = 0; rep < AdjustedReps(100); ++rep) { 199 Set set; 200 SlowSet slow_set; 201 // Mutate sets via random walk and ensure they are the same afterwards. 202 for (size_t iter = 0; iter < AdjustedReps(1000); ++iter) { 203 const uint64_t bits = (Random64(&rng) >> 10) & 0x3FF; 204 if (bits > 980 && slow_set.Count() != 0) { 205 // Small chance of reinsertion: already present, unchanged after. 206 const size_t i = slow_set.RandomChoice(rng); 207 const size_t count = set.Count(); 208 HWY_ASSERT(set.Get(i)); 209 slow_set.Set(i); 210 set.Set(i); 211 HWY_ASSERT(set.Get(i)); 212 HWY_ASSERT(count == set.Count()); 213 } else if (bits < grow_prob) { 214 // Set random value; no harm if already set. 215 const size_t i = static_cast<size_t>(Random32(&rng) % max_size); 216 slow_set.Set(i); 217 set.Set(i); 218 HWY_ASSERT(set.Get(i)); 219 } else if (slow_set.Count() != 0) { 220 // Remove existing item. 221 const size_t i = slow_set.RandomChoice(rng); 222 const size_t count = set.Count(); 223 HWY_ASSERT(set.Get(i)); 224 slow_set.Clear(i); 225 set.Clear(i); 226 HWY_ASSERT(!set.Get(i)); 227 HWY_ASSERT(count == set.Count() + 1); 228 } 229 } 230 slow_set.CheckSame(set); 231 } 232 } 233 234 template <class Set> 235 void TestSetRandom() { 236 // Lower probability of growth so that the set is often nearly empty. 237 TestSetWithGrowProb<Set>(400); 238 239 TestSetWithGrowProb<Set>(600); 240 } 241 242 TEST(BitSetTest, TestSet64) { TestSetRandom<BitSet64>(); } 243 TEST(BitSetTest, TestSet41) { TestSetRandom<BitSet<41>>(); } 244 TEST(BitSetTest, TestSet) { TestSetRandom<BitSet<199>>(); } 245 // One partial u64 246 TEST(BitSetTest, TestAtomicSet32) { TestSetRandom<AtomicBitSet<32>>(); } 247 // 3 whole u64 248 TEST(BitSetTest, TestAtomicSet192) { TestSetRandom<AtomicBitSet<192>>(); } 249 TEST(BitSetTest, TestSet3000) { TestSetRandom<BitSet4096<3000>>(); } 250 TEST(BitSetTest, TestSet4096) { TestSetRandom<BitSet4096<>>(); } 251 252 } // namespace 253 } // namespace hwy 254 255 HWY_TEST_MAIN();