HTTP reverse proxy for Tailscale
at main 233 lines 6.0 kB view raw
1package main 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "log" 9 "log/slog" 10 "net/http" 11 "net/http/httptest" 12 "net/url" 13 "reflect" 14 "strings" 15 "testing" 16 17 "github.com/google/go-cmp/cmp" 18 "github.com/prometheus/client_golang/prometheus" 19 "github.com/prometheus/client_golang/prometheus/testutil" 20 "tailscale.com/client/tailscale/apitype" 21 "tailscale.com/tailcfg" 22) 23 24type fakeLocalClient struct { 25 whois func(context.Context, string) (*apitype.WhoIsResponse, error) 26} 27 28func (c *fakeLocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { 29 return c.whois(ctx, remoteAddr) 30} 31 32func TestParseUpstream(t *testing.T) { 33 t.Parallel() 34 35 for _, tc := range []struct { 36 upstream string 37 want upstream 38 err error 39 }{ 40 { 41 upstream: "test=http://example.com:-80/", 42 want: upstream{}, 43 err: errors.New(`parse "http://`), 44 }, 45 { 46 upstream: "test=http://localhost", 47 want: upstream{name: "test", backend: mustParseURL("http://localhost")}, 48 }, 49 { 50 upstream: "test=http://localhost;prometheus", 51 want: upstream{name: "test", backend: mustParseURL("http://localhost"), prometheus: true}, 52 }, 53 { 54 upstream: "test=http://localhost;funnel;prometheus", 55 want: upstream{name: "test", backend: mustParseURL("http://localhost"), prometheus: true, funnel: true}, 56 }, 57 { 58 upstream: "test=http://localhost;foo", 59 want: upstream{}, 60 err: errors.New("unsupported option: foo"), 61 }, 62 } { 63 tc := tc 64 t.Run(tc.upstream, func(t *testing.T) { 65 t.Parallel() 66 up, err := parseUpstreamFlag(tc.upstream) 67 if tc.err != nil { 68 if err == nil { 69 t.Fatalf("want err %v, got nil", tc.err) 70 } 71 if !strings.Contains(err.Error(), tc.err.Error()) { 72 t.Fatalf("want err %v, got %v", tc.err, err) 73 } 74 } 75 if tc.err == nil && err != nil { 76 t.Fatalf("want no err, got %v", err) 77 } 78 if diff := cmp.Diff(tc.want, up, cmp.Exporter(func(_ reflect.Type) bool { return true })); diff != "" { 79 t.Errorf("mismatch (-want +got):\n%s", diff) 80 } 81 }) 82 } 83} 84 85func mustParseURL(s string) *url.URL { 86 v, err := url.Parse(s) 87 if err != nil { 88 panic(err) 89 } 90 return v 91} 92 93func TestReverseProxy(t *testing.T) { 94 t.Parallel() 95 96 for _, tc := range []struct { 97 name string 98 whois func(context.Context, string) (*apitype.WhoIsResponse, error) 99 want int 100 wantHeaders map[string]string 101 }{ 102 { 103 name: "tailscale whois error", 104 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 105 return nil, errors.New("whois error") 106 }, 107 want: http.StatusInternalServerError, 108 }, 109 { 110 name: "tailscale whois no profile", 111 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 112 return &apitype.WhoIsResponse{}, nil 113 }, 114 want: http.StatusInternalServerError, 115 }, 116 { 117 name: "tailscale whois no node", 118 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 119 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil 120 }, 121 want: http.StatusInternalServerError, 122 }, 123 { 124 name: "tailscale whois ok (tagged node)", 125 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 126 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 127 }, 128 want: http.StatusOK, 129 }, 130 { 131 name: "tailscale whois ok (user)", 132 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 133 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil 134 }, 135 want: http.StatusOK, 136 wantHeaders: map[string]string{ 137 "X-Webauth-User": "login", 138 "X-Webauth-Name": "name", 139 }, 140 }, 141 } { 142 tc := tc 143 t.Run(tc.name, func(t *testing.T) { 144 t.Parallel() 145 lc := &fakeLocalClient{whois: tc.whois} 146 be := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 147 for k, v := range r.Header { 148 w.Header().Set(k, v[0]) 149 } 150 fmt.Fprintln(w, "Hi from the backend.") 151 })) 152 defer be.Close() 153 beURL, err := url.Parse(be.URL) 154 if err != nil { 155 log.Fatal(err) 156 } 157 px := httptest.NewServer(newReverseProxy(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), lc, beURL)) 158 defer px.Close() 159 160 resp, err := http.Get(px.URL) 161 if err != nil { 162 t.Fatal(err) 163 } 164 defer resp.Body.Close() 165 166 if want, got := tc.want, resp.StatusCode; want != got { 167 t.Errorf("want status %d, got: %d", want, got) 168 } 169 if tc.wantHeaders == nil { 170 tc.wantHeaders = map[string]string{ 171 "X-Webauth-User": "", 172 "X-Webauth-Name": "", 173 } 174 } 175 for k, want := range tc.wantHeaders { 176 if got := resp.Header.Get(k); got != want { 177 t.Errorf("want header %s %s, got: %s", k, want, got) 178 } 179 } 180 }) 181 } 182} 183 184func TestServeDiscovery(t *testing.T) { 185 t.Parallel() 186 187 ts := httptest.NewServer(serveDiscovery("self", []target{ 188 {magicDNS: "b", prometheus: true}, 189 {magicDNS: "x"}, 190 {}, 191 {magicDNS: "a", prometheus: true}, 192 })) 193 defer ts.Close() 194 195 resp, err := http.Get(ts.URL) 196 if err != nil { 197 t.Fatal(err) 198 } 199 defer resp.Body.Close() 200 if want, got := http.StatusOK, resp.StatusCode; want != got { 201 t.Errorf("want status %d, got: %d", want, got) 202 } 203 b, err := io.ReadAll(resp.Body) 204 if err != nil { 205 t.Fatal(err) 206 } 207 if diff := cmp.Diff(`[{"targets":["a","b","self"]}]`, string(b)); diff != "" { 208 t.Errorf("body mismatch (-want +got):\n%s", diff) 209 } 210} 211 212func TestMetrics(t *testing.T) { 213 t.Parallel() 214 215 c, err := testutil.GatherAndCount(prometheus.DefaultGatherer) 216 if err != nil { 217 t.Fatalf("GatherAndCount: %v", err) 218 } 219 if c == 0 { 220 t.Fatalf("no metrics collected") 221 } 222 223 lint, err := testutil.GatherAndLint(prometheus.DefaultGatherer) 224 if err != nil { 225 t.Fatalf("CollectAndLint: %v", err) 226 } 227 if len(lint) > 0 { 228 t.Error("lint problems detected") 229 } 230 for _, prob := range lint { 231 t.Errorf("lint: %s: %s", prob.Metric, prob.Text) 232 } 233}