personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4import asyncio
5import importlib
6import json
7import sys
8from types import SimpleNamespace
9from unittest.mock import AsyncMock
10
11from tests.conftest import setup_google_genai_stub
12from think.models import GEMINI_FLASH
13from think.providers.google import (
14 _extract_finish_reason,
15 _format_completion_message,
16)
17
18
19async def run_main(mod, argv, stdin_data=None):
20 sys.argv = argv
21 if stdin_data:
22 import io
23
24 sys.stdin = io.StringIO(stdin_data)
25 await mod.main_async()
26
27
28def make_mock_process(stdout_lines, return_code=0):
29 """Create a mock asyncio subprocess for CLI tests."""
30
31 class MockStdout:
32 def __init__(self, lines):
33 self._lines = [line.encode() + b"\n" for line in lines]
34 self._index = 0
35
36 async def readline(self):
37 if self._index >= len(self._lines):
38 return b""
39 line = self._lines[self._index]
40 self._index += 1
41 return line
42
43 def __aiter__(self):
44 return self
45
46 async def __anext__(self):
47 if self._index >= len(self._lines):
48 raise StopAsyncIteration
49 line = self._lines[self._index]
50 self._index += 1
51 return line
52
53 class MockStderr:
54 def __aiter__(self):
55 return self
56
57 async def __anext__(self):
58 raise StopAsyncIteration
59
60 process = AsyncMock()
61 process.stdout = MockStdout(stdout_lines)
62 process.stderr = MockStderr()
63 process.stdin = AsyncMock()
64 process.stdin.write = lambda data: None
65 process.stdin.close = lambda: None
66 process.wait = AsyncMock(return_value=return_code)
67 return process
68
69
70def test_google_main(monkeypatch, tmp_path, capsys):
71 setup_google_genai_stub(monkeypatch, with_thinking=False)
72 sys.modules.pop("think.providers.google", None)
73 importlib.reload(importlib.import_module("think.providers.google"))
74 mod = importlib.reload(importlib.import_module("think.agents"))
75
76 journal = tmp_path / "journal"
77 journal.mkdir()
78
79 monkeypatch.setenv("_SOLSTONE_JOURNAL_OVERRIDE", str(journal))
80 monkeypatch.setenv("GOOGLE_API_KEY", "x")
81 monkeypatch.setattr(
82 "think.providers.cli.shutil.which",
83 lambda name: "/usr/bin/gemini" if name == "gemini" else None,
84 )
85
86 stdout_lines = [
87 json.dumps(
88 {
89 "type": "init",
90 "timestamp": 100,
91 "session_id": "sess-test",
92 "model": "gemini-2.5-flash",
93 }
94 ),
95 json.dumps(
96 {
97 "type": "message",
98 "role": "assistant",
99 "delta": True,
100 "content": "ok",
101 }
102 ),
103 json.dumps(
104 {
105 "type": "result",
106 "timestamp": 200,
107 "status": "success",
108 "stats": {
109 "total_tokens": 10,
110 "input_tokens": 5,
111 "output_tokens": 5,
112 },
113 }
114 ),
115 ]
116 process = make_mock_process(stdout_lines)
117 monkeypatch.setattr(
118 "think.providers.cli.asyncio.create_subprocess_exec",
119 AsyncMock(return_value=process),
120 )
121
122 ndjson_input = json.dumps(
123 {
124 "prompt": "hello",
125 "provider": "google",
126 "model": GEMINI_FLASH,
127 "tools": ["search_insights"],
128 }
129 )
130 asyncio.run(run_main(mod, ["sol agents"], stdin_data=ndjson_input))
131
132 out_lines = capsys.readouterr().out.strip().splitlines()
133 events = [json.loads(line) for line in out_lines]
134 assert events[0]["event"] == "start"
135 assert isinstance(events[0]["ts"], int)
136 assert "hello" in events[0]["prompt"]
137 assert events[0]["name"] == "unified"
138 assert events[0]["model"] == GEMINI_FLASH
139 assert events[-1]["event"] == "finish"
140 assert isinstance(events[-1]["ts"], int)
141 assert events[-1]["result"] == "ok"
142
143 # Journal logging is now handled by cortex, not by agents directly
144 # So we don't check for journal files here
145
146
147def test_google_cli_not_found_error(monkeypatch, tmp_path, capsys):
148 setup_google_genai_stub(monkeypatch, with_thinking=False)
149
150 sys.modules.pop("think.providers.google", None)
151 importlib.reload(importlib.import_module("think.providers.google"))
152 mod = importlib.reload(importlib.import_module("think.agents"))
153
154 journal = tmp_path / "journal"
155 journal.mkdir()
156
157 monkeypatch.setenv("_SOLSTONE_JOURNAL_OVERRIDE", str(journal))
158 monkeypatch.setenv("GOOGLE_API_KEY", "x")
159 monkeypatch.setattr("think.providers.cli.shutil.which", lambda _name: None)
160
161 ndjson_input = json.dumps(
162 {
163 "prompt": "hello",
164 "provider": "google",
165 "model": GEMINI_FLASH,
166 "tools": ["search_insights"],
167 }
168 )
169 asyncio.run(run_main(mod, ["sol agents"], stdin_data=ndjson_input))
170
171 # Check stdout for error event
172 out_lines = capsys.readouterr().out.strip().splitlines()
173 events = [json.loads(line) for line in out_lines]
174 assert events[-1]["event"] == "error"
175 assert isinstance(events[-1]["ts"], int)
176 error_message = events[-1]["error"].lower()
177 assert "gemini" in error_message
178 assert "not found" in error_message
179 assert "trace" in events[-1]
180
181
182# ---------------------------------------------------------------------------
183# Tests for finish reason extraction and formatting
184# ---------------------------------------------------------------------------
185
186
187def test_extract_finish_reason_with_enum():
188 """Test extracting finish_reason from enum-style response."""
189
190 class MockEnum:
191 name = "STOP"
192
193 candidate = SimpleNamespace(finish_reason=MockEnum())
194 response = SimpleNamespace(candidates=[candidate])
195 assert _extract_finish_reason(response) == "STOP"
196
197
198def test_extract_finish_reason_with_string():
199 """Test extracting finish_reason when it's already a string."""
200 candidate = SimpleNamespace(finish_reason="MAX_TOKENS")
201 response = SimpleNamespace(candidates=[candidate])
202 assert _extract_finish_reason(response) == "MAX_TOKENS"
203
204
205def test_extract_finish_reason_no_candidates():
206 """Test extracting finish_reason when no candidates exist."""
207 response = SimpleNamespace(candidates=[])
208 assert _extract_finish_reason(response) is None
209
210 response = SimpleNamespace()
211 assert _extract_finish_reason(response) is None
212
213
214def test_format_completion_message_stop_with_tools():
215 """Test message for STOP with tool calls."""
216 msg = _format_completion_message("STOP", had_tool_calls=True)
217 assert msg == "Completed via tools."
218
219
220def test_format_completion_message_stop_no_tools():
221 """Test message for STOP without tool calls."""
222 msg = _format_completion_message("STOP", had_tool_calls=False)
223 assert msg == "Completed."
224
225
226def test_format_completion_message_max_tokens():
227 """Test message for MAX_TOKENS finish reason."""
228 msg = _format_completion_message("MAX_TOKENS", had_tool_calls=False)
229 assert msg == "Reached token limit."
230
231
232def test_format_completion_message_safety():
233 """Test message for safety-related finish reasons."""
234 msg = _format_completion_message("SAFETY", had_tool_calls=False)
235 assert msg == "Blocked by safety filters."
236
237 msg = _format_completion_message("PROHIBITED_SAFETY", had_tool_calls=False)
238 assert msg == "Blocked by safety filters."
239
240
241def test_format_completion_message_tool_errors():
242 """Test message for tool-related error finish reasons."""
243 msg = _format_completion_message("UNEXPECTED_TOOL_CALL", had_tool_calls=True)
244 assert msg == "Tool execution incomplete."
245
246 msg = _format_completion_message("MALFORMED_FUNCTION_CALL", had_tool_calls=False)
247 assert msg == "Tool execution incomplete."
248
249
250def test_format_completion_message_unknown():
251 """Test message for unknown finish reasons."""
252 msg = _format_completion_message("SOME_NEW_REASON", had_tool_calls=False)
253 assert msg == "Completed (some_new_reason)."
254
255
256def test_format_completion_message_none():
257 """Test message when finish_reason is None."""
258 msg = _format_completion_message(None, had_tool_calls=False)
259 assert msg == "Completed (unknown)."