HTTP reverse proxy for Tailscale
at oidc 599 lines 17 kB view raw
1package main 2 3import ( 4 "context" 5 "crypto/subtle" 6 "crypto/tls" 7 "encoding/json" 8 "errors" 9 "flag" 10 "fmt" 11 "log/slog" 12 "net" 13 "net/http" 14 "net/http/httputil" 15 "net/netip" 16 "net/url" 17 "os" 18 "path/filepath" 19 "sort" 20 "strconv" 21 "strings" 22 "syscall" 23 24 "github.com/lstoll/oidc" 25 "github.com/lstoll/oidc/middleware" 26 "github.com/oklog/run" 27 "github.com/prometheus/client_golang/prometheus" 28 versioncollector "github.com/prometheus/client_golang/prometheus/collectors/version" 29 "github.com/prometheus/client_golang/prometheus/promauto" 30 "github.com/prometheus/client_golang/prometheus/promhttp" 31 "github.com/prometheus/common/version" 32 "github.com/tailscale/hujson" 33 "tailscale.com/client/local" 34 "tailscale.com/client/tailscale/apitype" 35 "tailscale.com/ipn" 36 "tailscale.com/tsnet" 37 tslogger "tailscale.com/types/logger" 38) 39 40// ctxConn is a key to look up a net.Conn stored in an HTTP request's context. 41type ctxConn struct{} 42 43var ( 44 requestsInFlight = promauto.NewGaugeVec( 45 prometheus.GaugeOpts{ 46 Namespace: "tsproxy", 47 Name: "requests_in_flight", 48 Help: "Number of requests currently being served by the server.", 49 }, 50 []string{"upstream"}, 51 ) 52 53 requests = promauto.NewCounterVec( 54 prometheus.CounterOpts{ 55 Namespace: "tsproxy", 56 Name: "requests_total", 57 Help: "Number of requests received by the server.", 58 }, 59 []string{"upstream", "code", "method"}, 60 ) 61 62 duration = promauto.NewHistogramVec( 63 prometheus.HistogramOpts{ 64 Namespace: "tsproxy", 65 Name: "request_duration_seconds", 66 Help: "A histogram of latencies for requests handled by the server.", 67 NativeHistogramBucketFactor: 1.1, 68 }, 69 []string{"upstream", "code", "method"}, 70 ) 71) 72 73type upstream struct { 74 Name string 75 Backend string 76 Prometheus bool 77 Funnel *funnelConfig 78} 79 80type funnelConfig struct { 81 Insecure bool 82 Issuer string 83 ClientID string 84 ClientSecret string 85 User string 86 Password string 87 IP []string 88} 89 90type target struct { 91 name string 92 magicDNS string 93 prometheus bool 94} 95 96func main() { 97 if err := tsproxy(context.Background()); err != nil { 98 fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err) 99 os.Exit(1) 100 } 101} 102 103func tsproxy(ctx context.Context) error { 104 var ( 105 state = flag.String("state", "", "Optional directory for storing Tailscale state.") 106 tslog = flag.Bool("tslog", false, "If true, log Tailscale output.") 107 port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.") 108 ver = flag.Bool("version", false, "print the version and exit") 109 upfile = flag.String("upstream", "", "path to upstreams config file") 110 ) 111 flag.Parse() 112 113 if *ver { 114 fmt.Fprintln(os.Stdout, version.Print("tsproxy")) 115 os.Exit(0) 116 } 117 118 if *upfile == "" { 119 return fmt.Errorf("required flag missing: upstream") 120 } 121 122 in, err := os.ReadFile(*upfile) 123 if err != nil { 124 return err 125 } 126 inJSON, err := hujson.Standardize(in) 127 if err != nil { 128 return fmt.Errorf("hujson: %w", err) 129 } 130 var upstreams []upstream 131 if err := json.Unmarshal(inJSON, &upstreams); err != nil { 132 return fmt.Errorf("json: %w", err) 133 } 134 if len(upstreams) == 0 { 135 return fmt.Errorf("file does not contain any upstreams: %s", *upfile) 136 } 137 138 if *state == "" { 139 v, err := os.UserCacheDir() 140 if err != nil { 141 return err 142 } 143 dir := filepath.Join(v, "tsproxy") 144 if err := os.MkdirAll(dir, 0o700); err != nil { 145 return err 146 } 147 state = &dir 148 } 149 prometheus.MustRegister(versioncollector.NewCollector("tsproxy")) 150 151 logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{})) 152 slog.SetDefault(logger) 153 154 // If tailscaled isn't ready yet, just crash. 155 st, err := (&local.Client{}).Status(ctx) 156 if err != nil { 157 return fmt.Errorf("tailscale: get node status: %w", err) 158 } 159 if v := len(st.Self.TailscaleIPs); v != 2 { 160 return fmt.Errorf("want 2 tailscale IPs, got %d", v) 161 } 162 163 // service discovery targets (self + all upstreams) 164 targets := make([]target, len(upstreams)+1) 165 166 var g run.Group 167 ctx, cancel := context.WithCancel(ctx) 168 defer cancel() 169 g.Add(run.SignalHandler(ctx, os.Interrupt, syscall.SIGTERM)) 170 171 { 172 p := strconv.Itoa(*port) 173 174 var listeners []net.Listener 175 for _, ip := range st.Self.TailscaleIPs { 176 ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), p)) 177 if err != nil { 178 return fmt.Errorf("listen on %s:%d: %w", ip, *port, err) 179 } 180 listeners = append(listeners, ln) 181 } 182 183 http.Handle("/metrics", promhttp.Handler()) 184 http.Handle("/sd", serveDiscovery(net.JoinHostPort(st.Self.DNSName, p), targets)) 185 http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { 186 _, _ = w.Write([]byte(`<html> 187 <head><title>tsproxy</title></head> 188 <body> 189 <h1>tsproxy</h1> 190 <p><a href="/metrics">Metrics</a></p> 191 <p><a href="/sd">Discovery</a></p> 192 </body> 193 </html>`)) 194 }) 195 196 srv := &http.Server{} 197 for _, ln := range listeners { 198 ln := ln 199 g.Add(func() error { 200 logger.Info("server ready", slog.String("addr", ln.Addr().String())) 201 return srv.Serve(ln) 202 }, func(_ error) { 203 if err := srv.Close(); err != nil { 204 logger.Error("shutdown server", lerr(err)) 205 } 206 cancel() 207 }) 208 } 209 } 210 211 for i, upstream := range upstreams { 212 log := logger.With(slog.String("upstream", upstream.Name)) 213 214 ts := &tsnet.Server{ 215 Hostname: upstream.Name, 216 Dir: filepath.Join(*state, "tailscale-"+upstream.Name), 217 RunWebClient: true, 218 } 219 defer ts.Close() 220 221 if *tslog { 222 ts.Logf = func(format string, args ...any) { 223 //nolint: sloglint 224 log.Info(fmt.Sprintf(format, args...), slog.String("logger", "tailscale")) 225 } 226 } else { 227 ts.Logf = tslogger.Discard 228 } 229 if err := os.MkdirAll(ts.Dir, 0o700); err != nil { 230 return err 231 } 232 233 lc, err := ts.LocalClient() 234 if err != nil { 235 return fmt.Errorf("tailscale: get local client for %s: %w", upstream.Name, err) 236 } 237 238 backendURL, err := url.Parse(upstream.Backend) 239 if err != nil { 240 return fmt.Errorf("upstream %s: parse backend URL: %w", upstream.Name, err) 241 } 242 // TODO(sr) Instrument proxy.Transport 243 proxy := &httputil.ReverseProxy{ 244 Rewrite: func(req *httputil.ProxyRequest) { 245 req.SetURL(backendURL) 246 req.SetXForwarded() 247 req.Out.Host = req.In.Host 248 }, 249 } 250 proxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) { 251 http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) 252 log.Error("upstream error", lerr(err)) 253 } 254 255 instrument := func(h http.Handler) http.Handler { 256 return promhttp.InstrumentHandlerInFlight( 257 requestsInFlight.With(prometheus.Labels{"upstream": upstream.Name}), 258 promhttp.InstrumentHandlerDuration( 259 duration.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 260 promhttp.InstrumentHandlerCounter( 261 requests.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 262 h, 263 ), 264 ), 265 ) 266 } 267 268 { 269 var srv *http.Server 270 g.Add(func() error { 271 st, err := ts.Up(ctx) 272 if err != nil { 273 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 274 } 275 276 srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, false, tailnet(log, lc, proxy)))} 277 ln, err := ts.Listen("tcp", ":80") 278 if err != nil { 279 return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err) 280 } 281 282 // register in service discovery when we're ready. 283 targets[i] = target{name: upstream.Name, prometheus: upstream.Prometheus, magicDNS: st.Self.DNSName} 284 285 return srv.Serve(ln) 286 }, func(_ error) { 287 if srv != nil { 288 if err := srv.Close(); err != nil { 289 log.Error("server shutdown", lerr(err)) 290 } 291 } 292 cancel() 293 }) 294 } 295 { 296 var srv *http.Server 297 g.Add(func() error { 298 st, err := ts.Up(ctx) 299 if err != nil { 300 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 301 } 302 303 srv = &http.Server{ 304 TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 305 Handler: instrument(redirect(st.Self.DNSName, true, tailnet(log, lc, proxy))), 306 } 307 308 ln, err := ts.Listen("tcp", ":443") 309 if err != nil { 310 return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.Name, err) 311 } 312 return srv.ServeTLS(ln, "", "") 313 }, func(_ error) { 314 if srv != nil { 315 if err := srv.Close(); err != nil { 316 log.Error("server shutdown", lerr(err)) 317 } 318 } 319 cancel() 320 }) 321 } 322 if funnel := upstream.Funnel; funnel != nil { 323 { 324 var srv *http.Server 325 g.Add(func() error { 326 st, err := ts.Up(ctx) 327 if err != nil { 328 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 329 } 330 331 var handler http.Handler 332 switch { 333 case funnel.Insecure: 334 handler = insecureFunnel(log, lc, proxy) 335 case funnel.Issuer != "": 336 redir := &url.URL{Scheme: "https", Host: strings.TrimSuffix(st.Self.DNSName, "."), Path: ".oidc-callback"} 337 wrapper, err := middleware.NewFromDiscovery(ctx, nil, funnel.Issuer, funnel.ClientID, funnel.ClientSecret, redir.String()) 338 if err != nil { 339 return fmt.Errorf("oidc middleware for %s: %w", upstream.Name, err) 340 } 341 wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile) 342 343 handler = wrapper.Wrap(oidcFunnel(log, lc, proxy)) 344 case funnel.User != "": 345 handler = insecureFunnel(log, lc, basicAuth(log, funnel.User, funnel.Password, proxy)) 346 default: 347 return fmt.Errorf("upstream %s must set funnel.insecure or funnel.issuer", upstream.Name) 348 } 349 350 handler = redirect(st.Self.DNSName, true, handler) 351 352 if len(funnel.IP) > 0 { 353 var allow []netip.Prefix 354 for _, ip := range funnel.IP { 355 allow = append(allow, netip.MustParsePrefix(ip)) 356 } 357 handler = restrictNetworks(log, allow, handler) 358 } 359 360 srv = &http.Server{ 361 Handler: instrument(handler), 362 ConnContext: func(ctx context.Context, c net.Conn) context.Context { 363 return context.WithValue(ctx, ctxConn{}, c) 364 }, 365 } 366 367 ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) 368 if err != nil { 369 return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.Name, err) 370 } 371 return srv.Serve(ln) 372 }, func(_ error) { 373 if srv != nil { 374 if err := srv.Close(); err != nil { 375 log.Error("server shutdown", lerr(err)) 376 } 377 } 378 cancel() 379 }) 380 } 381 } 382 } 383 384 return g.Run() 385} 386 387func redirect(fqdn string, forceSSL bool, next http.Handler) http.Handler { 388 if fqdn == "" { 389 panic("redirect: fqdn cannot be empty") 390 } 391 fqdn = strings.TrimSuffix(fqdn, ".") 392 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 393 if forceSSL && r.TLS == nil { 394 http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect) 395 return 396 } 397 398 if r.TLS != nil && strings.TrimSuffix(r.Host, ".") != fqdn { 399 http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect) 400 return 401 } 402 next.ServeHTTP(w, r) 403 }) 404} 405 406func basicAuth(logger *slog.Logger, user, password string, next http.Handler) http.Handler { 407 if user == "" || password == "" { 408 panic("user and password are required") 409 } 410 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 411 u, p, ok := r.BasicAuth() 412 if ok { 413 userCheck := subtle.ConstantTimeCompare([]byte(user), []byte(u)) 414 passwordCheck := subtle.ConstantTimeCompare([]byte(password), []byte(p)) 415 if userCheck == 1 && passwordCheck == 1 { 416 next.ServeHTTP(w, r) 417 return 418 } 419 } 420 logger.ErrorContext(r.Context(), "authentication failed", slog.String("user", u)) 421 w.Header().Set("WWW-Authenticate", "Basic realm=\"protected\", charset=\"UTF-8\"") 422 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 423 }) 424} 425 426func tailnet(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 427 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 428 whois, err := tsWhoIs(lc, r) 429 if err != nil { 430 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 431 logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 432 return 433 } 434 435 // Proxy requests from tagged nodes as is. 436 if whois.Node.IsTagged() { 437 next.ServeHTTP(w, r) 438 return 439 } 440 441 req := r.Clone(r.Context()) 442 req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName) 443 req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName) 444 next.ServeHTTP(w, req) 445 }) 446} 447 448func insecureFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 449 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 450 whois, err := tsWhoIs(lc, r) 451 if err != nil { 452 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 453 logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 454 return 455 } 456 if !whois.Node.IsTagged() { 457 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 458 logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node") 459 return 460 } 461 462 next.ServeHTTP(w, r) 463 }) 464} 465 466func oidcFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 467 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 468 whois, err := tsWhoIs(lc, r) 469 if err != nil { 470 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 471 logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 472 return 473 } 474 if !whois.Node.IsTagged() { 475 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 476 logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node") 477 return 478 } 479 480 tok := middleware.IDJWTFromContext(r.Context()) 481 if tok == nil { 482 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 483 logger.ErrorContext(r.Context(), "jwt token missing") 484 return 485 } 486 email, err := tok.StringClaim("email") 487 if err != nil { 488 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 489 logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "email")) 490 return 491 } 492 name, err := tok.StringClaim("name") 493 if err != nil { 494 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 495 logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "name")) 496 return 497 } 498 499 req := r.Clone(r.Context()) 500 req.Header.Set("X-Webauth-User", email) 501 req.Header.Set("X-Webauth-Name", name) 502 503 next.ServeHTTP(w, req) 504 }) 505} 506 507// restrictNetworks will only allow clients from the provided IP networks to 508// access the given handler. If skip prefixes are set, paths that match any 509// of the regular expressions will not have restrictions applied. 510func restrictNetworks(logger *slog.Logger, allowedNetworks []netip.Prefix, next http.Handler) http.Handler { 511 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 512 // If the funneled connection is from tsnet, then the net.Conn will be of 513 // type ipn.FunnelConn. 514 netConn := r.Context().Value(ctxConn{}) 515 // if the conn is wrapped inside TLS, unwrap it 516 if tlsConn, ok := netConn.(*tls.Conn); ok { 517 netConn = tlsConn.NetConn() 518 } 519 var remote netip.AddrPort 520 if fconn, ok := netConn.(*ipn.FunnelConn); ok { 521 remote = fconn.Src 522 } else if v, err := netip.ParseAddrPort(r.RemoteAddr); err == nil { 523 remote = v 524 } else { 525 logger.Error("restrictNetworks: cannot parse client IP:port", lerr(err), slog.String("remote", r.RemoteAddr)) 526 w.WriteHeader(http.StatusUnauthorized) 527 return 528 } 529 530 for _, wl := range allowedNetworks { 531 if wl.Contains(remote.Addr()) { 532 next.ServeHTTP(w, r) 533 return 534 } 535 } 536 537 w.WriteHeader(http.StatusForbidden) 538 _, _ = fmt.Fprint(w, badNetwork) 539 }) 540} 541 542const badNetwork = ` 543<html> 544<head><title>Untrusted network</title></head> 545<body><h1>Access from untrusted networks not permitted</h1></body> 546</html> 547` 548 549type tailscaleLocalClient interface { 550 WhoIs(context.Context, string) (*apitype.WhoIsResponse, error) 551} 552 553func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) { 554 whois, err := lc.WhoIs(r.Context(), r.RemoteAddr) 555 if err != nil { 556 return nil, fmt.Errorf("tailscale whois: %w", err) 557 } 558 559 if whois.Node == nil { 560 return nil, errors.New("tailscale whois: node missing") 561 } 562 563 if whois.UserProfile == nil { 564 return nil, errors.New("tailscale whois: user profile missing") 565 } 566 return whois, nil 567} 568 569func serveDiscovery(self string, targets []target) http.Handler { 570 return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 571 var tgs []string 572 tgs = append(tgs, self) 573 for _, t := range targets { 574 if t.magicDNS == "" { 575 continue 576 } 577 if !t.prometheus { 578 continue 579 } 580 tgs = append(tgs, t.magicDNS) 581 } 582 sort.Strings(tgs) 583 buf, err := json.Marshal([]struct { 584 Targets []string `json:"targets"` 585 }{ 586 {Targets: tgs}, 587 }) 588 if err != nil { 589 http.Error(w, err.Error(), http.StatusInternalServerError) 590 return 591 } 592 w.Header().Set("Content-Type", "application/json; charset=utf-8") 593 _, _ = w.Write(buf) 594 }) 595} 596 597func lerr(err error) slog.Attr { 598 return slog.Any("err", err) 599}