+1
-1
internal/server/handlers/login.go
+1
-1
internal/server/handlers/login.go
+10
-5
internal/server/oauth/oauth.go
+10
-5
internal/server/oauth/oauth.go
···
47
47
48
48
jwksUri := clientUri + "/oauth/jwks.json"
49
49
50
-
authStore, err := NewRedisStore(config.Redis.ToURL())
50
+
authStore, err := NewRedisStore(&RedisStoreConfig{
51
+
RedisURL: config.Redis.ToURL(),
52
+
SessionExpiryDuration: time.Hour * 24 * 90,
53
+
SessionInactivityDuration: time.Hour * 24 * 14,
54
+
AuthRequestExpiryDuration: time.Minute * 30,
55
+
})
51
56
if err != nil {
52
57
return nil, err
53
58
}
···
138
143
}
139
144
140
145
func (o *OAuth) GetUser(r *http.Request) *types.OauthUser {
141
-
clientSession, err := o.SessionStore.Get(r, SessionName)
142
-
if err != nil || clientSession.IsNew {
146
+
sess, err := o.ResumeSession(r)
147
+
if err != nil {
143
148
return nil
144
149
}
145
150
146
151
return &types.OauthUser{
147
-
Did: clientSession.Values[SessionDid].(string),
148
-
Pds: clientSession.Values[SessionPds].(string),
152
+
Did: sess.Data.AccountDID.String(),
153
+
Pds: sess.Data.HostURL,
149
154
}
150
155
}
151
156
+114
-14
internal/server/oauth/store.go
+114
-14
internal/server/oauth/store.go
···
11
11
"github.com/redis/go-redis/v9"
12
12
)
13
13
14
-
// redis-backed implementation of ClientAuthStore.
14
+
type RedisStoreConfig struct {
15
+
RedisURL string
16
+
17
+
// The purpose of these limits is to avoid dead sessions hanging around in
18
+
// the db indefinitely. The durations here should be *at least as long as*
19
+
// the expected duration of the oauth session itself.
20
+
SessionExpiryDuration time.Duration // duration since session creation (max TTL)
21
+
SessionInactivityDuration time.Duration // duration since last session update
22
+
AuthRequestExpiryDuration time.Duration // duration since auth request creation
23
+
}
24
+
25
+
// Redis-backed implementation of ClientAuthStore
15
26
type RedisStore struct {
16
-
client *redis.Client
17
-
SessionTTL time.Duration
18
-
AuthRequestTTL time.Duration
27
+
client *redis.Client
28
+
cfg *RedisStoreConfig
29
+
}
30
+
31
+
type sessionMetadata struct {
32
+
CreatedAt time.Time `json:"created_at"`
33
+
UpdatedAt time.Time `json:"updated_at"`
19
34
}
20
35
21
36
var _ oauth.ClientAuthStore = &RedisStore{}
22
37
23
-
func NewRedisStore(redisURL string) (*RedisStore, error) {
24
-
opts, err := redis.ParseURL(redisURL)
38
+
func NewRedisStore(cfg *RedisStoreConfig) (*RedisStore, error) {
39
+
if cfg == nil {
40
+
return nil, fmt.Errorf("missing cfg")
41
+
}
42
+
if cfg.RedisURL == "" {
43
+
return nil, fmt.Errorf("missing RedisURL")
44
+
}
45
+
if cfg.SessionExpiryDuration == 0 {
46
+
return nil, fmt.Errorf("missing SessionExpiryDuration")
47
+
}
48
+
if cfg.SessionInactivityDuration == 0 {
49
+
return nil, fmt.Errorf("missing SessionInactivityDuration")
50
+
}
51
+
if cfg.AuthRequestExpiryDuration == 0 {
52
+
return nil, fmt.Errorf("missing AuthRequestExpiryDuration")
53
+
}
54
+
55
+
opts, err := redis.ParseURL(cfg.RedisURL)
25
56
if err != nil {
26
57
return nil, fmt.Errorf("failed to parse redis URL: %w", err)
27
58
}
28
59
29
60
client := redis.NewClient(opts)
30
61
31
-
// Test the connection.
62
+
// Test the connection
32
63
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
33
64
defer cancel()
34
65
···
37
68
}
38
69
39
70
return &RedisStore{
40
-
client: client,
41
-
SessionTTL: 30 * 24 * time.Hour, // 30 days
42
-
AuthRequestTTL: 10 * time.Minute, // 10 minutes
71
+
client: client,
72
+
cfg: cfg,
43
73
}, nil
44
74
}
45
75
···
51
81
return fmt.Sprintf("oauth:session:%s:%s", did, sessionID)
52
82
}
53
83
84
+
func sessionMetadataKey(did syntax.DID, sessionID string) string {
85
+
return fmt.Sprintf("oauth:session_meta:%s:%s", did, sessionID)
86
+
}
87
+
54
88
func authRequestKey(state string) string {
55
89
return fmt.Sprintf("oauth:auth_request:%s", state)
56
90
}
57
91
58
92
func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
59
93
key := sessionKey(did, sessionID)
94
+
metaKey := sessionMetadataKey(did, sessionID)
95
+
96
+
// Check metadata for inactivity expiry
97
+
metaData, err := r.client.Get(ctx, metaKey).Bytes()
98
+
if err == redis.Nil {
99
+
return nil, fmt.Errorf("session not found: %s", did)
100
+
}
101
+
if err != nil {
102
+
return nil, fmt.Errorf("failed to get session metadata: %w", err)
103
+
}
104
+
105
+
var meta sessionMetadata
106
+
if err := json.Unmarshal(metaData, &meta); err != nil {
107
+
return nil, fmt.Errorf("failed to unmarshal session metadata: %w", err)
108
+
}
109
+
110
+
// Check if session has been inactive for too long
111
+
inactiveThreshold := time.Now().Add(-r.cfg.SessionInactivityDuration)
112
+
if meta.UpdatedAt.Before(inactiveThreshold) {
113
+
// Session is inactive, delete it
114
+
r.client.Del(ctx, key, metaKey)
115
+
return nil, fmt.Errorf("session expired due to inactivity: %s", did)
116
+
}
117
+
118
+
// Get the actual session data
60
119
data, err := r.client.Get(ctx, key).Bytes()
61
120
if err == redis.Nil {
62
121
return nil, fmt.Errorf("session not found: %s", did)
···
75
134
76
135
func (r *RedisStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
77
136
key := sessionKey(sess.AccountDID, sess.SessionID)
137
+
metaKey := sessionMetadataKey(sess.AccountDID, sess.SessionID)
78
138
79
139
data, err := json.Marshal(sess)
80
140
if err != nil {
81
141
return fmt.Errorf("failed to marshal session: %w", err)
82
142
}
83
143
84
-
if err := r.client.Set(ctx, key, data, r.SessionTTL).Err(); err != nil {
144
+
// Check if session already exists to preserve CreatedAt
145
+
var meta sessionMetadata
146
+
existingMetaData, err := r.client.Get(ctx, metaKey).Bytes()
147
+
if err == redis.Nil {
148
+
// New session
149
+
meta = sessionMetadata{
150
+
CreatedAt: time.Now(),
151
+
UpdatedAt: time.Now(),
152
+
}
153
+
} else if err != nil {
154
+
return fmt.Errorf("failed to check existing session metadata: %w", err)
155
+
} else {
156
+
// Existing session - preserve CreatedAt, update UpdatedAt
157
+
if err := json.Unmarshal(existingMetaData, &meta); err != nil {
158
+
return fmt.Errorf("failed to unmarshal existing session metadata: %w", err)
159
+
}
160
+
meta.UpdatedAt = time.Now()
161
+
}
162
+
163
+
// Calculate remaining TTL based on creation time
164
+
remainingTTL := r.cfg.SessionExpiryDuration - time.Since(meta.CreatedAt)
165
+
if remainingTTL <= 0 {
166
+
return fmt.Errorf("session has expired")
167
+
}
168
+
169
+
// Use the shorter of: remaining TTL or inactivity duration
170
+
ttl := min(r.cfg.SessionInactivityDuration, remainingTTL)
171
+
172
+
// Save session data
173
+
if err := r.client.Set(ctx, key, data, ttl).Err(); err != nil {
85
174
return fmt.Errorf("failed to save session: %w", err)
86
175
}
87
176
177
+
// Save metadata
178
+
metaData, err := json.Marshal(meta)
179
+
if err != nil {
180
+
return fmt.Errorf("failed to marshal session metadata: %w", err)
181
+
}
182
+
if err := r.client.Set(ctx, metaKey, metaData, ttl).Err(); err != nil {
183
+
return fmt.Errorf("failed to save session metadata: %w", err)
184
+
}
185
+
88
186
return nil
89
187
}
90
188
91
189
func (r *RedisStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
92
190
key := sessionKey(did, sessionID)
93
-
if err := r.client.Del(ctx, key).Err(); err != nil {
191
+
metaKey := sessionMetadataKey(did, sessionID)
192
+
193
+
if err := r.client.Del(ctx, key, metaKey).Err(); err != nil {
94
194
return fmt.Errorf("failed to delete session: %w", err)
95
195
}
96
196
return nil
···
117
217
func (r *RedisStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
118
218
key := authRequestKey(info.State)
119
219
120
-
// check if already exists (to match MemStore behavior)
220
+
// Check if already exists (to match MemStore behavior)
121
221
exists, err := r.client.Exists(ctx, key).Result()
122
222
if err != nil {
123
223
return fmt.Errorf("failed to check auth request existence: %w", err)
···
131
231
return fmt.Errorf("failed to marshal auth request: %w", err)
132
232
}
133
233
134
-
if err := r.client.Set(ctx, key, data, r.AuthRequestTTL).Err(); err != nil {
234
+
if err := r.client.Set(ctx, key, data, r.cfg.AuthRequestExpiryDuration).Err(); err != nil {
135
235
return fmt.Errorf("failed to save auth request: %w", err)
136
236
}
137
237