oauth: add multi-account infrastructure

Changed files
+461 -1
appview
+191
appview/oauth/accounts.go
··· 1 + package oauth 2 + 3 + import ( 4 + "encoding/json" 5 + "errors" 6 + "net/http" 7 + "time" 8 + ) 9 + 10 + const MaxAccounts = 20 11 + 12 + var ErrMaxAccountsReached = errors.New("maximum number of linked accounts reached") 13 + 14 + type AccountInfo struct { 15 + Did string `json:"did"` 16 + Handle string `json:"handle"` 17 + SessionId string `json:"session_id"` 18 + AddedAt int64 `json:"added_at"` 19 + } 20 + 21 + type AccountRegistry struct { 22 + Accounts []AccountInfo `json:"accounts"` 23 + } 24 + 25 + type MultiAccountUser struct { 26 + Active *User 27 + Accounts []AccountInfo 28 + } 29 + 30 + func (m *MultiAccountUser) Did() string { 31 + if m.Active == nil { 32 + return "" 33 + } 34 + return m.Active.Did 35 + } 36 + 37 + func (m *MultiAccountUser) Pds() string { 38 + if m.Active == nil { 39 + return "" 40 + } 41 + return m.Active.Pds 42 + } 43 + 44 + func (o *OAuth) GetAccounts(r *http.Request) *AccountRegistry { 45 + session, err := o.SessStore.Get(r, AccountsName) 46 + if err != nil || session.IsNew { 47 + return &AccountRegistry{Accounts: []AccountInfo{}} 48 + } 49 + 50 + data, ok := session.Values["accounts"].(string) 51 + if !ok { 52 + return &AccountRegistry{Accounts: []AccountInfo{}} 53 + } 54 + 55 + var registry AccountRegistry 56 + if err := json.Unmarshal([]byte(data), &registry); err != nil { 57 + return &AccountRegistry{Accounts: []AccountInfo{}} 58 + } 59 + 60 + return &registry 61 + } 62 + 63 + func (o *OAuth) SaveAccounts(w http.ResponseWriter, r *http.Request, registry *AccountRegistry) error { 64 + session, err := o.SessStore.Get(r, AccountsName) 65 + if err != nil { 66 + return err 67 + } 68 + 69 + data, err := json.Marshal(registry) 70 + if err != nil { 71 + return err 72 + } 73 + 74 + session.Values["accounts"] = string(data) 75 + session.Options.MaxAge = 60 * 60 * 24 * 365 76 + session.Options.HttpOnly = true 77 + session.Options.Secure = !o.Config.Core.Dev 78 + session.Options.SameSite = http.SameSiteLaxMode 79 + 80 + return session.Save(r, w) 81 + } 82 + 83 + func (r *AccountRegistry) AddAccount(did, handle, sessionId string) error { 84 + for i, acc := range r.Accounts { 85 + if acc.Did == did { 86 + r.Accounts[i].SessionId = sessionId 87 + r.Accounts[i].Handle = handle 88 + return nil 89 + } 90 + } 91 + 92 + if len(r.Accounts) >= MaxAccounts { 93 + return ErrMaxAccountsReached 94 + } 95 + 96 + r.Accounts = append(r.Accounts, AccountInfo{ 97 + Did: did, 98 + Handle: handle, 99 + SessionId: sessionId, 100 + AddedAt: time.Now().Unix(), 101 + }) 102 + return nil 103 + } 104 + 105 + func (r *AccountRegistry) RemoveAccount(did string) { 106 + filtered := make([]AccountInfo, 0, len(r.Accounts)) 107 + for _, acc := range r.Accounts { 108 + if acc.Did != did { 109 + filtered = append(filtered, acc) 110 + } 111 + } 112 + r.Accounts = filtered 113 + } 114 + 115 + func (r *AccountRegistry) FindAccount(did string) *AccountInfo { 116 + for i := range r.Accounts { 117 + if r.Accounts[i].Did == did { 118 + return &r.Accounts[i] 119 + } 120 + } 121 + return nil 122 + } 123 + 124 + func (r *AccountRegistry) OtherAccounts(activeDid string) []AccountInfo { 125 + result := make([]AccountInfo, 0, len(r.Accounts)) 126 + for _, acc := range r.Accounts { 127 + if acc.Did != activeDid { 128 + result = append(result, acc) 129 + } 130 + } 131 + return result 132 + } 133 + 134 + func (o *OAuth) GetMultiAccountUser(r *http.Request) *MultiAccountUser { 135 + user := o.GetUser(r) 136 + if user == nil { 137 + return nil 138 + } 139 + 140 + registry := o.GetAccounts(r) 141 + return &MultiAccountUser{ 142 + Active: user, 143 + Accounts: registry.Accounts, 144 + } 145 + } 146 + 147 + type AuthReturnInfo struct { 148 + ReturnURL string 149 + AddAccount bool 150 + } 151 + 152 + func (o *OAuth) SetAuthReturn(w http.ResponseWriter, r *http.Request, returnURL string, addAccount bool) error { 153 + session, err := o.SessStore.Get(r, AuthReturnName) 154 + if err != nil { 155 + return err 156 + } 157 + 158 + session.Values[AuthReturnURL] = returnURL 159 + session.Values[AuthAddAccount] = addAccount 160 + session.Options.MaxAge = 60 * 30 161 + session.Options.HttpOnly = true 162 + session.Options.Secure = !o.Config.Core.Dev 163 + session.Options.SameSite = http.SameSiteLaxMode 164 + 165 + return session.Save(r, w) 166 + } 167 + 168 + func (o *OAuth) GetAuthReturn(r *http.Request) *AuthReturnInfo { 169 + session, err := o.SessStore.Get(r, AuthReturnName) 170 + if err != nil || session.IsNew { 171 + return &AuthReturnInfo{} 172 + } 173 + 174 + returnURL, _ := session.Values[AuthReturnURL].(string) 175 + addAccount, _ := session.Values[AuthAddAccount].(bool) 176 + 177 + return &AuthReturnInfo{ 178 + ReturnURL: returnURL, 179 + AddAccount: addAccount, 180 + } 181 + } 182 + 183 + func (o *OAuth) ClearAuthReturn(w http.ResponseWriter, r *http.Request) error { 184 + session, err := o.SessStore.Get(r, AuthReturnName) 185 + if err != nil { 186 + return err 187 + } 188 + 189 + session.Options.MaxAge = -1 190 + return session.Save(r, w) 191 + }
+265
appview/oauth/accounts_test.go
··· 1 + package oauth 2 + 3 + import ( 4 + "testing" 5 + ) 6 + 7 + func TestAccountRegistry_AddAccount(t *testing.T) { 8 + tests := []struct { 9 + name string 10 + initial []AccountInfo 11 + addDid string 12 + addHandle string 13 + addSessionId string 14 + wantErr error 15 + wantLen int 16 + wantSessionId string 17 + }{ 18 + { 19 + name: "add first account", 20 + initial: []AccountInfo{}, 21 + addDid: "did:plc:abc123", 22 + addHandle: "alice.bsky.social", 23 + addSessionId: "session-1", 24 + wantErr: nil, 25 + wantLen: 1, 26 + wantSessionId: "session-1", 27 + }, 28 + { 29 + name: "add second account", 30 + initial: []AccountInfo{ 31 + {Did: "did:plc:abc123", Handle: "alice.bsky.social", SessionId: "session-1", AddedAt: 1000}, 32 + }, 33 + addDid: "did:plc:def456", 34 + addHandle: "bob.bsky.social", 35 + addSessionId: "session-2", 36 + wantErr: nil, 37 + wantLen: 2, 38 + wantSessionId: "session-2", 39 + }, 40 + { 41 + name: "update existing account session", 42 + initial: []AccountInfo{ 43 + {Did: "did:plc:abc123", Handle: "alice.bsky.social", SessionId: "old-session", AddedAt: 1000}, 44 + }, 45 + addDid: "did:plc:abc123", 46 + addHandle: "alice.bsky.social", 47 + addSessionId: "new-session", 48 + wantErr: nil, 49 + wantLen: 1, 50 + wantSessionId: "new-session", 51 + }, 52 + } 53 + 54 + for _, tt := range tests { 55 + t.Run(tt.name, func(t *testing.T) { 56 + registry := &AccountRegistry{Accounts: tt.initial} 57 + err := registry.AddAccount(tt.addDid, tt.addHandle, tt.addSessionId) 58 + 59 + if err != tt.wantErr { 60 + t.Errorf("AddAccount() error = %v, want %v", err, tt.wantErr) 61 + } 62 + 63 + if len(registry.Accounts) != tt.wantLen { 64 + t.Errorf("AddAccount() len = %d, want %d", len(registry.Accounts), tt.wantLen) 65 + } 66 + 67 + found := registry.FindAccount(tt.addDid) 68 + if found == nil { 69 + t.Errorf("AddAccount() account not found after add") 70 + return 71 + } 72 + 73 + if found.SessionId != tt.wantSessionId { 74 + t.Errorf("AddAccount() sessionId = %s, want %s", found.SessionId, tt.wantSessionId) 75 + } 76 + }) 77 + } 78 + } 79 + 80 + func TestAccountRegistry_AddAccount_MaxLimit(t *testing.T) { 81 + registry := &AccountRegistry{Accounts: make([]AccountInfo, 0, MaxAccounts)} 82 + 83 + for i := range MaxAccounts { 84 + err := registry.AddAccount("did:plc:user"+string(rune('a'+i)), "handle", "session") 85 + if err != nil { 86 + t.Fatalf("AddAccount() unexpected error on account %d: %v", i, err) 87 + } 88 + } 89 + 90 + if len(registry.Accounts) != MaxAccounts { 91 + t.Errorf("expected %d accounts, got %d", MaxAccounts, len(registry.Accounts)) 92 + } 93 + 94 + err := registry.AddAccount("did:plc:overflow", "overflow", "session-overflow") 95 + if err != ErrMaxAccountsReached { 96 + t.Errorf("AddAccount() error = %v, want %v", err, ErrMaxAccountsReached) 97 + } 98 + 99 + if len(registry.Accounts) != MaxAccounts { 100 + t.Errorf("account added despite max limit, got %d", len(registry.Accounts)) 101 + } 102 + } 103 + 104 + func TestAccountRegistry_RemoveAccount(t *testing.T) { 105 + tests := []struct { 106 + name string 107 + initial []AccountInfo 108 + removeDid string 109 + wantLen int 110 + wantDids []string 111 + }{ 112 + { 113 + name: "remove existing account", 114 + initial: []AccountInfo{ 115 + {Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"}, 116 + {Did: "did:plc:def456", Handle: "bob", SessionId: "s2"}, 117 + }, 118 + removeDid: "did:plc:abc123", 119 + wantLen: 1, 120 + wantDids: []string{"did:plc:def456"}, 121 + }, 122 + { 123 + name: "remove non-existing account", 124 + initial: []AccountInfo{ 125 + {Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"}, 126 + }, 127 + removeDid: "did:plc:notfound", 128 + wantLen: 1, 129 + wantDids: []string{"did:plc:abc123"}, 130 + }, 131 + { 132 + name: "remove last account", 133 + initial: []AccountInfo{ 134 + {Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"}, 135 + }, 136 + removeDid: "did:plc:abc123", 137 + wantLen: 0, 138 + wantDids: []string{}, 139 + }, 140 + { 141 + name: "remove from empty registry", 142 + initial: []AccountInfo{}, 143 + removeDid: "did:plc:abc123", 144 + wantLen: 0, 145 + wantDids: []string{}, 146 + }, 147 + } 148 + 149 + for _, tt := range tests { 150 + t.Run(tt.name, func(t *testing.T) { 151 + registry := &AccountRegistry{Accounts: tt.initial} 152 + registry.RemoveAccount(tt.removeDid) 153 + 154 + if len(registry.Accounts) != tt.wantLen { 155 + t.Errorf("RemoveAccount() len = %d, want %d", len(registry.Accounts), tt.wantLen) 156 + } 157 + 158 + for _, wantDid := range tt.wantDids { 159 + if registry.FindAccount(wantDid) == nil { 160 + t.Errorf("RemoveAccount() expected %s to remain", wantDid) 161 + } 162 + } 163 + 164 + if registry.FindAccount(tt.removeDid) != nil && tt.wantLen < len(tt.initial) { 165 + t.Errorf("RemoveAccount() %s should have been removed", tt.removeDid) 166 + } 167 + }) 168 + } 169 + } 170 + 171 + func TestAccountRegistry_FindAccount(t *testing.T) { 172 + registry := &AccountRegistry{ 173 + Accounts: []AccountInfo{ 174 + {Did: "did:plc:first", Handle: "first", SessionId: "s1", AddedAt: 1000}, 175 + {Did: "did:plc:second", Handle: "second", SessionId: "s2", AddedAt: 2000}, 176 + {Did: "did:plc:third", Handle: "third", SessionId: "s3", AddedAt: 3000}, 177 + }, 178 + } 179 + 180 + t.Run("find existing account", func(t *testing.T) { 181 + found := registry.FindAccount("did:plc:second") 182 + if found == nil { 183 + t.Fatal("FindAccount() returned nil for existing account") 184 + } 185 + if found.Handle != "second" { 186 + t.Errorf("FindAccount() handle = %s, want second", found.Handle) 187 + } 188 + if found.SessionId != "s2" { 189 + t.Errorf("FindAccount() sessionId = %s, want s2", found.SessionId) 190 + } 191 + }) 192 + 193 + t.Run("find non-existing account", func(t *testing.T) { 194 + found := registry.FindAccount("did:plc:notfound") 195 + if found != nil { 196 + t.Errorf("FindAccount() = %v, want nil", found) 197 + } 198 + }) 199 + 200 + t.Run("returned pointer is mutable", func(t *testing.T) { 201 + found := registry.FindAccount("did:plc:first") 202 + if found == nil { 203 + t.Fatal("FindAccount() returned nil") 204 + } 205 + found.SessionId = "modified" 206 + 207 + refetch := registry.FindAccount("did:plc:first") 208 + if refetch.SessionId != "modified" { 209 + t.Errorf("FindAccount() pointer not referencing original, got %s", refetch.SessionId) 210 + } 211 + }) 212 + } 213 + 214 + func TestAccountRegistry_OtherAccounts(t *testing.T) { 215 + registry := &AccountRegistry{ 216 + Accounts: []AccountInfo{ 217 + {Did: "did:plc:active", Handle: "active", SessionId: "s1"}, 218 + {Did: "did:plc:other1", Handle: "other1", SessionId: "s2"}, 219 + {Did: "did:plc:other2", Handle: "other2", SessionId: "s3"}, 220 + }, 221 + } 222 + 223 + others := registry.OtherAccounts("did:plc:active") 224 + 225 + if len(others) != 2 { 226 + t.Errorf("OtherAccounts() len = %d, want 2", len(others)) 227 + } 228 + 229 + for _, acc := range others { 230 + if acc.Did == "did:plc:active" { 231 + t.Errorf("OtherAccounts() should not include active account") 232 + } 233 + } 234 + 235 + hasDid := func(did string) bool { 236 + for _, acc := range others { 237 + if acc.Did == did { 238 + return true 239 + } 240 + } 241 + return false 242 + } 243 + 244 + if !hasDid("did:plc:other1") || !hasDid("did:plc:other2") { 245 + t.Errorf("OtherAccounts() missing expected accounts") 246 + } 247 + } 248 + 249 + func TestMultiAccountUser_Did(t *testing.T) { 250 + t.Run("with active user", func(t *testing.T) { 251 + user := &MultiAccountUser{ 252 + Active: &User{Did: "did:plc:test", Pds: "https://bsky.social"}, 253 + } 254 + if user.Did() != "did:plc:test" { 255 + t.Errorf("Did() = %s, want did:plc:test", user.Did()) 256 + } 257 + }) 258 + 259 + t.Run("with nil active", func(t *testing.T) { 260 + user := &MultiAccountUser{Active: nil} 261 + if user.Did() != "" { 262 + t.Errorf("Did() = %s, want empty string", user.Did()) 263 + } 264 + }) 265 + }
+5 -1
appview/oauth/consts.go
··· 1 1 package oauth 2 2 3 3 const ( 4 - SessionName = "appview-session-v2" 4 + SessionName = "appview-session-v2" 5 + AccountsName = "appview-accounts-v2" 6 + AuthReturnName = "appview-auth-return" 7 + AuthReturnURL = "return_url" 8 + AuthAddAccount = "add_account" 5 9 SessionHandle = "handle" 6 10 SessionDid = "did" 7 11 SessionId = "id"