A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
at codeberg-source 646 lines 19 kB view raw
1package token 2 3import ( 4 "context" 5 "crypto/rsa" 6 "crypto/tls" 7 "database/sql" 8 "encoding/base64" 9 "encoding/json" 10 "net/http" 11 "net/http/httptest" 12 "os" 13 "path/filepath" 14 "strings" 15 "sync" 16 "testing" 17 "time" 18 19 "atcr.io/pkg/appview/db" 20) 21 22// Shared test key to avoid generating a new RSA key for each test 23// Generating a 2048-bit RSA key takes ~0.15s, so reusing one key saves ~4.5s for 32 tests 24var ( 25 sharedTestKey *rsa.PrivateKey 26 sharedTestKeyPath string 27 sharedTestKeyOnce sync.Once 28 sharedTestKeyDir string 29) 30 31// getSharedTestKey returns a shared RSA key and its file path for all tests 32// The key is generated once and reused across all tests in this package 33func getSharedTestKey(t *testing.T) string { 34 sharedTestKeyOnce.Do(func() { 35 // Create a persistent temp directory for the shared key 36 var err error 37 sharedTestKeyDir, err = os.MkdirTemp("", "atcr-test-keys-*") 38 if err != nil { 39 t.Fatalf("Failed to create test key directory: %v", err) 40 } 41 42 sharedTestKeyPath = filepath.Join(sharedTestKeyDir, "test-key.pem") 43 44 // Generate the key once (this is the expensive operation we want to avoid repeating) 45 // This will also generate the certificate via NewIssuer 46 _, err = NewIssuer(sharedTestKeyPath, "atcr.io", "registry", 15*time.Minute) 47 if err != nil { 48 t.Fatalf("Failed to generate shared test key: %v", err) 49 } 50 }) 51 52 return sharedTestKeyPath 53} 54 55// setupTestDeviceStore creates an in-memory SQLite database for testing 56func setupTestDeviceStore(t *testing.T) (*db.DeviceStore, *sql.DB) { 57 testDB, err := db.InitDB(":memory:", true) 58 if err != nil { 59 t.Fatalf("Failed to initialize test database: %v", err) 60 } 61 return db.NewDeviceStore(testDB), testDB 62} 63 64// createTestDevice creates a device in the test database and returns its secret 65// Requires both DeviceStore and sql.DB to insert user record first 66func createTestDevice(t *testing.T, store *db.DeviceStore, testDB *sql.DB, did, handle string) string { 67 // First create a user record (required by foreign key constraint) 68 user := &db.User{ 69 DID: did, 70 Handle: handle, 71 PDSEndpoint: "https://pds.example.com", 72 } 73 err := db.UpsertUser(testDB, user) 74 if err != nil { 75 t.Fatalf("Failed to create user: %v", err) 76 } 77 78 // Create pending authorization 79 pending, err := store.CreatePendingAuth("Test Device", "127.0.0.1", "test-agent") 80 if err != nil { 81 t.Fatalf("Failed to create pending auth: %v", err) 82 } 83 84 // Approve the pending authorization 85 secret, err := store.ApprovePending(pending.UserCode, did, handle) 86 if err != nil { 87 t.Fatalf("Failed to approve pending auth: %v", err) 88 } 89 90 return secret 91} 92 93func TestNewHandler(t *testing.T) { 94 keyPath := getSharedTestKey(t) 95 96 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 97 if err != nil { 98 t.Fatalf("NewIssuer() error = %v", err) 99 } 100 101 handler := NewHandler(issuer, nil) 102 if handler == nil { 103 t.Fatal("Expected non-nil handler") 104 } 105 106 if handler.issuer == nil { 107 t.Error("Expected issuer to be set") 108 } 109 110 if handler.validator == nil { 111 t.Error("Expected validator to be initialized") 112 } 113} 114 115func TestHandler_SetPostAuthCallback(t *testing.T) { 116 keyPath := getSharedTestKey(t) 117 118 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 119 if err != nil { 120 t.Fatalf("NewIssuer() error = %v", err) 121 } 122 123 handler := NewHandler(issuer, nil) 124 125 handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error { 126 return nil 127 }) 128 129 if handler.postAuthCallback == nil { 130 t.Error("Expected post-auth callback to be set") 131 } 132} 133 134func TestHandler_ServeHTTP_NoAuth(t *testing.T) { 135 keyPath := getSharedTestKey(t) 136 137 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 138 if err != nil { 139 t.Fatalf("NewIssuer() error = %v", err) 140 } 141 142 handler := NewHandler(issuer, nil) 143 144 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) 145 w := httptest.NewRecorder() 146 147 handler.ServeHTTP(w, req) 148 149 if w.Code != http.StatusUnauthorized { 150 t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) 151 } 152 153 // Check for WWW-Authenticate header 154 if w.Header().Get("WWW-Authenticate") == "" { 155 t.Error("Expected WWW-Authenticate header") 156 } 157} 158 159func TestHandler_ServeHTTP_WrongMethod(t *testing.T) { 160 keyPath := getSharedTestKey(t) 161 162 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 163 if err != nil { 164 t.Fatalf("NewIssuer() error = %v", err) 165 } 166 167 handler := NewHandler(issuer, nil) 168 169 // Try POST instead of GET 170 req := httptest.NewRequest(http.MethodPost, "/auth/token", nil) 171 w := httptest.NewRecorder() 172 173 handler.ServeHTTP(w, req) 174 175 if w.Code != http.StatusMethodNotAllowed { 176 t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, w.Code) 177 } 178} 179 180func TestHandler_ServeHTTP_DeviceAuth_Valid(t *testing.T) { 181 keyPath := getSharedTestKey(t) 182 183 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 184 if err != nil { 185 t.Fatalf("NewIssuer() error = %v", err) 186 } 187 188 // Create real device store with in-memory database 189 deviceStore, database := setupTestDeviceStore(t) 190 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") 191 192 handler := NewHandler(issuer, deviceStore) 193 194 // Create request with device secret 195 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull,push", nil) 196 req.SetBasicAuth("alice.bsky.social", deviceSecret) 197 w := httptest.NewRecorder() 198 199 handler.ServeHTTP(w, req) 200 201 if w.Code != http.StatusOK { 202 t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) 203 t.Logf("Response body: %s", w.Body.String()) 204 } 205 206 // Parse response 207 var resp TokenResponse 208 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { 209 t.Fatalf("Failed to decode response: %v", err) 210 } 211 212 if resp.Token == "" { 213 t.Error("Expected non-empty token") 214 } 215 216 if resp.AccessToken == "" { 217 t.Error("Expected non-empty access_token") 218 } 219 220 if resp.ExpiresIn == 0 { 221 t.Error("Expected non-zero expires_in") 222 } 223 224 // Verify token and access_token are the same 225 if resp.Token != resp.AccessToken { 226 t.Error("Expected token and access_token to be the same") 227 } 228} 229 230func TestHandler_ServeHTTP_DeviceAuth_Invalid(t *testing.T) { 231 keyPath := getSharedTestKey(t) 232 233 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 234 if err != nil { 235 t.Fatalf("NewIssuer() error = %v", err) 236 } 237 238 // Create device store but don't add any devices 239 deviceStore, _ := setupTestDeviceStore(t) 240 241 handler := NewHandler(issuer, deviceStore) 242 243 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) 244 req.SetBasicAuth("alice", "atcr_device_invalid") 245 w := httptest.NewRecorder() 246 247 handler.ServeHTTP(w, req) 248 249 if w.Code != http.StatusUnauthorized { 250 t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) 251 } 252} 253 254func TestHandler_ServeHTTP_InvalidScope(t *testing.T) { 255 keyPath := getSharedTestKey(t) 256 257 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 258 if err != nil { 259 t.Fatalf("NewIssuer() error = %v", err) 260 } 261 262 deviceStore, database := setupTestDeviceStore(t) 263 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") 264 265 handler := NewHandler(issuer, deviceStore) 266 267 // Invalid scope format (missing colons) 268 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=invalid", nil) 269 req.SetBasicAuth("alice", deviceSecret) 270 w := httptest.NewRecorder() 271 272 handler.ServeHTTP(w, req) 273 274 if w.Code != http.StatusBadRequest { 275 t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code) 276 } 277 278 body := w.Body.String() 279 if !strings.Contains(body, "invalid scope") { 280 t.Errorf("Expected error message to contain 'invalid scope', got: %s", body) 281 } 282} 283 284func TestHandler_ServeHTTP_AccessDenied(t *testing.T) { 285 keyPath := getSharedTestKey(t) 286 287 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 288 if err != nil { 289 t.Fatalf("NewIssuer() error = %v", err) 290 } 291 292 deviceStore, database := setupTestDeviceStore(t) 293 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 294 295 handler := NewHandler(issuer, deviceStore) 296 297 // Try to push to someone else's repository 298 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:push", nil) 299 req.SetBasicAuth("alice", deviceSecret) 300 w := httptest.NewRecorder() 301 302 handler.ServeHTTP(w, req) 303 304 if w.Code != http.StatusForbidden { 305 t.Errorf("Expected status %d, got %d", http.StatusForbidden, w.Code) 306 } 307 308 body := w.Body.String() 309 if !strings.Contains(body, "access denied") { 310 t.Errorf("Expected error message to contain 'access denied', got: %s", body) 311 } 312} 313 314func TestHandler_ServeHTTP_WithCallback(t *testing.T) { 315 keyPath := getSharedTestKey(t) 316 317 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 318 if err != nil { 319 t.Fatalf("NewIssuer() error = %v", err) 320 } 321 322 deviceStore, database := setupTestDeviceStore(t) 323 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") 324 325 handler := NewHandler(issuer, deviceStore) 326 327 // Set callback to track if it's called 328 callbackCalled := false 329 handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error { 330 callbackCalled = true 331 // Note: We don't check the values because callback shouldn't be called for device auth 332 return nil 333 }) 334 335 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) 336 req.SetBasicAuth("alice", deviceSecret) 337 w := httptest.NewRecorder() 338 339 handler.ServeHTTP(w, req) 340 341 // Note: Callback is only called for app password auth, not device auth 342 // So callbackCalled should be false for this test 343 if callbackCalled { 344 t.Error("Expected callback NOT to be called for device auth") 345 } 346} 347 348func TestHandler_ServeHTTP_MultipleScopes(t *testing.T) { 349 keyPath := getSharedTestKey(t) 350 351 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 352 if err != nil { 353 t.Fatalf("NewIssuer() error = %v", err) 354 } 355 356 deviceStore, database := setupTestDeviceStore(t) 357 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 358 359 handler := NewHandler(issuer, deviceStore) 360 361 // Multiple scopes separated by space (URL encoded) 362 scopes := "repository%3Aalice.bsky.social%2Fapp1%3Apull+repository%3Aalice.bsky.social%2Fapp2%3Apush" 363 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope="+scopes, nil) 364 req.SetBasicAuth("alice", deviceSecret) 365 w := httptest.NewRecorder() 366 367 handler.ServeHTTP(w, req) 368 369 if w.Code != http.StatusOK { 370 t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) 371 } 372} 373 374func TestHandler_ServeHTTP_WildcardScope(t *testing.T) { 375 keyPath := getSharedTestKey(t) 376 377 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 378 if err != nil { 379 t.Fatalf("NewIssuer() error = %v", err) 380 } 381 382 deviceStore, database := setupTestDeviceStore(t) 383 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 384 385 handler := NewHandler(issuer, deviceStore) 386 387 // Wildcard scope should be allowed 388 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:*:pull,push", nil) 389 req.SetBasicAuth("alice", deviceSecret) 390 w := httptest.NewRecorder() 391 392 handler.ServeHTTP(w, req) 393 394 if w.Code != http.StatusOK { 395 t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) 396 } 397} 398 399func TestHandler_ServeHTTP_NoScope(t *testing.T) { 400 keyPath := getSharedTestKey(t) 401 402 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 403 if err != nil { 404 t.Fatalf("NewIssuer() error = %v", err) 405 } 406 407 deviceStore, database := setupTestDeviceStore(t) 408 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 409 410 handler := NewHandler(issuer, deviceStore) 411 412 // No scope parameter - should still work (empty access) 413 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) 414 req.SetBasicAuth("alice", deviceSecret) 415 w := httptest.NewRecorder() 416 417 handler.ServeHTTP(w, req) 418 419 if w.Code != http.StatusOK { 420 t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) 421 } 422 423 var resp TokenResponse 424 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { 425 t.Fatalf("Failed to decode response: %v", err) 426 } 427 428 if resp.Token == "" { 429 t.Error("Expected non-empty token even with no scope") 430 } 431} 432 433func TestGetBaseURL(t *testing.T) { 434 tests := []struct { 435 name string 436 host string 437 headers map[string]string 438 expectedURL string 439 }{ 440 { 441 name: "simple host", 442 host: "registry.example.com", 443 headers: map[string]string{}, 444 expectedURL: "http://registry.example.com", 445 }, 446 { 447 name: "with TLS", 448 host: "registry.example.com", 449 headers: map[string]string{}, 450 expectedURL: "https://registry.example.com", // Would need TLS in request 451 }, 452 { 453 name: "with X-Forwarded-Host", 454 host: "internal-host", 455 headers: map[string]string{ 456 "X-Forwarded-Host": "registry.example.com", 457 }, 458 expectedURL: "http://registry.example.com", 459 }, 460 { 461 name: "with X-Forwarded-Proto", 462 host: "registry.example.com", 463 headers: map[string]string{ 464 "X-Forwarded-Proto": "https", 465 }, 466 expectedURL: "https://registry.example.com", 467 }, 468 { 469 name: "with both forwarded headers", 470 host: "internal", 471 headers: map[string]string{ 472 "X-Forwarded-Host": "registry.example.com", 473 "X-Forwarded-Proto": "https", 474 }, 475 expectedURL: "https://registry.example.com", 476 }, 477 } 478 479 for _, tt := range tests { 480 t.Run(tt.name, func(t *testing.T) { 481 req := httptest.NewRequest(http.MethodGet, "/", nil) 482 req.Host = tt.host 483 484 for key, value := range tt.headers { 485 req.Header.Set(key, value) 486 } 487 488 // For TLS test 489 if tt.expectedURL == "https://registry.example.com" && len(tt.headers) == 0 { 490 req.TLS = &tls.ConnectionState{} // Non-nil TLS indicates HTTPS 491 } 492 493 baseURL := getBaseURL(req) 494 495 if baseURL != tt.expectedURL { 496 t.Errorf("Expected URL %q, got %q", tt.expectedURL, baseURL) 497 } 498 }) 499 } 500} 501 502func TestTokenResponse_JSONFormat(t *testing.T) { 503 resp := TokenResponse{ 504 Token: "jwt_token_here", 505 AccessToken: "jwt_token_here", 506 ExpiresIn: 900, 507 IssuedAt: "2025-01-01T00:00:00Z", 508 } 509 510 data, err := json.Marshal(resp) 511 if err != nil { 512 t.Fatalf("Failed to marshal response: %v", err) 513 } 514 515 // Verify JSON structure 516 var decoded map[string]interface{} 517 if err := json.Unmarshal(data, &decoded); err != nil { 518 t.Fatalf("Failed to unmarshal JSON: %v", err) 519 } 520 521 if decoded["token"] != "jwt_token_here" { 522 t.Error("Expected token field in JSON") 523 } 524 525 if decoded["access_token"] != "jwt_token_here" { 526 t.Error("Expected access_token field in JSON") 527 } 528 529 if decoded["expires_in"] != float64(900) { 530 t.Error("Expected expires_in field in JSON") 531 } 532 533 if decoded["issued_at"] != "2025-01-01T00:00:00Z" { 534 t.Error("Expected issued_at field in JSON") 535 } 536} 537 538func TestHandler_ServeHTTP_AuthHeader(t *testing.T) { 539 keyPath := getSharedTestKey(t) 540 541 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 542 if err != nil { 543 t.Fatalf("NewIssuer() error = %v", err) 544 } 545 546 handler := NewHandler(issuer, nil) 547 548 // Test with manually constructed auth header 549 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) 550 auth := base64.StdEncoding.EncodeToString([]byte("username:password")) 551 req.Header.Set("Authorization", "Basic "+auth) 552 w := httptest.NewRecorder() 553 554 handler.ServeHTTP(w, req) 555 556 // Should fail because we don't have valid credentials, but we're testing the header parsing 557 if w.Code != http.StatusUnauthorized { 558 t.Logf("Got status %d (this is fine, we're just testing header parsing)", w.Code) 559 } 560} 561 562func TestHandler_ServeHTTP_ContentType(t *testing.T) { 563 keyPath := getSharedTestKey(t) 564 565 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 566 if err != nil { 567 t.Fatalf("NewIssuer() error = %v", err) 568 } 569 570 deviceStore, database := setupTestDeviceStore(t) 571 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 572 573 handler := NewHandler(issuer, deviceStore) 574 575 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) 576 req.SetBasicAuth("alice", deviceSecret) 577 w := httptest.NewRecorder() 578 579 handler.ServeHTTP(w, req) 580 581 if w.Code != http.StatusOK { 582 t.Fatalf("Expected status %d, got %d", http.StatusOK, w.Code) 583 } 584 585 contentType := w.Header().Get("Content-Type") 586 if contentType != "application/json" { 587 t.Errorf("Expected Content-Type 'application/json', got %q", contentType) 588 } 589} 590 591func TestHandler_ServeHTTP_ExpiresIn(t *testing.T) { 592 keyPath := getSharedTestKey(t) 593 594 // Create issuer with specific expiration 595 expiration := 10 * time.Minute 596 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", expiration) 597 if err != nil { 598 t.Fatalf("NewIssuer() error = %v", err) 599 } 600 601 deviceStore, database := setupTestDeviceStore(t) 602 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 603 604 handler := NewHandler(issuer, deviceStore) 605 606 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) 607 req.SetBasicAuth("alice", deviceSecret) 608 w := httptest.NewRecorder() 609 610 handler.ServeHTTP(w, req) 611 612 var resp TokenResponse 613 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { 614 t.Fatalf("Failed to decode response: %v", err) 615 } 616 617 expectedExpiresIn := int(expiration.Seconds()) 618 if resp.ExpiresIn != expectedExpiresIn { 619 t.Errorf("Expected expires_in %d, got %d", expectedExpiresIn, resp.ExpiresIn) 620 } 621} 622 623func TestHandler_ServeHTTP_PullOnlyAccess(t *testing.T) { 624 keyPath := getSharedTestKey(t) 625 626 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 627 if err != nil { 628 t.Fatalf("NewIssuer() error = %v", err) 629 } 630 631 deviceStore, database := setupTestDeviceStore(t) 632 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 633 634 handler := NewHandler(issuer, deviceStore) 635 636 // Pull from someone else's repo should be allowed 637 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:pull", nil) 638 req.SetBasicAuth("alice", deviceSecret) 639 w := httptest.NewRecorder() 640 641 handler.ServeHTTP(w, req) 642 643 if w.Code != http.StatusOK { 644 t.Errorf("Expected status %d for pull-only access, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) 645 } 646}