[WIP] music platform user data scraper
teal-fm atproto
at main 9.4 kB view raw
1package db 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 "github.com/bluesky-social/indigo/atproto/syntax" 12 "github.com/teal-fm/piper/models" 13) 14 15func (db *DB) FindOrCreateUserByDID(did string) (*models.User, error) { 16 var user models.User 17 err := db.QueryRow(` 18 SELECT id, atproto_did, created_at, updated_at 19 FROM users 20 WHERE atproto_did = ?`, 21 did).Scan(&user.ID, &user.ATProtoDID, &user.CreatedAt, &user.UpdatedAt) 22 23 if err == sql.ErrNoRows { 24 now := time.Now().UTC() 25 // create user! 26 result, insertErr := db.Exec(` 27 INSERT INTO users (atproto_did, created_at, updated_at) 28 VALUES (?, ?, ?) 29 `, 30 did, 31 now, 32 now) 33 if insertErr != nil { 34 return nil, fmt.Errorf("failed to create user: %w", insertErr) 35 } 36 lastID, idErr := result.LastInsertId() 37 if idErr != nil { 38 return nil, fmt.Errorf("failed to get last insert id: %w", idErr) 39 } 40 user.ID = lastID 41 user.ATProtoDID = &did 42 user.CreatedAt = now 43 user.UpdatedAt = now 44 return &user, nil 45 } else if err != nil { 46 return nil, fmt.Errorf("failed to find user by DID: %w", err) 47 } 48 49 return &user, err 50} 51 52func (db *DB) SetLatestATProtoSessionId(did string, atProtoSessionID string) error { 53 db.logger.Printf("Setting latest atproto session id for did %s to %s", did, atProtoSessionID) 54 now := time.Now().UTC() 55 56 result, err := db.Exec(` 57 UPDATE users 58 SET 59 most_recent_at_session_id = ?, 60 updated_at = ? 61 WHERE atproto_did = ?`, 62 atProtoSessionID, 63 now, 64 did, 65 ) 66 if err != nil { 67 db.logger.Printf("%v", err) 68 return fmt.Errorf("failed to update atproto session for did %s: %w", did, atProtoSessionID) 69 } 70 71 rowsAffected, err := result.RowsAffected() 72 if err != nil { 73 // it's possible the update succeeded here? 74 return fmt.Errorf("failed to check rows affected after updating atproto session for did %s: %w", did, atProtoSessionID) 75 } 76 77 if rowsAffected == 0 { 78 return fmt.Errorf("no user found with did %s to update session, creating new session", did) 79 } 80 81 return nil 82} 83 84type SqliteATProtoStore struct { 85 db *sql.DB 86} 87 88var _ oauth.ClientAuthStore = (*SqliteATProtoStore)(nil) 89 90func NewSqliteATProtoStore(db *sql.DB) *SqliteATProtoStore { 91 return &SqliteATProtoStore{ 92 db: db, 93 } 94} 95 96func sessionKey(did syntax.DID, sessionID string) string { 97 return fmt.Sprintf("%s/%s", did, sessionID) 98} 99 100func splitScopes(s string) []string { 101 if s == "" { 102 return nil 103 } 104 return strings.Fields(s) 105} 106 107func joinScopes(scopes []string) string { 108 if len(scopes) == 0 { 109 return "" 110 } 111 return strings.Join(scopes, " ") 112} 113 114func (s *SqliteATProtoStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 115 lookUpKey := sessionKey(did, sessionID) 116 117 var ( 118 accountDIDStr string 119 lookUpKeyStr string 120 sessionIDStr string 121 hostURL string 122 authServerURL string 123 authServerTokenEndpoint string 124 authServerRevocationEndpoint string 125 scopesStr string 126 accessToken string 127 refreshToken string 128 dpopAuthServerNonce string 129 dpopHostNonce string 130 dpopPrivateKeyMultibase string 131 ) 132 133 err := s.db.QueryRow(` 134 SELECT account_did, 135 look_up_key, 136 session_id, 137 host_url, 138 authserver_url, 139 authserver_token_endpoint, 140 authserver_revocation_endpoint, 141 scopes, 142 access_token, 143 refresh_token, 144 dpop_authserver_nonce, 145 dpop_host_nonce, 146 dpop_privatekey_multibase 147 FROM atproto_sessions 148 WHERE look_up_key = ? 149 `, lookUpKey).Scan( 150 &accountDIDStr, 151 &lookUpKeyStr, 152 &sessionIDStr, 153 &hostURL, 154 &authServerURL, 155 &authServerTokenEndpoint, 156 &authServerRevocationEndpoint, 157 &scopesStr, 158 &accessToken, 159 &refreshToken, 160 &dpopAuthServerNonce, 161 &dpopHostNonce, 162 &dpopPrivateKeyMultibase, 163 ) 164 165 if err == sql.ErrNoRows { 166 return nil, fmt.Errorf("session not found: %s", lookUpKey) 167 } 168 if err != nil { 169 return nil, err 170 } 171 172 accDID, err := syntax.ParseDID(accountDIDStr) 173 if err != nil { 174 return nil, fmt.Errorf("invalid account DID in session: %w", err) 175 } 176 177 sess := oauth.ClientSessionData{ 178 AccountDID: accDID, 179 SessionID: sessionIDStr, 180 HostURL: hostURL, 181 AuthServerURL: authServerURL, 182 AuthServerTokenEndpoint: authServerTokenEndpoint, 183 AuthServerRevocationEndpoint: authServerRevocationEndpoint, 184 Scopes: splitScopes(scopesStr), 185 AccessToken: accessToken, 186 RefreshToken: refreshToken, 187 DPoPAuthServerNonce: dpopAuthServerNonce, 188 DPoPHostNonce: dpopHostNonce, 189 DPoPPrivateKeyMultibase: dpopPrivateKeyMultibase, 190 } 191 192 return &sess, nil 193} 194 195func (s *SqliteATProtoStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 196 lookUpKey := sessionKey(sess.AccountDID, sess.SessionID) 197 // simple upsert: delete then insert 198 _, _ = s.db.Exec(`DELETE FROM atproto_sessions WHERE look_up_key = ?`, lookUpKey) 199 _, err := s.db.Exec(` 200 INSERT INTO atproto_sessions ( 201 look_up_key, 202 account_did, 203 session_id, 204 host_url, 205 authserver_url, 206 authserver_token_endpoint, 207 authserver_revocation_endpoint, 208 scopes, 209 access_token, 210 refresh_token, 211 dpop_authserver_nonce, 212 dpop_host_nonce, 213 dpop_privatekey_multibase 214 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 215 `, 216 lookUpKey, 217 sess.AccountDID.String(), 218 sess.SessionID, 219 sess.HostURL, 220 sess.AuthServerURL, 221 sess.AuthServerTokenEndpoint, 222 sess.AuthServerRevocationEndpoint, 223 joinScopes(sess.Scopes), 224 sess.AccessToken, 225 sess.RefreshToken, 226 sess.DPoPAuthServerNonce, 227 sess.DPoPHostNonce, 228 sess.DPoPPrivateKeyMultibase, 229 ) 230 return err 231} 232 233func (s *SqliteATProtoStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 234 lookUpKey := sessionKey(did, sessionID) 235 _, err := s.db.Exec(`DELETE FROM atproto_sessions WHERE look_up_key = ?`, lookUpKey) 236 return err 237} 238 239func (s *SqliteATProtoStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 240 var ( 241 authServerURL string 242 accountDIDStr sql.NullString 243 scopesStr string 244 requestURI string 245 authServerTokenEndpoint string 246 authServerRevocationEndpoint string 247 pkceVerifier string 248 dpopAuthServerNonce string 249 dpopPrivateKeyMultibase string 250 ) 251 err := s.db.QueryRow(` 252 SELECT authserver_url, 253 account_did, 254 scopes, 255 request_uri, 256 authserver_token_endpoint, 257 authserver_revocation_endpoint, 258 pkce_verifier, 259 dpop_authserver_nonce, 260 dpop_privatekey_multibase 261 FROM atproto_state 262 WHERE state = ? 263 `, state).Scan( 264 &authServerURL, 265 &accountDIDStr, 266 &scopesStr, 267 &requestURI, 268 &authServerTokenEndpoint, 269 &authServerRevocationEndpoint, 270 &pkceVerifier, 271 &dpopAuthServerNonce, 272 &dpopPrivateKeyMultibase, 273 ) 274 if err == sql.ErrNoRows { 275 return nil, fmt.Errorf("request info not found: %s", state) 276 } 277 if err != nil { 278 return nil, err 279 } 280 var accountDIDPtr *syntax.DID 281 if accountDIDStr.Valid && accountDIDStr.String != "" { 282 acc, err := syntax.ParseDID(accountDIDStr.String) 283 if err != nil { 284 return nil, fmt.Errorf("invalid account DID in auth request: %w", err) 285 } 286 accountDIDPtr = &acc 287 } 288 info := oauth.AuthRequestData{ 289 State: state, 290 AuthServerURL: authServerURL, 291 AccountDID: accountDIDPtr, 292 Scopes: splitScopes(scopesStr), 293 RequestURI: requestURI, 294 AuthServerTokenEndpoint: authServerTokenEndpoint, 295 AuthServerRevocationEndpoint: authServerRevocationEndpoint, 296 PKCEVerifier: pkceVerifier, 297 DPoPAuthServerNonce: dpopAuthServerNonce, 298 DPoPPrivateKeyMultibase: dpopPrivateKeyMultibase, 299 } 300 return &info, nil 301} 302 303func (s *SqliteATProtoStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 304 // ensure not already exists 305 var exists int 306 err := s.db.QueryRow(`SELECT 1 FROM atproto_state WHERE state = ?`, info.State).Scan(&exists) 307 if err == nil { 308 return fmt.Errorf("auth request already saved for state %s", info.State) 309 } 310 if err != nil && err != sql.ErrNoRows { 311 return err 312 } 313 var accountDIDStr interface{} 314 if info.AccountDID != nil { 315 accountDIDStr = info.AccountDID.String() 316 } else { 317 accountDIDStr = nil 318 } 319 _, err = s.db.Exec(` 320 INSERT INTO atproto_state ( 321 state, 322 authserver_url, 323 account_did, 324 scopes, 325 request_uri, 326 authserver_token_endpoint, 327 authserver_revocation_endpoint, 328 pkce_verifier, 329 dpop_authserver_nonce, 330 dpop_privatekey_multibase 331 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 332 `, 333 info.State, 334 info.AuthServerURL, 335 accountDIDStr, 336 joinScopes(info.Scopes), 337 info.RequestURI, 338 info.AuthServerTokenEndpoint, 339 info.AuthServerRevocationEndpoint, 340 info.PKCEVerifier, 341 info.DPoPAuthServerNonce, 342 info.DPoPPrivateKeyMultibase, 343 ) 344 return err 345} 346 347func (s *SqliteATProtoStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 348 _, err := s.db.Exec(`DELETE FROM atproto_state WHERE state = ?`, state) 349 return err 350}