bit_set.h (11984B)
1 // Copyright 2024 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #ifndef HIGHWAY_HWY_BIT_SET_H_ 17 #define HIGHWAY_HWY_BIT_SET_H_ 18 19 // Various BitSet for 64, up to 4096, or any number of bits. 20 21 #include <stddef.h> 22 23 #include <atomic> 24 25 #include "hwy/base.h" 26 27 namespace hwy { 28 29 // 64-bit specialization of `std::bitset`, which lacks `Foreach`. 30 class BitSet64 { 31 public: 32 constexpr size_t MaxSize() const { return 64; } 33 34 // No harm if `i` is already set. 35 void Set(size_t i) { 36 HWY_DASSERT(i < 64); 37 bits_ |= (1ULL << i); 38 HWY_DASSERT(Get(i)); 39 } 40 41 // Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does 42 // not clear any existing bits. 43 void SetNonzeroBitsFrom64(uint64_t bits) { bits_ |= bits; } 44 45 void Clear(size_t i) { 46 HWY_DASSERT(i < 64); 47 bits_ &= ~(1ULL << i); 48 } 49 50 bool Get(size_t i) const { 51 HWY_DASSERT(i < 64); 52 return (bits_ & (1ULL << i)) != 0; 53 } 54 55 // Returns true if Get(i) would return true for any i in [0, 64). 56 bool Any() const { return bits_ != 0; } 57 58 // Returns true if Get(i) would return true for all i in [0, 64). 59 bool All() const { return bits_ == ~uint64_t{0}; } 60 61 // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! 62 size_t First() const { 63 HWY_DASSERT(Any()); 64 return Num0BitsBelowLS1Bit_Nonzero64(bits_); 65 } 66 67 // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! 68 size_t First0() const { 69 HWY_DASSERT(!All()); 70 return Num0BitsBelowLS1Bit_Nonzero64(~bits_); 71 } 72 73 // Returns uint64_t(Get(i)) << i for i in [0, 64). 74 uint64_t Get64() const { return bits_; } 75 76 // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify 77 // the set, but the current Foreach call is unaffected. 78 template <class Func> 79 void Foreach(const Func& func) const { 80 uint64_t remaining_bits = bits_; 81 while (remaining_bits != 0) { 82 const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits); 83 remaining_bits &= remaining_bits - 1; // clear LSB 84 func(i); 85 } 86 } 87 88 size_t Count() const { return PopCount(bits_); } 89 90 private: 91 uint64_t bits_ = 0; 92 }; 93 94 // Any number of bits, flat array. 95 template <size_t kMaxSize> 96 class BitSet { 97 static_assert(kMaxSize != 0, "BitSet requires non-zero size"); 98 99 public: 100 constexpr size_t MaxSize() const { return kMaxSize; } 101 102 // No harm if `i` is already set. 103 void Set(size_t i) { 104 HWY_DASSERT(i < kMaxSize); 105 const size_t idx = i / 64; 106 const size_t mod = i % 64; 107 bits_[idx].Set(mod); 108 } 109 110 void Clear(size_t i) { 111 HWY_DASSERT(i < kMaxSize); 112 const size_t idx = i / 64; 113 const size_t mod = i % 64; 114 bits_[idx].Clear(mod); 115 HWY_DASSERT(!Get(i)); 116 } 117 118 bool Get(size_t i) const { 119 HWY_DASSERT(i < kMaxSize); 120 const size_t idx = i / 64; 121 const size_t mod = i % 64; 122 return bits_[idx].Get(mod); 123 } 124 125 // Returns true if Get(i) would return true for any i in [0, kMaxSize). 126 bool Any() const { 127 for (const BitSet64& bits : bits_) { 128 if (bits.Any()) return true; 129 } 130 return false; 131 } 132 133 // Returns true if Get(i) would return true for all i in [0, kMaxSize). 134 bool All() const { 135 for (size_t idx = 0; idx < kNum64 - 1; ++idx) { 136 if (!bits_[idx].All()) return false; 137 } 138 139 constexpr size_t kRemainder = kMaxSize % 64; 140 if (kRemainder == 0) { 141 return bits_[kNum64 - 1].All(); 142 } 143 return bits_[kNum64 - 1].Count() == kRemainder; 144 } 145 146 // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! 147 size_t First() const { 148 HWY_DASSERT(Any()); 149 for (size_t idx = 0;; ++idx) { 150 HWY_DASSERT(idx < kNum64); 151 if (bits_[idx].Any()) return idx * 64 + bits_[idx].First(); 152 } 153 } 154 155 // Returns lowest i such that `!Get(i)`. Caller must first ensure `All()`! 156 size_t First0() const { 157 HWY_DASSERT(!All()); 158 for (size_t idx = 0;; ++idx) { 159 HWY_DASSERT(idx < kNum64); 160 if (!bits_[idx].All()) { 161 const size_t first0 = idx * 64 + bits_[idx].First0(); 162 HWY_DASSERT(first0 < kMaxSize); 163 return first0; 164 } 165 } 166 } 167 168 // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify 169 // the set, but the current Foreach call is only affected if changing one of 170 // the not yet visited BitSet64. 171 template <class Func> 172 void Foreach(const Func& func) const { 173 for (size_t idx = 0; idx < kNum64; ++idx) { 174 bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); }); 175 } 176 } 177 178 size_t Count() const { 179 size_t total = 0; 180 for (const BitSet64& bits : bits_) { 181 total += bits.Count(); 182 } 183 return total; 184 } 185 186 private: 187 static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); 188 BitSet64 bits_[kNum64]; 189 }; 190 191 // Any number of bits, flat array, atomic updates to the u64. 192 template <size_t kMaxSize> 193 class AtomicBitSet { 194 static_assert(kMaxSize != 0, "AtomicBitSet requires non-zero size"); 195 196 // Bits may signal something to other threads, hence relaxed is insufficient. 197 // Acq/Rel ensures a happens-before relationship. 198 static constexpr auto kAcq = std::memory_order_acquire; 199 static constexpr auto kRel = std::memory_order_release; 200 201 public: 202 constexpr size_t MaxSize() const { return kMaxSize; } 203 204 // No harm if `i` is already set. 205 void Set(size_t i) { 206 HWY_DASSERT(i < kMaxSize); 207 const size_t idx = i / 64; 208 const size_t mod = i % 64; 209 bits_[idx].fetch_or(1ULL << mod, kRel); 210 } 211 212 void Clear(size_t i) { 213 HWY_DASSERT(i < kMaxSize); 214 const size_t idx = i / 64; 215 const size_t mod = i % 64; 216 bits_[idx].fetch_and(~(1ULL << mod), kRel); 217 HWY_DASSERT(!Get(i)); 218 } 219 220 bool Get(size_t i) const { 221 HWY_DASSERT(i < kMaxSize); 222 const size_t idx = i / 64; 223 const size_t mod = i % 64; 224 return ((bits_[idx].load(kAcq) & (1ULL << mod))) != 0; 225 } 226 227 // Returns true if Get(i) would return true for any i in [0, kMaxSize). 228 bool Any() const { 229 for (const std::atomic<uint64_t>& bits : bits_) { 230 if (bits.load(kAcq)) return true; 231 } 232 return false; 233 } 234 235 // Returns true if Get(i) would return true for all i in [0, kMaxSize). 236 bool All() const { 237 for (size_t idx = 0; idx < kNum64 - 1; ++idx) { 238 if (bits_[idx].load(kAcq) != ~uint64_t{0}) return false; 239 } 240 241 constexpr size_t kRemainder = kMaxSize % 64; 242 const uint64_t last_bits = bits_[kNum64 - 1].load(kAcq); 243 if (kRemainder == 0) { 244 return last_bits == ~uint64_t{0}; 245 } 246 return PopCount(last_bits) == kRemainder; 247 } 248 249 // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! 250 size_t First() const { 251 HWY_DASSERT(Any()); 252 for (size_t idx = 0;; ++idx) { 253 HWY_DASSERT(idx < kNum64); 254 const uint64_t bits = bits_[idx].load(kAcq); 255 if (bits != 0) { 256 return idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(bits); 257 } 258 } 259 } 260 261 // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! 262 size_t First0() const { 263 HWY_DASSERT(!All()); 264 for (size_t idx = 0;; ++idx) { 265 HWY_DASSERT(idx < kNum64); 266 const uint64_t inv_bits = ~bits_[idx].load(kAcq); 267 if (inv_bits != 0) { 268 const size_t first0 = 269 idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(inv_bits); 270 HWY_DASSERT(first0 < kMaxSize); 271 return first0; 272 } 273 } 274 } 275 276 // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify 277 // the set, but the current Foreach call is only affected if changing one of 278 // the not yet visited uint64_t. 279 template <class Func> 280 void Foreach(const Func& func) const { 281 for (size_t idx = 0; idx < kNum64; ++idx) { 282 uint64_t remaining_bits = bits_[idx].load(kAcq); 283 while (remaining_bits != 0) { 284 const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits); 285 remaining_bits &= remaining_bits - 1; // clear LSB 286 func(idx * 64 + i); 287 } 288 } 289 } 290 291 size_t Count() const { 292 size_t total = 0; 293 for (const std::atomic<uint64_t>& bits : bits_) { 294 total += PopCount(bits.load(kAcq)); 295 } 296 return total; 297 } 298 299 private: 300 static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); 301 std::atomic<uint64_t> bits_[kNum64] = {}; 302 }; 303 304 // Two-level bitset for up to `kMaxSize` <= 4096 values. The iterators 305 // (`Any/First/Foreach/Count`) are more efficient than `BitSet` for sparse sets. 306 // This comes at the cost of slightly slower mutators (`Set/Clear`). 307 template <size_t kMaxSize = 4096> 308 class BitSet4096 { 309 static_assert(kMaxSize != 0, "BitSet4096 requires non-zero size"); 310 311 public: 312 constexpr size_t MaxSize() const { return kMaxSize; } 313 314 // No harm if `i` is already set. 315 void Set(size_t i) { 316 HWY_DASSERT(i < kMaxSize); 317 const size_t idx = i / 64; 318 const size_t mod = i % 64; 319 bits_[idx].Set(mod); 320 nonzero_.Set(idx); 321 HWY_DASSERT(Get(i)); 322 } 323 324 // Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does 325 // not clear any existing bits. 326 void SetNonzeroBitsFrom64(uint64_t bits) { 327 bits_[0].SetNonzeroBitsFrom64(bits); 328 if (bits) nonzero_.Set(0); 329 } 330 331 void Clear(size_t i) { 332 HWY_DASSERT(i < kMaxSize); 333 const size_t idx = i / 64; 334 const size_t mod = i % 64; 335 bits_[idx].Clear(mod); 336 if (!bits_[idx].Any()) { 337 nonzero_.Clear(idx); 338 } 339 HWY_DASSERT(!Get(i)); 340 } 341 342 bool Get(size_t i) const { 343 HWY_DASSERT(i < kMaxSize); 344 const size_t idx = i / 64; 345 const size_t mod = i % 64; 346 return bits_[idx].Get(mod); 347 } 348 349 // Returns true if `Get(i)` would return true for any i in [0, kMaxSize). 350 bool Any() const { return nonzero_.Any(); } 351 352 // Returns true if `Get(i)` would return true for all i in [0, kMaxSize). 353 bool All() const { 354 // Do not check `nonzero_.All()` - that only works if `kMaxSize` is 4096. 355 if (nonzero_.Count() != kNum64) return false; 356 return Count() == kMaxSize; 357 } 358 359 // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! 360 size_t First() const { 361 HWY_DASSERT(Any()); 362 const size_t idx = nonzero_.First(); 363 return idx * 64 + bits_[idx].First(); 364 } 365 366 // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! 367 size_t First0() const { 368 HWY_DASSERT(!All()); 369 // It is likely not worthwhile to have a separate `BitSet64` for `not_all_`, 370 // hence iterate over all u64. 371 for (size_t idx = 0;; ++idx) { 372 HWY_DASSERT(idx < kNum64); 373 if (!bits_[idx].All()) { 374 const size_t first0 = idx * 64 + bits_[idx].First0(); 375 HWY_DASSERT(first0 < kMaxSize); 376 return first0; 377 } 378 } 379 } 380 381 // Returns uint64_t(Get(i)) << i for i in [0, 64). 382 uint64_t Get64() const { return bits_[0].Get64(); } 383 384 // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify 385 // the set, but the current Foreach call is only affected if changing one of 386 // the not yet visited BitSet64 for which Any() is true. 387 template <class Func> 388 void Foreach(const Func& func) const { 389 nonzero_.Foreach([&func, this](size_t idx) { 390 bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); }); 391 }); 392 } 393 394 size_t Count() const { 395 size_t total = 0; 396 nonzero_.Foreach( 397 [&total, this](size_t idx) { total += bits_[idx].Count(); }); 398 return total; 399 } 400 401 private: 402 static_assert(kMaxSize <= 64 * 64, "One BitSet64 insufficient"); 403 static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); 404 BitSet64 nonzero_; 405 BitSet64 bits_[kNum64]; 406 }; 407 408 } // namespace hwy 409 410 #endif // HIGHWAY_HWY_BIT_SET_H_