HTTP reverse proxy for Tailscale
at main 387 lines 10 kB view raw
1package main 2 3import ( 4 "context" 5 "crypto/tls" 6 "encoding/json" 7 "errors" 8 "flag" 9 "fmt" 10 "log/slog" 11 "net" 12 "net/http" 13 "net/http/httputil" 14 "net/url" 15 "os" 16 "path/filepath" 17 "sort" 18 "strconv" 19 "strings" 20 "syscall" 21 22 "github.com/oklog/run" 23 "github.com/prometheus/client_golang/prometheus" 24 versioncollector "github.com/prometheus/client_golang/prometheus/collectors/version" 25 "github.com/prometheus/client_golang/prometheus/promauto" 26 "github.com/prometheus/client_golang/prometheus/promhttp" 27 "github.com/prometheus/common/version" 28 "tailscale.com/client/local" 29 "tailscale.com/client/tailscale/apitype" 30 "tailscale.com/tsnet" 31 tslogger "tailscale.com/types/logger" 32) 33 34var ( 35 requestsInFlight = promauto.NewGaugeVec( 36 prometheus.GaugeOpts{ 37 Namespace: "tsproxy", 38 Name: "requests_in_flight", 39 Help: "Number of requests currently being served by the server.", 40 }, 41 []string{"upstream"}, 42 ) 43 44 requests = promauto.NewCounterVec( 45 prometheus.CounterOpts{ 46 Namespace: "tsproxy", 47 Name: "requests_total", 48 Help: "Number of requests received by the server.", 49 }, 50 []string{"upstream", "code", "method"}, 51 ) 52 53 duration = promauto.NewHistogramVec( 54 prometheus.HistogramOpts{ 55 Namespace: "tsproxy", 56 Name: "request_duration_seconds", 57 Help: "A histogram of latencies for requests handled by the server.", 58 NativeHistogramBucketFactor: 1.1, 59 }, 60 []string{"upstream", "code", "method"}, 61 ) 62) 63 64type upstreamFlag []upstream 65 66func (f *upstreamFlag) String() string { 67 return fmt.Sprintf("%+v", *f) 68} 69 70func (f *upstreamFlag) Set(val string) error { 71 up, err := parseUpstreamFlag(val) 72 if err != nil { 73 return err 74 } 75 *f = append(*f, up) 76 return nil 77} 78 79type upstream struct { 80 name string 81 backend *url.URL 82 prometheus bool 83 funnel bool 84} 85 86type target struct { 87 name string 88 magicDNS string 89 prometheus bool 90} 91 92func parseUpstreamFlag(fval string) (upstream, error) { 93 k, v, ok := strings.Cut(fval, "=") 94 if !ok { 95 return upstream{}, errors.New("format: name=http://backend") 96 } 97 val := strings.Split(v, ";") 98 be, err := url.Parse(val[0]) 99 if err != nil { 100 return upstream{}, err 101 } 102 up := upstream{name: k, backend: be} 103 if len(val) > 1 { 104 for _, opt := range val[1:] { 105 switch opt { 106 case "prometheus": 107 up.prometheus = true 108 case "funnel": 109 up.funnel = true 110 default: 111 return upstream{}, fmt.Errorf("unsupported option: %v", opt) 112 } 113 } 114 } 115 return up, nil 116} 117 118func main() { 119 if err := tsproxy(context.Background()); err != nil { 120 fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err) 121 os.Exit(1) 122 } 123} 124 125func tsproxy(ctx context.Context) error { 126 var ( 127 state = flag.String("state", "", "Optional directory for storing Tailscale state.") 128 tslog = flag.Bool("tslog", false, "If true, log Tailscale output.") 129 port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.") 130 ver = flag.Bool("version", false, "print the version and exit") 131 ) 132 var upstreams upstreamFlag 133 flag.Var(&upstreams, "upstream", "Repeated for each upstream. Format: name=http://backend:8000") 134 flag.Parse() 135 136 if *ver { 137 fmt.Fprintln(os.Stdout, version.Print("tsproxy")) 138 os.Exit(0) 139 } 140 141 if len(upstreams) == 0 { 142 return fmt.Errorf("required flag missing: upstream") 143 } 144 if *state == "" { 145 v, err := os.UserCacheDir() 146 if err != nil { 147 return err 148 } 149 dir := filepath.Join(v, "tsproxy") 150 if err := os.MkdirAll(dir, 0o700); err != nil { 151 return err 152 } 153 state = &dir 154 } 155 prometheus.MustRegister(versioncollector.NewCollector("tsproxy")) 156 157 logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{})) 158 slog.SetDefault(logger) 159 160 // If tailscaled isn't ready yet, just crash. 161 st, err := (&local.Client{}).Status(ctx) 162 if err != nil { 163 return fmt.Errorf("tailscale: get node status: %w", err) 164 } 165 if v := len(st.Self.TailscaleIPs); v != 2 { 166 return fmt.Errorf("want 2 tailscale IPs, got %d", v) 167 } 168 169 // service discovery targets (self + all upstreams) 170 targets := make([]target, len(upstreams)+1) 171 172 var g run.Group 173 ctx, cancel := context.WithCancel(ctx) 174 defer cancel() 175 g.Add(run.SignalHandler(ctx, os.Interrupt, syscall.SIGTERM)) 176 177 { 178 p := strconv.Itoa(*port) 179 180 var listeners []net.Listener 181 for _, ip := range st.Self.TailscaleIPs { 182 ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), p)) 183 if err != nil { 184 return fmt.Errorf("listen on %s:%d: %w", ip, *port, err) 185 } 186 listeners = append(listeners, ln) 187 } 188 189 http.Handle("/metrics", promhttp.Handler()) 190 http.Handle("/sd", serveDiscovery(net.JoinHostPort(st.Self.DNSName, p), targets)) 191 http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { 192 _, _ = w.Write([]byte(`<html> 193 <head><title>tsproxy</title></head> 194 <body> 195 <h1>tsproxy</h1> 196 <p><a href="/metrics">Metrics</a></p> 197 <p><a href="/sd">Discovery</a></p> 198 </body> 199 </html>`)) 200 }) 201 202 srv := &http.Server{} 203 for _, ln := range listeners { 204 ln := ln 205 g.Add(func() error { 206 logger.Info("server ready", slog.String("addr", ln.Addr().String())) 207 return srv.Serve(ln) 208 }, func(_ error) { 209 if err := srv.Close(); err != nil { 210 logger.Error("shutdown server", lerr(err)) 211 } 212 cancel() 213 }) 214 } 215 } 216 217 for i, upstream := range upstreams { 218 // https://go.dev/doc/faq#closures_and_goroutines 219 i := i 220 upstream := upstream 221 222 log := logger.With(slog.String("upstream", upstream.name)) 223 224 ts := &tsnet.Server{ 225 Hostname: upstream.name, 226 Dir: filepath.Join(*state, "tailscale-"+upstream.name), 227 RunWebClient: true, 228 } 229 defer ts.Close() 230 231 if *tslog { 232 ts.Logf = func(format string, args ...any) { 233 //nolint: sloglint 234 log.Info(fmt.Sprintf(format, args...), slog.String("logger", "tailscale")) 235 } 236 } else { 237 ts.Logf = tslogger.Discard 238 } 239 if err := os.MkdirAll(ts.Dir, 0o700); err != nil { 240 return err 241 } 242 243 lc, err := ts.LocalClient() 244 if err != nil { 245 return fmt.Errorf("tailscale: get local client for %s: %w", upstream.name, err) 246 } 247 248 srv := &http.Server{ 249 TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 250 Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight.With(prometheus.Labels{"upstream": upstream.name}), 251 promhttp.InstrumentHandlerDuration(duration.MustCurryWith(prometheus.Labels{"upstream": upstream.name}), 252 promhttp.InstrumentHandlerCounter(requests.MustCurryWith(prometheus.Labels{"upstream": upstream.name}), 253 newReverseProxy(log, lc, upstream.backend)))), 254 } 255 256 g.Add(func() error { 257 st, err := ts.Up(ctx) 258 if err != nil { 259 return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) 260 } 261 262 // register in service discovery when we're ready. 263 targets[i] = target{name: upstream.name, prometheus: upstream.prometheus, magicDNS: st.Self.DNSName} 264 265 ln, err := ts.Listen("tcp", ":80") 266 if err != nil { 267 return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.name, err) 268 } 269 return srv.Serve(ln) 270 }, func(_ error) { 271 if err := srv.Close(); err != nil { 272 log.Error("server shutdown", lerr(err)) 273 } 274 cancel() 275 }) 276 g.Add(func() error { 277 _, err := ts.Up(ctx) 278 if err != nil { 279 return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) 280 } 281 282 if upstream.funnel { 283 ln, err := ts.ListenFunnel("tcp", ":443") 284 if err != nil { 285 return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.name, err) 286 } 287 return srv.Serve(ln) 288 } 289 290 ln, err := ts.Listen("tcp", ":443") 291 if err != nil { 292 return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.name, err) 293 } 294 return srv.ServeTLS(ln, "", "") 295 }, func(_ error) { 296 if err := srv.Close(); err != nil { 297 log.Error("TLS server shutdown", lerr(err)) 298 } 299 cancel() 300 }) 301 } 302 303 return g.Run() 304} 305 306type tailscaleLocalClient interface { 307 WhoIs(context.Context, string) (*apitype.WhoIsResponse, error) 308} 309 310func newReverseProxy(logger *slog.Logger, lc tailscaleLocalClient, url *url.URL) http.HandlerFunc { 311 // TODO(sr) Instrument proxy.Transport 312 rproxy := &httputil.ReverseProxy{ 313 Rewrite: func(req *httputil.ProxyRequest) { 314 req.SetURL(url) 315 req.SetXForwarded() 316 req.Out.Host = req.In.Host 317 }, 318 } 319 rproxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) { 320 http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) 321 logger.Error("upstream error", lerr(err)) 322 } 323 324 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 325 whois, err := lc.WhoIs(r.Context(), r.RemoteAddr) 326 if err != nil { 327 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 328 logger.Error("tailscale whois", lerr(err)) 329 return 330 } 331 332 if whois.Node == nil { 333 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 334 logger.Error("tailscale whois", slog.String("err", "node missing")) 335 return 336 } 337 338 if whois.UserProfile == nil { 339 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 340 logger.Error("tailscale whois", slog.String("err", "user profile missing")) 341 return 342 } 343 344 // Proxy requests from tagged nodes as is. 345 if whois.Node.IsTagged() { 346 rproxy.ServeHTTP(w, r) 347 return 348 } 349 350 req := r.Clone(r.Context()) 351 req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName) 352 req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName) 353 rproxy.ServeHTTP(w, req) 354 }) 355} 356 357func serveDiscovery(self string, targets []target) http.Handler { 358 return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 359 var tgs []string 360 tgs = append(tgs, self) 361 for _, t := range targets { 362 if t.magicDNS == "" { 363 continue 364 } 365 if !t.prometheus { 366 continue 367 } 368 tgs = append(tgs, t.magicDNS) 369 } 370 sort.Strings(tgs) 371 buf, err := json.Marshal([]struct { 372 Targets []string `json:"targets"` 373 }{ 374 {Targets: tgs}, 375 }) 376 if err != nil { 377 http.Error(w, err.Error(), http.StatusInternalServerError) 378 return 379 } 380 w.Header().Set("Content-Type", "application/json; charset=utf-8") 381 _, _ = w.Write(buf) 382 }) 383} 384 385func lerr(err error) slog.Attr { 386 return slog.String("err", err.Error()) 387}