personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4"""Tests for think.providers.cli — CLI subprocess runner infrastructure."""
5
6import asyncio
7import os
8from unittest.mock import AsyncMock, patch
9
10import pytest
11
12from think.providers.cli import (
13 CLIRunner,
14 ThinkingAggregator,
15 assemble_prompt,
16 build_cogitate_env,
17)
18from think.providers.shared import JSONEventCallback, safe_raw
19
20# ---------------------------------------------------------------------------
21# assemble_prompt
22# ---------------------------------------------------------------------------
23
24
25class TestAssemblePrompt:
26 def test_all_fields(self):
27 config = {
28 "transcript": "Speaker A: hello",
29 "extra_context": "Today is Monday",
30 "user_instruction": "Summarize the transcript",
31 "prompt": "What happened?",
32 "system_instruction": "You are a helpful assistant",
33 }
34 body, system = assemble_prompt(config)
35 assert "Speaker A: hello" in body
36 assert "Today is Monday" in body
37 assert "Summarize the transcript" in body
38 assert "What happened?" in body
39 assert system == "You are a helpful assistant"
40 # Parts joined with double newlines
41 assert body.count("\n\n") == 3
42
43 def test_prompt_only(self):
44 config = {"prompt": "hello"}
45 body, system = assemble_prompt(config)
46 assert body == "hello"
47 assert system is None
48
49 def test_empty_config(self):
50 body, system = assemble_prompt({})
51 assert body == ""
52 assert system is None
53
54 def test_skips_empty_values(self):
55 config = {
56 "transcript": "",
57 "extra_context": None,
58 "user_instruction": "Do something",
59 "prompt": "Go",
60 }
61 body, system = assemble_prompt(config)
62 assert body == "Do something\n\nGo"
63 assert system is None
64
65 def test_system_instruction_empty_string(self):
66 config = {"prompt": "test", "system_instruction": ""}
67 _, system = assemble_prompt(config)
68 assert system is None
69
70
71# ---------------------------------------------------------------------------
72# ThinkingAggregator
73# ---------------------------------------------------------------------------
74
75
76class TestThinkingAggregator:
77 def _make_aggregator(self):
78 """Create aggregator with event capture."""
79 events = []
80 cb = JSONEventCallback(events.append)
81 agg = ThinkingAggregator(cb, model="test-model")
82 return agg, events
83
84 def test_accumulate_and_flush_as_thinking(self):
85 agg, events = self._make_aggregator()
86 agg.accumulate("hello ")
87 agg.accumulate("world")
88 agg.flush_as_thinking(raw_events=[{"type": "message"}])
89
90 assert len(events) == 1
91 assert events[0]["event"] == "thinking"
92 assert events[0]["summary"] == "hello world"
93 assert events[0]["model"] == "test-model"
94 assert events[0]["raw"] == [{"type": "message"}]
95
96 def test_flush_thinking_empty_buffer_is_noop(self):
97 agg, events = self._make_aggregator()
98 agg.flush_as_thinking()
99 assert len(events) == 0
100
101 def test_flush_thinking_whitespace_only_is_noop(self):
102 agg, events = self._make_aggregator()
103 agg.accumulate(" ")
104 agg.flush_as_thinking()
105 assert len(events) == 0
106
107 def test_flush_as_result(self):
108 agg, events = self._make_aggregator()
109 agg.accumulate("final answer")
110 result = agg.flush_as_result()
111 assert result == "final answer"
112 # No events emitted for result flush
113 assert len(events) == 0
114 # Buffer is cleared
115 assert agg.flush_as_result() == ""
116
117 def test_multiple_thinking_flushes(self):
118 """Simulate text -> tool -> text -> tool -> text pattern."""
119 agg, events = self._make_aggregator()
120
121 # First text chunk (before first tool call)
122 agg.accumulate("Let me check...")
123 agg.flush_as_thinking()
124
125 # Second text chunk (between tool calls)
126 agg.accumulate("Now let me verify...")
127 agg.flush_as_thinking()
128
129 # Final text (the result)
130 agg.accumulate("The answer is 42")
131 result = agg.flush_as_result()
132
133 assert len(events) == 2
134 assert events[0]["summary"] == "Let me check..."
135 assert events[1]["summary"] == "Now let me verify..."
136 assert result == "The answer is 42"
137
138 def test_has_content(self):
139 agg, _ = self._make_aggregator()
140 assert not agg.has_content
141 agg.accumulate("x")
142 assert agg.has_content
143 agg.flush_as_result()
144 assert not agg.has_content
145
146 def test_no_raw_events(self):
147 agg, events = self._make_aggregator()
148 agg.accumulate("thinking")
149 agg.flush_as_thinking()
150 assert "raw" not in events[0]
151
152 def test_strips_whitespace(self):
153 agg, events = self._make_aggregator()
154 agg.accumulate(" padded ")
155 agg.flush_as_thinking()
156 assert events[0]["summary"] == "padded"
157
158
159class _MockStderr:
160 """Async iterator yielding pre-set stderr lines."""
161
162 def __init__(self, lines: list[bytes]):
163 self._lines = lines
164 self._index = 0
165
166 def __aiter__(self):
167 return self
168
169 async def __anext__(self):
170 if self._index >= len(self._lines):
171 raise StopAsyncIteration
172 line = self._lines[self._index]
173 self._index += 1
174 return line
175
176
177class _MockStdout:
178 """Async iterator yielding pre-set stdout lines, with readline support."""
179
180 def __init__(self, lines: list[bytes]):
181 self._lines = lines
182 self._index = 0
183
184 async def readline(self):
185 if self._index >= len(self._lines):
186 return b""
187 line = self._lines[self._index]
188 self._index += 1
189 return line
190
191 def __aiter__(self):
192 return self
193
194 async def __anext__(self):
195 if self._index >= len(self._lines):
196 raise StopAsyncIteration
197 line = self._lines[self._index]
198 self._index += 1
199 return line
200
201
202def _make_process(stdout_lines, stderr_lines, return_code):
203 """Create a mock process with given stdout/stderr/exit code."""
204 process = AsyncMock()
205 process.stdout = _MockStdout(stdout_lines)
206 process.stderr = _MockStderr(stderr_lines)
207 process.stdin = AsyncMock()
208 process.stdin.write = lambda _data: None
209 process.stdin.close = lambda: None
210 process.kill = lambda: None
211 process.wait = AsyncMock(return_value=return_code)
212 return process
213
214
215class TestCLIRunnerExitCode:
216 """Tests for CLIRunner handling of non-zero exit codes."""
217
218 def test_nonzero_exit_no_output_raises(self):
219 """CLI exits with error and no result → RuntimeError with stderr."""
220 events = []
221 callback = JSONEventCallback(events.append)
222 aggregator = ThinkingAggregator(callback, model="test-model")
223
224 process = _make_process(
225 stdout_lines=[],
226 stderr_lines=[b"TerminalQuotaError: quota exhausted\n"],
227 return_code=1,
228 )
229
230 runner = CLIRunner(
231 cmd=["fakecli", "--json"],
232 prompt_text="test",
233 translate=lambda _e, _a, _c: None,
234 callback=callback,
235 aggregator=aggregator,
236 )
237
238 with (
239 patch(
240 "think.providers.cli.asyncio.create_subprocess_exec",
241 AsyncMock(return_value=process),
242 ),
243 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"),
244 pytest.raises(RuntimeError, match="quota exhausted"),
245 ):
246 asyncio.run(runner.run())
247
248 # CLIRunner should NOT emit error events — that's the caller's job
249 error_events = [e for e in events if e.get("event") == "error"]
250 assert len(error_events) == 0
251
252 def test_nonzero_exit_with_output_returns_result(self):
253 """CLI exits with error but produced output → return result + warning."""
254 events = []
255 callback = JSONEventCallback(events.append)
256 aggregator = ThinkingAggregator(callback, model="test-model")
257
258 # translate accumulates text from stdout events
259 def translate(event, agg, cb):
260 if event.get("type") == "text":
261 agg.accumulate(event["content"])
262 return None
263
264 process = _make_process(
265 stdout_lines=[b'{"type": "text", "content": "The answer is 42"}\n'],
266 stderr_lines=[b"Warning: something went wrong\n"],
267 return_code=1,
268 )
269
270 runner = CLIRunner(
271 cmd=["fakecli", "--json"],
272 prompt_text="test",
273 translate=translate,
274 callback=callback,
275 aggregator=aggregator,
276 )
277
278 with (
279 patch(
280 "think.providers.cli.asyncio.create_subprocess_exec",
281 AsyncMock(return_value=process),
282 ),
283 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"),
284 ):
285 result = asyncio.run(runner.run())
286
287 assert result == "The answer is 42"
288 warning_events = [e for e in events if e.get("event") == "warning"]
289 assert len(warning_events) == 1
290 assert "code 1" in warning_events[0]["message"]
291 assert "something went wrong" in warning_events[0]["stderr"]
292
293 def test_zero_exit_empty_result_ok(self):
294 """CLI exits 0 with no output → return empty string, no error."""
295 events = []
296 callback = JSONEventCallback(events.append)
297 aggregator = ThinkingAggregator(callback, model="test-model")
298
299 process = _make_process(
300 stdout_lines=[],
301 stderr_lines=[],
302 return_code=0,
303 )
304
305 runner = CLIRunner(
306 cmd=["fakecli", "--json"],
307 prompt_text="test",
308 translate=lambda _e, _a, _c: None,
309 callback=callback,
310 aggregator=aggregator,
311 )
312
313 with (
314 patch(
315 "think.providers.cli.asyncio.create_subprocess_exec",
316 AsyncMock(return_value=process),
317 ),
318 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"),
319 ):
320 result = asyncio.run(runner.run())
321
322 assert result == ""
323 assert not [e for e in events if e.get("event") in ("error", "warning")]
324
325
326class TestCLIRunnerFirstEventTimeout:
327 def test_first_event_timeout_includes_stderr(self):
328 events = []
329 callback = JSONEventCallback(events.append)
330 aggregator = ThinkingAggregator(callback, model="test-model")
331
332 class HangingStdout:
333 async def readline(self):
334 future = asyncio.get_running_loop().create_future()
335 return await future
336
337 process = _make_process([], [b"Please authenticate first\n"], 0)
338 process.stdout = HangingStdout() # Override with hanging version
339
340 runner = CLIRunner(
341 cmd=["fakecli", "--json"],
342 prompt_text="test prompt",
343 translate=lambda _event, _agg, _cb: None,
344 callback=callback,
345 aggregator=aggregator,
346 timeout=5,
347 first_event_timeout=0.1,
348 )
349
350 with (
351 patch(
352 "think.providers.cli.asyncio.create_subprocess_exec",
353 AsyncMock(return_value=process),
354 ),
355 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"),
356 pytest.raises(RuntimeError) as exc_info,
357 ):
358 asyncio.run(runner.run())
359
360 message = str(exc_info.value)
361 assert "authenticate" in message.lower()
362 assert "Check that the CLI tool is installed and authenticated." in message
363
364 error_events = [event for event in events if event.get("event") == "error"]
365 assert len(error_events) == 1
366 assert "Please authenticate first" in error_events[0]["error"]
367
368
369_OVERSIZE = object() # sentinel for oversize line in _MockStdoutWithOversize
370
371
372class _MockStdoutWithOversize:
373 """Stdout mock that raises LimitOverrunError on a specific readline() call."""
374
375 def __init__(self, lines: list):
376 # lines entries are either bytes or the sentinel OVERSIZE
377 self._lines = lines
378 self._index = 0
379 self._draining_oversize = False
380
381 async def readline(self):
382 if self._draining_oversize:
383 self._draining_oversize = False
384 return b"x" * 1024 * 1024 + b"\n"
385 if self._index >= len(self._lines):
386 return b""
387 entry = self._lines[self._index]
388 self._index += 1
389 if entry is _OVERSIZE:
390 self._draining_oversize = True
391 raise asyncio.LimitOverrunError(
392 "Separator is not found, and chunk exceed the limit", 1024 * 1024
393 )
394 return entry
395
396 async def readexactly(self, n: int) -> bytes:
397 return b"x" * n
398
399 def __aiter__(self):
400 return self
401
402 async def __anext__(self):
403 val = await self.readline()
404 if val == b"":
405 raise StopAsyncIteration
406 return val
407
408
409class TestCLIRunnerOversizedOutput:
410 """CLIRunner recovers from LimitOverrunError in the stdout loop."""
411
412 def test_oversized_line_emits_tool_end_and_continues(self):
413 """Oversize line → synthetic tool_end emitted + subsequent line processed."""
414 import json
415
416 normal_line_1 = json.dumps({"event": "text", "text": "hello"}).encode() + b"\n"
417 normal_line_2 = json.dumps({"event": "text", "text": "world"}).encode() + b"\n"
418
419 events = []
420 callback = JSONEventCallback(events.append)
421 aggregator = ThinkingAggregator(callback, model="test-model")
422
423 process = AsyncMock()
424 process.stdout = _MockStdoutWithOversize(
425 [
426 normal_line_1,
427 _OVERSIZE,
428 normal_line_2,
429 ]
430 )
431 process.stderr = _MockStderr([])
432 process.stdin = AsyncMock()
433 process.stdin.write = lambda _data: None
434 process.stdin.close = lambda: None
435 process.kill = lambda: None
436 process.wait = AsyncMock(return_value=0)
437
438 # translate just forwards text events as-is
439 def translate(event_data, agg, cb):
440 if event_data.get("event") == "text":
441 cb.emit({"event": "text", "text": event_data["text"]})
442 return None
443
444 runner = CLIRunner(
445 cmd=["fakecli", "--json"],
446 prompt_text="test",
447 translate=translate,
448 callback=callback,
449 aggregator=aggregator,
450 )
451
452 with (
453 patch(
454 "think.providers.cli.asyncio.create_subprocess_exec",
455 AsyncMock(return_value=process),
456 ),
457 patch("think.providers.cli.shutil.which", return_value="/usr/bin/fakecli"),
458 ):
459 asyncio.run(runner.run())
460
461 event_types = [e["event"] for e in events]
462 # tool_end should be emitted
463 assert "tool_end" in event_types, f"Expected tool_end in events: {events}"
464
465 # the tool_end result should indicate truncation
466 tool_end_events = [e for e in events if e["event"] == "tool_end"]
467 assert len(tool_end_events) == 1
468 assert "truncated" in tool_end_events[0]["result"]
469
470 # the normal line after the oversize error should also be processed
471 text_events = [e for e in events if e["event"] == "text"]
472 texts = [e["text"] for e in text_events]
473 assert "world" in texts, f"Expected 'world' in text events: {texts}"
474
475
476# ---------------------------------------------------------------------------
477# safe_raw
478# ---------------------------------------------------------------------------
479
480
481class TestSafeRaw:
482 def test_small_event_returned_unchanged(self):
483 events = [{"type": "tool_use", "tool_name": "read_file", "tool_id": "t1"}]
484 assert safe_raw(events) is events
485
486 def test_large_event_trimmed(self):
487 big_output = "x" * 20_000
488 events = [
489 {
490 "type": "tool_result",
491 "tool_id": "t1",
492 "output": big_output,
493 "extra_field": "value",
494 }
495 ]
496 result = safe_raw(events)
497 assert result is not events
498 # Should keep only structural keys
499 assert result[0] == {"type": "tool_result", "tool_id": "t1"}
500 # Last element is the trimmed metadata
501 meta = result[-1]["_raw_trimmed"]
502 assert meta["limit"] == 16_384
503 assert meta["original_bytes"] > 16_384
504
505 def test_custom_limit(self):
506 events = [{"type": "message", "content": "a" * 200}]
507 # Under custom limit
508 assert safe_raw(events, limit=1024) is events
509 # Over custom limit
510 result = safe_raw(events, limit=50)
511 assert result is not events
512 assert result[-1]["_raw_trimmed"]["limit"] == 50
513
514 def test_structural_keys_preserved(self):
515 events = [
516 {
517 "type": "tool_use",
518 "id": "abc",
519 "tool_id": "t1",
520 "tool_name": "search",
521 "role": "assistant",
522 "event_type": "message",
523 "timestamp": 12345,
524 "big_content": "z" * 20_000,
525 }
526 ]
527 result = safe_raw(events)
528 kept = result[0]
529 assert kept == {
530 "type": "tool_use",
531 "id": "abc",
532 "tool_id": "t1",
533 "tool_name": "search",
534 "role": "assistant",
535 "event_type": "message",
536 "timestamp": 12345,
537 }
538
539 def test_multiple_events(self):
540 events = [
541 {"type": "msg", "data": "a" * 10_000},
542 {"type": "msg", "data": "b" * 10_000},
543 ]
544 result = safe_raw(events)
545 assert len(result) == 3 # 2 trimmed events + 1 metadata
546 assert result[0] == {"type": "msg"}
547 assert result[1] == {"type": "msg"}
548 assert "_raw_trimmed" in result[2]
549
550
551# ---------------------------------------------------------------------------
552# build_cogitate_env
553# ---------------------------------------------------------------------------
554
555
556class TestBuildCogitateEnv:
557 """Tests for build_cogitate_env — API key stripping for CLI subprocesses."""
558
559 def test_default_strips_key(self):
560 """No auth config → default platform mode → key removed."""
561 config = {"providers": {}}
562 with (
563 patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-secret"}, clear=False),
564 patch("think.utils.get_config", return_value=config),
565 ):
566 env = build_cogitate_env("ANTHROPIC_API_KEY")
567 assert "ANTHROPIC_API_KEY" not in env
568
569 def test_explicit_platform_strips_key(self):
570 """auth.anthropic = "platform" → key removed."""
571 config = {"providers": {"auth": {"anthropic": "platform"}}}
572 with (
573 patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-secret"}, clear=False),
574 patch("think.utils.get_config", return_value=config),
575 ):
576 env = build_cogitate_env("ANTHROPIC_API_KEY")
577 assert "ANTHROPIC_API_KEY" not in env
578
579 def test_api_key_mode_preserves_key(self):
580 """auth.anthropic = "api_key" → key preserved."""
581 config = {"providers": {"auth": {"anthropic": "api_key"}}}
582 with (
583 patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-secret"}, clear=False),
584 patch("think.utils.get_config", return_value=config),
585 ):
586 env = build_cogitate_env("ANTHROPIC_API_KEY")
587 assert env["ANTHROPIC_API_KEY"] == "sk-secret"
588
589 def test_missing_auth_section_strips_key(self):
590 """No providers section at all → safe default, key removed."""
591 config = {}
592 with (
593 patch.dict(os.environ, {"OPENAI_API_KEY": "sk-openai"}, clear=False),
594 patch("think.utils.get_config", return_value=config),
595 ):
596 env = build_cogitate_env("OPENAI_API_KEY")
597 assert "OPENAI_API_KEY" not in env
598
599 def test_other_env_vars_preserved(self):
600 """Non-API-key vars are never stripped."""
601 config = {"providers": {}}
602 with (
603 patch.dict(
604 os.environ,
605 {"ANTHROPIC_API_KEY": "sk-secret", "HOME": "/home/test"},
606 clear=False,
607 ),
608 patch("think.utils.get_config", return_value=config),
609 ):
610 env = build_cogitate_env("ANTHROPIC_API_KEY")
611 assert env["HOME"] == "/home/test"
612
613 def test_key_not_in_env_is_harmless(self):
614 """Stripping a key that doesn't exist doesn't error."""
615 config = {"providers": {}}
616 with (
617 patch.dict(os.environ, {}, clear=False),
618 patch("think.utils.get_config", return_value=config),
619 ):
620 env = build_cogitate_env("GOOGLE_API_KEY")
621 assert "GOOGLE_API_KEY" not in env
622
623 def test_per_provider_independence(self):
624 """Each provider's auth mode is independent."""
625 config = {
626 "providers": {
627 "auth": {
628 "anthropic": "api_key",
629 "openai": "platform",
630 }
631 }
632 }
633 with (
634 patch.dict(
635 os.environ,
636 {"ANTHROPIC_API_KEY": "sk-ant", "OPENAI_API_KEY": "sk-oai"},
637 clear=False,
638 ),
639 patch("think.utils.get_config", return_value=config),
640 ):
641 ant_env = build_cogitate_env("ANTHROPIC_API_KEY")
642 oai_env = build_cogitate_env("OPENAI_API_KEY")
643 assert ant_env["ANTHROPIC_API_KEY"] == "sk-ant"
644 assert "OPENAI_API_KEY" not in oai_env