+14
-2
appview/oauth/handler.go
+14
-2
appview/oauth/handler.go
···
55
55
ctx := r.Context()
56
56
l := o.Logger.With("query", r.URL.Query())
57
57
58
+
authReturn := o.GetAuthReturn(r)
59
+
_ = o.ClearAuthReturn(w, r)
60
+
58
61
sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query())
59
62
if err != nil {
60
63
var callbackErr *oauth.AuthRequestCallbackError
···
70
73
71
74
if err := o.SaveSession(w, r, sessData); err != nil {
72
75
l.Error("failed to save session", "data", sessData, "err", err)
73
-
http.Redirect(w, r, "/login?error=session", http.StatusFound)
76
+
errorCode := "session"
77
+
if errors.Is(err, ErrMaxAccountsReached) {
78
+
errorCode = "max_accounts"
79
+
}
80
+
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound)
74
81
return
75
82
}
76
83
···
88
95
}
89
96
}
90
97
91
-
http.Redirect(w, r, "/", http.StatusFound)
98
+
redirectURL := "/"
99
+
if authReturn.ReturnURL != "" {
100
+
redirectURL = authReturn.ReturnURL
101
+
}
102
+
103
+
http.Redirect(w, r, redirectURL, http.StatusFound)
92
104
}
93
105
94
106
func (o *OAuth) addToDefaultSpindle(did string) {
+66
-4
appview/oauth/oauth.go
+66
-4
appview/oauth/oauth.go
···
98
98
}
99
99
100
100
func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
101
-
// first we save the did in the user session
102
101
userSession, err := o.SessStore.Get(r, SessionName)
103
102
if err != nil {
104
103
return err
···
108
107
userSession.Values[SessionPds] = sessData.HostURL
109
108
userSession.Values[SessionId] = sessData.SessionID
110
109
userSession.Values[SessionAuthenticated] = true
111
-
return userSession.Save(r, w)
110
+
111
+
if err := userSession.Save(r, w); err != nil {
112
+
return err
113
+
}
114
+
115
+
handle := ""
116
+
resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
117
+
if err == nil && resolved.Handle.String() != "" {
118
+
handle = resolved.Handle.String()
119
+
}
120
+
121
+
registry := o.GetAccounts(r)
122
+
if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
123
+
return err
124
+
}
125
+
return o.SaveAccounts(w, r, registry)
112
126
}
113
127
114
128
func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
···
163
177
return errors.Join(err1, err2)
164
178
}
165
179
180
+
func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
181
+
registry := o.GetAccounts(r)
182
+
account := registry.FindAccount(targetDid)
183
+
if account == nil {
184
+
return fmt.Errorf("account not found in registry: %s", targetDid)
185
+
}
186
+
187
+
did, err := syntax.ParseDID(targetDid)
188
+
if err != nil {
189
+
return fmt.Errorf("invalid DID: %w", err)
190
+
}
191
+
192
+
sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId)
193
+
if err != nil {
194
+
registry.RemoveAccount(targetDid)
195
+
_ = o.SaveAccounts(w, r, registry)
196
+
return fmt.Errorf("session expired for account: %w", err)
197
+
}
198
+
199
+
userSession, err := o.SessStore.Get(r, SessionName)
200
+
if err != nil {
201
+
return err
202
+
}
203
+
204
+
userSession.Values[SessionDid] = sess.Data.AccountDID.String()
205
+
userSession.Values[SessionPds] = sess.Data.HostURL
206
+
userSession.Values[SessionId] = sess.Data.SessionID
207
+
userSession.Values[SessionAuthenticated] = true
208
+
209
+
return userSession.Save(r, w)
210
+
}
211
+
212
+
func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
213
+
registry := o.GetAccounts(r)
214
+
account := registry.FindAccount(targetDid)
215
+
if account == nil {
216
+
return nil
217
+
}
218
+
219
+
did, err := syntax.ParseDID(targetDid)
220
+
if err == nil {
221
+
_ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
222
+
}
223
+
224
+
registry.RemoveAccount(targetDid)
225
+
return o.SaveAccounts(w, r, registry)
226
+
}
227
+
166
228
type User struct {
167
229
Did string
168
230
Pds string
···
181
243
}
182
244
183
245
func (o *OAuth) GetDid(r *http.Request) string {
184
-
if u := o.GetUser(r); u != nil {
185
-
return u.Did
246
+
if u := o.GetMultiAccountUser(r); u != nil {
247
+
return u.Did()
186
248
}
187
249
188
250
return ""