forked from
tangled.org/core
fork
Configure Feed
Select the types of activity you want to include in your feed.
Monorepo for Tangled
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package oauth
2
3import (
4 "fmt"
5 "log"
6 "net/http"
7 "net/url"
8 "time"
9
10 "github.com/gorilla/sessions"
11 oauth "tangled.sh/icyphox.sh/atproto-oauth"
12 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
13 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session"
14 "tangled.sh/tangled.sh/core/appview/config"
15 "tangled.sh/tangled.sh/core/appview/oauth/client"
16 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient"
17)
18
19type OAuth struct {
20 store *sessions.CookieStore
21 config *config.Config
22 sess *sessioncache.SessionStore
23}
24
25func NewOAuth(config *config.Config, sess *sessioncache.SessionStore) *OAuth {
26 return &OAuth{
27 store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
28 config: config,
29 sess: sess,
30 }
31}
32
33func (o *OAuth) Stores() *sessions.CookieStore {
34 return o.store
35}
36
37func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq sessioncache.OAuthRequest, oresp *oauth.TokenResponse) error {
38 // first we save the did in the user session
39 userSession, err := o.store.Get(r, SessionName)
40 if err != nil {
41 return err
42 }
43
44 userSession.Values[SessionDid] = oreq.Did
45 userSession.Values[SessionHandle] = oreq.Handle
46 userSession.Values[SessionPds] = oreq.PdsUrl
47 userSession.Values[SessionAuthenticated] = true
48 err = userSession.Save(r, w)
49 if err != nil {
50 return fmt.Errorf("error saving user session: %w", err)
51 }
52
53 // then save the whole thing in the db
54 session := sessioncache.OAuthSession{
55 Did: oreq.Did,
56 Handle: oreq.Handle,
57 PdsUrl: oreq.PdsUrl,
58 DpopAuthserverNonce: oreq.DpopAuthserverNonce,
59 AuthServerIss: oreq.AuthserverIss,
60 DpopPrivateJwk: oreq.DpopPrivateJwk,
61 AccessJwt: oresp.AccessToken,
62 RefreshJwt: oresp.RefreshToken,
63 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
64 }
65
66 return o.sess.SaveSession(r.Context(), session)
67}
68
69func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
70 userSession, err := o.store.Get(r, SessionName)
71 if err != nil || userSession.IsNew {
72 return fmt.Errorf("error getting user session (or new session?): %w", err)
73 }
74
75 did := userSession.Values[SessionDid].(string)
76
77 err = o.sess.DeleteSession(r.Context(), did)
78 if err != nil {
79 return fmt.Errorf("error deleting oauth session: %w", err)
80 }
81
82 userSession.Options.MaxAge = -1
83
84 return userSession.Save(r, w)
85}
86
87func (o *OAuth) GetSession(r *http.Request) (*sessioncache.OAuthSession, bool, error) {
88 userSession, err := o.store.Get(r, SessionName)
89 if err != nil || userSession.IsNew {
90 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
91 }
92
93 did := userSession.Values[SessionDid].(string)
94 auth := userSession.Values[SessionAuthenticated].(bool)
95
96 session, err := o.sess.GetSession(r.Context(), did)
97 if err != nil {
98 return nil, false, fmt.Errorf("error getting oauth session: %w", err)
99 }
100
101 expiry, err := time.Parse(time.RFC3339, session.Expiry)
102 if err != nil {
103 return nil, false, fmt.Errorf("error parsing expiry time: %w", err)
104 }
105 if expiry.Sub(time.Now()) <= 5*time.Minute {
106 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
107 if err != nil {
108 return nil, false, err
109 }
110
111 self := o.ClientMetadata()
112
113 oauthClient, err := client.NewClient(
114 self.ClientID,
115 o.config.OAuth.Jwks,
116 self.RedirectURIs[0],
117 )
118
119 if err != nil {
120 return nil, false, err
121 }
122
123 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk)
124 if err != nil {
125 return nil, false, err
126 }
127
128 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
129 err = o.sess.RefreshSession(r.Context(), did, resp.AccessToken, resp.RefreshToken, newExpiry)
130 if err != nil {
131 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
132 }
133
134 // update the current session
135 session.AccessJwt = resp.AccessToken
136 session.RefreshJwt = resp.RefreshToken
137 session.DpopAuthserverNonce = resp.DpopAuthserverNonce
138 session.Expiry = newExpiry
139 }
140
141 return session, auth, nil
142}
143
144type User struct {
145 Handle string
146 Did string
147 Pds string
148}
149
150func (a *OAuth) GetUser(r *http.Request) *User {
151 clientSession, err := a.store.Get(r, SessionName)
152
153 if err != nil || clientSession.IsNew {
154 return nil
155 }
156
157 return &User{
158 Handle: clientSession.Values[SessionHandle].(string),
159 Did: clientSession.Values[SessionDid].(string),
160 Pds: clientSession.Values[SessionPds].(string),
161 }
162}
163
164func (a *OAuth) GetDid(r *http.Request) string {
165 clientSession, err := a.store.Get(r, SessionName)
166
167 if err != nil || clientSession.IsNew {
168 return ""
169 }
170
171 return clientSession.Values[SessionDid].(string)
172}
173
174func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
175 session, auth, err := o.GetSession(r)
176 if err != nil {
177 return nil, fmt.Errorf("error getting session: %w", err)
178 }
179 if !auth {
180 return nil, fmt.Errorf("not authorized")
181 }
182
183 client := &oauth.XrpcClient{
184 OnDpopPdsNonceChanged: func(did, newNonce string) {
185 err := o.sess.UpdateNonce(r.Context(), did, newNonce)
186 if err != nil {
187 log.Printf("error updating dpop pds nonce: %v", err)
188 }
189 },
190 }
191
192 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
193 if err != nil {
194 return nil, fmt.Errorf("error parsing private jwk: %w", err)
195 }
196
197 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{
198 Did: session.Did,
199 PdsUrl: session.PdsUrl,
200 DpopPdsNonce: session.PdsUrl,
201 AccessToken: session.AccessJwt,
202 Issuer: session.AuthServerIss,
203 DpopPrivateJwk: privateJwk,
204 })
205
206 return xrpcClient, nil
207}
208
209type ClientMetadata struct {
210 ClientID string `json:"client_id"`
211 ClientName string `json:"client_name"`
212 SubjectType string `json:"subject_type"`
213 ClientURI string `json:"client_uri"`
214 RedirectURIs []string `json:"redirect_uris"`
215 GrantTypes []string `json:"grant_types"`
216 ResponseTypes []string `json:"response_types"`
217 ApplicationType string `json:"application_type"`
218 DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
219 JwksURI string `json:"jwks_uri"`
220 Scope string `json:"scope"`
221 TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
222 TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
223}
224
225func (o *OAuth) ClientMetadata() ClientMetadata {
226 makeRedirectURIs := func(c string) []string {
227 return []string{fmt.Sprintf("%s/oauth/callback", c)}
228 }
229
230 clientURI := o.config.Core.AppviewHost
231 clientID := fmt.Sprintf("%s/oauth/client-metadata.json", clientURI)
232 redirectURIs := makeRedirectURIs(clientURI)
233
234 if o.config.Core.Dev {
235 clientURI = fmt.Sprintf("http://127.0.0.1:3000")
236 redirectURIs = makeRedirectURIs(clientURI)
237
238 query := url.Values{}
239 query.Add("redirect_uri", redirectURIs[0])
240 query.Add("scope", "atproto transition:generic")
241 clientID = fmt.Sprintf("http://localhost?%s", query.Encode())
242 }
243
244 jwksURI := fmt.Sprintf("%s/oauth/jwks.json", clientURI)
245
246 return ClientMetadata{
247 ClientID: clientID,
248 ClientName: "Tangled",
249 SubjectType: "public",
250 ClientURI: clientURI,
251 RedirectURIs: redirectURIs,
252 GrantTypes: []string{"authorization_code", "refresh_token"},
253 ResponseTypes: []string{"code"},
254 ApplicationType: "web",
255 DpopBoundAccessTokens: true,
256 JwksURI: jwksURI,
257 Scope: "atproto transition:generic",
258 TokenEndpointAuthMethod: "private_key_jwt",
259 TokenEndpointAuthSigningAlg: "ES256",
260 }
261}