package oauth import ( "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "log" "net/http" "net/url" "os" "sync" "time" "margin.at/internal/db" "margin.at/internal/xrpc" ) type Handler struct { db *db.DB configuredBaseURL string privateKey *ecdsa.PrivateKey pending map[string]*PendingAuth pendingMu sync.RWMutex } func NewHandler(database *db.DB) (*Handler, error) { configuredBaseURL := os.Getenv("BASE_URL") privateKey, err := loadOrGenerateKey() if err != nil { return nil, fmt.Errorf("failed to load/generate key: %w", err) } return &Handler{ db: database, configuredBaseURL: configuredBaseURL, privateKey: privateKey, pending: make(map[string]*PendingAuth), }, nil } func loadOrGenerateKey() (*ecdsa.PrivateKey, error) { keyPath := os.Getenv("OAUTH_KEY_PATH") if keyPath == "" { keyPath = "./oauth_private_key.pem" } if data, err := os.ReadFile(keyPath); err == nil { block, _ := pem.Decode(data) if block != nil { key, err := x509.ParseECPrivateKey(block.Bytes) if err == nil { return key, nil } } } key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, err } keyBytes, err := x509.MarshalECPrivateKey(key) if err != nil { return nil, err } block := &pem.Block{ Type: "EC PRIVATE KEY", Bytes: keyBytes, } if err := os.WriteFile(keyPath, pem.EncodeToMemory(block), 0600); err != nil { log.Printf("Warning: could not save key to %s: %v\n", keyPath, err) } return key, nil } func (h *Handler) getDynamicClient(r *http.Request) *Client { baseURL := h.configuredBaseURL if baseURL == "" { scheme := "http" if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { scheme = "https" } baseURL = fmt.Sprintf("%s://%s", scheme, r.Host) } if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { baseURL = baseURL[:len(baseURL)-1] } clientID := baseURL + "/client-metadata.json" redirectURI := baseURL + "/auth/callback" return NewClient(clientID, redirectURI, h.privateKey) } func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) { client := h.getDynamicClient(r) handle := r.URL.Query().Get("handle") if handle == "" { http.Redirect(w, r, "/login", http.StatusFound) return } ctx := r.Context() did, err := client.ResolveHandle(ctx, handle) if err != nil { http.Error(w, fmt.Sprintf("Failed to resolve handle: %v", err), http.StatusBadRequest) return } pds, err := client.ResolveDIDToPDS(ctx, did) if err != nil { http.Error(w, fmt.Sprintf("Failed to resolve PDS: %v", err), http.StatusBadRequest) return } meta, err := client.GetAuthServerMetadata(ctx, pds) if err != nil { http.Error(w, fmt.Sprintf("Failed to get auth server metadata: %v", err), http.StatusBadRequest) return } dpopKey, err := client.GenerateDPoPKey() if err != nil { http.Error(w, fmt.Sprintf("Failed to generate DPoP key: %v", err), http.StatusInternalServerError) return } pkceVerifier, pkceChallenge := client.GeneratePKCE() scope := "atproto transition:generic" parResp, state, dpopNonce, err := client.SendPAR(meta, handle, scope, dpopKey, pkceChallenge) if err != nil { http.Error(w, fmt.Sprintf("PAR request failed: %v", err), http.StatusInternalServerError) return } pending := &PendingAuth{ State: state, DID: did, PDS: pds, AuthServer: meta.TokenEndpoint, Issuer: meta.Issuer, PKCEVerifier: pkceVerifier, DPoPKey: dpopKey, DPoPNonce: dpopNonce, CreatedAt: time.Now(), } h.pendingMu.Lock() h.pending[state] = pending h.pendingMu.Unlock() authURL, _ := url.Parse(meta.AuthorizationEndpoint) q := authURL.Query() q.Set("client_id", client.ClientID) q.Set("request_uri", parResp.RequestURI) authURL.RawQuery = q.Encode() http.Redirect(w, r, authURL.String(), http.StatusFound) } func (h *Handler) HandleStart(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req struct { Handle string `json:"handle"` InviteCode string `json:"invite_code"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.Handle == "" { http.Error(w, "Handle is required", http.StatusBadRequest) return } requiredCode := os.Getenv("INVITE_CODE") if requiredCode != "" && req.InviteCode != requiredCode { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) json.NewEncoder(w).Encode(map[string]string{ "error": "Invite code required", "code": "invite_required", }) return } client := h.getDynamicClient(r) ctx := r.Context() did, err := client.ResolveHandle(ctx, req.Handle) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "Could not find that Bluesky account"}) return } pds, err := client.ResolveDIDToPDS(ctx, did) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "Failed to resolve PDS"}) return } meta, err := client.GetAuthServerMetadata(ctx, pds) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "Failed to get auth server"}) return } dpopKey, err := client.GenerateDPoPKey() if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "Internal error"}) return } pkceVerifier, pkceChallenge := client.GeneratePKCE() scope := "atproto transition:generic" parResp, state, dpopNonce, err := client.SendPAR(meta, req.Handle, scope, dpopKey, pkceChallenge) if err != nil { log.Printf("PAR request failed: %v", err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "Failed to initiate authentication"}) return } pending := &PendingAuth{ State: state, DID: did, Handle: req.Handle, PDS: pds, AuthServer: meta.TokenEndpoint, Issuer: meta.Issuer, PKCEVerifier: pkceVerifier, DPoPKey: dpopKey, DPoPNonce: dpopNonce, CreatedAt: time.Now(), } h.pendingMu.Lock() h.pending[state] = pending h.pendingMu.Unlock() authURL, _ := url.Parse(meta.AuthorizationEndpoint) q := authURL.Query() q.Set("client_id", client.ClientID) q.Set("request_uri", parResp.RequestURI) authURL.RawQuery = q.Encode() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{ "authorizationUrl": authURL.String(), }) } func (h *Handler) HandleCallback(w http.ResponseWriter, r *http.Request) { client := h.getDynamicClient(r) state := r.URL.Query().Get("state") code := r.URL.Query().Get("code") iss := r.URL.Query().Get("iss") if state == "" || code == "" { http.Error(w, "Missing state or code parameter", http.StatusBadRequest) return } h.pendingMu.Lock() pending, ok := h.pending[state] if ok { delete(h.pending, state) } h.pendingMu.Unlock() if !ok { http.Error(w, "Invalid or expired state", http.StatusBadRequest) return } if time.Since(pending.CreatedAt) > 10*time.Minute { http.Error(w, "Authentication request expired", http.StatusBadRequest) return } if iss != "" && iss != pending.Issuer { http.Error(w, "Issuer mismatch", http.StatusBadRequest) return } ctx := r.Context() meta, err := client.GetAuthServerMetadata(ctx, pending.PDS) if err != nil { http.Error(w, fmt.Sprintf("Failed to get auth metadata: %v", err), http.StatusInternalServerError) return } tokenResp, newNonce, err := client.ExchangeCode(meta, code, pending.PKCEVerifier, pending.DPoPKey, pending.DPoPNonce) if err != nil { http.Error(w, fmt.Sprintf("Token exchange failed: %v", err), http.StatusInternalServerError) return } _ = newNonce sessionID := generateSessionID() expiresAt := time.Now().Add(7 * 24 * time.Hour) dpopKeyBytes, err := x509.MarshalECPrivateKey(pending.DPoPKey) if err != nil { http.Error(w, "Failed to marshal DPoP key", http.StatusInternalServerError) return } dpopKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: dpopKeyBytes}) err = h.db.SaveSession( sessionID, tokenResp.Sub, pending.Handle, tokenResp.AccessToken, tokenResp.RefreshToken, string(dpopKeyPEM), expiresAt, ) if err != nil { http.Error(w, "Failed to save session", http.StatusInternalServerError) return } http.SetCookie(w, &http.Cookie{ Name: "margin_session", Value: sessionID, Path: "/", HttpOnly: true, Secure: true, SameSite: http.SameSiteNoneMode, MaxAge: 86400 * 7, }) go h.cleanupOrphanedReplies(tokenResp.Sub, tokenResp.AccessToken, string(dpopKeyPEM), pending.PDS) http.Redirect(w, r, "/?logged_in=true", http.StatusFound) } func (h *Handler) cleanupOrphanedReplies(did, accessToken, dpopKeyPEM, pds string) { orphans, err := h.db.GetOrphanedRepliesByAuthor(did) if err != nil || len(orphans) == 0 { return } block, _ := pem.Decode([]byte(dpopKeyPEM)) if block == nil { return } dpopKey, err := x509.ParseECPrivateKey(block.Bytes) if err != nil { return } for _, reply := range orphans { parts := url.PathEscape(reply.URI) _ = parts uriParts := splitURI(reply.URI) if len(uriParts) < 2 { continue } rkey := uriParts[len(uriParts)-1] deleteFromPDS(pds, accessToken, dpopKey, "at.margin.reply", did, rkey) h.db.DeleteReply(reply.URI) } } func splitURI(uri string) []string { return splitBySlash(uri) } func splitBySlash(s string) []string { var result []string current := "" for _, c := range s { if c == '/' { if current != "" { result = append(result, current) } current = "" } else { current += string(c) } } if current != "" { result = append(result, current) } return result } func deleteFromPDS(pds, accessToken string, dpopKey *ecdsa.PrivateKey, collection, did, rkey string) { client := xrpc.NewClient(pds, accessToken, dpopKey) err := client.DeleteRecord(context.Background(), collection, did, rkey) if err != nil { log.Printf("Failed to delete orphaned reply from PDS: %v", err) } else { log.Printf("Cleaned up orphaned reply %s/%s from PDS", collection, rkey) } } func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("margin_session") if err == nil { h.db.DeleteSession(cookie.Value) } http.SetCookie(w, &http.Cookie{ Name: "margin_session", Value: "", Path: "/", HttpOnly: true, MaxAge: -1, }) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]bool{"success": true}) } func (h *Handler) HandleSession(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("margin_session") if err != nil { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false}) return } did, handle, _, _, _, err := h.db.GetSession(cookie.Value) if err != nil { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{"authenticated": false}) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "authenticated": true, "did": did, "handle": handle, }) } func (h *Handler) HandleClientMetadata(w http.ResponseWriter, r *http.Request) { client := h.getDynamicClient(r) baseURL := client.ClientID[:len(client.ClientID)-len("/client-metadata.json")] w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "client_id": client.ClientID, "client_name": "Margin", "client_uri": baseURL, "logo_uri": baseURL + "/logo.svg", "tos_uri": baseURL + "/terms", "policy_uri": baseURL + "/privacy", "redirect_uris": []string{client.RedirectURI}, "grant_types": []string{"authorization_code", "refresh_token"}, "response_types": []string{"code"}, "scope": "atproto transition:generic", "token_endpoint_auth_method": "private_key_jwt", "token_endpoint_auth_signing_alg": "ES256", "dpop_bound_access_tokens": true, "jwks_uri": baseURL + "/jwks.json", "application_type": "web", }) } func (h *Handler) HandleJWKS(w http.ResponseWriter, r *http.Request) { client := h.getDynamicClient(r) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(client.GetPublicJWKS()) } func (h *Handler) GetPrivateKey() *ecdsa.PrivateKey { return h.privateKey } func generateSessionID() string { b := make([]byte, 32) rand.Read(b) return fmt.Sprintf("%x", b) }