Monorepo for Tangled
tangled.org
1package oauth
2
3import (
4 "fmt"
5 "log"
6 "net/http"
7 "time"
8
9 "github.com/gorilla/sessions"
10 oauth "github.com/haileyok/atproto-oauth-golang"
11 "github.com/haileyok/atproto-oauth-golang/helpers"
12 "tangled.sh/tangled.sh/core/appview"
13 "tangled.sh/tangled.sh/core/appview/db"
14 "tangled.sh/tangled.sh/core/appview/oauth/client"
15 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient"
16)
17
18type OAuthRequest struct {
19 ID uint
20 AuthserverIss string
21 State string
22 Did string
23 PdsUrl string
24 PkceVerifier string
25 DpopAuthserverNonce string
26 DpopPrivateJwk string
27}
28
29type OAuth struct {
30 Store *sessions.CookieStore
31 Db *db.DB
32 Config *appview.Config
33}
34
35func NewOAuth(db *db.DB, config *appview.Config) *OAuth {
36 return &OAuth{
37 Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
38 Db: db,
39 Config: config,
40 }
41}
42
43func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error {
44 // first we save the did in the user session
45 userSession, err := o.Store.Get(r, appview.SessionName)
46 if err != nil {
47 return err
48 }
49
50 userSession.Values[appview.SessionDid] = oreq.Did
51 userSession.Values[appview.SessionHandle] = oreq.Handle
52 userSession.Values[appview.SessionPds] = oreq.PdsUrl
53 userSession.Values[appview.SessionAuthenticated] = true
54 err = userSession.Save(r, w)
55 if err != nil {
56 return fmt.Errorf("error saving user session: %w", err)
57 }
58
59 // then save the whole thing in the db
60 session := db.OAuthSession{
61 Did: oreq.Did,
62 Handle: oreq.Handle,
63 PdsUrl: oreq.PdsUrl,
64 DpopAuthserverNonce: oreq.DpopAuthserverNonce,
65 AuthServerIss: oreq.AuthserverIss,
66 DpopPrivateJwk: oreq.DpopPrivateJwk,
67 AccessJwt: oresp.AccessToken,
68 RefreshJwt: oresp.RefreshToken,
69 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
70 }
71
72 return db.SaveOAuthSession(o.Db, session)
73}
74
75func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
76 userSession, err := o.Store.Get(r, appview.SessionName)
77 if err != nil || userSession.IsNew {
78 return fmt.Errorf("error getting user session (or new session?): %w", err)
79 }
80
81 did := userSession.Values[appview.SessionDid].(string)
82
83 err = db.DeleteOAuthSessionByDid(o.Db, did)
84 if err != nil {
85 return fmt.Errorf("error deleting oauth session: %w", err)
86 }
87
88 userSession.Options.MaxAge = -1
89
90 return userSession.Save(r, w)
91}
92
93func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) {
94 userSession, err := o.Store.Get(r, appview.SessionName)
95 if err != nil || userSession.IsNew {
96 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
97 }
98
99 did := userSession.Values[appview.SessionDid].(string)
100 auth := userSession.Values[appview.SessionAuthenticated].(bool)
101
102 session, err := db.GetOAuthSessionByDid(o.Db, did)
103 if err != nil {
104 return nil, false, fmt.Errorf("error getting oauth session: %w", err)
105 }
106
107 expiry, err := time.Parse(time.RFC3339, session.Expiry)
108 if err != nil {
109 return nil, false, fmt.Errorf("error parsing expiry time: %w", err)
110 }
111 if expiry.Sub(time.Now()) <= 5*time.Minute {
112 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
113 if err != nil {
114 return nil, false, err
115 }
116 oauthClient, err := client.NewClient(o.Config.OAuth.ServerMetadataUrl,
117 o.Config.OAuth.Jwks,
118 fmt.Sprintf("%s/oauth/callback", o.Config.Core.AppviewHost))
119
120 if err != nil {
121 return nil, false, err
122 }
123
124 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk)
125 if err != nil {
126 return nil, false, err
127 }
128
129 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
130 err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry)
131 if err != nil {
132 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
133 }
134
135 // update the current session
136 session.AccessJwt = resp.AccessToken
137 session.RefreshJwt = resp.RefreshToken
138 session.DpopAuthserverNonce = resp.DpopAuthserverNonce
139 session.Expiry = newExpiry
140 }
141
142 return session, auth, nil
143}
144
145type User struct {
146 Handle string
147 Did string
148 Pds string
149}
150
151func (a *OAuth) GetUser(r *http.Request) *User {
152 clientSession, err := a.Store.Get(r, appview.SessionName)
153
154 if err != nil || clientSession.IsNew {
155 return nil
156 }
157
158 return &User{
159 Handle: clientSession.Values[appview.SessionHandle].(string),
160 Did: clientSession.Values[appview.SessionDid].(string),
161 Pds: clientSession.Values[appview.SessionPds].(string),
162 }
163}
164
165func (a *OAuth) GetDid(r *http.Request) string {
166 clientSession, err := a.Store.Get(r, appview.SessionName)
167
168 if err != nil || clientSession.IsNew {
169 return ""
170 }
171
172 return clientSession.Values[appview.SessionDid].(string)
173}
174
175func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
176 session, auth, err := o.GetSession(r)
177 if err != nil {
178 return nil, fmt.Errorf("error getting session: %w", err)
179 }
180 if !auth {
181 return nil, fmt.Errorf("not authorized")
182 }
183
184 client := &oauth.XrpcClient{
185 OnDpopPdsNonceChanged: func(did, newNonce string) {
186 err := db.UpdateDpopPdsNonce(o.Db, did, newNonce)
187 if err != nil {
188 log.Printf("error updating dpop pds nonce: %v", err)
189 }
190 },
191 }
192
193 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
194 if err != nil {
195 return nil, fmt.Errorf("error parsing private jwk: %w", err)
196 }
197
198 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{
199 Did: session.Did,
200 PdsUrl: session.PdsUrl,
201 DpopPdsNonce: session.PdsUrl,
202 AccessToken: session.AccessJwt,
203 Issuer: session.AuthServerIss,
204 DpopPrivateJwk: privateJwk,
205 })
206
207 return xrpcClient, nil
208}