Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 * Copyright (c) 2023, Jelle Raaijmakers <jelle@gmta.nl>
4 *
5 * SPDX-License-Identifier: BSD-2-Clause
6 */
7
8#pragma once
9
10#include <AK/Concepts.h>
11#include <AK/Error.h>
12#include <AK/StdLibExtras.h>
13#include <AK/Traits.h>
14#include <AK/Types.h>
15#include <AK/kmalloc.h>
16
17namespace AK {
18
19enum class HashSetResult {
20 InsertedNewEntry,
21 ReplacedExistingEntry,
22 KeptExistingEntry,
23};
24
25enum class HashSetExistingEntryBehavior {
26 Keep,
27 Replace,
28};
29
30// BucketState doubles as both an enum and a probe length value.
31// - Free: empty bucket
32// - Used (implicit, values 1..254): value-1 represents probe length
33// - CalculateLength: same as Used but probe length > 253, so we calculate the actual probe length
34enum class BucketState : u8 {
35 Free = 0,
36 CalculateLength = 255,
37};
38
39template<typename HashTableType, typename T, typename BucketType>
40class HashTableIterator {
41 friend HashTableType;
42
43public:
44 bool operator==(HashTableIterator const& other) const { return m_bucket == other.m_bucket; }
45 bool operator!=(HashTableIterator const& other) const { return m_bucket != other.m_bucket; }
46 T& operator*() { return *m_bucket->slot(); }
47 T* operator->() { return m_bucket->slot(); }
48 void operator++() { skip_to_next(); }
49
50private:
51 void skip_to_next()
52 {
53 if (!m_bucket)
54 return;
55 do {
56 ++m_bucket;
57 if (m_bucket == m_end_bucket) {
58 m_bucket = nullptr;
59 return;
60 }
61 } while (m_bucket->state == BucketState::Free);
62 }
63
64 HashTableIterator(BucketType* bucket, BucketType* end_bucket)
65 : m_bucket(bucket)
66 , m_end_bucket(end_bucket)
67 {
68 }
69
70 BucketType* m_bucket { nullptr };
71 BucketType* m_end_bucket { nullptr };
72};
73
74template<typename OrderedHashTableType, typename T, typename BucketType>
75class OrderedHashTableIterator {
76 friend OrderedHashTableType;
77
78public:
79 bool operator==(OrderedHashTableIterator const& other) const { return m_bucket == other.m_bucket; }
80 bool operator!=(OrderedHashTableIterator const& other) const { return m_bucket != other.m_bucket; }
81 T& operator*() { return *m_bucket->slot(); }
82 T* operator->() { return m_bucket->slot(); }
83 void operator++() { m_bucket = m_bucket->next; }
84 void operator--() { m_bucket = m_bucket->previous; }
85
86private:
87 OrderedHashTableIterator(BucketType* bucket, BucketType*)
88 : m_bucket(bucket)
89 {
90 }
91
92 BucketType* m_bucket { nullptr };
93};
94
95template<typename T, typename TraitsForT, bool IsOrdered>
96class HashTable {
97 static constexpr size_t grow_capacity_at_least = 8;
98 static constexpr size_t grow_at_load_factor_percent = 80;
99 static constexpr size_t grow_capacity_increase_percent = 60;
100
101 struct Bucket {
102 BucketState state;
103 alignas(T) u8 storage[sizeof(T)];
104 T* slot() { return reinterpret_cast<T*>(storage); }
105 T const* slot() const { return reinterpret_cast<T const*>(storage); }
106 };
107
108 struct OrderedBucket {
109 OrderedBucket* previous;
110 OrderedBucket* next;
111 BucketState state;
112 alignas(T) u8 storage[sizeof(T)];
113 T* slot() { return reinterpret_cast<T*>(storage); }
114 T const* slot() const { return reinterpret_cast<T const*>(storage); }
115 };
116
117 using BucketType = Conditional<IsOrdered, OrderedBucket, Bucket>;
118
119 struct CollectionData {
120 };
121
122 struct OrderedCollectionData {
123 BucketType* head { nullptr };
124 BucketType* tail { nullptr };
125 };
126
127 using CollectionDataType = Conditional<IsOrdered, OrderedCollectionData, CollectionData>;
128
129public:
130 HashTable() = default;
131 explicit HashTable(size_t capacity) { rehash(capacity); }
132
133 ~HashTable()
134 {
135 if (!m_buckets)
136 return;
137
138 if constexpr (!IsTriviallyDestructible<T>) {
139 for (size_t i = 0; i < m_capacity; ++i) {
140 if (m_buckets[i].state != BucketState::Free)
141 m_buckets[i].slot()->~T();
142 }
143 }
144
145 kfree_sized(m_buckets, size_in_bytes(m_capacity));
146 }
147
148 HashTable(HashTable const& other)
149 {
150 rehash(other.capacity());
151 for (auto& it : other)
152 set(it);
153 }
154
155 HashTable& operator=(HashTable const& other)
156 {
157 HashTable temporary(other);
158 swap(*this, temporary);
159 return *this;
160 }
161
162 HashTable(HashTable&& other) noexcept
163 : m_buckets(other.m_buckets)
164 , m_collection_data(other.m_collection_data)
165 , m_size(other.m_size)
166 , m_capacity(other.m_capacity)
167 {
168 other.m_size = 0;
169 other.m_capacity = 0;
170 other.m_buckets = nullptr;
171 if constexpr (IsOrdered)
172 other.m_collection_data = { nullptr, nullptr };
173 }
174
175 HashTable& operator=(HashTable&& other) noexcept
176 {
177 HashTable temporary { move(other) };
178 swap(*this, temporary);
179 return *this;
180 }
181
182 friend void swap(HashTable& a, HashTable& b) noexcept
183 {
184 swap(a.m_buckets, b.m_buckets);
185 swap(a.m_size, b.m_size);
186 swap(a.m_capacity, b.m_capacity);
187
188 if constexpr (IsOrdered)
189 swap(a.m_collection_data, b.m_collection_data);
190 }
191
192 [[nodiscard]] bool is_empty() const { return m_size == 0; }
193 [[nodiscard]] size_t size() const { return m_size; }
194 [[nodiscard]] size_t capacity() const { return m_capacity; }
195
196 template<typename U, size_t N>
197 ErrorOr<void> try_set_from(U (&from_array)[N])
198 {
199 for (size_t i = 0; i < N; ++i)
200 TRY(try_set(from_array[i]));
201 return {};
202 }
203 template<typename U, size_t N>
204 void set_from(U (&from_array)[N])
205 {
206 MUST(try_set_from(from_array));
207 }
208
209 ErrorOr<void> try_ensure_capacity(size_t capacity)
210 {
211 // The user usually expects "capacity" to mean the number of values that can be stored in a
212 // container without it needing to reallocate. Our definition of "capacity" is the number of
213 // buckets we can store, but we reallocate earlier because of `grow_at_load_factor_percent`.
214 // This calculates the required internal capacity to store `capacity` number of values.
215 size_t required_capacity = capacity * 100 / grow_at_load_factor_percent + 1;
216 if (required_capacity <= m_capacity)
217 return {};
218 return try_rehash(required_capacity);
219 }
220 void ensure_capacity(size_t capacity)
221 {
222 MUST(try_ensure_capacity(capacity));
223 }
224
225 [[nodiscard]] bool contains(T const& value) const
226 {
227 return find(value) != end();
228 }
229
230 template<Concepts::HashCompatible<T> K>
231 requires(IsSame<TraitsForT, Traits<T>>) [[nodiscard]] bool contains(K const& value) const
232 {
233 return find(value) != end();
234 }
235
236 using Iterator = Conditional<IsOrdered,
237 OrderedHashTableIterator<HashTable, T, BucketType>,
238 HashTableIterator<HashTable, T, BucketType>>;
239
240 [[nodiscard]] Iterator begin()
241 {
242 if constexpr (IsOrdered)
243 return Iterator(m_collection_data.head, end_bucket());
244
245 for (size_t i = 0; i < m_capacity; ++i) {
246 if (m_buckets[i].state != BucketState::Free)
247 return Iterator(&m_buckets[i], end_bucket());
248 }
249 return end();
250 }
251
252 [[nodiscard]] Iterator end()
253 {
254 return Iterator(nullptr, nullptr);
255 }
256
257 using ConstIterator = Conditional<IsOrdered,
258 OrderedHashTableIterator<const HashTable, const T, BucketType const>,
259 HashTableIterator<const HashTable, const T, BucketType const>>;
260
261 [[nodiscard]] ConstIterator begin() const
262 {
263 if constexpr (IsOrdered)
264 return ConstIterator(m_collection_data.head, end_bucket());
265
266 for (size_t i = 0; i < m_capacity; ++i) {
267 if (m_buckets[i].state != BucketState::Free)
268 return ConstIterator(&m_buckets[i], end_bucket());
269 }
270 return end();
271 }
272
273 [[nodiscard]] ConstIterator end() const
274 {
275 return ConstIterator(nullptr, nullptr);
276 }
277
278 void clear()
279 {
280 *this = HashTable();
281 }
282
283 void clear_with_capacity()
284 {
285 if (m_capacity == 0)
286 return;
287 if constexpr (!IsTriviallyDestructible<T>) {
288 for (auto* bucket : *this)
289 bucket->~T();
290 }
291 __builtin_memset(m_buckets, 0, size_in_bytes(m_capacity));
292 m_size = 0;
293
294 if constexpr (IsOrdered)
295 m_collection_data = { nullptr, nullptr };
296 }
297
298 template<typename U = T>
299 ErrorOr<HashSetResult> try_set(U&& value, HashSetExistingEntryBehavior existing_entry_behavior = HashSetExistingEntryBehavior::Replace)
300 {
301 if (should_grow())
302 TRY(try_rehash(m_capacity * (100 + grow_capacity_increase_percent) / 100));
303
304 return write_value(forward<U>(value), existing_entry_behavior);
305 }
306 template<typename U = T>
307 HashSetResult set(U&& value, HashSetExistingEntryBehavior existing_entry_behaviour = HashSetExistingEntryBehavior::Replace)
308 {
309 return MUST(try_set(forward<U>(value), existing_entry_behaviour));
310 }
311
312 template<typename TUnaryPredicate>
313 [[nodiscard]] Iterator find(unsigned hash, TUnaryPredicate predicate)
314 {
315 return Iterator(lookup_with_hash(hash, move(predicate)), end_bucket());
316 }
317
318 [[nodiscard]] Iterator find(T const& value)
319 {
320 return find(TraitsForT::hash(value), [&](auto& other) { return TraitsForT::equals(value, other); });
321 }
322
323 template<typename TUnaryPredicate>
324 [[nodiscard]] ConstIterator find(unsigned hash, TUnaryPredicate predicate) const
325 {
326 return ConstIterator(lookup_with_hash(hash, move(predicate)), end_bucket());
327 }
328
329 [[nodiscard]] ConstIterator find(T const& value) const
330 {
331 return find(TraitsForT::hash(value), [&](auto& other) { return TraitsForT::equals(value, other); });
332 }
333 // FIXME: Support for predicates, while guaranteeing that the predicate call
334 // does not call a non trivial constructor each time invoked
335 template<Concepts::HashCompatible<T> K>
336 requires(IsSame<TraitsForT, Traits<T>>) [[nodiscard]] Iterator find(K const& value)
337 {
338 return find(Traits<K>::hash(value), [&](auto& other) { return Traits<T>::equals(other, value); });
339 }
340
341 template<Concepts::HashCompatible<T> K, typename TUnaryPredicate>
342 requires(IsSame<TraitsForT, Traits<T>>) [[nodiscard]] Iterator find(K const& value, TUnaryPredicate predicate)
343 {
344 return find(Traits<K>::hash(value), move(predicate));
345 }
346
347 template<Concepts::HashCompatible<T> K>
348 requires(IsSame<TraitsForT, Traits<T>>) [[nodiscard]] ConstIterator find(K const& value) const
349 {
350 return find(Traits<K>::hash(value), [&](auto& other) { return Traits<T>::equals(other, value); });
351 }
352
353 template<Concepts::HashCompatible<T> K, typename TUnaryPredicate>
354 requires(IsSame<TraitsForT, Traits<T>>) [[nodiscard]] ConstIterator find(K const& value, TUnaryPredicate predicate) const
355 {
356 return find(Traits<K>::hash(value), move(predicate));
357 }
358
359 bool remove(T const& value)
360 {
361 auto it = find(value);
362 if (it != end()) {
363 remove(it);
364 return true;
365 }
366 return false;
367 }
368
369 template<Concepts::HashCompatible<T> K>
370 requires(IsSame<TraitsForT, Traits<T>>) bool remove(K const& value)
371 {
372 auto it = find(value);
373 if (it != end()) {
374 remove(it);
375 return true;
376 }
377 return false;
378 }
379
380 // This invalidates the iterator
381 void remove(Iterator& iterator)
382 {
383 auto* bucket = iterator.m_bucket;
384 VERIFY(bucket);
385 delete_bucket(*bucket);
386 iterator.m_bucket = nullptr;
387 }
388
389 template<typename TUnaryPredicate>
390 bool remove_all_matching(TUnaryPredicate const& predicate)
391 {
392 bool has_removed_anything = false;
393 for (size_t i = 0; i < m_capacity; ++i) {
394 auto& bucket = m_buckets[i];
395 if (bucket.state == BucketState::Free || !predicate(*bucket.slot()))
396 continue;
397
398 delete_bucket(bucket);
399 has_removed_anything = true;
400
401 // If a bucket was shifted up, reevaluate this bucket index
402 if (bucket.state != BucketState::Free)
403 --i;
404 }
405 return has_removed_anything;
406 }
407
408 T take_last()
409 requires(IsOrdered)
410 {
411 VERIFY(!is_empty());
412 T element = move(*m_collection_data.tail->slot());
413 delete_bucket(*m_collection_data.tail);
414 return element;
415 }
416
417 T take_first()
418 requires(IsOrdered)
419 {
420 VERIFY(!is_empty());
421 T element = move(*m_collection_data.head->slot());
422 delete_bucket(*m_collection_data.head);
423 return element;
424 }
425
426private:
427 bool should_grow() const { return ((m_size + 1) * 100) >= (m_capacity * grow_at_load_factor_percent); }
428 static constexpr size_t size_in_bytes(size_t capacity) { return sizeof(BucketType) * capacity; }
429
430 BucketType* end_bucket()
431 {
432 if constexpr (IsOrdered)
433 return m_collection_data.tail;
434 else
435 return &m_buckets[m_capacity];
436 }
437 BucketType const* end_bucket() const
438 {
439 return const_cast<HashTable*>(this)->end_bucket();
440 }
441
442 ErrorOr<void> try_rehash(size_t new_capacity)
443 {
444 new_capacity = max(new_capacity, m_capacity + grow_capacity_at_least);
445 new_capacity = kmalloc_good_size(size_in_bytes(new_capacity)) / sizeof(BucketType);
446 VERIFY(new_capacity >= size());
447
448 auto* old_buckets = m_buckets;
449 auto old_buckets_size = size_in_bytes(m_capacity);
450 Iterator old_iter = begin();
451
452 auto* new_buckets = kcalloc(1, size_in_bytes(new_capacity));
453 if (!new_buckets)
454 return Error::from_errno(ENOMEM);
455
456 m_buckets = static_cast<BucketType*>(new_buckets);
457 m_capacity = new_capacity;
458
459 if constexpr (IsOrdered)
460 m_collection_data = { nullptr, nullptr };
461
462 if (!old_buckets)
463 return {};
464
465 m_size = 0;
466 for (auto it = move(old_iter); it != end(); ++it) {
467 write_value(move(*it), HashSetExistingEntryBehavior::Keep);
468 it->~T();
469 }
470
471 kfree_sized(old_buckets, old_buckets_size);
472 return {};
473 }
474 void rehash(size_t new_capacity)
475 {
476 MUST(try_rehash(new_capacity));
477 }
478
479 template<typename TUnaryPredicate>
480 [[nodiscard]] BucketType* lookup_with_hash(unsigned hash, TUnaryPredicate predicate) const
481 {
482 if (is_empty())
483 return nullptr;
484
485 hash %= m_capacity;
486 for (;;) {
487 auto* bucket = &m_buckets[hash];
488 if (bucket->state == BucketState::Free)
489 return nullptr;
490 if (predicate(*bucket->slot()))
491 return bucket;
492 if (++hash == m_capacity) [[unlikely]]
493 hash = 0;
494 }
495 }
496
497 size_t used_bucket_probe_length(BucketType const& bucket) const
498 {
499 VERIFY(bucket.state != BucketState::Free);
500
501 if (bucket.state == BucketState::CalculateLength) {
502 size_t ideal_bucket_index = TraitsForT::hash(*bucket.slot()) % m_capacity;
503
504 VERIFY(&bucket >= m_buckets);
505 size_t actual_bucket_index = &bucket - m_buckets;
506
507 if (actual_bucket_index < ideal_bucket_index)
508 return m_capacity + actual_bucket_index - ideal_bucket_index;
509 return actual_bucket_index - ideal_bucket_index;
510 }
511
512 return static_cast<u8>(bucket.state) - 1;
513 }
514
515 ALWAYS_INLINE constexpr BucketState bucket_state_for_probe_length(size_t probe_length)
516 {
517 if (probe_length > 253)
518 return BucketState::CalculateLength;
519 return static_cast<BucketState>(probe_length + 1);
520 }
521
522 template<typename U = T>
523 HashSetResult write_value(U&& value, HashSetExistingEntryBehavior existing_entry_behavior)
524 {
525 auto update_collection_for_new_bucket = [&](BucketType& bucket) {
526 if constexpr (IsOrdered) {
527 if (!m_collection_data.head) [[unlikely]] {
528 m_collection_data.head = &bucket;
529 } else {
530 bucket.previous = m_collection_data.tail;
531 m_collection_data.tail->next = &bucket;
532 }
533 m_collection_data.tail = &bucket;
534 }
535 };
536 auto update_collection_for_swapped_buckets = [&](BucketType* left_bucket, BucketType* right_bucket) {
537 if constexpr (IsOrdered) {
538 if (m_collection_data.head == left_bucket)
539 m_collection_data.head = right_bucket;
540 else if (m_collection_data.head == right_bucket)
541 m_collection_data.head = left_bucket;
542 if (m_collection_data.tail == left_bucket)
543 m_collection_data.tail = right_bucket;
544 else if (m_collection_data.tail == right_bucket)
545 m_collection_data.tail = left_bucket;
546
547 if (left_bucket->previous) {
548 if (left_bucket->previous == left_bucket)
549 left_bucket->previous = right_bucket;
550 left_bucket->previous->next = left_bucket;
551 }
552 if (left_bucket->next) {
553 if (left_bucket->next == left_bucket)
554 left_bucket->next = right_bucket;
555 left_bucket->next->previous = left_bucket;
556 }
557
558 if (right_bucket->previous && right_bucket->previous != left_bucket)
559 right_bucket->previous->next = right_bucket;
560 if (right_bucket->next && right_bucket->next != left_bucket)
561 right_bucket->next->previous = right_bucket;
562 }
563 };
564
565 auto bucket_index = TraitsForT::hash(value) % m_capacity;
566 size_t probe_length = 0;
567 for (;;) {
568 auto* bucket = &m_buckets[bucket_index];
569
570 // We found a free bucket, write to it and stop
571 if (bucket->state == BucketState::Free) {
572 new (bucket->slot()) T(forward<U>(value));
573 bucket->state = bucket_state_for_probe_length(probe_length);
574 update_collection_for_new_bucket(*bucket);
575 ++m_size;
576 return HashSetResult::InsertedNewEntry;
577 }
578
579 // The bucket is already used, does it have an identical value?
580 if (TraitsForT::equals(*bucket->slot(), static_cast<T const&>(value))) {
581 if (existing_entry_behavior == HashSetExistingEntryBehavior::Replace) {
582 (*bucket->slot()) = forward<U>(value);
583 return HashSetResult::ReplacedExistingEntry;
584 }
585 return HashSetResult::KeptExistingEntry;
586 }
587
588 // Robin hood: if our probe length is larger (poor) than this bucket's (rich), steal its position!
589 // This ensures that we will always traverse buckets in order of probe length.
590 auto target_probe_length = used_bucket_probe_length(*bucket);
591 if (probe_length > target_probe_length) {
592 // Copy out bucket
593 BucketType bucket_to_move = move(*bucket);
594 update_collection_for_swapped_buckets(bucket, &bucket_to_move);
595
596 // Write new bucket
597 new (bucket->slot()) T(forward<U>(value));
598 bucket->state = bucket_state_for_probe_length(probe_length);
599 probe_length = target_probe_length;
600 if constexpr (IsOrdered)
601 bucket->next = nullptr;
602 update_collection_for_new_bucket(*bucket);
603 ++m_size;
604
605 // Find a free bucket, swapping with smaller probe length buckets along the way
606 for (;;) {
607 if (++bucket_index == m_capacity) [[unlikely]]
608 bucket_index = 0;
609 bucket = &m_buckets[bucket_index];
610 ++probe_length;
611
612 if (bucket->state == BucketState::Free) {
613 *bucket = move(bucket_to_move);
614 bucket->state = bucket_state_for_probe_length(probe_length);
615 update_collection_for_swapped_buckets(&bucket_to_move, bucket);
616 break;
617 }
618
619 target_probe_length = used_bucket_probe_length(*bucket);
620 if (probe_length > target_probe_length) {
621 swap(bucket_to_move, *bucket);
622 bucket->state = bucket_state_for_probe_length(probe_length);
623 probe_length = target_probe_length;
624 update_collection_for_swapped_buckets(&bucket_to_move, bucket);
625 }
626 }
627
628 return HashSetResult::InsertedNewEntry;
629 }
630
631 // Try next bucket
632 if (++bucket_index == m_capacity) [[unlikely]]
633 bucket_index = 0;
634 ++probe_length;
635 }
636 }
637
638 void delete_bucket(auto& bucket)
639 {
640 VERIFY(bucket.state != BucketState::Free);
641
642 // Delete the bucket
643 bucket.slot()->~T();
644 if constexpr (IsOrdered) {
645 if (bucket.previous)
646 bucket.previous->next = bucket.next;
647 else
648 m_collection_data.head = bucket.next;
649 if (bucket.next)
650 bucket.next->previous = bucket.previous;
651 else
652 m_collection_data.tail = bucket.previous;
653 bucket.previous = nullptr;
654 bucket.next = nullptr;
655 }
656 --m_size;
657
658 // If we deleted a bucket, we need to make sure to shift up all buckets after it to ensure
659 // that we can still probe for buckets with collisions, and we automatically optimize the
660 // probe lengths. To do so, we shift the following buckets up until we reach a free bucket,
661 // or a bucket with a probe length of 0 (the ideal index for that bucket).
662 auto update_bucket_neighbours = [&](BucketType* bucket) {
663 if constexpr (IsOrdered) {
664 if (bucket->previous)
665 bucket->previous->next = bucket;
666 else
667 m_collection_data.head = bucket;
668 if (bucket->next)
669 bucket->next->previous = bucket;
670 else
671 m_collection_data.tail = bucket;
672 }
673 };
674
675 VERIFY(&bucket >= m_buckets);
676 size_t shift_to_index = &bucket - m_buckets;
677 VERIFY(shift_to_index < m_capacity);
678 size_t shift_from_index = shift_to_index;
679 for (;;) {
680 if (++shift_from_index == m_capacity) [[unlikely]]
681 shift_from_index = 0;
682
683 auto* shift_from_bucket = &m_buckets[shift_from_index];
684 if (shift_from_bucket->state == BucketState::Free)
685 break;
686
687 auto shift_from_probe_length = used_bucket_probe_length(*shift_from_bucket);
688 if (shift_from_probe_length == 0)
689 break;
690
691 auto* shift_to_bucket = &m_buckets[shift_to_index];
692 *shift_to_bucket = move(*shift_from_bucket);
693 shift_to_bucket->state = bucket_state_for_probe_length(shift_from_probe_length - 1);
694 update_bucket_neighbours(shift_to_bucket);
695
696 if (++shift_to_index == m_capacity) [[unlikely]]
697 shift_to_index = 0;
698 }
699
700 // Mark last bucket as free
701 m_buckets[shift_to_index].state = BucketState::Free;
702 }
703
704 BucketType* m_buckets { nullptr };
705
706 [[no_unique_address]] CollectionDataType m_collection_data;
707 size_t m_size { 0 };
708 size_t m_capacity { 0 };
709};
710}
711
712#if USING_AK_GLOBALLY
713using AK::HashSetResult;
714using AK::HashTable;
715using AK::OrderedHashTable;
716#endif