1package auth
2
3import (
4 "net/http"
5 "net/http/httptest"
6 "testing"
7 "time"
8
9 "github.com/bluesky-social/indigo/atproto/crypto"
10 "github.com/bluesky-social/indigo/atproto/identity"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17func webHome(w http.ResponseWriter, r *http.Request) {
18 ctx := r.Context()
19
20 w.WriteHeader(http.StatusOK)
21 did, ok := ctx.Value("did").(syntax.DID)
22 if ok {
23 w.Write([]byte(did.String()))
24 } else {
25 w.Write([]byte("hello world"))
26 }
27}
28
29func TestAdminAuthMiddleware(t *testing.T) {
30 assert := assert.New(t)
31
32 pw1 := "secret123"
33 pw2 := "secret789"
34
35 req := httptest.NewRequest(http.MethodGet, "/", nil)
36 middle := AdminAuthMiddleware(webHome, []string{pw1, pw2})
37
38 {
39 resp := httptest.NewRecorder()
40 middle(resp, req)
41 assert.Equal(http.StatusUnauthorized, resp.Code)
42 }
43
44 {
45 resp := httptest.NewRecorder()
46 req.SetBasicAuth("admin", pw1)
47 middle(resp, req)
48 assert.Equal(http.StatusOK, resp.Code)
49 }
50
51 {
52 resp := httptest.NewRecorder()
53 req.SetBasicAuth("admin", pw2)
54 middle(resp, req)
55 assert.Equal(http.StatusOK, resp.Code)
56 }
57
58 {
59 resp := httptest.NewRecorder()
60 req.SetBasicAuth("wrong", pw2)
61 middle(resp, req)
62 assert.Equal(http.StatusUnauthorized, resp.Code)
63 }
64
65 {
66 resp := httptest.NewRecorder()
67 req.SetBasicAuth("admin", "wrong")
68 middle(resp, req)
69 assert.Equal(http.StatusUnauthorized, resp.Code)
70 }
71}
72
73func TestServiceAuthMiddleware(t *testing.T) {
74 assert := assert.New(t)
75 require := require.New(t)
76
77 iss := syntax.DID("did:example:iss")
78 aud := "did:example:aud#svc"
79 lxm := syntax.NSID("com.example.api")
80
81 priv, err := crypto.GeneratePrivateKeyP256()
82 require.NoError(err)
83 pub, err := priv.PublicKey()
84 require.NoError(err)
85
86 dir := identity.NewMockDirectory()
87 dir.Insert(identity.Identity{
88 DID: iss,
89 Keys: map[string]identity.VerificationMethod{
90 "atproto": {
91 Type: "Multikey",
92 PublicKeyMultibase: pub.Multibase(),
93 },
94 },
95 })
96
97 v := ServiceAuthValidator{
98 Audience: aud,
99 Dir: &dir,
100 }
101
102 {
103 // optional middleware, no auth
104 req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.api", nil)
105 middle := v.Middleware(webHome, false)
106 resp := httptest.NewRecorder()
107 middle(resp, req)
108 assert.Equal(http.StatusOK, resp.Code)
109 assert.Equal("hello world", string(resp.Body.Bytes()))
110 }
111
112 {
113 // mandatory middleware, no auth
114 req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.api", nil)
115 middle := v.Middleware(webHome, true)
116 resp := httptest.NewRecorder()
117 middle(resp, req)
118 assert.Equal(http.StatusUnauthorized, resp.Code)
119 }
120
121 {
122 // mandatory middleware, valid auth
123 tok, err := SignServiceAuth(iss, aud, time.Minute, &lxm, priv)
124 require.NoError(err)
125 req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.api", nil)
126 req.Header.Set("Authorization", "Bearer "+tok)
127 middle := v.Middleware(webHome, true)
128 resp := httptest.NewRecorder()
129 middle(resp, req)
130 assert.Equal(http.StatusOK, resp.Code)
131 assert.Equal(iss.String(), string(resp.Body.Bytes()))
132 }
133
134 {
135 // mangled header
136 req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.api", nil)
137 req.Header.Set("Authorization", "Bearer dummy")
138 middle := v.Middleware(webHome, false)
139 resp := httptest.NewRecorder()
140 middle(resp, req)
141 assert.Equal(http.StatusUnauthorized, resp.Code)
142 }
143
144 {
145 // wrong path
146 tok, err := SignServiceAuth(iss, aud, time.Minute, &lxm, priv)
147 require.NoError(err)
148 req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.other.api", nil)
149 req.Header.Set("Authorization", "Bearer "+tok)
150 middle := v.Middleware(webHome, true)
151 resp := httptest.NewRecorder()
152 middle(resp, req)
153 assert.Equal(http.StatusUnauthorized, resp.Code)
154 }
155}