forked from
evan.jarrett.net/at-container-registry
A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
1package token
2
3import (
4 "context"
5 "crypto/rsa"
6 "crypto/tls"
7 "database/sql"
8 "encoding/base64"
9 "encoding/json"
10 "net/http"
11 "net/http/httptest"
12 "os"
13 "path/filepath"
14 "strings"
15 "sync"
16 "testing"
17 "time"
18
19 "atcr.io/pkg/appview/db"
20)
21
22// Shared test key to avoid generating a new RSA key for each test
23// Generating a 2048-bit RSA key takes ~0.15s, so reusing one key saves ~4.5s for 32 tests
24var (
25 sharedTestKey *rsa.PrivateKey
26 sharedTestKeyPath string
27 sharedTestKeyOnce sync.Once
28 sharedTestKeyDir string
29)
30
31// getSharedTestKey returns a shared RSA key and its file path for all tests
32// The key is generated once and reused across all tests in this package
33func getSharedTestKey(t *testing.T) string {
34 sharedTestKeyOnce.Do(func() {
35 // Create a persistent temp directory for the shared key
36 var err error
37 sharedTestKeyDir, err = os.MkdirTemp("", "atcr-test-keys-*")
38 if err != nil {
39 t.Fatalf("Failed to create test key directory: %v", err)
40 }
41
42 sharedTestKeyPath = filepath.Join(sharedTestKeyDir, "test-key.pem")
43
44 // Generate the key once (this is the expensive operation we want to avoid repeating)
45 // This will also generate the certificate via NewIssuer
46 _, err = NewIssuer(sharedTestKeyPath, "atcr.io", "registry", 15*time.Minute)
47 if err != nil {
48 t.Fatalf("Failed to generate shared test key: %v", err)
49 }
50 })
51
52 return sharedTestKeyPath
53}
54
55// setupTestDeviceStore creates an in-memory SQLite database for testing
56func setupTestDeviceStore(t *testing.T) (*db.DeviceStore, *sql.DB) {
57 testDB, err := db.InitDB(":memory:", true)
58 if err != nil {
59 t.Fatalf("Failed to initialize test database: %v", err)
60 }
61 return db.NewDeviceStore(testDB), testDB
62}
63
64// createTestDevice creates a device in the test database and returns its secret
65// Requires both DeviceStore and sql.DB to insert user record first
66func createTestDevice(t *testing.T, store *db.DeviceStore, testDB *sql.DB, did, handle string) string {
67 // First create a user record (required by foreign key constraint)
68 user := &db.User{
69 DID: did,
70 Handle: handle,
71 PDSEndpoint: "https://pds.example.com",
72 }
73 err := db.UpsertUser(testDB, user)
74 if err != nil {
75 t.Fatalf("Failed to create user: %v", err)
76 }
77
78 // Create pending authorization
79 pending, err := store.CreatePendingAuth("Test Device", "127.0.0.1", "test-agent")
80 if err != nil {
81 t.Fatalf("Failed to create pending auth: %v", err)
82 }
83
84 // Approve the pending authorization
85 secret, err := store.ApprovePending(pending.UserCode, did, handle)
86 if err != nil {
87 t.Fatalf("Failed to approve pending auth: %v", err)
88 }
89
90 return secret
91}
92
93func TestNewHandler(t *testing.T) {
94 keyPath := getSharedTestKey(t)
95
96 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
97 if err != nil {
98 t.Fatalf("NewIssuer() error = %v", err)
99 }
100
101 handler := NewHandler(issuer, nil)
102 if handler == nil {
103 t.Fatal("Expected non-nil handler")
104 }
105
106 if handler.issuer == nil {
107 t.Error("Expected issuer to be set")
108 }
109
110 if handler.validator == nil {
111 t.Error("Expected validator to be initialized")
112 }
113}
114
115func TestHandler_SetPostAuthCallback(t *testing.T) {
116 keyPath := getSharedTestKey(t)
117
118 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
119 if err != nil {
120 t.Fatalf("NewIssuer() error = %v", err)
121 }
122
123 handler := NewHandler(issuer, nil)
124
125 handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error {
126 return nil
127 })
128
129 if handler.postAuthCallback == nil {
130 t.Error("Expected post-auth callback to be set")
131 }
132}
133
134func TestHandler_ServeHTTP_NoAuth(t *testing.T) {
135 keyPath := getSharedTestKey(t)
136
137 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
138 if err != nil {
139 t.Fatalf("NewIssuer() error = %v", err)
140 }
141
142 handler := NewHandler(issuer, nil)
143
144 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
145 w := httptest.NewRecorder()
146
147 handler.ServeHTTP(w, req)
148
149 if w.Code != http.StatusUnauthorized {
150 t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
151 }
152
153 // Check for WWW-Authenticate header
154 if w.Header().Get("WWW-Authenticate") == "" {
155 t.Error("Expected WWW-Authenticate header")
156 }
157}
158
159func TestHandler_ServeHTTP_WrongMethod(t *testing.T) {
160 keyPath := getSharedTestKey(t)
161
162 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
163 if err != nil {
164 t.Fatalf("NewIssuer() error = %v", err)
165 }
166
167 handler := NewHandler(issuer, nil)
168
169 // Try POST instead of GET
170 req := httptest.NewRequest(http.MethodPost, "/auth/token", nil)
171 w := httptest.NewRecorder()
172
173 handler.ServeHTTP(w, req)
174
175 if w.Code != http.StatusMethodNotAllowed {
176 t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
177 }
178}
179
180func TestHandler_ServeHTTP_DeviceAuth_Valid(t *testing.T) {
181 keyPath := getSharedTestKey(t)
182
183 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
184 if err != nil {
185 t.Fatalf("NewIssuer() error = %v", err)
186 }
187
188 // Create real device store with in-memory database
189 deviceStore, database := setupTestDeviceStore(t)
190 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
191
192 handler := NewHandler(issuer, deviceStore)
193
194 // Create request with device secret
195 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull,push", nil)
196 req.SetBasicAuth("alice.bsky.social", deviceSecret)
197 w := httptest.NewRecorder()
198
199 handler.ServeHTTP(w, req)
200
201 if w.Code != http.StatusOK {
202 t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
203 t.Logf("Response body: %s", w.Body.String())
204 }
205
206 // Parse response
207 var resp TokenResponse
208 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
209 t.Fatalf("Failed to decode response: %v", err)
210 }
211
212 if resp.Token == "" {
213 t.Error("Expected non-empty token")
214 }
215
216 if resp.AccessToken == "" {
217 t.Error("Expected non-empty access_token")
218 }
219
220 if resp.ExpiresIn == 0 {
221 t.Error("Expected non-zero expires_in")
222 }
223
224 // Verify token and access_token are the same
225 if resp.Token != resp.AccessToken {
226 t.Error("Expected token and access_token to be the same")
227 }
228}
229
230func TestHandler_ServeHTTP_DeviceAuth_Invalid(t *testing.T) {
231 keyPath := getSharedTestKey(t)
232
233 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
234 if err != nil {
235 t.Fatalf("NewIssuer() error = %v", err)
236 }
237
238 // Create device store but don't add any devices
239 deviceStore, _ := setupTestDeviceStore(t)
240
241 handler := NewHandler(issuer, deviceStore)
242
243 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
244 req.SetBasicAuth("alice", "atcr_device_invalid")
245 w := httptest.NewRecorder()
246
247 handler.ServeHTTP(w, req)
248
249 if w.Code != http.StatusUnauthorized {
250 t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
251 }
252}
253
254func TestHandler_ServeHTTP_InvalidScope(t *testing.T) {
255 keyPath := getSharedTestKey(t)
256
257 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
258 if err != nil {
259 t.Fatalf("NewIssuer() error = %v", err)
260 }
261
262 deviceStore, database := setupTestDeviceStore(t)
263 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
264
265 handler := NewHandler(issuer, deviceStore)
266
267 // Invalid scope format (missing colons)
268 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=invalid", nil)
269 req.SetBasicAuth("alice", deviceSecret)
270 w := httptest.NewRecorder()
271
272 handler.ServeHTTP(w, req)
273
274 if w.Code != http.StatusBadRequest {
275 t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
276 }
277
278 body := w.Body.String()
279 if !strings.Contains(body, "invalid scope") {
280 t.Errorf("Expected error message to contain 'invalid scope', got: %s", body)
281 }
282}
283
284func TestHandler_ServeHTTP_AccessDenied(t *testing.T) {
285 keyPath := getSharedTestKey(t)
286
287 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
288 if err != nil {
289 t.Fatalf("NewIssuer() error = %v", err)
290 }
291
292 deviceStore, database := setupTestDeviceStore(t)
293 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
294
295 handler := NewHandler(issuer, deviceStore)
296
297 // Try to push to someone else's repository
298 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:push", nil)
299 req.SetBasicAuth("alice", deviceSecret)
300 w := httptest.NewRecorder()
301
302 handler.ServeHTTP(w, req)
303
304 if w.Code != http.StatusForbidden {
305 t.Errorf("Expected status %d, got %d", http.StatusForbidden, w.Code)
306 }
307
308 body := w.Body.String()
309 if !strings.Contains(body, "access denied") {
310 t.Errorf("Expected error message to contain 'access denied', got: %s", body)
311 }
312}
313
314func TestHandler_ServeHTTP_WithCallback(t *testing.T) {
315 keyPath := getSharedTestKey(t)
316
317 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
318 if err != nil {
319 t.Fatalf("NewIssuer() error = %v", err)
320 }
321
322 deviceStore, database := setupTestDeviceStore(t)
323 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
324
325 handler := NewHandler(issuer, deviceStore)
326
327 // Set callback to track if it's called
328 callbackCalled := false
329 handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error {
330 callbackCalled = true
331 // Note: We don't check the values because callback shouldn't be called for device auth
332 return nil
333 })
334
335 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
336 req.SetBasicAuth("alice", deviceSecret)
337 w := httptest.NewRecorder()
338
339 handler.ServeHTTP(w, req)
340
341 // Note: Callback is only called for app password auth, not device auth
342 // So callbackCalled should be false for this test
343 if callbackCalled {
344 t.Error("Expected callback NOT to be called for device auth")
345 }
346}
347
348func TestHandler_ServeHTTP_MultipleScopes(t *testing.T) {
349 keyPath := getSharedTestKey(t)
350
351 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
352 if err != nil {
353 t.Fatalf("NewIssuer() error = %v", err)
354 }
355
356 deviceStore, database := setupTestDeviceStore(t)
357 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
358
359 handler := NewHandler(issuer, deviceStore)
360
361 // Multiple scopes separated by space (URL encoded)
362 scopes := "repository%3Aalice.bsky.social%2Fapp1%3Apull+repository%3Aalice.bsky.social%2Fapp2%3Apush"
363 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope="+scopes, nil)
364 req.SetBasicAuth("alice", deviceSecret)
365 w := httptest.NewRecorder()
366
367 handler.ServeHTTP(w, req)
368
369 if w.Code != http.StatusOK {
370 t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
371 }
372}
373
374func TestHandler_ServeHTTP_WildcardScope(t *testing.T) {
375 keyPath := getSharedTestKey(t)
376
377 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
378 if err != nil {
379 t.Fatalf("NewIssuer() error = %v", err)
380 }
381
382 deviceStore, database := setupTestDeviceStore(t)
383 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
384
385 handler := NewHandler(issuer, deviceStore)
386
387 // Wildcard scope should be allowed
388 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:*:pull,push", nil)
389 req.SetBasicAuth("alice", deviceSecret)
390 w := httptest.NewRecorder()
391
392 handler.ServeHTTP(w, req)
393
394 if w.Code != http.StatusOK {
395 t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
396 }
397}
398
399func TestHandler_ServeHTTP_NoScope(t *testing.T) {
400 keyPath := getSharedTestKey(t)
401
402 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
403 if err != nil {
404 t.Fatalf("NewIssuer() error = %v", err)
405 }
406
407 deviceStore, database := setupTestDeviceStore(t)
408 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
409
410 handler := NewHandler(issuer, deviceStore)
411
412 // No scope parameter - should still work (empty access)
413 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
414 req.SetBasicAuth("alice", deviceSecret)
415 w := httptest.NewRecorder()
416
417 handler.ServeHTTP(w, req)
418
419 if w.Code != http.StatusOK {
420 t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
421 }
422
423 var resp TokenResponse
424 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
425 t.Fatalf("Failed to decode response: %v", err)
426 }
427
428 if resp.Token == "" {
429 t.Error("Expected non-empty token even with no scope")
430 }
431}
432
433func TestGetBaseURL(t *testing.T) {
434 tests := []struct {
435 name string
436 host string
437 headers map[string]string
438 expectedURL string
439 }{
440 {
441 name: "simple host",
442 host: "registry.example.com",
443 headers: map[string]string{},
444 expectedURL: "http://registry.example.com",
445 },
446 {
447 name: "with TLS",
448 host: "registry.example.com",
449 headers: map[string]string{},
450 expectedURL: "https://registry.example.com", // Would need TLS in request
451 },
452 {
453 name: "with X-Forwarded-Host",
454 host: "internal-host",
455 headers: map[string]string{
456 "X-Forwarded-Host": "registry.example.com",
457 },
458 expectedURL: "http://registry.example.com",
459 },
460 {
461 name: "with X-Forwarded-Proto",
462 host: "registry.example.com",
463 headers: map[string]string{
464 "X-Forwarded-Proto": "https",
465 },
466 expectedURL: "https://registry.example.com",
467 },
468 {
469 name: "with both forwarded headers",
470 host: "internal",
471 headers: map[string]string{
472 "X-Forwarded-Host": "registry.example.com",
473 "X-Forwarded-Proto": "https",
474 },
475 expectedURL: "https://registry.example.com",
476 },
477 }
478
479 for _, tt := range tests {
480 t.Run(tt.name, func(t *testing.T) {
481 req := httptest.NewRequest(http.MethodGet, "/", nil)
482 req.Host = tt.host
483
484 for key, value := range tt.headers {
485 req.Header.Set(key, value)
486 }
487
488 // For TLS test
489 if tt.expectedURL == "https://registry.example.com" && len(tt.headers) == 0 {
490 req.TLS = &tls.ConnectionState{} // Non-nil TLS indicates HTTPS
491 }
492
493 baseURL := getBaseURL(req)
494
495 if baseURL != tt.expectedURL {
496 t.Errorf("Expected URL %q, got %q", tt.expectedURL, baseURL)
497 }
498 })
499 }
500}
501
502func TestTokenResponse_JSONFormat(t *testing.T) {
503 resp := TokenResponse{
504 Token: "jwt_token_here",
505 AccessToken: "jwt_token_here",
506 ExpiresIn: 900,
507 IssuedAt: "2025-01-01T00:00:00Z",
508 }
509
510 data, err := json.Marshal(resp)
511 if err != nil {
512 t.Fatalf("Failed to marshal response: %v", err)
513 }
514
515 // Verify JSON structure
516 var decoded map[string]interface{}
517 if err := json.Unmarshal(data, &decoded); err != nil {
518 t.Fatalf("Failed to unmarshal JSON: %v", err)
519 }
520
521 if decoded["token"] != "jwt_token_here" {
522 t.Error("Expected token field in JSON")
523 }
524
525 if decoded["access_token"] != "jwt_token_here" {
526 t.Error("Expected access_token field in JSON")
527 }
528
529 if decoded["expires_in"] != float64(900) {
530 t.Error("Expected expires_in field in JSON")
531 }
532
533 if decoded["issued_at"] != "2025-01-01T00:00:00Z" {
534 t.Error("Expected issued_at field in JSON")
535 }
536}
537
538func TestHandler_ServeHTTP_AuthHeader(t *testing.T) {
539 keyPath := getSharedTestKey(t)
540
541 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
542 if err != nil {
543 t.Fatalf("NewIssuer() error = %v", err)
544 }
545
546 handler := NewHandler(issuer, nil)
547
548 // Test with manually constructed auth header
549 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
550 auth := base64.StdEncoding.EncodeToString([]byte("username:password"))
551 req.Header.Set("Authorization", "Basic "+auth)
552 w := httptest.NewRecorder()
553
554 handler.ServeHTTP(w, req)
555
556 // Should fail because we don't have valid credentials, but we're testing the header parsing
557 if w.Code != http.StatusUnauthorized {
558 t.Logf("Got status %d (this is fine, we're just testing header parsing)", w.Code)
559 }
560}
561
562func TestHandler_ServeHTTP_ContentType(t *testing.T) {
563 keyPath := getSharedTestKey(t)
564
565 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
566 if err != nil {
567 t.Fatalf("NewIssuer() error = %v", err)
568 }
569
570 deviceStore, database := setupTestDeviceStore(t)
571 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
572
573 handler := NewHandler(issuer, deviceStore)
574
575 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
576 req.SetBasicAuth("alice", deviceSecret)
577 w := httptest.NewRecorder()
578
579 handler.ServeHTTP(w, req)
580
581 if w.Code != http.StatusOK {
582 t.Fatalf("Expected status %d, got %d", http.StatusOK, w.Code)
583 }
584
585 contentType := w.Header().Get("Content-Type")
586 if contentType != "application/json" {
587 t.Errorf("Expected Content-Type 'application/json', got %q", contentType)
588 }
589}
590
591func TestHandler_ServeHTTP_ExpiresIn(t *testing.T) {
592 keyPath := getSharedTestKey(t)
593
594 // Create issuer with specific expiration
595 expiration := 10 * time.Minute
596 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", expiration)
597 if err != nil {
598 t.Fatalf("NewIssuer() error = %v", err)
599 }
600
601 deviceStore, database := setupTestDeviceStore(t)
602 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
603
604 handler := NewHandler(issuer, deviceStore)
605
606 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
607 req.SetBasicAuth("alice", deviceSecret)
608 w := httptest.NewRecorder()
609
610 handler.ServeHTTP(w, req)
611
612 var resp TokenResponse
613 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
614 t.Fatalf("Failed to decode response: %v", err)
615 }
616
617 expectedExpiresIn := int(expiration.Seconds())
618 if resp.ExpiresIn != expectedExpiresIn {
619 t.Errorf("Expected expires_in %d, got %d", expectedExpiresIn, resp.ExpiresIn)
620 }
621}
622
623func TestHandler_ServeHTTP_PullOnlyAccess(t *testing.T) {
624 keyPath := getSharedTestKey(t)
625
626 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
627 if err != nil {
628 t.Fatalf("NewIssuer() error = %v", err)
629 }
630
631 deviceStore, database := setupTestDeviceStore(t)
632 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
633
634 handler := NewHandler(issuer, deviceStore)
635
636 // Pull from someone else's repo should be allowed
637 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:pull", nil)
638 req.SetBasicAuth("alice", deviceSecret)
639 w := httptest.NewRecorder()
640
641 handler.ServeHTTP(w, req)
642
643 if w.Code != http.StatusOK {
644 t.Errorf("Expected status %d for pull-only access, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
645 }
646}