Monorepo for Tangled tangled.org
1package oauth 2 3import ( 4 "fmt" 5 "log" 6 "net/http" 7 "time" 8 9 "github.com/gorilla/sessions" 10 oauth "github.com/haileyok/atproto-oauth-golang" 11 "github.com/haileyok/atproto-oauth-golang/helpers" 12 "tangled.sh/tangled.sh/core/appview" 13 "tangled.sh/tangled.sh/core/appview/db" 14 "tangled.sh/tangled.sh/core/appview/oauth/client" 15 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient" 16) 17 18type OAuthRequest struct { 19 ID uint 20 AuthserverIss string 21 State string 22 Did string 23 PdsUrl string 24 PkceVerifier string 25 DpopAuthserverNonce string 26 DpopPrivateJwk string 27} 28 29type OAuth struct { 30 Store *sessions.CookieStore 31 Db *db.DB 32 Config *appview.Config 33} 34 35func NewOAuth(db *db.DB, config *appview.Config) *OAuth { 36 return &OAuth{ 37 Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)), 38 Db: db, 39 Config: config, 40 } 41} 42 43func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error { 44 // first we save the did in the user session 45 userSession, err := o.Store.Get(r, appview.SessionName) 46 if err != nil { 47 return err 48 } 49 50 userSession.Values[appview.SessionDid] = oreq.Did 51 userSession.Values[appview.SessionHandle] = oreq.Handle 52 userSession.Values[appview.SessionPds] = oreq.PdsUrl 53 userSession.Values[appview.SessionAuthenticated] = true 54 err = userSession.Save(r, w) 55 if err != nil { 56 return fmt.Errorf("error saving user session: %w", err) 57 } 58 59 // then save the whole thing in the db 60 session := db.OAuthSession{ 61 Did: oreq.Did, 62 Handle: oreq.Handle, 63 PdsUrl: oreq.PdsUrl, 64 DpopAuthserverNonce: oreq.DpopAuthserverNonce, 65 AuthServerIss: oreq.AuthserverIss, 66 DpopPrivateJwk: oreq.DpopPrivateJwk, 67 AccessJwt: oresp.AccessToken, 68 RefreshJwt: oresp.RefreshToken, 69 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339), 70 } 71 72 return db.SaveOAuthSession(o.Db, session) 73} 74 75func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error { 76 userSession, err := o.Store.Get(r, appview.SessionName) 77 if err != nil || userSession.IsNew { 78 return fmt.Errorf("error getting user session (or new session?): %w", err) 79 } 80 81 did := userSession.Values[appview.SessionDid].(string) 82 83 err = db.DeleteOAuthSessionByDid(o.Db, did) 84 if err != nil { 85 return fmt.Errorf("error deleting oauth session: %w", err) 86 } 87 88 userSession.Options.MaxAge = -1 89 90 return userSession.Save(r, w) 91} 92 93func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) { 94 userSession, err := o.Store.Get(r, appview.SessionName) 95 if err != nil || userSession.IsNew { 96 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err) 97 } 98 99 did := userSession.Values[appview.SessionDid].(string) 100 auth := userSession.Values[appview.SessionAuthenticated].(bool) 101 102 session, err := db.GetOAuthSessionByDid(o.Db, did) 103 if err != nil { 104 return nil, false, fmt.Errorf("error getting oauth session: %w", err) 105 } 106 107 expiry, err := time.Parse(time.RFC3339, session.Expiry) 108 if err != nil { 109 return nil, false, fmt.Errorf("error parsing expiry time: %w", err) 110 } 111 if expiry.Sub(time.Now()) <= 5*time.Minute { 112 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 113 if err != nil { 114 return nil, false, err 115 } 116 oauthClient, err := client.NewClient(o.Config.OAuth.ServerMetadataUrl, 117 o.Config.OAuth.Jwks, 118 fmt.Sprintf("%s/oauth/callback", o.Config.Core.AppviewHost)) 119 120 if err != nil { 121 return nil, false, err 122 } 123 124 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk) 125 if err != nil { 126 return nil, false, err 127 } 128 129 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) 130 err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry) 131 if err != nil { 132 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err) 133 } 134 135 // update the current session 136 session.AccessJwt = resp.AccessToken 137 session.RefreshJwt = resp.RefreshToken 138 session.DpopAuthserverNonce = resp.DpopAuthserverNonce 139 session.Expiry = newExpiry 140 } 141 142 return session, auth, nil 143} 144 145type User struct { 146 Handle string 147 Did string 148 Pds string 149} 150 151func (a *OAuth) GetUser(r *http.Request) *User { 152 clientSession, err := a.Store.Get(r, appview.SessionName) 153 154 if err != nil || clientSession.IsNew { 155 return nil 156 } 157 158 return &User{ 159 Handle: clientSession.Values[appview.SessionHandle].(string), 160 Did: clientSession.Values[appview.SessionDid].(string), 161 Pds: clientSession.Values[appview.SessionPds].(string), 162 } 163} 164 165func (a *OAuth) GetDid(r *http.Request) string { 166 clientSession, err := a.Store.Get(r, appview.SessionName) 167 168 if err != nil || clientSession.IsNew { 169 return "" 170 } 171 172 return clientSession.Values[appview.SessionDid].(string) 173} 174 175func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) { 176 session, auth, err := o.GetSession(r) 177 if err != nil { 178 return nil, fmt.Errorf("error getting session: %w", err) 179 } 180 if !auth { 181 return nil, fmt.Errorf("not authorized") 182 } 183 184 client := &oauth.XrpcClient{ 185 OnDpopPdsNonceChanged: func(did, newNonce string) { 186 err := db.UpdateDpopPdsNonce(o.Db, did, newNonce) 187 if err != nil { 188 log.Printf("error updating dpop pds nonce: %v", err) 189 } 190 }, 191 } 192 193 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk)) 194 if err != nil { 195 return nil, fmt.Errorf("error parsing private jwk: %w", err) 196 } 197 198 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{ 199 Did: session.Did, 200 PdsUrl: session.PdsUrl, 201 DpopPdsNonce: session.PdsUrl, 202 AccessToken: session.AccessJwt, 203 Issuer: session.AuthServerIss, 204 DpopPrivateJwk: privateJwk, 205 }) 206 207 return xrpcClient, nil 208}