[WIP] music platform user data scraper
teal-fm atproto
32
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 6d516e12d5b78a43a6a940dfb8ca739120fae28b 355 lines 8.6 kB view raw
1package session 2 3import ( 4 "context" 5 "crypto/rand" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "log" 10 "net/http" 11 "sync" 12 "time" 13 14 "github.com/teal-fm/piper/db" 15 "github.com/teal-fm/piper/db/apikey" 16) 17 18// session/session.go 19type Session struct { 20 ID string 21 UserID int64 22 ATprotoDID string 23 ATprotoAccessToken string 24 ATprotoRefreshToken string 25 CreatedAt time.Time 26 ExpiresAt time.Time 27} 28 29type SessionManager struct { 30 db *db.DB 31 sessions map[string]*Session // use in memory cache if necessary 32 apiKeyMgr *apikey.ApiKeyManager 33 mu sync.RWMutex 34} 35 36func NewSessionManager(database *db.DB) *SessionManager { 37 38 _, err := database.Exec(` 39 CREATE TABLE IF NOT EXISTS sessions ( 40 id TEXT PRIMARY KEY, 41 user_id INTEGER NOT NULL, 42 created_at TIMESTAMP, 43 expires_at TIMESTAMP, 44 FOREIGN KEY (user_id) REFERENCES users(id) 45 )`) 46 47 if err != nil { 48 log.Printf("Error creating sessions table: %v", err) 49 } 50 51 apiKeyMgr := apikey.NewApiKeyManager(database) 52 53 return &SessionManager{ 54 db: database, 55 sessions: make(map[string]*Session), 56 apiKeyMgr: apiKeyMgr, 57 } 58} 59 60// create a new session for a user 61func (sm *SessionManager) CreateSession(userID int64) *Session { 62 sm.mu.Lock() 63 defer sm.mu.Unlock() 64 65 // random session id 66 b := make([]byte, 32) 67 rand.Read(b) 68 sessionID := base64.URLEncoding.EncodeToString(b) 69 70 now := time.Now().UTC() 71 expiresAt := now.Add(24 * time.Hour) // 24-hour session 72 73 session := &Session{ 74 ID: sessionID, 75 UserID: userID, 76 CreatedAt: now, 77 ExpiresAt: expiresAt, 78 } 79 80 // store session in memory 81 sm.sessions[sessionID] = session 82 83 // store session in database if available 84 if sm.db != nil { 85 _, err := sm.db.Exec(` 86 INSERT INTO sessions (id, user_id, created_at, expires_at) 87 VALUES (?, ?, ?, ?)`, 88 sessionID, userID, now, expiresAt) 89 90 if err != nil { 91 log.Printf("Error storing session in database: %v", err) 92 } 93 } 94 95 return session 96} 97 98// retrieve a session by ID 99func (sm *SessionManager) GetSession(sessionID string) (*Session, bool) { 100 // First check in-memory cache 101 sm.mu.RLock() 102 session, exists := sm.sessions[sessionID] 103 sm.mu.RUnlock() 104 105 if exists { 106 // Check if session is expired 107 if time.Now().UTC().After(session.ExpiresAt) { 108 sm.DeleteSession(sessionID) 109 return nil, false 110 } 111 return session, true 112 } 113 114 // if not in memory and we have a database, check there 115 if sm.db != nil { 116 session = &Session{ID: sessionID} 117 118 err := sm.db.QueryRow(` 119 SELECT user_id, created_at, expires_at 120 FROM sessions WHERE id = ?`, sessionID).Scan( 121 &session.UserID, &session.CreatedAt, &session.ExpiresAt) 122 123 if err != nil { 124 return nil, false 125 } 126 127 if time.Now().UTC().After(session.ExpiresAt) { 128 sm.DeleteSession(sessionID) 129 return nil, false 130 } 131 132 // add to in-memory cache 133 sm.mu.Lock() 134 sm.sessions[sessionID] = session 135 sm.mu.Unlock() 136 137 return session, true 138 } 139 140 return nil, false 141} 142 143// remove a session 144func (sm *SessionManager) DeleteSession(sessionID string) { 145 sm.mu.Lock() 146 delete(sm.sessions, sessionID) 147 sm.mu.Unlock() 148 149 if sm.db != nil { 150 _, err := sm.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) 151 if err != nil { 152 log.Printf("Error deleting session from database: %v", err) 153 } 154 } 155} 156 157// set a session cookie for the user 158func (sm *SessionManager) SetSessionCookie(w http.ResponseWriter, session *Session) { 159 cookie := &http.Cookie{ 160 Name: "session", 161 Value: session.ID, 162 Path: "/", 163 HttpOnly: true, 164 Secure: false, 165 Expires: session.ExpiresAt, 166 } 167 http.SetCookie(w, cookie) 168} 169 170// ClearSessionCookie clears the session cookie 171func (sm *SessionManager) ClearSessionCookie(w http.ResponseWriter) { 172 cookie := &http.Cookie{ 173 Name: "session", 174 Value: "", 175 Path: "/", 176 HttpOnly: true, 177 Secure: false, 178 MaxAge: -1, 179 } 180 http.SetCookie(w, cookie) 181} 182 183func (sm *SessionManager) HandleLogout(w http.ResponseWriter, r *http.Request) { 184 cookie, err := r.Cookie("session") 185 if err == nil { 186 sm.DeleteSession(cookie.Value) 187 } 188 189 sm.ClearSessionCookie(w) 190 191 http.Redirect(w, r, "/", http.StatusSeeOther) 192} 193 194func (sm *SessionManager) GetAPIKeyManager() *apikey.ApiKeyManager { 195 return sm.apiKeyMgr 196} 197 198func (sm *SessionManager) CreateAPIKey(userID int64, name string, validityDays int) (*apikey.ApiKey, error) { 199 return sm.apiKeyMgr.CreateApiKey(userID, name, validityDays) 200} 201 202// middleware that checks if a user is authenticated via cookies or API key 203func WithAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 204 return func(w http.ResponseWriter, r *http.Request) { 205 // first: check API keys 206 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 207 if apiKeyErr == nil && apiKeyStr != "" { 208 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 209 if valid { 210 ctx := WithUserID(r.Context(), apiKey.UserID) 211 r = r.WithContext(ctx) 212 213 // set a flag for api requests 214 ctx = WithAPIRequest(r.Context(), true) 215 r = r.WithContext(ctx) 216 217 handler(w, r) 218 return 219 } 220 } 221 222 // if not found, check cookies for session value 223 cookie, err := r.Cookie("session") 224 if err != nil { 225 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther) 226 return 227 } 228 229 session, exists := sm.GetSession(cookie.Value) 230 if !exists { 231 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther) 232 return 233 } 234 235 ctx := WithUserID(r.Context(), session.UserID) 236 r = r.WithContext(ctx) 237 238 handler(w, r) 239 } 240} 241 242// middleware that checks if a user is authenticated but doesn't error out if not 243func WithPossibleAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 244 return func(w http.ResponseWriter, r *http.Request) { 245 ctx := r.Context() 246 authenticated := false 247 248 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 249 if apiKeyErr == nil && apiKeyStr != "" { 250 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 251 if valid { 252 ctx = WithUserID(ctx, apiKey.UserID) 253 ctx = WithAPIRequest(ctx, true) 254 authenticated = true 255 r = r.WithContext(WithAuthStatus(ctx, authenticated)) 256 handler(w, r) 257 return 258 } 259 } 260 261 if !authenticated { 262 cookie, err := r.Cookie("session") 263 if err == nil { 264 session, exists := sm.GetSession(cookie.Value) 265 if exists { 266 ctx = WithUserID(ctx, session.UserID) 267 authenticated = true 268 } 269 } 270 } 271 272 r = r.WithContext(WithAuthStatus(ctx, authenticated)) 273 handler(w, r) 274 } 275} 276 277// middleware that only accepts API keys 278func WithAPIAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 279 return func(w http.ResponseWriter, r *http.Request) { 280 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 281 if apiKeyErr != nil || apiKeyStr == "" { 282 w.Header().Set("Content-Type", "application/json") 283 w.WriteHeader(http.StatusUnauthorized) 284 w.Write([]byte(`{"error": "API key is required"}`)) 285 return 286 } 287 288 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 289 if !valid { 290 w.Header().Set("Content-Type", "application/json") 291 w.WriteHeader(http.StatusUnauthorized) 292 w.Write([]byte(`{"error": "Invalid or expired API key"}`)) 293 return 294 } 295 296 ctx := WithUserID(r.Context(), apiKey.UserID) 297 ctx = WithAPIRequest(ctx, true) 298 r = r.WithContext(ctx) 299 300 handler(w, r) 301 } 302} 303 304func (sm *SessionManager) HandleDebug(w http.ResponseWriter, r *http.Request) { 305 ctx := r.Context() 306 userID, ok := GetUserID(ctx) 307 if !ok { 308 w.Header().Set("Content-Type", "application/json") 309 w.WriteHeader(http.StatusUnauthorized) 310 w.Write([]byte(`{"error": "User ID not found in context"}`)) 311 return 312 } 313 314 res, err := sm.db.DebugViewUserInformation(userID) 315 if err != nil { 316 w.Header().Set("Content-Type", "application/json") 317 w.WriteHeader(http.StatusInternalServerError) 318 w.Write([]byte(fmt.Sprintf(`{"error": "Failed to retrieve user information: %v"}`, err))) 319 return 320 } 321 322 w.Header().Set("Content-Type", "application/json") 323 w.WriteHeader(http.StatusOK) 324 json.NewEncoder(w).Encode(res) 325} 326 327type contextKey int 328 329const ( 330 userIDKey contextKey = iota 331 apiRequestKey 332 authStatusKey 333) 334 335func WithUserID(ctx context.Context, userID int64) context.Context { 336 return context.WithValue(ctx, userIDKey, userID) 337} 338 339func GetUserID(ctx context.Context) (int64, bool) { 340 userID, ok := ctx.Value(userIDKey).(int64) 341 return userID, ok 342} 343 344func WithAuthStatus(ctx context.Context, isAuthed bool) context.Context { 345 return context.WithValue(ctx, authStatusKey, isAuthed) 346} 347 348func WithAPIRequest(ctx context.Context, isAPI bool) context.Context { 349 return context.WithValue(ctx, apiRequestKey, isAPI) 350} 351 352func IsAPIRequest(ctx context.Context) bool { 353 isAPI, ok := ctx.Value(apiRequestKey).(bool) 354 return ok && isAPI 355}