From 1e0845772c689444f1ca53f3cded300230c9ab7a Mon Sep 17 00:00:00 2001 From: Lewis Date: Sat, 10 Jan 2026 10:30:10 +0200 Subject: [PATCH] state: add account switch/remove endpoints and login flow Change-Id: nwtvruzyxunylpksttluoyprkuktzvqq --- appview/state/accounts.go | 83 +++++++++++++++++++++++++++++++++++++++ appview/state/login.go | 64 ++++++++++++++++++++++++++---- appview/state/router.go | 5 +++ 3 files changed, 145 insertions(+), 7 deletions(-) create mode 100644 appview/state/accounts.go diff --git a/appview/state/accounts.go b/appview/state/accounts.go new file mode 100644 index 00000000..078340bc --- /dev/null +++ b/appview/state/accounts.go @@ -0,0 +1,83 @@ +package state + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +func (s *State) SwitchAccount(w http.ResponseWriter, r *http.Request) { + l := s.logger.With("handler", "SwitchAccount") + + if err := r.ParseForm(); err != nil { + l.Error("failed to parse form", "err", err) + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + + did := r.FormValue("did") + if did == "" { + http.Error(w, "missing did", http.StatusBadRequest) + return + } + + if err := s.oauth.SwitchAccount(w, r, did); err != nil { + l.Error("failed to switch account", "err", err) + s.pages.HxRedirect(w, "/login?error=session") + return + } + + l.Info("switched account", "did", did) + s.pages.HxRedirect(w, "/") +} + +func (s *State) RemoveAccount(w http.ResponseWriter, r *http.Request) { + l := s.logger.With("handler", "RemoveAccount") + + did := chi.URLParam(r, "did") + if did == "" { + http.Error(w, "missing did", http.StatusBadRequest) + return + } + + currentUser := s.oauth.GetMultiAccountUser(r) + isCurrentAccount := currentUser != nil && currentUser.Active.Did == did + + var remainingAccounts []string + if currentUser != nil { + for _, acc := range currentUser.Accounts { + if acc.Did != did { + remainingAccounts = append(remainingAccounts, acc.Did) + } + } + } + + if err := s.oauth.RemoveAccount(w, r, did); err != nil { + l.Error("failed to remove account", "err", err) + http.Error(w, "failed to remove account", http.StatusInternalServerError) + return + } + + l.Info("removed account", "did", did) + + if isCurrentAccount { + if len(remainingAccounts) > 0 { + nextDid := remainingAccounts[0] + if err := s.oauth.SwitchAccount(w, r, nextDid); err != nil { + l.Error("failed to switch to next account", "err", err) + s.pages.HxRedirect(w, "/login") + return + } + s.pages.HxRefresh(w) + return + } + + if err := s.oauth.DeleteSession(w, r); err != nil { + l.Error("failed to delete session", "err", err) + } + s.pages.HxRedirect(w, "/login") + return + } + + s.pages.HxRefresh(w) +} diff --git a/appview/state/login.go b/appview/state/login.go index 309ef2f5..a68348e6 100644 --- a/appview/state/login.go +++ b/appview/state/login.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + "tangled.org/core/appview/oauth" "tangled.org/core/appview/pages" ) @@ -15,12 +16,28 @@ func (s *State) Login(w http.ResponseWriter, r *http.Request) { case http.MethodGet: returnURL := r.URL.Query().Get("return_url") errorCode := r.URL.Query().Get("error") + addAccount := r.URL.Query().Get("mode") == "add_account" + + user := s.oauth.GetMultiAccountUser(r) + if user == nil { + registry := s.oauth.GetAccounts(r) + if len(registry.Accounts) > 0 { + user = &oauth.MultiAccountUser{ + Active: nil, + Accounts: registry.Accounts, + } + } + } s.pages.Login(w, pages.LoginParams{ - ReturnUrl: returnURL, - ErrorCode: errorCode, + ReturnUrl: returnURL, + ErrorCode: errorCode, + AddAccount: addAccount, + LoggedInUser: user, }) case http.MethodPost: handle := r.FormValue("handle") + returnURL := r.FormValue("return_url") + addAccount := r.FormValue("add_account") == "true" // when users copy their handle from bsky.app, it tends to have these characters around it: // @@ -44,6 +61,10 @@ func (s *State) Login(w http.ResponseWriter, r *http.Request) { return } + if err := s.oauth.SetAuthReturn(w, r, returnURL, addAccount); err != nil { + l.Error("failed to set auth return", "err", err) + } + redirectURL, err := s.oauth.ClientApp.StartAuthFlow(r.Context(), handle) if err != nil { l.Error("failed to start auth", "err", err) @@ -58,12 +79,41 @@ func (s *State) Login(w http.ResponseWriter, r *http.Request) { func (s *State) Logout(w http.ResponseWriter, r *http.Request) { l := s.logger.With("handler", "Logout") - err := s.oauth.DeleteSession(w, r) - if err != nil { - l.Error("failed to logout", "err", err) - } else { - l.Info("logged out successfully") + currentUser := s.oauth.GetMultiAccountUser(r) + if currentUser == nil || currentUser.Active == nil { + s.pages.HxRedirect(w, "/login") + return + } + + currentDid := currentUser.Active.Did + + var remainingAccounts []string + for _, acc := range currentUser.Accounts { + if acc.Did != currentDid { + remainingAccounts = append(remainingAccounts, acc.Did) + } + } + + if err := s.oauth.RemoveAccount(w, r, currentDid); err != nil { + l.Error("failed to remove account from registry", "err", err) + } + + if err := s.oauth.DeleteSession(w, r); err != nil { + l.Error("failed to delete session", "err", err) + } + + if len(remainingAccounts) > 0 { + nextDid := remainingAccounts[0] + if err := s.oauth.SwitchAccount(w, r, nextDid); err != nil { + l.Error("failed to switch to next account", "err", err) + s.pages.HxRedirect(w, "/login") + return + } + l.Info("switched to next account after logout", "did", nextDid) + s.pages.HxRefresh(w) + return } + l.Info("logged out last account") s.pages.HxRedirect(w, "/login") } diff --git a/appview/state/router.go b/appview/state/router.go index 6b14e647..62806db2 100644 --- a/appview/state/router.go +++ b/appview/state/router.go @@ -132,6 +132,11 @@ func (s *State) StandardRouter(mw *middleware.Middleware) http.Handler { r.Post("/login", s.Login) r.Post("/logout", s.Logout) + r.With(middleware.AuthMiddleware(s.oauth)).Route("/account", func(r chi.Router) { + r.Post("/switch", s.SwitchAccount) + r.Delete("/{did}", s.RemoveAccount) + }) + r.Route("/repo", func(r chi.Router) { r.Route("/new", func(r chi.Router) { r.Use(middleware.AuthMiddleware(s.oauth)) -- 2.43.0