at default-knot 343 lines 9.5 kB view raw
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "net/http" 10 "slices" 11 "time" 12 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 14 "github.com/bluesky-social/indigo/atproto/auth/oauth" 15 lexutil "github.com/bluesky-social/indigo/lex/util" 16 "github.com/go-chi/chi/v5" 17 "github.com/posthog/posthog-go" 18 "tangled.org/core/api/tangled" 19 "tangled.org/core/appview/db" 20 "tangled.org/core/appview/models" 21 "tangled.org/core/consts" 22 "tangled.org/core/orm" 23 "tangled.org/core/tid" 24) 25 26func (o *OAuth) Router() http.Handler { 27 r := chi.NewRouter() 28 29 r.Get("/oauth/client-metadata.json", o.clientMetadata) 30 r.Get("/oauth/jwks.json", o.jwks) 31 r.Get("/oauth/callback", o.callback) 32 return r 33} 34 35func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) { 36 doc := o.ClientApp.Config.ClientMetadata() 37 doc.JWKSURI = &o.JwksUri 38 doc.ClientName = &o.ClientName 39 doc.ClientURI = &o.ClientUri 40 41 w.Header().Set("Content-Type", "application/json") 42 if err := json.NewEncoder(w).Encode(doc); err != nil { 43 http.Error(w, err.Error(), http.StatusInternalServerError) 44 return 45 } 46} 47 48func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) { 49 w.Header().Set("Content-Type", "application/json") 50 body := o.ClientApp.Config.PublicJWKS() 51 if err := json.NewEncoder(w).Encode(body); err != nil { 52 http.Error(w, err.Error(), http.StatusInternalServerError) 53 return 54 } 55} 56 57func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { 58 ctx := r.Context() 59 l := o.Logger.With("query", r.URL.Query()) 60 61 authReturn := o.GetAuthReturn(r) 62 _ = o.ClearAuthReturn(w, r) 63 64 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query()) 65 if err != nil { 66 var callbackErr *oauth.AuthRequestCallbackError 67 if errors.As(err, &callbackErr) { 68 l.Debug("callback error", "err", callbackErr) 69 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound) 70 return 71 } 72 l.Error("failed to process callback", "err", err) 73 http.Redirect(w, r, "/login?error=oauth", http.StatusFound) 74 return 75 } 76 77 if err := o.SaveSession(w, r, sessData); err != nil { 78 l.Error("failed to save session", "data", sessData, "err", err) 79 errorCode := "session" 80 if errors.Is(err, ErrMaxAccountsReached) { 81 errorCode = "max_accounts" 82 } 83 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound) 84 return 85 } 86 87 o.Logger.Debug("session saved successfully") 88 89 go o.addToDefaultKnot(sessData.AccountDID.String()) 90 go o.addToDefaultSpindle(sessData.AccountDID.String()) 91 go o.ensureTangledProfile(sessData) 92 93 if !o.Config.Core.Dev { 94 err = o.Posthog.Enqueue(posthog.Capture{ 95 DistinctId: sessData.AccountDID.String(), 96 Event: "signin", 97 }) 98 if err != nil { 99 o.Logger.Error("failed to enqueue posthog event", "err", err) 100 } 101 } 102 103 redirectURL := "/" 104 if authReturn.ReturnURL != "" { 105 redirectURL = authReturn.ReturnURL 106 } 107 108 http.Redirect(w, r, redirectURL, http.StatusFound) 109} 110 111func (o *OAuth) addToDefaultSpindle(did string) { 112 l := o.Logger.With("subject", did) 113 114 // use the tangled.sh app password to get an accessJwt 115 // and create an sh.tangled.spindle.member record with that 116 spindleMembers, err := db.GetSpindleMembers( 117 o.Db, 118 orm.FilterEq("instance", "spindle.tangled.sh"), 119 orm.FilterEq("subject", did), 120 ) 121 if err != nil { 122 l.Error("failed to get spindle members", "err", err) 123 return 124 } 125 126 if len(spindleMembers) != 0 { 127 l.Warn("already a member of the default spindle") 128 return 129 } 130 131 l.Debug("adding to default spindle") 132 session, err := o.createAppPasswordSession(o.Config.Core.AppPassword, consts.TangledDid) 133 if err != nil { 134 l.Error("failed to create session", "err", err) 135 return 136 } 137 138 record := tangled.SpindleMember{ 139 LexiconTypeID: "sh.tangled.spindle.member", 140 Subject: did, 141 Instance: consts.DefaultSpindle, 142 CreatedAt: time.Now().Format(time.RFC3339), 143 } 144 145 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil { 146 l.Error("failed to add to default spindle", "err", err) 147 return 148 } 149 150 l.Debug("successfully added to default spindle", "did", did) 151} 152 153func (o *OAuth) addToDefaultKnot(did string) { 154 l := o.Logger.With("subject", did) 155 156 // use the tangled.sh app password to get an accessJwt 157 // and create an sh.tangled.spindle.member record with that 158 159 allKnots, err := o.Enforcer.GetKnotsForUser(did) 160 if err != nil { 161 l.Error("failed to get knot members for did", "err", err) 162 return 163 } 164 165 if slices.Contains(allKnots, consts.DefaultKnot) { 166 l.Warn("already a member of the default knot") 167 return 168 } 169 170 l.Debug("addings to default knot") 171 session, err := o.createAppPasswordSession(o.Config.Core.TmpAltAppPassword, consts.IcyDid) 172 if err != nil { 173 l.Error("failed to create session", "err", err) 174 return 175 } 176 177 record := tangled.KnotMember{ 178 LexiconTypeID: "sh.tangled.knot.member", 179 Subject: did, 180 Domain: consts.DefaultKnot, 181 CreatedAt: time.Now().Format(time.RFC3339), 182 } 183 184 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil { 185 l.Error("failed to add to default knot", "err", err) 186 return 187 } 188 189 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil { 190 l.Error("failed to set up enforcer rules", "err", err) 191 return 192 } 193 194 l.Debug("successfully addeds to default Knot") 195} 196 197func (o *OAuth) ensureTangledProfile(sessData *oauth.ClientSessionData) { 198 ctx := context.Background() 199 did := sessData.AccountDID.String() 200 l := o.Logger.With("did", did) 201 202 _, err := db.GetProfile(o.Db, did) 203 if err == nil { 204 l.Debug("profile already exists in DB") 205 return 206 } 207 208 l.Debug("creating empty Tangled profile") 209 210 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID) 211 if err != nil { 212 l.Error("failed to resume session for profile creation", "err", err) 213 return 214 } 215 client := sess.APIClient() 216 217 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{ 218 Collection: tangled.ActorProfileNSID, 219 Repo: did, 220 Rkey: "self", 221 Record: &lexutil.LexiconTypeDecoder{Val: &tangled.ActorProfile{}}, 222 }) 223 224 if err != nil { 225 l.Error("failed to create empty profile on PDS", "err", err) 226 return 227 } 228 229 tx, err := o.Db.BeginTx(ctx, nil) 230 if err != nil { 231 l.Error("failed to start transaction", "err", err) 232 return 233 } 234 235 emptyProfile := &models.Profile{Did: did} 236 if err := db.UpsertProfile(tx, emptyProfile); err != nil { 237 l.Error("failed to create empty profile in DB", "err", err) 238 return 239 } 240 241 l.Debug("successfully created empty Tangled profile on PDS and DB") 242} 243 244// create a session using apppasswords 245type session struct { 246 AccessJwt string `json:"accessJwt"` 247 PdsEndpoint string 248 Did string 249} 250 251func (o *OAuth) createAppPasswordSession(appPassword, did string) (*session, error) { 252 if appPassword == "" { 253 return nil, fmt.Errorf("no app password configured, skipping member addition") 254 } 255 256 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did) 257 if err != nil { 258 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err) 259 } 260 261 pdsEndpoint := resolved.PDSEndpoint() 262 if pdsEndpoint == "" { 263 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did) 264 } 265 266 sessionPayload := map[string]string{ 267 "identifier": did, 268 "password": appPassword, 269 } 270 sessionBytes, err := json.Marshal(sessionPayload) 271 if err != nil { 272 return nil, fmt.Errorf("failed to marshal session payload: %v", err) 273 } 274 275 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 276 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 277 if err != nil { 278 return nil, fmt.Errorf("failed to create session request: %v", err) 279 } 280 sessionReq.Header.Set("Content-Type", "application/json") 281 282 client := &http.Client{Timeout: 30 * time.Second} 283 sessionResp, err := client.Do(sessionReq) 284 if err != nil { 285 return nil, fmt.Errorf("failed to create session: %v", err) 286 } 287 defer sessionResp.Body.Close() 288 289 if sessionResp.StatusCode != http.StatusOK { 290 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode) 291 } 292 293 var session session 294 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 295 return nil, fmt.Errorf("failed to decode session response: %v", err) 296 } 297 298 session.PdsEndpoint = pdsEndpoint 299 session.Did = did 300 301 return &session, nil 302} 303 304func (s *session) putRecord(record any, collection string) error { 305 recordBytes, err := json.Marshal(record) 306 if err != nil { 307 return fmt.Errorf("failed to marshal knot member record: %w", err) 308 } 309 310 payload := map[string]any{ 311 "repo": s.Did, 312 "collection": collection, 313 "rkey": tid.TID(), 314 "record": json.RawMessage(recordBytes), 315 } 316 317 payloadBytes, err := json.Marshal(payload) 318 if err != nil { 319 return fmt.Errorf("failed to marshal request payload: %w", err) 320 } 321 322 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 323 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 324 if err != nil { 325 return fmt.Errorf("failed to create HTTP request: %w", err) 326 } 327 328 req.Header.Set("Content-Type", "application/json") 329 req.Header.Set("Authorization", "Bearer "+s.AccessJwt) 330 331 client := &http.Client{Timeout: 30 * time.Second} 332 resp, err := client.Do(req) 333 if err != nil { 334 return fmt.Errorf("failed to add user to default service: %w", err) 335 } 336 defer resp.Body.Close() 337 338 if resp.StatusCode != http.StatusOK { 339 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode) 340 } 341 342 return nil 343}