package token import ( "context" "crypto/rsa" "crypto/tls" "database/sql" "encoding/base64" "encoding/json" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "sync" "testing" "time" "atcr.io/pkg/appview/db" ) // Shared test key to avoid generating a new RSA key for each test // Generating a 2048-bit RSA key takes ~0.15s, so reusing one key saves ~4.5s for 32 tests var ( sharedTestKey *rsa.PrivateKey sharedTestKeyPath string sharedTestKeyOnce sync.Once sharedTestKeyDir string ) // getSharedTestKey returns a shared RSA key and its file path for all tests // The key is generated once and reused across all tests in this package func getSharedTestKey(t *testing.T) string { sharedTestKeyOnce.Do(func() { // Create a persistent temp directory for the shared key var err error sharedTestKeyDir, err = os.MkdirTemp("", "atcr-test-keys-*") if err != nil { t.Fatalf("Failed to create test key directory: %v", err) } sharedTestKeyPath = filepath.Join(sharedTestKeyDir, "test-key.pem") // Generate the key once (this is the expensive operation we want to avoid repeating) // This will also generate the certificate via NewIssuer _, err = NewIssuer(sharedTestKeyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("Failed to generate shared test key: %v", err) } }) return sharedTestKeyPath } // setupTestDeviceStore creates an in-memory SQLite database for testing func setupTestDeviceStore(t *testing.T) (*db.DeviceStore, *sql.DB) { testDB, err := db.InitDB(":memory:", true) if err != nil { t.Fatalf("Failed to initialize test database: %v", err) } return db.NewDeviceStore(testDB), testDB } // createTestDevice creates a device in the test database and returns its secret // Requires both DeviceStore and sql.DB to insert user record first func createTestDevice(t *testing.T, store *db.DeviceStore, testDB *sql.DB, did, handle string) string { // First create a user record (required by foreign key constraint) user := &db.User{ DID: did, Handle: handle, PDSEndpoint: "https://pds.example.com", } err := db.UpsertUser(testDB, user) if err != nil { t.Fatalf("Failed to create user: %v", err) } // Create pending authorization pending, err := store.CreatePendingAuth("Test Device", "127.0.0.1", "test-agent") if err != nil { t.Fatalf("Failed to create pending auth: %v", err) } // Approve the pending authorization secret, err := store.ApprovePending(pending.UserCode, did, handle) if err != nil { t.Fatalf("Failed to approve pending auth: %v", err) } return secret } func TestNewHandler(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } handler := NewHandler(issuer, nil) if handler == nil { t.Fatal("Expected non-nil handler") } if handler.issuer == nil { t.Error("Expected issuer to be set") } if handler.validator == nil { t.Error("Expected validator to be initialized") } } func TestHandler_SetPostAuthCallback(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } handler := NewHandler(issuer, nil) handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error { return nil }) if handler.postAuthCallback == nil { t.Error("Expected post-auth callback to be set") } } func TestHandler_ServeHTTP_NoAuth(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } handler := NewHandler(issuer, nil) req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } // Check for WWW-Authenticate header if w.Header().Get("WWW-Authenticate") == "" { t.Error("Expected WWW-Authenticate header") } } func TestHandler_ServeHTTP_WrongMethod(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } handler := NewHandler(issuer, nil) // Try POST instead of GET req := httptest.NewRequest(http.MethodPost, "/auth/token", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, w.Code) } } func TestHandler_ServeHTTP_DeviceAuth_Valid(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } // Create real device store with in-memory database deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) // Create request with device secret req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull,push", nil) req.SetBasicAuth("alice.bsky.social", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) t.Logf("Response body: %s", w.Body.String()) } // Parse response var resp TokenResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("Failed to decode response: %v", err) } if resp.Token == "" { t.Error("Expected non-empty token") } if resp.AccessToken == "" { t.Error("Expected non-empty access_token") } if resp.ExpiresIn == 0 { t.Error("Expected non-zero expires_in") } // Verify token and access_token are the same if resp.Token != resp.AccessToken { t.Error("Expected token and access_token to be the same") } } func TestHandler_ServeHTTP_DeviceAuth_Invalid(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } // Create device store but don't add any devices deviceStore, _ := setupTestDeviceStore(t) handler := NewHandler(issuer, deviceStore) req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) req.SetBasicAuth("alice", "atcr_device_invalid") w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } func TestHandler_ServeHTTP_InvalidScope(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) // Invalid scope format (missing colons) req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=invalid", nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code) } body := w.Body.String() if !strings.Contains(body, "invalid scope") { t.Errorf("Expected error message to contain 'invalid scope', got: %s", body) } } func TestHandler_ServeHTTP_AccessDenied(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) // Try to push to someone else's repository req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:push", nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusForbidden { t.Errorf("Expected status %d, got %d", http.StatusForbidden, w.Code) } body := w.Body.String() if !strings.Contains(body, "access denied") { t.Errorf("Expected error message to contain 'access denied', got: %s", body) } } func TestHandler_ServeHTTP_WithCallback(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) // Set callback to track if it's called callbackCalled := false handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error { callbackCalled = true // Note: We don't check the values because callback shouldn't be called for device auth return nil }) req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) // Note: Callback is only called for app password auth, not device auth // So callbackCalled should be false for this test if callbackCalled { t.Error("Expected callback NOT to be called for device auth") } } func TestHandler_ServeHTTP_MultipleScopes(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) // Multiple scopes separated by space (URL encoded) scopes := "repository%3Aalice.bsky.social%2Fapp1%3Apull+repository%3Aalice.bsky.social%2Fapp2%3Apush" req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope="+scopes, nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) } } func TestHandler_ServeHTTP_WildcardScope(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) // Wildcard scope should be allowed req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:*:pull,push", nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) } } func TestHandler_ServeHTTP_NoScope(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) // No scope parameter - should still work (empty access) req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) } var resp TokenResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("Failed to decode response: %v", err) } if resp.Token == "" { t.Error("Expected non-empty token even with no scope") } } func TestGetBaseURL(t *testing.T) { tests := []struct { name string host string headers map[string]string expectedURL string }{ { name: "simple host", host: "registry.example.com", headers: map[string]string{}, expectedURL: "http://registry.example.com", }, { name: "with TLS", host: "registry.example.com", headers: map[string]string{}, expectedURL: "https://registry.example.com", // Would need TLS in request }, { name: "with X-Forwarded-Host", host: "internal-host", headers: map[string]string{ "X-Forwarded-Host": "registry.example.com", }, expectedURL: "http://registry.example.com", }, { name: "with X-Forwarded-Proto", host: "registry.example.com", headers: map[string]string{ "X-Forwarded-Proto": "https", }, expectedURL: "https://registry.example.com", }, { name: "with both forwarded headers", host: "internal", headers: map[string]string{ "X-Forwarded-Host": "registry.example.com", "X-Forwarded-Proto": "https", }, expectedURL: "https://registry.example.com", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Host = tt.host for key, value := range tt.headers { req.Header.Set(key, value) } // For TLS test if tt.expectedURL == "https://registry.example.com" && len(tt.headers) == 0 { req.TLS = &tls.ConnectionState{} // Non-nil TLS indicates HTTPS } baseURL := getBaseURL(req) if baseURL != tt.expectedURL { t.Errorf("Expected URL %q, got %q", tt.expectedURL, baseURL) } }) } } func TestTokenResponse_JSONFormat(t *testing.T) { resp := TokenResponse{ Token: "jwt_token_here", AccessToken: "jwt_token_here", ExpiresIn: 900, IssuedAt: "2025-01-01T00:00:00Z", } data, err := json.Marshal(resp) if err != nil { t.Fatalf("Failed to marshal response: %v", err) } // Verify JSON structure var decoded map[string]interface{} if err := json.Unmarshal(data, &decoded); err != nil { t.Fatalf("Failed to unmarshal JSON: %v", err) } if decoded["token"] != "jwt_token_here" { t.Error("Expected token field in JSON") } if decoded["access_token"] != "jwt_token_here" { t.Error("Expected access_token field in JSON") } if decoded["expires_in"] != float64(900) { t.Error("Expected expires_in field in JSON") } if decoded["issued_at"] != "2025-01-01T00:00:00Z" { t.Error("Expected issued_at field in JSON") } } func TestHandler_ServeHTTP_AuthHeader(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } handler := NewHandler(issuer, nil) // Test with manually constructed auth header req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) auth := base64.StdEncoding.EncodeToString([]byte("username:password")) req.Header.Set("Authorization", "Basic "+auth) w := httptest.NewRecorder() handler.ServeHTTP(w, req) // Should fail because we don't have valid credentials, but we're testing the header parsing if w.Code != http.StatusUnauthorized { t.Logf("Got status %d (this is fine, we're just testing header parsing)", w.Code) } } func TestHandler_ServeHTTP_ContentType(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("Expected status %d, got %d", http.StatusOK, w.Code) } contentType := w.Header().Get("Content-Type") if contentType != "application/json" { t.Errorf("Expected Content-Type 'application/json', got %q", contentType) } } func TestHandler_ServeHTTP_ExpiresIn(t *testing.T) { keyPath := getSharedTestKey(t) // Create issuer with specific expiration expiration := 10 * time.Minute issuer, err := NewIssuer(keyPath, "atcr.io", "registry", expiration) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) var resp TokenResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("Failed to decode response: %v", err) } expectedExpiresIn := int(expiration.Seconds()) if resp.ExpiresIn != expectedExpiresIn { t.Errorf("Expected expires_in %d, got %d", expectedExpiresIn, resp.ExpiresIn) } } func TestHandler_ServeHTTP_PullOnlyAccess(t *testing.T) { keyPath := getSharedTestKey(t) issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) if err != nil { t.Fatalf("NewIssuer() error = %v", err) } deviceStore, database := setupTestDeviceStore(t) deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") handler := NewHandler(issuer, deviceStore) // Pull from someone else's repo should be allowed req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:pull", nil) req.SetBasicAuth("alice", deviceSecret) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d for pull-only access, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) } }