+14
appview/cache/cache.go
+14
appview/cache/cache.go
+172
appview/cache/session/store.go
+172
appview/cache/session/store.go
···
1
+
package session
2
+
3
+
import (
4
+
"context"
5
+
"encoding/json"
6
+
"fmt"
7
+
"time"
8
+
9
+
"tangled.sh/tangled.sh/core/appview/cache"
10
+
)
11
+
12
+
type OAuthSession struct {
13
+
Handle string
14
+
Did string
15
+
PdsUrl string
16
+
AccessJwt string
17
+
RefreshJwt string
18
+
AuthServerIss string
19
+
DpopPdsNonce string
20
+
DpopAuthserverNonce string
21
+
DpopPrivateJwk string
22
+
Expiry string
23
+
}
24
+
25
+
type OAuthRequest struct {
26
+
AuthserverIss string
27
+
Handle string
28
+
State string
29
+
Did string
30
+
PdsUrl string
31
+
PkceVerifier string
32
+
DpopAuthserverNonce string
33
+
DpopPrivateJwk string
34
+
}
35
+
36
+
type SessionStore struct {
37
+
cache *cache.Cache
38
+
}
39
+
40
+
const (
41
+
stateKey = "oauthstate:%s"
42
+
requestKey = "oauthrequest:%s"
43
+
sessionKey = "oauthsession:%s"
44
+
)
45
+
46
+
func New(cache *cache.Cache) *SessionStore {
47
+
return &SessionStore{cache: cache}
48
+
}
49
+
50
+
func (s *SessionStore) SaveSession(ctx context.Context, session OAuthSession) error {
51
+
key := fmt.Sprintf(sessionKey, session.Did)
52
+
data, err := json.Marshal(session)
53
+
if err != nil {
54
+
return err
55
+
}
56
+
57
+
// set with ttl (expires in + buffer)
58
+
expiry, _ := time.Parse(time.RFC3339, session.Expiry)
59
+
ttl := time.Until(expiry) + time.Minute
60
+
61
+
return s.cache.Set(ctx, key, data, ttl).Err()
62
+
}
63
+
64
+
// SaveRequest stores the OAuth request to be later fetched in the callback. Since
65
+
// the fetching happens by comparing the state we get in the callback params, we
66
+
// store an additional state->did mapping which then lets us fetch the whole OAuth request.
67
+
func (s *SessionStore) SaveRequest(ctx context.Context, request OAuthRequest) error {
68
+
key := fmt.Sprintf(requestKey, request.Did)
69
+
data, err := json.Marshal(request)
70
+
if err != nil {
71
+
return err
72
+
}
73
+
74
+
// oauth flow must complete within 30 minutes
75
+
err = s.cache.Set(ctx, key, data, 30*time.Minute).Err()
76
+
if err != nil {
77
+
return fmt.Errorf("error saving request: %w", err)
78
+
}
79
+
80
+
stateKey := fmt.Sprintf(stateKey, request.State)
81
+
err = s.cache.Set(ctx, stateKey, request.Did, 30*time.Minute).Err()
82
+
if err != nil {
83
+
return fmt.Errorf("error saving state->did mapping: %w", err)
84
+
}
85
+
86
+
return nil
87
+
}
88
+
89
+
func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSession, error) {
90
+
key := fmt.Sprintf(sessionKey, did)
91
+
val, err := s.cache.Get(ctx, key).Result()
92
+
if err != nil {
93
+
return nil, err
94
+
}
95
+
96
+
var session OAuthSession
97
+
err = json.Unmarshal([]byte(val), &session)
98
+
if err != nil {
99
+
return nil, err
100
+
}
101
+
return &session, nil
102
+
}
103
+
104
+
func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) {
105
+
didKey, err := s.getRequestKey(ctx, state)
106
+
if err != nil {
107
+
return nil, err
108
+
}
109
+
110
+
val, err := s.cache.Get(ctx, didKey).Result()
111
+
if err != nil {
112
+
return nil, err
113
+
}
114
+
115
+
var request OAuthRequest
116
+
err = json.Unmarshal([]byte(val), &request)
117
+
if err != nil {
118
+
return nil, err
119
+
}
120
+
121
+
return &request, nil
122
+
}
123
+
124
+
func (s *SessionStore) DeleteSession(ctx context.Context, did string) error {
125
+
key := fmt.Sprintf(sessionKey, did)
126
+
return s.cache.Del(ctx, key).Err()
127
+
}
128
+
129
+
func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error {
130
+
didKey, err := s.getRequestKey(ctx, state)
131
+
if err != nil {
132
+
return err
133
+
}
134
+
135
+
err = s.cache.Del(ctx, fmt.Sprintf(stateKey, "state")).Err()
136
+
if err != nil {
137
+
return err
138
+
}
139
+
140
+
return s.cache.Del(ctx, didKey).Err()
141
+
}
142
+
143
+
func (s *SessionStore) RefreshSession(ctx context.Context, did, access, refresh, expiry string) error {
144
+
session, err := s.GetSession(ctx, did)
145
+
if err != nil {
146
+
return err
147
+
}
148
+
session.AccessJwt = access
149
+
session.RefreshJwt = refresh
150
+
session.Expiry = expiry
151
+
return s.SaveSession(ctx, *session)
152
+
}
153
+
154
+
func (s *SessionStore) UpdateNonce(ctx context.Context, did, nonce string) error {
155
+
session, err := s.GetSession(ctx, did)
156
+
if err != nil {
157
+
return err
158
+
}
159
+
session.DpopAuthserverNonce = nonce
160
+
return s.SaveSession(ctx, *session)
161
+
}
162
+
163
+
func (s *SessionStore) getRequestKeyFromState(ctx context.Context, state string) (string, error) {
164
+
key := fmt.Sprintf(stateKey, state)
165
+
did, err := s.cache.Get(ctx, key).Result()
166
+
if err != nil {
167
+
return "", err
168
+
}
169
+
170
+
didKey := fmt.Sprintf(requestKey, did)
171
+
return didKey, nil
172
+
}