1package client
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "net/http/httptest"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/bluesky-social/indigo/atproto/identity"
16 "github.com/bluesky-social/indigo/atproto/syntax"
17
18 "github.com/stretchr/testify/assert"
19 "github.com/stretchr/testify/require"
20)
21
22func pwHandler(w http.ResponseWriter, r *http.Request) {
23 switch r.URL.Path {
24 case "/xrpc/com.atproto.server.refreshSession":
25 //fmt.Println("refreshSession handler...")
26 hdr := r.Header.Get("Authorization")
27 if hdr != "Bearer refresh1" {
28 fmt.Printf("refreshSession header: %s\n", hdr)
29 w.Header().Set("WWW-Authenticate", `Bearer`)
30 http.Error(w, "Unauthorized", http.StatusUnauthorized)
31 return
32 }
33 w.Header().Set("Content-Type", "application/json")
34 json.NewEncoder(w).Encode(map[string]string{
35 "did": "did:web:account.example.com",
36 "accessJwt": "access2",
37 "refreshJwt": "refresh2",
38 })
39 return
40 case "/xrpc/com.atproto.server.deleteSession":
41 //fmt.Println("deleteSession handler...")
42 hdr := r.Header.Get("Authorization")
43 if hdr != "Bearer refresh1" {
44 fmt.Printf("refreshSession header: %s\n", hdr)
45 w.Header().Set("WWW-Authenticate", `Bearer`)
46 http.Error(w, "Unauthorized", http.StatusUnauthorized)
47 return
48 }
49 w.Header().Set("Content-Type", "application/json")
50 return
51 case "/xrpc/com.atproto.server.createSession":
52 if !strings.HasPrefix(r.Header.Get("Content-Type"), "application/json") {
53 fmt.Println("createSession Content-Type")
54 http.Error(w, "Bad Request", http.StatusBadRequest)
55 return
56 }
57 var body map[string]string
58 if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
59 fmt.Println("createSession JSON")
60 http.Error(w, "Bad Request", http.StatusBadRequest)
61 return
62 }
63 if body["identifier"] != "did:web:account.example.com" || body["password"] != "password1" {
64 fmt.Println("createSession wrong password")
65 http.Error(w, "Bad Request", http.StatusUnauthorized)
66 return
67 }
68
69 w.Header().Set("Content-Type", "application/json")
70 json.NewEncoder(w).Encode(map[string]string{
71 "did": body["identifier"],
72 "accessJwt": "access1",
73 "refreshJwt": "refresh1",
74 })
75 return
76 case "/xrpc/com.example.get", "/xrpc/com.example.post":
77 hdr := r.Header.Get("Authorization")
78 if hdr == "Bearer access1" || hdr == "Bearer access2" {
79 w.Header().Set("Content-Type", "application/json")
80 fmt.Fprintln(w, "{\"status\":\"success\"}")
81 return
82 } else {
83 fmt.Printf("get header: %s\n", hdr)
84 w.Header().Set("WWW-Authenticate", `Bearer`)
85 http.Error(w, "Unauthorized", http.StatusUnauthorized)
86 return
87 }
88 case "/xrpc/com.example.expire":
89 hdr := r.Header.Get("Authorization")
90 if hdr == "Bearer access1" {
91 //fmt.Println("forcing refresh...")
92 w.Header().Set("Content-Type", "application/json")
93 w.WriteHeader(400)
94 fmt.Fprintln(w, "{\"error\":\"ExpiredToken\"}")
95 return
96 } else if hdr == "Bearer access2" {
97 w.Header().Set("Content-Type", "application/json")
98 fmt.Fprintln(w, "{\"status\":\"success\"}")
99 return
100 } else {
101 fmt.Printf("expire header: %s\n", hdr)
102 w.Header().Set("WWW-Authenticate", `Bearer`)
103 http.Error(w, "Unauthorized", http.StatusUnauthorized)
104 return
105 }
106 default:
107 http.NotFound(w, r)
108 return
109 }
110}
111
112func TestPasswordAuth(t *testing.T) {
113 assert := assert.New(t)
114 require := require.New(t)
115 ctx := context.Background()
116
117 srv := httptest.NewServer(http.HandlerFunc(pwHandler))
118 defer srv.Close()
119
120 dir := identity.NewMockDirectory()
121 dir.Insert(identity.Identity{
122 DID: "did:web:account.example.com",
123 Handle: "user1.example.com",
124 Services: map[string]identity.ServiceEndpoint{
125 "atproto_pds": {
126 Type: "AtprotoPersonalDataServer",
127 URL: srv.URL,
128 },
129 },
130 })
131
132 {
133 // simple GET requests, with token expire/retry
134 c, err := LoginWithPassword(ctx, &dir, syntax.Handle("user1.example.com").AtIdentifier(), "password1", "", nil)
135 require.NoError(err)
136 err = c.Get(ctx, syntax.NSID("com.example.get"), nil, nil)
137 assert.NoError(err)
138 err = c.Get(ctx, syntax.NSID("com.example.expire"), nil, nil)
139 assert.NoError(err)
140 }
141
142 {
143 // test resume session, and session data callback mechanism
144 ch := make(chan string, 10)
145 cb := func(ctx context.Context, data PasswordSessionData) {
146 assert.Equal("refresh2", data.RefreshToken)
147 ch <- "refreshed"
148 }
149 c := ResumePasswordSession(PasswordSessionData{
150 AccessToken: "access1",
151 RefreshToken: "refresh1",
152 AccountDID: syntax.DID("did:web:account.example.com"),
153 Host: srv.URL,
154 }, cb)
155
156 err := c.Get(ctx, syntax.NSID("com.example.get"), nil, nil)
157 assert.NoError(err)
158 err = c.Get(ctx, syntax.NSID("com.example.expire"), nil, nil)
159 assert.NoError(err)
160
161 select {
162 case msg := <-ch:
163 assert.Equal("refreshed", msg)
164 }
165 }
166
167 {
168 // logout
169 c, err := LoginWithPassword(ctx, &dir, syntax.Handle("user1.example.com").AtIdentifier(), "password1", "", nil)
170 require.NoError(err)
171
172 passAuth, ok := c.Auth.(*PasswordAuth)
173 require.True(ok)
174 err = passAuth.Logout(ctx, c.Client)
175 assert.NoError(err)
176 }
177
178 {
179 // simple POST request, with token expire/retry
180 c, err := LoginWithPassword(ctx, &dir, syntax.Handle("user1.example.com").AtIdentifier(), "password1", "", nil)
181 require.NoError(err)
182 body := map[string]any{
183 "a": 123,
184 "b": "hello",
185 }
186 var out json.RawMessage
187 err = c.Post(ctx, syntax.NSID("com.example.post"), body, &out)
188 assert.NoError(err)
189 err = c.Post(ctx, syntax.NSID("com.example.expire"), body, &out)
190 assert.NoError(err)
191 }
192
193 {
194 // POST with bytes.Buffer body
195 c, err := LoginWithPassword(ctx, &dir, syntax.Handle("user1.example.com").AtIdentifier(), "password1", "", nil)
196 require.NoError(err)
197 body := bytes.NewBufferString("some text")
198 req := NewAPIRequest(MethodProcedure, syntax.NSID("com.example.expire"), body)
199 req.Headers.Set("Content-Type", "text/plain")
200 resp, err := c.Do(ctx, req)
201 require.NoError(err)
202 assert.Equal(200, resp.StatusCode)
203 }
204
205 {
206 // POST with file on disk (can seek and retry)
207 c, err := LoginWithPassword(ctx, &dir, syntax.Handle("user1.example.com").AtIdentifier(), "password1", "", nil)
208 require.NoError(err)
209 f, err := os.Open("testdata/body.json")
210 require.NoError(err)
211 req := NewAPIRequest(MethodProcedure, syntax.NSID("com.example.expire"), f)
212 req.Headers.Set("Content-Type", "application/json")
213 resp, err := c.Do(ctx, req)
214 require.NoError(err)
215 assert.Equal(200, resp.StatusCode)
216 }
217
218 {
219 // POST with pipe reader (can *not* retry)
220 c, err := LoginWithPassword(ctx, &dir, syntax.Handle("user1.example.com").AtIdentifier(), "password1", "", nil)
221 require.NoError(err)
222 r1, w1 := io.Pipe()
223 go func() {
224 fmt.Fprintf(w1, "some data")
225 w1.Close()
226 }()
227 req1 := NewAPIRequest(MethodProcedure, syntax.NSID("com.example.post"), r1)
228 req1.Headers.Set("Content-Type", "text/plain")
229 resp, err := c.Do(ctx, req1)
230 require.NoError(err)
231 assert.Equal(200, resp.StatusCode)
232
233 // expect this to fail (can't re-read from Pipe)
234 r2, w2 := io.Pipe()
235 go func() {
236 fmt.Fprintf(w2, "some data")
237 w2.Close()
238 }()
239 req2 := NewAPIRequest(MethodProcedure, syntax.NSID("com.example.expire"), r2)
240 req2.Headers.Set("Content-Type", "text/plain")
241 _, err = c.Do(ctx, req2)
242 assert.Error(err)
243 }
244}