HTTP reverse proxy for Tailscale
1package main
2
3import (
4 "context"
5 "crypto/subtle"
6 "crypto/tls"
7 "encoding/json"
8 "errors"
9 "flag"
10 "fmt"
11 "log/slog"
12 "net"
13 "net/http"
14 "net/http/httputil"
15 "net/netip"
16 "net/url"
17 "os"
18 "path/filepath"
19 "sort"
20 "strconv"
21 "strings"
22 "syscall"
23
24 "github.com/lstoll/oidc"
25 "github.com/lstoll/oidc/middleware"
26 "github.com/oklog/run"
27 "github.com/prometheus/client_golang/prometheus"
28 versioncollector "github.com/prometheus/client_golang/prometheus/collectors/version"
29 "github.com/prometheus/client_golang/prometheus/promauto"
30 "github.com/prometheus/client_golang/prometheus/promhttp"
31 "github.com/prometheus/common/version"
32 "github.com/tailscale/hujson"
33 "tailscale.com/client/local"
34 "tailscale.com/client/tailscale/apitype"
35 "tailscale.com/ipn"
36 "tailscale.com/tsnet"
37 tslogger "tailscale.com/types/logger"
38)
39
40// ctxConn is a key to look up a net.Conn stored in an HTTP request's context.
41type ctxConn struct{}
42
43var (
44 requestsInFlight = promauto.NewGaugeVec(
45 prometheus.GaugeOpts{
46 Namespace: "tsproxy",
47 Name: "requests_in_flight",
48 Help: "Number of requests currently being served by the server.",
49 },
50 []string{"upstream"},
51 )
52
53 requests = promauto.NewCounterVec(
54 prometheus.CounterOpts{
55 Namespace: "tsproxy",
56 Name: "requests_total",
57 Help: "Number of requests received by the server.",
58 },
59 []string{"upstream", "code", "method"},
60 )
61
62 duration = promauto.NewHistogramVec(
63 prometheus.HistogramOpts{
64 Namespace: "tsproxy",
65 Name: "request_duration_seconds",
66 Help: "A histogram of latencies for requests handled by the server.",
67 NativeHistogramBucketFactor: 1.1,
68 },
69 []string{"upstream", "code", "method"},
70 )
71)
72
73type upstream struct {
74 Name string
75 Backend string
76 Prometheus bool
77 Funnel *funnelConfig
78}
79
80type funnelConfig struct {
81 Insecure bool
82 Issuer string
83 ClientID string
84 ClientSecret string
85 User string
86 Password string
87 IP []string
88}
89
90type target struct {
91 name string
92 magicDNS string
93 prometheus bool
94}
95
96func main() {
97 if err := tsproxy(context.Background()); err != nil {
98 fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err)
99 os.Exit(1)
100 }
101}
102
103func tsproxy(ctx context.Context) error {
104 var (
105 state = flag.String("state", "", "Optional directory for storing Tailscale state.")
106 tslog = flag.Bool("tslog", false, "If true, log Tailscale output.")
107 port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.")
108 ver = flag.Bool("version", false, "print the version and exit")
109 upfile = flag.String("upstream", "", "path to upstreams config file")
110 )
111 flag.Parse()
112
113 if *ver {
114 fmt.Fprintln(os.Stdout, version.Print("tsproxy"))
115 os.Exit(0)
116 }
117
118 if *upfile == "" {
119 return fmt.Errorf("required flag missing: upstream")
120 }
121
122 in, err := os.ReadFile(*upfile)
123 if err != nil {
124 return err
125 }
126 inJSON, err := hujson.Standardize(in)
127 if err != nil {
128 return fmt.Errorf("hujson: %w", err)
129 }
130 var upstreams []upstream
131 if err := json.Unmarshal(inJSON, &upstreams); err != nil {
132 return fmt.Errorf("json: %w", err)
133 }
134 if len(upstreams) == 0 {
135 return fmt.Errorf("file does not contain any upstreams: %s", *upfile)
136 }
137
138 if *state == "" {
139 v, err := os.UserCacheDir()
140 if err != nil {
141 return err
142 }
143 dir := filepath.Join(v, "tsproxy")
144 if err := os.MkdirAll(dir, 0o700); err != nil {
145 return err
146 }
147 state = &dir
148 }
149 prometheus.MustRegister(versioncollector.NewCollector("tsproxy"))
150
151 logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{}))
152 slog.SetDefault(logger)
153
154 // If tailscaled isn't ready yet, just crash.
155 st, err := (&local.Client{}).Status(ctx)
156 if err != nil {
157 return fmt.Errorf("tailscale: get node status: %w", err)
158 }
159 if v := len(st.Self.TailscaleIPs); v != 2 {
160 return fmt.Errorf("want 2 tailscale IPs, got %d", v)
161 }
162
163 // service discovery targets (self + all upstreams)
164 targets := make([]target, len(upstreams)+1)
165
166 var g run.Group
167 ctx, cancel := context.WithCancel(ctx)
168 defer cancel()
169 g.Add(run.SignalHandler(ctx, os.Interrupt, syscall.SIGTERM))
170
171 {
172 p := strconv.Itoa(*port)
173
174 var listeners []net.Listener
175 for _, ip := range st.Self.TailscaleIPs {
176 ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), p))
177 if err != nil {
178 return fmt.Errorf("listen on %s:%d: %w", ip, *port, err)
179 }
180 listeners = append(listeners, ln)
181 }
182
183 http.Handle("/metrics", promhttp.Handler())
184 http.Handle("/sd", serveDiscovery(net.JoinHostPort(st.Self.DNSName, p), targets))
185 http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
186 _, _ = w.Write([]byte(`<html>
187 <head><title>tsproxy</title></head>
188 <body>
189 <h1>tsproxy</h1>
190 <p><a href="/metrics">Metrics</a></p>
191 <p><a href="/sd">Discovery</a></p>
192 </body>
193 </html>`))
194 })
195
196 srv := &http.Server{}
197 for _, ln := range listeners {
198 ln := ln
199 g.Add(func() error {
200 logger.Info("server ready", slog.String("addr", ln.Addr().String()))
201 return srv.Serve(ln)
202 }, func(_ error) {
203 if err := srv.Close(); err != nil {
204 logger.Error("shutdown server", lerr(err))
205 }
206 cancel()
207 })
208 }
209 }
210
211 for i, upstream := range upstreams {
212 log := logger.With(slog.String("upstream", upstream.Name))
213
214 ts := &tsnet.Server{
215 Hostname: upstream.Name,
216 Dir: filepath.Join(*state, "tailscale-"+upstream.Name),
217 RunWebClient: true,
218 }
219 defer ts.Close()
220
221 if *tslog {
222 ts.Logf = func(format string, args ...any) {
223 //nolint: sloglint
224 log.Info(fmt.Sprintf(format, args...), slog.String("logger", "tailscale"))
225 }
226 } else {
227 ts.Logf = tslogger.Discard
228 }
229 if err := os.MkdirAll(ts.Dir, 0o700); err != nil {
230 return err
231 }
232
233 lc, err := ts.LocalClient()
234 if err != nil {
235 return fmt.Errorf("tailscale: get local client for %s: %w", upstream.Name, err)
236 }
237
238 backendURL, err := url.Parse(upstream.Backend)
239 if err != nil {
240 return fmt.Errorf("upstream %s: parse backend URL: %w", upstream.Name, err)
241 }
242 // TODO(sr) Instrument proxy.Transport
243 proxy := &httputil.ReverseProxy{
244 Rewrite: func(req *httputil.ProxyRequest) {
245 req.SetURL(backendURL)
246 req.SetXForwarded()
247 req.Out.Host = req.In.Host
248 },
249 }
250 proxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) {
251 http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
252 log.Error("upstream error", lerr(err))
253 }
254
255 instrument := func(h http.Handler) http.Handler {
256 return promhttp.InstrumentHandlerInFlight(
257 requestsInFlight.With(prometheus.Labels{"upstream": upstream.Name}),
258 promhttp.InstrumentHandlerDuration(
259 duration.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}),
260 promhttp.InstrumentHandlerCounter(
261 requests.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}),
262 h,
263 ),
264 ),
265 )
266 }
267
268 {
269 var srv *http.Server
270 g.Add(func() error {
271 st, err := ts.Up(ctx)
272 if err != nil {
273 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
274 }
275
276 srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, false, tailnet(log, lc, proxy)))}
277 ln, err := ts.Listen("tcp", ":80")
278 if err != nil {
279 return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err)
280 }
281
282 // register in service discovery when we're ready.
283 targets[i] = target{name: upstream.Name, prometheus: upstream.Prometheus, magicDNS: st.Self.DNSName}
284
285 return srv.Serve(ln)
286 }, func(_ error) {
287 if srv != nil {
288 if err := srv.Close(); err != nil {
289 log.Error("server shutdown", lerr(err))
290 }
291 }
292 cancel()
293 })
294 }
295 {
296 var srv *http.Server
297 g.Add(func() error {
298 st, err := ts.Up(ctx)
299 if err != nil {
300 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
301 }
302
303 srv = &http.Server{
304 TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate},
305 Handler: instrument(redirect(st.Self.DNSName, true, tailnet(log, lc, proxy))),
306 }
307
308 ln, err := ts.Listen("tcp", ":443")
309 if err != nil {
310 return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.Name, err)
311 }
312 return srv.ServeTLS(ln, "", "")
313 }, func(_ error) {
314 if srv != nil {
315 if err := srv.Close(); err != nil {
316 log.Error("server shutdown", lerr(err))
317 }
318 }
319 cancel()
320 })
321 }
322 if funnel := upstream.Funnel; funnel != nil {
323 {
324 var srv *http.Server
325 g.Add(func() error {
326 st, err := ts.Up(ctx)
327 if err != nil {
328 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
329 }
330
331 var handler http.Handler
332 switch {
333 case funnel.Insecure:
334 handler = insecureFunnel(log, lc, proxy)
335 case funnel.Issuer != "":
336 redir := &url.URL{Scheme: "https", Host: strings.TrimSuffix(st.Self.DNSName, "."), Path: ".oidc-callback"}
337 wrapper, err := middleware.NewFromDiscovery(ctx, nil, funnel.Issuer, funnel.ClientID, funnel.ClientSecret, redir.String())
338 if err != nil {
339 return fmt.Errorf("oidc middleware for %s: %w", upstream.Name, err)
340 }
341 wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile)
342
343 handler = wrapper.Wrap(oidcFunnel(log, lc, proxy))
344 case funnel.User != "":
345 handler = insecureFunnel(log, lc, basicAuth(log, funnel.User, funnel.Password, proxy))
346 default:
347 return fmt.Errorf("upstream %s must set funnel.insecure or funnel.issuer", upstream.Name)
348 }
349
350 handler = redirect(st.Self.DNSName, true, handler)
351
352 if len(funnel.IP) > 0 {
353 var allow []netip.Prefix
354 for _, ip := range funnel.IP {
355 allow = append(allow, netip.MustParsePrefix(ip))
356 }
357 handler = restrictNetworks(log, allow, handler)
358 }
359
360 srv = &http.Server{
361 Handler: instrument(handler),
362 ConnContext: func(ctx context.Context, c net.Conn) context.Context {
363 return context.WithValue(ctx, ctxConn{}, c)
364 },
365 }
366
367 ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly())
368 if err != nil {
369 return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.Name, err)
370 }
371 return srv.Serve(ln)
372 }, func(_ error) {
373 if srv != nil {
374 if err := srv.Close(); err != nil {
375 log.Error("server shutdown", lerr(err))
376 }
377 }
378 cancel()
379 })
380 }
381 }
382 }
383
384 return g.Run()
385}
386
387func redirect(fqdn string, forceSSL bool, next http.Handler) http.Handler {
388 if fqdn == "" {
389 panic("redirect: fqdn cannot be empty")
390 }
391 fqdn = strings.TrimSuffix(fqdn, ".")
392 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
393 if forceSSL && r.TLS == nil {
394 http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect)
395 return
396 }
397
398 if r.TLS != nil && strings.TrimSuffix(r.Host, ".") != fqdn {
399 http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect)
400 return
401 }
402 next.ServeHTTP(w, r)
403 })
404}
405
406func basicAuth(logger *slog.Logger, user, password string, next http.Handler) http.Handler {
407 if user == "" || password == "" {
408 panic("user and password are required")
409 }
410 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
411 u, p, ok := r.BasicAuth()
412 if ok {
413 userCheck := subtle.ConstantTimeCompare([]byte(user), []byte(u))
414 passwordCheck := subtle.ConstantTimeCompare([]byte(password), []byte(p))
415 if userCheck == 1 && passwordCheck == 1 {
416 next.ServeHTTP(w, r)
417 return
418 }
419 }
420 logger.ErrorContext(r.Context(), "authentication failed", slog.String("user", u))
421 w.Header().Set("WWW-Authenticate", "Basic realm=\"protected\", charset=\"UTF-8\"")
422 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
423 })
424}
425
426func tailnet(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
427 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
428 whois, err := tsWhoIs(lc, r)
429 if err != nil {
430 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
431 logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
432 return
433 }
434
435 // Proxy requests from tagged nodes as is.
436 if whois.Node.IsTagged() {
437 next.ServeHTTP(w, r)
438 return
439 }
440
441 req := r.Clone(r.Context())
442 req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName)
443 req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName)
444 next.ServeHTTP(w, req)
445 })
446}
447
448func insecureFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
449 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
450 whois, err := tsWhoIs(lc, r)
451 if err != nil {
452 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
453 logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
454 return
455 }
456 if !whois.Node.IsTagged() {
457 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
458 logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node")
459 return
460 }
461
462 next.ServeHTTP(w, r)
463 })
464}
465
466func oidcFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
467 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
468 whois, err := tsWhoIs(lc, r)
469 if err != nil {
470 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
471 logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
472 return
473 }
474 if !whois.Node.IsTagged() {
475 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
476 logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node")
477 return
478 }
479
480 tok := middleware.IDJWTFromContext(r.Context())
481 if tok == nil {
482 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
483 logger.ErrorContext(r.Context(), "jwt token missing")
484 return
485 }
486 email, err := tok.StringClaim("email")
487 if err != nil {
488 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
489 logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "email"))
490 return
491 }
492 name, err := tok.StringClaim("name")
493 if err != nil {
494 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
495 logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "name"))
496 return
497 }
498
499 req := r.Clone(r.Context())
500 req.Header.Set("X-Webauth-User", email)
501 req.Header.Set("X-Webauth-Name", name)
502
503 next.ServeHTTP(w, req)
504 })
505}
506
507// restrictNetworks will only allow clients from the provided IP networks to
508// access the given handler. If skip prefixes are set, paths that match any
509// of the regular expressions will not have restrictions applied.
510func restrictNetworks(logger *slog.Logger, allowedNetworks []netip.Prefix, next http.Handler) http.Handler {
511 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
512 // If the funneled connection is from tsnet, then the net.Conn will be of
513 // type ipn.FunnelConn.
514 netConn := r.Context().Value(ctxConn{})
515 // if the conn is wrapped inside TLS, unwrap it
516 if tlsConn, ok := netConn.(*tls.Conn); ok {
517 netConn = tlsConn.NetConn()
518 }
519 var remote netip.AddrPort
520 if fconn, ok := netConn.(*ipn.FunnelConn); ok {
521 remote = fconn.Src
522 } else if v, err := netip.ParseAddrPort(r.RemoteAddr); err == nil {
523 remote = v
524 } else {
525 logger.Error("restrictNetworks: cannot parse client IP:port", lerr(err), slog.String("remote", r.RemoteAddr))
526 w.WriteHeader(http.StatusUnauthorized)
527 return
528 }
529
530 for _, wl := range allowedNetworks {
531 if wl.Contains(remote.Addr()) {
532 next.ServeHTTP(w, r)
533 return
534 }
535 }
536
537 w.WriteHeader(http.StatusForbidden)
538 _, _ = fmt.Fprint(w, badNetwork)
539 })
540}
541
542const badNetwork = `
543<html>
544<head><title>Untrusted network</title></head>
545<body><h1>Access from untrusted networks not permitted</h1></body>
546</html>
547`
548
549type tailscaleLocalClient interface {
550 WhoIs(context.Context, string) (*apitype.WhoIsResponse, error)
551}
552
553func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) {
554 whois, err := lc.WhoIs(r.Context(), r.RemoteAddr)
555 if err != nil {
556 return nil, fmt.Errorf("tailscale whois: %w", err)
557 }
558
559 if whois.Node == nil {
560 return nil, errors.New("tailscale whois: node missing")
561 }
562
563 if whois.UserProfile == nil {
564 return nil, errors.New("tailscale whois: user profile missing")
565 }
566 return whois, nil
567}
568
569func serveDiscovery(self string, targets []target) http.Handler {
570 return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
571 var tgs []string
572 tgs = append(tgs, self)
573 for _, t := range targets {
574 if t.magicDNS == "" {
575 continue
576 }
577 if !t.prometheus {
578 continue
579 }
580 tgs = append(tgs, t.magicDNS)
581 }
582 sort.Strings(tgs)
583 buf, err := json.Marshal([]struct {
584 Targets []string `json:"targets"`
585 }{
586 {Targets: tgs},
587 })
588 if err != nil {
589 http.Error(w, err.Error(), http.StatusInternalServerError)
590 return
591 }
592 w.Header().Set("Content-Type", "application/json; charset=utf-8")
593 _, _ = w.Write(buf)
594 })
595}
596
597func lerr(err error) slog.Attr {
598 return slog.Any("err", err)
599}