dev vouch dev on at. thats about it atvouch.dev
at appview 236 lines 5.7 kB view raw
1package main 2 3import ( 4 "context" 5 "database/sql" 6 "encoding/json" 7 "fmt" 8 "os" 9 "path/filepath" 10 11 "github.com/bluesky-social/indigo/atproto/auth/oauth" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 _ "github.com/mattn/go-sqlite3" 14) 15 16type Store struct { 17 db *sql.DB 18} 19 20var _ oauth.ClientAuthStore = &Store{} 21 22func NewStore() (*Store, error) { 23 configDir, err := os.UserConfigDir() 24 if err != nil { 25 return nil, fmt.Errorf("getting config dir: %w", err) 26 } 27 dir := filepath.Join(configDir, "atvouch") 28 if err := os.MkdirAll(dir, 0700); err != nil { 29 return nil, err 30 } 31 32 db, err := sql.Open("sqlite3", filepath.Join(dir, "state.db")) 33 if err != nil { 34 return nil, err 35 } 36 37 for _, pragma := range []string{ 38 "PRAGMA journal_mode = WAL", 39 "PRAGMA busy_timeout = 5000", 40 "PRAGMA synchronous = NORMAL", 41 "PRAGMA cache_size = -6000", 42 "PRAGMA foreign_keys = true", 43 "PRAGMA temp_store = memory", 44 } { 45 if _, err := db.Exec(pragma); err != nil { 46 db.Close() 47 return nil, fmt.Errorf("setting pragma: %w", err) 48 } 49 } 50 51 if err := migrate(db); err != nil { 52 db.Close() 53 return nil, err 54 } 55 56 return &Store{db: db}, nil 57} 58 59var migrations = []struct { 60 version int 61 sql string 62}{ 63 {1, ` 64 CREATE TABLE sessions ( 65 did TEXT NOT NULL, 66 session_id TEXT NOT NULL, 67 data TEXT NOT NULL, 68 PRIMARY KEY (did, session_id) 69 ) STRICT; 70 71 CREATE TABLE auth_requests ( 72 state TEXT NOT NULL PRIMARY KEY, 73 data TEXT NOT NULL 74 ) STRICT; 75 76 CREATE TABLE active_session ( 77 id INTEGER NOT NULL PRIMARY KEY CHECK (id = 1), 78 did TEXT NOT NULL, 79 session_id TEXT NOT NULL 80 ) STRICT; 81 `}, 82} 83 84func migrate(db *sql.DB) error { 85 _, err := db.Exec(` 86 CREATE TABLE IF NOT EXISTS migration_log ( 87 version INTEGER NOT NULL PRIMARY KEY, 88 applied_at TEXT NOT NULL DEFAULT (datetime('now')) 89 ) STRICT; 90 `) 91 if err != nil { 92 return fmt.Errorf("creating migration_log: %w", err) 93 } 94 95 for _, m := range migrations { 96 var exists int 97 err := db.QueryRow("SELECT 1 FROM migration_log WHERE version = ?", m.version).Scan(&exists) 98 if err == nil { 99 continue 100 } 101 if err != sql.ErrNoRows { 102 return fmt.Errorf("checking migration %d: %w", m.version, err) 103 } 104 105 tx, err := db.Begin() 106 if err != nil { 107 return err 108 } 109 if _, err := tx.Exec(m.sql); err != nil { 110 tx.Rollback() 111 return fmt.Errorf("migration %d: %w", m.version, err) 112 } 113 if _, err := tx.Exec("INSERT INTO migration_log (version) VALUES (?)", m.version); err != nil { 114 tx.Rollback() 115 return fmt.Errorf("recording migration %d: %w", m.version, err) 116 } 117 if err := tx.Commit(); err != nil { 118 return fmt.Errorf("committing migration %d: %w", m.version, err) 119 } 120 } 121 122 return nil 123} 124 125func (s *Store) Close() error { 126 return s.db.Close() 127} 128 129func (s *Store) SetActive(did syntax.DID, sessionID string) error { 130 _, err := s.db.Exec( 131 `INSERT INTO active_session (id, did, session_id) VALUES (1, ?, ?) 132 ON CONFLICT (id) DO UPDATE SET did = excluded.did, session_id = excluded.session_id`, 133 did.String(), sessionID, 134 ) 135 return err 136} 137 138type activeSession struct { 139 DID syntax.DID 140 SessionID string 141} 142 143func (s *Store) GetActive() (*activeSession, error) { 144 var didStr, sessionID string 145 err := s.db.QueryRow("SELECT did, session_id FROM active_session WHERE id = 1").Scan(&didStr, &sessionID) 146 if err == sql.ErrNoRows { 147 return nil, fmt.Errorf("no active session (run 'atvouch login' first)") 148 } 149 if err != nil { 150 return nil, err 151 } 152 did, err := syntax.ParseDID(didStr) 153 if err != nil { 154 return nil, err 155 } 156 return &activeSession{DID: did, SessionID: sessionID}, nil 157} 158 159func (s *Store) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 160 var data string 161 err := s.db.QueryRowContext(ctx, 162 "SELECT data FROM sessions WHERE did = ? AND session_id = ?", 163 did.String(), sessionID, 164 ).Scan(&data) 165 if err == sql.ErrNoRows { 166 return nil, fmt.Errorf("session not found for %s", did) 167 } 168 if err != nil { 169 return nil, err 170 } 171 var sess oauth.ClientSessionData 172 if err := json.Unmarshal([]byte(data), &sess); err != nil { 173 return nil, err 174 } 175 return &sess, nil 176} 177 178func (s *Store) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 179 data, err := json.Marshal(sess) 180 if err != nil { 181 return err 182 } 183 _, err = s.db.ExecContext(ctx, 184 `INSERT INTO sessions (did, session_id, data) VALUES (?, ?, ?) 185 ON CONFLICT (did, session_id) DO UPDATE SET data = excluded.data`, 186 sess.AccountDID.String(), sess.SessionID, string(data), 187 ) 188 return err 189} 190 191func (s *Store) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 192 _, err := s.db.ExecContext(ctx, 193 "DELETE FROM sessions WHERE did = ? AND session_id = ?", 194 did.String(), sessionID, 195 ) 196 return err 197} 198 199func (s *Store) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 200 var data string 201 err := s.db.QueryRowContext(ctx, 202 "SELECT data FROM auth_requests WHERE state = ?", 203 state, 204 ).Scan(&data) 205 if err == sql.ErrNoRows { 206 return nil, fmt.Errorf("request info not found: %s", state) 207 } 208 if err != nil { 209 return nil, err 210 } 211 var req oauth.AuthRequestData 212 if err := json.Unmarshal([]byte(data), &req); err != nil { 213 return nil, err 214 } 215 return &req, nil 216} 217 218func (s *Store) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 219 data, err := json.Marshal(info) 220 if err != nil { 221 return err 222 } 223 _, err = s.db.ExecContext(ctx, 224 "INSERT INTO auth_requests (state, data) VALUES (?, ?)", 225 info.State, string(data), 226 ) 227 return err 228} 229 230func (s *Store) DeleteAuthRequestInfo(ctx context.Context, state string) error { 231 _, err := s.db.ExecContext(ctx, 232 "DELETE FROM auth_requests WHERE state = ?", 233 state, 234 ) 235 return err 236}