Write on the margins of the internet. Powered by the AT Protocol. margin.at
extension web atproto comments
at main 13 kB view raw
1package oauth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rand" 8 "crypto/x509" 9 "encoding/json" 10 "encoding/pem" 11 "fmt" 12 "log" 13 "net/http" 14 "net/url" 15 "os" 16 "sync" 17 "time" 18 19 "margin.at/internal/db" 20 "margin.at/internal/xrpc" 21) 22 23type Handler struct { 24 db *db.DB 25 configuredBaseURL string 26 privateKey *ecdsa.PrivateKey 27 pending map[string]*PendingAuth 28 pendingMu sync.RWMutex 29} 30 31func NewHandler(database *db.DB) (*Handler, error) { 32 33 configuredBaseURL := os.Getenv("BASE_URL") 34 35 privateKey, err := loadOrGenerateKey() 36 if err != nil { 37 return nil, fmt.Errorf("failed to load/generate key: %w", err) 38 } 39 40 return &Handler{ 41 db: database, 42 configuredBaseURL: configuredBaseURL, 43 privateKey: privateKey, 44 pending: make(map[string]*PendingAuth), 45 }, nil 46} 47 48func loadOrGenerateKey() (*ecdsa.PrivateKey, error) { 49 keyPath := os.Getenv("OAUTH_KEY_PATH") 50 if keyPath == "" { 51 keyPath = "./oauth_private_key.pem" 52 } 53 54 if data, err := os.ReadFile(keyPath); err == nil { 55 block, _ := pem.Decode(data) 56 if block != nil { 57 key, err := x509.ParseECPrivateKey(block.Bytes) 58 if err == nil { 59 return key, nil 60 } 61 } 62 } 63 64 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 65 if err != nil { 66 return nil, err 67 } 68 69 keyBytes, err := x509.MarshalECPrivateKey(key) 70 if err != nil { 71 return nil, err 72 } 73 74 block := &pem.Block{ 75 Type: "EC PRIVATE KEY", 76 Bytes: keyBytes, 77 } 78 79 if err := os.WriteFile(keyPath, pem.EncodeToMemory(block), 0600); err != nil { 80 log.Printf("Warning: could not save key to %s: %v\n", keyPath, err) 81 } 82 83 return key, nil 84} 85 86func (h *Handler) getDynamicClient(r *http.Request) *Client { 87 baseURL := h.configuredBaseURL 88 if baseURL == "" { 89 scheme := "http" 90 if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { 91 scheme = "https" 92 } 93 baseURL = fmt.Sprintf("%s://%s", scheme, r.Host) 94 } 95 96 if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { 97 baseURL = baseURL[:len(baseURL)-1] 98 } 99 100 clientID := baseURL + "/client-metadata.json" 101 redirectURI := baseURL + "/auth/callback" 102 103 return NewClient(clientID, redirectURI, h.privateKey) 104} 105 106func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) { 107 client := h.getDynamicClient(r) 108 109 handle := r.URL.Query().Get("handle") 110 if handle == "" { 111 http.Redirect(w, r, "/login", http.StatusFound) 112 return 113 } 114 115 ctx := r.Context() 116 117 did, err := client.ResolveHandle(ctx, handle) 118 if err != nil { 119 http.Error(w, fmt.Sprintf("Failed to resolve handle: %v", err), http.StatusBadRequest) 120 return 121 } 122 123 pds, err := client.ResolveDIDToPDS(ctx, did) 124 if err != nil { 125 http.Error(w, fmt.Sprintf("Failed to resolve PDS: %v", err), http.StatusBadRequest) 126 return 127 } 128 129 meta, err := client.GetAuthServerMetadata(ctx, pds) 130 if err != nil { 131 http.Error(w, fmt.Sprintf("Failed to get auth server metadata: %v", err), http.StatusBadRequest) 132 return 133 } 134 135 dpopKey, err := client.GenerateDPoPKey() 136 if err != nil { 137 http.Error(w, fmt.Sprintf("Failed to generate DPoP key: %v", err), http.StatusInternalServerError) 138 return 139 } 140 141 pkceVerifier, pkceChallenge := client.GeneratePKCE() 142 143 scope := "atproto transition:generic" 144 145 parResp, state, dpopNonce, err := client.SendPAR(meta, handle, scope, dpopKey, pkceChallenge) 146 if err != nil { 147 http.Error(w, fmt.Sprintf("PAR request failed: %v", err), http.StatusInternalServerError) 148 return 149 } 150 151 pending := &PendingAuth{ 152 State: state, 153 DID: did, 154 PDS: pds, 155 AuthServer: meta.TokenEndpoint, 156 Issuer: meta.Issuer, 157 PKCEVerifier: pkceVerifier, 158 DPoPKey: dpopKey, 159 DPoPNonce: dpopNonce, 160 CreatedAt: time.Now(), 161 } 162 163 h.pendingMu.Lock() 164 h.pending[state] = pending 165 h.pendingMu.Unlock() 166 167 authURL, _ := url.Parse(meta.AuthorizationEndpoint) 168 q := authURL.Query() 169 q.Set("client_id", client.ClientID) 170 q.Set("request_uri", parResp.RequestURI) 171 authURL.RawQuery = q.Encode() 172 173 http.Redirect(w, r, authURL.String(), http.StatusFound) 174} 175 176func (h *Handler) HandleStart(w http.ResponseWriter, r *http.Request) { 177 if r.Method != "POST" { 178 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 179 return 180 } 181 182 var req struct { 183 Handle string `json:"handle"` 184 InviteCode string `json:"invite_code"` 185 } 186 if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 187 http.Error(w, "Invalid request body", http.StatusBadRequest) 188 return 189 } 190 191 if req.Handle == "" { 192 http.Error(w, "Handle is required", http.StatusBadRequest) 193 return 194 } 195 196 requiredCode := os.Getenv("INVITE_CODE") 197 if requiredCode != "" && req.InviteCode != requiredCode { 198 w.Header().Set("Content-Type", "application/json") 199 w.WriteHeader(http.StatusForbidden) 200 json.NewEncoder(w).Encode(map[string]string{ 201 "error": "Invite code required", 202 "code": "invite_required", 203 }) 204 return 205 } 206 207 client := h.getDynamicClient(r) 208 ctx := r.Context() 209 210 did, err := client.ResolveHandle(ctx, req.Handle) 211 if err != nil { 212 w.Header().Set("Content-Type", "application/json") 213 w.WriteHeader(http.StatusBadRequest) 214 json.NewEncoder(w).Encode(map[string]string{"error": "Could not find that Bluesky account"}) 215 return 216 } 217 218 pds, err := client.ResolveDIDToPDS(ctx, did) 219 if err != nil { 220 w.Header().Set("Content-Type", "application/json") 221 w.WriteHeader(http.StatusBadRequest) 222 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to resolve PDS"}) 223 return 224 } 225 226 meta, err := client.GetAuthServerMetadata(ctx, pds) 227 if err != nil { 228 w.Header().Set("Content-Type", "application/json") 229 w.WriteHeader(http.StatusInternalServerError) 230 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to get auth server"}) 231 return 232 } 233 234 dpopKey, err := client.GenerateDPoPKey() 235 if err != nil { 236 w.Header().Set("Content-Type", "application/json") 237 w.WriteHeader(http.StatusInternalServerError) 238 json.NewEncoder(w).Encode(map[string]string{"error": "Internal error"}) 239 return 240 } 241 242 pkceVerifier, pkceChallenge := client.GeneratePKCE() 243 scope := "atproto transition:generic" 244 245 parResp, state, dpopNonce, err := client.SendPAR(meta, req.Handle, scope, dpopKey, pkceChallenge) 246 if err != nil { 247 log.Printf("PAR request failed: %v", err) 248 w.Header().Set("Content-Type", "application/json") 249 w.WriteHeader(http.StatusInternalServerError) 250 json.NewEncoder(w).Encode(map[string]string{"error": "Failed to initiate authentication"}) 251 return 252 } 253 254 pending := &PendingAuth{ 255 State: state, 256 DID: did, 257 Handle: req.Handle, 258 PDS: pds, 259 AuthServer: meta.TokenEndpoint, 260 Issuer: meta.Issuer, 261 PKCEVerifier: pkceVerifier, 262 DPoPKey: dpopKey, 263 DPoPNonce: dpopNonce, 264 CreatedAt: time.Now(), 265 } 266 267 h.pendingMu.Lock() 268 h.pending[state] = pending 269 h.pendingMu.Unlock() 270 271 authURL, _ := url.Parse(meta.AuthorizationEndpoint) 272 q := authURL.Query() 273 q.Set("client_id", client.ClientID) 274 q.Set("request_uri", parResp.RequestURI) 275 authURL.RawQuery = q.Encode() 276 277 w.Header().Set("Content-Type", "application/json") 278 json.NewEncoder(w).Encode(map[string]string{ 279 "authorizationUrl": authURL.String(), 280 }) 281} 282 283func (h *Handler) HandleCallback(w http.ResponseWriter, r *http.Request) { 284 client := h.getDynamicClient(r) 285 286 state := r.URL.Query().Get("state") 287 code := r.URL.Query().Get("code") 288 iss := r.URL.Query().Get("iss") 289 290 if state == "" || code == "" { 291 http.Error(w, "Missing state or code parameter", http.StatusBadRequest) 292 return 293 } 294 295 h.pendingMu.Lock() 296 pending, ok := h.pending[state] 297 if ok { 298 delete(h.pending, state) 299 } 300 h.pendingMu.Unlock() 301 302 if !ok { 303 http.Error(w, "Invalid or expired state", http.StatusBadRequest) 304 return 305 } 306 307 if time.Since(pending.CreatedAt) > 10*time.Minute { 308 http.Error(w, "Authentication request expired", http.StatusBadRequest) 309 return 310 } 311 312 if iss != "" && iss != pending.Issuer { 313 http.Error(w, "Issuer mismatch", http.StatusBadRequest) 314 return 315 } 316 317 ctx := r.Context() 318 meta, err := client.GetAuthServerMetadata(ctx, pending.PDS) 319 if err != nil { 320 http.Error(w, fmt.Sprintf("Failed to get auth metadata: %v", err), http.StatusInternalServerError) 321 return 322 } 323 324 tokenResp, newNonce, err := client.ExchangeCode(meta, code, pending.PKCEVerifier, pending.DPoPKey, pending.DPoPNonce) 325 if err != nil { 326 http.Error(w, fmt.Sprintf("Token exchange failed: %v", err), http.StatusInternalServerError) 327 return 328 } 329 330 _ = newNonce 331 332 sessionID := generateSessionID() 333 expiresAt := time.Now().Add(7 * 24 * time.Hour) 334 335 dpopKeyBytes, err := x509.MarshalECPrivateKey(pending.DPoPKey) 336 if err != nil { 337 http.Error(w, "Failed to marshal DPoP key", http.StatusInternalServerError) 338 return 339 } 340 dpopKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: dpopKeyBytes}) 341 342 err = h.db.SaveSession( 343 sessionID, 344 tokenResp.Sub, 345 pending.Handle, 346 tokenResp.AccessToken, 347 tokenResp.RefreshToken, 348 string(dpopKeyPEM), 349 expiresAt, 350 ) 351 if err != nil { 352 http.Error(w, "Failed to save session", http.StatusInternalServerError) 353 return 354 } 355 356 http.SetCookie(w, &http.Cookie{ 357 Name: "margin_session", 358 Value: sessionID, 359 Path: "/", 360 HttpOnly: true, 361 Secure: true, 362 SameSite: http.SameSiteNoneMode, 363 MaxAge: 86400 * 7, 364 }) 365 366 go h.cleanupOrphanedReplies(tokenResp.Sub, tokenResp.AccessToken, string(dpopKeyPEM), pending.PDS) 367 368 http.Redirect(w, r, "/?logged_in=true", http.StatusFound) 369} 370 371func (h *Handler) cleanupOrphanedReplies(did, accessToken, dpopKeyPEM, pds string) { 372 orphans, err := h.db.GetOrphanedRepliesByAuthor(did) 373 if err != nil || len(orphans) == 0 { 374 return 375 } 376 377 block, _ := pem.Decode([]byte(dpopKeyPEM)) 378 if block == nil { 379 return 380 } 381 dpopKey, err := x509.ParseECPrivateKey(block.Bytes) 382 if err != nil { 383 return 384 } 385 386 for _, reply := range orphans { 387 388 parts := url.PathEscape(reply.URI) 389 _ = parts 390 uriParts := splitURI(reply.URI) 391 if len(uriParts) < 2 { 392 continue 393 } 394 rkey := uriParts[len(uriParts)-1] 395 396 deleteFromPDS(pds, accessToken, dpopKey, "at.margin.reply", did, rkey) 397 398 h.db.DeleteReply(reply.URI) 399 } 400} 401 402func splitURI(uri string) []string { 403 404 return splitBySlash(uri) 405} 406 407func splitBySlash(s string) []string { 408 var result []string 409 current := "" 410 for _, c := range s { 411 if c == '/' { 412 if current != "" { 413 result = append(result, current) 414 } 415 current = "" 416 } else { 417 current += string(c) 418 } 419 } 420 if current != "" { 421 result = append(result, current) 422 } 423 return result 424} 425 426func deleteFromPDS(pds, accessToken string, dpopKey *ecdsa.PrivateKey, collection, did, rkey string) { 427 428 client := xrpc.NewClient(pds, accessToken, dpopKey) 429 err := client.DeleteRecord(context.Background(), collection, did, rkey) 430 if err != nil { 431 log.Printf("Failed to delete orphaned reply from PDS: %v", err) 432 } else { 433 log.Printf("Cleaned up orphaned reply %s/%s from PDS", collection, rkey) 434 } 435} 436 437func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) { 438 cookie, err := r.Cookie("margin_session") 439 if err == nil { 440 h.db.DeleteSession(cookie.Value) 441 } 442 443 http.SetCookie(w, &http.Cookie{ 444 Name: "margin_session", 445 Value: "", 446 Path: "/", 447 HttpOnly: true, 448 MaxAge: -1, 449 }) 450 451 w.Header().Set("Content-Type", "application/json") 452 json.NewEncoder(w).Encode(map[string]bool{"success": true}) 453} 454 455func (h *Handler) HandleSession(w http.ResponseWriter, r *http.Request) { 456 cookie, err := r.Cookie("margin_session") 457 if err != nil { 458 w.Header().Set("Content-Type", "application/json") 459 json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false}) 460 return 461 } 462 463 did, handle, _, _, _, err := h.db.GetSession(cookie.Value) 464 if err != nil { 465 w.Header().Set("Content-Type", "application/json") 466 json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false}) 467 return 468 } 469 470 w.Header().Set("Content-Type", "application/json") 471 json.NewEncoder(w).Encode(map[string]interface{}{ 472 "authenticated": true, 473 "did": did, 474 "handle": handle, 475 }) 476} 477 478func (h *Handler) HandleClientMetadata(w http.ResponseWriter, r *http.Request) { 479 client := h.getDynamicClient(r) 480 baseURL := client.ClientID[:len(client.ClientID)-len("/client-metadata.json")] 481 482 w.Header().Set("Content-Type", "application/json") 483 json.NewEncoder(w).Encode(map[string]interface{}{ 484 "client_id": client.ClientID, 485 "client_name": "Margin", 486 "client_uri": baseURL, 487 "logo_uri": baseURL + "/logo.svg", 488 "tos_uri": baseURL + "/terms", 489 "policy_uri": baseURL + "/privacy", 490 "redirect_uris": []string{client.RedirectURI}, 491 "grant_types": []string{"authorization_code", "refresh_token"}, 492 "response_types": []string{"code"}, 493 "scope": "atproto transition:generic", 494 "token_endpoint_auth_method": "private_key_jwt", 495 "token_endpoint_auth_signing_alg": "ES256", 496 "dpop_bound_access_tokens": true, 497 "jwks_uri": baseURL + "/jwks.json", 498 "application_type": "web", 499 }) 500} 501 502func (h *Handler) HandleJWKS(w http.ResponseWriter, r *http.Request) { 503 client := h.getDynamicClient(r) 504 w.Header().Set("Content-Type", "application/json") 505 json.NewEncoder(w).Encode(client.GetPublicJWKS()) 506} 507 508func (h *Handler) GetPrivateKey() *ecdsa.PrivateKey { 509 return h.privateKey 510} 511 512func generateSessionID() string { 513 b := make([]byte, 32) 514 rand.Read(b) 515 return fmt.Sprintf("%x", b) 516}