like malachite (atproto-lastfm-importer) but in go and bluer
go
spotify
tealfm
lastfm
atproto
1package sync
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "iter"
9 "maps"
10 "sync"
11 "testing"
12 "testing/synctest"
13 "time"
14
15 "github.com/bluesky-social/indigo/atproto/atclient"
16
17 "tangled.org/karitham.dev/lazuli/atproto"
18 "tangled.org/karitham.dev/lazuli/cache"
19)
20
21// Mock Storage
22type mockStorage struct {
23 cache.RecordStore
24 unpublished map[string][]byte
25 published map[string]bool
26 failed map[string]string
27 kv map[string]int
28 mu sync.Mutex
29}
30
31func newMockStorage() *mockStorage {
32 return &mockStorage{
33 unpublished: make(map[string][]byte),
34 published: make(map[string]bool),
35 failed: make(map[string]string),
36 kv: make(map[string]int),
37 }
38}
39
40func (m *mockStorage) SaveRecords(did string, records map[string][]byte) error {
41 m.mu.Lock()
42 defer m.mu.Unlock()
43 maps.Copy(m.unpublished, records)
44 return nil
45}
46
47func (m *mockStorage) IterateUnpublished(did string, reverse bool) iter.Seq2[string, []byte] {
48 return func(yield func(key string, rec []byte) bool) {
49 m.mu.Lock()
50 keys := make([]string, 0, len(m.unpublished))
51 for k := range m.unpublished {
52 keys = append(keys, k)
53 }
54 m.mu.Unlock()
55
56 for _, k := range keys {
57 m.mu.Lock()
58 rec, ok := m.unpublished[k]
59 m.mu.Unlock()
60 if ok {
61 if !yield(k, rec) {
62 return
63 }
64 }
65 }
66 }
67}
68
69func (m *mockStorage) IteratePublished(did string, reverse bool) iter.Seq2[string, []byte] {
70 return func(yield func(key string, rec []byte) bool) {
71 m.mu.Lock()
72 keys := make([]string, 0, len(m.published))
73 for k := range m.published {
74 keys = append(keys, k)
75 }
76 m.mu.Unlock()
77
78 for _, k := range keys {
79 m.mu.Lock()
80 rec, ok := m.unpublished[k] // Get record data from unpublished map
81 m.mu.Unlock()
82 if ok {
83 if !yield(k, rec) {
84 return
85 }
86 }
87 }
88 }
89}
90
91func (m *mockStorage) MarkPublished(did string, keys ...string) error {
92 m.mu.Lock()
93 defer m.mu.Unlock()
94 for _, k := range keys {
95 delete(m.unpublished, k)
96 m.published[k] = true
97 }
98 return nil
99}
100
101func (m *mockStorage) MarkFailed(did string, keys []string, err string) error {
102 m.mu.Lock()
103 defer m.mu.Unlock()
104 for _, k := range keys {
105 m.failed[k] = err
106 }
107 return nil
108}
109
110func (m *mockStorage) Get(key string) (int, error) {
111 m.mu.Lock()
112 defer m.mu.Unlock()
113 return m.kv[key], nil
114}
115
116func (m *mockStorage) IncrBy(key string, n int) (int, error) {
117 m.mu.Lock()
118 defer m.mu.Unlock()
119 m.kv[key] += n
120 return m.kv[key], nil
121}
122
123// Mock ATProtoClient
124type mockATProtoClient struct {
125 applyWritesFunc func(ctx context.Context, collection string, records []*PlayRecord) error
126 listRecordsFunc func(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[*PlayRecord], string, error)
127 deleteRecordFunc func(ctx context.Context, collection, rkey string) error
128}
129
130func (m *mockATProtoClient) ApplyWrites(ctx context.Context, collection string, records []*PlayRecord) error {
131 if m.applyWritesFunc != nil {
132 return m.applyWritesFunc(ctx, collection, records)
133 }
134 return nil
135}
136
137func (m *mockATProtoClient) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[*PlayRecord], string, error) {
138 if m.listRecordsFunc != nil {
139 return m.listRecordsFunc(ctx, collection, limit, cursor)
140 }
141 return nil, "", nil
142}
143
144func (m *mockATProtoClient) DeleteRecord(ctx context.Context, collection, rkey string) error {
145 if m.deleteRecordFunc != nil {
146 return m.deleteRecordFunc(ctx, collection, rkey)
147 }
148 return nil
149}
150
151// Mock AuthClient
152type mockAuthClient struct {
153 did string
154}
155
156func (m *mockAuthClient) APIClient() *atclient.APIClient { return nil }
157func (m *mockAuthClient) DID() string { return m.did }
158
159type timeoutError struct{}
160
161func (e timeoutError) Error() string { return "timeout" }
162func (e timeoutError) Timeout() bool { return true }
163func (e timeoutError) Temporary() bool { return true }
164
165type failingStorage struct {
166 *mockStorage
167}
168
169func newFailingStorage() *failingStorage {
170 return &failingStorage{
171 mockStorage: newMockStorage(),
172 }
173}
174
175func (s *failingStorage) SaveRecords(did string, records map[string][]byte) error {
176 return errors.New("failed to save records")
177}
178
179func TestBuildRecordBatches(t *testing.T) {
180 tests := []struct {
181 name string
182 records []PlayRecord
183 batchSize int
184 wantBatches int
185 wantErr bool
186 ctxCancel bool
187 }{
188 {
189 name: "empty storage",
190 records: []PlayRecord{},
191 batchSize: 2,
192 wantBatches: 0,
193 wantErr: false,
194 },
195 {
196 name: "single batch",
197 records: []PlayRecord{
198 {TrackName: "Song 1"},
199 {TrackName: "Song 2"},
200 },
201 batchSize: 5,
202 wantBatches: 1,
203 wantErr: false,
204 },
205 {
206 name: "multiple exact batches",
207 records: []PlayRecord{
208 {TrackName: "Song 1"},
209 {TrackName: "Song 2"},
210 {TrackName: "Song 3"},
211 {TrackName: "Song 4"},
212 },
213 batchSize: 2,
214 wantBatches: 2,
215 wantErr: false,
216 },
217 {
218 name: "partial final batch",
219 records: []PlayRecord{
220 {TrackName: "Song 1"},
221 {TrackName: "Song 2"},
222 {TrackName: "Song 3"},
223 },
224 batchSize: 2,
225 wantBatches: 2,
226 wantErr: false,
227 },
228 {
229 name: "context cancelled",
230 records: []PlayRecord{
231 {TrackName: "Song 1"},
232 {TrackName: "Song 2"},
233 },
234 batchSize: 2,
235 wantBatches: 0,
236 wantErr: true,
237 ctxCancel: true,
238 },
239 {
240 name: "malformed records skipped",
241 records: []PlayRecord{
242 {TrackName: "Song 1"},
243 {TrackName: "Song 2"},
244 },
245 batchSize: 2,
246 wantBatches: 1,
247 wantErr: false,
248 },
249 }
250
251 for _, tt := range tests {
252 t.Run(tt.name, func(t *testing.T) {
253 t.Parallel()
254
255 ctx := t.Context()
256 if tt.ctxCancel {
257 var cancel context.CancelFunc
258 ctx, cancel = context.WithCancel(ctx)
259 cancel()
260 }
261
262 storage := newMockStorage()
263 did := "did:example:123"
264
265 // Add records to storage
266 for i, record := range tt.records {
267 data, _ := json.Marshal(record)
268 if tt.name == "malformed records skipped" && i == 1 {
269 data = []byte("invalid json")
270 }
271 storage.unpublished[fmt.Sprintf("key%d", i)] = data
272 }
273
274 batches, err := batchRecords(ctx, storage, did, tt.batchSize, false)
275
276 if (err != nil) != tt.wantErr {
277 t.Errorf("BuildRecordBatches() error = %v, wantErr %v", err, tt.wantErr)
278 return
279 }
280
281 if len(batches) != tt.wantBatches {
282 t.Errorf("BuildRecordBatches() batches = %d, want %d", len(batches), tt.wantBatches)
283 }
284
285 if !tt.wantErr {
286 totalRecords := 0
287 for _, batch := range batches {
288 totalRecords += len(batch.Records)
289 }
290
291 expectedRecords := len(tt.records)
292 if tt.name == "malformed records skipped" {
293 expectedRecords = len(tt.records) - 1 // Skip malformed record
294 }
295
296 if totalRecords != expectedRecords {
297 t.Errorf("BuildRecordBatches() total records = %d, want %d", totalRecords, expectedRecords)
298 }
299 }
300 })
301 }
302}
303
304func TestIteratePublished(t *testing.T) {
305 tests := []struct {
306 name string
307 setupStorage func() *mockStorage
308 reverse bool
309 wantKeys []string
310 }{
311 {
312 name: "returns only published records",
313 setupStorage: func() *mockStorage {
314 s := newMockStorage()
315 s.unpublished["key1"] = []byte(`{"trackName":"a"}`)
316 s.unpublished["key2"] = []byte(`{"trackName":"b"}`)
317 s.unpublished["key3"] = []byte(`{"trackName":"c"}`)
318 s.published["key1"] = true
319 s.published["key3"] = true
320 return s
321 },
322 reverse: false,
323 wantKeys: []string{"key1", "key3"},
324 },
325 {
326 name: "handles empty published set",
327 setupStorage: func() *mockStorage {
328 s := newMockStorage()
329 s.unpublished["key1"] = []byte(`{"trackName":"a"}`)
330 return s
331 },
332 reverse: false,
333 wantKeys: nil,
334 },
335 }
336
337 for _, tt := range tests {
338 t.Run(tt.name, func(t *testing.T) {
339 t.Parallel()
340
341 storage := tt.setupStorage()
342
343 var gotKeys []string
344 for key, rec := range storage.IteratePublished("did:test", tt.reverse) {
345 gotKeys = append(gotKeys, key)
346 _ = rec
347 }
348
349 if len(gotKeys) != len(tt.wantKeys) {
350 t.Errorf("IteratePublished() returned %d keys, want %d", len(gotKeys), len(tt.wantKeys))
351 return
352 }
353
354 wantSet := make(map[string]bool)
355 for _, k := range tt.wantKeys {
356 wantSet[k] = true
357 }
358 for _, k := range gotKeys {
359 if !wantSet[k] {
360 t.Errorf("IteratePublished() got unexpected key %s", k)
361 }
362 }
363 })
364 }
365}
366
367func TestProcessBatch(t *testing.T) {
368 tests := []struct {
369 name string
370 batch recordBatch
371 processor batchProcessor
372 wantSuccess int
373 wantError int
374 wantErr bool
375 setupClient func() *mockATProtoClient
376 setupStorage func() cache.Storage
377 }{
378 {
379 name: "empty batch",
380 batch: recordBatch{
381 Records: []*PlayRecord{},
382 Keys: []string{},
383 },
384 processor: batchProcessor{
385 Client: &mockATProtoClient{},
386 Storage: newMockStorage(),
387 },
388 wantSuccess: 0,
389 wantError: 0,
390 wantErr: false,
391 },
392 {
393 name: "successful batch",
394 batch: recordBatch{
395 Records: []*PlayRecord{{TrackName: "Song 1"}, {TrackName: "Song 2"}},
396 Keys: []string{"key1", "key2"},
397 },
398 processor: batchProcessor{
399 Client: &mockATProtoClient{},
400 Storage: newMockStorage(),
401 DID: "did:example:123",
402 ClientAgent: "test-agent",
403 },
404 wantSuccess: 2,
405 wantError: 0,
406 wantErr: false,
407 },
408 {
409 name: "dry run batch",
410 batch: recordBatch{
411 Records: []*PlayRecord{{TrackName: "Song 1"}, {TrackName: "Song 2"}},
412 Keys: []string{"key1", "key2"},
413 },
414 processor: batchProcessor{
415 Client: &mockATProtoClient{},
416 Storage: newMockStorage(),
417 DID: "did:example:123",
418 ClientAgent: "test-agent",
419 DryRun: true,
420 },
421 wantSuccess: 2,
422 wantError: 0,
423 wantErr: false,
424 },
425 {
426 name: "batch with apply writes failure",
427 batch: recordBatch{
428 Records: []*PlayRecord{{TrackName: "Song 1"}},
429 Keys: []string{"key1"},
430 },
431 processor: batchProcessor{
432 Client: func() *mockATProtoClient {
433 return &mockATProtoClient{
434 applyWritesFunc: func(ctx context.Context, collection string, records []*PlayRecord) error {
435 return errors.New("apply writes failed")
436 },
437 }
438 }(),
439 Storage: newMockStorage(),
440 DID: "did:example:123",
441 ClientAgent: "test-agent",
442 },
443 wantSuccess: 0,
444 wantError: 1,
445 wantErr: false,
446 },
447 {
448 name: "batch with storage failure",
449 batch: recordBatch{
450 Records: []*PlayRecord{{TrackName: "Song 1"}},
451 Keys: []string{"key1"},
452 },
453 processor: batchProcessor{
454 Client: &mockATProtoClient{},
455 Storage: newFailingStorage(),
456 DID: "did:example:123",
457 ClientAgent: "test-agent",
458 },
459 wantSuccess: 0,
460 wantError: 1,
461 wantErr: false,
462 },
463 }
464
465 for _, tt := range tests {
466 t.Run(tt.name, func(t *testing.T) {
467 t.Parallel()
468
469 ctx := t.Context()
470 result := processBatch(ctx, tt.batch, tt.processor)
471
472 if result.SuccessCount != tt.wantSuccess {
473 t.Errorf("ProcessBatch() success count = %d, want %d", result.SuccessCount, tt.wantSuccess)
474 }
475
476 if result.ErrorCount != tt.wantError {
477 t.Errorf("ProcessBatch() error count = %d, want %d", result.ErrorCount, tt.wantError)
478 }
479
480 if tt.wantError > 0 && len(result.Errors) == 0 {
481 t.Error("ProcessBatch() expected errors but got none")
482 }
483 })
484 }
485}
486
487func TestAggregateResults(t *testing.T) {
488 tests := []struct {
489 name string
490 results []batchResult
491 startTime time.Time
492 wantSuccess int
493 wantErrors int
494 wantTotal int
495 wantDuration bool
496 wantRatePerMin bool
497 }{
498 {
499 name: "empty results",
500 results: []batchResult{},
501 wantSuccess: 0,
502 wantErrors: 0,
503 wantTotal: 0,
504 },
505 {
506 name: "single successful result",
507 results: []batchResult{
508 {SuccessCount: 5, ErrorCount: 0, Duration: time.Second},
509 },
510 wantSuccess: 5,
511 wantErrors: 0,
512 wantTotal: 5,
513 wantDuration: true,
514 wantRatePerMin: true,
515 },
516 {
517 name: "multiple mixed results",
518 results: []batchResult{
519 {SuccessCount: 3, ErrorCount: 1, Duration: time.Second},
520 {SuccessCount: 2, ErrorCount: 0, Duration: time.Second},
521 {SuccessCount: 0, ErrorCount: 2, Duration: time.Second},
522 },
523 wantSuccess: 5,
524 wantErrors: 3,
525 wantTotal: 8,
526 wantDuration: true,
527 wantRatePerMin: true,
528 },
529 {
530 name: "all errors",
531 results: []batchResult{
532 {SuccessCount: 0, ErrorCount: 3, Duration: time.Second},
533 {SuccessCount: 0, ErrorCount: 2, Duration: time.Second},
534 },
535 wantSuccess: 0,
536 wantErrors: 5,
537 wantTotal: 5,
538 wantDuration: true,
539 wantRatePerMin: false, // 0 success rate
540 },
541 }
542
543 for _, tt := range tests {
544 t.Run(tt.name, func(t *testing.T) {
545 t.Parallel()
546
547 startTime := time.Now()
548 if !tt.startTime.IsZero() {
549 startTime = tt.startTime
550 }
551
552 result := aggregate(tt.results, startTime)
553
554 if result.SuccessCount != tt.wantSuccess {
555 t.Errorf("AggregateResults() success = %d, want %d", result.SuccessCount, tt.wantSuccess)
556 }
557
558 if result.ErrorCount != tt.wantErrors {
559 t.Errorf("AggregateResults() errors = %d, want %d", result.ErrorCount, tt.wantErrors)
560 }
561
562 if result.TotalRecords != tt.wantTotal {
563 t.Errorf("AggregateResults() total = %d, want %d", result.TotalRecords, tt.wantTotal)
564 }
565
566 if tt.wantDuration && result.Duration == 0 {
567 t.Error("AggregateResults() expected non-zero duration")
568 }
569
570 if tt.wantRatePerMin && result.RecordsPerMinute == 0 && tt.wantSuccess > 0 {
571 t.Error("AggregateResults() expected non-zero rate per minute")
572 }
573 })
574 }
575}
576
577func TestPublish(t *testing.T) {
578 tests := []struct {
579 name string
580 opts PublishOptions
581 records []PlayRecord
582 setupClient func() *mockATProtoClient
583 wantSuccess int
584 wantErrors int
585 wantCancelled bool
586 }{
587 {
588 name: "successful publish",
589 opts: PublishOptions{
590 BatchSize: 2,
591 Storage: newMockStorage(),
592 ClientAgent: "test-agent",
593 },
594 records: []PlayRecord{
595 {TrackName: "Song 1"},
596 {TrackName: "Song 2"},
597 },
598 wantSuccess: 2,
599 wantErrors: 0,
600 },
601 {
602 name: "dry run publish",
603 opts: PublishOptions{
604 BatchSize: 2,
605 DryRun: true,
606 Storage: newMockStorage(),
607 ClientAgent: "test-agent",
608 },
609 records: []PlayRecord{
610 {TrackName: "Song 1"},
611 {TrackName: "Song 2"},
612 },
613 wantSuccess: 2,
614 wantErrors: 0,
615 },
616 {
617 name: "publish with client errors",
618 opts: PublishOptions{
619 BatchSize: 1,
620 Storage: newMockStorage(),
621 ClientAgent: "test-agent",
622 },
623 records: []PlayRecord{
624 {TrackName: "Song 1"},
625 {TrackName: "Song 2"},
626 },
627 setupClient: func() *mockATProtoClient {
628 return &mockATProtoClient{
629 applyWritesFunc: func(ctx context.Context, collection string, records []*PlayRecord) error {
630 return &atclient.APIError{StatusCode: 400} // Non-transient error
631 },
632 }
633 },
634 wantSuccess: 0,
635 wantErrors: 2,
636 },
637 {
638 name: "publish with reverse iteration",
639 opts: PublishOptions{
640 BatchSize: 2,
641 Reverse: true,
642 Storage: newMockStorage(),
643 ClientAgent: "test-agent",
644 },
645 records: []PlayRecord{
646 {TrackName: "Song 1"},
647 {TrackName: "Song 2"},
648 {TrackName: "Song 3"},
649 },
650 wantSuccess: 3,
651 wantErrors: 0,
652 },
653 {
654 name: "publish with empty records",
655 opts: PublishOptions{
656 BatchSize: 2,
657 Storage: newMockStorage(),
658 ClientAgent: "test-agent",
659 },
660 records: []PlayRecord{},
661 wantSuccess: 0,
662 wantErrors: 0,
663 },
664 {
665 name: "publish with transient errors and retry",
666 opts: PublishOptions{
667 BatchSize: 1,
668 Storage: newMockStorage(),
669 ClientAgent: "test-agent",
670 },
671 records: []PlayRecord{
672 {TrackName: "Song 1"},
673 },
674 setupClient: func() *mockATProtoClient {
675 return &mockATProtoClient{
676 applyWritesFunc: func(ctx context.Context, collection string, records []*PlayRecord) error {
677 return &atclient.APIError{StatusCode: 500} // Transient error
678 },
679 }
680 },
681 wantSuccess: 0,
682 wantErrors: 1,
683 },
684 {
685 name: "publish with storage failure",
686 opts: PublishOptions{
687 BatchSize: 1,
688 Storage: newFailingStorage(),
689 ClientAgent: "test-agent",
690 },
691 records: []PlayRecord{
692 {TrackName: "Song 1"},
693 },
694 wantSuccess: 0,
695 wantErrors: 1,
696 },
697 }
698
699 for _, tt := range tests {
700 t.Run(tt.name, func(t *testing.T) {
701 synctest.Test(t, func(t *testing.T) {
702 ctx := t.Context()
703 did := "did:example:123"
704
705 var storage *mockStorage
706 if fs, ok := tt.opts.Storage.(*failingStorage); ok {
707 storage = fs.mockStorage
708 } else {
709 storage = tt.opts.Storage.(*mockStorage)
710 }
711 // Add records to storage
712 for i, record := range tt.records {
713 data, _ := json.Marshal(record)
714 storage.unpublished[fmt.Sprintf("key%d", i)] = data
715 }
716
717 client := &mockAuthClient{did: did}
718 if tt.setupClient != nil {
719 tt.opts.ATProtoClient = tt.setupClient()
720 } else {
721 tt.opts.ATProtoClient = &mockATProtoClient{}
722 }
723
724 result := Publish(ctx, client, tt.opts)
725
726 if result.SuccessCount != tt.wantSuccess {
727 t.Errorf("Publish() success = %d, want %d", result.SuccessCount, tt.wantSuccess)
728 }
729
730 if result.ErrorCount != tt.wantErrors {
731 t.Errorf("Publish() errors = %d, want %d", result.ErrorCount, tt.wantErrors)
732 }
733
734 if result.Cancelled != tt.wantCancelled {
735 t.Errorf("Publish() cancelled = %v, want %v", result.Cancelled, tt.wantCancelled)
736 }
737 })
738 })
739 }
740}
741
742func TestIsTransientError(t *testing.T) {
743 tests := []struct {
744 name string
745 err error
746 want bool
747 }{
748 {"nil", nil, false},
749 {"generic error", errors.New("some error"), false},
750 {"API 400", &atclient.APIError{StatusCode: 400}, false},
751 {"API 429", &atclient.APIError{StatusCode: 429}, true},
752 {"API 500", &atclient.APIError{StatusCode: 500}, true},
753 {"API 503", &atclient.APIError{StatusCode: 503}, true},
754 {"net timeout", timeoutError{}, true},
755 {"net non-timeout", errors.New("network is down"), false},
756 }
757
758 for _, tt := range tests {
759 t.Run(tt.name, func(t *testing.T) {
760 t.Parallel()
761 if got := IsTransientError(tt.err); got != tt.want {
762 t.Errorf("IsTransientError() = %v, want %v", got, tt.want)
763 }
764 })
765 }
766}