Live video on the AT Protocol

fix: untrusted resolution blocking all v4 addresses

+182 -2
+9 -1
pkg/aqhttp/aqhttp.go
··· 20 20 // where the validation overhead is problematic. 21 21 var TrustedClient http.Client 22 22 23 + type ClientOptions struct { 24 + OverrideInTest bool 25 + } 26 + 27 + var defaultClientOptions = ClientOptions{ 28 + OverrideInTest: true, 29 + } 30 + 23 31 func init() { 24 32 // Initialize the trusted client first. 25 33 TrustedClient = http.Client{ ··· 38 46 39 47 // When running under `go test` the test binary name typically ends with ".test". 40 48 // In that case, use the trusted client to avoid SSRF blocking for localhost tests. 41 - if len(os.Args) > 0 && strings.HasSuffix(os.Args[0], ".test") { 49 + if defaultClientOptions.OverrideInTest && len(os.Args) > 0 && strings.HasSuffix(os.Args[0], ".test") { 42 50 Client = TrustedClient 43 51 } 44 52 }
+173
pkg/aqhttp/aqhttp_test.go
··· 1 + package aqhttp 2 + 3 + import ( 4 + "context" 5 + "net/http" 6 + "net/http/httptest" 7 + "testing" 8 + "time" 9 + 10 + "github.com/stretchr/testify/require" 11 + ) 12 + 13 + func TestClientRedirects(t *testing.T) { 14 + // Temporarily disable the test override to test actual Client behavior 15 + originalOverride := defaultClientOptions.OverrideInTest 16 + defaultClientOptions.OverrideInTest = false 17 + 18 + // Reinitialize the Client with SSRF protection 19 + Client = http.Client{ 20 + Transport: NewUntrustedTransport(), 21 + CheckRedirect: func(req *http.Request, via []*http.Request) error { 22 + return http.ErrUseLastResponse 23 + }, 24 + Timeout: 30 * time.Second, 25 + } 26 + 27 + defer func() { 28 + defaultClientOptions.OverrideInTest = originalOverride 29 + // Restore to TrustedClient for other tests 30 + Client = TrustedClient 31 + }() 32 + 33 + redirectCount := 0 34 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 + if r.URL.Path == "/start" { 36 + redirectCount++ 37 + http.Redirect(w, r, "/end", http.StatusTemporaryRedirect) 38 + return 39 + } 40 + w.WriteHeader(http.StatusOK) 41 + })) 42 + defer server.Close() 43 + 44 + req, err := http.NewRequest("GET", server.URL+"/start", nil) 45 + require.NoError(t, err) 46 + 47 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 48 + defer cancel() 49 + 50 + // The Client should block localhost/127.0.0.1 due to SSRF protection 51 + _, err = Do(ctx, req) 52 + require.Error(t, err, "Client should block requests to localhost") 53 + require.Contains(t, err.Error(), "private/invalid", "Error should mention private/invalid IPs") 54 + } 55 + 56 + func TestClientCanAccessExternal(t *testing.T) { 57 + // Temporarily disable the test override to test actual Client behavior 58 + originalOverride := defaultClientOptions.OverrideInTest 59 + defaultClientOptions.OverrideInTest = false 60 + 61 + // Reinitialize the Client with SSRF protection 62 + Client = http.Client{ 63 + Transport: NewUntrustedTransport(), 64 + CheckRedirect: func(req *http.Request, via []*http.Request) error { 65 + return http.ErrUseLastResponse 66 + }, 67 + Timeout: 30 * time.Second, 68 + } 69 + 70 + defer func() { 71 + defaultClientOptions.OverrideInTest = originalOverride 72 + // Restore to TrustedClient for other tests 73 + Client = TrustedClient 74 + }() 75 + 76 + req, err := http.NewRequest("GET", "https://plc.directory", nil) 77 + require.NoError(t, err) 78 + 79 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 80 + defer cancel() 81 + 82 + _, err = Do(ctx, req) 83 + require.NoError(t, err, "Client shouldn't block requests to plc.directory") 84 + //require.Contains(t, err.Error(), "private/invalid", "Error should mention private/invalid IPs") 85 + } 86 + 87 + func TestTrustedClientFollowsRedirects(t *testing.T) { 88 + redirectCount := 0 89 + finalCount := 0 90 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 91 + if r.URL.Path == "/start" { 92 + redirectCount++ 93 + http.Redirect(w, r, "/end", http.StatusTemporaryRedirect) 94 + return 95 + } 96 + finalCount++ 97 + w.WriteHeader(http.StatusOK) 98 + })) 99 + defer server.Close() 100 + 101 + req, err := http.NewRequest("GET", server.URL+"/start", nil) 102 + require.NoError(t, err) 103 + 104 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 105 + defer cancel() 106 + 107 + resp, err := DoTrusted(ctx, req) 108 + require.NoError(t, err) 109 + require.NotNil(t, resp) 110 + defer resp.Body.Close() 111 + 112 + require.Equal(t, http.StatusOK, resp.StatusCode, "TrustedClient should follow redirects") 113 + require.Equal(t, 1, redirectCount, "Redirect handler should have been called once") 114 + require.Equal(t, 1, finalCount, "Final handler should have been called once") 115 + } 116 + 117 + func TestClientTimeout(t *testing.T) { 118 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 119 + time.Sleep(2 * time.Second) 120 + w.WriteHeader(http.StatusOK) 121 + })) 122 + defer server.Close() 123 + 124 + req, err := http.NewRequest("GET", server.URL, nil) 125 + require.NoError(t, err) 126 + 127 + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 128 + defer cancel() 129 + 130 + _, err = Do(ctx, req) 131 + require.Error(t, err, "Request should timeout") 132 + } 133 + 134 + func TestTrustedClientTimeout(t *testing.T) { 135 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 136 + time.Sleep(2 * time.Second) 137 + w.WriteHeader(http.StatusOK) 138 + })) 139 + defer server.Close() 140 + 141 + req, err := http.NewRequest("GET", server.URL, nil) 142 + require.NoError(t, err) 143 + 144 + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 145 + defer cancel() 146 + 147 + _, err = DoTrusted(ctx, req) 148 + require.Error(t, err, "Request should timeout") 149 + } 150 + 151 + func TestSuccessfulRequest(t *testing.T) { 152 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 153 + w.WriteHeader(http.StatusOK) 154 + _, err := w.Write([]byte("success")) 155 + if err != nil { 156 + http.Error(w, "failed to write response", http.StatusInternalServerError) 157 + } 158 + })) 159 + defer server.Close() 160 + 161 + req, err := http.NewRequest("GET", server.URL, nil) 162 + require.NoError(t, err) 163 + 164 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 165 + defer cancel() 166 + 167 + resp, err := Do(ctx, req) 168 + require.NoError(t, err) 169 + require.NotNil(t, resp) 170 + defer resp.Body.Close() 171 + 172 + require.Equal(t, http.StatusOK, resp.StatusCode) 173 + }
-1
pkg/aqhttp/resolv.go
··· 43 43 ipv6Bogons := []string{ 44 44 "::/128", // Unspecified 45 45 "::1/128", // Loopback 46 - "::ffff:0:0/96", // IPv4-mapped addresses 47 46 "100::/64", // Discard prefix 48 47 "2001::/32", // TEREDO 49 48 "2001:10::/28", // Deprecated (ORCHID)