HTTP reverse proxy for Tailscale
at oidc 433 lines 13 kB view raw
1package main 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "log/slog" 9 "net/http" 10 "net/http/httptest" 11 "strings" 12 "testing" 13 14 "github.com/google/go-cmp/cmp" 15 "github.com/prometheus/client_golang/prometheus" 16 "github.com/prometheus/client_golang/prometheus/testutil" 17 "tailscale.com/client/tailscale/apitype" 18 "tailscale.com/tailcfg" 19) 20 21type fakeLocalClient struct { 22 whois func(context.Context, string) (*apitype.WhoIsResponse, error) 23} 24 25func (c *fakeLocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { 26 if c.whois == nil { 27 return nil, errors.New("not implemented") 28 } 29 return c.whois(ctx, remoteAddr) 30} 31 32func TestTSHandlers(t *testing.T) { 33 t.Parallel() 34 35 logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) 36 37 for _, tc := range []struct { 38 name string 39 whois func(context.Context, string) (*apitype.WhoIsResponse, error) 40 handler func(*slog.Logger, tailscaleLocalClient, http.Handler) http.Handler 41 wantNext bool 42 wantStatus int 43 wantHeaders map[string]string 44 wantBody string 45 }{ 46 { 47 name: "tailnet: tailscale whois error", 48 handler: tailnet, 49 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 50 return nil, errors.New("whois error") 51 }, 52 wantStatus: http.StatusInternalServerError, 53 wantBody: "Internal Server Error", 54 }, 55 { 56 name: "tailnet: tailscale whois no profile", 57 handler: tailnet, 58 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 59 return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 60 }, 61 wantStatus: http.StatusInternalServerError, 62 wantBody: "Internal Server Error", 63 }, 64 { 65 name: "tailnet: tailscale whois no node", 66 handler: tailnet, 67 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 68 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil 69 }, 70 wantStatus: http.StatusInternalServerError, 71 wantBody: "Internal Server Error", 72 }, 73 { 74 name: "tailnet: tailscale whois ok (tagged node)", 75 handler: tailnet, 76 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 77 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 78 }, 79 wantNext: true, 80 wantStatus: http.StatusOK, 81 wantBody: "OK", 82 wantHeaders: map[string]string{ 83 "X-Webauth-User": "", 84 "X-Webauth-Name": "", 85 }, 86 }, 87 { 88 name: "tailnet: tailscale whois ok (user)", 89 handler: tailnet, 90 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 91 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil 92 }, 93 wantNext: true, 94 wantStatus: http.StatusOK, 95 wantBody: "OK", 96 wantHeaders: map[string]string{ 97 "X-Webauth-User": "login", 98 "X-Webauth-Name": "name", 99 }, 100 }, 101 { 102 name: "insecure: tailscale whois error", 103 handler: insecureFunnel, 104 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 105 return nil, errors.New("whois error") 106 }, 107 wantStatus: http.StatusInternalServerError, 108 wantBody: "Internal Server Error", 109 }, 110 { 111 name: "insecure: tailscale whois no profile", 112 handler: insecureFunnel, 113 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 114 return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 115 }, 116 wantStatus: http.StatusInternalServerError, 117 wantBody: "Internal Server Error", 118 }, 119 { 120 name: "insure: tailscale whois no node", 121 handler: insecureFunnel, 122 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 123 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil 124 }, 125 wantStatus: http.StatusInternalServerError, 126 wantBody: "Internal Server Error", 127 }, 128 { 129 name: "insecure: tagged node", 130 handler: insecureFunnel, 131 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 132 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 133 }, 134 wantNext: true, 135 wantStatus: http.StatusOK, 136 wantBody: "OK", 137 wantHeaders: map[string]string{ 138 "X-Webauth-User": "", 139 "X-Webauth-Name": "", 140 }, 141 }, 142 { 143 name: "insecure: user node", 144 handler: insecureFunnel, 145 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 146 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil 147 }, 148 wantStatus: http.StatusUnauthorized, 149 wantBody: "Unauthorized", 150 }, 151 { 152 name: "oidc: tailscale whois error", 153 handler: oidcFunnel, 154 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 155 return nil, errors.New("whois error") 156 }, 157 wantStatus: http.StatusInternalServerError, 158 wantBody: "Internal Server Error", 159 }, 160 { 161 name: "oidc: tailscale whois no profile", 162 handler: oidcFunnel, 163 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 164 return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 165 }, 166 wantStatus: http.StatusInternalServerError, 167 wantBody: "Internal Server Error", 168 }, 169 { 170 name: "oidc: tailscale whois no node", 171 handler: oidcFunnel, 172 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 173 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil 174 }, 175 wantStatus: http.StatusInternalServerError, 176 wantBody: "Internal Server Error", 177 }, 178 { 179 name: "oidc: user node", 180 handler: oidcFunnel, 181 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 182 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil 183 }, 184 wantStatus: http.StatusUnauthorized, 185 wantBody: "Unauthorized", 186 }, 187 { 188 name: "oidc: tagged node", 189 handler: oidcFunnel, 190 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 191 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"tag:ingress"}}}, nil 192 }, 193 wantStatus: http.StatusUnauthorized, 194 wantBody: "Unauthorized", 195 }, 196 } { 197 t.Run(tc.name, func(t *testing.T) { 198 t.Parallel() 199 200 var nextReq *http.Request 201 h := tc.handler(logger, &fakeLocalClient{whois: tc.whois}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 202 nextReq = r 203 fmt.Fprintf(w, "OK") 204 })) 205 w := httptest.NewRecorder() 206 h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "http://example.com/path", nil)) 207 resp := w.Result() 208 209 if want, got := tc.wantStatus, resp.StatusCode; want != got { 210 t.Errorf("want status %d, got: %d", want, got) 211 } 212 213 body, err := io.ReadAll(resp.Body) 214 if err != nil { 215 t.Fatal(err) 216 } 217 if !strings.Contains(string(body), tc.wantBody) { 218 t.Errorf("want body %q, got: %q", tc.wantBody, string(body)) 219 } 220 if tc.wantNext && nextReq == nil { 221 t.Fatalf("next handler not called") 222 } 223 for k, want := range tc.wantHeaders { 224 if got := nextReq.Header.Get(k); got != want { 225 t.Errorf("want header %s = %s, got: %s", k, want, got) 226 } 227 } 228 }) 229 } 230} 231 232func TestRedirectHandler(t *testing.T) { 233 t.Parallel() 234 235 for _, tc := range []struct { 236 name string 237 forceSSL bool 238 fqdn string 239 request *http.Request 240 wantNext bool 241 wantStatus int 242 wantLocation string 243 }{ 244 { 245 name: "forceSSL: redirect", 246 forceSSL: true, 247 fqdn: "http://example.com", 248 request: httptest.NewRequest("", "/path", nil), 249 wantStatus: http.StatusPermanentRedirect, 250 wantLocation: "https://example.com/path", 251 }, 252 { 253 name: "forceSSL: ok", 254 forceSSL: true, 255 fqdn: "example.com", 256 request: httptest.NewRequest("", "https://example.com/path", nil), 257 wantNext: true, 258 wantStatus: http.StatusOK, 259 }, 260 { 261 name: "fqdn: redirect", 262 fqdn: "example.ts.net", 263 request: httptest.NewRequest("", "https://example/path", nil), 264 wantStatus: http.StatusPermanentRedirect, 265 wantLocation: "https://example.ts.net/path", 266 }, 267 { 268 name: "fqdn: ok", 269 fqdn: "example.ts.net", 270 request: httptest.NewRequest("", "https://example.ts.net/path", nil), 271 wantNext: true, 272 wantStatus: http.StatusOK, 273 }, 274 { 275 name: "fqdn: ok (not tls)", 276 fqdn: "example.ts.net", 277 request: httptest.NewRequest("", "/path", nil), 278 wantNext: true, 279 wantStatus: http.StatusOK, 280 }, 281 } { 282 t.Run(tc.name, func(t *testing.T) { 283 t.Parallel() 284 285 var nextReq *http.Request 286 h := redirect(tc.fqdn, tc.forceSSL, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 287 nextReq = r 288 fmt.Fprintf(w, "OK") 289 })) 290 w := httptest.NewRecorder() 291 h.ServeHTTP(w, tc.request) 292 resp := w.Result() 293 294 if want, got := tc.wantStatus, resp.StatusCode; want != got { 295 t.Errorf("want status %d, got: %d", want, got) 296 } 297 298 if tc.wantNext && nextReq == nil { 299 t.Fatalf("next handler not called") 300 } 301 if !tc.wantNext && nextReq != nil { 302 t.Fatalf("next handler was called") 303 } 304 if nextReq != nil { 305 if want, got := tc.wantLocation, nextReq.Header.Get("Location"); got != want { 306 t.Errorf("want Location header %s, got: %s", want, got) 307 } 308 } 309 }) 310 } 311} 312 313func TestBasicAuthHandler(t *testing.T) { 314 t.Parallel() 315 316 logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) 317 318 for _, tc := range []struct { 319 name string 320 user string 321 password string 322 request func(*http.Request) 323 wantNext bool 324 wantStatus int 325 }{ 326 { 327 name: "no basic auth provided", 328 user: "admin", 329 password: "secret", 330 request: func(_ *http.Request) {}, 331 wantStatus: http.StatusUnauthorized, 332 }, 333 { 334 name: "wrong user", 335 user: "admin", 336 password: "secret", 337 request: func(r *http.Request) { r.SetBasicAuth("bad", "secret") }, 338 wantStatus: http.StatusUnauthorized, 339 }, 340 { 341 name: "wrong password", 342 user: "admin", 343 password: "secret", 344 request: func(r *http.Request) { r.SetBasicAuth("admin", "bad") }, 345 wantStatus: http.StatusUnauthorized, 346 }, 347 { 348 name: "ok", 349 user: "admin", 350 password: "secret", 351 request: func(r *http.Request) { r.SetBasicAuth("admin", "secret") }, 352 wantNext: true, 353 wantStatus: http.StatusOK, 354 }, 355 } { 356 t.Run(tc.name, func(t *testing.T) { 357 t.Parallel() 358 359 var nextReq *http.Request 360 h := basicAuth(logger, tc.user, tc.password, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 361 nextReq = r 362 fmt.Fprintf(w, "OK") 363 })) 364 w := httptest.NewRecorder() 365 req := httptest.NewRequest("", "/", nil) 366 tc.request(req) 367 h.ServeHTTP(w, req) 368 resp := w.Result() 369 370 if want, got := tc.wantStatus, resp.StatusCode; want != got { 371 t.Errorf("want status %d, got: %d", want, got) 372 } 373 374 if tc.wantNext && nextReq == nil { 375 t.Fatalf("next handler not called") 376 } 377 if !tc.wantNext && nextReq != nil { 378 t.Fatalf("next handler should not have been called") 379 } 380 }) 381 } 382} 383 384func TestServeDiscovery(t *testing.T) { 385 t.Parallel() 386 387 ts := httptest.NewServer(serveDiscovery("self", []target{ 388 {magicDNS: "b", prometheus: true}, 389 {magicDNS: "x"}, 390 {}, 391 {magicDNS: "a", prometheus: true}, 392 })) 393 defer ts.Close() 394 395 resp, err := http.Get(ts.URL) 396 if err != nil { 397 t.Fatal(err) 398 } 399 defer resp.Body.Close() 400 if want, got := http.StatusOK, resp.StatusCode; want != got { 401 t.Errorf("want status %d, got: %d", want, got) 402 } 403 b, err := io.ReadAll(resp.Body) 404 if err != nil { 405 t.Fatal(err) 406 } 407 if diff := cmp.Diff(`[{"targets":["a","b","self"]}]`, string(b)); diff != "" { 408 t.Errorf("body mismatch (-want +got):\n%s", diff) 409 } 410} 411 412func TestMetrics(t *testing.T) { 413 t.Parallel() 414 415 c, err := testutil.GatherAndCount(prometheus.DefaultGatherer) 416 if err != nil { 417 t.Fatalf("GatherAndCount: %v", err) 418 } 419 if c == 0 { 420 t.Fatalf("no metrics collected") 421 } 422 423 lint, err := testutil.GatherAndLint(prometheus.DefaultGatherer) 424 if err != nil { 425 t.Fatalf("CollectAndLint: %v", err) 426 } 427 if len(lint) > 0 { 428 t.Error("lint problems detected") 429 } 430 for _, prob := range lint { 431 t.Errorf("lint: %s: %s", prob.Metric, prob.Text) 432 } 433}