this repo has no description
1//! This is pulled from std.Thread.Pool. We've modified it to allow the spawned function to return
2//! an error. This helps with cleanup. All errors are logged
3const std = @import("std");
4const builtin = @import("builtin");
5const Pool = @This();
6const WaitGroup = std.Thread.WaitGroup;
7
8mutex: std.Thread.Mutex = .{},
9cond: std.Thread.Condition = .{},
10run_queue: RunQueue = .{},
11is_running: bool = true,
12allocator: std.mem.Allocator,
13threads: if (builtin.single_threaded) [0]std.Thread else []std.Thread,
14ids: if (builtin.single_threaded) struct {
15 inline fn deinit(_: @This(), _: std.mem.Allocator) void {}
16 fn getIndex(_: @This(), _: std.Thread.Id) usize {
17 return 0;
18 }
19} else std.AutoArrayHashMapUnmanaged(std.Thread.Id, void),
20
21const RunQueue = std.SinglyLinkedList(Runnable);
22const Runnable = struct {
23 runFn: RunProto,
24};
25
26const RunProto = *const fn (*Runnable, id: ?usize) void;
27
28pub const Options = struct {
29 allocator: std.mem.Allocator,
30 n_jobs: ?usize = null,
31 track_ids: bool = false,
32 stack_size: usize = std.Thread.SpawnConfig.default_stack_size,
33};
34
35pub fn init(pool: *Pool, options: Options) !void {
36 const allocator = options.allocator;
37
38 pool.* = .{
39 .allocator = allocator,
40 .threads = if (builtin.single_threaded) .{} else &.{},
41 .ids = .{},
42 };
43
44 if (builtin.single_threaded) {
45 return;
46 }
47
48 const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1);
49 if (options.track_ids) {
50 try pool.ids.ensureTotalCapacity(allocator, 1 + thread_count);
51 pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
52 }
53
54 // kill and join any threads we spawned and free memory on error.
55 pool.threads = try allocator.alloc(std.Thread, thread_count);
56 var spawned: usize = 0;
57 errdefer pool.join(spawned);
58
59 for (pool.threads) |*thread| {
60 thread.* = try std.Thread.spawn(.{
61 .stack_size = options.stack_size,
62 .allocator = allocator,
63 }, worker, .{pool});
64 spawned += 1;
65 }
66}
67
68pub fn deinit(pool: *Pool) void {
69 pool.join(pool.threads.len); // kill and join all threads.
70 pool.ids.deinit(pool.allocator);
71 pool.* = undefined;
72}
73
74fn join(pool: *Pool, spawned: usize) void {
75 if (builtin.single_threaded) {
76 return;
77 }
78
79 {
80 pool.mutex.lock();
81 defer pool.mutex.unlock();
82
83 // ensure future worker threads exit the dequeue loop
84 pool.is_running = false;
85 }
86
87 // wake up any sleeping threads (this can be done outside the mutex)
88 // then wait for all the threads we know are spawned to complete.
89 pool.cond.broadcast();
90 for (pool.threads[0..spawned]) |thread| {
91 thread.join();
92 }
93
94 pool.allocator.free(pool.threads);
95}
96
97/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
98/// `WaitGroup.finish` after it returns.
99///
100/// In the case that queuing the function call fails to allocate memory, or the
101/// target is single-threaded, the function is called directly.
102pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
103 wait_group.start();
104
105 if (builtin.single_threaded) {
106 @call(.auto, func, args);
107 wait_group.finish();
108 return;
109 }
110
111 const Args = @TypeOf(args);
112 const Closure = struct {
113 arguments: Args,
114 pool: *Pool,
115 run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } },
116 wait_group: *WaitGroup,
117
118 fn runFn(runnable: *Runnable, _: ?usize) void {
119 const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable);
120 const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node));
121 @call(.auto, func, closure.arguments);
122 closure.wait_group.finish();
123
124 // The thread pool's allocator is protected by the mutex.
125 const mutex = &closure.pool.mutex;
126 mutex.lock();
127 defer mutex.unlock();
128
129 closure.pool.allocator.destroy(closure);
130 }
131 };
132
133 {
134 pool.mutex.lock();
135
136 const closure = pool.allocator.create(Closure) catch {
137 pool.mutex.unlock();
138 @call(.auto, func, args);
139 wait_group.finish();
140 return;
141 };
142 closure.* = .{
143 .arguments = args,
144 .pool = pool,
145 .wait_group = wait_group,
146 };
147
148 pool.run_queue.prepend(&closure.run_node);
149 pool.mutex.unlock();
150 }
151
152 // Notify waiting threads outside the lock to try and keep the critical section small.
153 pool.cond.signal();
154}
155
156/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
157/// `WaitGroup.finish` after it returns.
158///
159/// The first argument passed to `func` is a dense `usize` thread id, the rest
160/// of the arguments are passed from `args`. Requires the pool to have been
161/// initialized with `.track_ids = true`.
162///
163/// In the case that queuing the function call fails to allocate memory, or the
164/// target is single-threaded, the function is called directly.
165pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
166 wait_group.start();
167
168 if (builtin.single_threaded) {
169 @call(.auto, func, .{0} ++ args);
170 wait_group.finish();
171 return;
172 }
173
174 const Args = @TypeOf(args);
175 const Closure = struct {
176 arguments: Args,
177 pool: *Pool,
178 run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } },
179 wait_group: *WaitGroup,
180
181 fn runFn(runnable: *Runnable, id: ?usize) void {
182 const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable);
183 const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node));
184 @call(.auto, func, .{id.?} ++ closure.arguments);
185 closure.wait_group.finish();
186
187 // The thread pool's allocator is protected by the mutex.
188 const mutex = &closure.pool.mutex;
189 mutex.lock();
190 defer mutex.unlock();
191
192 closure.pool.allocator.destroy(closure);
193 }
194 };
195
196 {
197 pool.mutex.lock();
198
199 const closure = pool.allocator.create(Closure) catch {
200 const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId());
201 pool.mutex.unlock();
202 @call(.auto, func, .{id.?} ++ args);
203 wait_group.finish();
204 return;
205 };
206 closure.* = .{
207 .arguments = args,
208 .pool = pool,
209 .wait_group = wait_group,
210 };
211
212 pool.run_queue.prepend(&closure.run_node);
213 pool.mutex.unlock();
214 }
215
216 // Notify waiting threads outside the lock to try and keep the critical section small.
217 pool.cond.signal();
218}
219
220pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
221 if (builtin.single_threaded) {
222 @call(.auto, func, args) catch {};
223 return;
224 }
225
226 const Args = @TypeOf(args);
227 const Closure = struct {
228 arguments: Args,
229 pool: *Pool,
230 run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } },
231
232 fn runFn(runnable: *Runnable, _: ?usize) void {
233 const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable);
234 const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node));
235 @call(.auto, func, closure.arguments) catch {};
236
237 // The thread pool's allocator is protected by the mutex.
238 const mutex = &closure.pool.mutex;
239 mutex.lock();
240 defer mutex.unlock();
241
242 closure.pool.allocator.destroy(closure);
243 }
244 };
245
246 {
247 pool.mutex.lock();
248 defer pool.mutex.unlock();
249
250 const closure = try pool.allocator.create(Closure);
251 closure.* = .{
252 .arguments = args,
253 .pool = pool,
254 };
255
256 pool.run_queue.prepend(&closure.run_node);
257 }
258
259 // Notify waiting threads outside the lock to try and keep the critical section small.
260 pool.cond.signal();
261}
262
263fn worker(pool: *Pool) void {
264 pool.mutex.lock();
265 defer pool.mutex.unlock();
266
267 const id: ?usize = if (pool.ids.count() > 0) @intCast(pool.ids.count()) else null;
268 if (id) |_| pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
269
270 while (true) {
271 while (pool.run_queue.popFirst()) |run_node| {
272 // Temporarily unlock the mutex in order to execute the run_node
273 pool.mutex.unlock();
274 defer pool.mutex.lock();
275
276 run_node.data.runFn(&run_node.data, id);
277 }
278
279 // Stop executing instead of waiting if the thread pool is no longer running.
280 if (pool.is_running) {
281 pool.cond.wait(&pool.mutex);
282 } else {
283 break;
284 }
285 }
286}
287
288pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
289 var id: ?usize = null;
290
291 while (!wait_group.isDone()) {
292 pool.mutex.lock();
293 if (pool.run_queue.popFirst()) |run_node| {
294 id = id orelse pool.ids.getIndex(std.Thread.getCurrentId());
295 pool.mutex.unlock();
296 run_node.data.runFn(&run_node.data, id);
297 continue;
298 }
299
300 pool.mutex.unlock();
301 wait_group.wait();
302 return;
303 }
304}
305
306pub fn getIdCount(pool: *Pool) usize {
307 return @intCast(1 + pool.threads.len);
308}