[WIP] music platform user data scraper
teal-fm
atproto
1package db
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "github.com/teal-fm/piper/models"
14)
15
16func (db *DB) FindOrCreateUserByDID(did string) (*models.User, error) {
17 var user models.User
18 err := db.QueryRow(`
19 SELECT id, atproto_did, created_at, updated_at
20 FROM users
21 WHERE atproto_did = ?`,
22 did).Scan(&user.ID, &user.ATProtoDID, &user.CreatedAt, &user.UpdatedAt)
23
24 if errors.Is(err, sql.ErrNoRows) {
25 now := time.Now().UTC()
26 // create user!
27 result, insertErr := db.Exec(`
28 INSERT INTO users (atproto_did, created_at, updated_at)
29 VALUES (?, ?, ?)
30 `,
31 did,
32 now,
33 now)
34 if insertErr != nil {
35 return nil, fmt.Errorf("failed to create user: %w", insertErr)
36 }
37 lastID, idErr := result.LastInsertId()
38 if idErr != nil {
39 return nil, fmt.Errorf("failed to get last insert id: %w", idErr)
40 }
41 user.ID = lastID
42 user.ATProtoDID = &did
43 user.CreatedAt = now
44 user.UpdatedAt = now
45 return &user, nil
46 } else if err != nil {
47 return nil, fmt.Errorf("failed to find user by DID: %w", err)
48 }
49
50 return &user, err
51}
52
53func (db *DB) SetLatestATProtoSessionId(did string, atProtoSessionID string) error {
54 db.logger.Printf("Setting latest atproto session id for did %s to %s", did, atProtoSessionID)
55 now := time.Now().UTC()
56
57 result, err := db.Exec(`
58 UPDATE users
59 SET
60 most_recent_at_session_id = ?,
61 updated_at = ?
62 WHERE atproto_did = ?`,
63 atProtoSessionID,
64 now,
65 did,
66 )
67 if err != nil {
68 db.logger.Printf("%v", err)
69 return fmt.Errorf("failed to update atproto session for did %s: %s", did, atProtoSessionID)
70 }
71
72 rowsAffected, err := result.RowsAffected()
73 if err != nil {
74 // it's possible the update succeeded here?
75 return fmt.Errorf("failed to check rows affected after updating atproto session for did %s: %s", did, atProtoSessionID)
76 }
77
78 if rowsAffected == 0 {
79 return fmt.Errorf("no user found with did %s to update session, creating new session", did)
80 }
81
82 return nil
83}
84
85type SqliteATProtoStore struct {
86 db *sql.DB
87}
88
89var _ oauth.ClientAuthStore = (*SqliteATProtoStore)(nil)
90
91func NewSqliteATProtoStore(db *sql.DB) *SqliteATProtoStore {
92 return &SqliteATProtoStore{
93 db: db,
94 }
95}
96
97func sessionKey(did syntax.DID, sessionID string) string {
98 return fmt.Sprintf("%s/%s", did, sessionID)
99}
100
101func splitScopes(s string) []string {
102 if s == "" {
103 return nil
104 }
105 return strings.Fields(s)
106}
107
108func joinScopes(scopes []string) string {
109 if len(scopes) == 0 {
110 return ""
111 }
112 return strings.Join(scopes, " ")
113}
114
115func (s *SqliteATProtoStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
116 lookUpKey := sessionKey(did, sessionID)
117
118 var (
119 accountDIDStr string
120 lookUpKeyStr string
121 sessionIDStr string
122 hostURL string
123 authServerURL string
124 authServerTokenEndpoint string
125 authServerRevocationEndpoint string
126 scopesStr string
127 accessToken string
128 refreshToken string
129 dpopAuthServerNonce string
130 dpopHostNonce string
131 dpopPrivateKeyMultibase string
132 )
133
134 err := s.db.QueryRow(`
135 SELECT account_did,
136 look_up_key,
137 session_id,
138 host_url,
139 authserver_url,
140 authserver_token_endpoint,
141 authserver_revocation_endpoint,
142 scopes,
143 access_token,
144 refresh_token,
145 dpop_authserver_nonce,
146 dpop_host_nonce,
147 dpop_privatekey_multibase
148 FROM atproto_sessions
149 WHERE look_up_key = ?
150 `, lookUpKey).Scan(
151 &accountDIDStr,
152 &lookUpKeyStr,
153 &sessionIDStr,
154 &hostURL,
155 &authServerURL,
156 &authServerTokenEndpoint,
157 &authServerRevocationEndpoint,
158 &scopesStr,
159 &accessToken,
160 &refreshToken,
161 &dpopAuthServerNonce,
162 &dpopHostNonce,
163 &dpopPrivateKeyMultibase,
164 )
165
166 if errors.Is(err, sql.ErrNoRows) {
167 return nil, fmt.Errorf("session not found: %s", lookUpKey)
168 }
169 if err != nil {
170 return nil, err
171 }
172
173 accDID, err := syntax.ParseDID(accountDIDStr)
174 if err != nil {
175 return nil, fmt.Errorf("invalid account DID in session: %w", err)
176 }
177
178 sess := oauth.ClientSessionData{
179 AccountDID: accDID,
180 SessionID: sessionIDStr,
181 HostURL: hostURL,
182 AuthServerURL: authServerURL,
183 AuthServerTokenEndpoint: authServerTokenEndpoint,
184 AuthServerRevocationEndpoint: authServerRevocationEndpoint,
185 Scopes: splitScopes(scopesStr),
186 AccessToken: accessToken,
187 RefreshToken: refreshToken,
188 DPoPAuthServerNonce: dpopAuthServerNonce,
189 DPoPHostNonce: dpopHostNonce,
190 DPoPPrivateKeyMultibase: dpopPrivateKeyMultibase,
191 }
192
193 return &sess, nil
194}
195
196func (s *SqliteATProtoStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
197 lookUpKey := sessionKey(sess.AccountDID, sess.SessionID)
198 // simple upsert: delete then insert
199 _, _ = s.db.Exec(`DELETE FROM atproto_sessions WHERE look_up_key = ?`, lookUpKey)
200 _, err := s.db.Exec(`
201 INSERT INTO atproto_sessions (
202 look_up_key,
203 account_did,
204 session_id,
205 host_url,
206 authserver_url,
207 authserver_token_endpoint,
208 authserver_revocation_endpoint,
209 scopes,
210 access_token,
211 refresh_token,
212 dpop_authserver_nonce,
213 dpop_host_nonce,
214 dpop_privatekey_multibase
215 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
216 `,
217 lookUpKey,
218 sess.AccountDID.String(),
219 sess.SessionID,
220 sess.HostURL,
221 sess.AuthServerURL,
222 sess.AuthServerTokenEndpoint,
223 sess.AuthServerRevocationEndpoint,
224 joinScopes(sess.Scopes),
225 sess.AccessToken,
226 sess.RefreshToken,
227 sess.DPoPAuthServerNonce,
228 sess.DPoPHostNonce,
229 sess.DPoPPrivateKeyMultibase,
230 )
231 return err
232}
233
234func (s *SqliteATProtoStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
235 lookUpKey := sessionKey(did, sessionID)
236 _, err := s.db.Exec(`DELETE FROM atproto_sessions WHERE look_up_key = ?`, lookUpKey)
237 return err
238}
239
240func (s *SqliteATProtoStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
241 var (
242 authServerURL string
243 accountDIDStr sql.NullString
244 scopesStr string
245 requestURI string
246 authServerTokenEndpoint string
247 authServerRevocationEndpoint string
248 pkceVerifier string
249 dpopAuthServerNonce string
250 dpopPrivateKeyMultibase string
251 )
252 err := s.db.QueryRow(`
253 SELECT authserver_url,
254 account_did,
255 scopes,
256 request_uri,
257 authserver_token_endpoint,
258 authserver_revocation_endpoint,
259 pkce_verifier,
260 dpop_authserver_nonce,
261 dpop_privatekey_multibase
262 FROM atproto_state
263 WHERE state = ?
264 `, state).Scan(
265 &authServerURL,
266 &accountDIDStr,
267 &scopesStr,
268 &requestURI,
269 &authServerTokenEndpoint,
270 &authServerRevocationEndpoint,
271 &pkceVerifier,
272 &dpopAuthServerNonce,
273 &dpopPrivateKeyMultibase,
274 )
275 if errors.Is(err, sql.ErrNoRows) {
276 return nil, fmt.Errorf("request info not found: %s", state)
277 }
278 if err != nil {
279 return nil, err
280 }
281 var accountDIDPtr *syntax.DID
282 if accountDIDStr.Valid && accountDIDStr.String != "" {
283 acc, err := syntax.ParseDID(accountDIDStr.String)
284 if err != nil {
285 return nil, fmt.Errorf("invalid account DID in auth request: %w", err)
286 }
287 accountDIDPtr = &acc
288 }
289 info := oauth.AuthRequestData{
290 State: state,
291 AuthServerURL: authServerURL,
292 AccountDID: accountDIDPtr,
293 Scopes: splitScopes(scopesStr),
294 RequestURI: requestURI,
295 AuthServerTokenEndpoint: authServerTokenEndpoint,
296 AuthServerRevocationEndpoint: authServerRevocationEndpoint,
297 PKCEVerifier: pkceVerifier,
298 DPoPAuthServerNonce: dpopAuthServerNonce,
299 DPoPPrivateKeyMultibase: dpopPrivateKeyMultibase,
300 }
301 return &info, nil
302}
303
304func (s *SqliteATProtoStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
305 // ensure not already exists
306 var exists int
307 err := s.db.QueryRow(`SELECT 1 FROM atproto_state WHERE state = ?`, info.State).Scan(&exists)
308 if err == nil {
309 return fmt.Errorf("auth request already saved for state %s", info.State)
310 }
311 if !errors.Is(err, sql.ErrNoRows) {
312 return err
313 }
314 var accountDIDStr interface{}
315 if info.AccountDID != nil {
316 accountDIDStr = info.AccountDID.String()
317 } else {
318 accountDIDStr = nil
319 }
320 _, err = s.db.Exec(`
321 INSERT INTO atproto_state (
322 state,
323 authserver_url,
324 account_did,
325 scopes,
326 request_uri,
327 authserver_token_endpoint,
328 authserver_revocation_endpoint,
329 pkce_verifier,
330 dpop_authserver_nonce,
331 dpop_privatekey_multibase
332 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
333 `,
334 info.State,
335 info.AuthServerURL,
336 accountDIDStr,
337 joinScopes(info.Scopes),
338 info.RequestURI,
339 info.AuthServerTokenEndpoint,
340 info.AuthServerRevocationEndpoint,
341 info.PKCEVerifier,
342 info.DPoPAuthServerNonce,
343 info.DPoPPrivateKeyMultibase,
344 )
345 return err
346}
347
348func (s *SqliteATProtoStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
349 _, err := s.db.Exec(`DELETE FROM atproto_state WHERE state = ?`, state)
350 return err
351}