package main
import (
"context"
"crypto/subtle"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"syscall"
"github.com/lstoll/oidc"
"github.com/lstoll/oidc/middleware"
"github.com/oklog/run"
"github.com/prometheus/client_golang/prometheus"
versioncollector "github.com/prometheus/client_golang/prometheus/collectors/version"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/prometheus/common/version"
"github.com/tailscale/hujson"
"tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn"
"tailscale.com/tsnet"
tslogger "tailscale.com/types/logger"
)
// ctxConn is a key to look up a net.Conn stored in an HTTP request's context.
type ctxConn struct{}
var (
requestsInFlight = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: "tsproxy",
Name: "requests_in_flight",
Help: "Number of requests currently being served by the server.",
},
[]string{"upstream"},
)
requests = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "tsproxy",
Name: "requests_total",
Help: "Number of requests received by the server.",
},
[]string{"upstream", "code", "method"},
)
duration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "tsproxy",
Name: "request_duration_seconds",
Help: "A histogram of latencies for requests handled by the server.",
NativeHistogramBucketFactor: 1.1,
},
[]string{"upstream", "code", "method"},
)
)
type upstream struct {
Name string
Backend string
Prometheus bool
Funnel *funnelConfig
}
type funnelConfig struct {
Insecure bool
Issuer string
ClientID string
ClientSecret string
User string
Password string
IP []string
}
type target struct {
name string
magicDNS string
prometheus bool
}
func main() {
if err := tsproxy(context.Background()); err != nil {
fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err)
os.Exit(1)
}
}
func tsproxy(ctx context.Context) error {
var (
state = flag.String("state", "", "Optional directory for storing Tailscale state.")
tslog = flag.Bool("tslog", false, "If true, log Tailscale output.")
port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.")
ver = flag.Bool("version", false, "print the version and exit")
upfile = flag.String("upstream", "", "path to upstreams config file")
)
flag.Parse()
if *ver {
fmt.Fprintln(os.Stdout, version.Print("tsproxy"))
os.Exit(0)
}
if *upfile == "" {
return fmt.Errorf("required flag missing: upstream")
}
in, err := os.ReadFile(*upfile)
if err != nil {
return err
}
inJSON, err := hujson.Standardize(in)
if err != nil {
return fmt.Errorf("hujson: %w", err)
}
var upstreams []upstream
if err := json.Unmarshal(inJSON, &upstreams); err != nil {
return fmt.Errorf("json: %w", err)
}
if len(upstreams) == 0 {
return fmt.Errorf("file does not contain any upstreams: %s", *upfile)
}
if *state == "" {
v, err := os.UserCacheDir()
if err != nil {
return err
}
dir := filepath.Join(v, "tsproxy")
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
state = &dir
}
prometheus.MustRegister(versioncollector.NewCollector("tsproxy"))
logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{}))
slog.SetDefault(logger)
// If tailscaled isn't ready yet, just crash.
st, err := (&local.Client{}).Status(ctx)
if err != nil {
return fmt.Errorf("tailscale: get node status: %w", err)
}
if v := len(st.Self.TailscaleIPs); v != 2 {
return fmt.Errorf("want 2 tailscale IPs, got %d", v)
}
// service discovery targets (self + all upstreams)
targets := make([]target, len(upstreams)+1)
var g run.Group
ctx, cancel := context.WithCancel(ctx)
defer cancel()
g.Add(run.SignalHandler(ctx, os.Interrupt, syscall.SIGTERM))
{
p := strconv.Itoa(*port)
var listeners []net.Listener
for _, ip := range st.Self.TailscaleIPs {
ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), p))
if err != nil {
return fmt.Errorf("listen on %s:%d: %w", ip, *port, err)
}
listeners = append(listeners, ln)
}
http.Handle("/metrics", promhttp.Handler())
http.Handle("/sd", serveDiscovery(net.JoinHostPort(st.Self.DNSName, p), targets))
http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte(`
tsproxy
tsproxy
Metrics
Discovery
`))
})
srv := &http.Server{}
for _, ln := range listeners {
ln := ln
g.Add(func() error {
logger.Info("server ready", slog.String("addr", ln.Addr().String()))
return srv.Serve(ln)
}, func(_ error) {
if err := srv.Close(); err != nil {
logger.Error("shutdown server", lerr(err))
}
cancel()
})
}
}
for i, upstream := range upstreams {
log := logger.With(slog.String("upstream", upstream.Name))
ts := &tsnet.Server{
Hostname: upstream.Name,
Dir: filepath.Join(*state, "tailscale-"+upstream.Name),
RunWebClient: true,
}
defer ts.Close()
if *tslog {
ts.Logf = func(format string, args ...any) {
//nolint: sloglint
log.Info(fmt.Sprintf(format, args...), slog.String("logger", "tailscale"))
}
} else {
ts.Logf = tslogger.Discard
}
if err := os.MkdirAll(ts.Dir, 0o700); err != nil {
return err
}
lc, err := ts.LocalClient()
if err != nil {
return fmt.Errorf("tailscale: get local client for %s: %w", upstream.Name, err)
}
backendURL, err := url.Parse(upstream.Backend)
if err != nil {
return fmt.Errorf("upstream %s: parse backend URL: %w", upstream.Name, err)
}
// TODO(sr) Instrument proxy.Transport
proxy := &httputil.ReverseProxy{
Rewrite: func(req *httputil.ProxyRequest) {
req.SetURL(backendURL)
req.SetXForwarded()
req.Out.Host = req.In.Host
},
}
proxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
log.Error("upstream error", lerr(err))
}
instrument := func(h http.Handler) http.Handler {
return promhttp.InstrumentHandlerInFlight(
requestsInFlight.With(prometheus.Labels{"upstream": upstream.Name}),
promhttp.InstrumentHandlerDuration(
duration.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}),
promhttp.InstrumentHandlerCounter(
requests.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}),
h,
),
),
)
}
{
var srv *http.Server
g.Add(func() error {
st, err := ts.Up(ctx)
if err != nil {
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
}
srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, false, tailnet(log, lc, proxy)))}
ln, err := ts.Listen("tcp", ":80")
if err != nil {
return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err)
}
// register in service discovery when we're ready.
targets[i] = target{name: upstream.Name, prometheus: upstream.Prometheus, magicDNS: st.Self.DNSName}
return srv.Serve(ln)
}, func(_ error) {
if srv != nil {
if err := srv.Close(); err != nil {
log.Error("server shutdown", lerr(err))
}
}
cancel()
})
}
{
var srv *http.Server
g.Add(func() error {
st, err := ts.Up(ctx)
if err != nil {
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
}
srv = &http.Server{
TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate},
Handler: instrument(redirect(st.Self.DNSName, true, tailnet(log, lc, proxy))),
}
ln, err := ts.Listen("tcp", ":443")
if err != nil {
return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.Name, err)
}
return srv.ServeTLS(ln, "", "")
}, func(_ error) {
if srv != nil {
if err := srv.Close(); err != nil {
log.Error("server shutdown", lerr(err))
}
}
cancel()
})
}
if funnel := upstream.Funnel; funnel != nil {
{
var srv *http.Server
g.Add(func() error {
st, err := ts.Up(ctx)
if err != nil {
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
}
var handler http.Handler
switch {
case funnel.Insecure:
handler = insecureFunnel(log, lc, proxy)
case funnel.Issuer != "":
redir := &url.URL{Scheme: "https", Host: strings.TrimSuffix(st.Self.DNSName, "."), Path: ".oidc-callback"}
wrapper, err := middleware.NewFromDiscovery(ctx, nil, funnel.Issuer, funnel.ClientID, funnel.ClientSecret, redir.String())
if err != nil {
return fmt.Errorf("oidc middleware for %s: %w", upstream.Name, err)
}
wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile)
handler = wrapper.Wrap(oidcFunnel(log, lc, proxy))
case funnel.User != "":
handler = insecureFunnel(log, lc, basicAuth(log, funnel.User, funnel.Password, proxy))
default:
return fmt.Errorf("upstream %s must set funnel.insecure or funnel.issuer", upstream.Name)
}
handler = redirect(st.Self.DNSName, true, handler)
if len(funnel.IP) > 0 {
var allow []netip.Prefix
for _, ip := range funnel.IP {
allow = append(allow, netip.MustParsePrefix(ip))
}
handler = restrictNetworks(log, allow, handler)
}
srv = &http.Server{
Handler: instrument(handler),
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, ctxConn{}, c)
},
}
ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly())
if err != nil {
return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.Name, err)
}
return srv.Serve(ln)
}, func(_ error) {
if srv != nil {
if err := srv.Close(); err != nil {
log.Error("server shutdown", lerr(err))
}
}
cancel()
})
}
}
}
return g.Run()
}
func redirect(fqdn string, forceSSL bool, next http.Handler) http.Handler {
if fqdn == "" {
panic("redirect: fqdn cannot be empty")
}
fqdn = strings.TrimSuffix(fqdn, ".")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if forceSSL && r.TLS == nil {
http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect)
return
}
if r.TLS != nil && strings.TrimSuffix(r.Host, ".") != fqdn {
http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect)
return
}
next.ServeHTTP(w, r)
})
}
func basicAuth(logger *slog.Logger, user, password string, next http.Handler) http.Handler {
if user == "" || password == "" {
panic("user and password are required")
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if ok {
userCheck := subtle.ConstantTimeCompare([]byte(user), []byte(u))
passwordCheck := subtle.ConstantTimeCompare([]byte(password), []byte(p))
if userCheck == 1 && passwordCheck == 1 {
next.ServeHTTP(w, r)
return
}
}
logger.ErrorContext(r.Context(), "authentication failed", slog.String("user", u))
w.Header().Set("WWW-Authenticate", "Basic realm=\"protected\", charset=\"UTF-8\"")
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
})
}
func tailnet(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
whois, err := tsWhoIs(lc, r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
return
}
// Proxy requests from tagged nodes as is.
if whois.Node.IsTagged() {
next.ServeHTTP(w, r)
return
}
req := r.Clone(r.Context())
req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName)
req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName)
next.ServeHTTP(w, req)
})
}
func insecureFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
whois, err := tsWhoIs(lc, r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
return
}
if !whois.Node.IsTagged() {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node")
return
}
next.ServeHTTP(w, r)
})
}
func oidcFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
whois, err := tsWhoIs(lc, r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
return
}
if !whois.Node.IsTagged() {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node")
return
}
tok := middleware.IDJWTFromContext(r.Context())
if tok == nil {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
logger.ErrorContext(r.Context(), "jwt token missing")
return
}
email, err := tok.StringClaim("email")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "email"))
return
}
name, err := tok.StringClaim("name")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "name"))
return
}
req := r.Clone(r.Context())
req.Header.Set("X-Webauth-User", email)
req.Header.Set("X-Webauth-Name", name)
next.ServeHTTP(w, req)
})
}
// restrictNetworks will only allow clients from the provided IP networks to
// access the given handler. If skip prefixes are set, paths that match any
// of the regular expressions will not have restrictions applied.
func restrictNetworks(logger *slog.Logger, allowedNetworks []netip.Prefix, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// If the funneled connection is from tsnet, then the net.Conn will be of
// type ipn.FunnelConn.
netConn := r.Context().Value(ctxConn{})
// if the conn is wrapped inside TLS, unwrap it
if tlsConn, ok := netConn.(*tls.Conn); ok {
netConn = tlsConn.NetConn()
}
var remote netip.AddrPort
if fconn, ok := netConn.(*ipn.FunnelConn); ok {
remote = fconn.Src
} else if v, err := netip.ParseAddrPort(r.RemoteAddr); err == nil {
remote = v
} else {
logger.Error("restrictNetworks: cannot parse client IP:port", lerr(err), slog.String("remote", r.RemoteAddr))
w.WriteHeader(http.StatusUnauthorized)
return
}
for _, wl := range allowedNetworks {
if wl.Contains(remote.Addr()) {
next.ServeHTTP(w, r)
return
}
}
w.WriteHeader(http.StatusForbidden)
_, _ = fmt.Fprint(w, badNetwork)
})
}
const badNetwork = `
Untrusted network
Access from untrusted networks not permitted
`
type tailscaleLocalClient interface {
WhoIs(context.Context, string) (*apitype.WhoIsResponse, error)
}
func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) {
whois, err := lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil {
return nil, fmt.Errorf("tailscale whois: %w", err)
}
if whois.Node == nil {
return nil, errors.New("tailscale whois: node missing")
}
if whois.UserProfile == nil {
return nil, errors.New("tailscale whois: user profile missing")
}
return whois, nil
}
func serveDiscovery(self string, targets []target) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
var tgs []string
tgs = append(tgs, self)
for _, t := range targets {
if t.magicDNS == "" {
continue
}
if !t.prometheus {
continue
}
tgs = append(tgs, t.magicDNS)
}
sort.Strings(tgs)
buf, err := json.Marshal([]struct {
Targets []string `json:"targets"`
}{
{Targets: tgs},
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
_, _ = w.Write(buf)
})
}
func lerr(err error) slog.Attr {
return slog.Any("err", err)
}