Yōten: A social tracker for your language learning journey built on the atproto.

fix(oauth): invalidate sessions if inactive for too long

Signed-off-by: brookjeynes <me@brookjeynes.dev>

authored by brookjeynes.dev and committed by Tangled 768b8a70 2420234e

Changed files
+125 -20
internal
server
handlers
oauth
+1 -1
internal/server/handlers/login.go
··· 83 83 if err != nil { 84 84 l.Error("failed to logout", "err", err) 85 85 } else { 86 - l.Error("logged out successfully") 86 + l.Debug("logged out successfully") 87 87 } 88 88 89 89 if !h.Config.Core.Dev && did != "" {
+10 -5
internal/server/oauth/oauth.go
··· 47 47 48 48 jwksUri := clientUri + "/oauth/jwks.json" 49 49 50 - authStore, err := NewRedisStore(config.Redis.ToURL()) 50 + authStore, err := NewRedisStore(&RedisStoreConfig{ 51 + RedisURL: config.Redis.ToURL(), 52 + SessionExpiryDuration: time.Hour * 24 * 90, 53 + SessionInactivityDuration: time.Hour * 24 * 14, 54 + AuthRequestExpiryDuration: time.Minute * 30, 55 + }) 51 56 if err != nil { 52 57 return nil, err 53 58 } ··· 138 143 } 139 144 140 145 func (o *OAuth) GetUser(r *http.Request) *types.OauthUser { 141 - clientSession, err := o.SessionStore.Get(r, SessionName) 142 - if err != nil || clientSession.IsNew { 146 + sess, err := o.ResumeSession(r) 147 + if err != nil { 143 148 return nil 144 149 } 145 150 146 151 return &types.OauthUser{ 147 - Did: clientSession.Values[SessionDid].(string), 148 - Pds: clientSession.Values[SessionPds].(string), 152 + Did: sess.Data.AccountDID.String(), 153 + Pds: sess.Data.HostURL, 149 154 } 150 155 } 151 156
+114 -14
internal/server/oauth/store.go
··· 11 11 "github.com/redis/go-redis/v9" 12 12 ) 13 13 14 - // redis-backed implementation of ClientAuthStore. 14 + type RedisStoreConfig struct { 15 + RedisURL string 16 + 17 + // The purpose of these limits is to avoid dead sessions hanging around in 18 + // the db indefinitely. The durations here should be *at least as long as* 19 + // the expected duration of the oauth session itself. 20 + SessionExpiryDuration time.Duration // duration since session creation (max TTL) 21 + SessionInactivityDuration time.Duration // duration since last session update 22 + AuthRequestExpiryDuration time.Duration // duration since auth request creation 23 + } 24 + 25 + // Redis-backed implementation of ClientAuthStore 15 26 type RedisStore struct { 16 - client *redis.Client 17 - SessionTTL time.Duration 18 - AuthRequestTTL time.Duration 27 + client *redis.Client 28 + cfg *RedisStoreConfig 29 + } 30 + 31 + type sessionMetadata struct { 32 + CreatedAt time.Time `json:"created_at"` 33 + UpdatedAt time.Time `json:"updated_at"` 19 34 } 20 35 21 36 var _ oauth.ClientAuthStore = &RedisStore{} 22 37 23 - func NewRedisStore(redisURL string) (*RedisStore, error) { 24 - opts, err := redis.ParseURL(redisURL) 38 + func NewRedisStore(cfg *RedisStoreConfig) (*RedisStore, error) { 39 + if cfg == nil { 40 + return nil, fmt.Errorf("missing cfg") 41 + } 42 + if cfg.RedisURL == "" { 43 + return nil, fmt.Errorf("missing RedisURL") 44 + } 45 + if cfg.SessionExpiryDuration == 0 { 46 + return nil, fmt.Errorf("missing SessionExpiryDuration") 47 + } 48 + if cfg.SessionInactivityDuration == 0 { 49 + return nil, fmt.Errorf("missing SessionInactivityDuration") 50 + } 51 + if cfg.AuthRequestExpiryDuration == 0 { 52 + return nil, fmt.Errorf("missing AuthRequestExpiryDuration") 53 + } 54 + 55 + opts, err := redis.ParseURL(cfg.RedisURL) 25 56 if err != nil { 26 57 return nil, fmt.Errorf("failed to parse redis URL: %w", err) 27 58 } 28 59 29 60 client := redis.NewClient(opts) 30 61 31 - // Test the connection. 62 + // Test the connection 32 63 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 33 64 defer cancel() 34 65 ··· 37 68 } 38 69 39 70 return &RedisStore{ 40 - client: client, 41 - SessionTTL: 30 * 24 * time.Hour, // 30 days 42 - AuthRequestTTL: 10 * time.Minute, // 10 minutes 71 + client: client, 72 + cfg: cfg, 43 73 }, nil 44 74 } 45 75 ··· 51 81 return fmt.Sprintf("oauth:session:%s:%s", did, sessionID) 52 82 } 53 83 84 + func sessionMetadataKey(did syntax.DID, sessionID string) string { 85 + return fmt.Sprintf("oauth:session_meta:%s:%s", did, sessionID) 86 + } 87 + 54 88 func authRequestKey(state string) string { 55 89 return fmt.Sprintf("oauth:auth_request:%s", state) 56 90 } 57 91 58 92 func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 59 93 key := sessionKey(did, sessionID) 94 + metaKey := sessionMetadataKey(did, sessionID) 95 + 96 + // Check metadata for inactivity expiry 97 + metaData, err := r.client.Get(ctx, metaKey).Bytes() 98 + if err == redis.Nil { 99 + return nil, fmt.Errorf("session not found: %s", did) 100 + } 101 + if err != nil { 102 + return nil, fmt.Errorf("failed to get session metadata: %w", err) 103 + } 104 + 105 + var meta sessionMetadata 106 + if err := json.Unmarshal(metaData, &meta); err != nil { 107 + return nil, fmt.Errorf("failed to unmarshal session metadata: %w", err) 108 + } 109 + 110 + // Check if session has been inactive for too long 111 + inactiveThreshold := time.Now().Add(-r.cfg.SessionInactivityDuration) 112 + if meta.UpdatedAt.Before(inactiveThreshold) { 113 + // Session is inactive, delete it 114 + r.client.Del(ctx, key, metaKey) 115 + return nil, fmt.Errorf("session expired due to inactivity: %s", did) 116 + } 117 + 118 + // Get the actual session data 60 119 data, err := r.client.Get(ctx, key).Bytes() 61 120 if err == redis.Nil { 62 121 return nil, fmt.Errorf("session not found: %s", did) ··· 75 134 76 135 func (r *RedisStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 77 136 key := sessionKey(sess.AccountDID, sess.SessionID) 137 + metaKey := sessionMetadataKey(sess.AccountDID, sess.SessionID) 78 138 79 139 data, err := json.Marshal(sess) 80 140 if err != nil { 81 141 return fmt.Errorf("failed to marshal session: %w", err) 82 142 } 83 143 84 - if err := r.client.Set(ctx, key, data, r.SessionTTL).Err(); err != nil { 144 + // Check if session already exists to preserve CreatedAt 145 + var meta sessionMetadata 146 + existingMetaData, err := r.client.Get(ctx, metaKey).Bytes() 147 + if err == redis.Nil { 148 + // New session 149 + meta = sessionMetadata{ 150 + CreatedAt: time.Now(), 151 + UpdatedAt: time.Now(), 152 + } 153 + } else if err != nil { 154 + return fmt.Errorf("failed to check existing session metadata: %w", err) 155 + } else { 156 + // Existing session - preserve CreatedAt, update UpdatedAt 157 + if err := json.Unmarshal(existingMetaData, &meta); err != nil { 158 + return fmt.Errorf("failed to unmarshal existing session metadata: %w", err) 159 + } 160 + meta.UpdatedAt = time.Now() 161 + } 162 + 163 + // Calculate remaining TTL based on creation time 164 + remainingTTL := r.cfg.SessionExpiryDuration - time.Since(meta.CreatedAt) 165 + if remainingTTL <= 0 { 166 + return fmt.Errorf("session has expired") 167 + } 168 + 169 + // Use the shorter of: remaining TTL or inactivity duration 170 + ttl := min(r.cfg.SessionInactivityDuration, remainingTTL) 171 + 172 + // Save session data 173 + if err := r.client.Set(ctx, key, data, ttl).Err(); err != nil { 85 174 return fmt.Errorf("failed to save session: %w", err) 86 175 } 87 176 177 + // Save metadata 178 + metaData, err := json.Marshal(meta) 179 + if err != nil { 180 + return fmt.Errorf("failed to marshal session metadata: %w", err) 181 + } 182 + if err := r.client.Set(ctx, metaKey, metaData, ttl).Err(); err != nil { 183 + return fmt.Errorf("failed to save session metadata: %w", err) 184 + } 185 + 88 186 return nil 89 187 } 90 188 91 189 func (r *RedisStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 92 190 key := sessionKey(did, sessionID) 93 - if err := r.client.Del(ctx, key).Err(); err != nil { 191 + metaKey := sessionMetadataKey(did, sessionID) 192 + 193 + if err := r.client.Del(ctx, key, metaKey).Err(); err != nil { 94 194 return fmt.Errorf("failed to delete session: %w", err) 95 195 } 96 196 return nil ··· 117 217 func (r *RedisStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 118 218 key := authRequestKey(info.State) 119 219 120 - // check if already exists (to match MemStore behavior) 220 + // Check if already exists (to match MemStore behavior) 121 221 exists, err := r.client.Exists(ctx, key).Result() 122 222 if err != nil { 123 223 return fmt.Errorf("failed to check auth request existence: %w", err) ··· 131 231 return fmt.Errorf("failed to marshal auth request: %w", err) 132 232 } 133 233 134 - if err := r.client.Set(ctx, key, data, r.AuthRequestTTL).Err(); err != nil { 234 + if err := r.client.Set(ctx, key, data, r.cfg.AuthRequestExpiryDuration).Err(); err != nil { 135 235 return fmt.Errorf("failed to save auth request: %w", err) 136 236 } 137 237