package sync import ( "cmp" "context" "encoding/json" "fmt" "log/slog" "strings" "time" "github.com/bluesky-social/indigo/atproto/atclient" "github.com/bluesky-social/indigo/atproto/syntax" "github.com/failsafe-go/failsafe-go" "github.com/failsafe-go/failsafe-go/retrypolicy" "tangled.org/karitham.dev/lazuli/atproto" "tangled.org/karitham.dev/lazuli/cache" ) const ( WriteLimitDay = atproto.WriteLimitDay GlobalLimitDay = atproto.GlobalLimitDay RecordType = "fm.teal.alpha.feed.play" DefaultBatchSize = 20 DefaultCrossSourceTolerance = 5 * time.Minute CrossSourceTolerance = DefaultCrossSourceTolerance CacheTTL = 24 * time.Hour CacheVersion = 1 SlingshotResolverURL = "https://slingshot.microcosm.blue/xrpc/com.bad-example.identity.resolveMiniDoc" MaxRetryDelay = 15 * time.Minute MaxRetries = 1000 BaseRetryDelay = 2 * time.Second ) var DefaultRetryPolicy = retrypolicy.NewBuilder[struct{}](). WithMaxRetries(10). WithBackoff(BaseRetryDelay, 5*time.Minute). HandleIf(func(_ struct{}, err error) bool { return atproto.IsTransientError(err) }). OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[struct{}]) { slog.Warn("batch failed with transient error, retrying", slog.Duration("retryDelay", e.Delay), ErrorAttr(e.LastError()), slog.Int("attempt", e.Attempts())) }). Build() type ( ATProtoClient = atproto.RepoClient[*PlayRecord] AuthClient = atproto.AuthClient RateLimiter = atproto.RateLimiter Client = atproto.Client RepoClient[T any] = atproto.RepoClient[T] RecordRef = atproto.RecordRef[PlayRecord] PublishOptions struct { BatchSize int DryRun bool Reverse bool ATProtoClient ATProtoClient ProgressLog func(ProgressReport) Storage cache.RecordStore Limiter RateLimiter ClientAgent string RetryDelay time.Duration } publishResult struct { SuccessCount int `json:"successCount"` ErrorCount int `json:"errorCount"` Cancelled bool `json:"cancelled"` Duration time.Duration `json:"duration"` TotalRecords int `json:"totalRecords"` RecordsPerMinute float64 `json:"recordsPerMinute"` FirstRecordTime time.Time `json:"firstRecordTime"` LastRecordTime time.Time `json:"lastRecordTime"` } recordBatch struct { Records []*PlayRecord Keys []string } batchProcessor struct { Client ATProtoClient Storage cache.RecordStore DID string ClientAgent string DryRun bool } batchResult struct { SuccessCount int ErrorCount int Duration time.Duration Errors []error } ) func NewRateLimiter(kv atproto.KVStore, maxPercent float32) RateLimiter { return atproto.NewRateLimiter(kv, maxPercent) } func NewRateClient(client *atclient.APIClient, did string, limiter RateLimiter) *atproto.RateClient[*PlayRecord] { return atproto.NewRateClient[*PlayRecord](client, did, limiter) } func IsTransientError(err error) bool { return atproto.IsTransientError(err) } func NewClient(ctx context.Context, handle, password string, opts ...func(*atproto.ClientOptions)) (*Client, error) { return atproto.NewClient(ctx, handle, password, opts...) } // batchRecords iterates through storage and builds record batches func batchRecords(ctx context.Context, storage cache.RecordStore, did string, batchSize int, reverse bool) ([]recordBatch, error) { var batches []recordBatch var currentBatch recordBatch for key, rec := range storage.IterateUnpublished(did, reverse) { select { case <-ctx.Done(): return nil, ctx.Err() default: } var record PlayRecord if err := json.Unmarshal(rec, &record); err != nil { slog.Error("malformed record in storage", slog.String("key", key), ErrorAttr(err)) if storage != nil { _ = storage.MarkFailed(did, []string{key}, "malformed record") } continue // Skip malformed records } currentBatch.Records = append(currentBatch.Records, &record) currentBatch.Keys = append(currentBatch.Keys, key) if len(currentBatch.Records) >= batchSize { records := make([]*PlayRecord, len(currentBatch.Records)) copy(records, currentBatch.Records) keys := make([]string, len(currentBatch.Keys)) copy(keys, currentBatch.Keys) batches = append(batches, recordBatch{ Records: records, Keys: keys, }) currentBatch = recordBatch{} } } // Add the last partial batch if it has records if len(currentBatch.Records) > 0 { batches = append(batches, currentBatch) } return batches, nil } // processBatch processes a single batch of records with retries func processBatch(ctx context.Context, batch recordBatch, processor batchProcessor) batchResult { if len(batch.Records) == 0 { return batchResult{} } start := time.Now() if processor.DryRun { for _, r := range batch.Records { tid := syntax.NewTIDFromTime(r.PlayedTime.Time, 0) slog.Info("would publish record (dry run)", trackAttr(*r), slog.String("rkey", string(tid))) } return batchResult{ SuccessCount: len(batch.Records), Duration: time.Since(start), } } policy := DefaultRetryPolicy err := failsafe.With(policy).WithContext(ctx).Run(func() error { return PublishBatch(ctx, processor.Client, processor.DID, batch.Records, processor.Storage, processor.ClientAgent) }) if err != nil { slog.Error("batch failed after retries", ErrorAttr(err), slog.Int("count", len(batch.Records))) if processor.Storage != nil { if markErr := processor.Storage.MarkFailed(processor.DID, batch.Keys, err.Error()); markErr != nil { slog.Error("failed to mark records as failed", ErrorAttr(markErr)) } } return batchResult{ ErrorCount: len(batch.Records), Duration: time.Since(start), Errors: []error{err}, } } return batchResult{ SuccessCount: len(batch.Records), Duration: time.Since(start), } } // aggregate combines batch results into final publish result func aggregate(results []batchResult, startTime time.Time) publishResult { totalSuccess := 0 totalErrors := 0 for _, result := range results { totalSuccess += result.SuccessCount totalErrors += result.ErrorCount } logResult(totalSuccess, totalErrors, startTime) return newPublishResult(totalSuccess, totalErrors, totalSuccess+totalErrors, startTime, false) } func Publish(ctx context.Context, client AuthClient, opts PublishOptions) publishResult { startTime := time.Now() batchSize := cmp.Or(opts.BatchSize, DefaultBatchSize) atprotoClient, err := atproto.BuildClient(client, opts.ATProtoClient) if err != nil { return errorResult(startTime) } batches, err := batchRecords(ctx, opts.Storage, client.DID(), batchSize, opts.Reverse) if err != nil { cancelled := ctx.Err() != nil return newPublishResult(0, 0, 0, startTime, cancelled) } if len(batches) == 0 { return publishResult{} } slog.Info("starting iterative import", slog.Int("total_records", countTotalRecords(batches)), slog.Int("batch_size", batchSize), slog.Int("daily_write_limit", atproto.WriteLimitDay), slog.Int("daily_token_limit", atproto.GlobalLimitDay), slog.String("rate_limit", fmt.Sprintf("1 write per %.1fs", 86400.0/atproto.WriteLimitDay))) tracker := NewProgressTracker(countTotalRecords(batches), opts.Limiter) progressLog := defaultProgressLog(opts.ProgressLog) processor := batchProcessor{ Client: atprotoClient, Storage: opts.Storage, DID: client.DID(), ClientAgent: opts.ClientAgent, DryRun: opts.DryRun, } var results []batchResult for _, batch := range batches { select { case <-ctx.Done(): return aggregate(results, startTime) default: } result := processBatch(ctx, batch, processor) results = append(results, result) // Update progress tracking tracker.Increment(result.SuccessCount + result.ErrorCount) tracker.IncrementErrors(result.ErrorCount) if tracker.ShouldLog() { progressLog(tracker.Report()) } } return aggregate(results, startTime) } func countTotalRecords(batches []recordBatch) int { total := 0 for _, batch := range batches { total += len(batch.Records) } return total } func errorResult(startTime time.Time) publishResult { return publishResult{ SuccessCount: 0, ErrorCount: 0, Cancelled: false, Duration: time.Since(startTime), TotalRecords: 0, } } func defaultProgressLog(f func(ProgressReport)) func(ProgressReport) { if f != nil { return f } return func(pr ProgressReport) { slog.Info("sync progress", slog.Int("completed", pr.Completed), slog.Int("total", pr.Total), slog.Float64("percent", pr.Percent), slog.String("elapsed", pr.Elapsed), slog.String("eta", pr.ETA), slog.String("rate", pr.Rate), slog.Int("errors", pr.Errors), ) } } func newPublishResult(success, errors, total int, start time.Time, cancelled bool) publishResult { return publishResult{ SuccessCount: success, ErrorCount: errors, Cancelled: cancelled, Duration: time.Since(start), TotalRecords: total, RecordsPerMinute: ratePerMinute(success, time.Since(start)), } } func logResult(success, errors int, startTime time.Time) { if errors > 0 { slog.Warn("import completed with errors", slog.Int("success", success), slog.Int("errors", errors)) } slog.Info("import completed", slog.Int("success", success), slog.Int("errors", errors), slog.Duration("duration", time.Since(startTime)), slog.String("rate", formatRate(ratePerMinute(success, time.Since(startTime))))) } func PublishBatch(ctx context.Context, client ATProtoClient, did string, batch []*PlayRecord, storage cache.RecordStore, clientAgent string) error { if len(batch) == 0 { return nil } atprotoRecords := prepareRecords(batch, clientAgent) err := client.ApplyWrites(ctx, RecordType, atprotoRecords) if err != nil { slog.Error("batch publish failed", ErrorAttr(err)) return err } if storage != nil && did != "" { keys := CreateRecordKeys(batch) cacheEntries := make(map[string][]byte) for i, rec := range batch { key := keys[i] value, _ := json.Marshal(rec) cacheEntries[key] = value } if err := storage.SaveRecords(did, cacheEntries); err != nil { return fmt.Errorf("failed to save records to storage: %w", err) } if err := storage.MarkPublished(did, keys...); err != nil { return fmt.Errorf("failed to mark records as published: %w", err) } } return nil } func prepareRecords(batch []*PlayRecord, clientAgent string) []*PlayRecord { atprotoRecords := make([]*PlayRecord, 0, len(batch)) for _, record := range batch { record.Type = RecordType record.SubmissionClientAgent = clientAgent atprotoRecords = append(atprotoRecords, record) } return atprotoRecords } func ratePerMinute(count int, duration time.Duration) float64 { if duration == 0 { return 0 } return float64(count) / duration.Minutes() } func FetchExisting(ctx context.Context, client RepoClient[*PlayRecord], did string, storage cache.Storage, forceRefresh bool) ([]ExistingRecord, error) { if !forceRefresh && storage != nil { published, err := storage.GetPublished(did) if err == nil && len(published) > 0 && storage.IsValid(did) { records := make([]ExistingRecord, 0, len(published)) for _, data := range storage.IteratePublished(did, false) { var value PlayRecord if err := json.Unmarshal(data, &value); err != nil { slog.Debug("failed to unmarshal cached record", ErrorAttr(err)) continue } records = append(records, ExistingRecord{ URI: generateRecordURI(did, &value), Value: &value, }) } if len(records) > 0 { slog.Debug("loaded from cache", slog.Int("count", len(records))) return records, nil } } } select { case <-ctx.Done(): return nil, ctx.Err() default: } allRecords := make([]ExistingRecord, 0, 1024) return fetchExistingLoop(ctx, client, did, storage, allRecords) } func fetchExistingLoop(ctx context.Context, client RepoClient[*PlayRecord], did string, storage cache.Storage, allRecords []ExistingRecord) ([]ExistingRecord, error) { const batchSize = 100 var cursor string type fetchResult struct { records []atproto.RecordRef[*PlayRecord] cursor string } fetchRetryPolicy := retrypolicy.NewBuilder[fetchResult](). WithMaxRetries(10). WithBackoff(BaseRetryDelay, 5*time.Minute). HandleIf(func(_ fetchResult, err error) bool { return IsTransientError(err) }). OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[fetchResult]) { slog.Warn("fetch failed with transient error, retrying", slog.Duration("retryDelay", e.Delay), ErrorAttr(e.LastError()), slog.Int("attempt", e.Attempts())) }). Build() for { select { case <-ctx.Done(): return nil, ctx.Err() default: } result, err := failsafe.With(fetchRetryPolicy). WithContext(ctx). Get(func() (fetchResult, error) { recs, next, err := client.ListRecords(ctx, RecordType, batchSize, cursor) if err != nil { return fetchResult{}, err } return fetchResult{records: recs, cursor: next}, nil }) if err != nil { return nil, err } for _, rec := range result.records { allRecords = append(allRecords, ExistingRecord(rec)) } if result.cursor == "" || len(result.records) < batchSize { break } cursor = result.cursor } if storage != nil { cacheEntries := make(map[string][]byte) keys := make([]string, 0, len(allRecords)) for _, rec := range allRecords { parts := strings.Split(rec.URI, "/") key := parts[len(parts)-1] if key == "" { key = CreateRecordKey(rec.Value) } value, _ := json.Marshal(rec.Value) cacheEntries[key] = value keys = append(keys, key) } if err := storage.SaveRecords(did, cacheEntries); err != nil { return nil, err } if err := storage.MarkPublished(did, keys...); err != nil { return nil, err } slog.Debug("saved to cache and marked as published", slog.Int("count", len(allRecords))) } return allRecords, nil } func generateRecordURI(did string, record *PlayRecord) string { return fmt.Sprintf("at://%s/%s/%s", did, RecordType, CreateRecordKey(record)) } func prepareWrites(records []*PlayRecord, collection string) ([]map[string]any, error) { if len(records) == 0 { return nil, nil } writes := make([]map[string]any, len(records)) keys := CreateRecordKeys(records) for i, rec := range records { writes[i] = map[string]any{ "$type": "com.atproto.repo.applyWrites#create", "collection": collection, "rkey": keys[i], "value": rec, } } return writes, nil }