1package db
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/bluesky-social/indigo/atproto/auth/oauth"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "github.com/teal-fm/piper/models"
13)
14
15func (db *DB) FindOrCreateUserByDID(did string) (*models.User, error) {
16 var user models.User
17 err := db.QueryRow(`
18 SELECT id, atproto_did, created_at, updated_at
19 FROM users
20 WHERE atproto_did = ?`,
21 did).Scan(&user.ID, &user.ATProtoDID, &user.CreatedAt, &user.UpdatedAt)
22
23 if err == sql.ErrNoRows {
24 now := time.Now().UTC()
25 // create user!
26 result, insertErr := db.Exec(`
27 INSERT INTO users (atproto_did, created_at, updated_at)
28 VALUES (?, ?, ?)
29 `,
30 did,
31 now,
32 now)
33 if insertErr != nil {
34 return nil, fmt.Errorf("failed to create user: %w", insertErr)
35 }
36 lastID, idErr := result.LastInsertId()
37 if idErr != nil {
38 return nil, fmt.Errorf("failed to get last insert id: %w", idErr)
39 }
40 user.ID = lastID
41 user.ATProtoDID = &did
42 user.CreatedAt = now
43 user.UpdatedAt = now
44 return &user, nil
45 } else if err != nil {
46 return nil, fmt.Errorf("failed to find user by DID: %w", err)
47 }
48
49 return &user, err
50}
51
52func (db *DB) SetLatestATProtoSessionId(did string, atProtoSessionID string) error {
53 db.logger.Printf("Setting latest atproto session id for did %s to %s", did, atProtoSessionID)
54 now := time.Now().UTC()
55
56 result, err := db.Exec(`
57 UPDATE users
58 SET
59 most_recent_at_session_id = ?,
60 updated_at = ?
61 WHERE atproto_did = ?`,
62 atProtoSessionID,
63 now,
64 did,
65 )
66 if err != nil {
67 db.logger.Printf("%v", err)
68 return fmt.Errorf("failed to update atproto session for did %s: %w", did, atProtoSessionID)
69 }
70
71 rowsAffected, err := result.RowsAffected()
72 if err != nil {
73 // it's possible the update succeeded here?
74 return fmt.Errorf("failed to check rows affected after updating atproto session for did %s: %w", did, atProtoSessionID)
75 }
76
77 if rowsAffected == 0 {
78 return fmt.Errorf("no user found with did %s to update session, creating new session", did)
79 }
80
81 return nil
82}
83
84type SqliteATProtoStore struct {
85 db *sql.DB
86}
87
88var _ oauth.ClientAuthStore = (*SqliteATProtoStore)(nil)
89
90func NewSqliteATProtoStore(db *sql.DB) *SqliteATProtoStore {
91 return &SqliteATProtoStore{
92 db: db,
93 }
94}
95
96func sessionKey(did syntax.DID, sessionID string) string {
97 return fmt.Sprintf("%s/%s", did, sessionID)
98}
99
100func splitScopes(s string) []string {
101 if s == "" {
102 return nil
103 }
104 return strings.Fields(s)
105}
106
107func joinScopes(scopes []string) string {
108 if len(scopes) == 0 {
109 return ""
110 }
111 return strings.Join(scopes, " ")
112}
113
114func (s *SqliteATProtoStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
115 lookUpKey := sessionKey(did, sessionID)
116
117 var (
118 accountDIDStr string
119 lookUpKeyStr string
120 sessionIDStr string
121 hostURL string
122 authServerURL string
123 authServerTokenEndpoint string
124 authServerRevocationEndpoint string
125 scopesStr string
126 accessToken string
127 refreshToken string
128 dpopAuthServerNonce string
129 dpopHostNonce string
130 dpopPrivateKeyMultibase string
131 )
132
133 err := s.db.QueryRow(`
134 SELECT account_did,
135 look_up_key,
136 session_id,
137 host_url,
138 authserver_url,
139 authserver_token_endpoint,
140 authserver_revocation_endpoint,
141 scopes,
142 access_token,
143 refresh_token,
144 dpop_authserver_nonce,
145 dpop_host_nonce,
146 dpop_privatekey_multibase
147 FROM atproto_sessions
148 WHERE look_up_key = ?
149 `, lookUpKey).Scan(
150 &accountDIDStr,
151 &lookUpKeyStr,
152 &sessionIDStr,
153 &hostURL,
154 &authServerURL,
155 &authServerTokenEndpoint,
156 &authServerRevocationEndpoint,
157 &scopesStr,
158 &accessToken,
159 &refreshToken,
160 &dpopAuthServerNonce,
161 &dpopHostNonce,
162 &dpopPrivateKeyMultibase,
163 )
164
165 if err == sql.ErrNoRows {
166 return nil, fmt.Errorf("session not found: %s", lookUpKey)
167 }
168 if err != nil {
169 return nil, err
170 }
171
172 accDID, err := syntax.ParseDID(accountDIDStr)
173 if err != nil {
174 return nil, fmt.Errorf("invalid account DID in session: %w", err)
175 }
176
177 sess := oauth.ClientSessionData{
178 AccountDID: accDID,
179 SessionID: sessionIDStr,
180 HostURL: hostURL,
181 AuthServerURL: authServerURL,
182 AuthServerTokenEndpoint: authServerTokenEndpoint,
183 AuthServerRevocationEndpoint: authServerRevocationEndpoint,
184 Scopes: splitScopes(scopesStr),
185 AccessToken: accessToken,
186 RefreshToken: refreshToken,
187 DPoPAuthServerNonce: dpopAuthServerNonce,
188 DPoPHostNonce: dpopHostNonce,
189 DPoPPrivateKeyMultibase: dpopPrivateKeyMultibase,
190 }
191
192 return &sess, nil
193}
194
195func (s *SqliteATProtoStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
196 lookUpKey := sessionKey(sess.AccountDID, sess.SessionID)
197 // simple upsert: delete then insert
198 _, _ = s.db.Exec(`DELETE FROM atproto_sessions WHERE look_up_key = ?`, lookUpKey)
199 _, err := s.db.Exec(`
200 INSERT INTO atproto_sessions (
201 look_up_key,
202 account_did,
203 session_id,
204 host_url,
205 authserver_url,
206 authserver_token_endpoint,
207 authserver_revocation_endpoint,
208 scopes,
209 access_token,
210 refresh_token,
211 dpop_authserver_nonce,
212 dpop_host_nonce,
213 dpop_privatekey_multibase
214 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
215 `,
216 lookUpKey,
217 sess.AccountDID.String(),
218 sess.SessionID,
219 sess.HostURL,
220 sess.AuthServerURL,
221 sess.AuthServerTokenEndpoint,
222 sess.AuthServerRevocationEndpoint,
223 joinScopes(sess.Scopes),
224 sess.AccessToken,
225 sess.RefreshToken,
226 sess.DPoPAuthServerNonce,
227 sess.DPoPHostNonce,
228 sess.DPoPPrivateKeyMultibase,
229 )
230 return err
231}
232
233func (s *SqliteATProtoStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
234 lookUpKey := sessionKey(did, sessionID)
235 _, err := s.db.Exec(`DELETE FROM atproto_sessions WHERE look_up_key = ?`, lookUpKey)
236 return err
237}
238
239func (s *SqliteATProtoStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
240 var (
241 authServerURL string
242 accountDIDStr sql.NullString
243 scopesStr string
244 requestURI string
245 authServerTokenEndpoint string
246 authServerRevocationEndpoint string
247 pkceVerifier string
248 dpopAuthServerNonce string
249 dpopPrivateKeyMultibase string
250 )
251 err := s.db.QueryRow(`
252 SELECT authserver_url,
253 account_did,
254 scopes,
255 request_uri,
256 authserver_token_endpoint,
257 authserver_revocation_endpoint,
258 pkce_verifier,
259 dpop_authserver_nonce,
260 dpop_privatekey_multibase
261 FROM atproto_state
262 WHERE state = ?
263 `, state).Scan(
264 &authServerURL,
265 &accountDIDStr,
266 &scopesStr,
267 &requestURI,
268 &authServerTokenEndpoint,
269 &authServerRevocationEndpoint,
270 &pkceVerifier,
271 &dpopAuthServerNonce,
272 &dpopPrivateKeyMultibase,
273 )
274 if err == sql.ErrNoRows {
275 return nil, fmt.Errorf("request info not found: %s", state)
276 }
277 if err != nil {
278 return nil, err
279 }
280 var accountDIDPtr *syntax.DID
281 if accountDIDStr.Valid && accountDIDStr.String != "" {
282 acc, err := syntax.ParseDID(accountDIDStr.String)
283 if err != nil {
284 return nil, fmt.Errorf("invalid account DID in auth request: %w", err)
285 }
286 accountDIDPtr = &acc
287 }
288 info := oauth.AuthRequestData{
289 State: state,
290 AuthServerURL: authServerURL,
291 AccountDID: accountDIDPtr,
292 Scopes: splitScopes(scopesStr),
293 RequestURI: requestURI,
294 AuthServerTokenEndpoint: authServerTokenEndpoint,
295 AuthServerRevocationEndpoint: authServerRevocationEndpoint,
296 PKCEVerifier: pkceVerifier,
297 DPoPAuthServerNonce: dpopAuthServerNonce,
298 DPoPPrivateKeyMultibase: dpopPrivateKeyMultibase,
299 }
300 return &info, nil
301}
302
303func (s *SqliteATProtoStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
304 // ensure not already exists
305 var exists int
306 err := s.db.QueryRow(`SELECT 1 FROM atproto_state WHERE state = ?`, info.State).Scan(&exists)
307 if err == nil {
308 return fmt.Errorf("auth request already saved for state %s", info.State)
309 }
310 if err != nil && err != sql.ErrNoRows {
311 return err
312 }
313 var accountDIDStr interface{}
314 if info.AccountDID != nil {
315 accountDIDStr = info.AccountDID.String()
316 } else {
317 accountDIDStr = nil
318 }
319 _, err = s.db.Exec(`
320 INSERT INTO atproto_state (
321 state,
322 authserver_url,
323 account_did,
324 scopes,
325 request_uri,
326 authserver_token_endpoint,
327 authserver_revocation_endpoint,
328 pkce_verifier,
329 dpop_authserver_nonce,
330 dpop_privatekey_multibase
331 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
332 `,
333 info.State,
334 info.AuthServerURL,
335 accountDIDStr,
336 joinScopes(info.Scopes),
337 info.RequestURI,
338 info.AuthServerTokenEndpoint,
339 info.AuthServerRevocationEndpoint,
340 info.PKCEVerifier,
341 info.DPoPAuthServerNonce,
342 info.DPoPPrivateKeyMultibase,
343 )
344 return err
345}
346
347func (s *SqliteATProtoStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
348 _, err := s.db.Exec(`DELETE FROM atproto_state WHERE state = ?`, state)
349 return err
350}