A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
at codeberg-source 173 lines 5.2 kB view raw
1// Package auth provides authentication and authorization for ATCR, including 2// ATProto session validation, hold authorization (captain/crew membership), 3// scope parsing, and token caching for OAuth and service tokens. 4package auth 5 6import ( 7 "bytes" 8 "context" 9 "crypto/sha256" 10 "encoding/hex" 11 "encoding/json" 12 "fmt" 13 "io" 14 "log/slog" 15 "net/http" 16 "sync" 17 "time" 18 19 "atcr.io/pkg/atproto" 20) 21 22// CachedSession represents a cached session 23type CachedSession struct { 24 DID string 25 Handle string 26 PDS string 27 AccessToken string 28 ExpiresAt time.Time 29} 30 31// SessionValidator validates ATProto credentials 32type SessionValidator struct { 33 httpClient *http.Client 34 cache map[string]*CachedSession 35 cacheMu sync.RWMutex 36} 37 38// NewSessionValidator creates a new ATProto session validator 39func NewSessionValidator() *SessionValidator { 40 return &SessionValidator{ 41 httpClient: &http.Client{}, 42 cache: make(map[string]*CachedSession), 43 } 44} 45 46// getCacheKey generates a cache key from username and password 47func getCacheKey(username, password string) string { 48 h := sha256.New() 49 h.Write([]byte(username + ":" + password)) 50 return hex.EncodeToString(h.Sum(nil)) 51} 52 53// getCachedSession retrieves a cached session if valid 54func (v *SessionValidator) getCachedSession(cacheKey string) (*CachedSession, bool) { 55 v.cacheMu.RLock() 56 defer v.cacheMu.RUnlock() 57 58 session, ok := v.cache[cacheKey] 59 if !ok { 60 return nil, false 61 } 62 63 // Check if expired (with 5 minute buffer) 64 if time.Now().After(session.ExpiresAt.Add(-5 * time.Minute)) { 65 return nil, false 66 } 67 68 return session, true 69} 70 71// setCachedSession stores a session in the cache 72func (v *SessionValidator) setCachedSession(cacheKey string, session *CachedSession) { 73 v.cacheMu.Lock() 74 defer v.cacheMu.Unlock() 75 v.cache[cacheKey] = session 76} 77 78// SessionResponse represents the response from createSession 79type SessionResponse struct { 80 DID string `json:"did"` 81 Handle string `json:"handle"` 82 AccessJWT string `json:"accessJwt"` 83 RefreshJWT string `json:"refreshJwt"` 84 Email string `json:"email,omitempty"` 85 AccessToken string `json:"access_token,omitempty"` // Alternative field name 86} 87 88// CreateSessionAndGetToken creates a session and returns the DID, handle, and access token 89func (v *SessionValidator) CreateSessionAndGetToken(ctx context.Context, identifier, password string) (did, handle, accessToken string, err error) { 90 // Check cache first 91 cacheKey := getCacheKey(identifier, password) 92 if cached, ok := v.getCachedSession(cacheKey); ok { 93 slog.Debug("Using cached session", "identifier", identifier, "did", cached.DID) 94 return cached.DID, cached.Handle, cached.AccessToken, nil 95 } 96 97 slog.Debug("No cached session, creating new session", "identifier", identifier) 98 99 // Resolve identifier to PDS endpoint 100 _, _, pds, err := atproto.ResolveIdentity(ctx, identifier) 101 if err != nil { 102 return "", "", "", err 103 } 104 105 // Create session 106 sessionResp, err := v.createSession(ctx, pds, identifier, password) 107 if err != nil { 108 return "", "", "", fmt.Errorf("authentication failed: %w", err) 109 } 110 111 // Cache the session (ATProto sessions typically last 2 hours) 112 v.setCachedSession(cacheKey, &CachedSession{ 113 DID: sessionResp.DID, 114 Handle: sessionResp.Handle, 115 PDS: pds, 116 AccessToken: sessionResp.AccessJWT, 117 ExpiresAt: time.Now().Add(2 * time.Hour), 118 }) 119 slog.Debug("Cached session (expires in 2 hours)", "identifier", identifier, "did", sessionResp.DID) 120 121 return sessionResp.DID, sessionResp.Handle, sessionResp.AccessJWT, nil 122} 123 124// createSession calls com.atproto.server.createSession 125func (v *SessionValidator) createSession(ctx context.Context, pdsEndpoint, identifier, password string) (*SessionResponse, error) { 126 payload := map[string]string{ 127 "identifier": identifier, 128 "password": password, 129 } 130 131 body, err := json.Marshal(payload) 132 if err != nil { 133 return nil, fmt.Errorf("failed to marshal request: %w", err) 134 } 135 136 url := fmt.Sprintf("%s%s", pdsEndpoint, atproto.ServerCreateSession) 137 slog.Debug("Creating ATProto session", "url", url) 138 139 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) 140 if err != nil { 141 return nil, err 142 } 143 144 req.Header.Set("Content-Type", "application/json") 145 146 resp, err := v.httpClient.Do(req) 147 if err != nil { 148 slog.Debug("Session creation HTTP request failed", "error", err) 149 return nil, fmt.Errorf("failed to create session: %w", err) 150 } 151 defer resp.Body.Close() 152 153 slog.Debug("Received session creation response", "status", resp.StatusCode) 154 155 if resp.StatusCode == http.StatusUnauthorized { 156 bodyBytes, _ := io.ReadAll(resp.Body) 157 slog.Debug("Session creation unauthorized", "response", string(bodyBytes)) 158 return nil, fmt.Errorf("invalid credentials") 159 } 160 161 if resp.StatusCode != http.StatusOK { 162 bodyBytes, _ := io.ReadAll(resp.Body) 163 slog.Debug("Session creation failed", "status", resp.StatusCode, "response", string(bodyBytes)) 164 return nil, fmt.Errorf("create session failed with status %d: %s", resp.StatusCode, string(bodyBytes)) 165 } 166 167 var sessionResp SessionResponse 168 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil { 169 return nil, fmt.Errorf("failed to decode response: %w", err) 170 } 171 172 return &sessionResp, nil 173}