From 46a766c4f1d0ff9f48ddfc85a6aadb0eded4ae29 Mon Sep 17 00:00:00 2001 From: Lewis Date: Sat, 10 Jan 2026 10:29:09 +0200 Subject: [PATCH] oauth: integrate multi-account into session management Change-Id: rnxuwmomkoslrszywxwtlrmtvvxmxxks --- appview/oauth/handler.go | 16 +++++++-- appview/oauth/oauth.go | 70 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/appview/oauth/handler.go b/appview/oauth/handler.go index f95e2aa5..824543cc 100644 --- a/appview/oauth/handler.go +++ b/appview/oauth/handler.go @@ -55,6 +55,9 @@ func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { ctx := r.Context() l := o.Logger.With("query", r.URL.Query()) + authReturn := o.GetAuthReturn(r) + _ = o.ClearAuthReturn(w, r) + sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query()) if err != nil { var callbackErr *oauth.AuthRequestCallbackError @@ -70,7 +73,11 @@ func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { if err := o.SaveSession(w, r, sessData); err != nil { l.Error("failed to save session", "data", sessData, "err", err) - http.Redirect(w, r, "/login?error=session", http.StatusFound) + errorCode := "session" + if errors.Is(err, ErrMaxAccountsReached) { + errorCode = "max_accounts" + } + http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound) return } @@ -88,7 +95,12 @@ func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { } } - http.Redirect(w, r, "/", http.StatusFound) + redirectURL := "/" + if authReturn.ReturnURL != "" { + redirectURL = authReturn.ReturnURL + } + + http.Redirect(w, r, redirectURL, http.StatusFound) } func (o *OAuth) addToDefaultSpindle(did string) { diff --git a/appview/oauth/oauth.go b/appview/oauth/oauth.go index 8e94fc23..1fa3e68d 100644 --- a/appview/oauth/oauth.go +++ b/appview/oauth/oauth.go @@ -98,7 +98,6 @@ func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enf } func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { - // first we save the did in the user session userSession, err := o.SessStore.Get(r, SessionName) if err != nil { return err @@ -108,7 +107,22 @@ func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oa userSession.Values[SessionPds] = sessData.HostURL userSession.Values[SessionId] = sessData.SessionID userSession.Values[SessionAuthenticated] = true - return userSession.Save(r, w) + + if err := userSession.Save(r, w); err != nil { + return err + } + + handle := "" + resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) + if err == nil && resolved.Handle.String() != "" { + handle = resolved.Handle.String() + } + + registry := o.GetAccounts(r) + if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { + return err + } + return o.SaveAccounts(w, r, registry) } func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { @@ -163,6 +177,54 @@ func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { return errors.Join(err1, err2) } +func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { + registry := o.GetAccounts(r) + account := registry.FindAccount(targetDid) + if account == nil { + return fmt.Errorf("account not found in registry: %s", targetDid) + } + + did, err := syntax.ParseDID(targetDid) + if err != nil { + return fmt.Errorf("invalid DID: %w", err) + } + + sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId) + if err != nil { + registry.RemoveAccount(targetDid) + _ = o.SaveAccounts(w, r, registry) + return fmt.Errorf("session expired for account: %w", err) + } + + userSession, err := o.SessStore.Get(r, SessionName) + if err != nil { + return err + } + + userSession.Values[SessionDid] = sess.Data.AccountDID.String() + userSession.Values[SessionPds] = sess.Data.HostURL + userSession.Values[SessionId] = sess.Data.SessionID + userSession.Values[SessionAuthenticated] = true + + return userSession.Save(r, w) +} + +func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { + registry := o.GetAccounts(r) + account := registry.FindAccount(targetDid) + if account == nil { + return nil + } + + did, err := syntax.ParseDID(targetDid) + if err == nil { + _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) + } + + registry.RemoveAccount(targetDid) + return o.SaveAccounts(w, r, registry) +} + type User struct { Did string Pds string @@ -181,8 +243,8 @@ func (o *OAuth) GetUser(r *http.Request) *User { } func (o *OAuth) GetDid(r *http.Request) string { - if u := o.GetUser(r); u != nil { - return u.Did + if u := o.GetMultiAccountUser(r); u != nil { + return u.Did() } return "" -- 2.43.0