package main import ( "context" "crypto/tls" "encoding/json" "errors" "flag" "fmt" "log/slog" "net" "net/http" "net/http/httputil" "net/url" "os" "path/filepath" "sort" "strconv" "strings" "syscall" "github.com/oklog/run" "github.com/prometheus/client_golang/prometheus" versioncollector "github.com/prometheus/client_golang/prometheus/collectors/version" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/common/version" "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/tsnet" tslogger "tailscale.com/types/logger" ) var ( requestsInFlight = promauto.NewGaugeVec( prometheus.GaugeOpts{ Namespace: "tsproxy", Name: "requests_in_flight", Help: "Number of requests currently being served by the server.", }, []string{"upstream"}, ) requests = promauto.NewCounterVec( prometheus.CounterOpts{ Namespace: "tsproxy", Name: "requests_total", Help: "Number of requests received by the server.", }, []string{"upstream", "code", "method"}, ) duration = promauto.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "tsproxy", Name: "request_duration_seconds", Help: "A histogram of latencies for requests handled by the server.", NativeHistogramBucketFactor: 1.1, }, []string{"upstream", "code", "method"}, ) ) type upstreamFlag []upstream func (f *upstreamFlag) String() string { return fmt.Sprintf("%+v", *f) } func (f *upstreamFlag) Set(val string) error { up, err := parseUpstreamFlag(val) if err != nil { return err } *f = append(*f, up) return nil } type upstream struct { name string backend *url.URL prometheus bool funnel bool } type target struct { name string magicDNS string prometheus bool } func parseUpstreamFlag(fval string) (upstream, error) { k, v, ok := strings.Cut(fval, "=") if !ok { return upstream{}, errors.New("format: name=http://backend") } val := strings.Split(v, ";") be, err := url.Parse(val[0]) if err != nil { return upstream{}, err } up := upstream{name: k, backend: be} if len(val) > 1 { for _, opt := range val[1:] { switch opt { case "prometheus": up.prometheus = true case "funnel": up.funnel = true default: return upstream{}, fmt.Errorf("unsupported option: %v", opt) } } } return up, nil } func main() { if err := tsproxy(context.Background()); err != nil { fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err) os.Exit(1) } } func tsproxy(ctx context.Context) error { var ( state = flag.String("state", "", "Optional directory for storing Tailscale state.") tslog = flag.Bool("tslog", false, "If true, log Tailscale output.") port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.") ver = flag.Bool("version", false, "print the version and exit") ) var upstreams upstreamFlag flag.Var(&upstreams, "upstream", "Repeated for each upstream. Format: name=http://backend:8000") flag.Parse() if *ver { fmt.Fprintln(os.Stdout, version.Print("tsproxy")) os.Exit(0) } if len(upstreams) == 0 { return fmt.Errorf("required flag missing: upstream") } if *state == "" { v, err := os.UserCacheDir() if err != nil { return err } dir := filepath.Join(v, "tsproxy") if err := os.MkdirAll(dir, 0o700); err != nil { return err } state = &dir } prometheus.MustRegister(versioncollector.NewCollector("tsproxy")) logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{})) slog.SetDefault(logger) // If tailscaled isn't ready yet, just crash. st, err := (&local.Client{}).Status(ctx) if err != nil { return fmt.Errorf("tailscale: get node status: %w", err) } if v := len(st.Self.TailscaleIPs); v != 2 { return fmt.Errorf("want 2 tailscale IPs, got %d", v) } // service discovery targets (self + all upstreams) targets := make([]target, len(upstreams)+1) var g run.Group ctx, cancel := context.WithCancel(ctx) defer cancel() g.Add(run.SignalHandler(ctx, os.Interrupt, syscall.SIGTERM)) { p := strconv.Itoa(*port) var listeners []net.Listener for _, ip := range st.Self.TailscaleIPs { ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), p)) if err != nil { return fmt.Errorf("listen on %s:%d: %w", ip, *port, err) } listeners = append(listeners, ln) } http.Handle("/metrics", promhttp.Handler()) http.Handle("/sd", serveDiscovery(net.JoinHostPort(st.Self.DNSName, p), targets)) http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte(` tsproxy

tsproxy

Metrics

Discovery

`)) }) srv := &http.Server{} for _, ln := range listeners { ln := ln g.Add(func() error { logger.Info("server ready", slog.String("addr", ln.Addr().String())) return srv.Serve(ln) }, func(_ error) { if err := srv.Close(); err != nil { logger.Error("shutdown server", lerr(err)) } cancel() }) } } for i, upstream := range upstreams { // https://go.dev/doc/faq#closures_and_goroutines i := i upstream := upstream log := logger.With(slog.String("upstream", upstream.name)) ts := &tsnet.Server{ Hostname: upstream.name, Dir: filepath.Join(*state, "tailscale-"+upstream.name), RunWebClient: true, } defer ts.Close() if *tslog { ts.Logf = func(format string, args ...any) { //nolint: sloglint log.Info(fmt.Sprintf(format, args...), slog.String("logger", "tailscale")) } } else { ts.Logf = tslogger.Discard } if err := os.MkdirAll(ts.Dir, 0o700); err != nil { return err } lc, err := ts.LocalClient() if err != nil { return fmt.Errorf("tailscale: get local client for %s: %w", upstream.name, err) } srv := &http.Server{ TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight.With(prometheus.Labels{"upstream": upstream.name}), promhttp.InstrumentHandlerDuration(duration.MustCurryWith(prometheus.Labels{"upstream": upstream.name}), promhttp.InstrumentHandlerCounter(requests.MustCurryWith(prometheus.Labels{"upstream": upstream.name}), newReverseProxy(log, lc, upstream.backend)))), } g.Add(func() error { st, err := ts.Up(ctx) if err != nil { return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) } // register in service discovery when we're ready. targets[i] = target{name: upstream.name, prometheus: upstream.prometheus, magicDNS: st.Self.DNSName} ln, err := ts.Listen("tcp", ":80") if err != nil { return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.name, err) } return srv.Serve(ln) }, func(_ error) { if err := srv.Close(); err != nil { log.Error("server shutdown", lerr(err)) } cancel() }) g.Add(func() error { _, err := ts.Up(ctx) if err != nil { return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) } if upstream.funnel { ln, err := ts.ListenFunnel("tcp", ":443") if err != nil { return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.name, err) } return srv.Serve(ln) } ln, err := ts.Listen("tcp", ":443") if err != nil { return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.name, err) } return srv.ServeTLS(ln, "", "") }, func(_ error) { if err := srv.Close(); err != nil { log.Error("TLS server shutdown", lerr(err)) } cancel() }) } return g.Run() } type tailscaleLocalClient interface { WhoIs(context.Context, string) (*apitype.WhoIsResponse, error) } func newReverseProxy(logger *slog.Logger, lc tailscaleLocalClient, url *url.URL) http.HandlerFunc { // TODO(sr) Instrument proxy.Transport rproxy := &httputil.ReverseProxy{ Rewrite: func(req *httputil.ProxyRequest) { req.SetURL(url) req.SetXForwarded() req.Out.Host = req.In.Host }, } rproxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) { http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) logger.Error("upstream error", lerr(err)) } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { whois, err := lc.WhoIs(r.Context(), r.RemoteAddr) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) logger.Error("tailscale whois", lerr(err)) return } if whois.Node == nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) logger.Error("tailscale whois", slog.String("err", "node missing")) return } if whois.UserProfile == nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) logger.Error("tailscale whois", slog.String("err", "user profile missing")) return } // Proxy requests from tagged nodes as is. if whois.Node.IsTagged() { rproxy.ServeHTTP(w, r) return } req := r.Clone(r.Context()) req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName) req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName) rproxy.ServeHTTP(w, req) }) } func serveDiscovery(self string, targets []target) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { var tgs []string tgs = append(tgs, self) for _, t := range targets { if t.magicDNS == "" { continue } if !t.prometheus { continue } tgs = append(tgs, t.magicDNS) } sort.Strings(tgs) buf, err := json.Marshal([]struct { Targets []string `json:"targets"` }{ {Targets: tgs}, }) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") _, _ = w.Write(buf) }) } func lerr(err error) slog.Attr { return slog.String("err", err.Error()) }