1package oauth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/elliptic"
7 "crypto/rand"
8 "crypto/x509"
9 "encoding/json"
10 "encoding/pem"
11 "fmt"
12 "log"
13 "net/http"
14 "net/url"
15 "os"
16 "sync"
17 "time"
18
19 "margin.at/internal/db"
20 "margin.at/internal/xrpc"
21)
22
23type Handler struct {
24 db *db.DB
25 configuredBaseURL string
26 privateKey *ecdsa.PrivateKey
27 pending map[string]*PendingAuth
28 pendingMu sync.RWMutex
29}
30
31func NewHandler(database *db.DB) (*Handler, error) {
32
33 configuredBaseURL := os.Getenv("BASE_URL")
34
35 privateKey, err := loadOrGenerateKey()
36 if err != nil {
37 return nil, fmt.Errorf("failed to load/generate key: %w", err)
38 }
39
40 return &Handler{
41 db: database,
42 configuredBaseURL: configuredBaseURL,
43 privateKey: privateKey,
44 pending: make(map[string]*PendingAuth),
45 }, nil
46}
47
48func loadOrGenerateKey() (*ecdsa.PrivateKey, error) {
49 keyPath := os.Getenv("OAUTH_KEY_PATH")
50 if keyPath == "" {
51 keyPath = "./oauth_private_key.pem"
52 }
53
54 if data, err := os.ReadFile(keyPath); err == nil {
55 block, _ := pem.Decode(data)
56 if block != nil {
57 key, err := x509.ParseECPrivateKey(block.Bytes)
58 if err == nil {
59 return key, nil
60 }
61 }
62 }
63
64 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
65 if err != nil {
66 return nil, err
67 }
68
69 keyBytes, err := x509.MarshalECPrivateKey(key)
70 if err != nil {
71 return nil, err
72 }
73
74 block := &pem.Block{
75 Type: "EC PRIVATE KEY",
76 Bytes: keyBytes,
77 }
78
79 if err := os.WriteFile(keyPath, pem.EncodeToMemory(block), 0600); err != nil {
80 log.Printf("Warning: could not save key to %s: %v\n", keyPath, err)
81 }
82
83 return key, nil
84}
85
86func (h *Handler) getDynamicClient(r *http.Request) *Client {
87 baseURL := h.configuredBaseURL
88 if baseURL == "" {
89 scheme := "http"
90 if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
91 scheme = "https"
92 }
93 baseURL = fmt.Sprintf("%s://%s", scheme, r.Host)
94 }
95
96 if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
97 baseURL = baseURL[:len(baseURL)-1]
98 }
99
100 clientID := baseURL + "/client-metadata.json"
101 redirectURI := baseURL + "/auth/callback"
102
103 return NewClient(clientID, redirectURI, h.privateKey)
104}
105
106func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
107 client := h.getDynamicClient(r)
108
109 handle := r.URL.Query().Get("handle")
110 if handle == "" {
111 http.Redirect(w, r, "/login", http.StatusFound)
112 return
113 }
114
115 ctx := r.Context()
116
117 did, err := client.ResolveHandle(ctx, handle)
118 if err != nil {
119 http.Error(w, fmt.Sprintf("Failed to resolve handle: %v", err), http.StatusBadRequest)
120 return
121 }
122
123 pds, err := client.ResolveDIDToPDS(ctx, did)
124 if err != nil {
125 http.Error(w, fmt.Sprintf("Failed to resolve PDS: %v", err), http.StatusBadRequest)
126 return
127 }
128
129 meta, err := client.GetAuthServerMetadata(ctx, pds)
130 if err != nil {
131 http.Error(w, fmt.Sprintf("Failed to get auth server metadata: %v", err), http.StatusBadRequest)
132 return
133 }
134
135 dpopKey, err := client.GenerateDPoPKey()
136 if err != nil {
137 http.Error(w, fmt.Sprintf("Failed to generate DPoP key: %v", err), http.StatusInternalServerError)
138 return
139 }
140
141 pkceVerifier, pkceChallenge := client.GeneratePKCE()
142
143 scope := "atproto transition:generic"
144
145 parResp, state, dpopNonce, err := client.SendPAR(meta, handle, scope, dpopKey, pkceChallenge)
146 if err != nil {
147 http.Error(w, fmt.Sprintf("PAR request failed: %v", err), http.StatusInternalServerError)
148 return
149 }
150
151 pending := &PendingAuth{
152 State: state,
153 DID: did,
154 PDS: pds,
155 AuthServer: meta.TokenEndpoint,
156 Issuer: meta.Issuer,
157 PKCEVerifier: pkceVerifier,
158 DPoPKey: dpopKey,
159 DPoPNonce: dpopNonce,
160 CreatedAt: time.Now(),
161 }
162
163 h.pendingMu.Lock()
164 h.pending[state] = pending
165 h.pendingMu.Unlock()
166
167 authURL, _ := url.Parse(meta.AuthorizationEndpoint)
168 q := authURL.Query()
169 q.Set("client_id", client.ClientID)
170 q.Set("request_uri", parResp.RequestURI)
171 authURL.RawQuery = q.Encode()
172
173 http.Redirect(w, r, authURL.String(), http.StatusFound)
174}
175
176func (h *Handler) HandleStart(w http.ResponseWriter, r *http.Request) {
177 if r.Method != "POST" {
178 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
179 return
180 }
181
182 var req struct {
183 Handle string `json:"handle"`
184 InviteCode string `json:"invite_code"`
185 }
186 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
187 http.Error(w, "Invalid request body", http.StatusBadRequest)
188 return
189 }
190
191 if req.Handle == "" {
192 http.Error(w, "Handle is required", http.StatusBadRequest)
193 return
194 }
195
196 requiredCode := os.Getenv("INVITE_CODE")
197 if requiredCode != "" && req.InviteCode != requiredCode {
198 w.Header().Set("Content-Type", "application/json")
199 w.WriteHeader(http.StatusForbidden)
200 json.NewEncoder(w).Encode(map[string]string{
201 "error": "Invite code required",
202 "code": "invite_required",
203 })
204 return
205 }
206
207 client := h.getDynamicClient(r)
208 ctx := r.Context()
209
210 did, err := client.ResolveHandle(ctx, req.Handle)
211 if err != nil {
212 w.Header().Set("Content-Type", "application/json")
213 w.WriteHeader(http.StatusBadRequest)
214 json.NewEncoder(w).Encode(map[string]string{"error": "Could not find that Bluesky account"})
215 return
216 }
217
218 pds, err := client.ResolveDIDToPDS(ctx, did)
219 if err != nil {
220 w.Header().Set("Content-Type", "application/json")
221 w.WriteHeader(http.StatusBadRequest)
222 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to resolve PDS"})
223 return
224 }
225
226 meta, err := client.GetAuthServerMetadata(ctx, pds)
227 if err != nil {
228 w.Header().Set("Content-Type", "application/json")
229 w.WriteHeader(http.StatusInternalServerError)
230 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to get auth server"})
231 return
232 }
233
234 dpopKey, err := client.GenerateDPoPKey()
235 if err != nil {
236 w.Header().Set("Content-Type", "application/json")
237 w.WriteHeader(http.StatusInternalServerError)
238 json.NewEncoder(w).Encode(map[string]string{"error": "Internal error"})
239 return
240 }
241
242 pkceVerifier, pkceChallenge := client.GeneratePKCE()
243 scope := "atproto transition:generic"
244
245 parResp, state, dpopNonce, err := client.SendPAR(meta, req.Handle, scope, dpopKey, pkceChallenge)
246 if err != nil {
247 log.Printf("PAR request failed: %v", err)
248 w.Header().Set("Content-Type", "application/json")
249 w.WriteHeader(http.StatusInternalServerError)
250 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to initiate authentication"})
251 return
252 }
253
254 pending := &PendingAuth{
255 State: state,
256 DID: did,
257 Handle: req.Handle,
258 PDS: pds,
259 AuthServer: meta.TokenEndpoint,
260 Issuer: meta.Issuer,
261 PKCEVerifier: pkceVerifier,
262 DPoPKey: dpopKey,
263 DPoPNonce: dpopNonce,
264 CreatedAt: time.Now(),
265 }
266
267 h.pendingMu.Lock()
268 h.pending[state] = pending
269 h.pendingMu.Unlock()
270
271 authURL, _ := url.Parse(meta.AuthorizationEndpoint)
272 q := authURL.Query()
273 q.Set("client_id", client.ClientID)
274 q.Set("request_uri", parResp.RequestURI)
275 authURL.RawQuery = q.Encode()
276
277 w.Header().Set("Content-Type", "application/json")
278 json.NewEncoder(w).Encode(map[string]string{
279 "authorizationUrl": authURL.String(),
280 })
281}
282
283func (h *Handler) HandleCallback(w http.ResponseWriter, r *http.Request) {
284 client := h.getDynamicClient(r)
285
286 state := r.URL.Query().Get("state")
287 code := r.URL.Query().Get("code")
288 iss := r.URL.Query().Get("iss")
289
290 if state == "" || code == "" {
291 http.Error(w, "Missing state or code parameter", http.StatusBadRequest)
292 return
293 }
294
295 h.pendingMu.Lock()
296 pending, ok := h.pending[state]
297 if ok {
298 delete(h.pending, state)
299 }
300 h.pendingMu.Unlock()
301
302 if !ok {
303 http.Error(w, "Invalid or expired state", http.StatusBadRequest)
304 return
305 }
306
307 if time.Since(pending.CreatedAt) > 10*time.Minute {
308 http.Error(w, "Authentication request expired", http.StatusBadRequest)
309 return
310 }
311
312 if iss != "" && iss != pending.Issuer {
313 http.Error(w, "Issuer mismatch", http.StatusBadRequest)
314 return
315 }
316
317 ctx := r.Context()
318 meta, err := client.GetAuthServerMetadata(ctx, pending.PDS)
319 if err != nil {
320 http.Error(w, fmt.Sprintf("Failed to get auth metadata: %v", err), http.StatusInternalServerError)
321 return
322 }
323
324 tokenResp, newNonce, err := client.ExchangeCode(meta, code, pending.PKCEVerifier, pending.DPoPKey, pending.DPoPNonce)
325 if err != nil {
326 http.Error(w, fmt.Sprintf("Token exchange failed: %v", err), http.StatusInternalServerError)
327 return
328 }
329
330 _ = newNonce
331
332 sessionID := generateSessionID()
333 expiresAt := time.Now().Add(7 * 24 * time.Hour)
334
335 dpopKeyBytes, err := x509.MarshalECPrivateKey(pending.DPoPKey)
336 if err != nil {
337 http.Error(w, "Failed to marshal DPoP key", http.StatusInternalServerError)
338 return
339 }
340 dpopKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: dpopKeyBytes})
341
342 err = h.db.SaveSession(
343 sessionID,
344 tokenResp.Sub,
345 pending.Handle,
346 tokenResp.AccessToken,
347 tokenResp.RefreshToken,
348 string(dpopKeyPEM),
349 expiresAt,
350 )
351 if err != nil {
352 http.Error(w, "Failed to save session", http.StatusInternalServerError)
353 return
354 }
355
356 http.SetCookie(w, &http.Cookie{
357 Name: "margin_session",
358 Value: sessionID,
359 Path: "/",
360 HttpOnly: true,
361 Secure: true,
362 SameSite: http.SameSiteNoneMode,
363 MaxAge: 86400 * 7,
364 })
365
366 go h.cleanupOrphanedReplies(tokenResp.Sub, tokenResp.AccessToken, string(dpopKeyPEM), pending.PDS)
367
368 http.Redirect(w, r, "/?logged_in=true", http.StatusFound)
369}
370
371func (h *Handler) cleanupOrphanedReplies(did, accessToken, dpopKeyPEM, pds string) {
372 orphans, err := h.db.GetOrphanedRepliesByAuthor(did)
373 if err != nil || len(orphans) == 0 {
374 return
375 }
376
377 block, _ := pem.Decode([]byte(dpopKeyPEM))
378 if block == nil {
379 return
380 }
381 dpopKey, err := x509.ParseECPrivateKey(block.Bytes)
382 if err != nil {
383 return
384 }
385
386 for _, reply := range orphans {
387
388 parts := url.PathEscape(reply.URI)
389 _ = parts
390 uriParts := splitURI(reply.URI)
391 if len(uriParts) < 2 {
392 continue
393 }
394 rkey := uriParts[len(uriParts)-1]
395
396 deleteFromPDS(pds, accessToken, dpopKey, "at.margin.reply", did, rkey)
397
398 h.db.DeleteReply(reply.URI)
399 }
400}
401
402func splitURI(uri string) []string {
403
404 return splitBySlash(uri)
405}
406
407func splitBySlash(s string) []string {
408 var result []string
409 current := ""
410 for _, c := range s {
411 if c == '/' {
412 if current != "" {
413 result = append(result, current)
414 }
415 current = ""
416 } else {
417 current += string(c)
418 }
419 }
420 if current != "" {
421 result = append(result, current)
422 }
423 return result
424}
425
426func deleteFromPDS(pds, accessToken string, dpopKey *ecdsa.PrivateKey, collection, did, rkey string) {
427
428 client := xrpc.NewClient(pds, accessToken, dpopKey)
429 err := client.DeleteRecord(context.Background(), collection, did, rkey)
430 if err != nil {
431 log.Printf("Failed to delete orphaned reply from PDS: %v", err)
432 } else {
433 log.Printf("Cleaned up orphaned reply %s/%s from PDS", collection, rkey)
434 }
435}
436
437func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
438 cookie, err := r.Cookie("margin_session")
439 if err == nil {
440 h.db.DeleteSession(cookie.Value)
441 }
442
443 http.SetCookie(w, &http.Cookie{
444 Name: "margin_session",
445 Value: "",
446 Path: "/",
447 HttpOnly: true,
448 MaxAge: -1,
449 })
450
451 w.Header().Set("Content-Type", "application/json")
452 json.NewEncoder(w).Encode(map[string]bool{"success": true})
453}
454
455func (h *Handler) HandleSession(w http.ResponseWriter, r *http.Request) {
456 cookie, err := r.Cookie("margin_session")
457 if err != nil {
458 w.Header().Set("Content-Type", "application/json")
459 json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false})
460 return
461 }
462
463 did, handle, _, _, _, err := h.db.GetSession(cookie.Value)
464 if err != nil {
465 w.Header().Set("Content-Type", "application/json")
466 json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false})
467 return
468 }
469
470 w.Header().Set("Content-Type", "application/json")
471 json.NewEncoder(w).Encode(map[string]interface{}{
472 "authenticated": true,
473 "did": did,
474 "handle": handle,
475 })
476}
477
478func (h *Handler) HandleClientMetadata(w http.ResponseWriter, r *http.Request) {
479 client := h.getDynamicClient(r)
480 baseURL := client.ClientID[:len(client.ClientID)-len("/client-metadata.json")]
481
482 w.Header().Set("Content-Type", "application/json")
483 json.NewEncoder(w).Encode(map[string]interface{}{
484 "client_id": client.ClientID,
485 "client_name": "Margin",
486 "client_uri": baseURL,
487 "logo_uri": baseURL + "/logo.svg",
488 "tos_uri": baseURL + "/terms",
489 "policy_uri": baseURL + "/privacy",
490 "redirect_uris": []string{client.RedirectURI},
491 "grant_types": []string{"authorization_code", "refresh_token"},
492 "response_types": []string{"code"},
493 "scope": "atproto transition:generic",
494 "token_endpoint_auth_method": "private_key_jwt",
495 "token_endpoint_auth_signing_alg": "ES256",
496 "dpop_bound_access_tokens": true,
497 "jwks_uri": baseURL + "/jwks.json",
498 "application_type": "web",
499 })
500}
501
502func (h *Handler) HandleJWKS(w http.ResponseWriter, r *http.Request) {
503 client := h.getDynamicClient(r)
504 w.Header().Set("Content-Type", "application/json")
505 json.NewEncoder(w).Encode(client.GetPublicJWKS())
506}
507
508func (h *Handler) GetPrivateKey() *ecdsa.PrivateKey {
509 return h.privateKey
510}
511
512func generateSessionID() string {
513 b := make([]byte, 32)
514 rand.Read(b)
515 return fmt.Sprintf("%x", b)
516}