tangled
alpha
login
or
join now
stream.place
/
streamplace
Live video on the AT Protocol
74
fork
atom
overview
issues
1
pulls
pipelines
fix: untrusted resolution blocking all v4 addresses
Natalie B.
2 months ago
559b4ac1
68cb0b5a
+182
-2
3 changed files
expand all
collapse all
unified
split
pkg
aqhttp
aqhttp.go
aqhttp_test.go
resolv.go
+9
-1
pkg/aqhttp/aqhttp.go
···
20
20
// where the validation overhead is problematic.
21
21
var TrustedClient http.Client
22
22
23
23
+
type ClientOptions struct {
24
24
+
OverrideInTest bool
25
25
+
}
26
26
+
27
27
+
var defaultClientOptions = ClientOptions{
28
28
+
OverrideInTest: true,
29
29
+
}
30
30
+
23
31
func init() {
24
32
// Initialize the trusted client first.
25
33
TrustedClient = http.Client{
···
38
46
39
47
// When running under `go test` the test binary name typically ends with ".test".
40
48
// In that case, use the trusted client to avoid SSRF blocking for localhost tests.
41
41
-
if len(os.Args) > 0 && strings.HasSuffix(os.Args[0], ".test") {
49
49
+
if defaultClientOptions.OverrideInTest && len(os.Args) > 0 && strings.HasSuffix(os.Args[0], ".test") {
42
50
Client = TrustedClient
43
51
}
44
52
}
+173
pkg/aqhttp/aqhttp_test.go
···
1
1
+
package aqhttp
2
2
+
3
3
+
import (
4
4
+
"context"
5
5
+
"net/http"
6
6
+
"net/http/httptest"
7
7
+
"testing"
8
8
+
"time"
9
9
+
10
10
+
"github.com/stretchr/testify/require"
11
11
+
)
12
12
+
13
13
+
func TestClientRedirects(t *testing.T) {
14
14
+
// Temporarily disable the test override to test actual Client behavior
15
15
+
originalOverride := defaultClientOptions.OverrideInTest
16
16
+
defaultClientOptions.OverrideInTest = false
17
17
+
18
18
+
// Reinitialize the Client with SSRF protection
19
19
+
Client = http.Client{
20
20
+
Transport: NewUntrustedTransport(),
21
21
+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
22
22
+
return http.ErrUseLastResponse
23
23
+
},
24
24
+
Timeout: 30 * time.Second,
25
25
+
}
26
26
+
27
27
+
defer func() {
28
28
+
defaultClientOptions.OverrideInTest = originalOverride
29
29
+
// Restore to TrustedClient for other tests
30
30
+
Client = TrustedClient
31
31
+
}()
32
32
+
33
33
+
redirectCount := 0
34
34
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
35
35
+
if r.URL.Path == "/start" {
36
36
+
redirectCount++
37
37
+
http.Redirect(w, r, "/end", http.StatusTemporaryRedirect)
38
38
+
return
39
39
+
}
40
40
+
w.WriteHeader(http.StatusOK)
41
41
+
}))
42
42
+
defer server.Close()
43
43
+
44
44
+
req, err := http.NewRequest("GET", server.URL+"/start", nil)
45
45
+
require.NoError(t, err)
46
46
+
47
47
+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
48
48
+
defer cancel()
49
49
+
50
50
+
// The Client should block localhost/127.0.0.1 due to SSRF protection
51
51
+
_, err = Do(ctx, req)
52
52
+
require.Error(t, err, "Client should block requests to localhost")
53
53
+
require.Contains(t, err.Error(), "private/invalid", "Error should mention private/invalid IPs")
54
54
+
}
55
55
+
56
56
+
func TestClientCanAccessExternal(t *testing.T) {
57
57
+
// Temporarily disable the test override to test actual Client behavior
58
58
+
originalOverride := defaultClientOptions.OverrideInTest
59
59
+
defaultClientOptions.OverrideInTest = false
60
60
+
61
61
+
// Reinitialize the Client with SSRF protection
62
62
+
Client = http.Client{
63
63
+
Transport: NewUntrustedTransport(),
64
64
+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
65
65
+
return http.ErrUseLastResponse
66
66
+
},
67
67
+
Timeout: 30 * time.Second,
68
68
+
}
69
69
+
70
70
+
defer func() {
71
71
+
defaultClientOptions.OverrideInTest = originalOverride
72
72
+
// Restore to TrustedClient for other tests
73
73
+
Client = TrustedClient
74
74
+
}()
75
75
+
76
76
+
req, err := http.NewRequest("GET", "https://plc.directory", nil)
77
77
+
require.NoError(t, err)
78
78
+
79
79
+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
80
80
+
defer cancel()
81
81
+
82
82
+
_, err = Do(ctx, req)
83
83
+
require.NoError(t, err, "Client shouldn't block requests to plc.directory")
84
84
+
//require.Contains(t, err.Error(), "private/invalid", "Error should mention private/invalid IPs")
85
85
+
}
86
86
+
87
87
+
func TestTrustedClientFollowsRedirects(t *testing.T) {
88
88
+
redirectCount := 0
89
89
+
finalCount := 0
90
90
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91
91
+
if r.URL.Path == "/start" {
92
92
+
redirectCount++
93
93
+
http.Redirect(w, r, "/end", http.StatusTemporaryRedirect)
94
94
+
return
95
95
+
}
96
96
+
finalCount++
97
97
+
w.WriteHeader(http.StatusOK)
98
98
+
}))
99
99
+
defer server.Close()
100
100
+
101
101
+
req, err := http.NewRequest("GET", server.URL+"/start", nil)
102
102
+
require.NoError(t, err)
103
103
+
104
104
+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
105
105
+
defer cancel()
106
106
+
107
107
+
resp, err := DoTrusted(ctx, req)
108
108
+
require.NoError(t, err)
109
109
+
require.NotNil(t, resp)
110
110
+
defer resp.Body.Close()
111
111
+
112
112
+
require.Equal(t, http.StatusOK, resp.StatusCode, "TrustedClient should follow redirects")
113
113
+
require.Equal(t, 1, redirectCount, "Redirect handler should have been called once")
114
114
+
require.Equal(t, 1, finalCount, "Final handler should have been called once")
115
115
+
}
116
116
+
117
117
+
func TestClientTimeout(t *testing.T) {
118
118
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
119
119
+
time.Sleep(2 * time.Second)
120
120
+
w.WriteHeader(http.StatusOK)
121
121
+
}))
122
122
+
defer server.Close()
123
123
+
124
124
+
req, err := http.NewRequest("GET", server.URL, nil)
125
125
+
require.NoError(t, err)
126
126
+
127
127
+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
128
128
+
defer cancel()
129
129
+
130
130
+
_, err = Do(ctx, req)
131
131
+
require.Error(t, err, "Request should timeout")
132
132
+
}
133
133
+
134
134
+
func TestTrustedClientTimeout(t *testing.T) {
135
135
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
136
136
+
time.Sleep(2 * time.Second)
137
137
+
w.WriteHeader(http.StatusOK)
138
138
+
}))
139
139
+
defer server.Close()
140
140
+
141
141
+
req, err := http.NewRequest("GET", server.URL, nil)
142
142
+
require.NoError(t, err)
143
143
+
144
144
+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
145
145
+
defer cancel()
146
146
+
147
147
+
_, err = DoTrusted(ctx, req)
148
148
+
require.Error(t, err, "Request should timeout")
149
149
+
}
150
150
+
151
151
+
func TestSuccessfulRequest(t *testing.T) {
152
152
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
153
153
+
w.WriteHeader(http.StatusOK)
154
154
+
_, err := w.Write([]byte("success"))
155
155
+
if err != nil {
156
156
+
http.Error(w, "failed to write response", http.StatusInternalServerError)
157
157
+
}
158
158
+
}))
159
159
+
defer server.Close()
160
160
+
161
161
+
req, err := http.NewRequest("GET", server.URL, nil)
162
162
+
require.NoError(t, err)
163
163
+
164
164
+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
165
165
+
defer cancel()
166
166
+
167
167
+
resp, err := Do(ctx, req)
168
168
+
require.NoError(t, err)
169
169
+
require.NotNil(t, resp)
170
170
+
defer resp.Body.Close()
171
171
+
172
172
+
require.Equal(t, http.StatusOK, resp.StatusCode)
173
173
+
}
-1
pkg/aqhttp/resolv.go
···
43
43
ipv6Bogons := []string{
44
44
"::/128", // Unspecified
45
45
"::1/128", // Loopback
46
46
-
"::ffff:0:0/96", // IPv4-mapped addresses
47
46
"100::/64", // Discard prefix
48
47
"2001::/32", // TEREDO
49
48
"2001:10::/28", // Deprecated (ORCHID)