package atproto import ( "encoding/json" "fmt" "io" "log" "net/http" "net/url" "os" "strings" "time" "github.com/limeleaf/diffdown/internal/atproto/dpop" "github.com/limeleaf/diffdown/internal/db" "github.com/limeleaf/diffdown/internal/model" ) // RefreshAccessToken refreshes the access token for the given ATProto session. func RefreshAccessToken(database *db.DB, session *model.ATProtoSession) error { kp, err := dpop.UnmarshalPrivate([]byte(session.DPoPKeyJWK)) if err != nil { return fmt.Errorf("unmarshal DPoP key: %w", err) } base := os.Getenv("DIFFDOWN_BASE_URL") if base == "" { base = "http://127.0.0.1:8080" } clientID := strings.TrimRight(base, "/") + "/client-metadata.json" params := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {session.RefreshToken}, "client_id": {clientID}, } proof, err := kp.Proof("POST", session.TokenEndpoint, session.DPoPNonce, "") if err != nil { return fmt.Errorf("build DPoP proof: %w", err) } req, _ := http.NewRequest("POST", session.TokenEndpoint, strings.NewReader(params.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("DPoP", proof) resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("refresh request: %w", err) } defer resp.Body.Close() // Handle nonce retry newNonce := resp.Header.Get("DPoP-Nonce") if resp.StatusCode == http.StatusBadRequest && newNonce != "" { proof, err = kp.Proof("POST", session.TokenEndpoint, newNonce, "") if err != nil { return fmt.Errorf("build DPoP proof (retry): %w", err) } req2, _ := http.NewRequest("POST", session.TokenEndpoint, strings.NewReader(params.Encode())) req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") req2.Header.Set("DPoP", proof) resp.Body.Close() resp, err = http.DefaultClient.Do(req2) if err != nil { return fmt.Errorf("refresh retry: %w", err) } defer resp.Body.Close() newNonce = resp.Header.Get("DPoP-Nonce") } if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("refresh failed (HTTP %d): %s", resp.StatusCode, body) } var tokenBody struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int `json:"expires_in"` } if err := json.NewDecoder(resp.Body).Decode(&tokenBody); err != nil { return fmt.Errorf("decode refresh response: %w", err) } if newNonce == "" { newNonce = session.DPoPNonce } expiresAt := time.Now().Add(time.Duration(tokenBody.ExpiresIn) * time.Second) if err := database.UpdateATProtoTokens(session.UserID, tokenBody.AccessToken, tokenBody.RefreshToken, newNonce, expiresAt); err != nil { return fmt.Errorf("update tokens in DB: %w", err) } // Update in-memory session session.AccessToken = tokenBody.AccessToken session.RefreshToken = tokenBody.RefreshToken session.DPoPNonce = newNonce session.ExpiresAt = expiresAt log.Printf("ATProto: refreshed token for user %s", session.UserID) return nil }