like malachite (atproto-lastfm-importer) but in go and bluer
go spotify tealfm lastfm atproto
at main 532 lines 15 kB view raw
1package sync 2 3import ( 4 "cmp" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log/slog" 9 "strings" 10 "time" 11 12 "github.com/bluesky-social/indigo/atproto/atclient" 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 "github.com/failsafe-go/failsafe-go" 15 "github.com/failsafe-go/failsafe-go/retrypolicy" 16 17 "tangled.org/karitham.dev/lazuli/atproto" 18 "tangled.org/karitham.dev/lazuli/cache" 19) 20 21const ( 22 WriteLimitDay = atproto.WriteLimitDay 23 GlobalLimitDay = atproto.GlobalLimitDay 24 25 RecordType = "fm.teal.alpha.feed.play" 26 DefaultBatchSize = 20 27 DefaultCrossSourceTolerance = 5 * time.Minute 28 CrossSourceTolerance = DefaultCrossSourceTolerance 29 CacheTTL = 24 * time.Hour 30 CacheVersion = 1 31 SlingshotResolverURL = "https://slingshot.microcosm.blue/xrpc/com.bad-example.identity.resolveMiniDoc" 32 MaxRetryDelay = 15 * time.Minute 33 MaxRetries = 1000 34 BaseRetryDelay = 2 * time.Second 35) 36 37var DefaultRetryPolicy = retrypolicy.NewBuilder[struct{}](). 38 WithMaxRetries(10). 39 WithBackoff(BaseRetryDelay, 5*time.Minute). 40 HandleIf(func(_ struct{}, err error) bool { 41 return atproto.IsTransientError(err) 42 }). 43 OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[struct{}]) { 44 slog.Warn("batch failed with transient error, retrying", 45 slog.Duration("retryDelay", e.Delay), 46 ErrorAttr(e.LastError()), 47 slog.Int("attempt", e.Attempts())) 48 }). 49 Build() 50 51type ( 52 ATProtoClient = atproto.RepoClient[*PlayRecord] 53 AuthClient = atproto.AuthClient 54 RateLimiter = atproto.RateLimiter 55 Client = atproto.Client 56 RepoClient[T any] = atproto.RepoClient[T] 57 RecordRef = atproto.RecordRef[PlayRecord] 58 59 PublishOptions struct { 60 BatchSize int 61 DryRun bool 62 Reverse bool 63 ATProtoClient ATProtoClient 64 ProgressLog func(ProgressReport) 65 Storage cache.RecordStore 66 Limiter RateLimiter 67 ClientAgent string 68 RetryDelay time.Duration 69 } 70 71 publishResult struct { 72 SuccessCount int `json:"successCount"` 73 ErrorCount int `json:"errorCount"` 74 Cancelled bool `json:"cancelled"` 75 Duration time.Duration `json:"duration"` 76 TotalRecords int `json:"totalRecords"` 77 RecordsPerMinute float64 `json:"recordsPerMinute"` 78 FirstRecordTime time.Time `json:"firstRecordTime"` 79 LastRecordTime time.Time `json:"lastRecordTime"` 80 } 81 82 recordBatch struct { 83 Records []*PlayRecord 84 Keys []string 85 } 86 87 batchProcessor struct { 88 Client ATProtoClient 89 Storage cache.RecordStore 90 DID string 91 ClientAgent string 92 DryRun bool 93 } 94 95 batchResult struct { 96 SuccessCount int 97 ErrorCount int 98 Duration time.Duration 99 Errors []error 100 } 101) 102 103func NewRateLimiter(kv atproto.KVStore, maxPercent float32) RateLimiter { 104 return atproto.NewRateLimiter(kv, maxPercent) 105} 106 107func NewRateClient(client *atclient.APIClient, did string, limiter RateLimiter) *atproto.RateClient[*PlayRecord] { 108 return atproto.NewRateClient[*PlayRecord](client, did, limiter) 109} 110 111func IsTransientError(err error) bool { 112 return atproto.IsTransientError(err) 113} 114 115func NewClient(ctx context.Context, handle, password string, opts ...func(*atproto.ClientOptions)) (*Client, error) { 116 return atproto.NewClient(ctx, handle, password, opts...) 117} 118 119// batchRecords iterates through storage and builds record batches 120func batchRecords(ctx context.Context, storage cache.RecordStore, did string, batchSize int, reverse bool) ([]recordBatch, error) { 121 var batches []recordBatch 122 var currentBatch recordBatch 123 124 for key, rec := range storage.IterateUnpublished(did, reverse) { 125 select { 126 case <-ctx.Done(): 127 return nil, ctx.Err() 128 default: 129 } 130 131 var record PlayRecord 132 if err := json.Unmarshal(rec, &record); err != nil { 133 slog.Error("malformed record in storage", slog.String("key", key), ErrorAttr(err)) 134 if storage != nil { 135 _ = storage.MarkFailed(did, []string{key}, "malformed record") 136 } 137 continue // Skip malformed records 138 } 139 140 currentBatch.Records = append(currentBatch.Records, &record) 141 currentBatch.Keys = append(currentBatch.Keys, key) 142 143 if len(currentBatch.Records) >= batchSize { 144 records := make([]*PlayRecord, len(currentBatch.Records)) 145 copy(records, currentBatch.Records) 146 keys := make([]string, len(currentBatch.Keys)) 147 copy(keys, currentBatch.Keys) 148 batches = append(batches, recordBatch{ 149 Records: records, 150 Keys: keys, 151 }) 152 currentBatch = recordBatch{} 153 } 154 } 155 156 // Add the last partial batch if it has records 157 if len(currentBatch.Records) > 0 { 158 batches = append(batches, currentBatch) 159 } 160 161 return batches, nil 162} 163 164// processBatch processes a single batch of records with retries 165func processBatch(ctx context.Context, batch recordBatch, processor batchProcessor) batchResult { 166 if len(batch.Records) == 0 { 167 return batchResult{} 168 } 169 170 start := time.Now() 171 172 if processor.DryRun { 173 for _, r := range batch.Records { 174 tid := syntax.NewTIDFromTime(r.PlayedTime.Time, 0) 175 slog.Info("would publish record (dry run)", trackAttr(*r), slog.String("rkey", string(tid))) 176 } 177 return batchResult{ 178 SuccessCount: len(batch.Records), 179 Duration: time.Since(start), 180 } 181 } 182 183 policy := DefaultRetryPolicy 184 185 err := failsafe.With(policy).WithContext(ctx).Run(func() error { 186 return PublishBatch(ctx, processor.Client, processor.DID, batch.Records, processor.Storage, processor.ClientAgent) 187 }) 188 if err != nil { 189 slog.Error("batch failed after retries", 190 ErrorAttr(err), 191 slog.Int("count", len(batch.Records))) 192 193 if processor.Storage != nil { 194 if markErr := processor.Storage.MarkFailed(processor.DID, batch.Keys, err.Error()); markErr != nil { 195 slog.Error("failed to mark records as failed", ErrorAttr(markErr)) 196 } 197 } 198 199 return batchResult{ 200 ErrorCount: len(batch.Records), 201 Duration: time.Since(start), 202 Errors: []error{err}, 203 } 204 } 205 206 return batchResult{ 207 SuccessCount: len(batch.Records), 208 Duration: time.Since(start), 209 } 210} 211 212// aggregate combines batch results into final publish result 213func aggregate(results []batchResult, startTime time.Time) publishResult { 214 totalSuccess := 0 215 totalErrors := 0 216 217 for _, result := range results { 218 totalSuccess += result.SuccessCount 219 totalErrors += result.ErrorCount 220 } 221 222 logResult(totalSuccess, totalErrors, startTime) 223 return newPublishResult(totalSuccess, totalErrors, totalSuccess+totalErrors, startTime, false) 224} 225 226func Publish(ctx context.Context, client AuthClient, opts PublishOptions) publishResult { 227 startTime := time.Now() 228 batchSize := cmp.Or(opts.BatchSize, DefaultBatchSize) 229 230 atprotoClient, err := atproto.BuildClient(client, opts.ATProtoClient) 231 if err != nil { 232 return errorResult(startTime) 233 } 234 235 batches, err := batchRecords(ctx, opts.Storage, client.DID(), batchSize, opts.Reverse) 236 if err != nil { 237 cancelled := ctx.Err() != nil 238 return newPublishResult(0, 0, 0, startTime, cancelled) 239 } 240 241 if len(batches) == 0 { 242 return publishResult{} 243 } 244 245 slog.Info("starting iterative import", 246 slog.Int("total_records", countTotalRecords(batches)), 247 slog.Int("batch_size", batchSize), 248 slog.Int("daily_write_limit", atproto.WriteLimitDay), 249 slog.Int("daily_token_limit", atproto.GlobalLimitDay), 250 slog.String("rate_limit", fmt.Sprintf("1 write per %.1fs", 86400.0/atproto.WriteLimitDay))) 251 252 tracker := NewProgressTracker(countTotalRecords(batches), opts.Limiter) 253 progressLog := defaultProgressLog(opts.ProgressLog) 254 255 processor := batchProcessor{ 256 Client: atprotoClient, 257 Storage: opts.Storage, 258 DID: client.DID(), 259 ClientAgent: opts.ClientAgent, 260 DryRun: opts.DryRun, 261 } 262 263 var results []batchResult 264 for _, batch := range batches { 265 select { 266 case <-ctx.Done(): 267 return aggregate(results, startTime) 268 default: 269 } 270 271 result := processBatch(ctx, batch, processor) 272 results = append(results, result) 273 274 // Update progress tracking 275 tracker.Increment(result.SuccessCount + result.ErrorCount) 276 tracker.IncrementErrors(result.ErrorCount) 277 278 if tracker.ShouldLog() { 279 progressLog(tracker.Report()) 280 } 281 } 282 283 return aggregate(results, startTime) 284} 285 286func countTotalRecords(batches []recordBatch) int { 287 total := 0 288 for _, batch := range batches { 289 total += len(batch.Records) 290 } 291 return total 292} 293 294func errorResult(startTime time.Time) publishResult { 295 return publishResult{ 296 SuccessCount: 0, 297 ErrorCount: 0, 298 Cancelled: false, 299 Duration: time.Since(startTime), 300 TotalRecords: 0, 301 } 302} 303 304func defaultProgressLog(f func(ProgressReport)) func(ProgressReport) { 305 if f != nil { 306 return f 307 } 308 return func(pr ProgressReport) { 309 slog.Info("sync progress", 310 slog.Int("completed", pr.Completed), 311 slog.Int("total", pr.Total), 312 slog.Float64("percent", pr.Percent), 313 slog.String("elapsed", pr.Elapsed), 314 slog.String("eta", pr.ETA), 315 slog.String("rate", pr.Rate), 316 slog.Int("errors", pr.Errors), 317 ) 318 } 319} 320 321func newPublishResult(success, errors, total int, start time.Time, cancelled bool) publishResult { 322 return publishResult{ 323 SuccessCount: success, 324 ErrorCount: errors, 325 Cancelled: cancelled, 326 Duration: time.Since(start), 327 TotalRecords: total, 328 RecordsPerMinute: ratePerMinute(success, time.Since(start)), 329 } 330} 331 332func logResult(success, errors int, startTime time.Time) { 333 if errors > 0 { 334 slog.Warn("import completed with errors", 335 slog.Int("success", success), 336 slog.Int("errors", errors)) 337 } 338 slog.Info("import completed", 339 slog.Int("success", success), 340 slog.Int("errors", errors), 341 slog.Duration("duration", time.Since(startTime)), 342 slog.String("rate", formatRate(ratePerMinute(success, time.Since(startTime))))) 343} 344 345func PublishBatch(ctx context.Context, client ATProtoClient, did string, batch []*PlayRecord, storage cache.RecordStore, clientAgent string) error { 346 if len(batch) == 0 { 347 return nil 348 } 349 350 atprotoRecords := prepareRecords(batch, clientAgent) 351 err := client.ApplyWrites(ctx, RecordType, atprotoRecords) 352 if err != nil { 353 slog.Error("batch publish failed", ErrorAttr(err)) 354 return err 355 } 356 357 if storage != nil && did != "" { 358 keys := CreateRecordKeys(batch) 359 cacheEntries := make(map[string][]byte) 360 for i, rec := range batch { 361 key := keys[i] 362 value, _ := json.Marshal(rec) 363 cacheEntries[key] = value 364 } 365 366 if err := storage.SaveRecords(did, cacheEntries); err != nil { 367 return fmt.Errorf("failed to save records to storage: %w", err) 368 } 369 370 if err := storage.MarkPublished(did, keys...); err != nil { 371 return fmt.Errorf("failed to mark records as published: %w", err) 372 } 373 } 374 375 return nil 376} 377 378func prepareRecords(batch []*PlayRecord, clientAgent string) []*PlayRecord { 379 atprotoRecords := make([]*PlayRecord, 0, len(batch)) 380 for _, record := range batch { 381 record.Type = RecordType 382 record.SubmissionClientAgent = clientAgent 383 atprotoRecords = append(atprotoRecords, record) 384 } 385 return atprotoRecords 386} 387 388func ratePerMinute(count int, duration time.Duration) float64 { 389 if duration == 0 { 390 return 0 391 } 392 return float64(count) / duration.Minutes() 393} 394 395func FetchExisting(ctx context.Context, client RepoClient[*PlayRecord], did string, storage cache.Storage, forceRefresh bool) ([]ExistingRecord, error) { 396 if !forceRefresh && storage != nil { 397 published, err := storage.GetPublished(did) 398 if err == nil && len(published) > 0 && storage.IsValid(did) { 399 records := make([]ExistingRecord, 0, len(published)) 400 for _, data := range storage.IteratePublished(did, false) { 401 var value PlayRecord 402 if err := json.Unmarshal(data, &value); err != nil { 403 slog.Debug("failed to unmarshal cached record", ErrorAttr(err)) 404 continue 405 } 406 records = append(records, ExistingRecord{ 407 URI: generateRecordURI(did, &value), 408 Value: &value, 409 }) 410 } 411 if len(records) > 0 { 412 slog.Debug("loaded from cache", slog.Int("count", len(records))) 413 return records, nil 414 } 415 } 416 } 417 418 select { 419 case <-ctx.Done(): 420 return nil, ctx.Err() 421 default: 422 } 423 424 allRecords := make([]ExistingRecord, 0, 1024) 425 return fetchExistingLoop(ctx, client, did, storage, allRecords) 426} 427 428func fetchExistingLoop(ctx context.Context, client RepoClient[*PlayRecord], did string, storage cache.Storage, allRecords []ExistingRecord) ([]ExistingRecord, error) { 429 const batchSize = 100 430 var cursor string 431 432 type fetchResult struct { 433 records []atproto.RecordRef[*PlayRecord] 434 cursor string 435 } 436 437 fetchRetryPolicy := retrypolicy.NewBuilder[fetchResult](). 438 WithMaxRetries(10). 439 WithBackoff(BaseRetryDelay, 5*time.Minute). 440 HandleIf(func(_ fetchResult, err error) bool { 441 return IsTransientError(err) 442 }). 443 OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[fetchResult]) { 444 slog.Warn("fetch failed with transient error, retrying", 445 slog.Duration("retryDelay", e.Delay), 446 ErrorAttr(e.LastError()), 447 slog.Int("attempt", e.Attempts())) 448 }). 449 Build() 450 451 for { 452 select { 453 case <-ctx.Done(): 454 return nil, ctx.Err() 455 default: 456 } 457 458 result, err := failsafe.With(fetchRetryPolicy). 459 WithContext(ctx). 460 Get(func() (fetchResult, error) { 461 recs, next, err := client.ListRecords(ctx, RecordType, batchSize, cursor) 462 if err != nil { 463 return fetchResult{}, err 464 } 465 466 return fetchResult{records: recs, cursor: next}, nil 467 }) 468 if err != nil { 469 return nil, err 470 } 471 472 for _, rec := range result.records { 473 allRecords = append(allRecords, ExistingRecord(rec)) 474 } 475 476 if result.cursor == "" || len(result.records) < batchSize { 477 break 478 } 479 cursor = result.cursor 480 } 481 482 if storage != nil { 483 cacheEntries := make(map[string][]byte) 484 keys := make([]string, 0, len(allRecords)) 485 for _, rec := range allRecords { 486 parts := strings.Split(rec.URI, "/") 487 key := parts[len(parts)-1] 488 if key == "" { 489 key = CreateRecordKey(rec.Value) 490 } 491 value, _ := json.Marshal(rec.Value) 492 cacheEntries[key] = value 493 keys = append(keys, key) 494 } 495 496 if err := storage.SaveRecords(did, cacheEntries); err != nil { 497 return nil, err 498 } 499 500 if err := storage.MarkPublished(did, keys...); err != nil { 501 return nil, err 502 } 503 504 slog.Debug("saved to cache and marked as published", slog.Int("count", len(allRecords))) 505 } 506 507 return allRecords, nil 508} 509 510func generateRecordURI(did string, record *PlayRecord) string { 511 return fmt.Sprintf("at://%s/%s/%s", did, RecordType, CreateRecordKey(record)) 512} 513 514func prepareWrites(records []*PlayRecord, collection string) ([]map[string]any, error) { 515 if len(records) == 0 { 516 return nil, nil 517 } 518 519 writes := make([]map[string]any, len(records)) 520 keys := CreateRecordKeys(records) 521 522 for i, rec := range records { 523 writes[i] = map[string]any{ 524 "$type": "com.atproto.repo.applyWrites#create", 525 "collection": collection, 526 "rkey": keys[i], 527 "value": rec, 528 } 529 } 530 531 return writes, nil 532}