1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "net/http"
10 "slices"
11 "time"
12
13 comatproto "github.com/bluesky-social/indigo/api/atproto"
14 "github.com/bluesky-social/indigo/atproto/auth/oauth"
15 lexutil "github.com/bluesky-social/indigo/lex/util"
16 "github.com/go-chi/chi/v5"
17 "github.com/posthog/posthog-go"
18 "tangled.org/core/api/tangled"
19 "tangled.org/core/appview/db"
20 "tangled.org/core/appview/models"
21 "tangled.org/core/consts"
22 "tangled.org/core/orm"
23 "tangled.org/core/tid"
24)
25
26func (o *OAuth) Router() http.Handler {
27 r := chi.NewRouter()
28
29 r.Get("/oauth/client-metadata.json", o.clientMetadata)
30 r.Get("/oauth/jwks.json", o.jwks)
31 r.Get("/oauth/callback", o.callback)
32 return r
33}
34
35func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) {
36 doc := o.ClientApp.Config.ClientMetadata()
37 doc.JWKSURI = &o.JwksUri
38 doc.ClientName = &o.ClientName
39 doc.ClientURI = &o.ClientUri
40
41 w.Header().Set("Content-Type", "application/json")
42 if err := json.NewEncoder(w).Encode(doc); err != nil {
43 http.Error(w, err.Error(), http.StatusInternalServerError)
44 return
45 }
46}
47
48func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) {
49 w.Header().Set("Content-Type", "application/json")
50 body := o.ClientApp.Config.PublicJWKS()
51 if err := json.NewEncoder(w).Encode(body); err != nil {
52 http.Error(w, err.Error(), http.StatusInternalServerError)
53 return
54 }
55}
56
57func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) {
58 ctx := r.Context()
59 l := o.Logger.With("query", r.URL.Query())
60
61 authReturn := o.GetAuthReturn(r)
62 _ = o.ClearAuthReturn(w, r)
63
64 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query())
65 if err != nil {
66 var callbackErr *oauth.AuthRequestCallbackError
67 if errors.As(err, &callbackErr) {
68 l.Debug("callback error", "err", callbackErr)
69 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound)
70 return
71 }
72 l.Error("failed to process callback", "err", err)
73 http.Redirect(w, r, "/login?error=oauth", http.StatusFound)
74 return
75 }
76
77 if err := o.SaveSession(w, r, sessData); err != nil {
78 l.Error("failed to save session", "data", sessData, "err", err)
79 errorCode := "session"
80 if errors.Is(err, ErrMaxAccountsReached) {
81 errorCode = "max_accounts"
82 }
83 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound)
84 return
85 }
86
87 o.Logger.Debug("session saved successfully")
88
89 go o.addToDefaultKnot(sessData.AccountDID.String())
90 go o.addToDefaultSpindle(sessData.AccountDID.String())
91 go o.ensureTangledProfile(sessData)
92
93 if !o.Config.Core.Dev {
94 err = o.Posthog.Enqueue(posthog.Capture{
95 DistinctId: sessData.AccountDID.String(),
96 Event: "signin",
97 })
98 if err != nil {
99 o.Logger.Error("failed to enqueue posthog event", "err", err)
100 }
101 }
102
103 redirectURL := "/"
104 if authReturn.ReturnURL != "" {
105 redirectURL = authReturn.ReturnURL
106 }
107
108 http.Redirect(w, r, redirectURL, http.StatusFound)
109}
110
111func (o *OAuth) addToDefaultSpindle(did string) {
112 l := o.Logger.With("subject", did)
113
114 // use the tangled.sh app password to get an accessJwt
115 // and create an sh.tangled.spindle.member record with that
116 spindleMembers, err := db.GetSpindleMembers(
117 o.Db,
118 orm.FilterEq("instance", "spindle.tangled.sh"),
119 orm.FilterEq("subject", did),
120 )
121 if err != nil {
122 l.Error("failed to get spindle members", "err", err)
123 return
124 }
125
126 if len(spindleMembers) != 0 {
127 l.Warn("already a member of the default spindle")
128 return
129 }
130
131 l.Debug("adding to default spindle")
132 session, err := o.createAppPasswordSession(o.Config.Core.AppPassword, consts.TangledDid)
133 if err != nil {
134 l.Error("failed to create session", "err", err)
135 return
136 }
137
138 record := tangled.SpindleMember{
139 LexiconTypeID: "sh.tangled.spindle.member",
140 Subject: did,
141 Instance: consts.DefaultSpindle,
142 CreatedAt: time.Now().Format(time.RFC3339),
143 }
144
145 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil {
146 l.Error("failed to add to default spindle", "err", err)
147 return
148 }
149
150 l.Debug("successfully added to default spindle", "did", did)
151}
152
153func (o *OAuth) addToDefaultKnot(did string) {
154 l := o.Logger.With("subject", did)
155
156 // use the tangled.sh app password to get an accessJwt
157 // and create an sh.tangled.spindle.member record with that
158
159 allKnots, err := o.Enforcer.GetKnotsForUser(did)
160 if err != nil {
161 l.Error("failed to get knot members for did", "err", err)
162 return
163 }
164
165 if slices.Contains(allKnots, consts.DefaultKnot) {
166 l.Warn("already a member of the default knot")
167 return
168 }
169
170 l.Debug("addings to default knot")
171 session, err := o.createAppPasswordSession(o.Config.Core.TmpAltAppPassword, consts.IcyDid)
172 if err != nil {
173 l.Error("failed to create session", "err", err)
174 return
175 }
176
177 record := tangled.KnotMember{
178 LexiconTypeID: "sh.tangled.knot.member",
179 Subject: did,
180 Domain: consts.DefaultKnot,
181 CreatedAt: time.Now().Format(time.RFC3339),
182 }
183
184 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil {
185 l.Error("failed to add to default knot", "err", err)
186 return
187 }
188
189 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil {
190 l.Error("failed to set up enforcer rules", "err", err)
191 return
192 }
193
194 l.Debug("successfully addeds to default Knot")
195}
196
197func (o *OAuth) ensureTangledProfile(sessData *oauth.ClientSessionData) {
198 ctx := context.Background()
199 did := sessData.AccountDID.String()
200 l := o.Logger.With("did", did)
201
202 _, err := db.GetProfile(o.Db, did)
203 if err == nil {
204 l.Debug("profile already exists in DB")
205 return
206 }
207
208 l.Debug("creating empty Tangled profile")
209
210 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID)
211 if err != nil {
212 l.Error("failed to resume session for profile creation", "err", err)
213 return
214 }
215 client := sess.APIClient()
216
217 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{
218 Collection: tangled.ActorProfileNSID,
219 Repo: did,
220 Rkey: "self",
221 Record: &lexutil.LexiconTypeDecoder{Val: &tangled.ActorProfile{}},
222 })
223
224 if err != nil {
225 l.Error("failed to create empty profile on PDS", "err", err)
226 return
227 }
228
229 tx, err := o.Db.BeginTx(ctx, nil)
230 if err != nil {
231 l.Error("failed to start transaction", "err", err)
232 return
233 }
234
235 emptyProfile := &models.Profile{Did: did}
236 if err := db.UpsertProfile(tx, emptyProfile); err != nil {
237 l.Error("failed to create empty profile in DB", "err", err)
238 return
239 }
240
241 l.Debug("successfully created empty Tangled profile on PDS and DB")
242}
243
244// create a session using apppasswords
245type session struct {
246 AccessJwt string `json:"accessJwt"`
247 PdsEndpoint string
248 Did string
249}
250
251func (o *OAuth) createAppPasswordSession(appPassword, did string) (*session, error) {
252 if appPassword == "" {
253 return nil, fmt.Errorf("no app password configured, skipping member addition")
254 }
255
256 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did)
257 if err != nil {
258 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err)
259 }
260
261 pdsEndpoint := resolved.PDSEndpoint()
262 if pdsEndpoint == "" {
263 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did)
264 }
265
266 sessionPayload := map[string]string{
267 "identifier": did,
268 "password": appPassword,
269 }
270 sessionBytes, err := json.Marshal(sessionPayload)
271 if err != nil {
272 return nil, fmt.Errorf("failed to marshal session payload: %v", err)
273 }
274
275 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
276 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
277 if err != nil {
278 return nil, fmt.Errorf("failed to create session request: %v", err)
279 }
280 sessionReq.Header.Set("Content-Type", "application/json")
281
282 client := &http.Client{Timeout: 30 * time.Second}
283 sessionResp, err := client.Do(sessionReq)
284 if err != nil {
285 return nil, fmt.Errorf("failed to create session: %v", err)
286 }
287 defer sessionResp.Body.Close()
288
289 if sessionResp.StatusCode != http.StatusOK {
290 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode)
291 }
292
293 var session session
294 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
295 return nil, fmt.Errorf("failed to decode session response: %v", err)
296 }
297
298 session.PdsEndpoint = pdsEndpoint
299 session.Did = did
300
301 return &session, nil
302}
303
304func (s *session) putRecord(record any, collection string) error {
305 recordBytes, err := json.Marshal(record)
306 if err != nil {
307 return fmt.Errorf("failed to marshal knot member record: %w", err)
308 }
309
310 payload := map[string]any{
311 "repo": s.Did,
312 "collection": collection,
313 "rkey": tid.TID(),
314 "record": json.RawMessage(recordBytes),
315 }
316
317 payloadBytes, err := json.Marshal(payload)
318 if err != nil {
319 return fmt.Errorf("failed to marshal request payload: %w", err)
320 }
321
322 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
323 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
324 if err != nil {
325 return fmt.Errorf("failed to create HTTP request: %w", err)
326 }
327
328 req.Header.Set("Content-Type", "application/json")
329 req.Header.Set("Authorization", "Bearer "+s.AccessJwt)
330
331 client := &http.Client{Timeout: 30 * time.Second}
332 resp, err := client.Do(req)
333 if err != nil {
334 return fmt.Errorf("failed to add user to default service: %w", err)
335 }
336 defer resp.Body.Close()
337
338 if resp.StatusCode != http.StatusOK {
339 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode)
340 }
341
342 return nil
343}