[WIP] music platform user data scraper
teal-fm
atproto
1package session
2
3import (
4 "context"
5 "crypto/rand"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "log"
10 "net/http"
11 "sync"
12 "time"
13
14 "github.com/teal-fm/piper/db"
15 "github.com/teal-fm/piper/db/apikey"
16)
17
18// session/session.go
19type Session struct {
20 ID string
21 UserID int64
22 ATprotoDID string
23 ATprotoAccessToken string
24 ATprotoRefreshToken string
25 CreatedAt time.Time
26 ExpiresAt time.Time
27}
28
29type SessionManager struct {
30 db *db.DB
31 sessions map[string]*Session // use in memory cache if necessary
32 apiKeyMgr *apikey.ApiKeyManager
33 mu sync.RWMutex
34}
35
36func NewSessionManager(database *db.DB) *SessionManager {
37
38 _, err := database.Exec(`
39 CREATE TABLE IF NOT EXISTS sessions (
40 id TEXT PRIMARY KEY,
41 user_id INTEGER NOT NULL,
42 created_at TIMESTAMP,
43 expires_at TIMESTAMP,
44 FOREIGN KEY (user_id) REFERENCES users(id)
45 )`)
46
47 if err != nil {
48 log.Printf("Error creating sessions table: %v", err)
49 }
50
51 apiKeyMgr := apikey.NewApiKeyManager(database)
52
53 return &SessionManager{
54 db: database,
55 sessions: make(map[string]*Session),
56 apiKeyMgr: apiKeyMgr,
57 }
58}
59
60// create a new session for a user
61func (sm *SessionManager) CreateSession(userID int64) *Session {
62 sm.mu.Lock()
63 defer sm.mu.Unlock()
64
65 // random session id
66 b := make([]byte, 32)
67 rand.Read(b)
68 sessionID := base64.URLEncoding.EncodeToString(b)
69
70 now := time.Now().UTC()
71 expiresAt := now.Add(24 * time.Hour) // 24-hour session
72
73 session := &Session{
74 ID: sessionID,
75 UserID: userID,
76 CreatedAt: now,
77 ExpiresAt: expiresAt,
78 }
79
80 // store session in memory
81 sm.sessions[sessionID] = session
82
83 // store session in database if available
84 if sm.db != nil {
85 _, err := sm.db.Exec(`
86 INSERT INTO sessions (id, user_id, created_at, expires_at)
87 VALUES (?, ?, ?, ?)`,
88 sessionID, userID, now, expiresAt)
89
90 if err != nil {
91 log.Printf("Error storing session in database: %v", err)
92 }
93 }
94
95 return session
96}
97
98// retrieve a session by ID
99func (sm *SessionManager) GetSession(sessionID string) (*Session, bool) {
100 // First check in-memory cache
101 sm.mu.RLock()
102 session, exists := sm.sessions[sessionID]
103 sm.mu.RUnlock()
104
105 if exists {
106 // Check if session is expired
107 if time.Now().UTC().After(session.ExpiresAt) {
108 sm.DeleteSession(sessionID)
109 return nil, false
110 }
111 return session, true
112 }
113
114 // if not in memory and we have a database, check there
115 if sm.db != nil {
116 session = &Session{ID: sessionID}
117
118 err := sm.db.QueryRow(`
119 SELECT user_id, created_at, expires_at
120 FROM sessions WHERE id = ?`, sessionID).Scan(
121 &session.UserID, &session.CreatedAt, &session.ExpiresAt)
122
123 if err != nil {
124 return nil, false
125 }
126
127 if time.Now().UTC().After(session.ExpiresAt) {
128 sm.DeleteSession(sessionID)
129 return nil, false
130 }
131
132 // add to in-memory cache
133 sm.mu.Lock()
134 sm.sessions[sessionID] = session
135 sm.mu.Unlock()
136
137 return session, true
138 }
139
140 return nil, false
141}
142
143// remove a session
144func (sm *SessionManager) DeleteSession(sessionID string) {
145 sm.mu.Lock()
146 delete(sm.sessions, sessionID)
147 sm.mu.Unlock()
148
149 if sm.db != nil {
150 _, err := sm.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
151 if err != nil {
152 log.Printf("Error deleting session from database: %v", err)
153 }
154 }
155}
156
157// set a session cookie for the user
158func (sm *SessionManager) SetSessionCookie(w http.ResponseWriter, session *Session) {
159 cookie := &http.Cookie{
160 Name: "session",
161 Value: session.ID,
162 Path: "/",
163 HttpOnly: true,
164 Secure: false,
165 Expires: session.ExpiresAt,
166 }
167 http.SetCookie(w, cookie)
168}
169
170// ClearSessionCookie clears the session cookie
171func (sm *SessionManager) ClearSessionCookie(w http.ResponseWriter) {
172 cookie := &http.Cookie{
173 Name: "session",
174 Value: "",
175 Path: "/",
176 HttpOnly: true,
177 Secure: false,
178 MaxAge: -1,
179 }
180 http.SetCookie(w, cookie)
181}
182
183func (sm *SessionManager) HandleLogout(w http.ResponseWriter, r *http.Request) {
184 cookie, err := r.Cookie("session")
185 if err == nil {
186 sm.DeleteSession(cookie.Value)
187 }
188
189 sm.ClearSessionCookie(w)
190
191 http.Redirect(w, r, "/", http.StatusSeeOther)
192}
193
194func (sm *SessionManager) GetAPIKeyManager() *apikey.ApiKeyManager {
195 return sm.apiKeyMgr
196}
197
198func (sm *SessionManager) CreateAPIKey(userID int64, name string, validityDays int) (*apikey.ApiKey, error) {
199 return sm.apiKeyMgr.CreateApiKey(userID, name, validityDays)
200}
201
202// middleware that checks if a user is authenticated via cookies or API key
203func WithAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
204 return func(w http.ResponseWriter, r *http.Request) {
205 // first: check API keys
206 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
207 if apiKeyErr == nil && apiKeyStr != "" {
208 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
209 if valid {
210 ctx := WithUserID(r.Context(), apiKey.UserID)
211 r = r.WithContext(ctx)
212
213 // set a flag for api requests
214 ctx = WithAPIRequest(r.Context(), true)
215 r = r.WithContext(ctx)
216
217 handler(w, r)
218 return
219 }
220 }
221
222 // if not found, check cookies for session value
223 cookie, err := r.Cookie("session")
224 if err != nil {
225 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther)
226 return
227 }
228
229 session, exists := sm.GetSession(cookie.Value)
230 if !exists {
231 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther)
232 return
233 }
234
235 ctx := WithUserID(r.Context(), session.UserID)
236 r = r.WithContext(ctx)
237
238 handler(w, r)
239 }
240}
241
242// middleware that checks if a user is authenticated but doesn't error out if not
243func WithPossibleAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
244 return func(w http.ResponseWriter, r *http.Request) {
245 ctx := r.Context()
246 authenticated := false
247
248 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
249 if apiKeyErr == nil && apiKeyStr != "" {
250 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
251 if valid {
252 ctx = WithUserID(ctx, apiKey.UserID)
253 ctx = WithAPIRequest(ctx, true)
254 authenticated = true
255 r = r.WithContext(WithAuthStatus(ctx, authenticated))
256 handler(w, r)
257 return
258 }
259 }
260
261 if !authenticated {
262 cookie, err := r.Cookie("session")
263 if err == nil {
264 session, exists := sm.GetSession(cookie.Value)
265 if exists {
266 ctx = WithUserID(ctx, session.UserID)
267 authenticated = true
268 }
269 }
270 }
271
272 r = r.WithContext(WithAuthStatus(ctx, authenticated))
273 handler(w, r)
274 }
275}
276
277// middleware that only accepts API keys
278func WithAPIAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
279 return func(w http.ResponseWriter, r *http.Request) {
280 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
281 if apiKeyErr != nil || apiKeyStr == "" {
282 w.Header().Set("Content-Type", "application/json")
283 w.WriteHeader(http.StatusUnauthorized)
284 w.Write([]byte(`{"error": "API key is required"}`))
285 return
286 }
287
288 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
289 if !valid {
290 w.Header().Set("Content-Type", "application/json")
291 w.WriteHeader(http.StatusUnauthorized)
292 w.Write([]byte(`{"error": "Invalid or expired API key"}`))
293 return
294 }
295
296 ctx := WithUserID(r.Context(), apiKey.UserID)
297 ctx = WithAPIRequest(ctx, true)
298 r = r.WithContext(ctx)
299
300 handler(w, r)
301 }
302}
303
304func (sm *SessionManager) HandleDebug(w http.ResponseWriter, r *http.Request) {
305 ctx := r.Context()
306 userID, ok := GetUserID(ctx)
307 if !ok {
308 w.Header().Set("Content-Type", "application/json")
309 w.WriteHeader(http.StatusUnauthorized)
310 w.Write([]byte(`{"error": "User ID not found in context"}`))
311 return
312 }
313
314 res, err := sm.db.DebugViewUserInformation(userID)
315 if err != nil {
316 w.Header().Set("Content-Type", "application/json")
317 w.WriteHeader(http.StatusInternalServerError)
318 w.Write([]byte(fmt.Sprintf(`{"error": "Failed to retrieve user information: %v"}`, err)))
319 return
320 }
321
322 w.Header().Set("Content-Type", "application/json")
323 w.WriteHeader(http.StatusOK)
324 json.NewEncoder(w).Encode(res)
325}
326
327type contextKey int
328
329const (
330 userIDKey contextKey = iota
331 apiRequestKey
332 authStatusKey
333)
334
335func WithUserID(ctx context.Context, userID int64) context.Context {
336 return context.WithValue(ctx, userIDKey, userID)
337}
338
339func GetUserID(ctx context.Context) (int64, bool) {
340 userID, ok := ctx.Value(userIDKey).(int64)
341 return userID, ok
342}
343
344func WithAuthStatus(ctx context.Context, isAuthed bool) context.Context {
345 return context.WithValue(ctx, authStatusKey, isAuthed)
346}
347
348func WithAPIRequest(ctx context.Context, isAPI bool) context.Context {
349 return context.WithValue(ctx, apiRequestKey, isAPI)
350}
351
352func IsAPIRequest(ctx context.Context) bool {
353 isAPI, ok := ctx.Value(apiRequestKey).(bool)
354 return ok && isAPI
355}