appview: cache/session: init redis session store #210

closed
opened by anirudh.fi targeting master from push-ruoqnsmttnxx

And a high-level cache package for future use.

Signed-off-by: Anirudh Oppiliappan anirudh@tangled.sh

Changed files
+223 -42
appview
cache
session
oauth
state
+14
appview/cache/cache.go
··· 1 + package cache 2 + 3 + import "github.com/redis/go-redis/v9" 4 + 5 + type Cache struct { 6 + *redis.Client 7 + } 8 + 9 + func New(addr string) *Cache { 10 + rdb := redis.NewClient(&redis.Options{ 11 + Addr: addr, 12 + }) 13 + return &Cache{rdb} 14 + }
+163
appview/cache/session/store.go
··· 1 + package session 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "fmt" 7 + "time" 8 + 9 + "tangled.sh/tangled.sh/core/appview/cache" 10 + ) 11 + 12 + type OAuthSession struct { 13 + Handle string 14 + Did string 15 + PdsUrl string 16 + AccessJwt string 17 + RefreshJwt string 18 + AuthServerIss string 19 + DpopPdsNonce string 20 + DpopAuthserverNonce string 21 + DpopPrivateJwk string 22 + Expiry string 23 + } 24 + 25 + type OAuthRequest struct { 26 + AuthserverIss string 27 + Handle string 28 + State string 29 + Did string 30 + PdsUrl string 31 + PkceVerifier string 32 + DpopAuthserverNonce string 33 + DpopPrivateJwk string 34 + } 35 + 36 + type SessionStore struct { 37 + cache *cache.Cache 38 + } 39 + 40 + func New(cache *cache.Cache) *SessionStore { 41 + return &SessionStore{cache: cache} 42 + } 43 + 44 + func (s *SessionStore) SaveSession(ctx context.Context, session OAuthSession) error { 45 + key := fmt.Sprintf("oauthsession:%s", session.Did) 46 + data, err := json.Marshal(session) 47 + if err != nil { 48 + return err 49 + } 50 + 51 + // Set with TTL = expires in + buffer 52 + expiry, _ := time.Parse(time.RFC3339, session.Expiry) 53 + ttl := time.Until(expiry) + time.Minute 54 + 55 + return s.cache.Set(ctx, key, data, ttl).Err() 56 + } 57 + 58 + // SaveRequest stores the OAuth request to be later fetched in the callback. Since 59 + // the fetching happens by comparing the state we get in the callback params, we 60 + // store an additional state->did mapping which then lets us fetch the whole OAuth request. 61 + func (s *SessionStore) SaveRequest(ctx context.Context, request OAuthRequest) error { 62 + key := fmt.Sprintf("oauthrequest:%s", request.Did) 63 + data, err := json.Marshal(request) 64 + if err != nil { 65 + return err 66 + } 67 + 68 + // oauth flow must complete within 30 minutes 69 + err = s.cache.Set(ctx, key, data, 30*time.Minute).Err() 70 + if err != nil { 71 + return fmt.Errorf("error saving request: %w", err) 72 + } 73 + 74 + stateKey := fmt.Sprintf("oauthstate:%s", request.State) 75 + err = s.cache.Set(ctx, stateKey, request.Did, 30*time.Minute).Err() 76 + if err != nil { 77 + return fmt.Errorf("error saving state->did mapping: %w", err) 78 + } 79 + 80 + return nil 81 + } 82 + 83 + func (s *SessionStore) GetSession(ctx context.Context, did string) (*OAuthSession, error) { 84 + key := fmt.Sprintf("oauthsession:%s", did) 85 + val, err := s.cache.Get(ctx, key).Result() 86 + if err != nil { 87 + return nil, err 88 + } 89 + 90 + var session OAuthSession 91 + err = json.Unmarshal([]byte(val), &session) 92 + if err != nil { 93 + return nil, err 94 + } 95 + return &session, nil 96 + } 97 + 98 + func (s *SessionStore) GetRequestByState(ctx context.Context, state string) (*OAuthRequest, error) { 99 + didKey, err := s.getRequestKey(ctx, state) 100 + if err != nil { 101 + return nil, err 102 + } 103 + 104 + val, err := s.cache.Get(ctx, didKey).Result() 105 + if err != nil { 106 + return nil, err 107 + } 108 + 109 + var request OAuthRequest 110 + err = json.Unmarshal([]byte(val), &request) 111 + if err != nil { 112 + return nil, err 113 + } 114 + 115 + return &request, nil 116 + } 117 + 118 + func (s *SessionStore) DeleteSession(ctx context.Context, did string) error { 119 + key := fmt.Sprintf("oauthsession:%s", did) 120 + return s.cache.Del(ctx, key).Err() 121 + } 122 + 123 + func (s *SessionStore) DeleteRequestByState(ctx context.Context, state string) error { 124 + key := fmt.Sprintf("oauthstate:%s", state) 125 + did, err := s.cache.Get(ctx, key).Result() 126 + if err != nil { 127 + return err 128 + } 129 + 130 + didKey := fmt.Sprintf("oauthrequest:%s", did) 131 + return s.cache.Del(ctx, didKey).Err() 132 + } 133 + 134 + func (s *SessionStore) RefreshSession(ctx context.Context, did, access, refresh, expiry string) error { 135 + session, err := s.GetSession(ctx, did) 136 + if err != nil { 137 + return err 138 + } 139 + session.AccessJwt = access 140 + session.RefreshJwt = refresh 141 + session.Expiry = expiry 142 + return s.SaveSession(ctx, *session) 143 + } 144 + 145 + func (s *SessionStore) UpdateNonce(ctx context.Context, did, nonce string) error { 146 + session, err := s.GetSession(ctx, did) 147 + if err != nil { 148 + return err 149 + } 150 + session.DpopAuthserverNonce = nonce 151 + return s.SaveSession(ctx, *session) 152 + } 153 + 154 + func (s *SessionStore) getRequestKey(ctx context.Context, state string) (string, error) { 155 + key := fmt.Sprintf("oauthstate:%s", state) 156 + did, err := s.cache.Get(ctx, key).Result() 157 + if err != nil { 158 + return "", err 159 + } 160 + 161 + didKey := fmt.Sprintf("oauthrequest:%s", did) 162 + return didKey, nil 163 + }
+8 -4
appview/oauth/handler/handler.go
··· 13 13 "github.com/lestrrat-go/jwx/v2/jwk" 14 14 "github.com/posthog/posthog-go" 15 15 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 16 + sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" 16 17 "tangled.sh/tangled.sh/core/appview/config" 17 18 "tangled.sh/tangled.sh/core/appview/db" 18 19 "tangled.sh/tangled.sh/core/appview/idresolver" ··· 32 33 config *config.Config 33 34 pages *pages.Pages 34 35 idResolver *idresolver.Resolver 36 + sess *sessioncache.SessionStore 35 37 db *db.DB 36 38 store *sessions.CookieStore 37 39 oauth *oauth.OAuth ··· 44 46 pages *pages.Pages, 45 47 idResolver *idresolver.Resolver, 46 48 db *db.DB, 49 + sess *sessioncache.SessionStore, 47 50 store *sessions.CookieStore, 48 51 oauth *oauth.OAuth, 49 52 enforcer *rbac.Enforcer, ··· 54 57 pages: pages, 55 58 idResolver: idResolver, 56 59 db: db, 60 + sess: sess, 57 61 store: store, 58 62 oauth: oauth, 59 63 enforcer: enforcer, ··· 158 162 return 159 163 } 160 164 161 - err = db.SaveOAuthRequest(o.db, db.OAuthRequest{ 165 + err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{ 162 166 Did: resolved.DID.String(), 163 167 PdsUrl: resolved.PDSEndpoint(), 164 168 Handle: handle, ··· 186 190 func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) { 187 191 state := r.FormValue("state") 188 192 189 - oauthRequest, err := db.GetOAuthRequestByState(o.db, state) 193 + oauthRequest, err := o.sess.GetRequestByState(r.Context(), state) 190 194 if err != nil { 191 195 log.Println("failed to get oauth request:", err) 192 196 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.") ··· 194 198 } 195 199 196 200 defer func() { 197 - err := db.DeleteOAuthRequestByState(o.db, state) 201 + err := o.sess.DeleteRequestByState(r.Context(), state) 198 202 if err != nil { 199 203 log.Println("failed to delete oauth request for state:", state, err) 200 204 } ··· 263 267 return 264 268 } 265 269 266 - err = o.oauth.SaveSession(w, r, oauthRequest, tokenResp) 270 + err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp) 267 271 if err != nil { 268 272 log.Println("failed to save session:", err) 269 273 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
+28 -35
appview/oauth/oauth.go
··· 10 10 "github.com/gorilla/sessions" 11 11 oauth "tangled.sh/icyphox.sh/atproto-oauth" 12 12 "tangled.sh/icyphox.sh/atproto-oauth/helpers" 13 + sessioncache "tangled.sh/tangled.sh/core/appview/cache/session" 13 14 "tangled.sh/tangled.sh/core/appview/config" 14 - "tangled.sh/tangled.sh/core/appview/db" 15 15 "tangled.sh/tangled.sh/core/appview/oauth/client" 16 16 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient" 17 17 ) 18 18 19 - type OAuthRequest struct { 20 - ID uint 21 - AuthserverIss string 22 - State string 23 - Did string 24 - PdsUrl string 25 - PkceVerifier string 26 - DpopAuthserverNonce string 27 - DpopPrivateJwk string 28 - } 29 - 30 19 type OAuth struct { 31 - Store *sessions.CookieStore 32 - Db *db.DB 33 - Config *config.Config 20 + store *sessions.CookieStore 21 + config *config.Config 22 + sess *sessioncache.SessionStore 34 23 } 35 24 36 - func NewOAuth(db *db.DB, config *config.Config) *OAuth { 25 + func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth { 37 26 return &OAuth{ 38 - Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 39 - Db: db, 40 - Config: config, 27 + store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 28 + config: config, 29 + sess: sess, 41 30 } 42 31 } 43 32 44 - func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error { 33 + func (o *OAuth) Stores() *sessions.CookieStore { 34 + return o.store 35 + } 36 + 37 + func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error { 45 38 // first we save the did in the user session 46 - userSession, err := o.Store.Get(r, SessionName) 39 + userSession, err := o.store.Get(r, SessionName) 47 40 if err != nil { 48 41 return err 49 42 } ··· 58 51 } 59 52 60 53 // then save the whole thing in the db 61 - session := db.OAuthSession{ 54 + session := sessioncache.OAuthSession{ 62 55 Did: oreq.Did, 63 56 Handle: oreq.Handle, 64 57 PdsUrl: oreq.PdsUrl, ··· 70 63 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), 71 64 } 72 65 73 - return db.SaveOAuthSession(o.Db, session) 66 + return o.sess.SaveSession(r.Context(), session) 74 67 } 75 68 76 69 func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { 77 - userSession, err := o.Store.Get(r, SessionName) 70 + userSession, err := o.store.Get(r, SessionName) 78 71 if err != nil || userSession.IsNew { 79 72 return fmt.Errorf("error getting user session (or new session?): %w", err) 80 73 } 81 74 82 75 did := userSession.Values[SessionDid].(string) 83 76 84 - err = db.DeleteOAuthSessionByDid(o.Db, did) 77 + err = o.sess.DeleteSession(r.Context(), did) 85 78 if err != nil { 86 79 return fmt.Errorf("error deleting oauth session: %w", err) 87 80 } ··· 91 84 return userSession.Save(r, w) 92 85 } 93 86 94 - func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { 95 - userSession, err := o.Store.Get(r, SessionName) 87 + func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) { 88 + userSession, err := o.store.Get(r, SessionName) 96 89 if err != nil || userSession.IsNew { 97 90 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err) 98 91 } ··· 100 93 did := userSession.Values[SessionDid].(string) 101 94 auth := userSession.Values[SessionAuthenticated].(bool) 102 95 103 - session, err := db.GetOAuthSessionByDid(o.Db, did) 96 + session, err := o.sess.GetSession(r.Context(), did) 104 97 if err != nil { 105 98 return nil, false, fmt.Errorf("error getting oauth session: %w", err) 106 99 } ··· 119 112 120 113 oauthClient, err := client.NewClient( 121 114 self.ClientID, 122 - o.Config.OAuth.Jwks, 115 + o.config.OAuth.Jwks, 123 116 self.RedirectURIs[0], 124 117 ) 125 118 ··· 133 126 } 134 127 135 128 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) 136 - err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry) 129 + err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry) 137 130 if err != nil { 138 131 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err) 139 132 } ··· 155 148 } 156 149 157 150 func (a *OAuth) GetUser(r *http.Request) *User { 158 - clientSession, err := a.Store.Get(r, SessionName) 151 + clientSession, err := a.store.Get(r, SessionName) 159 152 160 153 if err != nil || clientSession.IsNew { 161 154 return nil ··· 169 162 } 170 163 171 164 func (a *OAuth) GetDid(r *http.Request) string { 172 - clientSession, err := a.Store.Get(r, SessionName) 165 + clientSession, err := a.store.Get(r, SessionName) 173 166 174 167 if err != nil || clientSession.IsNew { 175 168 return "" ··· 189 182 190 183 client := &oauth.XrpcClient{ 191 184 OnDpopPdsNonceChanged: func(did, newNonce string) { 192 - err := db.UpdateDpopPdsNonce(o.Db, did, newNonce) 185 + err := o.sess.UpdateNonce(r.Context(), did, newNonce) 193 186 if err != nil { 194 187 log.Printf("error updating dpop pds nonce: %v", err) 195 188 } ··· 234 227 return []string{fmt.Sprintf("%s/oauth/callback", c)} 235 228 } 236 229 237 - clientURI := o.Config.Core.AppviewHost 230 + clientURI := o.config.Core.AppviewHost 238 231 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI) 239 232 redirectURIs := makeRedirectURIs(clientURI) 240 233 241 - if o.Config.Core.Dev { 234 + if o.config.Core.Dev { 242 235 clientURI = fmt.Sprintf("http://127.0.0.1:3000") 243 236 redirectURIs = makeRedirectURIs(clientURI) 244 237
+1 -1
appview/state/router.go
··· 156 156 157 157 func (s *State) OAuthRouter() http.Handler { 158 158 store := sessions.NewCookieStore([]byte(s.config.Core.CookieSecret)) 159 - oauth := oauthhandler.New(s.config, s.pages, s.idResolver, s.db, store, s.oauth, s.enforcer, s.posthog) 159 + oauth := oauthhandler.New(s.config, s.pages, s.idResolver, s.db, s.sess, store, s.oauth, s.enforcer, s.posthog) 160 160 return oauth.Router() 161 161 } 162 162
+9 -2
appview/state/state.go
··· 20 20 "github.com/posthog/posthog-go" 21 21 "tangled.sh/tangled.sh/core/api/tangled" 22 22 "tangled.sh/tangled.sh/core/appview" 23 + "tangled.sh/tangled.sh/core/appview/cache" 24 + "tangled.sh/tangled.sh/core/appview/cache/session" 23 25 "tangled.sh/tangled.sh/core/appview/config" 24 26 "tangled.sh/tangled.sh/core/appview/db" 25 27 "tangled.sh/tangled.sh/core/appview/idresolver" ··· 37 39 enforcer *rbac.Enforcer 38 40 tidClock syntax.TIDClock 39 41 pages *pages.Pages 42 + sess *session.SessionStore 40 43 idResolver *idresolver.Resolver 41 44 posthog posthog.Client 42 45 jc *jetstream.JetstreamClient ··· 65 68 res = idresolver.DefaultResolver() 66 69 } 67 70 68 - oauth := oauth.NewOAuth(d, config) 71 + cache := cache.New(config.Redis.Addr) 72 + sess := session.New(cache) 73 + 74 + oauth := oauth.NewOAuth(config, sess) 69 75 70 76 posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint}) 71 77 if err != nil { ··· 104 110 enforcer, 105 111 clock, 106 112 pgs, 113 + sess, 107 114 res, 108 115 posthog, 109 116 jc, ··· 176 183 177 184 return 178 185 case http.MethodPost: 179 - session, err := s.oauth.Store.Get(r, oauth.SessionName) 186 + session, err := s.oauth.Stores().Get(r, oauth.SessionName) 180 187 if err != nil || session.IsNew { 181 188 log.Println("unauthorized attempt to generate registration key") 182 189 http.Error(w, "Forbidden", http.StatusUnauthorized)