+191
appview/oauth/accounts.go
+191
appview/oauth/accounts.go
···
1
+
package oauth
2
+
3
+
import (
4
+
"encoding/json"
5
+
"errors"
6
+
"net/http"
7
+
"time"
8
+
)
9
+
10
+
const MaxAccounts = 20
11
+
12
+
var ErrMaxAccountsReached = errors.New("maximum number of linked accounts reached")
13
+
14
+
type AccountInfo struct {
15
+
Did string `json:"did"`
16
+
Handle string `json:"handle"`
17
+
SessionId string `json:"session_id"`
18
+
AddedAt int64 `json:"added_at"`
19
+
}
20
+
21
+
type AccountRegistry struct {
22
+
Accounts []AccountInfo `json:"accounts"`
23
+
}
24
+
25
+
type MultiAccountUser struct {
26
+
Active *User
27
+
Accounts []AccountInfo
28
+
}
29
+
30
+
func (m *MultiAccountUser) Did() string {
31
+
if m.Active == nil {
32
+
return ""
33
+
}
34
+
return m.Active.Did
35
+
}
36
+
37
+
func (m *MultiAccountUser) Pds() string {
38
+
if m.Active == nil {
39
+
return ""
40
+
}
41
+
return m.Active.Pds
42
+
}
43
+
44
+
func (o *OAuth) GetAccounts(r *http.Request) *AccountRegistry {
45
+
session, err := o.SessStore.Get(r, AccountsName)
46
+
if err != nil || session.IsNew {
47
+
return &AccountRegistry{Accounts: []AccountInfo{}}
48
+
}
49
+
50
+
data, ok := session.Values["accounts"].(string)
51
+
if !ok {
52
+
return &AccountRegistry{Accounts: []AccountInfo{}}
53
+
}
54
+
55
+
var registry AccountRegistry
56
+
if err := json.Unmarshal([]byte(data), ®istry); err != nil {
57
+
return &AccountRegistry{Accounts: []AccountInfo{}}
58
+
}
59
+
60
+
return ®istry
61
+
}
62
+
63
+
func (o *OAuth) SaveAccounts(w http.ResponseWriter, r *http.Request, registry *AccountRegistry) error {
64
+
session, err := o.SessStore.Get(r, AccountsName)
65
+
if err != nil {
66
+
return err
67
+
}
68
+
69
+
data, err := json.Marshal(registry)
70
+
if err != nil {
71
+
return err
72
+
}
73
+
74
+
session.Values["accounts"] = string(data)
75
+
session.Options.MaxAge = 60 * 60 * 24 * 365
76
+
session.Options.HttpOnly = true
77
+
session.Options.Secure = !o.Config.Core.Dev
78
+
session.Options.SameSite = http.SameSiteLaxMode
79
+
80
+
return session.Save(r, w)
81
+
}
82
+
83
+
func (r *AccountRegistry) AddAccount(did, handle, sessionId string) error {
84
+
for i, acc := range r.Accounts {
85
+
if acc.Did == did {
86
+
r.Accounts[i].SessionId = sessionId
87
+
r.Accounts[i].Handle = handle
88
+
return nil
89
+
}
90
+
}
91
+
92
+
if len(r.Accounts) >= MaxAccounts {
93
+
return ErrMaxAccountsReached
94
+
}
95
+
96
+
r.Accounts = append(r.Accounts, AccountInfo{
97
+
Did: did,
98
+
Handle: handle,
99
+
SessionId: sessionId,
100
+
AddedAt: time.Now().Unix(),
101
+
})
102
+
return nil
103
+
}
104
+
105
+
func (r *AccountRegistry) RemoveAccount(did string) {
106
+
filtered := make([]AccountInfo, 0, len(r.Accounts))
107
+
for _, acc := range r.Accounts {
108
+
if acc.Did != did {
109
+
filtered = append(filtered, acc)
110
+
}
111
+
}
112
+
r.Accounts = filtered
113
+
}
114
+
115
+
func (r *AccountRegistry) FindAccount(did string) *AccountInfo {
116
+
for i := range r.Accounts {
117
+
if r.Accounts[i].Did == did {
118
+
return &r.Accounts[i]
119
+
}
120
+
}
121
+
return nil
122
+
}
123
+
124
+
func (r *AccountRegistry) OtherAccounts(activeDid string) []AccountInfo {
125
+
result := make([]AccountInfo, 0, len(r.Accounts))
126
+
for _, acc := range r.Accounts {
127
+
if acc.Did != activeDid {
128
+
result = append(result, acc)
129
+
}
130
+
}
131
+
return result
132
+
}
133
+
134
+
func (o *OAuth) GetMultiAccountUser(r *http.Request) *MultiAccountUser {
135
+
user := o.GetUser(r)
136
+
if user == nil {
137
+
return nil
138
+
}
139
+
140
+
registry := o.GetAccounts(r)
141
+
return &MultiAccountUser{
142
+
Active: user,
143
+
Accounts: registry.Accounts,
144
+
}
145
+
}
146
+
147
+
type AuthReturnInfo struct {
148
+
ReturnURL string
149
+
AddAccount bool
150
+
}
151
+
152
+
func (o *OAuth) SetAuthReturn(w http.ResponseWriter, r *http.Request, returnURL string, addAccount bool) error {
153
+
session, err := o.SessStore.Get(r, AuthReturnName)
154
+
if err != nil {
155
+
return err
156
+
}
157
+
158
+
session.Values[AuthReturnURL] = returnURL
159
+
session.Values[AuthAddAccount] = addAccount
160
+
session.Options.MaxAge = 60 * 30
161
+
session.Options.HttpOnly = true
162
+
session.Options.Secure = !o.Config.Core.Dev
163
+
session.Options.SameSite = http.SameSiteLaxMode
164
+
165
+
return session.Save(r, w)
166
+
}
167
+
168
+
func (o *OAuth) GetAuthReturn(r *http.Request) *AuthReturnInfo {
169
+
session, err := o.SessStore.Get(r, AuthReturnName)
170
+
if err != nil || session.IsNew {
171
+
return &AuthReturnInfo{}
172
+
}
173
+
174
+
returnURL, _ := session.Values[AuthReturnURL].(string)
175
+
addAccount, _ := session.Values[AuthAddAccount].(bool)
176
+
177
+
return &AuthReturnInfo{
178
+
ReturnURL: returnURL,
179
+
AddAccount: addAccount,
180
+
}
181
+
}
182
+
183
+
func (o *OAuth) ClearAuthReturn(w http.ResponseWriter, r *http.Request) error {
184
+
session, err := o.SessStore.Get(r, AuthReturnName)
185
+
if err != nil {
186
+
return err
187
+
}
188
+
189
+
session.Options.MaxAge = -1
190
+
return session.Save(r, w)
191
+
}
+265
appview/oauth/accounts_test.go
+265
appview/oauth/accounts_test.go
···
1
+
package oauth
2
+
3
+
import (
4
+
"testing"
5
+
)
6
+
7
+
func TestAccountRegistry_AddAccount(t *testing.T) {
8
+
tests := []struct {
9
+
name string
10
+
initial []AccountInfo
11
+
addDid string
12
+
addHandle string
13
+
addSessionId string
14
+
wantErr error
15
+
wantLen int
16
+
wantSessionId string
17
+
}{
18
+
{
19
+
name: "add first account",
20
+
initial: []AccountInfo{},
21
+
addDid: "did:plc:abc123",
22
+
addHandle: "alice.bsky.social",
23
+
addSessionId: "session-1",
24
+
wantErr: nil,
25
+
wantLen: 1,
26
+
wantSessionId: "session-1",
27
+
},
28
+
{
29
+
name: "add second account",
30
+
initial: []AccountInfo{
31
+
{Did: "did:plc:abc123", Handle: "alice.bsky.social", SessionId: "session-1", AddedAt: 1000},
32
+
},
33
+
addDid: "did:plc:def456",
34
+
addHandle: "bob.bsky.social",
35
+
addSessionId: "session-2",
36
+
wantErr: nil,
37
+
wantLen: 2,
38
+
wantSessionId: "session-2",
39
+
},
40
+
{
41
+
name: "update existing account session",
42
+
initial: []AccountInfo{
43
+
{Did: "did:plc:abc123", Handle: "alice.bsky.social", SessionId: "old-session", AddedAt: 1000},
44
+
},
45
+
addDid: "did:plc:abc123",
46
+
addHandle: "alice.bsky.social",
47
+
addSessionId: "new-session",
48
+
wantErr: nil,
49
+
wantLen: 1,
50
+
wantSessionId: "new-session",
51
+
},
52
+
}
53
+
54
+
for _, tt := range tests {
55
+
t.Run(tt.name, func(t *testing.T) {
56
+
registry := &AccountRegistry{Accounts: tt.initial}
57
+
err := registry.AddAccount(tt.addDid, tt.addHandle, tt.addSessionId)
58
+
59
+
if err != tt.wantErr {
60
+
t.Errorf("AddAccount() error = %v, want %v", err, tt.wantErr)
61
+
}
62
+
63
+
if len(registry.Accounts) != tt.wantLen {
64
+
t.Errorf("AddAccount() len = %d, want %d", len(registry.Accounts), tt.wantLen)
65
+
}
66
+
67
+
found := registry.FindAccount(tt.addDid)
68
+
if found == nil {
69
+
t.Errorf("AddAccount() account not found after add")
70
+
return
71
+
}
72
+
73
+
if found.SessionId != tt.wantSessionId {
74
+
t.Errorf("AddAccount() sessionId = %s, want %s", found.SessionId, tt.wantSessionId)
75
+
}
76
+
})
77
+
}
78
+
}
79
+
80
+
func TestAccountRegistry_AddAccount_MaxLimit(t *testing.T) {
81
+
registry := &AccountRegistry{Accounts: make([]AccountInfo, 0, MaxAccounts)}
82
+
83
+
for i := range MaxAccounts {
84
+
err := registry.AddAccount("did:plc:user"+string(rune('a'+i)), "handle", "session")
85
+
if err != nil {
86
+
t.Fatalf("AddAccount() unexpected error on account %d: %v", i, err)
87
+
}
88
+
}
89
+
90
+
if len(registry.Accounts) != MaxAccounts {
91
+
t.Errorf("expected %d accounts, got %d", MaxAccounts, len(registry.Accounts))
92
+
}
93
+
94
+
err := registry.AddAccount("did:plc:overflow", "overflow", "session-overflow")
95
+
if err != ErrMaxAccountsReached {
96
+
t.Errorf("AddAccount() error = %v, want %v", err, ErrMaxAccountsReached)
97
+
}
98
+
99
+
if len(registry.Accounts) != MaxAccounts {
100
+
t.Errorf("account added despite max limit, got %d", len(registry.Accounts))
101
+
}
102
+
}
103
+
104
+
func TestAccountRegistry_RemoveAccount(t *testing.T) {
105
+
tests := []struct {
106
+
name string
107
+
initial []AccountInfo
108
+
removeDid string
109
+
wantLen int
110
+
wantDids []string
111
+
}{
112
+
{
113
+
name: "remove existing account",
114
+
initial: []AccountInfo{
115
+
{Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"},
116
+
{Did: "did:plc:def456", Handle: "bob", SessionId: "s2"},
117
+
},
118
+
removeDid: "did:plc:abc123",
119
+
wantLen: 1,
120
+
wantDids: []string{"did:plc:def456"},
121
+
},
122
+
{
123
+
name: "remove non-existing account",
124
+
initial: []AccountInfo{
125
+
{Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"},
126
+
},
127
+
removeDid: "did:plc:notfound",
128
+
wantLen: 1,
129
+
wantDids: []string{"did:plc:abc123"},
130
+
},
131
+
{
132
+
name: "remove last account",
133
+
initial: []AccountInfo{
134
+
{Did: "did:plc:abc123", Handle: "alice", SessionId: "s1"},
135
+
},
136
+
removeDid: "did:plc:abc123",
137
+
wantLen: 0,
138
+
wantDids: []string{},
139
+
},
140
+
{
141
+
name: "remove from empty registry",
142
+
initial: []AccountInfo{},
143
+
removeDid: "did:plc:abc123",
144
+
wantLen: 0,
145
+
wantDids: []string{},
146
+
},
147
+
}
148
+
149
+
for _, tt := range tests {
150
+
t.Run(tt.name, func(t *testing.T) {
151
+
registry := &AccountRegistry{Accounts: tt.initial}
152
+
registry.RemoveAccount(tt.removeDid)
153
+
154
+
if len(registry.Accounts) != tt.wantLen {
155
+
t.Errorf("RemoveAccount() len = %d, want %d", len(registry.Accounts), tt.wantLen)
156
+
}
157
+
158
+
for _, wantDid := range tt.wantDids {
159
+
if registry.FindAccount(wantDid) == nil {
160
+
t.Errorf("RemoveAccount() expected %s to remain", wantDid)
161
+
}
162
+
}
163
+
164
+
if registry.FindAccount(tt.removeDid) != nil && tt.wantLen < len(tt.initial) {
165
+
t.Errorf("RemoveAccount() %s should have been removed", tt.removeDid)
166
+
}
167
+
})
168
+
}
169
+
}
170
+
171
+
func TestAccountRegistry_FindAccount(t *testing.T) {
172
+
registry := &AccountRegistry{
173
+
Accounts: []AccountInfo{
174
+
{Did: "did:plc:first", Handle: "first", SessionId: "s1", AddedAt: 1000},
175
+
{Did: "did:plc:second", Handle: "second", SessionId: "s2", AddedAt: 2000},
176
+
{Did: "did:plc:third", Handle: "third", SessionId: "s3", AddedAt: 3000},
177
+
},
178
+
}
179
+
180
+
t.Run("find existing account", func(t *testing.T) {
181
+
found := registry.FindAccount("did:plc:second")
182
+
if found == nil {
183
+
t.Fatal("FindAccount() returned nil for existing account")
184
+
}
185
+
if found.Handle != "second" {
186
+
t.Errorf("FindAccount() handle = %s, want second", found.Handle)
187
+
}
188
+
if found.SessionId != "s2" {
189
+
t.Errorf("FindAccount() sessionId = %s, want s2", found.SessionId)
190
+
}
191
+
})
192
+
193
+
t.Run("find non-existing account", func(t *testing.T) {
194
+
found := registry.FindAccount("did:plc:notfound")
195
+
if found != nil {
196
+
t.Errorf("FindAccount() = %v, want nil", found)
197
+
}
198
+
})
199
+
200
+
t.Run("returned pointer is mutable", func(t *testing.T) {
201
+
found := registry.FindAccount("did:plc:first")
202
+
if found == nil {
203
+
t.Fatal("FindAccount() returned nil")
204
+
}
205
+
found.SessionId = "modified"
206
+
207
+
refetch := registry.FindAccount("did:plc:first")
208
+
if refetch.SessionId != "modified" {
209
+
t.Errorf("FindAccount() pointer not referencing original, got %s", refetch.SessionId)
210
+
}
211
+
})
212
+
}
213
+
214
+
func TestAccountRegistry_OtherAccounts(t *testing.T) {
215
+
registry := &AccountRegistry{
216
+
Accounts: []AccountInfo{
217
+
{Did: "did:plc:active", Handle: "active", SessionId: "s1"},
218
+
{Did: "did:plc:other1", Handle: "other1", SessionId: "s2"},
219
+
{Did: "did:plc:other2", Handle: "other2", SessionId: "s3"},
220
+
},
221
+
}
222
+
223
+
others := registry.OtherAccounts("did:plc:active")
224
+
225
+
if len(others) != 2 {
226
+
t.Errorf("OtherAccounts() len = %d, want 2", len(others))
227
+
}
228
+
229
+
for _, acc := range others {
230
+
if acc.Did == "did:plc:active" {
231
+
t.Errorf("OtherAccounts() should not include active account")
232
+
}
233
+
}
234
+
235
+
hasDid := func(did string) bool {
236
+
for _, acc := range others {
237
+
if acc.Did == did {
238
+
return true
239
+
}
240
+
}
241
+
return false
242
+
}
243
+
244
+
if !hasDid("did:plc:other1") || !hasDid("did:plc:other2") {
245
+
t.Errorf("OtherAccounts() missing expected accounts")
246
+
}
247
+
}
248
+
249
+
func TestMultiAccountUser_Did(t *testing.T) {
250
+
t.Run("with active user", func(t *testing.T) {
251
+
user := &MultiAccountUser{
252
+
Active: &User{Did: "did:plc:test", Pds: "https://bsky.social"},
253
+
}
254
+
if user.Did() != "did:plc:test" {
255
+
t.Errorf("Did() = %s, want did:plc:test", user.Did())
256
+
}
257
+
})
258
+
259
+
t.Run("with nil active", func(t *testing.T) {
260
+
user := &MultiAccountUser{Active: nil}
261
+
if user.Did() != "" {
262
+
t.Errorf("Did() = %s, want empty string", user.Did())
263
+
}
264
+
})
265
+
}
+5
-1
appview/oauth/consts.go
+5
-1
appview/oauth/consts.go
···
1
1
package oauth
2
2
3
3
const (
4
-
SessionName = "appview-session-v2"
4
+
SessionName = "appview-session-v2"
5
+
AccountsName = "appview-accounts-v2"
6
+
AuthReturnName = "appview-auth-return"
7
+
AuthReturnURL = "return_url"
8
+
AuthAddAccount = "add_account"
5
9
SessionHandle = "handle"
6
10
SessionDid = "did"
7
11
SessionId = "id"