1package api
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "crypto/x509"
8 "encoding/pem"
9 "fmt"
10 "log"
11 "net/http"
12 "os"
13 "time"
14
15 "margin.at/internal/db"
16 "margin.at/internal/oauth"
17 "margin.at/internal/xrpc"
18)
19
20type TokenRefresher struct {
21 db *db.DB
22 privateKey *ecdsa.PrivateKey
23 baseURL string
24}
25
26func NewTokenRefresher(database *db.DB, privateKey *ecdsa.PrivateKey) *TokenRefresher {
27 return &TokenRefresher{
28 db: database,
29 privateKey: privateKey,
30 baseURL: os.Getenv("BASE_URL"),
31 }
32}
33
34func (tr *TokenRefresher) getOAuthClient(r *http.Request) *oauth.Client {
35 baseURL := tr.baseURL
36 if baseURL == "" {
37 scheme := "http"
38 if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
39 scheme = "https"
40 }
41 baseURL = fmt.Sprintf("%s://%s", scheme, r.Host)
42 }
43
44 if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
45 baseURL = baseURL[:len(baseURL)-1]
46 }
47
48 clientID := baseURL + "/client-metadata.json"
49 redirectURI := baseURL + "/auth/callback"
50
51 return oauth.NewClient(clientID, redirectURI, tr.privateKey)
52}
53
54type SessionData struct {
55 ID string
56 DID string
57 Handle string
58 AccessToken string
59 RefreshToken string
60 DPoPKey *ecdsa.PrivateKey
61 PDS string
62}
63
64func (tr *TokenRefresher) GetSessionWithAutoRefresh(r *http.Request) (*SessionData, error) {
65 sessionID := ""
66
67 cookie, err := r.Cookie("margin_session")
68 if err == nil {
69 sessionID = cookie.Value
70 } else {
71 sessionID = r.Header.Get("X-Session-Token")
72 }
73
74 if sessionID == "" {
75 return nil, fmt.Errorf("not authenticated")
76 }
77
78 did, handle, accessToken, refreshToken, dpopKeyStr, err := tr.db.GetSession(sessionID)
79 if err != nil {
80 return nil, fmt.Errorf("session expired")
81 }
82
83 block, _ := pem.Decode([]byte(dpopKeyStr))
84 if block == nil {
85 return nil, fmt.Errorf("invalid session DPoP key")
86 }
87 dpopKey, err := x509.ParseECPrivateKey(block.Bytes)
88 if err != nil {
89 return nil, fmt.Errorf("invalid session DPoP key")
90 }
91
92 pds, err := resolveDIDToPDS(did)
93 if err != nil {
94 return nil, fmt.Errorf("failed to resolve PDS")
95 }
96
97 return &SessionData{
98 ID: sessionID,
99 DID: did,
100 Handle: handle,
101 AccessToken: accessToken,
102 RefreshToken: refreshToken,
103 DPoPKey: dpopKey,
104 PDS: pds,
105 }, nil
106}
107
108func (tr *TokenRefresher) RefreshSessionToken(r *http.Request, session *SessionData) (*SessionData, error) {
109 if session.ID == "" {
110 return nil, fmt.Errorf("invalid session ID")
111 }
112
113 oauthClient := tr.getOAuthClient(r)
114 ctx := context.Background()
115
116 meta, err := oauthClient.GetAuthServerMetadata(ctx, session.PDS)
117 if err != nil {
118 return nil, fmt.Errorf("failed to get auth server metadata: %w", err)
119 }
120
121 tokenResp, _, err := oauthClient.RefreshToken(meta, session.RefreshToken, session.DPoPKey, "")
122 if err != nil {
123 return nil, fmt.Errorf("failed to refresh token: %w", err)
124 }
125
126 dpopKeyBytes, err := x509.MarshalECPrivateKey(session.DPoPKey)
127 if err != nil {
128 return nil, fmt.Errorf("failed to marshal DPoP key: %w", err)
129 }
130 dpopKeyPEM := pem.EncodeToMemory(&pem.Block{
131 Type: "EC PRIVATE KEY",
132 Bytes: dpopKeyBytes,
133 })
134
135 newRefreshToken := tokenResp.RefreshToken
136 if newRefreshToken == "" {
137 newRefreshToken = session.RefreshToken
138 }
139
140 expiresAt := time.Now().Add(7 * 24 * time.Hour)
141 if err := tr.db.SaveSession(
142 session.ID,
143 session.DID,
144 session.Handle,
145 tokenResp.AccessToken,
146 newRefreshToken,
147 string(dpopKeyPEM),
148 expiresAt,
149 ); err != nil {
150 return nil, fmt.Errorf("failed to save refreshed session: %w", err)
151 }
152
153 log.Printf("Successfully refreshed token for user %s", session.Handle)
154
155 return &SessionData{
156 ID: session.ID,
157 DID: session.DID,
158 Handle: session.Handle,
159 AccessToken: tokenResp.AccessToken,
160 RefreshToken: newRefreshToken,
161 DPoPKey: session.DPoPKey,
162 PDS: session.PDS,
163 }, nil
164}
165
166func IsTokenExpiredError(err error) bool {
167 if err == nil {
168 return false
169 }
170 errStr := err.Error()
171 return bytes.Contains([]byte(errStr), []byte("invalid_token")) &&
172 bytes.Contains([]byte(errStr), []byte("exp"))
173}
174
175func (tr *TokenRefresher) ExecuteWithAutoRefresh(
176 r *http.Request,
177 session *SessionData,
178 fn func(client *xrpc.Client, did string) error,
179) error {
180 client := xrpc.NewClient(session.PDS, session.AccessToken, session.DPoPKey)
181
182 err := fn(client, session.DID)
183 if err == nil {
184 return nil
185 }
186
187 if !IsTokenExpiredError(err) {
188 return err
189 }
190
191 log.Printf("Token expired for user %s, attempting refresh...", session.Handle)
192
193 newSession, refreshErr := tr.RefreshSessionToken(r, session)
194 if refreshErr != nil {
195 return fmt.Errorf("original error: %w; refresh failed: %v", err, refreshErr)
196 }
197
198 client = xrpc.NewClient(newSession.PDS, newSession.AccessToken, newSession.DPoPKey)
199 return fn(client, newSession.DID)
200}
201
202func (tr *TokenRefresher) CreateClientFromSession(session *SessionData) *xrpc.Client {
203 return xrpc.NewClient(session.PDS, session.AccessToken, session.DPoPKey)
204}