+79
atproto/auth/http.go
+79
atproto/auth/http.go
···
1
+
package auth
2
+
3
+
import (
4
+
"context"
5
+
"crypto/subtle"
6
+
"net/http"
7
+
"strings"
8
+
9
+
"github.com/bluesky-social/indigo/atproto/syntax"
10
+
)
11
+
12
+
// HTTP Middleware for atproto admin auth, which is HTTP Basic auth with the username "admin".
13
+
//
14
+
// This supports multiple admin passwords, which makes it easier to rotate service secrets.
15
+
//
16
+
// This can be used with `echo.WrapMiddleware` (part of the echo web framework)
17
+
func AdminAuthMiddleware(handler http.HandlerFunc, adminPasswords []string) http.HandlerFunc {
18
+
return func(w http.ResponseWriter, r *http.Request) {
19
+
username, password, ok := r.BasicAuth()
20
+
if ok && username == "admin" {
21
+
for _, pw := range adminPasswords {
22
+
if subtle.ConstantTimeCompare([]byte(pw), []byte(password)) == 1 {
23
+
handler(w, r)
24
+
return
25
+
}
26
+
}
27
+
}
28
+
w.Header().Set("WWW-Authenticate", `Basic realm="admin", charset="UTF-8"`)
29
+
// TODO: XRPC error body?
30
+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
31
+
}
32
+
}
33
+
34
+
// HTTP Middleware for inter-service auth, which is HTTP Bearer with JWT.
35
+
//
36
+
// 'mandatory' indicates whether valid inter-service auth must be present, or just optional.
37
+
func (v *ServiceAuthValidator) Middleware(handler http.HandlerFunc, mandatory bool) http.HandlerFunc {
38
+
return func(w http.ResponseWriter, r *http.Request) {
39
+
40
+
if hdr := r.Header.Get("Authorization"); hdr != "" {
41
+
parts := strings.Split(hdr, " ")
42
+
if parts[0] != "Bearer" || len(parts) != 2 {
43
+
// TODO: XRPC error body?
44
+
w.Header().Set("WWW-Authenticate", "Bearer")
45
+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
46
+
return
47
+
}
48
+
49
+
var lxm *syntax.NSID
50
+
uparts := strings.Split(r.URL.Path, "/")
51
+
// TODO: should this "fail closed"? eg, reject if not a valid XRPC endpoint
52
+
if len(uparts) >= 3 && uparts[1] == "xrpc" {
53
+
nsid, err := syntax.ParseNSID(uparts[2])
54
+
if nil == err {
55
+
lxm = &nsid
56
+
}
57
+
}
58
+
59
+
did, err := v.Validate(r.Context(), parts[1], lxm)
60
+
if err != nil {
61
+
w.Header().Set("WWW-Authenticate", "Bearer")
62
+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
63
+
// TODO: XRPC error body?
64
+
return
65
+
}
66
+
ctx := context.WithValue(r.Context(), "did", did)
67
+
handler(w, r.WithContext(ctx))
68
+
return
69
+
}
70
+
71
+
if mandatory {
72
+
// TODO: XRPC error body?
73
+
w.Header().Set("WWW-Authenticate", "Bearer")
74
+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
75
+
return
76
+
}
77
+
handler(w, r)
78
+
}
79
+
}
+161
atproto/auth/http_test.go
+161
atproto/auth/http_test.go
···
1
+
package auth
2
+
3
+
import (
4
+
"net/http"
5
+
"net/http/httptest"
6
+
"testing"
7
+
"time"
8
+
9
+
"github.com/bluesky-social/indigo/atproto/crypto"
10
+
"github.com/bluesky-social/indigo/atproto/identity"
11
+
"github.com/bluesky-social/indigo/atproto/syntax"
12
+
13
+
"github.com/stretchr/testify/assert"
14
+
)
15
+
16
+
func webHome(w http.ResponseWriter, r *http.Request) {
17
+
ctx := r.Context()
18
+
19
+
w.WriteHeader(http.StatusOK)
20
+
did, ok := ctx.Value("did").(syntax.DID)
21
+
if ok {
22
+
w.Write([]byte(did.String()))
23
+
} else {
24
+
w.Write([]byte("hello world"))
25
+
}
26
+
}
27
+
28
+
func TestAdminAuthMiddleware(t *testing.T) {
29
+
assert := assert.New(t)
30
+
31
+
pw1 := "secret123"
32
+
pw2 := "secret789"
33
+
34
+
req := httptest.NewRequest(http.MethodGet, "/", nil)
35
+
middle := AdminAuthMiddleware(webHome, []string{pw1, pw2})
36
+
37
+
{
38
+
resp := httptest.NewRecorder()
39
+
middle(resp, req)
40
+
assert.Equal(http.StatusUnauthorized, resp.Code)
41
+
}
42
+
43
+
{
44
+
resp := httptest.NewRecorder()
45
+
req.SetBasicAuth("admin", pw1)
46
+
middle(resp, req)
47
+
assert.Equal(http.StatusOK, resp.Code)
48
+
}
49
+
50
+
{
51
+
resp := httptest.NewRecorder()
52
+
req.SetBasicAuth("admin", pw2)
53
+
middle(resp, req)
54
+
assert.Equal(http.StatusOK, resp.Code)
55
+
}
56
+
57
+
{
58
+
resp := httptest.NewRecorder()
59
+
req.SetBasicAuth("wrong", pw2)
60
+
middle(resp, req)
61
+
assert.Equal(http.StatusUnauthorized, resp.Code)
62
+
}
63
+
64
+
{
65
+
resp := httptest.NewRecorder()
66
+
req.SetBasicAuth("admin", "wrong")
67
+
middle(resp, req)
68
+
assert.Equal(http.StatusUnauthorized, resp.Code)
69
+
}
70
+
}
71
+
72
+
func TestServiceAuthMiddleware(t *testing.T) {
73
+
assert := assert.New(t)
74
+
75
+
iss := syntax.DID("did:example:iss")
76
+
aud := "did:example:aud#svc"
77
+
lxm := syntax.NSID("com.example.api")
78
+
79
+
priv, err := crypto.GeneratePrivateKeyP256()
80
+
if err != nil {
81
+
t.Fatal(err)
82
+
}
83
+
pub, err := priv.PublicKey()
84
+
if err != nil {
85
+
t.Fatal(err)
86
+
}
87
+
88
+
dir := identity.NewMockDirectory()
89
+
dir.Insert(identity.Identity{
90
+
DID: iss,
91
+
Keys: map[string]identity.Key{
92
+
"atproto": identity.Key{
93
+
Type: "Multikey",
94
+
PublicKeyMultibase: pub.Multibase(),
95
+
},
96
+
},
97
+
})
98
+
99
+
v := ServiceAuthValidator{
100
+
Audience: aud,
101
+
Dir: &dir,
102
+
}
103
+
104
+
{
105
+
// optional middleware, no auth
106
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.api", nil)
107
+
middle := v.Middleware(webHome, false)
108
+
resp := httptest.NewRecorder()
109
+
middle(resp, req)
110
+
assert.Equal(http.StatusOK, resp.Code)
111
+
assert.Equal("hello world", string(resp.Body.Bytes()))
112
+
}
113
+
114
+
{
115
+
// mandatory middleware, no auth
116
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.api", nil)
117
+
middle := v.Middleware(webHome, true)
118
+
resp := httptest.NewRecorder()
119
+
middle(resp, req)
120
+
assert.Equal(http.StatusUnauthorized, resp.Code)
121
+
}
122
+
123
+
{
124
+
// mandatory middleware, valid auth
125
+
tok, err := SignServiceAuth(iss, aud, time.Minute, &lxm, priv)
126
+
if err != nil {
127
+
t.Fatal(err)
128
+
}
129
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.api", nil)
130
+
req.Header.Set("Authorization", "Bearer "+tok)
131
+
middle := v.Middleware(webHome, true)
132
+
resp := httptest.NewRecorder()
133
+
middle(resp, req)
134
+
assert.Equal(http.StatusOK, resp.Code)
135
+
assert.Equal(iss.String(), string(resp.Body.Bytes()))
136
+
}
137
+
138
+
{
139
+
// mangled header
140
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.api", nil)
141
+
req.Header.Set("Authorization", "Bearer dummy")
142
+
middle := v.Middleware(webHome, false)
143
+
resp := httptest.NewRecorder()
144
+
middle(resp, req)
145
+
assert.Equal(http.StatusUnauthorized, resp.Code)
146
+
}
147
+
148
+
{
149
+
// wrong path
150
+
tok, err := SignServiceAuth(iss, aud, time.Minute, &lxm, priv)
151
+
if err != nil {
152
+
t.Fatal(err)
153
+
}
154
+
req := httptest.NewRequest(http.MethodGet, "/xrpc/com.example.other.api", nil)
155
+
req.Header.Set("Authorization", "Bearer "+tok)
156
+
middle := v.Middleware(webHome, true)
157
+
resp := httptest.NewRecorder()
158
+
middle(resp, req)
159
+
assert.Equal(http.StatusUnauthorized, resp.Code)
160
+
}
161
+
}