Write on the margins of the internet. Powered by the AT Protocol. margin.at
extension web atproto comments
at main 5.0 kB view raw
1package api 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/x509" 8 "encoding/pem" 9 "fmt" 10 "log" 11 "net/http" 12 "os" 13 "time" 14 15 "margin.at/internal/db" 16 "margin.at/internal/oauth" 17 "margin.at/internal/xrpc" 18) 19 20type TokenRefresher struct { 21 db *db.DB 22 privateKey *ecdsa.PrivateKey 23 baseURL string 24} 25 26func NewTokenRefresher(database *db.DB, privateKey *ecdsa.PrivateKey) *TokenRefresher { 27 return &TokenRefresher{ 28 db: database, 29 privateKey: privateKey, 30 baseURL: os.Getenv("BASE_URL"), 31 } 32} 33 34func (tr *TokenRefresher) getOAuthClient(r *http.Request) *oauth.Client { 35 baseURL := tr.baseURL 36 if baseURL == "" { 37 scheme := "http" 38 if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { 39 scheme = "https" 40 } 41 baseURL = fmt.Sprintf("%s://%s", scheme, r.Host) 42 } 43 44 if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { 45 baseURL = baseURL[:len(baseURL)-1] 46 } 47 48 clientID := baseURL + "/client-metadata.json" 49 redirectURI := baseURL + "/auth/callback" 50 51 return oauth.NewClient(clientID, redirectURI, tr.privateKey) 52} 53 54type SessionData struct { 55 ID string 56 DID string 57 Handle string 58 AccessToken string 59 RefreshToken string 60 DPoPKey *ecdsa.PrivateKey 61 PDS string 62} 63 64func (tr *TokenRefresher) GetSessionWithAutoRefresh(r *http.Request) (*SessionData, error) { 65 sessionID := "" 66 67 cookie, err := r.Cookie("margin_session") 68 if err == nil { 69 sessionID = cookie.Value 70 } else { 71 sessionID = r.Header.Get("X-Session-Token") 72 } 73 74 if sessionID == "" { 75 return nil, fmt.Errorf("not authenticated") 76 } 77 78 did, handle, accessToken, refreshToken, dpopKeyStr, err := tr.db.GetSession(sessionID) 79 if err != nil { 80 return nil, fmt.Errorf("session expired") 81 } 82 83 block, _ := pem.Decode([]byte(dpopKeyStr)) 84 if block == nil { 85 return nil, fmt.Errorf("invalid session DPoP key") 86 } 87 dpopKey, err := x509.ParseECPrivateKey(block.Bytes) 88 if err != nil { 89 return nil, fmt.Errorf("invalid session DPoP key") 90 } 91 92 pds, err := resolveDIDToPDS(did) 93 if err != nil { 94 return nil, fmt.Errorf("failed to resolve PDS") 95 } 96 97 return &SessionData{ 98 ID: sessionID, 99 DID: did, 100 Handle: handle, 101 AccessToken: accessToken, 102 RefreshToken: refreshToken, 103 DPoPKey: dpopKey, 104 PDS: pds, 105 }, nil 106} 107 108func (tr *TokenRefresher) RefreshSessionToken(r *http.Request, session *SessionData) (*SessionData, error) { 109 if session.ID == "" { 110 return nil, fmt.Errorf("invalid session ID") 111 } 112 113 oauthClient := tr.getOAuthClient(r) 114 ctx := context.Background() 115 116 meta, err := oauthClient.GetAuthServerMetadata(ctx, session.PDS) 117 if err != nil { 118 return nil, fmt.Errorf("failed to get auth server metadata: %w", err) 119 } 120 121 tokenResp, _, err := oauthClient.RefreshToken(meta, session.RefreshToken, session.DPoPKey, "") 122 if err != nil { 123 return nil, fmt.Errorf("failed to refresh token: %w", err) 124 } 125 126 dpopKeyBytes, err := x509.MarshalECPrivateKey(session.DPoPKey) 127 if err != nil { 128 return nil, fmt.Errorf("failed to marshal DPoP key: %w", err) 129 } 130 dpopKeyPEM := pem.EncodeToMemory(&pem.Block{ 131 Type: "EC PRIVATE KEY", 132 Bytes: dpopKeyBytes, 133 }) 134 135 newRefreshToken := tokenResp.RefreshToken 136 if newRefreshToken == "" { 137 newRefreshToken = session.RefreshToken 138 } 139 140 expiresAt := time.Now().Add(7 * 24 * time.Hour) 141 if err := tr.db.SaveSession( 142 session.ID, 143 session.DID, 144 session.Handle, 145 tokenResp.AccessToken, 146 newRefreshToken, 147 string(dpopKeyPEM), 148 expiresAt, 149 ); err != nil { 150 return nil, fmt.Errorf("failed to save refreshed session: %w", err) 151 } 152 153 log.Printf("Successfully refreshed token for user %s", session.Handle) 154 155 return &SessionData{ 156 ID: session.ID, 157 DID: session.DID, 158 Handle: session.Handle, 159 AccessToken: tokenResp.AccessToken, 160 RefreshToken: newRefreshToken, 161 DPoPKey: session.DPoPKey, 162 PDS: session.PDS, 163 }, nil 164} 165 166func IsTokenExpiredError(err error) bool { 167 if err == nil { 168 return false 169 } 170 errStr := err.Error() 171 return bytes.Contains([]byte(errStr), []byte("invalid_token")) && 172 bytes.Contains([]byte(errStr), []byte("exp")) 173} 174 175func (tr *TokenRefresher) ExecuteWithAutoRefresh( 176 r *http.Request, 177 session *SessionData, 178 fn func(client *xrpc.Client, did string) error, 179) error { 180 client := xrpc.NewClient(session.PDS, session.AccessToken, session.DPoPKey) 181 182 err := fn(client, session.DID) 183 if err == nil { 184 return nil 185 } 186 187 if !IsTokenExpiredError(err) { 188 return err 189 } 190 191 log.Printf("Token expired for user %s, attempting refresh...", session.Handle) 192 193 newSession, refreshErr := tr.RefreshSessionToken(r, session) 194 if refreshErr != nil { 195 return fmt.Errorf("original error: %w; refresh failed: %v", err, refreshErr) 196 } 197 198 client = xrpc.NewClient(newSession.PDS, newSession.AccessToken, newSession.DPoPKey) 199 return fn(client, newSession.DID) 200} 201 202func (tr *TokenRefresher) CreateClientFromSession(session *SessionData) *xrpc.Client { 203 return xrpc.NewClient(session.PDS, session.AccessToken, session.DPoPKey) 204}