personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4"""Tests for Gemini CLI subprocess provider (_translate_gemini)."""
5
6import asyncio
7import importlib
8from unittest.mock import AsyncMock, patch
9
10from think.providers.cli import ThinkingAggregator
11from think.providers.google import _translate_gemini
12from think.providers.shared import JSONEventCallback
13
14
15def _google_provider():
16 return importlib.reload(importlib.import_module("think.providers.google"))
17
18
19def _assert_write_mode_removes_allowed_tools(make_runner):
20 provider = _google_provider()
21 MockCLIRunner = make_runner()
22 with patch("think.providers.google.CLIRunner", MockCLIRunner):
23 asyncio.run(
24 provider.run_cogitate(
25 {"prompt": "hello", "model": "gemini-2.5-flash", "write": True},
26 lambda e: None,
27 )
28 )
29 cmd = MockCLIRunner.last_instance.cmd
30 assert "--yolo" in cmd
31 assert "--allowed-tools" not in cmd
32
33
34class TestTranslateGemini:
35 """Tests for _translate_gemini event translation."""
36
37 def _make_callback(self):
38 """Create a callback that records emitted events."""
39 events = []
40 cb = JSONEventCallback(events.append)
41 return cb, events
42
43 def _make_aggregator(self, cb):
44 """Create a ThinkingAggregator with the given callback."""
45 return ThinkingAggregator(cb, model="gemini-2.5-flash")
46
47 def test_init_returns_session_id(self):
48 cb, events = self._make_callback()
49 agg = self._make_aggregator(cb)
50 result = _translate_gemini(
51 {
52 "type": "init",
53 "timestamp": 1000,
54 "session_id": "sess-abc",
55 "model": "gemini-2.5-flash",
56 },
57 agg,
58 cb,
59 )
60 assert result == "sess-abc"
61 assert events == []
62
63 def test_user_message_ignored(self):
64 cb, events = self._make_callback()
65 agg = self._make_aggregator(cb)
66 result = _translate_gemini(
67 {"type": "message", "role": "user", "content": "Hello"},
68 agg,
69 cb,
70 )
71 assert result is None
72 assert events == []
73
74 def test_assistant_delta_accumulates(self):
75 cb, events = self._make_callback()
76 agg = self._make_aggregator(cb)
77 _translate_gemini(
78 {
79 "type": "message",
80 "role": "assistant",
81 "delta": True,
82 "content": "Hello ",
83 },
84 agg,
85 cb,
86 )
87 _translate_gemini(
88 {"type": "message", "role": "assistant", "delta": True, "content": "world"},
89 agg,
90 cb,
91 )
92 assert agg.flush_as_result() == "Hello world"
93 assert events == []
94
95 def test_tool_use_flushes_thinking_and_emits_start(self):
96 cb, events = self._make_callback()
97 agg = self._make_aggregator(cb)
98 agg.accumulate("I'll use a tool now.")
99 event = {
100 "type": "tool_use",
101 "timestamp": 2000,
102 "tool_name": "read_file",
103 "tool_id": "tool-1",
104 "parameters": {"path": "/tmp/test.txt"},
105 }
106 _translate_gemini(event, agg, cb)
107 assert len(events) == 2
108 assert events[0]["event"] == "thinking"
109 assert events[0]["summary"] == "I'll use a tool now."
110 assert events[1]["event"] == "tool_start"
111 assert events[1]["tool"] == "read_file"
112 assert events[1]["call_id"] == "tool-1"
113 assert events[1]["args"] == {"path": "/tmp/test.txt"}
114 assert events[1]["raw"] == [event]
115
116 def test_tool_use_no_thinking_if_buffer_empty(self):
117 cb, events = self._make_callback()
118 agg = self._make_aggregator(cb)
119 event = {
120 "type": "tool_use",
121 "timestamp": 2000,
122 "tool_name": "read_file",
123 "tool_id": "tool-1",
124 "parameters": {},
125 }
126 _translate_gemini(event, agg, cb)
127 assert len(events) == 1
128 assert events[0]["event"] == "tool_start"
129
130 def test_tool_result_includes_tool_name(self):
131 """tool_end should include tool name from preceding tool_use."""
132 cb, events = self._make_callback()
133 agg = self._make_aggregator(cb)
134 pending = {}
135 _translate_gemini(
136 {
137 "type": "tool_use",
138 "tool_name": "read_file",
139 "tool_id": "t1",
140 "parameters": {"path": "x.py"},
141 },
142 agg,
143 cb,
144 pending_tools=pending,
145 )
146 events.clear() # ignore tool_start
147 _translate_gemini(
148 {
149 "type": "tool_result",
150 "tool_id": "t1",
151 "status": "success",
152 "output": "contents",
153 },
154 agg,
155 cb,
156 pending_tools=pending,
157 )
158 assert len(events) == 1
159 assert events[0]["event"] == "tool_end"
160 assert events[0]["tool"] == "read_file"
161 assert events[0]["args"] == {"path": "x.py"}
162 assert events[0]["result"] == "contents"
163 assert events[0]["call_id"] == "t1"
164
165 def test_tool_result_without_pending(self):
166 """tool_end without pending_tools still works, tool is empty."""
167 cb, events = self._make_callback()
168 agg = self._make_aggregator(cb)
169 _translate_gemini(
170 {
171 "type": "tool_result",
172 "tool_id": "t1",
173 "status": "success",
174 "output": "data",
175 },
176 agg,
177 cb,
178 )
179 assert len(events) == 1
180 assert events[0]["event"] == "tool_end"
181 assert events[0]["tool"] == ""
182
183 def test_result_stores_usage(self):
184 cb, events = self._make_callback()
185 agg = self._make_aggregator(cb)
186 usage = {}
187 _translate_gemini(
188 {
189 "type": "result",
190 "status": "success",
191 "stats": {
192 "total_tokens": 1500,
193 "input_tokens": 1000,
194 "output_tokens": 500,
195 "cached": 200,
196 "duration_ms": 3000,
197 "tool_calls": 2,
198 },
199 },
200 agg,
201 cb,
202 usage,
203 )
204 assert events == []
205 assert usage["input_tokens"] == 1000
206 assert usage["output_tokens"] == 500
207 assert usage["total_tokens"] == 1500
208 assert usage["cached_tokens"] == 200
209 assert "duration_ms" not in usage
210
211 def test_result_no_stats(self):
212 cb, events = self._make_callback()
213 agg = self._make_aggregator(cb)
214 usage = {}
215 _translate_gemini(
216 {"type": "result", "status": "success"},
217 agg,
218 cb,
219 usage,
220 )
221 assert usage == {}
222
223 def test_unknown_event_type_ignored(self):
224 cb, events = self._make_callback()
225 agg = self._make_aggregator(cb)
226 result = _translate_gemini(
227 {"type": "unknown_type", "data": "whatever"},
228 agg,
229 cb,
230 )
231 assert result is None
232 assert events == []
233
234 def test_full_sequence(self):
235 """Process a full sequence of Gemini JSONL events."""
236 cb, events = self._make_callback()
237 agg = self._make_aggregator(cb)
238 usage = {}
239 pending = {}
240
241 sequence = [
242 {
243 "type": "init",
244 "session_id": "sess-42",
245 "model": "gemini-2.5-flash",
246 },
247 {"type": "message", "role": "user", "content": "Analyze this file"},
248 {
249 "type": "message",
250 "role": "assistant",
251 "delta": True,
252 "content": "I'll read the file. ",
253 },
254 {
255 "type": "tool_use",
256 "tool_name": "read_file",
257 "tool_id": "t1",
258 "parameters": {"path": "test.py"},
259 },
260 {
261 "type": "tool_result",
262 "tool_id": "t1",
263 "status": "success",
264 "output": "print('hello')",
265 },
266 {
267 "type": "message",
268 "role": "assistant",
269 "delta": True,
270 "content": "The file contains ",
271 },
272 {
273 "type": "message",
274 "role": "assistant",
275 "delta": True,
276 "content": "a print statement.",
277 },
278 {
279 "type": "result",
280 "status": "success",
281 "stats": {
282 "total_tokens": 100,
283 "input_tokens": 60,
284 "output_tokens": 40,
285 },
286 },
287 ]
288
289 session_ids = []
290 for ev in sequence:
291 sid = _translate_gemini(ev, agg, cb, usage, pending)
292 if sid:
293 session_ids.append(sid)
294
295 assert session_ids == ["sess-42"]
296
297 # Events: thinking, tool_start, tool_end
298 assert len(events) == 3
299 assert events[0]["event"] == "thinking"
300 assert "read the file" in events[0]["summary"]
301 assert events[1]["event"] == "tool_start"
302 assert events[1]["tool"] == "read_file"
303 assert events[2]["event"] == "tool_end"
304 assert events[2]["tool"] == "read_file"
305 assert events[2]["call_id"] == "t1"
306
307 # Final result text in aggregator
308 result = agg.flush_as_result()
309 assert result == "The file contains a print statement."
310
311 assert usage["total_tokens"] == 100
312
313
314class TestRunCogitateCommand:
315 """Tests for run_cogitate command construction."""
316
317 def _mock_runner(self):
318 """Create a MockCLIRunner that captures the command."""
319
320 class MockCLIRunner:
321 last_instance = None
322
323 def __init__(self, **kwargs):
324 self.cmd = kwargs["cmd"]
325 self.prompt_text = kwargs["prompt_text"]
326 self.cli_session_id = "test-session"
327 self.run = AsyncMock(return_value="result")
328 MockCLIRunner.last_instance = self
329
330 return MockCLIRunner
331
332 def test_yolo_mode_with_sol_allowed(self):
333 provider = _google_provider()
334 MockCLIRunner = self._mock_runner()
335 with patch("think.providers.google.CLIRunner", MockCLIRunner):
336 asyncio.run(
337 provider.run_cogitate(
338 {"prompt": "hello", "model": "gemini-2.5-flash"}, lambda e: None
339 )
340 )
341 cmd = MockCLIRunner.last_instance.cmd
342 assert "--yolo" in cmd
343 assert cmd[cmd.index("--allowed-tools") + 1] == "run_shell_command(sol)"
344
345 def test_write_mode_removes_allowed_tools(self):
346 _assert_write_mode_removes_allowed_tools(self._mock_runner)