A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
1// Package auth provides authentication and authorization for ATCR, including
2// ATProto session validation, hold authorization (captain/crew membership),
3// scope parsing, and token caching for OAuth and service tokens.
4package auth
5
6import (
7 "bytes"
8 "context"
9 "crypto/sha256"
10 "encoding/hex"
11 "encoding/json"
12 "fmt"
13 "io"
14 "log/slog"
15 "net/http"
16 "sync"
17 "time"
18
19 "atcr.io/pkg/atproto"
20)
21
22// CachedSession represents a cached session
23type CachedSession struct {
24 DID string
25 Handle string
26 PDS string
27 AccessToken string
28 ExpiresAt time.Time
29}
30
31// SessionValidator validates ATProto credentials
32type SessionValidator struct {
33 httpClient *http.Client
34 cache map[string]*CachedSession
35 cacheMu sync.RWMutex
36}
37
38// NewSessionValidator creates a new ATProto session validator
39func NewSessionValidator() *SessionValidator {
40 return &SessionValidator{
41 httpClient: &http.Client{},
42 cache: make(map[string]*CachedSession),
43 }
44}
45
46// getCacheKey generates a cache key from username and password
47func getCacheKey(username, password string) string {
48 h := sha256.New()
49 h.Write([]byte(username + ":" + password))
50 return hex.EncodeToString(h.Sum(nil))
51}
52
53// getCachedSession retrieves a cached session if valid
54func (v *SessionValidator) getCachedSession(cacheKey string) (*CachedSession, bool) {
55 v.cacheMu.RLock()
56 defer v.cacheMu.RUnlock()
57
58 session, ok := v.cache[cacheKey]
59 if !ok {
60 return nil, false
61 }
62
63 // Check if expired (with 5 minute buffer)
64 if time.Now().After(session.ExpiresAt.Add(-5 * time.Minute)) {
65 return nil, false
66 }
67
68 return session, true
69}
70
71// setCachedSession stores a session in the cache
72func (v *SessionValidator) setCachedSession(cacheKey string, session *CachedSession) {
73 v.cacheMu.Lock()
74 defer v.cacheMu.Unlock()
75 v.cache[cacheKey] = session
76}
77
78// SessionResponse represents the response from createSession
79type SessionResponse struct {
80 DID string `json:"did"`
81 Handle string `json:"handle"`
82 AccessJWT string `json:"accessJwt"`
83 RefreshJWT string `json:"refreshJwt"`
84 Email string `json:"email,omitempty"`
85 AccessToken string `json:"access_token,omitempty"` // Alternative field name
86}
87
88// CreateSessionAndGetToken creates a session and returns the DID, handle, and access token
89func (v *SessionValidator) CreateSessionAndGetToken(ctx context.Context, identifier, password string) (did, handle, accessToken string, err error) {
90 // Check cache first
91 cacheKey := getCacheKey(identifier, password)
92 if cached, ok := v.getCachedSession(cacheKey); ok {
93 slog.Debug("Using cached session", "identifier", identifier, "did", cached.DID)
94 return cached.DID, cached.Handle, cached.AccessToken, nil
95 }
96
97 slog.Debug("No cached session, creating new session", "identifier", identifier)
98
99 // Resolve identifier to PDS endpoint
100 _, _, pds, err := atproto.ResolveIdentity(ctx, identifier)
101 if err != nil {
102 return "", "", "", err
103 }
104
105 // Create session
106 sessionResp, err := v.createSession(ctx, pds, identifier, password)
107 if err != nil {
108 return "", "", "", fmt.Errorf("authentication failed: %w", err)
109 }
110
111 // Cache the session (ATProto sessions typically last 2 hours)
112 v.setCachedSession(cacheKey, &CachedSession{
113 DID: sessionResp.DID,
114 Handle: sessionResp.Handle,
115 PDS: pds,
116 AccessToken: sessionResp.AccessJWT,
117 ExpiresAt: time.Now().Add(2 * time.Hour),
118 })
119 slog.Debug("Cached session (expires in 2 hours)", "identifier", identifier, "did", sessionResp.DID)
120
121 return sessionResp.DID, sessionResp.Handle, sessionResp.AccessJWT, nil
122}
123
124// createSession calls com.atproto.server.createSession
125func (v *SessionValidator) createSession(ctx context.Context, pdsEndpoint, identifier, password string) (*SessionResponse, error) {
126 payload := map[string]string{
127 "identifier": identifier,
128 "password": password,
129 }
130
131 body, err := json.Marshal(payload)
132 if err != nil {
133 return nil, fmt.Errorf("failed to marshal request: %w", err)
134 }
135
136 url := fmt.Sprintf("%s%s", pdsEndpoint, atproto.ServerCreateSession)
137 slog.Debug("Creating ATProto session", "url", url)
138
139 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
140 if err != nil {
141 return nil, err
142 }
143
144 req.Header.Set("Content-Type", "application/json")
145
146 resp, err := v.httpClient.Do(req)
147 if err != nil {
148 slog.Debug("Session creation HTTP request failed", "error", err)
149 return nil, fmt.Errorf("failed to create session: %w", err)
150 }
151 defer resp.Body.Close()
152
153 slog.Debug("Received session creation response", "status", resp.StatusCode)
154
155 if resp.StatusCode == http.StatusUnauthorized {
156 bodyBytes, _ := io.ReadAll(resp.Body)
157 slog.Debug("Session creation unauthorized", "response", string(bodyBytes))
158 return nil, fmt.Errorf("invalid credentials")
159 }
160
161 if resp.StatusCode != http.StatusOK {
162 bodyBytes, _ := io.ReadAll(resp.Body)
163 slog.Debug("Session creation failed", "status", resp.StatusCode, "response", string(bodyBytes))
164 return nil, fmt.Errorf("create session failed with status %d: %s", resp.StatusCode, string(bodyBytes))
165 }
166
167 var sessionResp SessionResponse
168 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil {
169 return nil, fmt.Errorf("failed to decode response: %w", err)
170 }
171
172 return &sessionResp, nil
173}