personal memory agent
at main 644 lines 22 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4"""Tests for think.providers.cli — CLI subprocess runner infrastructure.""" 5 6import asyncio 7import os 8from unittest.mock import AsyncMock, patch 9 10import pytest 11 12from think.providers.cli import ( 13 CLIRunner, 14 ThinkingAggregator, 15 assemble_prompt, 16 build_cogitate_env, 17) 18from think.providers.shared import JSONEventCallback, safe_raw 19 20# --------------------------------------------------------------------------- 21# assemble_prompt 22# --------------------------------------------------------------------------- 23 24 25class TestAssemblePrompt: 26 def test_all_fields(self): 27 config = { 28 "transcript": "Speaker A: hello", 29 "extra_context": "Today is Monday", 30 "user_instruction": "Summarize the transcript", 31 "prompt": "What happened?", 32 "system_instruction": "You are a helpful assistant", 33 } 34 body, system = assemble_prompt(config) 35 assert "Speaker A: hello" in body 36 assert "Today is Monday" in body 37 assert "Summarize the transcript" in body 38 assert "What happened?" in body 39 assert system == "You are a helpful assistant" 40 # Parts joined with double newlines 41 assert body.count("\n\n") == 3 42 43 def test_prompt_only(self): 44 config = {"prompt": "hello"} 45 body, system = assemble_prompt(config) 46 assert body == "hello" 47 assert system is None 48 49 def test_empty_config(self): 50 body, system = assemble_prompt({}) 51 assert body == "" 52 assert system is None 53 54 def test_skips_empty_values(self): 55 config = { 56 "transcript": "", 57 "extra_context": None, 58 "user_instruction": "Do something", 59 "prompt": "Go", 60 } 61 body, system = assemble_prompt(config) 62 assert body == "Do something\n\nGo" 63 assert system is None 64 65 def test_system_instruction_empty_string(self): 66 config = {"prompt": "test", "system_instruction": ""} 67 _, system = assemble_prompt(config) 68 assert system is None 69 70 71# --------------------------------------------------------------------------- 72# ThinkingAggregator 73# --------------------------------------------------------------------------- 74 75 76class TestThinkingAggregator: 77 def _make_aggregator(self): 78 """Create aggregator with event capture.""" 79 events = [] 80 cb = JSONEventCallback(events.append) 81 agg = ThinkingAggregator(cb, model="test-model") 82 return agg, events 83 84 def test_accumulate_and_flush_as_thinking(self): 85 agg, events = self._make_aggregator() 86 agg.accumulate("hello ") 87 agg.accumulate("world") 88 agg.flush_as_thinking(raw_events=[{"type": "message"}]) 89 90 assert len(events) == 1 91 assert events[0]["event"] == "thinking" 92 assert events[0]["summary"] == "hello world" 93 assert events[0]["model"] == "test-model" 94 assert events[0]["raw"] == [{"type": "message"}] 95 96 def test_flush_thinking_empty_buffer_is_noop(self): 97 agg, events = self._make_aggregator() 98 agg.flush_as_thinking() 99 assert len(events) == 0 100 101 def test_flush_thinking_whitespace_only_is_noop(self): 102 agg, events = self._make_aggregator() 103 agg.accumulate(" ") 104 agg.flush_as_thinking() 105 assert len(events) == 0 106 107 def test_flush_as_result(self): 108 agg, events = self._make_aggregator() 109 agg.accumulate("final answer") 110 result = agg.flush_as_result() 111 assert result == "final answer" 112 # No events emitted for result flush 113 assert len(events) == 0 114 # Buffer is cleared 115 assert agg.flush_as_result() == "" 116 117 def test_multiple_thinking_flushes(self): 118 """Simulate text -> tool -> text -> tool -> text pattern.""" 119 agg, events = self._make_aggregator() 120 121 # First text chunk (before first tool call) 122 agg.accumulate("Let me check...") 123 agg.flush_as_thinking() 124 125 # Second text chunk (between tool calls) 126 agg.accumulate("Now let me verify...") 127 agg.flush_as_thinking() 128 129 # Final text (the result) 130 agg.accumulate("The answer is 42") 131 result = agg.flush_as_result() 132 133 assert len(events) == 2 134 assert events[0]["summary"] == "Let me check..." 135 assert events[1]["summary"] == "Now let me verify..." 136 assert result == "The answer is 42" 137 138 def test_has_content(self): 139 agg, _ = self._make_aggregator() 140 assert not agg.has_content 141 agg.accumulate("x") 142 assert agg.has_content 143 agg.flush_as_result() 144 assert not agg.has_content 145 146 def test_no_raw_events(self): 147 agg, events = self._make_aggregator() 148 agg.accumulate("thinking") 149 agg.flush_as_thinking() 150 assert "raw" not in events[0] 151 152 def test_strips_whitespace(self): 153 agg, events = self._make_aggregator() 154 agg.accumulate(" padded ") 155 agg.flush_as_thinking() 156 assert events[0]["summary"] == "padded" 157 158 159class _MockStderr: 160 """Async iterator yielding pre-set stderr lines.""" 161 162 def __init__(self, lines: list[bytes]): 163 self._lines = lines 164 self._index = 0 165 166 def __aiter__(self): 167 return self 168 169 async def __anext__(self): 170 if self._index >= len(self._lines): 171 raise StopAsyncIteration 172 line = self._lines[self._index] 173 self._index += 1 174 return line 175 176 177class _MockStdout: 178 """Async iterator yielding pre-set stdout lines, with readline support.""" 179 180 def __init__(self, lines: list[bytes]): 181 self._lines = lines 182 self._index = 0 183 184 async def readline(self): 185 if self._index >= len(self._lines): 186 return b"" 187 line = self._lines[self._index] 188 self._index += 1 189 return line 190 191 def __aiter__(self): 192 return self 193 194 async def __anext__(self): 195 if self._index >= len(self._lines): 196 raise StopAsyncIteration 197 line = self._lines[self._index] 198 self._index += 1 199 return line 200 201 202def _make_process(stdout_lines, stderr_lines, return_code): 203 """Create a mock process with given stdout/stderr/exit code.""" 204 process = AsyncMock() 205 process.stdout = _MockStdout(stdout_lines) 206 process.stderr = _MockStderr(stderr_lines) 207 process.stdin = AsyncMock() 208 process.stdin.write = lambda _data: None 209 process.stdin.close = lambda: None 210 process.kill = lambda: None 211 process.wait = AsyncMock(return_value=return_code) 212 return process 213 214 215class TestCLIRunnerExitCode: 216 """Tests for CLIRunner handling of non-zero exit codes.""" 217 218 def test_nonzero_exit_no_output_raises(self): 219 """CLI exits with error and no result → RuntimeError with stderr.""" 220 events = [] 221 callback = JSONEventCallback(events.append) 222 aggregator = ThinkingAggregator(callback, model="test-model") 223 224 process = _make_process( 225 stdout_lines=[], 226 stderr_lines=[b"TerminalQuotaError: quota exhausted\n"], 227 return_code=1, 228 ) 229 230 runner = CLIRunner( 231 cmd=["fakecli", "--json"], 232 prompt_text="test", 233 translate=lambda _e, _a, _c: None, 234 callback=callback, 235 aggregator=aggregator, 236 ) 237 238 with ( 239 patch( 240 "think.providers.cli.asyncio.create_subprocess_exec", 241 AsyncMock(return_value=process), 242 ), 243 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"), 244 pytest.raises(RuntimeError, match="quota exhausted"), 245 ): 246 asyncio.run(runner.run()) 247 248 # CLIRunner should NOT emit error events — that's the caller's job 249 error_events = [e for e in events if e.get("event") == "error"] 250 assert len(error_events) == 0 251 252 def test_nonzero_exit_with_output_returns_result(self): 253 """CLI exits with error but produced output → return result + warning.""" 254 events = [] 255 callback = JSONEventCallback(events.append) 256 aggregator = ThinkingAggregator(callback, model="test-model") 257 258 # translate accumulates text from stdout events 259 def translate(event, agg, cb): 260 if event.get("type") == "text": 261 agg.accumulate(event["content"]) 262 return None 263 264 process = _make_process( 265 stdout_lines=[b'{"type": "text", "content": "The answer is 42"}\n'], 266 stderr_lines=[b"Warning: something went wrong\n"], 267 return_code=1, 268 ) 269 270 runner = CLIRunner( 271 cmd=["fakecli", "--json"], 272 prompt_text="test", 273 translate=translate, 274 callback=callback, 275 aggregator=aggregator, 276 ) 277 278 with ( 279 patch( 280 "think.providers.cli.asyncio.create_subprocess_exec", 281 AsyncMock(return_value=process), 282 ), 283 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"), 284 ): 285 result = asyncio.run(runner.run()) 286 287 assert result == "The answer is 42" 288 warning_events = [e for e in events if e.get("event") == "warning"] 289 assert len(warning_events) == 1 290 assert "code 1" in warning_events[0]["message"] 291 assert "something went wrong" in warning_events[0]["stderr"] 292 293 def test_zero_exit_empty_result_ok(self): 294 """CLI exits 0 with no output → return empty string, no error.""" 295 events = [] 296 callback = JSONEventCallback(events.append) 297 aggregator = ThinkingAggregator(callback, model="test-model") 298 299 process = _make_process( 300 stdout_lines=[], 301 stderr_lines=[], 302 return_code=0, 303 ) 304 305 runner = CLIRunner( 306 cmd=["fakecli", "--json"], 307 prompt_text="test", 308 translate=lambda _e, _a, _c: None, 309 callback=callback, 310 aggregator=aggregator, 311 ) 312 313 with ( 314 patch( 315 "think.providers.cli.asyncio.create_subprocess_exec", 316 AsyncMock(return_value=process), 317 ), 318 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"), 319 ): 320 result = asyncio.run(runner.run()) 321 322 assert result == "" 323 assert not [e for e in events if e.get("event") in ("error", "warning")] 324 325 326class TestCLIRunnerFirstEventTimeout: 327 def test_first_event_timeout_includes_stderr(self): 328 events = [] 329 callback = JSONEventCallback(events.append) 330 aggregator = ThinkingAggregator(callback, model="test-model") 331 332 class HangingStdout: 333 async def readline(self): 334 future = asyncio.get_running_loop().create_future() 335 return await future 336 337 process = _make_process([], [b"Please authenticate first\n"], 0) 338 process.stdout = HangingStdout() # Override with hanging version 339 340 runner = CLIRunner( 341 cmd=["fakecli", "--json"], 342 prompt_text="test prompt", 343 translate=lambda _event, _agg, _cb: None, 344 callback=callback, 345 aggregator=aggregator, 346 timeout=5, 347 first_event_timeout=0.1, 348 ) 349 350 with ( 351 patch( 352 "think.providers.cli.asyncio.create_subprocess_exec", 353 AsyncMock(return_value=process), 354 ), 355 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"), 356 pytest.raises(RuntimeError) as exc_info, 357 ): 358 asyncio.run(runner.run()) 359 360 message = str(exc_info.value) 361 assert "authenticate" in message.lower() 362 assert "Check that the CLI tool is installed and authenticated." in message 363 364 error_events = [event for event in events if event.get("event") == "error"] 365 assert len(error_events) == 1 366 assert "Please authenticate first" in error_events[0]["error"] 367 368 369_OVERSIZE = object() # sentinel for oversize line in _MockStdoutWithOversize 370 371 372class _MockStdoutWithOversize: 373 """Stdout mock that raises LimitOverrunError on a specific readline() call.""" 374 375 def __init__(self, lines: list): 376 # lines entries are either bytes or the sentinel OVERSIZE 377 self._lines = lines 378 self._index = 0 379 self._draining_oversize = False 380 381 async def readline(self): 382 if self._draining_oversize: 383 self._draining_oversize = False 384 return b"x" * 1024 * 1024 + b"\n" 385 if self._index >= len(self._lines): 386 return b"" 387 entry = self._lines[self._index] 388 self._index += 1 389 if entry is _OVERSIZE: 390 self._draining_oversize = True 391 raise asyncio.LimitOverrunError( 392 "Separator is not found, and chunk exceed the limit", 1024 * 1024 393 ) 394 return entry 395 396 async def readexactly(self, n: int) -> bytes: 397 return b"x" * n 398 399 def __aiter__(self): 400 return self 401 402 async def __anext__(self): 403 val = await self.readline() 404 if val == b"": 405 raise StopAsyncIteration 406 return val 407 408 409class TestCLIRunnerOversizedOutput: 410 """CLIRunner recovers from LimitOverrunError in the stdout loop.""" 411 412 def test_oversized_line_emits_tool_end_and_continues(self): 413 """Oversize line → synthetic tool_end emitted + subsequent line processed.""" 414 import json 415 416 normal_line_1 = json.dumps({"event": "text", "text": "hello"}).encode() + b"\n" 417 normal_line_2 = json.dumps({"event": "text", "text": "world"}).encode() + b"\n" 418 419 events = [] 420 callback = JSONEventCallback(events.append) 421 aggregator = ThinkingAggregator(callback, model="test-model") 422 423 process = AsyncMock() 424 process.stdout = _MockStdoutWithOversize( 425 [ 426 normal_line_1, 427 _OVERSIZE, 428 normal_line_2, 429 ] 430 ) 431 process.stderr = _MockStderr([]) 432 process.stdin = AsyncMock() 433 process.stdin.write = lambda _data: None 434 process.stdin.close = lambda: None 435 process.kill = lambda: None 436 process.wait = AsyncMock(return_value=0) 437 438 # translate just forwards text events as-is 439 def translate(event_data, agg, cb): 440 if event_data.get("event") == "text": 441 cb.emit({"event": "text", "text": event_data["text"]}) 442 return None 443 444 runner = CLIRunner( 445 cmd=["fakecli", "--json"], 446 prompt_text="test", 447 translate=translate, 448 callback=callback, 449 aggregator=aggregator, 450 ) 451 452 with ( 453 patch( 454 "think.providers.cli.asyncio.create_subprocess_exec", 455 AsyncMock(return_value=process), 456 ), 457 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"), 458 ): 459 asyncio.run(runner.run()) 460 461 event_types = [e["event"] for e in events] 462 # tool_end should be emitted 463 assert "tool_end" in event_types, f"Expected tool_end in events: {events}" 464 465 # the tool_end result should indicate truncation 466 tool_end_events = [e for e in events if e["event"] == "tool_end"] 467 assert len(tool_end_events) == 1 468 assert "truncated" in tool_end_events[0]["result"] 469 470 # the normal line after the oversize error should also be processed 471 text_events = [e for e in events if e["event"] == "text"] 472 texts = [e["text"] for e in text_events] 473 assert "world" in texts, f"Expected 'world' in text events: {texts}" 474 475 476# --------------------------------------------------------------------------- 477# safe_raw 478# --------------------------------------------------------------------------- 479 480 481class TestSafeRaw: 482 def test_small_event_returned_unchanged(self): 483 events = [{"type": "tool_use", "tool_name": "read_file", "tool_id": "t1"}] 484 assert safe_raw(events) is events 485 486 def test_large_event_trimmed(self): 487 big_output = "x" * 20_000 488 events = [ 489 { 490 "type": "tool_result", 491 "tool_id": "t1", 492 "output": big_output, 493 "extra_field": "value", 494 } 495 ] 496 result = safe_raw(events) 497 assert result is not events 498 # Should keep only structural keys 499 assert result[0] == {"type": "tool_result", "tool_id": "t1"} 500 # Last element is the trimmed metadata 501 meta = result[-1]["_raw_trimmed"] 502 assert meta["limit"] == 16_384 503 assert meta["original_bytes"] > 16_384 504 505 def test_custom_limit(self): 506 events = [{"type": "message", "content": "a" * 200}] 507 # Under custom limit 508 assert safe_raw(events, limit=1024) is events 509 # Over custom limit 510 result = safe_raw(events, limit=50) 511 assert result is not events 512 assert result[-1]["_raw_trimmed"]["limit"] == 50 513 514 def test_structural_keys_preserved(self): 515 events = [ 516 { 517 "type": "tool_use", 518 "id": "abc", 519 "tool_id": "t1", 520 "tool_name": "search", 521 "role": "assistant", 522 "event_type": "message", 523 "timestamp": 12345, 524 "big_content": "z" * 20_000, 525 } 526 ] 527 result = safe_raw(events) 528 kept = result[0] 529 assert kept == { 530 "type": "tool_use", 531 "id": "abc", 532 "tool_id": "t1", 533 "tool_name": "search", 534 "role": "assistant", 535 "event_type": "message", 536 "timestamp": 12345, 537 } 538 539 def test_multiple_events(self): 540 events = [ 541 {"type": "msg", "data": "a" * 10_000}, 542 {"type": "msg", "data": "b" * 10_000}, 543 ] 544 result = safe_raw(events) 545 assert len(result) == 3 # 2 trimmed events + 1 metadata 546 assert result[0] == {"type": "msg"} 547 assert result[1] == {"type": "msg"} 548 assert "_raw_trimmed" in result[2] 549 550 551# --------------------------------------------------------------------------- 552# build_cogitate_env 553# --------------------------------------------------------------------------- 554 555 556class TestBuildCogitateEnv: 557 """Tests for build_cogitate_env — API key stripping for CLI subprocesses.""" 558 559 def test_default_strips_key(self): 560 """No auth config → default platform mode → key removed.""" 561 config = {"providers": {}} 562 with ( 563 patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-secret"}, clear=False), 564 patch("think.utils.get_config", return_value=config), 565 ): 566 env = build_cogitate_env("ANTHROPIC_API_KEY") 567 assert "ANTHROPIC_API_KEY" not in env 568 569 def test_explicit_platform_strips_key(self): 570 """auth.anthropic = "platform" → key removed.""" 571 config = {"providers": {"auth": {"anthropic": "platform"}}} 572 with ( 573 patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-secret"}, clear=False), 574 patch("think.utils.get_config", return_value=config), 575 ): 576 env = build_cogitate_env("ANTHROPIC_API_KEY") 577 assert "ANTHROPIC_API_KEY" not in env 578 579 def test_api_key_mode_preserves_key(self): 580 """auth.anthropic = "api_key" → key preserved.""" 581 config = {"providers": {"auth": {"anthropic": "api_key"}}} 582 with ( 583 patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-secret"}, clear=False), 584 patch("think.utils.get_config", return_value=config), 585 ): 586 env = build_cogitate_env("ANTHROPIC_API_KEY") 587 assert env["ANTHROPIC_API_KEY"] == "sk-secret" 588 589 def test_missing_auth_section_strips_key(self): 590 """No providers section at all → safe default, key removed.""" 591 config = {} 592 with ( 593 patch.dict(os.environ, {"OPENAI_API_KEY": "sk-openai"}, clear=False), 594 patch("think.utils.get_config", return_value=config), 595 ): 596 env = build_cogitate_env("OPENAI_API_KEY") 597 assert "OPENAI_API_KEY" not in env 598 599 def test_other_env_vars_preserved(self): 600 """Non-API-key vars are never stripped.""" 601 config = {"providers": {}} 602 with ( 603 patch.dict( 604 os.environ, 605 {"ANTHROPIC_API_KEY": "sk-secret", "HOME": "/home/test"}, 606 clear=False, 607 ), 608 patch("think.utils.get_config", return_value=config), 609 ): 610 env = build_cogitate_env("ANTHROPIC_API_KEY") 611 assert env["HOME"] == "/home/test" 612 613 def test_key_not_in_env_is_harmless(self): 614 """Stripping a key that doesn't exist doesn't error.""" 615 config = {"providers": {}} 616 with ( 617 patch.dict(os.environ, {}, clear=False), 618 patch("think.utils.get_config", return_value=config), 619 ): 620 env = build_cogitate_env("GOOGLE_API_KEY") 621 assert "GOOGLE_API_KEY" not in env 622 623 def test_per_provider_independence(self): 624 """Each provider's auth mode is independent.""" 625 config = { 626 "providers": { 627 "auth": { 628 "anthropic": "api_key", 629 "openai": "platform", 630 } 631 } 632 } 633 with ( 634 patch.dict( 635 os.environ, 636 {"ANTHROPIC_API_KEY": "sk-ant", "OPENAI_API_KEY": "sk-oai"}, 637 clear=False, 638 ), 639 patch("think.utils.get_config", return_value=config), 640 ): 641 ant_env = build_cogitate_env("ANTHROPIC_API_KEY") 642 oai_env = build_cogitate_env("OPENAI_API_KEY") 643 assert ant_env["ANTHROPIC_API_KEY"] == "sk-ant" 644 assert "OPENAI_API_KEY" not in oai_env