An AI agent built to do Ralph loops - plan mode for planning and ralph mode for implementing.
at new-directions 715 lines 22 kB view raw
1use rustagent::config::SecurityConfig; 2use rustagent::db::Database; 3use rustagent::graph::store::{GraphStore, SqliteGraphStore}; 4use rustagent::graph::{EdgeType, NodeStatus, NodeType}; 5use rustagent::security::SecurityValidator; 6use rustagent::security::permission::AutoApproveHandler; 7use rustagent::tools::Tool; 8use rustagent::tools::factory::create_v2_registry; 9use rustagent::tools::graph_tools::*; 10use serde_json::{Value, json}; 11use std::sync::Arc; 12 13mod common; 14use common::MockGraphStore; 15 16/// Create a test database in memory with a test project 17async fn setup_test_db() -> anyhow::Result<(Database, Arc<SqliteGraphStore>)> { 18 let db = Database::open_in_memory().await?; 19 20 // Insert a test project 21 db.connection() 22 .call(|conn| { 23 let now = chrono::Utc::now().to_rfc3339(); 24 conn.execute( 25 "INSERT INTO projects (id, name, path, registered_at, config_overrides, metadata) 26 VALUES (?, ?, ?, ?, ?, ?)", 27 rusqlite::params![ 28 "proj-1", 29 "proj-1", 30 "/tmp/proj-1", 31 &now, 32 None::<String>, 33 "{}" 34 ], 35 )?; 36 Ok(()) 37 }) 38 .await?; 39 40 let store = Arc::new(SqliteGraphStore::new(db.clone())); 41 Ok((db, store)) 42} 43 44#[tokio::test] 45async fn test_create_node_tool() -> anyhow::Result<()> { 46 let (_db, store) = setup_test_db().await?; 47 let tool = CreateNodeTool::new(store.clone()); 48 49 let params = json!({ 50 "node_type": "task", 51 "title": "Test Task", 52 "description": "A test task", 53 "project_id": "proj-1" 54 }); 55 56 let result = tool.execute(params).await?; 57 let parsed: Value = serde_json::from_str(&result)?; 58 59 assert!(parsed["id"].as_str().is_some()); 60 assert!(parsed["id"].as_str().unwrap().starts_with("ra-")); 61 assert_eq!(parsed["message"], "Node created successfully"); 62 63 // Verify node was created 64 let node_id = parsed["id"].as_str().unwrap(); 65 let node = store.get_node(node_id).await?.expect("Node not found"); 66 67 assert_eq!(node.title, "Test Task"); 68 assert_eq!(node.node_type, NodeType::Task); 69 assert_eq!(node.status, NodeStatus::Pending); 70 71 Ok(()) 72} 73 74#[tokio::test] 75async fn test_create_child_node_tool() -> anyhow::Result<()> { 76 let (_db, store) = setup_test_db().await?; 77 let tool = CreateNodeTool::new(store.clone()); 78 79 // Create parent 80 let parent_params = json!({ 81 "node_type": "goal", 82 "title": "Parent Goal", 83 "description": "A parent goal", 84 "project_id": "proj-1" 85 }); 86 87 let parent_result = tool.execute(parent_params).await?; 88 let parent_parsed: Value = serde_json::from_str(&parent_result)?; 89 let parent_id = parent_parsed["id"].as_str().unwrap(); 90 91 // Create child with parent_id 92 let child_params = json!({ 93 "node_type": "task", 94 "title": "Child Task", 95 "description": "A child task", 96 "project_id": "proj-1", 97 "parent_id": parent_id 98 }); 99 100 let child_result = tool.execute(child_params).await?; 101 let child_parsed: Value = serde_json::from_str(&child_result)?; 102 let child_id = child_parsed["id"].as_str().unwrap(); 103 104 // Verify child ID has parent prefix 105 assert!(child_id.starts_with(parent_id)); 106 assert!(child_id.contains(".")); 107 108 // Verify Contains edge was created 109 let edges = store 110 .get_edges(parent_id, rustagent::graph::store::EdgeDirection::Outgoing) 111 .await?; 112 113 assert!(!edges.is_empty()); 114 assert_eq!(edges[0].0.edge_type, EdgeType::Contains); 115 116 Ok(()) 117} 118 119#[tokio::test] 120async fn test_update_node_tool() -> anyhow::Result<()> { 121 let (_db, store) = setup_test_db().await?; 122 let create_tool = CreateNodeTool::new(store.clone()); 123 let update_tool = UpdateNodeTool::new(store.clone()); 124 125 // Create a node 126 let create_params = json!({ 127 "node_type": "task", 128 "title": "Original Title", 129 "description": "Original description", 130 "project_id": "proj-1" 131 }); 132 133 let create_result = create_tool.execute(create_params).await?; 134 let parsed: Value = serde_json::from_str(&create_result)?; 135 let node_id = parsed["id"].as_str().unwrap(); 136 137 // Update the node 138 let update_params = json!({ 139 "node_id": node_id, 140 "title": "Updated Title", 141 "description": "Updated description", 142 "status": "ready" 143 }); 144 145 update_tool.execute(update_params).await?; 146 147 // Verify update 148 let node = store.get_node(node_id).await?.expect("Node not found"); 149 150 assert_eq!(node.title, "Updated Title"); 151 assert_eq!(node.description, "Updated description"); 152 assert_eq!(node.status, NodeStatus::Ready); 153 154 Ok(()) 155} 156 157#[tokio::test] 158async fn test_add_edge_tool() -> anyhow::Result<()> { 159 let (_db, store) = setup_test_db().await?; 160 let create_tool = CreateNodeTool::new(store.clone()); 161 let edge_tool = AddEdgeTool::new(store.clone()); 162 163 // Create two nodes 164 let params1 = json!({ 165 "node_type": "task", 166 "title": "Task 1", 167 "description": "First task", 168 "project_id": "proj-1" 169 }); 170 171 let result1 = create_tool.execute(params1).await?; 172 let parsed1: Value = serde_json::from_str(&result1)?; 173 let node1_id = parsed1["id"].as_str().unwrap(); 174 175 let params2 = json!({ 176 "node_type": "task", 177 "title": "Task 2", 178 "description": "Second task", 179 "project_id": "proj-1" 180 }); 181 182 let result2 = create_tool.execute(params2).await?; 183 let parsed2: Value = serde_json::from_str(&result2)?; 184 let node2_id = parsed2["id"].as_str().unwrap(); 185 186 // Add DependsOn edge 187 let edge_params = json!({ 188 "edge_type": "depends_on", 189 "from_node": node2_id, 190 "to_node": node1_id, 191 "label": "blocks" 192 }); 193 194 let edge_result = edge_tool.execute(edge_params).await?; 195 let edge_parsed: Value = serde_json::from_str(&edge_result)?; 196 197 assert!(edge_parsed["id"].as_str().is_some()); 198 assert_eq!(edge_parsed["message"], "Edge created successfully"); 199 200 // Verify edge 201 let edges = store 202 .get_edges(node2_id, rustagent::graph::store::EdgeDirection::Outgoing) 203 .await?; 204 205 assert_eq!(edges.len(), 1); 206 assert_eq!(edges[0].0.edge_type, EdgeType::DependsOn); 207 208 Ok(()) 209} 210 211#[tokio::test] 212async fn test_query_nodes_tool() -> anyhow::Result<()> { 213 let (_db, store) = setup_test_db().await?; 214 let create_tool = CreateNodeTool::new(store.clone()); 215 let query_tool = QueryNodesTool::new(store.clone()); 216 217 // Create a few nodes 218 for i in 1..=3 { 219 let params = json!({ 220 "node_type": "task", 221 "title": format!("Task {}", i), 222 "description": "Test task", 223 "project_id": "proj-1" 224 }); 225 create_tool.execute(params).await?; 226 } 227 228 // Query all tasks 229 let query_params = json!({ 230 "node_type": "task", 231 "project_id": "proj-1" 232 }); 233 234 let result = query_tool.execute(query_params).await?; 235 let parsed: Vec<Value> = serde_json::from_str(&result)?; 236 237 assert_eq!(parsed.len(), 3); 238 assert_eq!(parsed[0]["node_type"], "task"); 239 240 Ok(()) 241} 242 243#[tokio::test] 244async fn test_search_nodes_tool() -> anyhow::Result<()> { 245 let (_db, store) = setup_test_db().await?; 246 let create_tool = CreateNodeTool::new(store.clone()); 247 let search_tool = SearchNodesTool::new(store.clone()); 248 249 // Create nodes with specific titles 250 let params = json!({ 251 "node_type": "task", 252 "title": "Authentication Task", 253 "description": "Handle user authentication", 254 "project_id": "proj-1" 255 }); 256 create_tool.execute(params).await?; 257 258 // Search for "authentication" 259 let search_params = json!({ 260 "query": "authentication" 261 }); 262 263 let result = search_tool.execute(search_params).await?; 264 let parsed: Vec<Value> = serde_json::from_str(&result)?; 265 266 assert!(!parsed.is_empty()); 267 assert_eq!(parsed[0]["title"], "Authentication Task"); 268 269 Ok(()) 270} 271 272#[tokio::test] 273async fn test_claim_task_tool() -> anyhow::Result<()> { 274 let (_db, store) = setup_test_db().await?; 275 let create_tool = CreateNodeTool::new(store.clone()); 276 let update_tool = UpdateNodeTool::new(store.clone()); 277 let claim_tool = ClaimTaskTool::new(store.clone()); 278 279 // Create a task 280 let create_params = json!({ 281 "node_type": "task", 282 "title": "Test Task", 283 "description": "To be claimed", 284 "project_id": "proj-1" 285 }); 286 287 let create_result = create_tool.execute(create_params).await?; 288 let parsed: Value = serde_json::from_str(&create_result)?; 289 let task_id = parsed["id"].as_str().unwrap(); 290 291 // Update status to Ready 292 let update_params = json!({ 293 "node_id": task_id, 294 "status": "ready" 295 }); 296 update_tool.execute(update_params).await?; 297 298 // Claim the task 299 let claim_params = json!({ 300 "node_id": task_id, 301 "agent_id": "agent-1" 302 }); 303 304 let claim_result = claim_tool.execute(claim_params).await?; 305 let claim_parsed: Value = serde_json::from_str(&claim_result)?; 306 307 assert_eq!(claim_parsed["claimed"], true); 308 309 // Verify node status 310 let node = store.get_node(task_id).await?.expect("Node not found"); 311 312 assert_eq!(node.status, NodeStatus::Claimed); 313 assert_eq!(node.assigned_to, Some("agent-1".to_string())); 314 315 Ok(()) 316} 317 318#[tokio::test] 319async fn test_log_decision_tool() -> anyhow::Result<()> { 320 let (_db, store) = setup_test_db().await?; 321 let tool = LogDecisionTool::new(store.clone()); 322 323 let params = json!({ 324 "title": "Architecture Decision", 325 "description": "Choose between microservices or monolith", 326 "project_id": "proj-1", 327 "options": [ 328 { 329 "title": "Microservices", 330 "description": "Multiple independent services", 331 "pros": "Scalability, independence", 332 "cons": "Complexity, latency" 333 }, 334 { 335 "title": "Monolith", 336 "description": "Single unified application", 337 "pros": "Simplicity, performance", 338 "cons": "Scalability limitations" 339 } 340 ] 341 }); 342 343 let result = tool.execute(params).await?; 344 let parsed: Value = serde_json::from_str(&result)?; 345 346 assert!(parsed["decision_id"].as_str().is_some()); 347 assert!(parsed["option_ids"].is_array()); 348 assert_eq!(parsed["option_ids"].as_array().unwrap().len(), 2); 349 350 // Verify decision node was created 351 let decision_id = parsed["decision_id"].as_str().unwrap(); 352 let decision = store 353 .get_node(decision_id) 354 .await? 355 .expect("Decision not found"); 356 357 assert_eq!(decision.node_type, NodeType::Decision); 358 assert_eq!(decision.status, NodeStatus::Active); 359 360 // Verify option nodes were created 361 let option_ids = parsed["option_ids"].as_array().unwrap(); 362 for option_id_val in option_ids { 363 let option_id = option_id_val.as_str().unwrap(); 364 let option = store.get_node(option_id).await?.expect("Option not found"); 365 366 assert_eq!(option.node_type, NodeType::Option); 367 assert_eq!(option.status, NodeStatus::Active); 368 } 369 370 Ok(()) 371} 372 373#[tokio::test] 374async fn test_choose_option_tool() -> anyhow::Result<()> { 375 let (_db, store) = setup_test_db().await?; 376 let decision_tool = LogDecisionTool::new(store.clone()); 377 let choose_tool = ChooseOptionTool::new(store.clone()); 378 379 // Create a decision with options 380 let decision_params = json!({ 381 "title": "Test Decision", 382 "description": "Test", 383 "project_id": "proj-1", 384 "options": [ 385 { 386 "title": "Option A", 387 "description": "First option" 388 }, 389 { 390 "title": "Option B", 391 "description": "Second option" 392 } 393 ] 394 }); 395 396 let decision_result = decision_tool.execute(decision_params).await?; 397 let decision_parsed: Value = serde_json::from_str(&decision_result)?; 398 399 let decision_id = decision_parsed["decision_id"].as_str().unwrap(); 400 let option_ids = decision_parsed["option_ids"].as_array().unwrap(); 401 let chosen_option_id = option_ids[0].as_str().unwrap(); 402 403 // Choose an option 404 let choose_params = json!({ 405 "decision_id": decision_id, 406 "option_id": chosen_option_id, 407 "rationale": "Best fit for our needs" 408 }); 409 410 choose_tool.execute(choose_params).await?; 411 412 // Verify decision status changed to Decided 413 let decision = store 414 .get_node(decision_id) 415 .await? 416 .expect("Decision not found"); 417 418 assert_eq!(decision.status, NodeStatus::Decided); 419 420 // Verify chosen option has Chosen status 421 let chosen_option = store 422 .get_node(chosen_option_id) 423 .await? 424 .expect("Option not found"); 425 426 assert_eq!(chosen_option.status, NodeStatus::Chosen); 427 428 // Verify other options are Rejected 429 let other_option_id = option_ids[1].as_str().unwrap(); 430 let other_option = store 431 .get_node(other_option_id) 432 .await? 433 .expect("Option not found"); 434 435 assert_eq!(other_option.status, NodeStatus::Rejected); 436 437 Ok(()) 438} 439 440#[tokio::test] 441async fn test_record_outcome_tool() -> anyhow::Result<()> { 442 let (_db, store) = setup_test_db().await?; 443 let create_tool = CreateNodeTool::new(store.clone()); 444 let outcome_tool = RecordOutcomeTool::new(store.clone()); 445 446 // Create a task 447 let task_params = json!({ 448 "node_type": "task", 449 "title": "Test Task", 450 "description": "Task to record outcome for", 451 "project_id": "proj-1" 452 }); 453 454 let task_result = create_tool.execute(task_params).await?; 455 let task_parsed: Value = serde_json::from_str(&task_result)?; 456 let task_id = task_parsed["id"].as_str().unwrap(); 457 458 // Record outcome 459 let outcome_params = json!({ 460 "parent_id": task_id, 461 "title": "Task Completed", 462 "description": "Successfully completed the task", 463 "project_id": "proj-1", 464 "success": true 465 }); 466 467 let outcome_result = outcome_tool.execute(outcome_params).await?; 468 let outcome_parsed: Value = serde_json::from_str(&outcome_result)?; 469 470 assert!(outcome_parsed["outcome_id"].as_str().is_some()); 471 472 // Verify outcome node 473 let outcome_id = outcome_parsed["outcome_id"].as_str().unwrap(); 474 let outcome = store 475 .get_node(outcome_id) 476 .await? 477 .expect("Outcome not found"); 478 479 assert_eq!(outcome.node_type, NodeType::Outcome); 480 assert_eq!(outcome.status, NodeStatus::Completed); 481 assert_eq!(outcome.metadata.get("success"), Some(&"true".to_string())); 482 483 Ok(()) 484} 485 486#[tokio::test] 487async fn test_record_observation_tool() -> anyhow::Result<()> { 488 let (_db, store) = setup_test_db().await?; 489 let create_tool = CreateNodeTool::new(store.clone()); 490 let obs_tool = RecordObservationTool::new(store.clone()); 491 492 // Create a task to observe 493 let task_params = json!({ 494 "node_type": "task", 495 "title": "Test Task", 496 "description": "Task to observe", 497 "project_id": "proj-1" 498 }); 499 500 let task_result = create_tool.execute(task_params).await?; 501 let task_parsed: Value = serde_json::from_str(&task_result)?; 502 let task_id = task_parsed["id"].as_str().unwrap(); 503 504 // Record observation related to task 505 let obs_params = json!({ 506 "title": "Performance Issue Observed", 507 "description": "Task took longer than expected", 508 "project_id": "proj-1", 509 "related_node_id": task_id 510 }); 511 512 let obs_result = obs_tool.execute(obs_params).await?; 513 let obs_parsed: Value = serde_json::from_str(&obs_result)?; 514 515 assert!(obs_parsed["observation_id"].as_str().is_some()); 516 517 // Verify observation node 518 let obs_id = obs_parsed["observation_id"].as_str().unwrap(); 519 let obs = store 520 .get_node(obs_id) 521 .await? 522 .expect("Observation not found"); 523 524 assert_eq!(obs.node_type, NodeType::Observation); 525 assert_eq!(obs.status, NodeStatus::Active); 526 527 // Verify Informs edge 528 let edges = store 529 .get_edges(obs_id, rustagent::graph::store::EdgeDirection::Outgoing) 530 .await?; 531 532 assert_eq!(edges.len(), 1); 533 assert_eq!(edges[0].0.edge_type, EdgeType::Informs); 534 535 Ok(()) 536} 537 538#[tokio::test] 539async fn test_revisit_tool() -> anyhow::Result<()> { 540 let (_db, store) = setup_test_db().await?; 541 let create_tool = CreateNodeTool::new(store.clone()); 542 let outcome_tool = RecordOutcomeTool::new(store.clone()); 543 let revisit_tool = RevisitTool::new(store.clone()); 544 545 // Create a task and outcome 546 let task_params = json!({ 547 "node_type": "task", 548 "title": "Test Task", 549 "description": "Task", 550 "project_id": "proj-1" 551 }); 552 553 let task_result = create_tool.execute(task_params).await?; 554 let task_parsed: Value = serde_json::from_str(&task_result)?; 555 let task_id = task_parsed["id"].as_str().unwrap(); 556 557 let outcome_params = json!({ 558 "parent_id": task_id, 559 "title": "Outcome", 560 "description": "Task outcome", 561 "project_id": "proj-1", 562 "success": true 563 }); 564 565 let outcome_result = outcome_tool.execute(outcome_params).await?; 566 let outcome_parsed: Value = serde_json::from_str(&outcome_result)?; 567 let outcome_id = outcome_parsed["outcome_id"].as_str().unwrap(); 568 569 // Revisit with new decision 570 let revisit_params = json!({ 571 "outcome_id": outcome_id, 572 "project_id": "proj-1", 573 "reason": "Results not as expected", 574 "new_decision_title": "Reconsider approach" 575 }); 576 577 let revisit_result = revisit_tool.execute(revisit_params).await?; 578 let revisit_parsed: Value = serde_json::from_str(&revisit_result)?; 579 580 assert!(revisit_parsed["revisit_id"].as_str().is_some()); 581 assert!(revisit_parsed["decision_id"].as_str().is_some()); 582 583 // Verify revisit node 584 let revisit_id = revisit_parsed["revisit_id"].as_str().unwrap(); 585 let revisit = store 586 .get_node(revisit_id) 587 .await? 588 .expect("Revisit not found"); 589 590 assert_eq!(revisit.node_type, NodeType::Revisit); 591 assert_eq!(revisit.status, NodeStatus::Active); 592 593 // Verify new decision was created 594 let decision_id = revisit_parsed["decision_id"].as_str().unwrap(); 595 let decision = store 596 .get_node(decision_id) 597 .await? 598 .expect("Decision not found"); 599 600 assert_eq!(decision.node_type, NodeType::Decision); 601 602 Ok(()) 603} 604 605#[tokio::test] 606async fn test_tool_name_and_description() -> anyhow::Result<()> { 607 let (_db, store) = setup_test_db().await?; 608 609 let tools: Vec<(Box<dyn Tool + Send + Sync>, &str)> = vec![ 610 (Box::new(CreateNodeTool::new(store.clone())), "create_node"), 611 (Box::new(UpdateNodeTool::new(store.clone())), "update_node"), 612 (Box::new(AddEdgeTool::new(store.clone())), "add_edge"), 613 (Box::new(QueryNodesTool::new(store.clone())), "query_nodes"), 614 ( 615 Box::new(SearchNodesTool::new(store.clone())), 616 "search_nodes", 617 ), 618 (Box::new(ClaimTaskTool::new(store.clone())), "claim_task"), 619 ( 620 Box::new(LogDecisionTool::new(store.clone())), 621 "log_decision", 622 ), 623 ( 624 Box::new(ChooseOptionTool::new(store.clone())), 625 "choose_option", 626 ), 627 ( 628 Box::new(RecordOutcomeTool::new(store.clone())), 629 "record_outcome", 630 ), 631 ( 632 Box::new(RecordObservationTool::new(store.clone())), 633 "record_observation", 634 ), 635 (Box::new(RevisitTool::new(store.clone())), "revisit"), 636 ]; 637 638 for (tool, expected_name) in tools { 639 assert_eq!(tool.name(), expected_name); 640 assert!(!tool.description().is_empty()); 641 let params = tool.parameters(); 642 assert!(params.is_object()); 643 } 644 645 Ok(()) 646} 647 648#[test] 649fn test_v2_registry_includes_all_tools() { 650 // Create a mock graph store 651 let graph_store = Arc::new(MockGraphStore); 652 653 // Create security config and validator 654 let security_config = SecurityConfig { 655 shell_policy: rustagent::config::ShellPolicy::Blocklist, 656 allowed_commands: vec![], 657 blocked_patterns: vec![], 658 max_file_size_mb: 100, 659 allowed_paths: vec![], 660 }; 661 let validator = 662 Arc::new(SecurityValidator::new(security_config).expect("Failed to create validator")); 663 let permission_handler = Arc::new(AutoApproveHandler); 664 665 let project_root = std::path::PathBuf::from("/tmp"); 666 667 // Create the v2 registry 668 let registry = create_v2_registry( 669 validator, 670 permission_handler, 671 graph_store, 672 None, 673 None, 674 project_root, 675 ); 676 677 // Expected tool names: graph tools + legacy tools + context tools + search tools 678 let expected_tools = vec![ 679 // Graph tools 680 "create_node", 681 "update_node", 682 "add_edge", 683 "query_nodes", 684 "search_nodes", 685 "claim_task", 686 "log_decision", 687 "choose_option", 688 "record_outcome", 689 "record_observation", 690 "revisit", 691 // Legacy tools 692 "read_file", 693 "write_file", 694 "list_files", 695 "run_command", 696 "signal_completion", 697 // Context tools 698 "read_agents_md", 699 // Search tools 700 "code_search", 701 ]; 702 703 // Get all registered tool names 704 let registered_names = registry.list(); 705 706 // Verify each expected tool is registered 707 for expected in expected_tools { 708 assert!( 709 registered_names.contains(&expected.to_string()), 710 "Tool '{}' not found in v2 registry. Registered tools: {:?}", 711 expected, 712 registered_names 713 ); 714 } 715}