this repo has no description
at main 9.4 kB view raw
1//! This is pulled from std.Thread.Pool. We've modified it to allow the spawned function to return 2//! an error. This helps with cleanup. All errors are logged 3const std = @import("std"); 4const builtin = @import("builtin"); 5const Pool = @This(); 6const WaitGroup = std.Thread.WaitGroup; 7 8mutex: std.Thread.Mutex = .{}, 9cond: std.Thread.Condition = .{}, 10run_queue: RunQueue = .{}, 11is_running: bool = true, 12allocator: std.mem.Allocator, 13threads: if (builtin.single_threaded) [0]std.Thread else []std.Thread, 14ids: if (builtin.single_threaded) struct { 15 inline fn deinit(_: @This(), _: std.mem.Allocator) void {} 16 fn getIndex(_: @This(), _: std.Thread.Id) usize { 17 return 0; 18 } 19} else std.AutoArrayHashMapUnmanaged(std.Thread.Id, void), 20 21const RunQueue = std.SinglyLinkedList(Runnable); 22const Runnable = struct { 23 runFn: RunProto, 24}; 25 26const RunProto = *const fn (*Runnable, id: ?usize) void; 27 28pub const Options = struct { 29 allocator: std.mem.Allocator, 30 n_jobs: ?usize = null, 31 track_ids: bool = false, 32 stack_size: usize = std.Thread.SpawnConfig.default_stack_size, 33}; 34 35pub fn init(pool: *Pool, options: Options) !void { 36 const allocator = options.allocator; 37 38 pool.* = .{ 39 .allocator = allocator, 40 .threads = if (builtin.single_threaded) .{} else &.{}, 41 .ids = .{}, 42 }; 43 44 if (builtin.single_threaded) { 45 return; 46 } 47 48 const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1); 49 if (options.track_ids) { 50 try pool.ids.ensureTotalCapacity(allocator, 1 + thread_count); 51 pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {}); 52 } 53 54 // kill and join any threads we spawned and free memory on error. 55 pool.threads = try allocator.alloc(std.Thread, thread_count); 56 var spawned: usize = 0; 57 errdefer pool.join(spawned); 58 59 for (pool.threads) |*thread| { 60 thread.* = try std.Thread.spawn(.{ 61 .stack_size = options.stack_size, 62 .allocator = allocator, 63 }, worker, .{pool}); 64 spawned += 1; 65 } 66} 67 68pub fn deinit(pool: *Pool) void { 69 pool.join(pool.threads.len); // kill and join all threads. 70 pool.ids.deinit(pool.allocator); 71 pool.* = undefined; 72} 73 74fn join(pool: *Pool, spawned: usize) void { 75 if (builtin.single_threaded) { 76 return; 77 } 78 79 { 80 pool.mutex.lock(); 81 defer pool.mutex.unlock(); 82 83 // ensure future worker threads exit the dequeue loop 84 pool.is_running = false; 85 } 86 87 // wake up any sleeping threads (this can be done outside the mutex) 88 // then wait for all the threads we know are spawned to complete. 89 pool.cond.broadcast(); 90 for (pool.threads[0..spawned]) |thread| { 91 thread.join(); 92 } 93 94 pool.allocator.free(pool.threads); 95} 96 97/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and 98/// `WaitGroup.finish` after it returns. 99/// 100/// In the case that queuing the function call fails to allocate memory, or the 101/// target is single-threaded, the function is called directly. 102pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void { 103 wait_group.start(); 104 105 if (builtin.single_threaded) { 106 @call(.auto, func, args); 107 wait_group.finish(); 108 return; 109 } 110 111 const Args = @TypeOf(args); 112 const Closure = struct { 113 arguments: Args, 114 pool: *Pool, 115 run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } }, 116 wait_group: *WaitGroup, 117 118 fn runFn(runnable: *Runnable, _: ?usize) void { 119 const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); 120 const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); 121 @call(.auto, func, closure.arguments); 122 closure.wait_group.finish(); 123 124 // The thread pool's allocator is protected by the mutex. 125 const mutex = &closure.pool.mutex; 126 mutex.lock(); 127 defer mutex.unlock(); 128 129 closure.pool.allocator.destroy(closure); 130 } 131 }; 132 133 { 134 pool.mutex.lock(); 135 136 const closure = pool.allocator.create(Closure) catch { 137 pool.mutex.unlock(); 138 @call(.auto, func, args); 139 wait_group.finish(); 140 return; 141 }; 142 closure.* = .{ 143 .arguments = args, 144 .pool = pool, 145 .wait_group = wait_group, 146 }; 147 148 pool.run_queue.prepend(&closure.run_node); 149 pool.mutex.unlock(); 150 } 151 152 // Notify waiting threads outside the lock to try and keep the critical section small. 153 pool.cond.signal(); 154} 155 156/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and 157/// `WaitGroup.finish` after it returns. 158/// 159/// The first argument passed to `func` is a dense `usize` thread id, the rest 160/// of the arguments are passed from `args`. Requires the pool to have been 161/// initialized with `.track_ids = true`. 162/// 163/// In the case that queuing the function call fails to allocate memory, or the 164/// target is single-threaded, the function is called directly. 165pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void { 166 wait_group.start(); 167 168 if (builtin.single_threaded) { 169 @call(.auto, func, .{0} ++ args); 170 wait_group.finish(); 171 return; 172 } 173 174 const Args = @TypeOf(args); 175 const Closure = struct { 176 arguments: Args, 177 pool: *Pool, 178 run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } }, 179 wait_group: *WaitGroup, 180 181 fn runFn(runnable: *Runnable, id: ?usize) void { 182 const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); 183 const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); 184 @call(.auto, func, .{id.?} ++ closure.arguments); 185 closure.wait_group.finish(); 186 187 // The thread pool's allocator is protected by the mutex. 188 const mutex = &closure.pool.mutex; 189 mutex.lock(); 190 defer mutex.unlock(); 191 192 closure.pool.allocator.destroy(closure); 193 } 194 }; 195 196 { 197 pool.mutex.lock(); 198 199 const closure = pool.allocator.create(Closure) catch { 200 const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId()); 201 pool.mutex.unlock(); 202 @call(.auto, func, .{id.?} ++ args); 203 wait_group.finish(); 204 return; 205 }; 206 closure.* = .{ 207 .arguments = args, 208 .pool = pool, 209 .wait_group = wait_group, 210 }; 211 212 pool.run_queue.prepend(&closure.run_node); 213 pool.mutex.unlock(); 214 } 215 216 // Notify waiting threads outside the lock to try and keep the critical section small. 217 pool.cond.signal(); 218} 219 220pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { 221 if (builtin.single_threaded) { 222 @call(.auto, func, args) catch {}; 223 return; 224 } 225 226 const Args = @TypeOf(args); 227 const Closure = struct { 228 arguments: Args, 229 pool: *Pool, 230 run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } }, 231 232 fn runFn(runnable: *Runnable, _: ?usize) void { 233 const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); 234 const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); 235 @call(.auto, func, closure.arguments) catch {}; 236 237 // The thread pool's allocator is protected by the mutex. 238 const mutex = &closure.pool.mutex; 239 mutex.lock(); 240 defer mutex.unlock(); 241 242 closure.pool.allocator.destroy(closure); 243 } 244 }; 245 246 { 247 pool.mutex.lock(); 248 defer pool.mutex.unlock(); 249 250 const closure = try pool.allocator.create(Closure); 251 closure.* = .{ 252 .arguments = args, 253 .pool = pool, 254 }; 255 256 pool.run_queue.prepend(&closure.run_node); 257 } 258 259 // Notify waiting threads outside the lock to try and keep the critical section small. 260 pool.cond.signal(); 261} 262 263fn worker(pool: *Pool) void { 264 pool.mutex.lock(); 265 defer pool.mutex.unlock(); 266 267 const id: ?usize = if (pool.ids.count() > 0) @intCast(pool.ids.count()) else null; 268 if (id) |_| pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {}); 269 270 while (true) { 271 while (pool.run_queue.popFirst()) |run_node| { 272 // Temporarily unlock the mutex in order to execute the run_node 273 pool.mutex.unlock(); 274 defer pool.mutex.lock(); 275 276 run_node.data.runFn(&run_node.data, id); 277 } 278 279 // Stop executing instead of waiting if the thread pool is no longer running. 280 if (pool.is_running) { 281 pool.cond.wait(&pool.mutex); 282 } else { 283 break; 284 } 285 } 286} 287 288pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void { 289 var id: ?usize = null; 290 291 while (!wait_group.isDone()) { 292 pool.mutex.lock(); 293 if (pool.run_queue.popFirst()) |run_node| { 294 id = id orelse pool.ids.getIndex(std.Thread.getCurrentId()); 295 pool.mutex.unlock(); 296 run_node.data.runFn(&run_node.data, id); 297 continue; 298 } 299 300 pool.mutex.unlock(); 301 wait_group.wait(); 302 return; 303 } 304} 305 306pub fn getIdCount(pool: *Pool) usize { 307 return @intCast(1 + pool.threads.len); 308}