HTTP reverse proxy for Tailscale

add back TLS listener

This is still useful for services that work better with TLS such as
https://github.com/distribution/distribution.

+36 -32
+36 -32
main.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "crypto/tls" 5 6 "errors" 6 7 "flag" 7 8 "fmt" ··· 14 15 "strconv" 15 16 "strings" 16 17 "syscall" 17 - "time" 18 18 19 19 "github.com/oklog/run" 20 20 "github.com/prometheus/client_golang/prometheus" ··· 24 24 "tailscale.com/client/tailscale" 25 25 "tailscale.com/tsnet" 26 26 tslogger "tailscale.com/types/logger" 27 - ) 28 - 29 - const ( 30 - // keep this below systemd's DefaultTimeoutStopSec (90 seconds) 31 - stopTimeout = 80 * time.Second 32 27 ) 33 28 34 29 var ( ··· 120 115 var ( 121 116 state = flag.String("state", "", "Optional directory for storing Tailscale state.") 122 117 tslog = flag.Bool("tslog", false, "If true, log Tailscale output.") 123 - port = flag.Int("port", 32019, "Port of the proxy's own HTTP server.") 118 + port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.") 124 119 ) 125 120 var upstreams upstreamFlag 126 121 flag.Var(&upstreams, "upstream", "Repeated for each upstream. Format: name=http://backend:8000") ··· 142 137 } 143 138 144 139 logger := slog.New(slog.NewJSONHandler(os.Stderr)) 140 + // see ReverseProxy.ErrorLog; this ensures these logs go to our logger. 141 + slog.SetDefault(logger) 145 142 146 143 st, err := tsWaitStatusReady(ctx, &tailscale.LocalClient{}) 147 144 if err != nil { ··· 201 198 for i, upstream := range upstreams { 202 199 // https://go.dev/doc/faq#closures_and_goroutines 203 200 i := i 204 - up := upstream 201 + upstream := upstream 205 202 206 - log := logger.With(slog.String("upstream", up.name)) 203 + log := logger.With(slog.String("upstream", upstream.name)) 207 204 208 205 ts := &tsnet.Server{ 209 - Hostname: up.name, 210 - Dir: filepath.Join(*state, "tailscale-"+up.name), 206 + Hostname: upstream.name, 207 + Dir: filepath.Join(*state, "tailscale-"+upstream.name), 211 208 } 209 + defer ts.Close() 210 + 212 211 if *tslog { 213 212 ts.Logf = func(format string, args ...any) { 214 - log.LogAttrs(slog.InfoLevel, fmt.Sprintf(format, args...), slog.String("logger", "tailscale")) 213 + log.Info(fmt.Sprintf(format, args...), slog.String("logger", "tailscale")) 215 214 } 216 215 } else { 217 216 ts.Logf = tslogger.Discard ··· 222 221 223 222 lc, err := ts.LocalClient() 224 223 if err != nil { 225 - return fmt.Errorf("tailscale: get local client for %s: %w", up.name, err) 224 + return fmt.Errorf("tailscale: get local client for %s: %w", upstream.name, err) 226 225 } 227 226 228 227 srv := &http.Server{ 228 + TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 229 229 Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight, 230 230 promhttp.InstrumentHandlerDuration(duration, 231 231 promhttp.InstrumentHandlerCounter(requests, 232 - newReverseProxy(log, lc, up.backend)))), 232 + newReverseProxy(log, lc, upstream.backend)))), 233 233 } 234 234 235 235 g.Add(func() error { 236 - defer ts.Close() 237 - 238 - ln, err := ts.Listen("tcp", ":80") 239 - if err != nil { 240 - return fmt.Errorf("tailscale: listen for %s on port 80: %w", up.name, err) 241 - } 242 - defer ln.Close() 243 - 244 236 st, err := tsWaitStatusReady(ctx, lc) 245 237 if err != nil { 246 - return fmt.Errorf("tailscale: wait for node %s to be ready: %w", up.name, err) 238 + return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) 247 239 } 248 240 249 241 // register in service discovery when we're ready. 250 - targets[i] = target{name: up.name, prometheus: up.prometheus, magicDNS: st.Self.DNSName} 242 + targets[i] = target{name: upstream.name, prometheus: upstream.prometheus, magicDNS: st.Self.DNSName} 251 243 252 - log.Info("server ready", slog.String("addr", ln.Addr().String())) 253 - 244 + ln, err := ts.Listen("tcp", ":80") 245 + if err != nil { 246 + return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.name, err) 247 + } 254 248 return srv.Serve(ln) 255 249 }, func(err error) { 256 - log.Info("shutting down server") 257 - 258 - sctx, sc := context.WithTimeout(ctx, stopTimeout) 259 - defer sc() 260 - if err := srv.Shutdown(sctx); err != nil && err != http.ErrServerClosed { 250 + if err := srv.Close(); err != nil { 261 251 log.Error("server shutdown", err) 262 252 } 263 253 }) 264 - 254 + g.Add(func() error { 255 + _, err := tsWaitStatusReady(ctx, lc) 256 + if err != nil { 257 + return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) 258 + } 259 + ln, err := ts.Listen("tcp", ":443") 260 + if err != nil { 261 + return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.name, err) 262 + } 263 + return srv.ServeTLS(ln, "", "") 264 + }, func(err error) { 265 + if err := srv.Close(); err != nil { 266 + log.Error("TLS server shutdown", err) 267 + } 268 + }) 265 269 } 266 270 267 271 return g.Run()