1package session
2
3import (
4 "context"
5 "crypto/rand"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "log"
10 "net/http"
11 "sync"
12 "time"
13
14 "github.com/teal-fm/piper/db"
15 "github.com/teal-fm/piper/db/apikey"
16)
17
18// session/session.go
19type Session struct {
20
21 //need to re work this. May add onto it for atproto oauth. But need to be careful about that expiresd
22 //Maybe a speerate oauth session store table and it has a created date? yeah do that then can look it up by session id from this table for user actions
23
24 ID string
25 UserID int64
26 ATProtoSessionID string
27 CreatedAt time.Time
28 ExpiresAt time.Time
29}
30
31type SessionManager struct {
32 db *db.DB
33 sessions map[string]*Session // use in memory cache if necessary
34 apiKeyMgr *apikey.ApiKeyManager
35 mu sync.RWMutex
36}
37
38func NewSessionManager(database *db.DB) *SessionManager {
39
40 _, err := database.Exec(`
41 CREATE TABLE IF NOT EXISTS sessions (
42 id TEXT PRIMARY KEY,
43 user_id INTEGER NOT NULL,
44 at_proto_session_id TEXT NOT NULL,
45 created_at TIMESTAMP,
46 expires_at TIMESTAMP,
47 FOREIGN KEY (user_id) REFERENCES users(id)
48 )`)
49
50 if err != nil {
51 log.Printf("Error creating sessions table: %v", err)
52 }
53
54 apiKeyMgr := apikey.NewApiKeyManager(database)
55
56 return &SessionManager{
57 db: database,
58 sessions: make(map[string]*Session),
59 apiKeyMgr: apiKeyMgr,
60 }
61}
62
63// create a new session for a user
64func (sm *SessionManager) CreateSession(userID int64, atProtoSessionId string) *Session {
65 sm.mu.Lock()
66 defer sm.mu.Unlock()
67
68 // random session id
69 b := make([]byte, 32)
70 rand.Read(b)
71 sessionID := base64.URLEncoding.EncodeToString(b)
72
73 now := time.Now().UTC()
74 expiresAt := now.Add(24 * time.Hour) // 24-hour session
75
76 session := &Session{
77 ID: sessionID,
78 UserID: userID,
79 ATProtoSessionID: atProtoSessionId,
80 CreatedAt: now,
81 ExpiresAt: expiresAt,
82 }
83
84 // store session in memory
85 sm.sessions[sessionID] = session
86
87 // store session in database if available
88 if sm.db != nil {
89 _, err := sm.db.Exec(`
90 INSERT INTO sessions (id, user_id, at_proto_session_id, created_at, expires_at)
91 VALUES (?, ?, ?, ?, ?)`,
92 sessionID, userID, atProtoSessionId, now, expiresAt)
93
94 if err != nil {
95 log.Printf("Error storing session in database: %v", err)
96 }
97 }
98
99 return session
100}
101
102// retrieve a session by ID
103func (sm *SessionManager) GetSession(sessionID string) (*Session, bool) {
104 // First check in-memory cache
105 sm.mu.RLock()
106 session, exists := sm.sessions[sessionID]
107 sm.mu.RUnlock()
108
109 if exists {
110 // Check if session is expired
111 if time.Now().UTC().After(session.ExpiresAt) {
112 sm.DeleteSession(sessionID)
113 return nil, false
114 }
115 return session, true
116 }
117
118 // if not in memory and we have a database, check there
119 if sm.db != nil {
120 session = &Session{ID: sessionID}
121
122 err := sm.db.QueryRow(`
123 SELECT user_id, at_proto_session_id, created_at, expires_at
124 FROM sessions WHERE id = ?`, sessionID).Scan(
125 &session.UserID, &session.ATProtoSessionID, &session.CreatedAt, &session.ExpiresAt)
126
127 if err != nil {
128 return nil, false
129 }
130
131 if time.Now().UTC().After(session.ExpiresAt) {
132 sm.DeleteSession(sessionID)
133 return nil, false
134 }
135
136 // add to in-memory cache
137 sm.mu.Lock()
138 sm.sessions[sessionID] = session
139 sm.mu.Unlock()
140
141 return session, true
142 }
143
144 return nil, false
145}
146
147// remove a session
148func (sm *SessionManager) DeleteSession(sessionID string) {
149 sm.mu.Lock()
150 delete(sm.sessions, sessionID)
151 sm.mu.Unlock()
152
153 if sm.db != nil {
154 _, err := sm.db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
155 if err != nil {
156 log.Printf("Error deleting session from database: %v", err)
157 }
158 }
159}
160
161// set a session cookie for the user
162func (sm *SessionManager) SetSessionCookie(w http.ResponseWriter, session *Session) {
163 cookie := &http.Cookie{
164 Name: "session",
165 Value: session.ID,
166 Path: "/",
167 HttpOnly: true,
168 Secure: false,
169 Expires: session.ExpiresAt,
170 }
171 http.SetCookie(w, cookie)
172}
173
174// ClearSessionCookie clears the session cookie
175func (sm *SessionManager) ClearSessionCookie(w http.ResponseWriter) {
176 cookie := &http.Cookie{
177 Name: "session",
178 Value: "",
179 Path: "/",
180 HttpOnly: true,
181 Secure: false,
182 MaxAge: -1,
183 }
184 http.SetCookie(w, cookie)
185}
186
187func (sm *SessionManager) GetAPIKeyManager() *apikey.ApiKeyManager {
188 return sm.apiKeyMgr
189}
190
191func (sm *SessionManager) CreateAPIKey(userID int64, name string, validityDays int) (*apikey.ApiKey, error) {
192 return sm.apiKeyMgr.CreateApiKey(userID, name, validityDays)
193}
194
195// middleware that checks if a user is authenticated via cookies or API key
196func WithAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
197 return func(w http.ResponseWriter, r *http.Request) {
198 // first: check API keys
199 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
200 if apiKeyErr == nil && apiKeyStr != "" {
201 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
202 if valid {
203 ctx := WithUserID(r.Context(), apiKey.UserID)
204 r = r.WithContext(ctx)
205
206 // set a flag for api requests
207 ctx = WithAPIRequest(r.Context(), true)
208 r = r.WithContext(ctx)
209
210 handler(w, r)
211 return
212 }
213 }
214
215 // if not found, check cookies for session value
216 cookie, err := r.Cookie("session")
217 if err != nil {
218 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther)
219 return
220 }
221
222 session, exists := sm.GetSession(cookie.Value)
223 if !exists {
224 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther)
225 return
226 }
227
228 ctx := WithUserID(r.Context(), session.UserID)
229 r = r.WithContext(ctx)
230
231 handler(w, r)
232 }
233}
234
235// middleware that checks if a user is authenticated but doesn't error out if not
236func WithPossibleAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
237 return func(w http.ResponseWriter, r *http.Request) {
238 ctx := r.Context()
239 authenticated := false
240
241 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
242 if apiKeyErr == nil && apiKeyStr != "" {
243 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
244 if valid {
245 ctx = WithUserID(ctx, apiKey.UserID)
246 ctx = WithAPIRequest(ctx, true)
247 authenticated = true
248 r = r.WithContext(WithAuthStatus(ctx, authenticated))
249 handler(w, r)
250 return
251 }
252 }
253
254 if !authenticated {
255 cookie, err := r.Cookie("session")
256 if err == nil {
257 session, exists := sm.GetSession(cookie.Value)
258 if exists {
259 ctx = WithUserID(ctx, session.UserID)
260 authenticated = true
261 }
262 }
263 }
264
265 r = r.WithContext(WithAuthStatus(ctx, authenticated))
266 handler(w, r)
267 }
268}
269
270// middleware that only accepts API keys
271func WithAPIAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
272 return func(w http.ResponseWriter, r *http.Request) {
273 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
274 if apiKeyErr != nil || apiKeyStr == "" {
275 w.Header().Set("Content-Type", "application/json")
276 w.WriteHeader(http.StatusUnauthorized)
277 w.Write([]byte(`{"error": "API key is required"}`))
278 return
279 }
280
281 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
282 if !valid {
283 w.Header().Set("Content-Type", "application/json")
284 w.WriteHeader(http.StatusUnauthorized)
285 w.Write([]byte(`{"error": "Invalid or expired API key"}`))
286 return
287 }
288
289 ctx := WithUserID(r.Context(), apiKey.UserID)
290 ctx = WithAPIRequest(ctx, true)
291 r = r.WithContext(ctx)
292
293 handler(w, r)
294 }
295}
296
297func (sm *SessionManager) HandleDebug(w http.ResponseWriter, r *http.Request) {
298 ctx := r.Context()
299 userID, ok := GetUserID(ctx)
300 if !ok {
301 w.Header().Set("Content-Type", "application/json")
302 w.WriteHeader(http.StatusUnauthorized)
303 w.Write([]byte(`{"error": "User ID not found in context"}`))
304 return
305 }
306
307 res, err := sm.db.DebugViewUserInformation(userID)
308 if err != nil {
309 w.Header().Set("Content-Type", "application/json")
310 w.WriteHeader(http.StatusInternalServerError)
311 w.Write([]byte(fmt.Sprintf(`{"error": "Failed to retrieve user information: %v"}`, err)))
312 return
313 }
314
315 w.Header().Set("Content-Type", "application/json")
316 w.WriteHeader(http.StatusOK)
317 json.NewEncoder(w).Encode(res)
318}
319
320type contextKey int
321
322const (
323 userIDKey contextKey = iota
324 apiRequestKey
325 authStatusKey
326)
327
328func WithUserID(ctx context.Context, userID int64) context.Context {
329 return context.WithValue(ctx, userIDKey, userID)
330}
331
332func GetUserID(ctx context.Context) (int64, bool) {
333 userID, ok := ctx.Value(userIDKey).(int64)
334 return userID, ok
335}
336
337func WithAuthStatus(ctx context.Context, isAuthed bool) context.Context {
338 return context.WithValue(ctx, authStatusKey, isAuthed)
339}
340
341func WithAPIRequest(ctx context.Context, isAPI bool) context.Context {
342 return context.WithValue(ctx, apiRequestKey, isAPI)
343}
344
345func IsAPIRequest(ctx context.Context) bool {
346 isAPI, ok := ctx.Value(apiRequestKey).(bool)
347 return ok && isAPI
348}