Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "gorm.io/gorm"
12)
13
14// TaskStatus represents the status of a task in the queue
15type TaskStatus string
16
17const (
18 TaskStatusPending TaskStatus = "PENDING"
19 TaskStatusProcessing TaskStatus = "PROCESSING"
20 TaskStatusCompleted TaskStatus = "COMPLETED"
21 TaskStatusFailed TaskStatus = "FAILED"
22 TaskStatusRetrying TaskStatus = "RETRYING"
23)
24
25// AppTask represents a task in the queue
26type AppTask struct {
27 ID uint `gorm:"column:id;primarykey"`
28 Type string `gorm:"column:type;not null;index"`
29 TaskKey *string `gorm:"column:task_key;index:idx_task_dedup,unique"`
30 Status TaskStatus `gorm:"column:status;not null;index;default:'PENDING'"`
31 Payload json.RawMessage `gorm:"column:payload;type:jsonb"`
32 Priority int `gorm:"column:priority;default:0;index"`
33 TryCount int `gorm:"column:try_count;default:0"`
34 MaxTries int `gorm:"column:max_tries;default:3"`
35 LockExpires *time.Time `gorm:"column:lock_expires"`
36 WorkerID *string `gorm:"column:worker_id"`
37 Error *string `gorm:"column:error"`
38 CreatedAt time.Time `gorm:"column:created_at"`
39 UpdatedAt time.Time `gorm:"column:updated_at"`
40 ScheduledAt *time.Time `gorm:"column:scheduled_at"` // for delayed tasks
41}
42
43// EnqueueTask adds a new task to the queue
44func (state *StatefulDB) EnqueueTask(ctx context.Context, taskType string, payload any, options ...TaskOption) (*AppTask, error) {
45 payloadBytes, err := json.Marshal(payload)
46 if err != nil {
47 return nil, fmt.Errorf("failed to marshal payload: %w", err)
48 }
49
50 task := &AppTask{
51 Type: taskType,
52 Status: TaskStatusPending,
53 Payload: payloadBytes,
54 Priority: 0,
55 MaxTries: 3,
56 }
57
58 // Apply options
59 for _, opt := range options {
60 opt(task)
61 }
62
63 // If task has a key, check for deduplication
64 if task.TaskKey != nil {
65 existingTask, err := state.GetTaskByKey(ctx, *task.TaskKey)
66 if err != nil {
67 return nil, fmt.Errorf("failed to check for existing task: %w", err)
68 }
69 if existingTask != nil {
70 // Task already exists, return the existing one
71 return existingTask, nil
72 }
73 }
74
75 if err := state.DB.WithContext(ctx).Create(task).Error; err != nil {
76 // Handle unique constraint violation gracefully
77 if strings.Contains(err.Error(), "duplicate") || strings.Contains(err.Error(), "UNIQUE constraint") {
78 // Another node beat us to it, try to fetch the existing task
79 if task.TaskKey != nil {
80 existingTask, fetchErr := state.GetTaskByKey(ctx, *task.TaskKey)
81 if fetchErr == nil && existingTask != nil {
82 return existingTask, nil
83 }
84 }
85 }
86 return nil, fmt.Errorf("failed to enqueue task: %w", err)
87 }
88
89 go func() {
90 select {
91 case state.pokeQueue <- struct{}{}:
92 // wake up the queue processor
93 default:
94 // queue is already awake, do nothing
95 }
96 }()
97
98 return task, nil
99}
100
101// DequeueTask retrieves the next available task from the queue and locks it
102func (state *StatefulDB) DequeueTask(ctx context.Context, workerID string, taskTypes ...string) (*AppTask, error) {
103 var task AppTask
104
105 err := state.DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
106 query := tx.Where("status = ?", TaskStatusPending).
107 Where("try_count < max_tries").
108 Where("(lock_expires IS NULL OR lock_expires < ?)", time.Now()).
109 Where("(scheduled_at IS NULL OR scheduled_at <= ?)", time.Now())
110
111 if len(taskTypes) > 0 {
112 query = query.Where("type IN ?", taskTypes)
113 }
114
115 // Use raw SQL for PostgreSQL-specific locking
116 if state.Type == DBTypePostgres {
117 baseQuery := "SELECT * FROM app_tasks WHERE status = ? AND try_count < max_tries AND (lock_expires IS NULL OR lock_expires < ?) AND (scheduled_at IS NULL OR scheduled_at <= ?)"
118 if len(taskTypes) > 0 {
119 baseQuery += " AND type IN ?"
120 params := []interface{}{TaskStatusPending, time.Now(), time.Now(), taskTypes}
121 err := tx.Raw(baseQuery+" ORDER BY priority DESC, created_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED", params...).
122 Scan(&task).Error
123 if err != nil {
124 return err
125 }
126 } else {
127 err := tx.Raw(baseQuery+" ORDER BY priority DESC, created_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED",
128 TaskStatusPending, time.Now(), time.Now()).
129 Scan(&task).Error
130 if err != nil {
131 return err
132 }
133 }
134 } else {
135 // Fallback for SQLite (no SKIP LOCKED support)
136 err := query.Order("priority DESC, created_at ASC").First(&task).Error
137 if err != nil {
138 return err
139 }
140 }
141
142 if task.ID == 0 {
143 return gorm.ErrRecordNotFound
144 }
145
146 // Lock the task
147 lockExpires := time.Now().Add(30 * time.Minute) // 30-minute lock
148 updates := map[string]interface{}{
149 "status": TaskStatusProcessing,
150 "worker_id": workerID,
151 "lock_expires": lockExpires,
152 "try_count": task.TryCount + 1,
153 }
154
155 return tx.Model(&task).Updates(updates).Error
156 })
157
158 if err != nil {
159 if errors.Is(err, gorm.ErrRecordNotFound) {
160 return nil, nil // No tasks available
161 }
162 return nil, fmt.Errorf("failed to dequeue task: %w", err)
163 }
164
165 // Reload the task to get updated fields
166 if err := state.DB.WithContext(ctx).First(&task, task.ID).Error; err != nil {
167 return nil, fmt.Errorf("failed to reload task: %w", err)
168 }
169
170 return &task, nil
171}
172
173// CompleteTask marks a task as completed
174func (state *StatefulDB) CompleteTask(ctx context.Context, taskID uint) error {
175 result := state.DB.WithContext(ctx).Model(&AppTask{}).
176 Where("id = ?", taskID).
177 Updates(map[string]interface{}{
178 "status": TaskStatusCompleted,
179 "lock_expires": nil,
180 "worker_id": nil,
181 })
182
183 if result.Error != nil {
184 return fmt.Errorf("failed to complete task: %w", result.Error)
185 }
186
187 if result.RowsAffected == 0 {
188 return errors.New("task not found")
189 }
190
191 return nil
192}
193
194// FailTask marks a task as failed and optionally retries it
195func (state *StatefulDB) FailTask(ctx context.Context, taskID uint, errorMsg string) error {
196 var task AppTask
197 err := state.DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
198 if err := tx.First(&task, taskID).Error; err != nil {
199 return err
200 }
201
202 updates := map[string]interface{}{
203 "error": errorMsg,
204 "lock_expires": nil,
205 "worker_id": nil,
206 }
207
208 if task.TryCount >= task.MaxTries {
209 updates["status"] = TaskStatusFailed
210 } else {
211 updates["status"] = TaskStatusPending
212 }
213
214 return tx.Model(&task).Updates(updates).Error
215 })
216
217 if err != nil {
218 return fmt.Errorf("failed to mark task as failed: %w", err)
219 }
220
221 return nil
222}
223
224// ReleaseTask releases a locked task back to the queue (e.g., worker shutdown)
225func (state *StatefulDB) ReleaseTask(ctx context.Context, taskID uint) error {
226 result := state.DB.WithContext(ctx).Model(&AppTask{}).
227 Where("id = ?", taskID).
228 Updates(map[string]interface{}{
229 "status": TaskStatusPending,
230 "lock_expires": nil,
231 "worker_id": nil,
232 })
233
234 if result.Error != nil {
235 return fmt.Errorf("failed to release task: %w", result.Error)
236 }
237
238 if result.RowsAffected == 0 {
239 return errors.New("task not found")
240 }
241
242 return nil
243}
244
245// GetTask retrieves a task by ID
246func (state *StatefulDB) GetTask(ctx context.Context, taskID uint) (*AppTask, error) {
247 var task AppTask
248 if err := state.DB.WithContext(ctx).First(&task, taskID).Error; err != nil {
249 if errors.Is(err, gorm.ErrRecordNotFound) {
250 return nil, nil
251 }
252 return nil, fmt.Errorf("failed to get task: %w", err)
253 }
254 return &task, nil
255}
256
257// GetTaskByKey retrieves a task by its unique task key
258func (state *StatefulDB) GetTaskByKey(ctx context.Context, taskKey string) (*AppTask, error) {
259 var task AppTask
260 if err := state.DB.WithContext(ctx).Where("task_key = ?", taskKey).First(&task).Error; err != nil {
261 if errors.Is(err, gorm.ErrRecordNotFound) {
262 return nil, nil
263 }
264 return nil, fmt.Errorf("failed to get task by key: %w", err)
265 }
266 return &task, nil
267}
268
269// ListTasks retrieves tasks with optional filters
270func (state *StatefulDB) ListTasks(ctx context.Context, filters TaskFilters) ([]AppTask, error) {
271 var tasks []AppTask
272 query := state.DB.WithContext(ctx).Model(&AppTask{})
273
274 if filters.Status != "" {
275 query = query.Where("status = ?", filters.Status)
276 }
277 if filters.Type != "" {
278 query = query.Where("type = ?", filters.Type)
279 }
280 if filters.TaskKey != "" {
281 query = query.Where("task_key = ?", filters.TaskKey)
282 }
283 if filters.WorkerID != "" {
284 query = query.Where("worker_id = ?", filters.WorkerID)
285 }
286 if filters.Limit > 0 {
287 query = query.Limit(filters.Limit)
288 }
289 if filters.Offset > 0 {
290 query = query.Offset(filters.Offset)
291 }
292
293 query = query.Order("created_at DESC")
294
295 if err := query.Find(&tasks).Error; err != nil {
296 return nil, fmt.Errorf("failed to list tasks: %w", err)
297 }
298
299 return tasks, nil
300}
301
302// CleanupExpiredLocks releases tasks with expired locks
303func (state *StatefulDB) CleanupExpiredLocks(ctx context.Context) (int64, error) {
304 result := state.DB.WithContext(ctx).Model(&AppTask{}).
305 Where("status = ? AND lock_expires < ?", TaskStatusProcessing, time.Now()).
306 Updates(map[string]interface{}{
307 "status": TaskStatusPending,
308 "lock_expires": nil,
309 "worker_id": nil,
310 })
311
312 if result.Error != nil {
313 return 0, fmt.Errorf("failed to cleanup expired locks: %w", result.Error)
314 }
315
316 return result.RowsAffected, nil
317}
318
319// TaskOption is a function that configures a task
320type TaskOption func(*AppTask)
321
322// WithPriority sets the task priority (higher numbers = higher priority)
323func WithPriority(priority int) TaskOption {
324 return func(t *AppTask) {
325 t.Priority = priority
326 }
327}
328
329// WithMaxTries sets the maximum number of retry attempts
330func WithMaxTries(maxTries int) TaskOption {
331 return func(t *AppTask) {
332 t.MaxTries = maxTries
333 }
334}
335
336// WithScheduledAt sets when the task should be processed (for delayed tasks)
337func WithScheduledAt(scheduledAt time.Time) TaskOption {
338 return func(t *AppTask) {
339 t.ScheduledAt = &scheduledAt
340 }
341}
342
343// WithTaskKey sets a unique key for task deduplication
344func WithTaskKey(taskKey string) TaskOption {
345 return func(t *AppTask) {
346 t.TaskKey = &taskKey
347 }
348}
349
350// TaskFilters holds filters for listing tasks
351type TaskFilters struct {
352 Status TaskStatus
353 Type string
354 TaskKey string
355 WorkerID string
356 Limit int
357 Offset int
358}