like malachite (atproto-lastfm-importer) but in go and bluer
go spotify tealfm lastfm atproto
at main 766 lines 18 kB view raw
1package sync 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "iter" 9 "maps" 10 "sync" 11 "testing" 12 "testing/synctest" 13 "time" 14 15 "github.com/bluesky-social/indigo/atproto/atclient" 16 17 "tangled.org/karitham.dev/lazuli/atproto" 18 "tangled.org/karitham.dev/lazuli/cache" 19) 20 21// Mock Storage 22type mockStorage struct { 23 cache.RecordStore 24 unpublished map[string][]byte 25 published map[string]bool 26 failed map[string]string 27 kv map[string]int 28 mu sync.Mutex 29} 30 31func newMockStorage() *mockStorage { 32 return &mockStorage{ 33 unpublished: make(map[string][]byte), 34 published: make(map[string]bool), 35 failed: make(map[string]string), 36 kv: make(map[string]int), 37 } 38} 39 40func (m *mockStorage) SaveRecords(did string, records map[string][]byte) error { 41 m.mu.Lock() 42 defer m.mu.Unlock() 43 maps.Copy(m.unpublished, records) 44 return nil 45} 46 47func (m *mockStorage) IterateUnpublished(did string, reverse bool) iter.Seq2[string, []byte] { 48 return func(yield func(key string, rec []byte) bool) { 49 m.mu.Lock() 50 keys := make([]string, 0, len(m.unpublished)) 51 for k := range m.unpublished { 52 keys = append(keys, k) 53 } 54 m.mu.Unlock() 55 56 for _, k := range keys { 57 m.mu.Lock() 58 rec, ok := m.unpublished[k] 59 m.mu.Unlock() 60 if ok { 61 if !yield(k, rec) { 62 return 63 } 64 } 65 } 66 } 67} 68 69func (m *mockStorage) IteratePublished(did string, reverse bool) iter.Seq2[string, []byte] { 70 return func(yield func(key string, rec []byte) bool) { 71 m.mu.Lock() 72 keys := make([]string, 0, len(m.published)) 73 for k := range m.published { 74 keys = append(keys, k) 75 } 76 m.mu.Unlock() 77 78 for _, k := range keys { 79 m.mu.Lock() 80 rec, ok := m.unpublished[k] // Get record data from unpublished map 81 m.mu.Unlock() 82 if ok { 83 if !yield(k, rec) { 84 return 85 } 86 } 87 } 88 } 89} 90 91func (m *mockStorage) MarkPublished(did string, keys ...string) error { 92 m.mu.Lock() 93 defer m.mu.Unlock() 94 for _, k := range keys { 95 delete(m.unpublished, k) 96 m.published[k] = true 97 } 98 return nil 99} 100 101func (m *mockStorage) MarkFailed(did string, keys []string, err string) error { 102 m.mu.Lock() 103 defer m.mu.Unlock() 104 for _, k := range keys { 105 m.failed[k] = err 106 } 107 return nil 108} 109 110func (m *mockStorage) Get(key string) (int, error) { 111 m.mu.Lock() 112 defer m.mu.Unlock() 113 return m.kv[key], nil 114} 115 116func (m *mockStorage) IncrBy(key string, n int) (int, error) { 117 m.mu.Lock() 118 defer m.mu.Unlock() 119 m.kv[key] += n 120 return m.kv[key], nil 121} 122 123// Mock ATProtoClient 124type mockATProtoClient struct { 125 applyWritesFunc func(ctx context.Context, collection string, records []*PlayRecord) error 126 listRecordsFunc func(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[*PlayRecord], string, error) 127 deleteRecordFunc func(ctx context.Context, collection, rkey string) error 128} 129 130func (m *mockATProtoClient) ApplyWrites(ctx context.Context, collection string, records []*PlayRecord) error { 131 if m.applyWritesFunc != nil { 132 return m.applyWritesFunc(ctx, collection, records) 133 } 134 return nil 135} 136 137func (m *mockATProtoClient) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[*PlayRecord], string, error) { 138 if m.listRecordsFunc != nil { 139 return m.listRecordsFunc(ctx, collection, limit, cursor) 140 } 141 return nil, "", nil 142} 143 144func (m *mockATProtoClient) DeleteRecord(ctx context.Context, collection, rkey string) error { 145 if m.deleteRecordFunc != nil { 146 return m.deleteRecordFunc(ctx, collection, rkey) 147 } 148 return nil 149} 150 151// Mock AuthClient 152type mockAuthClient struct { 153 did string 154} 155 156func (m *mockAuthClient) APIClient() *atclient.APIClient { return nil } 157func (m *mockAuthClient) DID() string { return m.did } 158 159type timeoutError struct{} 160 161func (e timeoutError) Error() string { return "timeout" } 162func (e timeoutError) Timeout() bool { return true } 163func (e timeoutError) Temporary() bool { return true } 164 165type failingStorage struct { 166 *mockStorage 167} 168 169func newFailingStorage() *failingStorage { 170 return &failingStorage{ 171 mockStorage: newMockStorage(), 172 } 173} 174 175func (s *failingStorage) SaveRecords(did string, records map[string][]byte) error { 176 return errors.New("failed to save records") 177} 178 179func TestBuildRecordBatches(t *testing.T) { 180 tests := []struct { 181 name string 182 records []PlayRecord 183 batchSize int 184 wantBatches int 185 wantErr bool 186 ctxCancel bool 187 }{ 188 { 189 name: "empty storage", 190 records: []PlayRecord{}, 191 batchSize: 2, 192 wantBatches: 0, 193 wantErr: false, 194 }, 195 { 196 name: "single batch", 197 records: []PlayRecord{ 198 {TrackName: "Song 1"}, 199 {TrackName: "Song 2"}, 200 }, 201 batchSize: 5, 202 wantBatches: 1, 203 wantErr: false, 204 }, 205 { 206 name: "multiple exact batches", 207 records: []PlayRecord{ 208 {TrackName: "Song 1"}, 209 {TrackName: "Song 2"}, 210 {TrackName: "Song 3"}, 211 {TrackName: "Song 4"}, 212 }, 213 batchSize: 2, 214 wantBatches: 2, 215 wantErr: false, 216 }, 217 { 218 name: "partial final batch", 219 records: []PlayRecord{ 220 {TrackName: "Song 1"}, 221 {TrackName: "Song 2"}, 222 {TrackName: "Song 3"}, 223 }, 224 batchSize: 2, 225 wantBatches: 2, 226 wantErr: false, 227 }, 228 { 229 name: "context cancelled", 230 records: []PlayRecord{ 231 {TrackName: "Song 1"}, 232 {TrackName: "Song 2"}, 233 }, 234 batchSize: 2, 235 wantBatches: 0, 236 wantErr: true, 237 ctxCancel: true, 238 }, 239 { 240 name: "malformed records skipped", 241 records: []PlayRecord{ 242 {TrackName: "Song 1"}, 243 {TrackName: "Song 2"}, 244 }, 245 batchSize: 2, 246 wantBatches: 1, 247 wantErr: false, 248 }, 249 } 250 251 for _, tt := range tests { 252 t.Run(tt.name, func(t *testing.T) { 253 t.Parallel() 254 255 ctx := t.Context() 256 if tt.ctxCancel { 257 var cancel context.CancelFunc 258 ctx, cancel = context.WithCancel(ctx) 259 cancel() 260 } 261 262 storage := newMockStorage() 263 did := "did:example:123" 264 265 // Add records to storage 266 for i, record := range tt.records { 267 data, _ := json.Marshal(record) 268 if tt.name == "malformed records skipped" && i == 1 { 269 data = []byte("invalid json") 270 } 271 storage.unpublished[fmt.Sprintf("key%d", i)] = data 272 } 273 274 batches, err := batchRecords(ctx, storage, did, tt.batchSize, false) 275 276 if (err != nil) != tt.wantErr { 277 t.Errorf("BuildRecordBatches() error = %v, wantErr %v", err, tt.wantErr) 278 return 279 } 280 281 if len(batches) != tt.wantBatches { 282 t.Errorf("BuildRecordBatches() batches = %d, want %d", len(batches), tt.wantBatches) 283 } 284 285 if !tt.wantErr { 286 totalRecords := 0 287 for _, batch := range batches { 288 totalRecords += len(batch.Records) 289 } 290 291 expectedRecords := len(tt.records) 292 if tt.name == "malformed records skipped" { 293 expectedRecords = len(tt.records) - 1 // Skip malformed record 294 } 295 296 if totalRecords != expectedRecords { 297 t.Errorf("BuildRecordBatches() total records = %d, want %d", totalRecords, expectedRecords) 298 } 299 } 300 }) 301 } 302} 303 304func TestIteratePublished(t *testing.T) { 305 tests := []struct { 306 name string 307 setupStorage func() *mockStorage 308 reverse bool 309 wantKeys []string 310 }{ 311 { 312 name: "returns only published records", 313 setupStorage: func() *mockStorage { 314 s := newMockStorage() 315 s.unpublished["key1"] = []byte(`{"trackName":"a"}`) 316 s.unpublished["key2"] = []byte(`{"trackName":"b"}`) 317 s.unpublished["key3"] = []byte(`{"trackName":"c"}`) 318 s.published["key1"] = true 319 s.published["key3"] = true 320 return s 321 }, 322 reverse: false, 323 wantKeys: []string{"key1", "key3"}, 324 }, 325 { 326 name: "handles empty published set", 327 setupStorage: func() *mockStorage { 328 s := newMockStorage() 329 s.unpublished["key1"] = []byte(`{"trackName":"a"}`) 330 return s 331 }, 332 reverse: false, 333 wantKeys: nil, 334 }, 335 } 336 337 for _, tt := range tests { 338 t.Run(tt.name, func(t *testing.T) { 339 t.Parallel() 340 341 storage := tt.setupStorage() 342 343 var gotKeys []string 344 for key, rec := range storage.IteratePublished("did:test", tt.reverse) { 345 gotKeys = append(gotKeys, key) 346 _ = rec 347 } 348 349 if len(gotKeys) != len(tt.wantKeys) { 350 t.Errorf("IteratePublished() returned %d keys, want %d", len(gotKeys), len(tt.wantKeys)) 351 return 352 } 353 354 wantSet := make(map[string]bool) 355 for _, k := range tt.wantKeys { 356 wantSet[k] = true 357 } 358 for _, k := range gotKeys { 359 if !wantSet[k] { 360 t.Errorf("IteratePublished() got unexpected key %s", k) 361 } 362 } 363 }) 364 } 365} 366 367func TestProcessBatch(t *testing.T) { 368 tests := []struct { 369 name string 370 batch recordBatch 371 processor batchProcessor 372 wantSuccess int 373 wantError int 374 wantErr bool 375 setupClient func() *mockATProtoClient 376 setupStorage func() cache.Storage 377 }{ 378 { 379 name: "empty batch", 380 batch: recordBatch{ 381 Records: []*PlayRecord{}, 382 Keys: []string{}, 383 }, 384 processor: batchProcessor{ 385 Client: &mockATProtoClient{}, 386 Storage: newMockStorage(), 387 }, 388 wantSuccess: 0, 389 wantError: 0, 390 wantErr: false, 391 }, 392 { 393 name: "successful batch", 394 batch: recordBatch{ 395 Records: []*PlayRecord{{TrackName: "Song 1"}, {TrackName: "Song 2"}}, 396 Keys: []string{"key1", "key2"}, 397 }, 398 processor: batchProcessor{ 399 Client: &mockATProtoClient{}, 400 Storage: newMockStorage(), 401 DID: "did:example:123", 402 ClientAgent: "test-agent", 403 }, 404 wantSuccess: 2, 405 wantError: 0, 406 wantErr: false, 407 }, 408 { 409 name: "dry run batch", 410 batch: recordBatch{ 411 Records: []*PlayRecord{{TrackName: "Song 1"}, {TrackName: "Song 2"}}, 412 Keys: []string{"key1", "key2"}, 413 }, 414 processor: batchProcessor{ 415 Client: &mockATProtoClient{}, 416 Storage: newMockStorage(), 417 DID: "did:example:123", 418 ClientAgent: "test-agent", 419 DryRun: true, 420 }, 421 wantSuccess: 2, 422 wantError: 0, 423 wantErr: false, 424 }, 425 { 426 name: "batch with apply writes failure", 427 batch: recordBatch{ 428 Records: []*PlayRecord{{TrackName: "Song 1"}}, 429 Keys: []string{"key1"}, 430 }, 431 processor: batchProcessor{ 432 Client: func() *mockATProtoClient { 433 return &mockATProtoClient{ 434 applyWritesFunc: func(ctx context.Context, collection string, records []*PlayRecord) error { 435 return errors.New("apply writes failed") 436 }, 437 } 438 }(), 439 Storage: newMockStorage(), 440 DID: "did:example:123", 441 ClientAgent: "test-agent", 442 }, 443 wantSuccess: 0, 444 wantError: 1, 445 wantErr: false, 446 }, 447 { 448 name: "batch with storage failure", 449 batch: recordBatch{ 450 Records: []*PlayRecord{{TrackName: "Song 1"}}, 451 Keys: []string{"key1"}, 452 }, 453 processor: batchProcessor{ 454 Client: &mockATProtoClient{}, 455 Storage: newFailingStorage(), 456 DID: "did:example:123", 457 ClientAgent: "test-agent", 458 }, 459 wantSuccess: 0, 460 wantError: 1, 461 wantErr: false, 462 }, 463 } 464 465 for _, tt := range tests { 466 t.Run(tt.name, func(t *testing.T) { 467 t.Parallel() 468 469 ctx := t.Context() 470 result := processBatch(ctx, tt.batch, tt.processor) 471 472 if result.SuccessCount != tt.wantSuccess { 473 t.Errorf("ProcessBatch() success count = %d, want %d", result.SuccessCount, tt.wantSuccess) 474 } 475 476 if result.ErrorCount != tt.wantError { 477 t.Errorf("ProcessBatch() error count = %d, want %d", result.ErrorCount, tt.wantError) 478 } 479 480 if tt.wantError > 0 && len(result.Errors) == 0 { 481 t.Error("ProcessBatch() expected errors but got none") 482 } 483 }) 484 } 485} 486 487func TestAggregateResults(t *testing.T) { 488 tests := []struct { 489 name string 490 results []batchResult 491 startTime time.Time 492 wantSuccess int 493 wantErrors int 494 wantTotal int 495 wantDuration bool 496 wantRatePerMin bool 497 }{ 498 { 499 name: "empty results", 500 results: []batchResult{}, 501 wantSuccess: 0, 502 wantErrors: 0, 503 wantTotal: 0, 504 }, 505 { 506 name: "single successful result", 507 results: []batchResult{ 508 {SuccessCount: 5, ErrorCount: 0, Duration: time.Second}, 509 }, 510 wantSuccess: 5, 511 wantErrors: 0, 512 wantTotal: 5, 513 wantDuration: true, 514 wantRatePerMin: true, 515 }, 516 { 517 name: "multiple mixed results", 518 results: []batchResult{ 519 {SuccessCount: 3, ErrorCount: 1, Duration: time.Second}, 520 {SuccessCount: 2, ErrorCount: 0, Duration: time.Second}, 521 {SuccessCount: 0, ErrorCount: 2, Duration: time.Second}, 522 }, 523 wantSuccess: 5, 524 wantErrors: 3, 525 wantTotal: 8, 526 wantDuration: true, 527 wantRatePerMin: true, 528 }, 529 { 530 name: "all errors", 531 results: []batchResult{ 532 {SuccessCount: 0, ErrorCount: 3, Duration: time.Second}, 533 {SuccessCount: 0, ErrorCount: 2, Duration: time.Second}, 534 }, 535 wantSuccess: 0, 536 wantErrors: 5, 537 wantTotal: 5, 538 wantDuration: true, 539 wantRatePerMin: false, // 0 success rate 540 }, 541 } 542 543 for _, tt := range tests { 544 t.Run(tt.name, func(t *testing.T) { 545 t.Parallel() 546 547 startTime := time.Now() 548 if !tt.startTime.IsZero() { 549 startTime = tt.startTime 550 } 551 552 result := aggregate(tt.results, startTime) 553 554 if result.SuccessCount != tt.wantSuccess { 555 t.Errorf("AggregateResults() success = %d, want %d", result.SuccessCount, tt.wantSuccess) 556 } 557 558 if result.ErrorCount != tt.wantErrors { 559 t.Errorf("AggregateResults() errors = %d, want %d", result.ErrorCount, tt.wantErrors) 560 } 561 562 if result.TotalRecords != tt.wantTotal { 563 t.Errorf("AggregateResults() total = %d, want %d", result.TotalRecords, tt.wantTotal) 564 } 565 566 if tt.wantDuration && result.Duration == 0 { 567 t.Error("AggregateResults() expected non-zero duration") 568 } 569 570 if tt.wantRatePerMin && result.RecordsPerMinute == 0 && tt.wantSuccess > 0 { 571 t.Error("AggregateResults() expected non-zero rate per minute") 572 } 573 }) 574 } 575} 576 577func TestPublish(t *testing.T) { 578 tests := []struct { 579 name string 580 opts PublishOptions 581 records []PlayRecord 582 setupClient func() *mockATProtoClient 583 wantSuccess int 584 wantErrors int 585 wantCancelled bool 586 }{ 587 { 588 name: "successful publish", 589 opts: PublishOptions{ 590 BatchSize: 2, 591 Storage: newMockStorage(), 592 ClientAgent: "test-agent", 593 }, 594 records: []PlayRecord{ 595 {TrackName: "Song 1"}, 596 {TrackName: "Song 2"}, 597 }, 598 wantSuccess: 2, 599 wantErrors: 0, 600 }, 601 { 602 name: "dry run publish", 603 opts: PublishOptions{ 604 BatchSize: 2, 605 DryRun: true, 606 Storage: newMockStorage(), 607 ClientAgent: "test-agent", 608 }, 609 records: []PlayRecord{ 610 {TrackName: "Song 1"}, 611 {TrackName: "Song 2"}, 612 }, 613 wantSuccess: 2, 614 wantErrors: 0, 615 }, 616 { 617 name: "publish with client errors", 618 opts: PublishOptions{ 619 BatchSize: 1, 620 Storage: newMockStorage(), 621 ClientAgent: "test-agent", 622 }, 623 records: []PlayRecord{ 624 {TrackName: "Song 1"}, 625 {TrackName: "Song 2"}, 626 }, 627 setupClient: func() *mockATProtoClient { 628 return &mockATProtoClient{ 629 applyWritesFunc: func(ctx context.Context, collection string, records []*PlayRecord) error { 630 return &atclient.APIError{StatusCode: 400} // Non-transient error 631 }, 632 } 633 }, 634 wantSuccess: 0, 635 wantErrors: 2, 636 }, 637 { 638 name: "publish with reverse iteration", 639 opts: PublishOptions{ 640 BatchSize: 2, 641 Reverse: true, 642 Storage: newMockStorage(), 643 ClientAgent: "test-agent", 644 }, 645 records: []PlayRecord{ 646 {TrackName: "Song 1"}, 647 {TrackName: "Song 2"}, 648 {TrackName: "Song 3"}, 649 }, 650 wantSuccess: 3, 651 wantErrors: 0, 652 }, 653 { 654 name: "publish with empty records", 655 opts: PublishOptions{ 656 BatchSize: 2, 657 Storage: newMockStorage(), 658 ClientAgent: "test-agent", 659 }, 660 records: []PlayRecord{}, 661 wantSuccess: 0, 662 wantErrors: 0, 663 }, 664 { 665 name: "publish with transient errors and retry", 666 opts: PublishOptions{ 667 BatchSize: 1, 668 Storage: newMockStorage(), 669 ClientAgent: "test-agent", 670 }, 671 records: []PlayRecord{ 672 {TrackName: "Song 1"}, 673 }, 674 setupClient: func() *mockATProtoClient { 675 return &mockATProtoClient{ 676 applyWritesFunc: func(ctx context.Context, collection string, records []*PlayRecord) error { 677 return &atclient.APIError{StatusCode: 500} // Transient error 678 }, 679 } 680 }, 681 wantSuccess: 0, 682 wantErrors: 1, 683 }, 684 { 685 name: "publish with storage failure", 686 opts: PublishOptions{ 687 BatchSize: 1, 688 Storage: newFailingStorage(), 689 ClientAgent: "test-agent", 690 }, 691 records: []PlayRecord{ 692 {TrackName: "Song 1"}, 693 }, 694 wantSuccess: 0, 695 wantErrors: 1, 696 }, 697 } 698 699 for _, tt := range tests { 700 t.Run(tt.name, func(t *testing.T) { 701 synctest.Test(t, func(t *testing.T) { 702 ctx := t.Context() 703 did := "did:example:123" 704 705 var storage *mockStorage 706 if fs, ok := tt.opts.Storage.(*failingStorage); ok { 707 storage = fs.mockStorage 708 } else { 709 storage = tt.opts.Storage.(*mockStorage) 710 } 711 // Add records to storage 712 for i, record := range tt.records { 713 data, _ := json.Marshal(record) 714 storage.unpublished[fmt.Sprintf("key%d", i)] = data 715 } 716 717 client := &mockAuthClient{did: did} 718 if tt.setupClient != nil { 719 tt.opts.ATProtoClient = tt.setupClient() 720 } else { 721 tt.opts.ATProtoClient = &mockATProtoClient{} 722 } 723 724 result := Publish(ctx, client, tt.opts) 725 726 if result.SuccessCount != tt.wantSuccess { 727 t.Errorf("Publish() success = %d, want %d", result.SuccessCount, tt.wantSuccess) 728 } 729 730 if result.ErrorCount != tt.wantErrors { 731 t.Errorf("Publish() errors = %d, want %d", result.ErrorCount, tt.wantErrors) 732 } 733 734 if result.Cancelled != tt.wantCancelled { 735 t.Errorf("Publish() cancelled = %v, want %v", result.Cancelled, tt.wantCancelled) 736 } 737 }) 738 }) 739 } 740} 741 742func TestIsTransientError(t *testing.T) { 743 tests := []struct { 744 name string 745 err error 746 want bool 747 }{ 748 {"nil", nil, false}, 749 {"generic error", errors.New("some error"), false}, 750 {"API 400", &atclient.APIError{StatusCode: 400}, false}, 751 {"API 429", &atclient.APIError{StatusCode: 429}, true}, 752 {"API 500", &atclient.APIError{StatusCode: 500}, true}, 753 {"API 503", &atclient.APIError{StatusCode: 503}, true}, 754 {"net timeout", timeoutError{}, true}, 755 {"net non-timeout", errors.New("network is down"), false}, 756 } 757 758 for _, tt := range tests { 759 t.Run(tt.name, func(t *testing.T) { 760 t.Parallel() 761 if got := IsTransientError(tt.err); got != tt.want { 762 t.Errorf("IsTransientError() = %v, want %v", got, tt.want) 763 } 764 }) 765 } 766}