personal memory agent
at main 259 lines 8.3 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4import asyncio 5import importlib 6import json 7import sys 8from types import SimpleNamespace 9from unittest.mock import AsyncMock 10 11from tests.conftest import setup_google_genai_stub 12from think.models import GEMINI_FLASH 13from think.providers.google import ( 14 _extract_finish_reason, 15 _format_completion_message, 16) 17 18 19async def run_main(mod, argv, stdin_data=None): 20 sys.argv = argv 21 if stdin_data: 22 import io 23 24 sys.stdin = io.StringIO(stdin_data) 25 await mod.main_async() 26 27 28def make_mock_process(stdout_lines, return_code=0): 29 """Create a mock asyncio subprocess for CLI tests.""" 30 31 class MockStdout: 32 def __init__(self, lines): 33 self._lines = [line.encode() + b"\n" for line in lines] 34 self._index = 0 35 36 async def readline(self): 37 if self._index >= len(self._lines): 38 return b"" 39 line = self._lines[self._index] 40 self._index += 1 41 return line 42 43 def __aiter__(self): 44 return self 45 46 async def __anext__(self): 47 if self._index >= len(self._lines): 48 raise StopAsyncIteration 49 line = self._lines[self._index] 50 self._index += 1 51 return line 52 53 class MockStderr: 54 def __aiter__(self): 55 return self 56 57 async def __anext__(self): 58 raise StopAsyncIteration 59 60 process = AsyncMock() 61 process.stdout = MockStdout(stdout_lines) 62 process.stderr = MockStderr() 63 process.stdin = AsyncMock() 64 process.stdin.write = lambda data: None 65 process.stdin.close = lambda: None 66 process.wait = AsyncMock(return_value=return_code) 67 return process 68 69 70def test_google_main(monkeypatch, tmp_path, capsys): 71 setup_google_genai_stub(monkeypatch, with_thinking=False) 72 sys.modules.pop("think.providers.google", None) 73 importlib.reload(importlib.import_module("think.providers.google")) 74 mod = importlib.reload(importlib.import_module("think.agents")) 75 76 journal = tmp_path / "journal" 77 journal.mkdir() 78 79 monkeypatch.setenv("_SOLSTONE_JOURNAL_OVERRIDE", str(journal)) 80 monkeypatch.setenv("GOOGLE_API_KEY", "x") 81 monkeypatch.setattr( 82 "think.providers.cli.shutil.which", 83 lambda name: "/usr/bin/gemini" if name == "gemini" else None, 84 ) 85 86 stdout_lines = [ 87 json.dumps( 88 { 89 "type": "init", 90 "timestamp": 100, 91 "session_id": "sess-test", 92 "model": "gemini-2.5-flash", 93 } 94 ), 95 json.dumps( 96 { 97 "type": "message", 98 "role": "assistant", 99 "delta": True, 100 "content": "ok", 101 } 102 ), 103 json.dumps( 104 { 105 "type": "result", 106 "timestamp": 200, 107 "status": "success", 108 "stats": { 109 "total_tokens": 10, 110 "input_tokens": 5, 111 "output_tokens": 5, 112 }, 113 } 114 ), 115 ] 116 process = make_mock_process(stdout_lines) 117 monkeypatch.setattr( 118 "think.providers.cli.asyncio.create_subprocess_exec", 119 AsyncMock(return_value=process), 120 ) 121 122 ndjson_input = json.dumps( 123 { 124 "prompt": "hello", 125 "provider": "google", 126 "model": GEMINI_FLASH, 127 "tools": ["search_insights"], 128 } 129 ) 130 asyncio.run(run_main(mod, ["sol agents"], stdin_data=ndjson_input)) 131 132 out_lines = capsys.readouterr().out.strip().splitlines() 133 events = [json.loads(line) for line in out_lines] 134 assert events[0]["event"] == "start" 135 assert isinstance(events[0]["ts"], int) 136 assert "hello" in events[0]["prompt"] 137 assert events[0]["name"] == "unified" 138 assert events[0]["model"] == GEMINI_FLASH 139 assert events[-1]["event"] == "finish" 140 assert isinstance(events[-1]["ts"], int) 141 assert events[-1]["result"] == "ok" 142 143 # Journal logging is now handled by cortex, not by agents directly 144 # So we don't check for journal files here 145 146 147def test_google_cli_not_found_error(monkeypatch, tmp_path, capsys): 148 setup_google_genai_stub(monkeypatch, with_thinking=False) 149 150 sys.modules.pop("think.providers.google", None) 151 importlib.reload(importlib.import_module("think.providers.google")) 152 mod = importlib.reload(importlib.import_module("think.agents")) 153 154 journal = tmp_path / "journal" 155 journal.mkdir() 156 157 monkeypatch.setenv("_SOLSTONE_JOURNAL_OVERRIDE", str(journal)) 158 monkeypatch.setenv("GOOGLE_API_KEY", "x") 159 monkeypatch.setattr("think.providers.cli.shutil.which", lambda _name: None) 160 161 ndjson_input = json.dumps( 162 { 163 "prompt": "hello", 164 "provider": "google", 165 "model": GEMINI_FLASH, 166 "tools": ["search_insights"], 167 } 168 ) 169 asyncio.run(run_main(mod, ["sol agents"], stdin_data=ndjson_input)) 170 171 # Check stdout for error event 172 out_lines = capsys.readouterr().out.strip().splitlines() 173 events = [json.loads(line) for line in out_lines] 174 assert events[-1]["event"] == "error" 175 assert isinstance(events[-1]["ts"], int) 176 error_message = events[-1]["error"].lower() 177 assert "gemini" in error_message 178 assert "not found" in error_message 179 assert "trace" in events[-1] 180 181 182# --------------------------------------------------------------------------- 183# Tests for finish reason extraction and formatting 184# --------------------------------------------------------------------------- 185 186 187def test_extract_finish_reason_with_enum(): 188 """Test extracting finish_reason from enum-style response.""" 189 190 class MockEnum: 191 name = "STOP" 192 193 candidate = SimpleNamespace(finish_reason=MockEnum()) 194 response = SimpleNamespace(candidates=[candidate]) 195 assert _extract_finish_reason(response) == "STOP" 196 197 198def test_extract_finish_reason_with_string(): 199 """Test extracting finish_reason when it's already a string.""" 200 candidate = SimpleNamespace(finish_reason="MAX_TOKENS") 201 response = SimpleNamespace(candidates=[candidate]) 202 assert _extract_finish_reason(response) == "MAX_TOKENS" 203 204 205def test_extract_finish_reason_no_candidates(): 206 """Test extracting finish_reason when no candidates exist.""" 207 response = SimpleNamespace(candidates=[]) 208 assert _extract_finish_reason(response) is None 209 210 response = SimpleNamespace() 211 assert _extract_finish_reason(response) is None 212 213 214def test_format_completion_message_stop_with_tools(): 215 """Test message for STOP with tool calls.""" 216 msg = _format_completion_message("STOP", had_tool_calls=True) 217 assert msg == "Completed via tools." 218 219 220def test_format_completion_message_stop_no_tools(): 221 """Test message for STOP without tool calls.""" 222 msg = _format_completion_message("STOP", had_tool_calls=False) 223 assert msg == "Completed." 224 225 226def test_format_completion_message_max_tokens(): 227 """Test message for MAX_TOKENS finish reason.""" 228 msg = _format_completion_message("MAX_TOKENS", had_tool_calls=False) 229 assert msg == "Reached token limit." 230 231 232def test_format_completion_message_safety(): 233 """Test message for safety-related finish reasons.""" 234 msg = _format_completion_message("SAFETY", had_tool_calls=False) 235 assert msg == "Blocked by safety filters." 236 237 msg = _format_completion_message("PROHIBITED_SAFETY", had_tool_calls=False) 238 assert msg == "Blocked by safety filters." 239 240 241def test_format_completion_message_tool_errors(): 242 """Test message for tool-related error finish reasons.""" 243 msg = _format_completion_message("UNEXPECTED_TOOL_CALL", had_tool_calls=True) 244 assert msg == "Tool execution incomplete." 245 246 msg = _format_completion_message("MALFORMED_FUNCTION_CALL", had_tool_calls=False) 247 assert msg == "Tool execution incomplete." 248 249 250def test_format_completion_message_unknown(): 251 """Test message for unknown finish reasons.""" 252 msg = _format_completion_message("SOME_NEW_REASON", had_tool_calls=False) 253 assert msg == "Completed (some_new_reason)." 254 255 256def test_format_completion_message_none(): 257 """Test message when finish_reason is None.""" 258 msg = _format_completion_message(None, had_tool_calls=False) 259 assert msg == "Completed (unknown)."