like malachite (atproto-lastfm-importer) but in go and bluer
go
spotify
tealfm
lastfm
atproto
1package sync
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/atclient"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 "github.com/failsafe-go/failsafe-go"
15 "github.com/failsafe-go/failsafe-go/retrypolicy"
16
17 "tangled.org/karitham.dev/lazuli/atproto"
18 "tangled.org/karitham.dev/lazuli/cache"
19)
20
21const (
22 WriteLimitDay = atproto.WriteLimitDay
23 GlobalLimitDay = atproto.GlobalLimitDay
24
25 RecordType = "fm.teal.alpha.feed.play"
26 DefaultBatchSize = 20
27 DefaultCrossSourceTolerance = 5 * time.Minute
28 CrossSourceTolerance = DefaultCrossSourceTolerance
29 CacheTTL = 24 * time.Hour
30 CacheVersion = 1
31 SlingshotResolverURL = "https://slingshot.microcosm.blue/xrpc/com.bad-example.identity.resolveMiniDoc"
32 MaxRetryDelay = 15 * time.Minute
33 MaxRetries = 1000
34 BaseRetryDelay = 2 * time.Second
35)
36
37var DefaultRetryPolicy = retrypolicy.NewBuilder[struct{}]().
38 WithMaxRetries(10).
39 WithBackoff(BaseRetryDelay, 5*time.Minute).
40 HandleIf(func(_ struct{}, err error) bool {
41 return atproto.IsTransientError(err)
42 }).
43 OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[struct{}]) {
44 slog.Warn("batch failed with transient error, retrying",
45 slog.Duration("retryDelay", e.Delay),
46 ErrorAttr(e.LastError()),
47 slog.Int("attempt", e.Attempts()))
48 }).
49 Build()
50
51type (
52 ATProtoClient = atproto.RepoClient[*PlayRecord]
53 AuthClient = atproto.AuthClient
54 RateLimiter = atproto.RateLimiter
55 Client = atproto.Client
56 RepoClient[T any] = atproto.RepoClient[T]
57 RecordRef = atproto.RecordRef[PlayRecord]
58
59 PublishOptions struct {
60 BatchSize int
61 DryRun bool
62 Reverse bool
63 ATProtoClient ATProtoClient
64 ProgressLog func(ProgressReport)
65 Storage cache.RecordStore
66 Limiter RateLimiter
67 ClientAgent string
68 RetryDelay time.Duration
69 }
70
71 publishResult struct {
72 SuccessCount int `json:"successCount"`
73 ErrorCount int `json:"errorCount"`
74 Cancelled bool `json:"cancelled"`
75 Duration time.Duration `json:"duration"`
76 TotalRecords int `json:"totalRecords"`
77 RecordsPerMinute float64 `json:"recordsPerMinute"`
78 FirstRecordTime time.Time `json:"firstRecordTime"`
79 LastRecordTime time.Time `json:"lastRecordTime"`
80 }
81
82 recordBatch struct {
83 Records []*PlayRecord
84 Keys []string
85 }
86
87 batchProcessor struct {
88 Client ATProtoClient
89 Storage cache.RecordStore
90 DID string
91 ClientAgent string
92 DryRun bool
93 }
94
95 batchResult struct {
96 SuccessCount int
97 ErrorCount int
98 Duration time.Duration
99 Errors []error
100 }
101)
102
103func NewRateLimiter(kv atproto.KVStore, maxPercent float32) RateLimiter {
104 return atproto.NewRateLimiter(kv, maxPercent)
105}
106
107func NewRateClient(client *atclient.APIClient, did string, limiter RateLimiter) *atproto.RateClient[*PlayRecord] {
108 return atproto.NewRateClient[*PlayRecord](client, did, limiter)
109}
110
111func IsTransientError(err error) bool {
112 return atproto.IsTransientError(err)
113}
114
115func NewClient(ctx context.Context, handle, password string, opts ...func(*atproto.ClientOptions)) (*Client, error) {
116 return atproto.NewClient(ctx, handle, password, opts...)
117}
118
119// batchRecords iterates through storage and builds record batches
120func batchRecords(ctx context.Context, storage cache.RecordStore, did string, batchSize int, reverse bool) ([]recordBatch, error) {
121 var batches []recordBatch
122 var currentBatch recordBatch
123
124 for key, rec := range storage.IterateUnpublished(did, reverse) {
125 select {
126 case <-ctx.Done():
127 return nil, ctx.Err()
128 default:
129 }
130
131 var record PlayRecord
132 if err := json.Unmarshal(rec, &record); err != nil {
133 slog.Error("malformed record in storage", slog.String("key", key), ErrorAttr(err))
134 if storage != nil {
135 _ = storage.MarkFailed(did, []string{key}, "malformed record")
136 }
137 continue // Skip malformed records
138 }
139
140 currentBatch.Records = append(currentBatch.Records, &record)
141 currentBatch.Keys = append(currentBatch.Keys, key)
142
143 if len(currentBatch.Records) >= batchSize {
144 records := make([]*PlayRecord, len(currentBatch.Records))
145 copy(records, currentBatch.Records)
146 keys := make([]string, len(currentBatch.Keys))
147 copy(keys, currentBatch.Keys)
148 batches = append(batches, recordBatch{
149 Records: records,
150 Keys: keys,
151 })
152 currentBatch = recordBatch{}
153 }
154 }
155
156 // Add the last partial batch if it has records
157 if len(currentBatch.Records) > 0 {
158 batches = append(batches, currentBatch)
159 }
160
161 return batches, nil
162}
163
164// processBatch processes a single batch of records with retries
165func processBatch(ctx context.Context, batch recordBatch, processor batchProcessor) batchResult {
166 if len(batch.Records) == 0 {
167 return batchResult{}
168 }
169
170 start := time.Now()
171
172 if processor.DryRun {
173 for _, r := range batch.Records {
174 tid := syntax.NewTIDFromTime(r.PlayedTime.Time, 0)
175 slog.Info("would publish record (dry run)", trackAttr(*r), slog.String("rkey", string(tid)))
176 }
177 return batchResult{
178 SuccessCount: len(batch.Records),
179 Duration: time.Since(start),
180 }
181 }
182
183 policy := DefaultRetryPolicy
184
185 err := failsafe.With(policy).WithContext(ctx).Run(func() error {
186 return PublishBatch(ctx, processor.Client, processor.DID, batch.Records, processor.Storage, processor.ClientAgent)
187 })
188 if err != nil {
189 slog.Error("batch failed after retries",
190 ErrorAttr(err),
191 slog.Int("count", len(batch.Records)))
192
193 if processor.Storage != nil {
194 if markErr := processor.Storage.MarkFailed(processor.DID, batch.Keys, err.Error()); markErr != nil {
195 slog.Error("failed to mark records as failed", ErrorAttr(markErr))
196 }
197 }
198
199 return batchResult{
200 ErrorCount: len(batch.Records),
201 Duration: time.Since(start),
202 Errors: []error{err},
203 }
204 }
205
206 return batchResult{
207 SuccessCount: len(batch.Records),
208 Duration: time.Since(start),
209 }
210}
211
212// aggregate combines batch results into final publish result
213func aggregate(results []batchResult, startTime time.Time) publishResult {
214 totalSuccess := 0
215 totalErrors := 0
216
217 for _, result := range results {
218 totalSuccess += result.SuccessCount
219 totalErrors += result.ErrorCount
220 }
221
222 logResult(totalSuccess, totalErrors, startTime)
223 return newPublishResult(totalSuccess, totalErrors, totalSuccess+totalErrors, startTime, false)
224}
225
226func Publish(ctx context.Context, client AuthClient, opts PublishOptions) publishResult {
227 startTime := time.Now()
228 batchSize := cmp.Or(opts.BatchSize, DefaultBatchSize)
229
230 atprotoClient, err := atproto.BuildClient(client, opts.ATProtoClient)
231 if err != nil {
232 return errorResult(startTime)
233 }
234
235 batches, err := batchRecords(ctx, opts.Storage, client.DID(), batchSize, opts.Reverse)
236 if err != nil {
237 cancelled := ctx.Err() != nil
238 return newPublishResult(0, 0, 0, startTime, cancelled)
239 }
240
241 if len(batches) == 0 {
242 return publishResult{}
243 }
244
245 slog.Info("starting iterative import",
246 slog.Int("total_records", countTotalRecords(batches)),
247 slog.Int("batch_size", batchSize),
248 slog.Int("daily_write_limit", atproto.WriteLimitDay),
249 slog.Int("daily_token_limit", atproto.GlobalLimitDay),
250 slog.String("rate_limit", fmt.Sprintf("1 write per %.1fs", 86400.0/atproto.WriteLimitDay)))
251
252 tracker := NewProgressTracker(countTotalRecords(batches), opts.Limiter)
253 progressLog := defaultProgressLog(opts.ProgressLog)
254
255 processor := batchProcessor{
256 Client: atprotoClient,
257 Storage: opts.Storage,
258 DID: client.DID(),
259 ClientAgent: opts.ClientAgent,
260 DryRun: opts.DryRun,
261 }
262
263 var results []batchResult
264 for _, batch := range batches {
265 select {
266 case <-ctx.Done():
267 return aggregate(results, startTime)
268 default:
269 }
270
271 result := processBatch(ctx, batch, processor)
272 results = append(results, result)
273
274 // Update progress tracking
275 tracker.Increment(result.SuccessCount + result.ErrorCount)
276 tracker.IncrementErrors(result.ErrorCount)
277
278 if tracker.ShouldLog() {
279 progressLog(tracker.Report())
280 }
281 }
282
283 return aggregate(results, startTime)
284}
285
286func countTotalRecords(batches []recordBatch) int {
287 total := 0
288 for _, batch := range batches {
289 total += len(batch.Records)
290 }
291 return total
292}
293
294func errorResult(startTime time.Time) publishResult {
295 return publishResult{
296 SuccessCount: 0,
297 ErrorCount: 0,
298 Cancelled: false,
299 Duration: time.Since(startTime),
300 TotalRecords: 0,
301 }
302}
303
304func defaultProgressLog(f func(ProgressReport)) func(ProgressReport) {
305 if f != nil {
306 return f
307 }
308 return func(pr ProgressReport) {
309 slog.Info("sync progress",
310 slog.Int("completed", pr.Completed),
311 slog.Int("total", pr.Total),
312 slog.Float64("percent", pr.Percent),
313 slog.String("elapsed", pr.Elapsed),
314 slog.String("eta", pr.ETA),
315 slog.String("rate", pr.Rate),
316 slog.Int("errors", pr.Errors),
317 )
318 }
319}
320
321func newPublishResult(success, errors, total int, start time.Time, cancelled bool) publishResult {
322 return publishResult{
323 SuccessCount: success,
324 ErrorCount: errors,
325 Cancelled: cancelled,
326 Duration: time.Since(start),
327 TotalRecords: total,
328 RecordsPerMinute: ratePerMinute(success, time.Since(start)),
329 }
330}
331
332func logResult(success, errors int, startTime time.Time) {
333 if errors > 0 {
334 slog.Warn("import completed with errors",
335 slog.Int("success", success),
336 slog.Int("errors", errors))
337 }
338 slog.Info("import completed",
339 slog.Int("success", success),
340 slog.Int("errors", errors),
341 slog.Duration("duration", time.Since(startTime)),
342 slog.String("rate", formatRate(ratePerMinute(success, time.Since(startTime)))))
343}
344
345func PublishBatch(ctx context.Context, client ATProtoClient, did string, batch []*PlayRecord, storage cache.RecordStore, clientAgent string) error {
346 if len(batch) == 0 {
347 return nil
348 }
349
350 atprotoRecords := prepareRecords(batch, clientAgent)
351 err := client.ApplyWrites(ctx, RecordType, atprotoRecords)
352 if err != nil {
353 slog.Error("batch publish failed", ErrorAttr(err))
354 return err
355 }
356
357 if storage != nil && did != "" {
358 keys := CreateRecordKeys(batch)
359 cacheEntries := make(map[string][]byte)
360 for i, rec := range batch {
361 key := keys[i]
362 value, _ := json.Marshal(rec)
363 cacheEntries[key] = value
364 }
365
366 if err := storage.SaveRecords(did, cacheEntries); err != nil {
367 return fmt.Errorf("failed to save records to storage: %w", err)
368 }
369
370 if err := storage.MarkPublished(did, keys...); err != nil {
371 return fmt.Errorf("failed to mark records as published: %w", err)
372 }
373 }
374
375 return nil
376}
377
378func prepareRecords(batch []*PlayRecord, clientAgent string) []*PlayRecord {
379 atprotoRecords := make([]*PlayRecord, 0, len(batch))
380 for _, record := range batch {
381 record.Type = RecordType
382 record.SubmissionClientAgent = clientAgent
383 atprotoRecords = append(atprotoRecords, record)
384 }
385 return atprotoRecords
386}
387
388func ratePerMinute(count int, duration time.Duration) float64 {
389 if duration == 0 {
390 return 0
391 }
392 return float64(count) / duration.Minutes()
393}
394
395func FetchExisting(ctx context.Context, client RepoClient[*PlayRecord], did string, storage cache.Storage, forceRefresh bool) ([]ExistingRecord, error) {
396 if !forceRefresh && storage != nil {
397 published, err := storage.GetPublished(did)
398 if err == nil && len(published) > 0 && storage.IsValid(did) {
399 records := make([]ExistingRecord, 0, len(published))
400 for _, data := range storage.IteratePublished(did, false) {
401 var value PlayRecord
402 if err := json.Unmarshal(data, &value); err != nil {
403 slog.Debug("failed to unmarshal cached record", ErrorAttr(err))
404 continue
405 }
406 records = append(records, ExistingRecord{
407 URI: generateRecordURI(did, &value),
408 Value: &value,
409 })
410 }
411 if len(records) > 0 {
412 slog.Debug("loaded from cache", slog.Int("count", len(records)))
413 return records, nil
414 }
415 }
416 }
417
418 select {
419 case <-ctx.Done():
420 return nil, ctx.Err()
421 default:
422 }
423
424 allRecords := make([]ExistingRecord, 0, 1024)
425 return fetchExistingLoop(ctx, client, did, storage, allRecords)
426}
427
428func fetchExistingLoop(ctx context.Context, client RepoClient[*PlayRecord], did string, storage cache.Storage, allRecords []ExistingRecord) ([]ExistingRecord, error) {
429 const batchSize = 100
430 var cursor string
431
432 type fetchResult struct {
433 records []atproto.RecordRef[*PlayRecord]
434 cursor string
435 }
436
437 fetchRetryPolicy := retrypolicy.NewBuilder[fetchResult]().
438 WithMaxRetries(10).
439 WithBackoff(BaseRetryDelay, 5*time.Minute).
440 HandleIf(func(_ fetchResult, err error) bool {
441 return IsTransientError(err)
442 }).
443 OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[fetchResult]) {
444 slog.Warn("fetch failed with transient error, retrying",
445 slog.Duration("retryDelay", e.Delay),
446 ErrorAttr(e.LastError()),
447 slog.Int("attempt", e.Attempts()))
448 }).
449 Build()
450
451 for {
452 select {
453 case <-ctx.Done():
454 return nil, ctx.Err()
455 default:
456 }
457
458 result, err := failsafe.With(fetchRetryPolicy).
459 WithContext(ctx).
460 Get(func() (fetchResult, error) {
461 recs, next, err := client.ListRecords(ctx, RecordType, batchSize, cursor)
462 if err != nil {
463 return fetchResult{}, err
464 }
465
466 return fetchResult{records: recs, cursor: next}, nil
467 })
468 if err != nil {
469 return nil, err
470 }
471
472 for _, rec := range result.records {
473 allRecords = append(allRecords, ExistingRecord(rec))
474 }
475
476 if result.cursor == "" || len(result.records) < batchSize {
477 break
478 }
479 cursor = result.cursor
480 }
481
482 if storage != nil {
483 cacheEntries := make(map[string][]byte)
484 keys := make([]string, 0, len(allRecords))
485 for _, rec := range allRecords {
486 parts := strings.Split(rec.URI, "/")
487 key := parts[len(parts)-1]
488 if key == "" {
489 key = CreateRecordKey(rec.Value)
490 }
491 value, _ := json.Marshal(rec.Value)
492 cacheEntries[key] = value
493 keys = append(keys, key)
494 }
495
496 if err := storage.SaveRecords(did, cacheEntries); err != nil {
497 return nil, err
498 }
499
500 if err := storage.MarkPublished(did, keys...); err != nil {
501 return nil, err
502 }
503
504 slog.Debug("saved to cache and marked as published", slog.Int("count", len(allRecords)))
505 }
506
507 return allRecords, nil
508}
509
510func generateRecordURI(did string, record *PlayRecord) string {
511 return fmt.Sprintf("at://%s/%s/%s", did, RecordType, CreateRecordKey(record))
512}
513
514func prepareWrites(records []*PlayRecord, collection string) ([]map[string]any, error) {
515 if len(records) == 0 {
516 return nil, nil
517 }
518
519 writes := make([]map[string]any, len(records))
520 keys := CreateRecordKeys(records)
521
522 for i, rec := range records {
523 writes[i] = map[string]any{
524 "$type": "com.atproto.repo.applyWrites#create",
525 "collection": collection,
526 "rkey": keys[i],
527 "value": rec,
528 }
529 }
530
531 return writes, nil
532}