Discover books, shows, and movies at your level. Track your progress by filling your Shelf with what you find, and share with other language learners. *No dusting required.
shlf.space
1// MIT License
2//
3// Copyright (c) 2025 Anirudh Oppiliappan, Akshay Oppiliappan and
4// contributors.
5//
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to deal
8// in the Software without restriction, including without limitation the rights
9// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10// copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12//
13// The above copyright notice and this permission notice shall be included in all
14// copies or substantial portions of the Software.
15//
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22// SOFTWARE.
23
24package session
25
26import (
27 "context"
28 "encoding/json"
29 "fmt"
30 "time"
31
32 "shlf.space/internal/cache"
33)
34
35type OAuthSession struct {
36 Handle string
37 Did string
38 PdsUrl string
39 AccessJwt string
40 RefreshJwt string
41 AuthServerIss string
42 DpopPdsNonce string
43 DpopAuthserverNonce string
44 DpopPrivateJwk string
45 Expiry string
46}
47
48type OAuthRequest struct {
49 AuthserverIss string
50 Handle string
51 State string
52 Did string
53 PdsUrl string
54 PkceVerifier string
55 DpopAuthserverNonce string
56 DpopPrivateJwk string
57 ReturnUrl string
58}
59
60type SessionStore struct {
61 cache *cache.Cache
62}
63
64const (
65 stateKey = "oauthstate:%s"
66 requestKey = "oauthrequest:%s"
67 sessionKey = "oauthsession:%s"
68)
69
70func New(cache *cache.Cache) *SessionStore {
71 return &SessionStore{cache: cache}
72}
73
74func (s *SessionStore) SaveSession(ctx context.Context, session OAuthSession) error {
75 key := fmt.Sprintf(sessionKey, session.Did)
76 data, err := json.Marshal(session)
77 if err != nil {
78 return err
79 }
80
81 // set with ttl (7 days)
82 ttl := 7 * 24 * time.Hour
83
84 return s.cache.Set(ctx, key, data, ttl).Err()
85}
86
87// SaveRequest stores the OAuth request to be later fetched in the callback. Since
88// the fetching happens by comparing the state we get in the callback params, we
89// store an additional state->did mapping which then lets us fetch the whole OAuth request.
90func (s *SessionStore) SaveRequest(ctx context.Context, request OAuthRequest) error {
91 key := fmt.Sprintf(requestKey, request.Did)
92 data, err := json.Marshal(request)
93 if err != nil {
94 return err
95 }
96
97 // oauth flow must complete within 30 minutes
98 err = s.cache.Set(ctx, key, data, 30*time.Minute).Err()
99 if err != nil {
100 return fmt.Errorf("error saving request: %w", err)
101 }
102
103 stateKey := fmt.Sprintf(stateKey, request.State)
104 err = s.cache.Set(ctx, stateKey, request.Did, 30*time.Minute).Err()
105 if err != nil {
106 return fmt.Errorf("error saving state->did mapping: %w", err)
107 }
108
109 return nil
110}
111
112func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSession, error) {
113 key := fmt.Sprintf(sessionKey, did)
114 val, err := s.cache.Get(ctx, key).Result()
115 if err != nil {
116 return nil, err
117 }
118
119 var session OAuthSession
120 err = json.Unmarshal([]byte(val), &session)
121 if err != nil {
122 return nil, err
123 }
124 return &session, nil
125}
126
127func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) {
128 didKey, err := s.getRequestKeyFromState(ctx, state)
129 if err != nil {
130 return nil, err
131 }
132
133 val, err := s.cache.Get(ctx, didKey).Result()
134 if err != nil {
135 return nil, err
136 }
137
138 var request OAuthRequest
139 err = json.Unmarshal([]byte(val), &request)
140 if err != nil {
141 return nil, err
142 }
143
144 return &request, nil
145}
146
147func (s *SessionStore) DeleteSession(ctx context.Context, did string) error {
148 key := fmt.Sprintf(sessionKey, did)
149 return s.cache.Del(ctx, key).Err()
150}
151
152func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error {
153 didKey, err := s.getRequestKeyFromState(ctx, state)
154 if err != nil {
155 return err
156 }
157
158 err = s.cache.Del(ctx, fmt.Sprintf(stateKey, state)).Err()
159 if err != nil {
160 return err
161 }
162
163 return s.cache.Del(ctx, didKey).Err()
164}
165
166func (s *SessionStore) RefreshSession(ctx context.Context, did, access, refresh, expiry string) error {
167 session, err := s.GetSession(ctx, did)
168 if err != nil {
169 return err
170 }
171 session.AccessJwt = access
172 session.RefreshJwt = refresh
173 session.Expiry = expiry
174 return s.SaveSession(ctx, *session)
175}
176
177func (s *SessionStore) UpdateNonce(ctx context.Context, did, nonce string) error {
178 session, err := s.GetSession(ctx, did)
179 if err != nil {
180 return err
181 }
182 session.DpopAuthserverNonce = nonce
183 return s.SaveSession(ctx, *session)
184}
185
186func (s *SessionStore) getRequestKeyFromState(ctx context.Context, state string) (string, error) {
187 key := fmt.Sprintf(stateKey, state)
188 did, err := s.cache.Get(ctx, key).Result()
189 if err != nil {
190 return "", err
191 }
192
193 didKey := fmt.Sprintf(requestKey, did)
194 return didKey, nil
195}