[WIP] music platform user data scraper
teal-fm atproto
at main 8.7 kB view raw
1package session 2 3import ( 4 "context" 5 "crypto/rand" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "log" 10 "net/http" 11 "sync" 12 "time" 13 14 "github.com/teal-fm/piper/db" 15 "github.com/teal-fm/piper/db/apikey" 16) 17 18// session/session.go 19type Session struct { 20 21 //need to re work this. May add onto it for atproto oauth. But need to be careful about that expiresd 22 //Maybe a speerate oauth session store table and it has a created date? yeah do that then can look it up by session id from this table for user actions 23 24 ID string 25 UserID int64 26 ATProtoSessionID string 27 CreatedAt time.Time 28 ExpiresAt time.Time 29} 30 31type SessionManager struct { 32 db *db.DB 33 sessions map[string]*Session // use in memory cache if necessary 34 apiKeyMgr *apikey.ApiKeyManager 35 mu sync.RWMutex 36} 37 38func NewSessionManager(database *db.DB) *SessionManager { 39 40 _, err := database.Exec(` 41 CREATE TABLE IF NOT EXISTS sessions ( 42 id TEXT PRIMARY KEY, 43 user_id INTEGER NOT NULL, 44 at_proto_session_id TEXT NOT NULL, 45 created_at TIMESTAMP, 46 expires_at TIMESTAMP, 47 FOREIGN KEY (user_id) REFERENCES users(id) 48 )`) 49 50 if err != nil { 51 log.Printf("Error creating sessions table: %v", err) 52 } 53 54 apiKeyMgr := apikey.NewApiKeyManager(database) 55 56 return &SessionManager{ 57 db: database, 58 sessions: make(map[string]*Session), 59 apiKeyMgr: apiKeyMgr, 60 } 61} 62 63// create a new session for a user 64func (sm *SessionManager) CreateSession(userID int64, atProtoSessionId string) *Session { 65 sm.mu.Lock() 66 defer sm.mu.Unlock() 67 68 // random session id 69 b := make([]byte, 32) 70 rand.Read(b) 71 sessionID := base64.URLEncoding.EncodeToString(b) 72 73 now := time.Now().UTC() 74 expiresAt := now.Add(24 * time.Hour) // 24-hour session 75 76 session := &Session{ 77 ID: sessionID, 78 UserID: userID, 79 ATProtoSessionID: atProtoSessionId, 80 CreatedAt: now, 81 ExpiresAt: expiresAt, 82 } 83 84 // store session in memory 85 sm.sessions[sessionID] = session 86 87 // store session in database if available 88 if sm.db != nil { 89 _, err := sm.db.Exec(` 90 INSERT INTO sessions (id, user_id, at_proto_session_id, created_at, expires_at) 91 VALUES (?, ?, ?, ?, ?)`, 92 sessionID, userID, atProtoSessionId, now, expiresAt) 93 94 if err != nil { 95 log.Printf("Error storing session in database: %v", err) 96 } 97 } 98 99 return session 100} 101 102// retrieve a session by ID 103func (sm *SessionManager) GetSession(sessionID string) (*Session, bool) { 104 // First check in-memory cache 105 sm.mu.RLock() 106 session, exists := sm.sessions[sessionID] 107 sm.mu.RUnlock() 108 109 if exists { 110 // Check if session is expired 111 if time.Now().UTC().After(session.ExpiresAt) { 112 sm.DeleteSession(sessionID) 113 return nil, false 114 } 115 return session, true 116 } 117 118 // if not in memory and we have a database, check there 119 if sm.db != nil { 120 session = &Session{ID: sessionID} 121 122 err := sm.db.QueryRow(` 123 SELECT user_id, at_proto_session_id, created_at, expires_at 124 FROM sessions WHERE id = ?`, sessionID).Scan( 125 &session.UserID, &session.ATProtoSessionID, &session.CreatedAt, &session.ExpiresAt) 126 127 if err != nil { 128 return nil, false 129 } 130 131 if time.Now().UTC().After(session.ExpiresAt) { 132 sm.DeleteSession(sessionID) 133 return nil, false 134 } 135 136 // add to in-memory cache 137 sm.mu.Lock() 138 sm.sessions[sessionID] = session 139 sm.mu.Unlock() 140 141 return session, true 142 } 143 144 return nil, false 145} 146 147// remove a session 148func (sm *SessionManager) DeleteSession(sessionID string) { 149 sm.mu.Lock() 150 delete(sm.sessions, sessionID) 151 sm.mu.Unlock() 152 153 if sm.db != nil { 154 _, err := sm.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) 155 if err != nil { 156 log.Printf("Error deleting session from database: %v", err) 157 } 158 } 159} 160 161// set a session cookie for the user 162func (sm *SessionManager) SetSessionCookie(w http.ResponseWriter, session *Session) { 163 cookie := &http.Cookie{ 164 Name: "session", 165 Value: session.ID, 166 Path: "/", 167 HttpOnly: true, 168 Secure: false, 169 Expires: session.ExpiresAt, 170 } 171 http.SetCookie(w, cookie) 172} 173 174// ClearSessionCookie clears the session cookie 175func (sm *SessionManager) ClearSessionCookie(w http.ResponseWriter) { 176 cookie := &http.Cookie{ 177 Name: "session", 178 Value: "", 179 Path: "/", 180 HttpOnly: true, 181 Secure: false, 182 MaxAge: -1, 183 } 184 http.SetCookie(w, cookie) 185} 186 187func (sm *SessionManager) GetAPIKeyManager() *apikey.ApiKeyManager { 188 return sm.apiKeyMgr 189} 190 191func (sm *SessionManager) CreateAPIKey(userID int64, name string, validityDays int) (*apikey.ApiKey, error) { 192 return sm.apiKeyMgr.CreateApiKey(userID, name, validityDays) 193} 194 195// middleware that checks if a user is authenticated via cookies or API key 196func WithAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 197 return func(w http.ResponseWriter, r *http.Request) { 198 // first: check API keys 199 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 200 if apiKeyErr == nil && apiKeyStr != "" { 201 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 202 if valid { 203 ctx := WithUserID(r.Context(), apiKey.UserID) 204 r = r.WithContext(ctx) 205 206 // set a flag for api requests 207 ctx = WithAPIRequest(r.Context(), true) 208 r = r.WithContext(ctx) 209 210 handler(w, r) 211 return 212 } 213 } 214 215 // if not found, check cookies for session value 216 cookie, err := r.Cookie("session") 217 if err != nil { 218 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther) 219 return 220 } 221 222 session, exists := sm.GetSession(cookie.Value) 223 if !exists { 224 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther) 225 return 226 } 227 228 ctx := WithUserID(r.Context(), session.UserID) 229 r = r.WithContext(ctx) 230 231 handler(w, r) 232 } 233} 234 235// middleware that checks if a user is authenticated but doesn't error out if not 236func WithPossibleAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 237 return func(w http.ResponseWriter, r *http.Request) { 238 ctx := r.Context() 239 authenticated := false 240 241 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 242 if apiKeyErr == nil && apiKeyStr != "" { 243 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 244 if valid { 245 ctx = WithUserID(ctx, apiKey.UserID) 246 ctx = WithAPIRequest(ctx, true) 247 authenticated = true 248 r = r.WithContext(WithAuthStatus(ctx, authenticated)) 249 handler(w, r) 250 return 251 } 252 } 253 254 if !authenticated { 255 cookie, err := r.Cookie("session") 256 if err == nil { 257 session, exists := sm.GetSession(cookie.Value) 258 if exists { 259 ctx = WithUserID(ctx, session.UserID) 260 authenticated = true 261 } 262 } 263 } 264 265 r = r.WithContext(WithAuthStatus(ctx, authenticated)) 266 handler(w, r) 267 } 268} 269 270// middleware that only accepts API keys 271func WithAPIAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 272 return func(w http.ResponseWriter, r *http.Request) { 273 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 274 if apiKeyErr != nil || apiKeyStr == "" { 275 w.Header().Set("Content-Type", "application/json") 276 w.WriteHeader(http.StatusUnauthorized) 277 w.Write([]byte(`{"error": "API key is required"}`)) 278 return 279 } 280 281 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 282 if !valid { 283 w.Header().Set("Content-Type", "application/json") 284 w.WriteHeader(http.StatusUnauthorized) 285 w.Write([]byte(`{"error": "Invalid or expired API key"}`)) 286 return 287 } 288 289 ctx := WithUserID(r.Context(), apiKey.UserID) 290 ctx = WithAPIRequest(ctx, true) 291 r = r.WithContext(ctx) 292 293 handler(w, r) 294 } 295} 296 297func (sm *SessionManager) HandleDebug(w http.ResponseWriter, r *http.Request) { 298 ctx := r.Context() 299 userID, ok := GetUserID(ctx) 300 if !ok { 301 w.Header().Set("Content-Type", "application/json") 302 w.WriteHeader(http.StatusUnauthorized) 303 w.Write([]byte(`{"error": "User ID not found in context"}`)) 304 return 305 } 306 307 res, err := sm.db.DebugViewUserInformation(userID) 308 if err != nil { 309 w.Header().Set("Content-Type", "application/json") 310 w.WriteHeader(http.StatusInternalServerError) 311 w.Write([]byte(fmt.Sprintf(`{"error": "Failed to retrieve user information: %v"}`, err))) 312 return 313 } 314 315 w.Header().Set("Content-Type", "application/json") 316 w.WriteHeader(http.StatusOK) 317 json.NewEncoder(w).Encode(res) 318} 319 320type contextKey int 321 322const ( 323 userIDKey contextKey = iota 324 apiRequestKey 325 authStatusKey 326) 327 328func WithUserID(ctx context.Context, userID int64) context.Context { 329 return context.WithValue(ctx, userIDKey, userID) 330} 331 332func GetUserID(ctx context.Context) (int64, bool) { 333 userID, ok := ctx.Value(userIDKey).(int64) 334 return userID, ok 335} 336 337func WithAuthStatus(ctx context.Context, isAuthed bool) context.Context { 338 return context.WithValue(ctx, authStatusKey, isAuthed) 339} 340 341func WithAPIRequest(ctx context.Context, isAPI bool) context.Context { 342 return context.WithValue(ctx, apiRequestKey, isAPI) 343} 344 345func IsAPIRequest(ctx context.Context) bool { 346 isAPI, ok := ctx.Value(apiRequestKey).(bool) 347 return ok && isAPI 348}