Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 * Copyright (c) 2023, Kenneth Myhra <kennethmyhra@serenityos.org>
4 *
5 * SPDX-License-Identifier: BSD-2-Clause
6 */
7
8#pragma once
9
10#include <AK/HashTable.h>
11#include <AK/Optional.h>
12#include <AK/Vector.h>
13#include <initializer_list>
14
15namespace AK {
16
17template<typename K, typename V, typename KeyTraits, typename ValueTraits, bool IsOrdered>
18class HashMap {
19private:
20 struct Entry {
21 K key;
22 V value;
23 };
24
25 struct EntryTraits {
26 static unsigned hash(Entry const& entry) { return KeyTraits::hash(entry.key); }
27 static bool equals(Entry const& a, Entry const& b) { return KeyTraits::equals(a.key, b.key); }
28 };
29
30public:
31 using KeyType = K;
32 using ValueType = V;
33
34 HashMap() = default;
35
36 HashMap(std::initializer_list<Entry> list)
37 {
38 MUST(try_ensure_capacity(list.size()));
39 for (auto& item : list)
40 set(item.key, item.value);
41 }
42
43 [[nodiscard]] bool is_empty() const
44 {
45 return m_table.is_empty();
46 }
47 [[nodiscard]] size_t size() const { return m_table.size(); }
48 [[nodiscard]] size_t capacity() const { return m_table.capacity(); }
49 void clear() { m_table.clear(); }
50 void clear_with_capacity() { m_table.clear_with_capacity(); }
51
52 HashSetResult set(K const& key, V const& value) { return m_table.set({ key, value }); }
53 HashSetResult set(K const& key, V&& value) { return m_table.set({ key, move(value) }); }
54 HashSetResult set(K&& key, V&& value) { return m_table.set({ move(key), move(value) }); }
55 ErrorOr<HashSetResult> try_set(K const& key, V const& value) { return m_table.try_set({ key, value }); }
56 ErrorOr<HashSetResult> try_set(K const& key, V&& value) { return m_table.try_set({ key, move(value) }); }
57 ErrorOr<HashSetResult> try_set(K&& key, V&& value) { return m_table.try_set({ move(key), move(value) }); }
58
59 bool remove(K const& key)
60 {
61 auto it = find(key);
62 if (it != end()) {
63 m_table.remove(it);
64 return true;
65 }
66 return false;
67 }
68
69 template<Concepts::HashCompatible<K> Key>
70 requires(IsSame<KeyTraits, Traits<K>>) bool remove(Key const& key)
71 {
72 auto it = find(key);
73 if (it != end()) {
74 m_table.remove(it);
75 return true;
76 }
77 return false;
78 }
79
80 template<typename TUnaryPredicate>
81 bool remove_all_matching(TUnaryPredicate const& predicate)
82 {
83 return m_table.template remove_all_matching([&](auto& entry) {
84 return predicate(entry.key, entry.value);
85 });
86 }
87
88 using HashTableType = HashTable<Entry, EntryTraits, IsOrdered>;
89 using IteratorType = typename HashTableType::Iterator;
90 using ConstIteratorType = typename HashTableType::ConstIterator;
91
92 [[nodiscard]] IteratorType begin() { return m_table.begin(); }
93 [[nodiscard]] IteratorType end() { return m_table.end(); }
94 [[nodiscard]] IteratorType find(K const& key)
95 {
96 return m_table.find(KeyTraits::hash(key), [&](auto& entry) { return KeyTraits::equals(key, entry.key); });
97 }
98 template<typename TUnaryPredicate>
99 [[nodiscard]] IteratorType find(unsigned hash, TUnaryPredicate predicate)
100 {
101 return m_table.find(hash, predicate);
102 }
103
104 [[nodiscard]] ConstIteratorType begin() const { return m_table.begin(); }
105 [[nodiscard]] ConstIteratorType end() const { return m_table.end(); }
106 [[nodiscard]] ConstIteratorType find(K const& key) const
107 {
108 return m_table.find(KeyTraits::hash(key), [&](auto& entry) { return KeyTraits::equals(key, entry.key); });
109 }
110 template<typename TUnaryPredicate>
111 [[nodiscard]] ConstIteratorType find(unsigned hash, TUnaryPredicate predicate) const
112 {
113 return m_table.find(hash, predicate);
114 }
115
116 template<Concepts::HashCompatible<K> Key>
117 requires(IsSame<KeyTraits, Traits<K>>) [[nodiscard]] IteratorType find(Key const& key)
118 {
119 return m_table.find(Traits<Key>::hash(key), [&](auto& entry) { return Traits<K>::equals(key, entry.key); });
120 }
121
122 template<Concepts::HashCompatible<K> Key>
123 requires(IsSame<KeyTraits, Traits<K>>) [[nodiscard]] ConstIteratorType find(Key const& key) const
124 {
125 return m_table.find(Traits<Key>::hash(key), [&](auto& entry) { return Traits<K>::equals(key, entry.key); });
126 }
127
128 ErrorOr<void> try_ensure_capacity(size_t capacity) { return m_table.try_ensure_capacity(capacity); }
129
130 Optional<typename ValueTraits::ConstPeekType> get(K const& key) const
131 requires(!IsPointer<typename ValueTraits::PeekType>)
132 {
133 auto it = find(key);
134 if (it == end())
135 return {};
136 return (*it).value;
137 }
138
139 Optional<typename ValueTraits::ConstPeekType> get(K const& key) const
140 requires(IsPointer<typename ValueTraits::PeekType>)
141 {
142 auto it = find(key);
143 if (it == end())
144 return {};
145 return (*it).value;
146 }
147
148 Optional<typename ValueTraits::PeekType> get(K const& key)
149 requires(!IsConst<typename ValueTraits::PeekType>)
150 {
151 auto it = find(key);
152 if (it == end())
153 return {};
154 return (*it).value;
155 }
156
157 template<Concepts::HashCompatible<K> Key>
158 requires(IsSame<KeyTraits, Traits<K>>) Optional<typename ValueTraits::ConstPeekType> get(Key const& key) const
159 requires(!IsPointer<typename ValueTraits::PeekType>)
160 {
161 auto it = find(key);
162 if (it == end())
163 return {};
164 return (*it).value;
165 }
166
167 template<Concepts::HashCompatible<K> Key>
168 requires(IsSame<KeyTraits, Traits<K>>) Optional<typename ValueTraits::ConstPeekType> get(Key const& key) const
169 requires(IsPointer<typename ValueTraits::PeekType>)
170 {
171 auto it = find(key);
172 if (it == end())
173 return {};
174 return (*it).value;
175 }
176
177 template<Concepts::HashCompatible<K> Key>
178 requires(IsSame<KeyTraits, Traits<K>>) Optional<typename ValueTraits::PeekType> get(Key const& key)
179 requires(!IsConst<typename ValueTraits::PeekType>)
180 {
181 auto it = find(key);
182 if (it == end())
183 return {};
184 return (*it).value;
185 }
186
187 [[nodiscard]] bool contains(K const& key) const
188 {
189 return find(key) != end();
190 }
191
192 template<Concepts::HashCompatible<K> Key>
193 requires(IsSame<KeyTraits, Traits<K>>) [[nodiscard]] bool contains(Key const& value) const
194 {
195 return find(value) != end();
196 }
197
198 void remove(IteratorType it)
199 {
200 m_table.remove(it);
201 }
202
203 Optional<V> take(K const& key)
204 {
205 if (auto it = find(key); it != end()) {
206 auto value = move(it->value);
207 m_table.remove(it);
208
209 return value;
210 }
211
212 return {};
213 }
214
215 template<Concepts::HashCompatible<K> Key>
216 requires(IsSame<KeyTraits, Traits<K>>) Optional<V> take(Key const& key)
217 {
218 if (auto it = find(key); it != end()) {
219 auto value = move(it->value);
220 m_table.remove(it);
221
222 return value;
223 }
224
225 return {};
226 }
227
228 V& ensure(K const& key)
229 {
230 auto it = find(key);
231 if (it != end())
232 return it->value;
233 auto result = set(key, V());
234 VERIFY(result == HashSetResult::InsertedNewEntry);
235 return find(key)->value;
236 }
237
238 template<typename Callback>
239 V& ensure(K const& key, Callback initialization_callback)
240 {
241 auto it = find(key);
242 if (it != end())
243 return it->value;
244 auto result = set(key, initialization_callback());
245 VERIFY(result == HashSetResult::InsertedNewEntry);
246 return find(key)->value;
247 }
248
249 template<typename Callback>
250 ErrorOr<V> try_ensure(K const& key, Callback initialization_callback)
251 {
252 auto it = find(key);
253 if (it != end())
254 return it->value;
255 if constexpr (FallibleFunction<Callback>) {
256 auto result = TRY(try_set(key, TRY(initialization_callback())));
257 VERIFY(result == HashSetResult::InsertedNewEntry);
258 } else {
259 auto result = TRY(try_set(key, initialization_callback()));
260 VERIFY(result == HashSetResult::InsertedNewEntry);
261 }
262 return find(key)->value;
263 }
264
265 [[nodiscard]] Vector<K> keys() const
266 {
267 Vector<K> list;
268 list.ensure_capacity(size());
269 for (auto& it : *this)
270 list.unchecked_append(it.key);
271 return list;
272 }
273
274 [[nodiscard]] u32 hash() const
275 {
276 u32 hash = 0;
277 for (auto& it : *this) {
278 auto entry_hash = pair_int_hash(it.key.hash(), it.value.hash());
279 hash = pair_int_hash(hash, entry_hash);
280 }
281 return hash;
282 }
283
284 ErrorOr<HashMap<K, V>> clone()
285 {
286 HashMap<K, V> hash_map_clone;
287 for (auto& it : *this)
288 TRY(hash_map_clone.try_set(it.key, it.value));
289 return hash_map_clone;
290 }
291
292private:
293 HashTableType m_table;
294};
295
296}
297
298#if USING_AK_GLOBALLY
299using AK::HashMap;
300using AK::OrderedHashMap;
301#endif