From 219ae29631068411c586c0f6d275545fd36e2929 Mon Sep 17 00:00:00 2001 From: Lewis Date: Sat, 10 Jan 2026 10:28:53 +0200 Subject: [PATCH] oauth: add multi-account infrastructure Change-Id: qplmzlnvokrxrzvxwvmpkkkyyrpxoolm --- appview/oauth/accounts.go | 191 ++++++++++++++++++++++++ appview/oauth/accounts_test.go | 265 +++++++++++++++++++++++++++++++++ appview/oauth/consts.go | 6 +- 3 files changed, 461 insertions(+), 1 deletion(-) create mode 100644 appview/oauth/accounts.go create mode 100644 appview/oauth/accounts_test.go diff --git a/appview/oauth/accounts.go b/appview/oauth/accounts.go new file mode 100644 index 00000000..5197715e --- /dev/null +++ b/appview/oauth/accounts.go @@ -0,0 +1,191 @@ +package oauth + +import ( + "encoding/json" + "errors" + "net/http" + "time" +) + +const MaxAccounts = 20 + +var ErrMaxAccountsReached = errors.New("maximum number of linked accounts reached") + +type AccountInfo struct { + Did string `json:"did"` + Handle string `json:"handle"` + SessionId string `json:"session_id"` + AddedAt int64 `json:"added_at"` +} + +type AccountRegistry struct { + Accounts []AccountInfo `json:"accounts"` +} + +type MultiAccountUser struct { + Active *User + Accounts []AccountInfo +} + +func (m *MultiAccountUser) Did() string { + if m.Active == nil { + return "" + } + return m.Active.Did +} + +func (m *MultiAccountUser) Pds() string { + if m.Active == nil { + return "" + } + return m.Active.Pds +} + +func (o *OAuth) GetAccounts(r *http.Request) *AccountRegistry { + session, err := o.SessStore.Get(r, AccountsName) + if err != nil || session.IsNew { + return &AccountRegistry{Accounts: []AccountInfo{}} + } + + data, ok := session.Values["accounts"].(string) + if !ok { + return &AccountRegistry{Accounts: []AccountInfo{}} + } + + var registry AccountRegistry + if err := json.Unmarshal([]byte(data), ®istry); err != nil { + return &AccountRegistry{Accounts: []AccountInfo{}} + } + + return ®istry +} + +func (o *OAuth) SaveAccounts(w http.ResponseWriter, r *http.Request, registry *AccountRegistry) error { + session, err := o.SessStore.Get(r, AccountsName) + if err != nil { + return err + } + + data, err := json.Marshal(registry) + if err != nil { + return err + } + + session.Values["accounts"] = string(data) + session.Options.MaxAge = 60 * 60 * 24 * 365 + session.Options.HttpOnly = true + session.Options.Secure = !o.Config.Core.Dev + session.Options.SameSite = http.SameSiteLaxMode + + return session.Save(r, w) +} + +func (r *AccountRegistry) AddAccount(did, handle, sessionId string) error { + for i, acc := range r.Accounts { + if acc.Did == did { + r.Accounts[i].SessionId = sessionId + r.Accounts[i].Handle = handle + return nil + } + } + + if len(r.Accounts) >= MaxAccounts { + return ErrMaxAccountsReached + } + + r.Accounts = append(r.Accounts, AccountInfo{ + Did: did, + Handle: handle, + SessionId: sessionId, + AddedAt: time.Now().Unix(), + }) + return nil +} + +func (r *AccountRegistry) RemoveAccount(did string) { + filtered := make([]AccountInfo, 0, len(r.Accounts)) + for _, acc := range r.Accounts { + if acc.Did != did { + filtered = append(filtered, acc) + } + } + r.Accounts = filtered +} + +func (r *AccountRegistry) FindAccount(did string) *AccountInfo { + for i := range r.Accounts { + if r.Accounts[i].Did == did { + return &r.Accounts[i] + } + } + return nil +} + +func (r *AccountRegistry) OtherAccounts(activeDid string) []AccountInfo { + result := make([]AccountInfo, 0, len(r.Accounts)) + for _, acc := range r.Accounts { + if acc.Did != activeDid { + result = append(result, acc) + } + } + return result +} + +func (o *OAuth) GetMultiAccountUser(r *http.Request) *MultiAccountUser { + user := o.GetUser(r) + if user == nil { + return nil + } + + registry := o.GetAccounts(r) + return &MultiAccountUser{ + Active: user, + Accounts: registry.Accounts, + } +} + +type AuthReturnInfo struct { + ReturnURL string + AddAccount bool +} + +func (o *OAuth) SetAuthReturn(w http.ResponseWriter, r *http.Request, returnURL string, addAccount bool) error { + session, err := o.SessStore.Get(r, AuthReturnName) + if err != nil { + return err + } + + session.Values[AuthReturnURL] = returnURL + session.Values[AuthAddAccount] = addAccount + session.Options.MaxAge = 60 * 30 + session.Options.HttpOnly = true + session.Options.Secure = !o.Config.Core.Dev + session.Options.SameSite = http.SameSiteLaxMode + + return session.Save(r, w) +} + +func (o *OAuth) GetAuthReturn(r *http.Request) *AuthReturnInfo { + session, err := o.SessStore.Get(r, AuthReturnName) + if err != nil || session.IsNew { + return &AuthReturnInfo{} + } + + returnURL, _ := session.Values[AuthReturnURL].(string) + addAccount, _ := session.Values[AuthAddAccount].(bool) + + return &AuthReturnInfo{ + ReturnURL: returnURL, + AddAccount: addAccount, + } +} + +func (o *OAuth) ClearAuthReturn(w http.ResponseWriter, r *http.Request) error { + session, err := o.SessStore.Get(r, AuthReturnName) + if err != nil { + return err + } + + session.Options.MaxAge = -1 + return session.Save(r, w) +} diff --git a/appview/oauth/accounts_test.go b/appview/oauth/accounts_test.go new file mode 100644 index 00000000..76cb9ae4 --- /dev/null +++ b/appview/oauth/accounts_test.go @@ -0,0 +1,265 @@ +package oauth + +import ( + "testing" +) + +func TestAccountRegistry_AddAccount(t *testing.T) { + tests := []struct { + name string + initial []AccountInfo + addDid string + addHandle string + addSessionId string + wantErr error + wantLen int + wantSessionId string + }{ + { + name: "add first account", + initial: []AccountInfo{}, + addDid: "did:plc:abc123", + addHandle: "alice.bsky.social", + addSessionId: "session-1", + wantErr: nil, + wantLen: 1, + wantSessionId: "session-1", + }, + { + name: "add second account", + initial: []AccountInfo{ + {Did: "did:plc:abc123", Handle: "alice.bsky.social", SessionId: "session-1", AddedAt: 1000}, + }, + addDid: "did:plc:def456", + addHandle: "bob.bsky.social", + addSessionId: "session-2", + wantErr: nil, + wantLen: 2, + wantSessionId: "session-2", + }, + { + name: "update existing account session", + initial: []AccountInfo{ + {Did: "did:plc:abc123", Handle: "alice.bsky.social", SessionId: "old-session", AddedAt: 1000}, + }, + addDid: "did:plc:abc123", + addHandle: "alice.bsky.social", + addSessionId: "new-session", + wantErr: nil, + wantLen: 1, + wantSessionId: "new-session", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry := &AccountRegistry{Accounts: tt.initial} + err := registry.AddAccount(tt.addDid, tt.addHandle, tt.addSessionId) + + if err != tt.wantErr { + t.Errorf("AddAccount() error = %v, want %v", err, tt.wantErr) + } + + if len(registry.Accounts) != tt.wantLen { + t.Errorf("AddAccount() len = %d, want %d", len(registry.Accounts), tt.wantLen) + } + + found := registry.FindAccount(tt.addDid) + if found == nil { + t.Errorf("AddAccount() account not found after add") + return + } + + if found.SessionId != tt.wantSessionId { + t.Errorf("AddAccount() sessionId = %s, want %s", found.SessionId, tt.wantSessionId) + } + }) + } +} + +func TestAccountRegistry_AddAccount_MaxLimit(t *testing.T) { + registry := &AccountRegistry{Accounts: make([]AccountInfo, 0, MaxAccounts)} + + for i := range MaxAccounts { + err := registry.AddAccount("did:plc:user"+string(rune('a'+i)), "handle", "session") + if err != nil { + t.Fatalf("AddAccount() unexpected error on account %d: %v", i, err) + } + } + + if len(registry.Accounts) != MaxAccounts { + t.Errorf("expected %d accounts, got %d", MaxAccounts, len(registry.Accounts)) + } + + err := registry.AddAccount("did:plc:overflow", "overflow", "session-overflow") + if err != ErrMaxAccountsReached { + t.Errorf("AddAccount() error = %v, want %v", err, ErrMaxAccountsReached) + } + + if len(registry.Accounts) != MaxAccounts { + t.Errorf("account added despite max limit, got %d", len(registry.Accounts)) + } +} + +func TestAccountRegistry_RemoveAccount(t *testing.T) { + tests := []struct { + name string + initial []AccountInfo + removeDid string + wantLen int + wantDids []string + }{ + { + name: "remove existing account", + initial: []AccountInfo{ + {Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"}, + {Did: "did:plc:def456", Handle: "bob", SessionId: "s2"}, + }, + removeDid: "did:plc:abc123", + wantLen: 1, + wantDids: []string{"did:plc:def456"}, + }, + { + name: "remove non-existing account", + initial: []AccountInfo{ + {Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"}, + }, + removeDid: "did:plc:notfound", + wantLen: 1, + wantDids: []string{"did:plc:abc123"}, + }, + { + name: "remove last account", + initial: []AccountInfo{ + {Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"}, + }, + removeDid: "did:plc:abc123", + wantLen: 0, + wantDids: []string{}, + }, + { + name: "remove from empty registry", + initial: []AccountInfo{}, + removeDid: "did:plc:abc123", + wantLen: 0, + wantDids: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry := &AccountRegistry{Accounts: tt.initial} + registry.RemoveAccount(tt.removeDid) + + if len(registry.Accounts) != tt.wantLen { + t.Errorf("RemoveAccount() len = %d, want %d", len(registry.Accounts), tt.wantLen) + } + + for _, wantDid := range tt.wantDids { + if registry.FindAccount(wantDid) == nil { + t.Errorf("RemoveAccount() expected %s to remain", wantDid) + } + } + + if registry.FindAccount(tt.removeDid) != nil && tt.wantLen < len(tt.initial) { + t.Errorf("RemoveAccount() %s should have been removed", tt.removeDid) + } + }) + } +} + +func TestAccountRegistry_FindAccount(t *testing.T) { + registry := &AccountRegistry{ + Accounts: []AccountInfo{ + {Did: "did:plc:first", Handle: "first", SessionId: "s1", AddedAt: 1000}, + {Did: "did:plc:second", Handle: "second", SessionId: "s2", AddedAt: 2000}, + {Did: "did:plc:third", Handle: "third", SessionId: "s3", AddedAt: 3000}, + }, + } + + t.Run("find existing account", func(t *testing.T) { + found := registry.FindAccount("did:plc:second") + if found == nil { + t.Fatal("FindAccount() returned nil for existing account") + } + if found.Handle != "second" { + t.Errorf("FindAccount() handle = %s, want second", found.Handle) + } + if found.SessionId != "s2" { + t.Errorf("FindAccount() sessionId = %s, want s2", found.SessionId) + } + }) + + t.Run("find non-existing account", func(t *testing.T) { + found := registry.FindAccount("did:plc:notfound") + if found != nil { + t.Errorf("FindAccount() = %v, want nil", found) + } + }) + + t.Run("returned pointer is mutable", func(t *testing.T) { + found := registry.FindAccount("did:plc:first") + if found == nil { + t.Fatal("FindAccount() returned nil") + } + found.SessionId = "modified" + + refetch := registry.FindAccount("did:plc:first") + if refetch.SessionId != "modified" { + t.Errorf("FindAccount() pointer not referencing original, got %s", refetch.SessionId) + } + }) +} + +func TestAccountRegistry_OtherAccounts(t *testing.T) { + registry := &AccountRegistry{ + Accounts: []AccountInfo{ + {Did: "did:plc:active", Handle: "active", SessionId: "s1"}, + {Did: "did:plc:other1", Handle: "other1", SessionId: "s2"}, + {Did: "did:plc:other2", Handle: "other2", SessionId: "s3"}, + }, + } + + others := registry.OtherAccounts("did:plc:active") + + if len(others) != 2 { + t.Errorf("OtherAccounts() len = %d, want 2", len(others)) + } + + for _, acc := range others { + if acc.Did == "did:plc:active" { + t.Errorf("OtherAccounts() should not include active account") + } + } + + hasDid := func(did string) bool { + for _, acc := range others { + if acc.Did == did { + return true + } + } + return false + } + + if !hasDid("did:plc:other1") || !hasDid("did:plc:other2") { + t.Errorf("OtherAccounts() missing expected accounts") + } +} + +func TestMultiAccountUser_Did(t *testing.T) { + t.Run("with active user", func(t *testing.T) { + user := &MultiAccountUser{ + Active: &User{Did: "did:plc:test", Pds: "https://bsky.social"}, + } + if user.Did() != "did:plc:test" { + t.Errorf("Did() = %s, want did:plc:test", user.Did()) + } + }) + + t.Run("with nil active", func(t *testing.T) { + user := &MultiAccountUser{Active: nil} + if user.Did() != "" { + t.Errorf("Did() = %s, want empty string", user.Did()) + } + }) +} diff --git a/appview/oauth/consts.go b/appview/oauth/consts.go index f86ca6a0..ab736ce7 100644 --- a/appview/oauth/consts.go +++ b/appview/oauth/consts.go @@ -1,7 +1,11 @@ package oauth const ( - SessionName = "appview-session-v2" + SessionName = "appview-session-v2" + AccountsName = "appview-accounts-v2" + AuthReturnName = "appview-auth-return" + AuthReturnURL = "return_url" + AuthAddAccount = "add_account" SessionHandle = "handle" SessionDid = "did" SessionId = "id" -- 2.43.0