1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "net/http"
10 "slices"
11 "time"
12
13 "github.com/bluesky-social/indigo/atproto/auth/oauth"
14 "github.com/go-chi/chi/v5"
15 "github.com/posthog/posthog-go"
16 "tangled.org/core/api/tangled"
17 "tangled.org/core/appview/db"
18 "tangled.org/core/consts"
19 "tangled.org/core/orm"
20 "tangled.org/core/tid"
21)
22
23func (o *OAuth) Router() http.Handler {
24 r := chi.NewRouter()
25
26 r.Get("/oauth/client-metadata.json", o.clientMetadata)
27 r.Get("/oauth/jwks.json", o.jwks)
28 r.Get("/oauth/callback", o.callback)
29 return r
30}
31
32func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) {
33 doc := o.ClientApp.Config.ClientMetadata()
34 doc.JWKSURI = &o.JwksUri
35 doc.ClientName = &o.ClientName
36 doc.ClientURI = &o.ClientUri
37
38 w.Header().Set("Content-Type", "application/json")
39 if err := json.NewEncoder(w).Encode(doc); err != nil {
40 http.Error(w, err.Error(), http.StatusInternalServerError)
41 return
42 }
43}
44
45func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) {
46 w.Header().Set("Content-Type", "application/json")
47 body := o.ClientApp.Config.PublicJWKS()
48 if err := json.NewEncoder(w).Encode(body); err != nil {
49 http.Error(w, err.Error(), http.StatusInternalServerError)
50 return
51 }
52}
53
54func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) {
55 ctx := r.Context()
56 l := o.Logger.With("query", r.URL.Query())
57
58 authReturn := o.GetAuthReturn(r)
59 _ = o.ClearAuthReturn(w, r)
60
61 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query())
62 if err != nil {
63 var callbackErr *oauth.AuthRequestCallbackError
64 if errors.As(err, &callbackErr) {
65 l.Debug("callback error", "err", callbackErr)
66 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound)
67 return
68 }
69 l.Error("failed to process callback", "err", err)
70 http.Redirect(w, r, "/login?error=oauth", http.StatusFound)
71 return
72 }
73
74 if err := o.SaveSession(w, r, sessData); err != nil {
75 l.Error("failed to save session", "data", sessData, "err", err)
76 errorCode := "session"
77 if errors.Is(err, ErrMaxAccountsReached) {
78 errorCode = "max_accounts"
79 }
80 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound)
81 return
82 }
83
84 o.Logger.Debug("session saved successfully")
85 go o.addToDefaultKnot(sessData.AccountDID.String())
86 go o.addToDefaultSpindle(sessData.AccountDID.String())
87
88 if !o.Config.Core.Dev {
89 err = o.Posthog.Enqueue(posthog.Capture{
90 DistinctId: sessData.AccountDID.String(),
91 Event: "signin",
92 })
93 if err != nil {
94 o.Logger.Error("failed to enqueue posthog event", "err", err)
95 }
96 }
97
98 redirectURL := "/"
99 if authReturn.ReturnURL != "" {
100 redirectURL = authReturn.ReturnURL
101 }
102
103 http.Redirect(w, r, redirectURL, http.StatusFound)
104}
105
106func (o *OAuth) addToDefaultSpindle(did string) {
107 l := o.Logger.With("subject", did)
108
109 // use the tangled.sh app password to get an accessJwt
110 // and create an sh.tangled.spindle.member record with that
111 spindleMembers, err := db.GetSpindleMembers(
112 o.Db,
113 orm.FilterEq("instance", "spindle.tangled.sh"),
114 orm.FilterEq("subject", did),
115 )
116 if err != nil {
117 l.Error("failed to get spindle members", "err", err)
118 return
119 }
120
121 if len(spindleMembers) != 0 {
122 l.Warn("already a member of the default spindle")
123 return
124 }
125
126 l.Debug("adding to default spindle")
127 session, err := o.createAppPasswordSession(o.Config.Core.AppPassword, consts.TangledDid)
128 if err != nil {
129 l.Error("failed to create session", "err", err)
130 return
131 }
132
133 record := tangled.SpindleMember{
134 LexiconTypeID: "sh.tangled.spindle.member",
135 Subject: did,
136 Instance: consts.DefaultSpindle,
137 CreatedAt: time.Now().Format(time.RFC3339),
138 }
139
140 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil {
141 l.Error("failed to add to default spindle", "err", err)
142 return
143 }
144
145 l.Debug("successfully added to default spindle", "did", did)
146}
147
148func (o *OAuth) addToDefaultKnot(did string) {
149 l := o.Logger.With("subject", did)
150
151 // use the tangled.sh app password to get an accessJwt
152 // and create an sh.tangled.spindle.member record with that
153
154 allKnots, err := o.Enforcer.GetKnotsForUser(did)
155 if err != nil {
156 l.Error("failed to get knot members for did", "err", err)
157 return
158 }
159
160 if slices.Contains(allKnots, consts.DefaultKnot) {
161 l.Warn("already a member of the default knot")
162 return
163 }
164
165 l.Debug("addings to default knot")
166 session, err := o.createAppPasswordSession(o.Config.Core.TmpAltAppPassword, consts.IcyDid)
167 if err != nil {
168 l.Error("failed to create session", "err", err)
169 return
170 }
171
172 record := tangled.KnotMember{
173 LexiconTypeID: "sh.tangled.knot.member",
174 Subject: did,
175 Domain: consts.DefaultKnot,
176 CreatedAt: time.Now().Format(time.RFC3339),
177 }
178
179 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil {
180 l.Error("failed to add to default knot", "err", err)
181 return
182 }
183
184 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil {
185 l.Error("failed to set up enforcer rules", "err", err)
186 return
187 }
188
189 l.Debug("successfully addeds to default Knot")
190}
191
192// create a session using apppasswords
193type session struct {
194 AccessJwt string `json:"accessJwt"`
195 PdsEndpoint string
196 Did string
197}
198
199func (o *OAuth) createAppPasswordSession(appPassword, did string) (*session, error) {
200 if appPassword == "" {
201 return nil, fmt.Errorf("no app password configured, skipping member addition")
202 }
203
204 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did)
205 if err != nil {
206 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err)
207 }
208
209 pdsEndpoint := resolved.PDSEndpoint()
210 if pdsEndpoint == "" {
211 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did)
212 }
213
214 sessionPayload := map[string]string{
215 "identifier": did,
216 "password": appPassword,
217 }
218 sessionBytes, err := json.Marshal(sessionPayload)
219 if err != nil {
220 return nil, fmt.Errorf("failed to marshal session payload: %v", err)
221 }
222
223 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
224 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
225 if err != nil {
226 return nil, fmt.Errorf("failed to create session request: %v", err)
227 }
228 sessionReq.Header.Set("Content-Type", "application/json")
229
230 client := &http.Client{Timeout: 30 * time.Second}
231 sessionResp, err := client.Do(sessionReq)
232 if err != nil {
233 return nil, fmt.Errorf("failed to create session: %v", err)
234 }
235 defer sessionResp.Body.Close()
236
237 if sessionResp.StatusCode != http.StatusOK {
238 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode)
239 }
240
241 var session session
242 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
243 return nil, fmt.Errorf("failed to decode session response: %v", err)
244 }
245
246 session.PdsEndpoint = pdsEndpoint
247 session.Did = did
248
249 return &session, nil
250}
251
252func (s *session) putRecord(record any, collection string) error {
253 recordBytes, err := json.Marshal(record)
254 if err != nil {
255 return fmt.Errorf("failed to marshal knot member record: %w", err)
256 }
257
258 payload := map[string]any{
259 "repo": s.Did,
260 "collection": collection,
261 "rkey": tid.TID(),
262 "record": json.RawMessage(recordBytes),
263 }
264
265 payloadBytes, err := json.Marshal(payload)
266 if err != nil {
267 return fmt.Errorf("failed to marshal request payload: %w", err)
268 }
269
270 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
271 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
272 if err != nil {
273 return fmt.Errorf("failed to create HTTP request: %w", err)
274 }
275
276 req.Header.Set("Content-Type", "application/json")
277 req.Header.Set("Authorization", "Bearer "+s.AccessJwt)
278
279 client := &http.Client{Timeout: 30 * time.Second}
280 resp, err := client.Do(req)
281 if err != nil {
282 return fmt.Errorf("failed to add user to default service: %w", err)
283 }
284 defer resp.Body.Close()
285
286 if resp.StatusCode != http.StatusOK {
287 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode)
288 }
289
290 return nil
291}