personal memory agent
at main 346 lines 11 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4"""Tests for Gemini CLI subprocess provider (_translate_gemini).""" 5 6import asyncio 7import importlib 8from unittest.mock import AsyncMock, patch 9 10from think.providers.cli import ThinkingAggregator 11from think.providers.google import _translate_gemini 12from think.providers.shared import JSONEventCallback 13 14 15def _google_provider(): 16 return importlib.reload(importlib.import_module("think.providers.google")) 17 18 19def _assert_write_mode_removes_allowed_tools(make_runner): 20 provider = _google_provider() 21 MockCLIRunner = make_runner() 22 with patch("think.providers.google.CLIRunner", MockCLIRunner): 23 asyncio.run( 24 provider.run_cogitate( 25 {"prompt": "hello", "model": "gemini-2.5-flash", "write": True}, 26 lambda e: None, 27 ) 28 ) 29 cmd = MockCLIRunner.last_instance.cmd 30 assert "--yolo" in cmd 31 assert "--allowed-tools" not in cmd 32 33 34class TestTranslateGemini: 35 """Tests for _translate_gemini event translation.""" 36 37 def _make_callback(self): 38 """Create a callback that records emitted events.""" 39 events = [] 40 cb = JSONEventCallback(events.append) 41 return cb, events 42 43 def _make_aggregator(self, cb): 44 """Create a ThinkingAggregator with the given callback.""" 45 return ThinkingAggregator(cb, model="gemini-2.5-flash") 46 47 def test_init_returns_session_id(self): 48 cb, events = self._make_callback() 49 agg = self._make_aggregator(cb) 50 result = _translate_gemini( 51 { 52 "type": "init", 53 "timestamp": 1000, 54 "session_id": "sess-abc", 55 "model": "gemini-2.5-flash", 56 }, 57 agg, 58 cb, 59 ) 60 assert result == "sess-abc" 61 assert events == [] 62 63 def test_user_message_ignored(self): 64 cb, events = self._make_callback() 65 agg = self._make_aggregator(cb) 66 result = _translate_gemini( 67 {"type": "message", "role": "user", "content": "Hello"}, 68 agg, 69 cb, 70 ) 71 assert result is None 72 assert events == [] 73 74 def test_assistant_delta_accumulates(self): 75 cb, events = self._make_callback() 76 agg = self._make_aggregator(cb) 77 _translate_gemini( 78 { 79 "type": "message", 80 "role": "assistant", 81 "delta": True, 82 "content": "Hello ", 83 }, 84 agg, 85 cb, 86 ) 87 _translate_gemini( 88 {"type": "message", "role": "assistant", "delta": True, "content": "world"}, 89 agg, 90 cb, 91 ) 92 assert agg.flush_as_result() == "Hello world" 93 assert events == [] 94 95 def test_tool_use_flushes_thinking_and_emits_start(self): 96 cb, events = self._make_callback() 97 agg = self._make_aggregator(cb) 98 agg.accumulate("I'll use a tool now.") 99 event = { 100 "type": "tool_use", 101 "timestamp": 2000, 102 "tool_name": "read_file", 103 "tool_id": "tool-1", 104 "parameters": {"path": "/tmp/test.txt"}, 105 } 106 _translate_gemini(event, agg, cb) 107 assert len(events) == 2 108 assert events[0]["event"] == "thinking" 109 assert events[0]["summary"] == "I'll use a tool now." 110 assert events[1]["event"] == "tool_start" 111 assert events[1]["tool"] == "read_file" 112 assert events[1]["call_id"] == "tool-1" 113 assert events[1]["args"] == {"path": "/tmp/test.txt"} 114 assert events[1]["raw"] == [event] 115 116 def test_tool_use_no_thinking_if_buffer_empty(self): 117 cb, events = self._make_callback() 118 agg = self._make_aggregator(cb) 119 event = { 120 "type": "tool_use", 121 "timestamp": 2000, 122 "tool_name": "read_file", 123 "tool_id": "tool-1", 124 "parameters": {}, 125 } 126 _translate_gemini(event, agg, cb) 127 assert len(events) == 1 128 assert events[0]["event"] == "tool_start" 129 130 def test_tool_result_includes_tool_name(self): 131 """tool_end should include tool name from preceding tool_use.""" 132 cb, events = self._make_callback() 133 agg = self._make_aggregator(cb) 134 pending = {} 135 _translate_gemini( 136 { 137 "type": "tool_use", 138 "tool_name": "read_file", 139 "tool_id": "t1", 140 "parameters": {"path": "x.py"}, 141 }, 142 agg, 143 cb, 144 pending_tools=pending, 145 ) 146 events.clear() # ignore tool_start 147 _translate_gemini( 148 { 149 "type": "tool_result", 150 "tool_id": "t1", 151 "status": "success", 152 "output": "contents", 153 }, 154 agg, 155 cb, 156 pending_tools=pending, 157 ) 158 assert len(events) == 1 159 assert events[0]["event"] == "tool_end" 160 assert events[0]["tool"] == "read_file" 161 assert events[0]["args"] == {"path": "x.py"} 162 assert events[0]["result"] == "contents" 163 assert events[0]["call_id"] == "t1" 164 165 def test_tool_result_without_pending(self): 166 """tool_end without pending_tools still works, tool is empty.""" 167 cb, events = self._make_callback() 168 agg = self._make_aggregator(cb) 169 _translate_gemini( 170 { 171 "type": "tool_result", 172 "tool_id": "t1", 173 "status": "success", 174 "output": "data", 175 }, 176 agg, 177 cb, 178 ) 179 assert len(events) == 1 180 assert events[0]["event"] == "tool_end" 181 assert events[0]["tool"] == "" 182 183 def test_result_stores_usage(self): 184 cb, events = self._make_callback() 185 agg = self._make_aggregator(cb) 186 usage = {} 187 _translate_gemini( 188 { 189 "type": "result", 190 "status": "success", 191 "stats": { 192 "total_tokens": 1500, 193 "input_tokens": 1000, 194 "output_tokens": 500, 195 "cached": 200, 196 "duration_ms": 3000, 197 "tool_calls": 2, 198 }, 199 }, 200 agg, 201 cb, 202 usage, 203 ) 204 assert events == [] 205 assert usage["input_tokens"] == 1000 206 assert usage["output_tokens"] == 500 207 assert usage["total_tokens"] == 1500 208 assert usage["cached_tokens"] == 200 209 assert "duration_ms" not in usage 210 211 def test_result_no_stats(self): 212 cb, events = self._make_callback() 213 agg = self._make_aggregator(cb) 214 usage = {} 215 _translate_gemini( 216 {"type": "result", "status": "success"}, 217 agg, 218 cb, 219 usage, 220 ) 221 assert usage == {} 222 223 def test_unknown_event_type_ignored(self): 224 cb, events = self._make_callback() 225 agg = self._make_aggregator(cb) 226 result = _translate_gemini( 227 {"type": "unknown_type", "data": "whatever"}, 228 agg, 229 cb, 230 ) 231 assert result is None 232 assert events == [] 233 234 def test_full_sequence(self): 235 """Process a full sequence of Gemini JSONL events.""" 236 cb, events = self._make_callback() 237 agg = self._make_aggregator(cb) 238 usage = {} 239 pending = {} 240 241 sequence = [ 242 { 243 "type": "init", 244 "session_id": "sess-42", 245 "model": "gemini-2.5-flash", 246 }, 247 {"type": "message", "role": "user", "content": "Analyze this file"}, 248 { 249 "type": "message", 250 "role": "assistant", 251 "delta": True, 252 "content": "I'll read the file. ", 253 }, 254 { 255 "type": "tool_use", 256 "tool_name": "read_file", 257 "tool_id": "t1", 258 "parameters": {"path": "test.py"}, 259 }, 260 { 261 "type": "tool_result", 262 "tool_id": "t1", 263 "status": "success", 264 "output": "print('hello')", 265 }, 266 { 267 "type": "message", 268 "role": "assistant", 269 "delta": True, 270 "content": "The file contains ", 271 }, 272 { 273 "type": "message", 274 "role": "assistant", 275 "delta": True, 276 "content": "a print statement.", 277 }, 278 { 279 "type": "result", 280 "status": "success", 281 "stats": { 282 "total_tokens": 100, 283 "input_tokens": 60, 284 "output_tokens": 40, 285 }, 286 }, 287 ] 288 289 session_ids = [] 290 for ev in sequence: 291 sid = _translate_gemini(ev, agg, cb, usage, pending) 292 if sid: 293 session_ids.append(sid) 294 295 assert session_ids == ["sess-42"] 296 297 # Events: thinking, tool_start, tool_end 298 assert len(events) == 3 299 assert events[0]["event"] == "thinking" 300 assert "read the file" in events[0]["summary"] 301 assert events[1]["event"] == "tool_start" 302 assert events[1]["tool"] == "read_file" 303 assert events[2]["event"] == "tool_end" 304 assert events[2]["tool"] == "read_file" 305 assert events[2]["call_id"] == "t1" 306 307 # Final result text in aggregator 308 result = agg.flush_as_result() 309 assert result == "The file contains a print statement." 310 311 assert usage["total_tokens"] == 100 312 313 314class TestRunCogitateCommand: 315 """Tests for run_cogitate command construction.""" 316 317 def _mock_runner(self): 318 """Create a MockCLIRunner that captures the command.""" 319 320 class MockCLIRunner: 321 last_instance = None 322 323 def __init__(self, **kwargs): 324 self.cmd = kwargs["cmd"] 325 self.prompt_text = kwargs["prompt_text"] 326 self.cli_session_id = "test-session" 327 self.run = AsyncMock(return_value="result") 328 MockCLIRunner.last_instance = self 329 330 return MockCLIRunner 331 332 def test_yolo_mode_with_sol_allowed(self): 333 provider = _google_provider() 334 MockCLIRunner = self._mock_runner() 335 with patch("think.providers.google.CLIRunner", MockCLIRunner): 336 asyncio.run( 337 provider.run_cogitate( 338 {"prompt": "hello", "model": "gemini-2.5-flash"}, lambda e: None 339 ) 340 ) 341 cmd = MockCLIRunner.last_instance.cmd 342 assert "--yolo" in cmd 343 assert cmd[cmd.index("--allowed-tools") + 1] == "run_shell_command(sol)" 344 345 def test_write_mode_removes_allowed_tools(self): 346 _assert_write_mode_removes_allowed_tools(self._mock_runner)