1#![allow(async_fn_in_trait)] // We use native async fn in traits for Worker/WorkerFactory
2
3//! Unified worker supervision and lifecycle management
4//!
5//! This module provides a trait-based interface for background workers with:
6//! - Consistent initialization and shutdown
7//! - Automatic restart on failure with configurable policies
8//! - Error collection and reporting
9//! - Health monitoring
10//!
11//! ## Usage
12//!
13//! ```rust,ignore
14//! // Define a worker
15//! struct MyWorker { pool: Pool }
16//!
17//! #[async_trait::async_trait]
18//! impl Worker for MyWorker {
19//! fn name(&self) -> &'static str { "my_worker" }
20//!
21//! async fn run(self, stop: WatchReceiver<bool>) -> eyre::Result<()> {
22//! // Worker logic here
23//! Ok(())
24//! }
25//! }
26//!
27//! // Spawn with supervision
28//! let supervisor = WorkerSupervisor::new();
29//! supervisor.spawn(
30//! MyWorker { pool },
31//! stop.clone(),
32//! RestartPolicy::Backoff {
33//! initial_delay: Duration::from_secs(1),
34//! max_delay: Duration::from_secs(60),
35//! max_retries: 5,
36//! },
37//! );
38//! ```
39
40use eyre::Result;
41use std::time::Duration;
42use tokio::sync::watch::Receiver as WatchReceiver;
43use tokio::task::JoinHandle;
44use tokio_util::task::TaskTracker;
45use tracing::{error, info, warn};
46
47/// Worker trait for background tasks
48///
49/// All workers implement this trait to provide consistent lifecycle management.
50///
51/// Note: Uses native async fn in traits (stable since Rust 1.75).
52/// The Future returned is automatically Send because the trait requires Send + 'static.
53pub trait Worker: Send + 'static {
54 /// Worker name for logging and metrics
55 ///
56 /// Should be a static string like "cleanup_worker" or "handle_resolution"
57 fn name(&self) -> &'static str;
58
59 /// Run the worker until stopped
60 ///
61 /// This is the main worker loop. Workers should:
62 /// - Poll the stop signal regularly (via select!)
63 /// - Return Ok(()) on graceful shutdown
64 /// - Return Err(_) only for unrecoverable errors
65 ///
66 /// Transient errors should be handled internally with retries.
67 fn run(self, stop: WatchReceiver<bool>) -> impl std::future::Future<Output = Result<()>> + Send
68 where
69 Self: Sized;
70}
71
72/// Factory trait for creating restartable workers
73///
74/// Workers that support restart must implement this trait.
75/// The factory is cloned on each restart attempt.
76///
77/// Note: Uses native async fn in traits.
78pub trait WorkerFactory: Clone + Send + 'static {
79 /// The worker type this factory creates
80 type Worker: Worker;
81
82 /// Worker name (used for logging before worker is created)
83 fn name(&self) -> &'static str;
84
85 /// Create a new worker instance
86 ///
87 /// Called on initial spawn and each restart attempt.
88 /// Should return Err only if worker creation is impossible.
89 fn create(&self) -> impl std::future::Future<Output = Result<Self::Worker>> + Send;
90}
91
92/// Restart policy for worker supervision
93#[derive(Debug, Clone, Copy)]
94pub enum RestartPolicy {
95 /// Never restart - worker runs once
96 ///
97 /// Use for:
98 /// - One-shot initialization tasks
99 /// - Workers that should fail-fast
100 Never,
101
102 /// Restart immediately up to max_retries times
103 ///
104 /// Use for:
105 /// - Quick recovery from transient errors
106 /// - Workers with external rate limiting
107 Immediate { max_retries: u32 },
108
109 /// Restart with exponential backoff (1s, 2s, 4s, 8s... up to 60s)
110 ///
111 /// Use for:
112 /// - Network-dependent workers
113 /// - Workers that might overwhelm external services
114 Backoff { max_retries: u32 },
115
116 /// Always restart, no limit
117 ///
118 /// Use for:
119 /// - Critical infrastructure workers
120 /// - Workers that must never stop
121 Always,
122}
123
124impl RestartPolicy {
125 /// Get delay for given attempt number (0-indexed)
126 fn delay_for_attempt(&self, attempt: u32) -> Option<Duration> {
127 match self {
128 RestartPolicy::Never => None,
129 RestartPolicy::Immediate { max_retries } => {
130 if attempt < *max_retries {
131 Some(Duration::from_millis(0))
132 } else {
133 None
134 }
135 }
136 RestartPolicy::Backoff { max_retries } => {
137 if attempt >= *max_retries {
138 return None;
139 }
140 // Exponential: 1s, 2s, 4s, 8s, 16s, 32s, capped at 60s
141 let delay_secs = (1_u64 << attempt).min(60);
142 Some(Duration::from_secs(delay_secs))
143 }
144 RestartPolicy::Always => Some(Duration::from_secs(1)),
145 }
146 }
147
148 /// Check if restart is allowed for given attempt
149 fn should_restart(&self, attempt: u32) -> bool {
150 self.delay_for_attempt(attempt).is_some()
151 }
152}
153
154/// Worker status after completion
155#[derive(Debug)]
156pub enum WorkerStatus {
157 /// Worker completed successfully
158 Completed { name: &'static str },
159
160 /// Worker failed and will not restart
161 Failed {
162 name: &'static str,
163 error: eyre::Report,
164 attempts: u32,
165 },
166
167 /// Worker stopped by signal
168 Stopped { name: &'static str },
169
170 /// Worker panicked
171 Panicked { name: &'static str },
172}
173
174/// Handle to a supervised worker
175pub struct WorkerHandle {
176 name: &'static str,
177 join_handle: JoinHandle<WorkerStatus>,
178}
179
180impl WorkerHandle {
181 /// Wait for worker to complete and get final status
182 pub async fn wait(self) -> WorkerStatus {
183 match self.join_handle.await {
184 Ok(status) => status,
185 Err(_) => WorkerStatus::Panicked { name: self.name },
186 }
187 }
188
189 /// Worker name
190 pub fn name(&self) -> &'static str {
191 self.name
192 }
193}
194
195/// Supervisor for managing worker lifecycle with restart policies
196pub struct WorkerSupervisor {
197 tracker: TaskTracker,
198}
199
200impl WorkerSupervisor {
201 /// Create a new worker supervisor
202 pub fn new() -> Self {
203 Self {
204 tracker: TaskTracker::new(),
205 }
206 }
207
208 /// Spawn a one-shot worker (no restart support)
209 ///
210 /// Use this for workers that should never restart.
211 /// The worker is consumed on first run.
212 pub fn spawn_oneshot<W: Worker>(&self, worker: W, stop: WatchReceiver<bool>) -> WorkerHandle {
213 let name = worker.name();
214 let handle = self.tracker.spawn(Self::supervise_oneshot(worker, stop));
215
216 WorkerHandle {
217 name,
218 join_handle: handle,
219 }
220 }
221
222 /// Spawn a restartable worker using a factory
223 ///
224 /// The factory will be used to create new worker instances on restart.
225 pub fn spawn<F: WorkerFactory>(
226 &self,
227 factory: F,
228 stop: WatchReceiver<bool>,
229 policy: RestartPolicy,
230 ) -> WorkerHandle {
231 let name = factory.name();
232 let handle = self.tracker.spawn(Self::supervise(factory, stop, policy));
233
234 WorkerHandle {
235 name,
236 join_handle: handle,
237 }
238 }
239
240 /// Supervise a one-shot worker (no restart)
241 async fn supervise_oneshot<W: Worker>(
242 worker: W,
243 stop: WatchReceiver<bool>,
244 ) -> WorkerStatus {
245 let name = worker.name();
246 info!(worker = name, "Worker starting (one-shot)");
247
248 match worker.run(stop).await {
249 Ok(()) => {
250 info!(worker = name, "Worker completed successfully");
251 WorkerStatus::Completed { name }
252 }
253 Err(e) => {
254 error!(worker = name, error = ?e, "Worker failed");
255 WorkerStatus::Failed {
256 name,
257 error: e,
258 attempts: 1,
259 }
260 }
261 }
262 }
263
264 /// Supervise a restartable worker with restart logic
265 async fn supervise<F: WorkerFactory>(
266 factory: F,
267 stop: WatchReceiver<bool>,
268 policy: RestartPolicy,
269 ) -> WorkerStatus {
270 let name = factory.name();
271 let mut attempt = 0_u32;
272
273 loop {
274 // Create worker instance
275 let worker = match factory.create().await {
276 Ok(w) => w,
277 Err(e) => {
278 error!(
279 worker = name,
280 attempt = attempt + 1,
281 error = ?e,
282 "Failed to create worker"
283 );
284
285 // If we can't create the worker, check restart policy
286 if !policy.should_restart(attempt) {
287 return WorkerStatus::Failed {
288 name,
289 error: e.wrap_err("Failed to create worker"),
290 attempts: attempt + 1,
291 };
292 }
293
294 // Wait before retry
295 if let Some(delay) = policy.delay_for_attempt(attempt) {
296 if delay > Duration::from_millis(0) {
297 warn!(
298 worker = name,
299 attempt = attempt + 1,
300 delay_secs = delay.as_secs(),
301 "Waiting before worker creation retry"
302 );
303 tokio::time::sleep(delay).await;
304 }
305 }
306
307 attempt += 1;
308 continue;
309 }
310 };
311
312 // Run worker
313 info!(worker = name, attempt = attempt + 1, "Worker starting");
314 match worker.run(stop.clone()).await {
315 Ok(()) => {
316 info!(worker = name, "Worker completed successfully");
317 return WorkerStatus::Completed { name };
318 }
319 Err(e) => {
320 error!(
321 worker = name,
322 attempt = attempt + 1,
323 error = ?e,
324 "Worker failed"
325 );
326
327 // Check if we should restart
328 if !policy.should_restart(attempt) {
329 error!(
330 worker = name,
331 attempts = attempt + 1,
332 "Worker will not restart (max attempts reached)"
333 );
334 return WorkerStatus::Failed {
335 name,
336 error: e,
337 attempts: attempt + 1,
338 };
339 }
340
341 // Wait before restart
342 if let Some(delay) = policy.delay_for_attempt(attempt) {
343 if delay > Duration::from_millis(0) {
344 warn!(
345 worker = name,
346 attempt = attempt + 2,
347 delay_secs = delay.as_secs(),
348 "Waiting before worker restart"
349 );
350 tokio::time::sleep(delay).await;
351 }
352 }
353
354 attempt += 1;
355 }
356 }
357 }
358 }
359
360 /// Wait for all workers to complete
361 ///
362 /// Returns status for each worker.
363 /// Note: Currently this doesn't collect individual statuses
364 /// because TaskTracker doesn't expose handles. Use WorkerHandle::wait() instead.
365 pub async fn wait_all(self) -> Vec<WorkerStatus> {
366 self.tracker.close();
367 self.tracker.wait().await;
368 Vec::new()
369 }
370
371 /// Close the supervisor (no new workers can be spawned)
372 pub fn close(&self) {
373 self.tracker.close();
374 }
375
376 /// Wait for all workers without consuming self
377 pub async fn wait(&self) {
378 self.tracker.wait().await;
379 }
380}
381
382impl Default for WorkerSupervisor {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_restart_policy_never() {
394 let policy = RestartPolicy::Never;
395 assert!(!policy.should_restart(0));
396 assert!(!policy.should_restart(1));
397 }
398
399 #[test]
400 fn test_restart_policy_immediate() {
401 let policy = RestartPolicy::Immediate { max_retries: 3 };
402 assert!(policy.should_restart(0));
403 assert!(policy.should_restart(1));
404 assert!(policy.should_restart(2));
405 assert!(!policy.should_restart(3));
406 }
407
408 #[test]
409 fn test_restart_policy_backoff() {
410 let policy = RestartPolicy::Backoff { max_retries: 5 };
411
412 // Delays should be: 1s, 2s, 4s, 8s, 16s
413 assert_eq!(
414 policy.delay_for_attempt(0),
415 Some(Duration::from_secs(1))
416 );
417 assert_eq!(
418 policy.delay_for_attempt(1),
419 Some(Duration::from_secs(2))
420 );
421 assert_eq!(
422 policy.delay_for_attempt(2),
423 Some(Duration::from_secs(4))
424 );
425 assert_eq!(
426 policy.delay_for_attempt(3),
427 Some(Duration::from_secs(8))
428 );
429 assert_eq!(
430 policy.delay_for_attempt(4),
431 Some(Duration::from_secs(16))
432 );
433 assert_eq!(policy.delay_for_attempt(5), None); // max_retries
434 }
435
436 #[test]
437 fn test_restart_policy_backoff_capped() {
438 let policy = RestartPolicy::Backoff { max_retries: 10 };
439
440 // 1s, 2s, 4s, 8s, 16s, 32s, then capped at 60s
441 assert_eq!(
442 policy.delay_for_attempt(5),
443 Some(Duration::from_secs(32))
444 );
445 assert_eq!(
446 policy.delay_for_attempt(6),
447 Some(Duration::from_secs(60))
448 );
449 assert_eq!(
450 policy.delay_for_attempt(7),
451 Some(Duration::from_secs(60))
452 );
453 }
454
455 #[test]
456 fn test_restart_policy_always() {
457 let policy = RestartPolicy::Always;
458 assert!(policy.should_restart(0));
459 assert!(policy.should_restart(100));
460 assert!(policy.should_restart(1000));
461 }
462}