Tailscale-native MCP gateway with identity-based access control, audit logging, and session recording
at main 575 lines 15 kB view raw
1package ui 2 3import ( 4 "encoding/json" 5 "errors" 6 "fmt" 7 "io" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 13 "github.com/slanos/turnscale/internal/audit" 14 "github.com/slanos/turnscale/internal/config" 15 "github.com/slanos/turnscale/internal/identity" 16 "github.com/slanos/turnscale/internal/policy" 17) 18 19type mockIdentifier struct { 20 caller *identity.Caller 21 err error 22} 23 24func (m *mockIdentifier) Identify(r *http.Request) (*identity.Caller, error) { 25 return m.caller, m.err 26} 27 28func testConfig() *config.Config { 29 return &config.Config{ 30 Hostname: "mcp", 31 Tailnet: "example.ts.net", 32 Servers: map[string]config.Server{ 33 "gitea": {URL: "http://localhost:8091/mcp", Transport: "streamable-http"}, 34 "nomad": {URL: "http://localhost:8090/mcp", Transport: "streamable-http"}, 35 }, 36 Policies: []config.Policy{ 37 { 38 Name: "admin-full-access", 39 Match: config.Match{Identity: []string{"scott@github"}}, 40 Allow: []string{"*"}, 41 }, 42 { 43 Name: "ai-agents", 44 Match: config.Match{Tags: []string{"tag:ai-agent"}}, 45 Allow: []string{"gitea", "nomad"}, 46 DenyTools: []string{"mcp__gitea__delete_*"}, 47 }, 48 { 49 Name: "default-deny", 50 Match: config.Match{Identity: []string{"*"}}, 51 Deny: []string{"*"}, 52 }, 53 }, 54 } 55} 56 57func setupTestUI(t *testing.T, caller *identity.Caller, identErr error) *UI { 58 t.Helper() 59 cfg := testConfig() 60 pol := policy.NewEngine(cfg.Policies) 61 aud, err := audit.NewLogger(t.TempDir()) 62 if err != nil { 63 t.Fatalf("audit logger: %v", err) 64 } 65 t.Cleanup(func() { aud.Close() }) 66 adminIDs := map[string]bool{"scott@github": true} 67 ident := &mockIdentifier{caller: caller, err: identErr} 68 return New(cfg, ident, pol, aud, adminIDs) 69} 70 71func TestDashboardAdmin(t *testing.T) { 72 caller := &identity.Caller{ 73 UserLogin: "scott@github", 74 DisplayName: "Test Admin", 75 Node: "little-mac", 76 TailscaleIP: "100.64.0.1", 77 } 78 u := setupTestUI(t, caller, nil) 79 80 // Seed some audit entries 81 u.audit.Log(audit.Entry{ 82 Caller: "scott@github", Server: "gitea", Method: "tools/call", 83 Tool: "mcp__gitea__list_repos", Status: "ok", LatencyMs: 42, 84 }) 85 86 req := httptest.NewRequest("GET", "/ui/", nil) 87 rec := httptest.NewRecorder() 88 u.HandleDashboard(rec, req) 89 90 if rec.Code != http.StatusOK { 91 t.Fatalf("status = %d, want 200", rec.Code) 92 } 93 if ct := rec.Header().Get("Content-Type"); !strings.Contains(ct, "text/html") { 94 t.Errorf("Content-Type = %q, want text/html", ct) 95 } 96 97 body := rec.Body.String() 98 99 for _, want := range []string{ 100 "scott@github", 101 "little-mac", 102 "100.64.0.1", 103 "admin", 104 "gitea", 105 "nomad", 106 "admin-full-access", 107 "ai-agents", 108 "default-deny", 109 "Recent Activity", 110 "mcp__gitea__list_repos", 111 "Turnscale v", 112 } { 113 if !strings.Contains(body, want) { 114 t.Errorf("body missing %q", want) 115 } 116 } 117} 118 119func TestDashboardNonAdmin(t *testing.T) { 120 caller := &identity.Caller{ 121 UserLogin: "stranger@github", 122 Node: "other-mac", 123 } 124 u := setupTestUI(t, caller, nil) 125 126 req := httptest.NewRequest("GET", "/ui/", nil) 127 rec := httptest.NewRecorder() 128 u.HandleDashboard(rec, req) 129 130 if rec.Code != http.StatusOK { 131 t.Fatalf("status = %d, want 200", rec.Code) 132 } 133 134 body := rec.Body.String() 135 136 if !strings.Contains(body, "stranger@github") { 137 t.Error("body missing caller identity") 138 } 139 if strings.Contains(body, "Recent Activity") { 140 t.Error("non-admin should NOT see Recent Activity") 141 } 142 if strings.Contains(body, ">admin<") { 143 t.Error("non-admin should NOT have admin badge") 144 } 145} 146 147func TestDashboardTaggedNode(t *testing.T) { 148 caller := &identity.Caller{ 149 Node: "owl", 150 TailscaleIP: "100.1.2.3", 151 Tags: []string{"tag:ai-agent"}, 152 IsTagged: true, 153 } 154 u := setupTestUI(t, caller, nil) 155 156 req := httptest.NewRequest("GET", "/ui/", nil) 157 rec := httptest.NewRecorder() 158 u.HandleDashboard(rec, req) 159 160 if rec.Code != http.StatusOK { 161 t.Fatalf("status = %d, want 200", rec.Code) 162 } 163 164 body := rec.Body.String() 165 if !strings.Contains(body, "owl") { 166 t.Error("body missing node name") 167 } 168 if !strings.Contains(body, "tag:ai-agent") { 169 t.Error("body missing tag") 170 } 171} 172 173func TestDashboardUnauthorized(t *testing.T) { 174 u := setupTestUI(t, nil, errors.New("no identity")) 175 176 req := httptest.NewRequest("GET", "/ui/", nil) 177 rec := httptest.NewRecorder() 178 u.HandleDashboard(rec, req) 179 180 if rec.Code != http.StatusUnauthorized { 181 t.Fatalf("status = %d, want 401", rec.Code) 182 } 183} 184 185// fakeMCPServer returns an httptest.Server that responds to MCP initialize and tools/list. 186func fakeMCPServer(tools []mcpTool) *httptest.Server { 187 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 188 body, _ := io.ReadAll(r.Body) 189 var req jsonRPCRequest 190 json.Unmarshal(body, &req) 191 192 w.Header().Set("Content-Type", "application/json") 193 switch req.Method { 194 case "initialize": 195 w.Header().Set("Mcp-Session-Id", "test-session") 196 json.NewEncoder(w).Encode(map[string]any{ 197 "jsonrpc": "2.0", 198 "id": req.ID, 199 "result": map[string]any{ 200 "protocolVersion": "2025-03-26", 201 "capabilities": map[string]any{"tools": map[string]any{}}, 202 "serverInfo": map[string]string{"name": "test", "version": "1.0"}, 203 }, 204 }) 205 case "notifications/initialized": 206 w.WriteHeader(http.StatusNoContent) 207 case "tools/list": 208 json.NewEncoder(w).Encode(map[string]any{ 209 "jsonrpc": "2.0", 210 "id": req.ID, 211 "result": map[string]any{"tools": tools}, 212 }) 213 default: 214 w.WriteHeader(http.StatusBadRequest) 215 } 216 })) 217} 218 219func TestProbeBackendWithTools(t *testing.T) { 220 tools := []mcpTool{ 221 {Name: "list_repos", Description: "List all repositories"}, 222 {Name: "create_issue", Description: "Create a new issue"}, 223 {Name: "delete_branch", Description: "Delete a branch"}, 224 } 225 srv := fakeMCPServer(tools) 226 defer srv.Close() 227 228 result := probeBackend(t.Context(), srv.URL) 229 230 if !result.healthy { 231 t.Fatalf("expected healthy, got err: %s", result.err) 232 } 233 if len(result.tools) != 3 { 234 t.Fatalf("expected 3 tools, got %d", len(result.tools)) 235 } 236 // Tools should be sorted by name 237 if result.tools[0].Name != "create_issue" { 238 t.Errorf("first tool = %q, want create_issue (sorted)", result.tools[0].Name) 239 } 240} 241 242func TestProbeBackendUnreachable(t *testing.T) { 243 result := probeBackend(t.Context(), "http://127.0.0.1:1") 244 245 if result.healthy { 246 t.Error("unreachable server should not be healthy") 247 } 248 if result.err == "" { 249 t.Error("expected error message for unreachable server") 250 } 251} 252 253func TestToolDiscoveryInDashboard(t *testing.T) { 254 tools := []mcpTool{ 255 {Name: "mcp__gitea__list_repos", Description: "List repos"}, 256 {Name: "mcp__gitea__delete_branch", Description: "Delete a branch"}, 257 } 258 srv := fakeMCPServer(tools) 259 defer srv.Close() 260 261 caller := &identity.Caller{ 262 Node: "owl", Tags: []string{"tag:ai-agent"}, IsTagged: true, 263 } 264 cfg := &config.Config{ 265 Hostname: "mcp", 266 Tailnet: "example.ts.net", 267 Servers: map[string]config.Server{ 268 "gitea": {URL: srv.URL, Transport: "streamable-http"}, 269 }, 270 Policies: []config.Policy{ 271 { 272 Name: "ai-agents", 273 Match: config.Match{Tags: []string{"tag:ai-agent"}}, 274 Allow: []string{"gitea"}, 275 DenyTools: []string{"mcp__gitea__delete_*"}, 276 }, 277 }, 278 } 279 pol := policy.NewEngine(cfg.Policies) 280 aud, err := audit.NewLogger(t.TempDir()) 281 if err != nil { 282 t.Fatal(err) 283 } 284 defer aud.Close() 285 286 u := New(cfg, &mockIdentifier{caller: caller}, pol, aud, nil) 287 288 req := httptest.NewRequest("GET", "/ui/", nil) 289 rec := httptest.NewRecorder() 290 u.HandleDashboard(rec, req) 291 292 if rec.Code != http.StatusOK { 293 t.Fatalf("status = %d, want 200", rec.Code) 294 } 295 296 body := rec.Body.String() 297 298 // Should show tool names 299 if !strings.Contains(body, "mcp__gitea__list_repos") { 300 t.Error("body missing allowed tool name") 301 } 302 303 // Denied tool should have strikethrough class 304 if !strings.Contains(body, `class="denied"`) { 305 t.Error("body missing denied tool styling") 306 } 307 308 // Should show tool count 309 if !strings.Contains(body, "2 tools") { 310 t.Error("body missing tool count") 311 } 312} 313 314func TestHandleAddServer(t *testing.T) { 315 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 316 u := setupTestUI(t, caller, nil) 317 318 form := strings.NewReader("name=jira&url=http://localhost:9090/mcp&transport=streamable-http") 319 req := httptest.NewRequest("POST", "/ui/servers", form) 320 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 321 rec := httptest.NewRecorder() 322 u.HandleAddServer(rec, req) 323 324 if rec.Code != http.StatusSeeOther { 325 t.Fatalf("status = %d, want 303", rec.Code) 326 } 327} 328 329func TestHandleAddServerForbidden(t *testing.T) { 330 caller := &identity.Caller{UserLogin: "stranger@github", Node: "test"} 331 u := setupTestUI(t, caller, nil) 332 333 form := strings.NewReader("name=jira&url=http://localhost:9090/mcp") 334 req := httptest.NewRequest("POST", "/ui/servers", form) 335 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 336 rec := httptest.NewRecorder() 337 u.HandleAddServer(rec, req) 338 339 if rec.Code != http.StatusForbidden { 340 t.Fatalf("status = %d, want 403", rec.Code) 341 } 342} 343 344func TestHandleDeleteServer(t *testing.T) { 345 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 346 u := setupTestUI(t, caller, nil) 347 348 form := strings.NewReader("name=nomad") 349 req := httptest.NewRequest("POST", "/ui/servers/delete", form) 350 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 351 rec := httptest.NewRecorder() 352 u.HandleDeleteServer(rec, req) 353 354 if rec.Code != http.StatusSeeOther { 355 t.Fatalf("status = %d, want 303", rec.Code) 356 } 357} 358 359func TestHandleEditServer(t *testing.T) { 360 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 361 u := setupTestUI(t, caller, nil) 362 363 form := strings.NewReader("name=gitea&url=http://localhost:9999/mcp&transport=streamable-http") 364 req := httptest.NewRequest("POST", "/ui/servers/edit", form) 365 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 366 rec := httptest.NewRecorder() 367 u.HandleEditServer(rec, req) 368 369 if rec.Code != http.StatusSeeOther { 370 t.Fatalf("status = %d, want 303", rec.Code) 371 } 372} 373 374func TestHandleSessionNotFound(t *testing.T) { 375 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 376 u := setupTestUI(t, caller, nil) 377 378 req := httptest.NewRequest("GET", "/ui/session/99999", nil) 379 req.SetPathValue("id", "99999") 380 rec := httptest.NewRecorder() 381 u.HandleSession(rec, req) 382 383 if rec.Code != http.StatusNotFound { 384 t.Fatalf("status = %d, want 404", rec.Code) 385 } 386} 387 388func TestHandleSessionForbidden(t *testing.T) { 389 caller := &identity.Caller{UserLogin: "stranger@github", Node: "test"} 390 u := setupTestUI(t, caller, nil) 391 392 req := httptest.NewRequest("GET", "/ui/session/1", nil) 393 req.SetPathValue("id", "1") 394 rec := httptest.NewRecorder() 395 u.HandleSession(rec, req) 396 397 if rec.Code != http.StatusForbidden { 398 t.Fatalf("status = %d, want 403", rec.Code) 399 } 400} 401 402func TestHandleSessionWithRecording(t *testing.T) { 403 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 404 u := setupTestUI(t, caller, nil) 405 406 // Create a recording 407 id := u.audit.LogWithRecording( 408 audit.Entry{Caller: "test", Server: "gitea", Method: "tools/call", Tool: "list", Status: "ok"}, 409 []byte(`{"method":"tools/call"}`), []byte(`{"result":"ok"}`), 410 ) 411 412 req := httptest.NewRequest("GET", "/ui/session/"+fmt.Sprint(id), nil) 413 req.SetPathValue("id", fmt.Sprint(id)) 414 rec := httptest.NewRecorder() 415 u.HandleSession(rec, req) 416 417 if rec.Code != http.StatusOK { 418 t.Fatalf("status = %d, want 200", rec.Code) 419 } 420 if !strings.Contains(rec.Body.String(), "tools/call") { 421 t.Error("body missing method") 422 } 423} 424 425func TestPrettyJSON(t *testing.T) { 426 out := prettyJSON(`{"a":1,"b":2}`) 427 if !strings.Contains(out, "\n") { 428 t.Error("expected indented output") 429 } 430 // Invalid JSON returns as-is 431 out = prettyJSON("not json") 432 if out != "not json" { 433 t.Error("invalid JSON should pass through") 434 } 435} 436 437func TestBuildChartEmpty(t *testing.T) { 438 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 439 u := setupTestUI(t, caller, nil) 440 441 chart := u.buildChart(24) 442 if chart != nil { 443 t.Error("expected nil chart with no data") 444 } 445} 446 447func TestBuildChartWithData(t *testing.T) { 448 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 449 u := setupTestUI(t, caller, nil) 450 451 u.audit.Log(audit.Entry{Caller: "a", Server: "gitea", Method: "tools/call", Status: "ok"}) 452 u.audit.Log(audit.Entry{Caller: "b", Server: "gitea", Method: "tools/call", Status: "error", Error: "fail"}) 453 454 cr := u.buildChart(24) 455 if cr == nil { 456 t.Fatal("expected chart data") 457 } 458 if len(cr.Bars) != 24 { 459 t.Errorf("expected 24 bars, got %d", len(cr.Bars)) 460 } 461 if cr.Total != 2 { 462 t.Errorf("expected total 2, got %d", cr.Total) 463 } 464 if cr.Errors != 1 { 465 t.Errorf("expected 1 error, got %d", cr.Errors) 466 } 467 468 // At least one bar should have data 469 hasData := false 470 for _, bar := range cr.Bars { 471 if bar.Total > 0 { 472 hasData = true 473 break 474 } 475 } 476 if !hasData { 477 t.Error("expected at least one bar with data") 478 } 479} 480 481func TestBuildChartDenied(t *testing.T) { 482 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 483 u := setupTestUI(t, caller, nil) 484 485 u.audit.Log(audit.Entry{Caller: "a", Server: "gitea", Method: "tools/call", Status: "ok"}) 486 u.audit.Log(audit.Entry{Caller: "b", Server: "gitea", Method: "tools/call", Status: "denied"}) 487 488 cr := u.buildChart(24) 489 if cr == nil { 490 t.Fatal("expected chart data") 491 } 492 if cr.Denied != 1 { 493 t.Errorf("expected 1 denied, got %d", cr.Denied) 494 } 495 496 // Find bar with data and check denied height 497 for _, bar := range cr.Bars { 498 if bar.Denied > 0 { 499 if bar.DenyH == 0 { 500 t.Error("expected non-zero DenyH for bar with denied requests") 501 } 502 return 503 } 504 } 505 t.Error("no bar with denied data found") 506} 507 508func TestInitial(t *testing.T) { 509 if v := initial("scott"); v != "S" { 510 t.Errorf("initial(scott) = %q", v) 511 } 512 if v := initial(""); v != "?" { 513 t.Errorf("initial('') = %q", v) 514 } 515} 516 517func TestParseToolsList(t *testing.T) { 518 body := `{"jsonrpc":"2.0","id":2,"result":{"tools":[ 519 {"name":"foo","description":"Do foo"}, 520 {"name":"bar","description":"Do bar"} 521 ]}}` 522 tools := parseToolsList([]byte(body)) 523 if len(tools) != 2 { 524 t.Fatalf("expected 2 tools, got %d", len(tools)) 525 } 526 if tools[0].Name != "foo" || tools[1].Name != "bar" { 527 t.Errorf("unexpected tools: %v", tools) 528 } 529} 530 531func TestBuildChartCallerLimit(t *testing.T) { 532 caller := &identity.Caller{UserLogin: "scott@github", Node: "test"} 533 u := setupTestUI(t, caller, nil) 534 535 // Insert entries from many different callers to the same server 536 for i := 0; i < 10; i++ { 537 u.audit.Log(audit.Entry{ 538 Caller: fmt.Sprintf("caller%d", i), Server: "gitea", 539 Method: "tools/call", Status: "ok", 540 }) 541 } 542 543 cr := u.buildChart(24) 544 if cr == nil { 545 t.Fatal("expected chart data") 546 } 547 548 // Find the bar with data 549 for _, bar := range cr.Bars { 550 if bar.Total > 0 { 551 if len(bar.Callers) > 3 { 552 t.Errorf("callers should be limited to 3, got %d", len(bar.Callers)) 553 } 554 return 555 } 556 } 557 t.Error("no bar with data found") 558} 559 560func TestMatchGlob(t *testing.T) { 561 tests := []struct { 562 pattern, name string 563 want bool 564 }{ 565 {"mcp__gitea__delete_*", "mcp__gitea__delete_branch", true}, 566 {"mcp__gitea__delete_*", "mcp__gitea__list_repos", false}, 567 {"exact_match", "exact_match", true}, 568 {"exact_match", "not_match", false}, 569 } 570 for _, tt := range tests { 571 if got := matchGlob(tt.pattern, tt.name); got != tt.want { 572 t.Errorf("matchGlob(%q, %q) = %v, want %v", tt.pattern, tt.name, got, tt.want) 573 } 574 } 575}