HTTP reverse proxy for Tailscale
1package main
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log/slog"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13
14 "github.com/google/go-cmp/cmp"
15 "github.com/prometheus/client_golang/prometheus"
16 "github.com/prometheus/client_golang/prometheus/testutil"
17 "tailscale.com/client/tailscale/apitype"
18 "tailscale.com/tailcfg"
19)
20
21type fakeLocalClient struct {
22 whois func(context.Context, string) (*apitype.WhoIsResponse, error)
23}
24
25func (c *fakeLocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
26 if c.whois == nil {
27 return nil, errors.New("not implemented")
28 }
29 return c.whois(ctx, remoteAddr)
30}
31
32func TestTSHandlers(t *testing.T) {
33 t.Parallel()
34
35 logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
36
37 for _, tc := range []struct {
38 name string
39 whois func(context.Context, string) (*apitype.WhoIsResponse, error)
40 handler func(*slog.Logger, tailscaleLocalClient, http.Handler) http.Handler
41 wantNext bool
42 wantStatus int
43 wantHeaders map[string]string
44 wantBody string
45 }{
46 {
47 name: "tailnet: tailscale whois error",
48 handler: tailnet,
49 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
50 return nil, errors.New("whois error")
51 },
52 wantStatus: http.StatusInternalServerError,
53 wantBody: "Internal Server Error",
54 },
55 {
56 name: "tailnet: tailscale whois no profile",
57 handler: tailnet,
58 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
59 return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
60 },
61 wantStatus: http.StatusInternalServerError,
62 wantBody: "Internal Server Error",
63 },
64 {
65 name: "tailnet: tailscale whois no node",
66 handler: tailnet,
67 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
68 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil
69 },
70 wantStatus: http.StatusInternalServerError,
71 wantBody: "Internal Server Error",
72 },
73 {
74 name: "tailnet: tailscale whois ok (tagged node)",
75 handler: tailnet,
76 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
77 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
78 },
79 wantNext: true,
80 wantStatus: http.StatusOK,
81 wantBody: "OK",
82 wantHeaders: map[string]string{
83 "X-Webauth-User": "",
84 "X-Webauth-Name": "",
85 },
86 },
87 {
88 name: "tailnet: tailscale whois ok (user)",
89 handler: tailnet,
90 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
91 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil
92 },
93 wantNext: true,
94 wantStatus: http.StatusOK,
95 wantBody: "OK",
96 wantHeaders: map[string]string{
97 "X-Webauth-User": "login",
98 "X-Webauth-Name": "name",
99 },
100 },
101 {
102 name: "insecure: tailscale whois error",
103 handler: insecureFunnel,
104 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
105 return nil, errors.New("whois error")
106 },
107 wantStatus: http.StatusInternalServerError,
108 wantBody: "Internal Server Error",
109 },
110 {
111 name: "insecure: tailscale whois no profile",
112 handler: insecureFunnel,
113 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
114 return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
115 },
116 wantStatus: http.StatusInternalServerError,
117 wantBody: "Internal Server Error",
118 },
119 {
120 name: "insure: tailscale whois no node",
121 handler: insecureFunnel,
122 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
123 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil
124 },
125 wantStatus: http.StatusInternalServerError,
126 wantBody: "Internal Server Error",
127 },
128 {
129 name: "insecure: tagged node",
130 handler: insecureFunnel,
131 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
132 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
133 },
134 wantNext: true,
135 wantStatus: http.StatusOK,
136 wantBody: "OK",
137 wantHeaders: map[string]string{
138 "X-Webauth-User": "",
139 "X-Webauth-Name": "",
140 },
141 },
142 {
143 name: "insecure: user node",
144 handler: insecureFunnel,
145 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
146 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil
147 },
148 wantStatus: http.StatusUnauthorized,
149 wantBody: "Unauthorized",
150 },
151 {
152 name: "oidc: tailscale whois error",
153 handler: oidcFunnel,
154 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
155 return nil, errors.New("whois error")
156 },
157 wantStatus: http.StatusInternalServerError,
158 wantBody: "Internal Server Error",
159 },
160 {
161 name: "oidc: tailscale whois no profile",
162 handler: oidcFunnel,
163 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
164 return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
165 },
166 wantStatus: http.StatusInternalServerError,
167 wantBody: "Internal Server Error",
168 },
169 {
170 name: "oidc: tailscale whois no node",
171 handler: oidcFunnel,
172 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
173 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil
174 },
175 wantStatus: http.StatusInternalServerError,
176 wantBody: "Internal Server Error",
177 },
178 {
179 name: "oidc: user node",
180 handler: oidcFunnel,
181 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
182 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil
183 },
184 wantStatus: http.StatusUnauthorized,
185 wantBody: "Unauthorized",
186 },
187 {
188 name: "oidc: tagged node",
189 handler: oidcFunnel,
190 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
191 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"tag:ingress"}}}, nil
192 },
193 wantStatus: http.StatusUnauthorized,
194 wantBody: "Unauthorized",
195 },
196 } {
197 t.Run(tc.name, func(t *testing.T) {
198 t.Parallel()
199
200 var nextReq *http.Request
201 h := tc.handler(logger, &fakeLocalClient{whois: tc.whois}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
202 nextReq = r
203 fmt.Fprintf(w, "OK")
204 }))
205 w := httptest.NewRecorder()
206 h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "http://example.com/path", nil))
207 resp := w.Result()
208
209 if want, got := tc.wantStatus, resp.StatusCode; want != got {
210 t.Errorf("want status %d, got: %d", want, got)
211 }
212
213 body, err := io.ReadAll(resp.Body)
214 if err != nil {
215 t.Fatal(err)
216 }
217 if !strings.Contains(string(body), tc.wantBody) {
218 t.Errorf("want body %q, got: %q", tc.wantBody, string(body))
219 }
220 if tc.wantNext && nextReq == nil {
221 t.Fatalf("next handler not called")
222 }
223 for k, want := range tc.wantHeaders {
224 if got := nextReq.Header.Get(k); got != want {
225 t.Errorf("want header %s = %s, got: %s", k, want, got)
226 }
227 }
228 })
229 }
230}
231
232func TestRedirectHandler(t *testing.T) {
233 t.Parallel()
234
235 for _, tc := range []struct {
236 name string
237 forceSSL bool
238 fqdn string
239 request *http.Request
240 wantNext bool
241 wantStatus int
242 wantLocation string
243 }{
244 {
245 name: "forceSSL: redirect",
246 forceSSL: true,
247 fqdn: "http://example.com",
248 request: httptest.NewRequest("", "/path", nil),
249 wantStatus: http.StatusPermanentRedirect,
250 wantLocation: "https://example.com/path",
251 },
252 {
253 name: "forceSSL: ok",
254 forceSSL: true,
255 fqdn: "example.com",
256 request: httptest.NewRequest("", "https://example.com/path", nil),
257 wantNext: true,
258 wantStatus: http.StatusOK,
259 },
260 {
261 name: "fqdn: redirect",
262 fqdn: "example.ts.net",
263 request: httptest.NewRequest("", "https://example/path", nil),
264 wantStatus: http.StatusPermanentRedirect,
265 wantLocation: "https://example.ts.net/path",
266 },
267 {
268 name: "fqdn: ok",
269 fqdn: "example.ts.net",
270 request: httptest.NewRequest("", "https://example.ts.net/path", nil),
271 wantNext: true,
272 wantStatus: http.StatusOK,
273 },
274 {
275 name: "fqdn: ok (not tls)",
276 fqdn: "example.ts.net",
277 request: httptest.NewRequest("", "/path", nil),
278 wantNext: true,
279 wantStatus: http.StatusOK,
280 },
281 } {
282 t.Run(tc.name, func(t *testing.T) {
283 t.Parallel()
284
285 var nextReq *http.Request
286 h := redirect(tc.fqdn, tc.forceSSL, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
287 nextReq = r
288 fmt.Fprintf(w, "OK")
289 }))
290 w := httptest.NewRecorder()
291 h.ServeHTTP(w, tc.request)
292 resp := w.Result()
293
294 if want, got := tc.wantStatus, resp.StatusCode; want != got {
295 t.Errorf("want status %d, got: %d", want, got)
296 }
297
298 if tc.wantNext && nextReq == nil {
299 t.Fatalf("next handler not called")
300 }
301 if !tc.wantNext && nextReq != nil {
302 t.Fatalf("next handler was called")
303 }
304 if nextReq != nil {
305 if want, got := tc.wantLocation, nextReq.Header.Get("Location"); got != want {
306 t.Errorf("want Location header %s, got: %s", want, got)
307 }
308 }
309 })
310 }
311}
312
313func TestBasicAuthHandler(t *testing.T) {
314 t.Parallel()
315
316 logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
317
318 for _, tc := range []struct {
319 name string
320 user string
321 password string
322 request func(*http.Request)
323 wantNext bool
324 wantStatus int
325 }{
326 {
327 name: "no basic auth provided",
328 user: "admin",
329 password: "secret",
330 request: func(_ *http.Request) {},
331 wantStatus: http.StatusUnauthorized,
332 },
333 {
334 name: "wrong user",
335 user: "admin",
336 password: "secret",
337 request: func(r *http.Request) { r.SetBasicAuth("bad", "secret") },
338 wantStatus: http.StatusUnauthorized,
339 },
340 {
341 name: "wrong password",
342 user: "admin",
343 password: "secret",
344 request: func(r *http.Request) { r.SetBasicAuth("admin", "bad") },
345 wantStatus: http.StatusUnauthorized,
346 },
347 {
348 name: "ok",
349 user: "admin",
350 password: "secret",
351 request: func(r *http.Request) { r.SetBasicAuth("admin", "secret") },
352 wantNext: true,
353 wantStatus: http.StatusOK,
354 },
355 } {
356 t.Run(tc.name, func(t *testing.T) {
357 t.Parallel()
358
359 var nextReq *http.Request
360 h := basicAuth(logger, tc.user, tc.password, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
361 nextReq = r
362 fmt.Fprintf(w, "OK")
363 }))
364 w := httptest.NewRecorder()
365 req := httptest.NewRequest("", "/", nil)
366 tc.request(req)
367 h.ServeHTTP(w, req)
368 resp := w.Result()
369
370 if want, got := tc.wantStatus, resp.StatusCode; want != got {
371 t.Errorf("want status %d, got: %d", want, got)
372 }
373
374 if tc.wantNext && nextReq == nil {
375 t.Fatalf("next handler not called")
376 }
377 if !tc.wantNext && nextReq != nil {
378 t.Fatalf("next handler should not have been called")
379 }
380 })
381 }
382}
383
384func TestServeDiscovery(t *testing.T) {
385 t.Parallel()
386
387 ts := httptest.NewServer(serveDiscovery("self", []target{
388 {magicDNS: "b", prometheus: true},
389 {magicDNS: "x"},
390 {},
391 {magicDNS: "a", prometheus: true},
392 }))
393 defer ts.Close()
394
395 resp, err := http.Get(ts.URL)
396 if err != nil {
397 t.Fatal(err)
398 }
399 defer resp.Body.Close()
400 if want, got := http.StatusOK, resp.StatusCode; want != got {
401 t.Errorf("want status %d, got: %d", want, got)
402 }
403 b, err := io.ReadAll(resp.Body)
404 if err != nil {
405 t.Fatal(err)
406 }
407 if diff := cmp.Diff(`[{"targets":["a","b","self"]}]`, string(b)); diff != "" {
408 t.Errorf("body mismatch (-want +got):\n%s", diff)
409 }
410}
411
412func TestMetrics(t *testing.T) {
413 t.Parallel()
414
415 c, err := testutil.GatherAndCount(prometheus.DefaultGatherer)
416 if err != nil {
417 t.Fatalf("GatherAndCount: %v", err)
418 }
419 if c == 0 {
420 t.Fatalf("no metrics collected")
421 }
422
423 lint, err := testutil.GatherAndLint(prometheus.DefaultGatherer)
424 if err != nil {
425 t.Fatalf("CollectAndLint: %v", err)
426 }
427 if len(lint) > 0 {
428 t.Error("lint problems detected")
429 }
430 for _, prob := range lint {
431 t.Errorf("lint: %s: %s", prob.Metric, prob.Text)
432 }
433}