HTTP reverse proxy for Tailscale
1package main
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log"
9 "log/slog"
10 "net/http"
11 "net/http/httptest"
12 "net/url"
13 "reflect"
14 "strings"
15 "testing"
16
17 "github.com/google/go-cmp/cmp"
18 "github.com/prometheus/client_golang/prometheus"
19 "github.com/prometheus/client_golang/prometheus/testutil"
20 "tailscale.com/client/tailscale/apitype"
21 "tailscale.com/tailcfg"
22)
23
24type fakeLocalClient struct {
25 whois func(context.Context, string) (*apitype.WhoIsResponse, error)
26}
27
28func (c *fakeLocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
29 return c.whois(ctx, remoteAddr)
30}
31
32func TestParseUpstream(t *testing.T) {
33 t.Parallel()
34
35 for _, tc := range []struct {
36 upstream string
37 want upstream
38 err error
39 }{
40 {
41 upstream: "test=http://example.com:-80/",
42 want: upstream{},
43 err: errors.New(`parse "http://`),
44 },
45 {
46 upstream: "test=http://localhost",
47 want: upstream{name: "test", backend: mustParseURL("http://localhost")},
48 },
49 {
50 upstream: "test=http://localhost;prometheus",
51 want: upstream{name: "test", backend: mustParseURL("http://localhost"), prometheus: true},
52 },
53 {
54 upstream: "test=http://localhost;funnel;prometheus",
55 want: upstream{name: "test", backend: mustParseURL("http://localhost"), prometheus: true, funnel: true},
56 },
57 {
58 upstream: "test=http://localhost;foo",
59 want: upstream{},
60 err: errors.New("unsupported option: foo"),
61 },
62 } {
63 tc := tc
64 t.Run(tc.upstream, func(t *testing.T) {
65 t.Parallel()
66 up, err := parseUpstreamFlag(tc.upstream)
67 if tc.err != nil {
68 if err == nil {
69 t.Fatalf("want err %v, got nil", tc.err)
70 }
71 if !strings.Contains(err.Error(), tc.err.Error()) {
72 t.Fatalf("want err %v, got %v", tc.err, err)
73 }
74 }
75 if tc.err == nil && err != nil {
76 t.Fatalf("want no err, got %v", err)
77 }
78 if diff := cmp.Diff(tc.want, up, cmp.Exporter(func(_ reflect.Type) bool { return true })); diff != "" {
79 t.Errorf("mismatch (-want +got):\n%s", diff)
80 }
81 })
82 }
83}
84
85func mustParseURL(s string) *url.URL {
86 v, err := url.Parse(s)
87 if err != nil {
88 panic(err)
89 }
90 return v
91}
92
93func TestReverseProxy(t *testing.T) {
94 t.Parallel()
95
96 for _, tc := range []struct {
97 name string
98 whois func(context.Context, string) (*apitype.WhoIsResponse, error)
99 want int
100 wantHeaders map[string]string
101 }{
102 {
103 name: "tailscale whois error",
104 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
105 return nil, errors.New("whois error")
106 },
107 want: http.StatusInternalServerError,
108 },
109 {
110 name: "tailscale whois no profile",
111 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
112 return &apitype.WhoIsResponse{}, nil
113 },
114 want: http.StatusInternalServerError,
115 },
116 {
117 name: "tailscale whois no node",
118 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
119 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil
120 },
121 want: http.StatusInternalServerError,
122 },
123 {
124 name: "tailscale whois ok (tagged node)",
125 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
126 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
127 },
128 want: http.StatusOK,
129 },
130 {
131 name: "tailscale whois ok (user)",
132 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
133 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil
134 },
135 want: http.StatusOK,
136 wantHeaders: map[string]string{
137 "X-Webauth-User": "login",
138 "X-Webauth-Name": "name",
139 },
140 },
141 } {
142 tc := tc
143 t.Run(tc.name, func(t *testing.T) {
144 t.Parallel()
145 lc := &fakeLocalClient{whois: tc.whois}
146 be := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
147 for k, v := range r.Header {
148 w.Header().Set(k, v[0])
149 }
150 fmt.Fprintln(w, "Hi from the backend.")
151 }))
152 defer be.Close()
153 beURL, err := url.Parse(be.URL)
154 if err != nil {
155 log.Fatal(err)
156 }
157 px := httptest.NewServer(newReverseProxy(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), lc, beURL))
158 defer px.Close()
159
160 resp, err := http.Get(px.URL)
161 if err != nil {
162 t.Fatal(err)
163 }
164 defer resp.Body.Close()
165
166 if want, got := tc.want, resp.StatusCode; want != got {
167 t.Errorf("want status %d, got: %d", want, got)
168 }
169 if tc.wantHeaders == nil {
170 tc.wantHeaders = map[string]string{
171 "X-Webauth-User": "",
172 "X-Webauth-Name": "",
173 }
174 }
175 for k, want := range tc.wantHeaders {
176 if got := resp.Header.Get(k); got != want {
177 t.Errorf("want header %s %s, got: %s", k, want, got)
178 }
179 }
180 })
181 }
182}
183
184func TestServeDiscovery(t *testing.T) {
185 t.Parallel()
186
187 ts := httptest.NewServer(serveDiscovery("self", []target{
188 {magicDNS: "b", prometheus: true},
189 {magicDNS: "x"},
190 {},
191 {magicDNS: "a", prometheus: true},
192 }))
193 defer ts.Close()
194
195 resp, err := http.Get(ts.URL)
196 if err != nil {
197 t.Fatal(err)
198 }
199 defer resp.Body.Close()
200 if want, got := http.StatusOK, resp.StatusCode; want != got {
201 t.Errorf("want status %d, got: %d", want, got)
202 }
203 b, err := io.ReadAll(resp.Body)
204 if err != nil {
205 t.Fatal(err)
206 }
207 if diff := cmp.Diff(`[{"targets":["a","b","self"]}]`, string(b)); diff != "" {
208 t.Errorf("body mismatch (-want +got):\n%s", diff)
209 }
210}
211
212func TestMetrics(t *testing.T) {
213 t.Parallel()
214
215 c, err := testutil.GatherAndCount(prometheus.DefaultGatherer)
216 if err != nil {
217 t.Fatalf("GatherAndCount: %v", err)
218 }
219 if c == 0 {
220 t.Fatalf("no metrics collected")
221 }
222
223 lint, err := testutil.GatherAndLint(prometheus.DefaultGatherer)
224 if err != nil {
225 t.Fatalf("CollectAndLint: %v", err)
226 }
227 if len(lint) > 0 {
228 t.Error("lint problems detected")
229 }
230 for _, prob := range lint {
231 t.Errorf("lint: %s: %s", prob.Metric, prob.Text)
232 }
233}