HTTP reverse proxy for Tailscale
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "encoding/json"
7 "errors"
8 "flag"
9 "fmt"
10 "log/slog"
11 "net"
12 "net/http"
13 "net/http/httputil"
14 "net/url"
15 "os"
16 "path/filepath"
17 "sort"
18 "strconv"
19 "strings"
20 "syscall"
21
22 "github.com/oklog/run"
23 "github.com/prometheus/client_golang/prometheus"
24 versioncollector "github.com/prometheus/client_golang/prometheus/collectors/version"
25 "github.com/prometheus/client_golang/prometheus/promauto"
26 "github.com/prometheus/client_golang/prometheus/promhttp"
27 "github.com/prometheus/common/version"
28 "tailscale.com/client/local"
29 "tailscale.com/client/tailscale/apitype"
30 "tailscale.com/tsnet"
31 tslogger "tailscale.com/types/logger"
32)
33
34var (
35 requestsInFlight = promauto.NewGaugeVec(
36 prometheus.GaugeOpts{
37 Namespace: "tsproxy",
38 Name: "requests_in_flight",
39 Help: "Number of requests currently being served by the server.",
40 },
41 []string{"upstream"},
42 )
43
44 requests = promauto.NewCounterVec(
45 prometheus.CounterOpts{
46 Namespace: "tsproxy",
47 Name: "requests_total",
48 Help: "Number of requests received by the server.",
49 },
50 []string{"upstream", "code", "method"},
51 )
52
53 duration = promauto.NewHistogramVec(
54 prometheus.HistogramOpts{
55 Namespace: "tsproxy",
56 Name: "request_duration_seconds",
57 Help: "A histogram of latencies for requests handled by the server.",
58 NativeHistogramBucketFactor: 1.1,
59 },
60 []string{"upstream", "code", "method"},
61 )
62)
63
64type upstreamFlag []upstream
65
66func (f *upstreamFlag) String() string {
67 return fmt.Sprintf("%+v", *f)
68}
69
70func (f *upstreamFlag) Set(val string) error {
71 up, err := parseUpstreamFlag(val)
72 if err != nil {
73 return err
74 }
75 *f = append(*f, up)
76 return nil
77}
78
79type upstream struct {
80 name string
81 backend *url.URL
82 prometheus bool
83 funnel bool
84}
85
86type target struct {
87 name string
88 magicDNS string
89 prometheus bool
90}
91
92func parseUpstreamFlag(fval string) (upstream, error) {
93 k, v, ok := strings.Cut(fval, "=")
94 if !ok {
95 return upstream{}, errors.New("format: name=http://backend")
96 }
97 val := strings.Split(v, ";")
98 be, err := url.Parse(val[0])
99 if err != nil {
100 return upstream{}, err
101 }
102 up := upstream{name: k, backend: be}
103 if len(val) > 1 {
104 for _, opt := range val[1:] {
105 switch opt {
106 case "prometheus":
107 up.prometheus = true
108 case "funnel":
109 up.funnel = true
110 default:
111 return upstream{}, fmt.Errorf("unsupported option: %v", opt)
112 }
113 }
114 }
115 return up, nil
116}
117
118func main() {
119 if err := tsproxy(context.Background()); err != nil {
120 fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err)
121 os.Exit(1)
122 }
123}
124
125func tsproxy(ctx context.Context) error {
126 var (
127 state = flag.String("state", "", "Optional directory for storing Tailscale state.")
128 tslog = flag.Bool("tslog", false, "If true, log Tailscale output.")
129 port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.")
130 ver = flag.Bool("version", false, "print the version and exit")
131 )
132 var upstreams upstreamFlag
133 flag.Var(&upstreams, "upstream", "Repeated for each upstream. Format: name=http://backend:8000")
134 flag.Parse()
135
136 if *ver {
137 fmt.Fprintln(os.Stdout, version.Print("tsproxy"))
138 os.Exit(0)
139 }
140
141 if len(upstreams) == 0 {
142 return fmt.Errorf("required flag missing: upstream")
143 }
144 if *state == "" {
145 v, err := os.UserCacheDir()
146 if err != nil {
147 return err
148 }
149 dir := filepath.Join(v, "tsproxy")
150 if err := os.MkdirAll(dir, 0o700); err != nil {
151 return err
152 }
153 state = &dir
154 }
155 prometheus.MustRegister(versioncollector.NewCollector("tsproxy"))
156
157 logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{}))
158 slog.SetDefault(logger)
159
160 // If tailscaled isn't ready yet, just crash.
161 st, err := (&local.Client{}).Status(ctx)
162 if err != nil {
163 return fmt.Errorf("tailscale: get node status: %w", err)
164 }
165 if v := len(st.Self.TailscaleIPs); v != 2 {
166 return fmt.Errorf("want 2 tailscale IPs, got %d", v)
167 }
168
169 // service discovery targets (self + all upstreams)
170 targets := make([]target, len(upstreams)+1)
171
172 var g run.Group
173 ctx, cancel := context.WithCancel(ctx)
174 defer cancel()
175 g.Add(run.SignalHandler(ctx, os.Interrupt, syscall.SIGTERM))
176
177 {
178 p := strconv.Itoa(*port)
179
180 var listeners []net.Listener
181 for _, ip := range st.Self.TailscaleIPs {
182 ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), p))
183 if err != nil {
184 return fmt.Errorf("listen on %s:%d: %w", ip, *port, err)
185 }
186 listeners = append(listeners, ln)
187 }
188
189 http.Handle("/metrics", promhttp.Handler())
190 http.Handle("/sd", serveDiscovery(net.JoinHostPort(st.Self.DNSName, p), targets))
191 http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
192 _, _ = w.Write([]byte(`<html>
193 <head><title>tsproxy</title></head>
194 <body>
195 <h1>tsproxy</h1>
196 <p><a href="/metrics">Metrics</a></p>
197 <p><a href="/sd">Discovery</a></p>
198 </body>
199 </html>`))
200 })
201
202 srv := &http.Server{}
203 for _, ln := range listeners {
204 ln := ln
205 g.Add(func() error {
206 logger.Info("server ready", slog.String("addr", ln.Addr().String()))
207 return srv.Serve(ln)
208 }, func(_ error) {
209 if err := srv.Close(); err != nil {
210 logger.Error("shutdown server", lerr(err))
211 }
212 cancel()
213 })
214 }
215 }
216
217 for i, upstream := range upstreams {
218 // https://go.dev/doc/faq#closures_and_goroutines
219 i := i
220 upstream := upstream
221
222 log := logger.With(slog.String("upstream", upstream.name))
223
224 ts := &tsnet.Server{
225 Hostname: upstream.name,
226 Dir: filepath.Join(*state, "tailscale-"+upstream.name),
227 RunWebClient: true,
228 }
229 defer ts.Close()
230
231 if *tslog {
232 ts.Logf = func(format string, args ...any) {
233 //nolint: sloglint
234 log.Info(fmt.Sprintf(format, args...), slog.String("logger", "tailscale"))
235 }
236 } else {
237 ts.Logf = tslogger.Discard
238 }
239 if err := os.MkdirAll(ts.Dir, 0o700); err != nil {
240 return err
241 }
242
243 lc, err := ts.LocalClient()
244 if err != nil {
245 return fmt.Errorf("tailscale: get local client for %s: %w", upstream.name, err)
246 }
247
248 srv := &http.Server{
249 TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate},
250 Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight.With(prometheus.Labels{"upstream": upstream.name}),
251 promhttp.InstrumentHandlerDuration(duration.MustCurryWith(prometheus.Labels{"upstream": upstream.name}),
252 promhttp.InstrumentHandlerCounter(requests.MustCurryWith(prometheus.Labels{"upstream": upstream.name}),
253 newReverseProxy(log, lc, upstream.backend)))),
254 }
255
256 g.Add(func() error {
257 st, err := ts.Up(ctx)
258 if err != nil {
259 return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err)
260 }
261
262 // register in service discovery when we're ready.
263 targets[i] = target{name: upstream.name, prometheus: upstream.prometheus, magicDNS: st.Self.DNSName}
264
265 ln, err := ts.Listen("tcp", ":80")
266 if err != nil {
267 return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.name, err)
268 }
269 return srv.Serve(ln)
270 }, func(_ error) {
271 if err := srv.Close(); err != nil {
272 log.Error("server shutdown", lerr(err))
273 }
274 cancel()
275 })
276 g.Add(func() error {
277 _, err := ts.Up(ctx)
278 if err != nil {
279 return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err)
280 }
281
282 if upstream.funnel {
283 ln, err := ts.ListenFunnel("tcp", ":443")
284 if err != nil {
285 return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.name, err)
286 }
287 return srv.Serve(ln)
288 }
289
290 ln, err := ts.Listen("tcp", ":443")
291 if err != nil {
292 return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.name, err)
293 }
294 return srv.ServeTLS(ln, "", "")
295 }, func(_ error) {
296 if err := srv.Close(); err != nil {
297 log.Error("TLS server shutdown", lerr(err))
298 }
299 cancel()
300 })
301 }
302
303 return g.Run()
304}
305
306type tailscaleLocalClient interface {
307 WhoIs(context.Context, string) (*apitype.WhoIsResponse, error)
308}
309
310func newReverseProxy(logger *slog.Logger, lc tailscaleLocalClient, url *url.URL) http.HandlerFunc {
311 // TODO(sr) Instrument proxy.Transport
312 rproxy := &httputil.ReverseProxy{
313 Rewrite: func(req *httputil.ProxyRequest) {
314 req.SetURL(url)
315 req.SetXForwarded()
316 req.Out.Host = req.In.Host
317 },
318 }
319 rproxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) {
320 http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
321 logger.Error("upstream error", lerr(err))
322 }
323
324 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
325 whois, err := lc.WhoIs(r.Context(), r.RemoteAddr)
326 if err != nil {
327 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
328 logger.Error("tailscale whois", lerr(err))
329 return
330 }
331
332 if whois.Node == nil {
333 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
334 logger.Error("tailscale whois", slog.String("err", "node missing"))
335 return
336 }
337
338 if whois.UserProfile == nil {
339 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
340 logger.Error("tailscale whois", slog.String("err", "user profile missing"))
341 return
342 }
343
344 // Proxy requests from tagged nodes as is.
345 if whois.Node.IsTagged() {
346 rproxy.ServeHTTP(w, r)
347 return
348 }
349
350 req := r.Clone(r.Context())
351 req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName)
352 req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName)
353 rproxy.ServeHTTP(w, req)
354 })
355}
356
357func serveDiscovery(self string, targets []target) http.Handler {
358 return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
359 var tgs []string
360 tgs = append(tgs, self)
361 for _, t := range targets {
362 if t.magicDNS == "" {
363 continue
364 }
365 if !t.prometheus {
366 continue
367 }
368 tgs = append(tgs, t.magicDNS)
369 }
370 sort.Strings(tgs)
371 buf, err := json.Marshal([]struct {
372 Targets []string `json:"targets"`
373 }{
374 {Targets: tgs},
375 })
376 if err != nil {
377 http.Error(w, err.Error(), http.StatusInternalServerError)
378 return
379 }
380 w.Header().Set("Content-Type", "application/json; charset=utf-8")
381 _, _ = w.Write(buf)
382 })
383}
384
385func lerr(err error) slog.Attr {
386 return slog.String("err", err.Error())
387}