Monorepo for Tangled tangled.org
at sl/shared-stacks 341 lines 9.5 kB view raw
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "net/http" 10 "slices" 11 "time" 12 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 14 "github.com/bluesky-social/indigo/atproto/auth/oauth" 15 lexutil "github.com/bluesky-social/indigo/lex/util" 16 "github.com/go-chi/chi/v5" 17 "github.com/posthog/posthog-go" 18 "tangled.org/core/api/tangled" 19 "tangled.org/core/appview/db" 20 "tangled.org/core/appview/models" 21 "tangled.org/core/consts" 22 "tangled.org/core/orm" 23 "tangled.org/core/tid" 24) 25 26func (o *OAuth) Router() http.Handler { 27 r := chi.NewRouter() 28 29 r.Get("/oauth/client-metadata.json", o.clientMetadata) 30 r.Get("/oauth/jwks.json", o.jwks) 31 r.Get("/oauth/callback", o.Callback) 32 return r 33} 34 35func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) { 36 doc := o.ClientApp.Config.ClientMetadata() 37 doc.JWKSURI = &o.JwksUri 38 doc.ClientName = &o.ClientName 39 doc.ClientURI = &o.ClientUri 40 41 w.Header().Set("Content-Type", "application/json") 42 if err := json.NewEncoder(w).Encode(doc); err != nil { 43 http.Error(w, err.Error(), http.StatusInternalServerError) 44 return 45 } 46} 47 48func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) { 49 w.Header().Set("Content-Type", "application/json") 50 body := o.ClientApp.Config.PublicJWKS() 51 if err := json.NewEncoder(w).Encode(body); err != nil { 52 http.Error(w, err.Error(), http.StatusInternalServerError) 53 return 54 } 55} 56 57func (o *OAuth) Callback(w http.ResponseWriter, r *http.Request) { 58 ctx := r.Context() 59 l := o.Logger.With("query", r.URL.Query()) 60 61 authReturn := o.GetAuthReturn(r) 62 _ = o.ClearAuthReturn(w, r) 63 64 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query()) 65 if err != nil { 66 var callbackErr *oauth.AuthRequestCallbackError 67 if errors.As(err, &callbackErr) { 68 l.Debug("callback error", "err", callbackErr) 69 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound) 70 return 71 } 72 l.Error("failed to process callback", "err", err) 73 http.Redirect(w, r, "/login?error=oauth", http.StatusFound) 74 return 75 } 76 77 if err := o.SaveSession(w, r, sessData); err != nil { 78 l.Error("failed to save session", "data", sessData, "err", err) 79 errorCode := "session" 80 if errors.Is(err, ErrMaxAccountsReached) { 81 errorCode = "max_accounts" 82 } 83 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound) 84 return 85 } 86 87 o.Logger.Debug("session saved successfully") 88 89 go o.addToDefaultKnot(sessData.AccountDID.String()) 90 go o.addToDefaultSpindle(sessData.AccountDID.String()) 91 go o.ensureTangledProfile(sessData) 92 93 if !o.Config.Core.Dev { 94 err = o.Posthog.Enqueue(posthog.Capture{ 95 DistinctId: sessData.AccountDID.String(), 96 Event: "signin", 97 }) 98 if err != nil { 99 o.Logger.Error("failed to enqueue posthog event", "err", err) 100 } 101 } 102 103 if authReturn == "" { 104 authReturn = "/" 105 } 106 http.Redirect(w, r, authReturn, http.StatusFound) 107} 108 109func (o *OAuth) addToDefaultSpindle(did string) { 110 l := o.Logger.With("subject", did) 111 112 // use the tangled.sh app password to get an accessJwt 113 // and create an sh.tangled.spindle.member record with that 114 spindleMembers, err := db.GetSpindleMembers( 115 o.Db, 116 orm.FilterEq("instance", "spindle.tangled.sh"), 117 orm.FilterEq("subject", did), 118 ) 119 if err != nil { 120 l.Error("failed to get spindle members", "err", err) 121 return 122 } 123 124 if len(spindleMembers) != 0 { 125 l.Warn("already a member of the default spindle") 126 return 127 } 128 129 l.Debug("adding to default spindle") 130 session, err := o.createAppPasswordSession(o.Config.Core.AppPassword, consts.TangledDid) 131 if err != nil { 132 l.Error("failed to create session", "err", err) 133 return 134 } 135 136 record := tangled.SpindleMember{ 137 LexiconTypeID: "sh.tangled.spindle.member", 138 Subject: did, 139 Instance: consts.DefaultSpindle, 140 CreatedAt: time.Now().Format(time.RFC3339), 141 } 142 143 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil { 144 l.Error("failed to add to default spindle", "err", err) 145 return 146 } 147 148 l.Debug("successfully added to default spindle", "did", did) 149} 150 151func (o *OAuth) addToDefaultKnot(did string) { 152 l := o.Logger.With("subject", did) 153 154 // use the tangled.sh app password to get an accessJwt 155 // and create an sh.tangled.spindle.member record with that 156 157 allKnots, err := o.Enforcer.GetKnotsForUser(did) 158 if err != nil { 159 l.Error("failed to get knot members for did", "err", err) 160 return 161 } 162 163 if slices.Contains(allKnots, consts.DefaultKnot) { 164 l.Warn("already a member of the default knot") 165 return 166 } 167 168 l.Debug("adding to default knot") 169 session, err := o.createAppPasswordSession(o.Config.Core.TmpAltAppPassword, consts.IcyDid) 170 if err != nil { 171 l.Error("failed to create session", "err", err) 172 return 173 } 174 175 record := tangled.KnotMember{ 176 LexiconTypeID: "sh.tangled.knot.member", 177 Subject: did, 178 Domain: consts.DefaultKnot, 179 CreatedAt: time.Now().Format(time.RFC3339), 180 } 181 182 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil { 183 l.Error("failed to add to default knot", "err", err) 184 return 185 } 186 187 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil { 188 l.Error("failed to set up enforcer rules", "err", err) 189 return 190 } 191 192 l.Debug("successfully addeds to default Knot") 193} 194 195func (o *OAuth) ensureTangledProfile(sessData *oauth.ClientSessionData) { 196 ctx := context.Background() 197 did := sessData.AccountDID.String() 198 l := o.Logger.With("did", did) 199 200 _, err := db.GetProfile(o.Db, did) 201 if err == nil { 202 l.Debug("profile already exists in DB") 203 return 204 } 205 206 l.Debug("creating empty Tangled profile") 207 208 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID) 209 if err != nil { 210 l.Error("failed to resume session for profile creation", "err", err) 211 return 212 } 213 client := sess.APIClient() 214 215 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{ 216 Collection: tangled.ActorProfileNSID, 217 Repo: did, 218 Rkey: "self", 219 Record: &lexutil.LexiconTypeDecoder{Val: &tangled.ActorProfile{}}, 220 }) 221 222 if err != nil { 223 l.Error("failed to create empty profile on PDS", "err", err) 224 return 225 } 226 227 tx, err := o.Db.BeginTx(ctx, nil) 228 if err != nil { 229 l.Error("failed to start transaction", "err", err) 230 return 231 } 232 233 emptyProfile := &models.Profile{Did: did} 234 if err := db.UpsertProfile(tx, emptyProfile); err != nil { 235 l.Error("failed to create empty profile in DB", "err", err) 236 return 237 } 238 239 l.Debug("successfully created empty Tangled profile on PDS and DB") 240} 241 242// create a session using apppasswords 243type session struct { 244 AccessJwt string `json:"accessJwt"` 245 PdsEndpoint string 246 Did string 247} 248 249func (o *OAuth) createAppPasswordSession(appPassword, did string) (*session, error) { 250 if appPassword == "" { 251 return nil, fmt.Errorf("no app password configured, skipping member addition") 252 } 253 254 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did) 255 if err != nil { 256 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err) 257 } 258 259 pdsEndpoint := resolved.PDSEndpoint() 260 if pdsEndpoint == "" { 261 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did) 262 } 263 264 sessionPayload := map[string]string{ 265 "identifier": did, 266 "password": appPassword, 267 } 268 sessionBytes, err := json.Marshal(sessionPayload) 269 if err != nil { 270 return nil, fmt.Errorf("failed to marshal session payload: %v", err) 271 } 272 273 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 274 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 275 if err != nil { 276 return nil, fmt.Errorf("failed to create session request: %v", err) 277 } 278 sessionReq.Header.Set("Content-Type", "application/json") 279 280 client := &http.Client{Timeout: 30 * time.Second} 281 sessionResp, err := client.Do(sessionReq) 282 if err != nil { 283 return nil, fmt.Errorf("failed to create session: %v", err) 284 } 285 defer sessionResp.Body.Close() 286 287 if sessionResp.StatusCode != http.StatusOK { 288 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode) 289 } 290 291 var session session 292 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 293 return nil, fmt.Errorf("failed to decode session response: %v", err) 294 } 295 296 session.PdsEndpoint = pdsEndpoint 297 session.Did = did 298 299 return &session, nil 300} 301 302func (s *session) putRecord(record any, collection string) error { 303 recordBytes, err := json.Marshal(record) 304 if err != nil { 305 return fmt.Errorf("failed to marshal knot member record: %w", err) 306 } 307 308 payload := map[string]any{ 309 "repo": s.Did, 310 "collection": collection, 311 "rkey": tid.TID(), 312 "record": json.RawMessage(recordBytes), 313 } 314 315 payloadBytes, err := json.Marshal(payload) 316 if err != nil { 317 return fmt.Errorf("failed to marshal request payload: %w", err) 318 } 319 320 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 321 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 322 if err != nil { 323 return fmt.Errorf("failed to create HTTP request: %w", err) 324 } 325 326 req.Header.Set("Content-Type", "application/json") 327 req.Header.Set("Authorization", "Bearer "+s.AccessJwt) 328 329 client := &http.Client{Timeout: 30 * time.Second} 330 resp, err := client.Do(req) 331 if err != nil { 332 return fmt.Errorf("failed to add user to default service: %w", err) 333 } 334 defer resp.Body.Close() 335 336 if resp.StatusCode != http.StatusOK { 337 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode) 338 } 339 340 return nil 341}