personal memory agent
at main 456 lines 15 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4"""CLI subprocess runner for AI provider tool agents. 5 6Spawns provider CLI tools (claude, codex, gemini) in JSON streaming mode 7and translates their JSONL output into our standard Event format. 8 9Each provider module implements a translate() function that converts 10provider-specific JSONL events into our Event TypedDicts. The CLIRunner 11handles subprocess lifecycle, stdin piping, and event emission. 12""" 13 14from __future__ import annotations 15 16import asyncio 17import json 18import logging 19import os 20import shutil 21from pathlib import Path 22from typing import Any, Callable 23 24from think.providers.shared import JSONEventCallback, safe_raw 25from think.utils import now_ms 26 27LOG = logging.getLogger("think.providers.cli") 28 29_PROJECT_ROOT = Path(__file__).parent.parent.parent 30 31 32async def _drain_line(stream: asyncio.StreamReader) -> None: 33 """Drain a single overlong line from the stream by consuming it in chunks.""" 34 while True: 35 try: 36 await stream.readline() 37 return 38 except asyncio.LimitOverrunError as exc: 39 await stream.readexactly(exc.consumed) 40 41 42# --------------------------------------------------------------------------- 43# Prompt Assembly 44# --------------------------------------------------------------------------- 45 46 47def assemble_prompt(config: dict[str, Any]) -> tuple[str, str | None]: 48 """Combine config fields into a single prompt string and system instruction. 49 50 Joins transcript, extra_context, user_instruction, and prompt with 51 double newlines. Returns the system_instruction separately for CLIs 52 that support --system-prompt (Claude); callers for other CLIs should 53 prepend it to the prompt body. 54 55 Args: 56 config: Agent config dict with prompt, transcript, etc. 57 58 Returns: 59 Tuple of (prompt_body, system_instruction). 60 system_instruction may be None. 61 """ 62 parts = [] 63 for key in ("transcript", "extra_context", "user_instruction", "prompt"): 64 value = config.get(key) 65 if value: 66 parts.append(value) 67 68 prompt_body = "\n\n".join(parts) if parts else "" 69 system_instruction = config.get("system_instruction") or None 70 return prompt_body, system_instruction 71 72 73# --------------------------------------------------------------------------- 74# Thinking Aggregator 75# --------------------------------------------------------------------------- 76 77 78class ThinkingAggregator: 79 """Buffers assistant text between tool calls for thinking/result classification. 80 81 All assistant text that arrives between tool calls is treated as "thinking". 82 Only the final text after all tool activity completes is the "result". 83 84 Usage: 85 agg = ThinkingAggregator(callback, model) 86 # As text arrives: 87 agg.accumulate("some text") 88 # When a tool_start arrives, flush buffered text as thinking: 89 agg.flush_as_thinking(raw_events=[...]) 90 # When done (no more tool calls), get the final result: 91 result = agg.flush_as_result() 92 """ 93 94 def __init__( 95 self, 96 callback: JSONEventCallback, 97 model: str | None = None, 98 ) -> None: 99 self._buffer: list[str] = [] 100 self._callback = callback 101 self._model = model 102 103 def accumulate(self, text: str) -> None: 104 """Add text to the buffer.""" 105 if text: 106 self._buffer.append(text) 107 108 def flush_as_thinking(self, raw_events: list[dict[str, Any]] | None = None) -> None: 109 """Emit buffered text as a thinking event and clear the buffer. 110 111 Does nothing if the buffer is empty. 112 """ 113 text = "".join(self._buffer).strip() 114 self._buffer.clear() 115 if not text: 116 return 117 118 event: dict[str, Any] = { 119 "event": "thinking", 120 "summary": text, 121 "ts": now_ms(), 122 } 123 if self._model: 124 event["model"] = self._model 125 if raw_events: 126 event["raw"] = safe_raw(raw_events) 127 self._callback.emit(event) 128 129 def flush_as_result(self) -> str: 130 """Return buffered text as the final result and clear the buffer.""" 131 text = "".join(self._buffer).strip() 132 self._buffer.clear() 133 return text 134 135 @property 136 def has_content(self) -> bool: 137 """Whether the buffer has any content.""" 138 return bool(self._buffer) 139 140 141# --------------------------------------------------------------------------- 142# CLI Runner 143# --------------------------------------------------------------------------- 144 145 146class CLIRunner: 147 """Spawn a CLI subprocess and translate its JSONL output to our events. 148 149 The runner pipes a prompt to stdin, reads JSONL from stdout line by line, 150 and calls a provider-specific translate function for each line. 151 152 Args: 153 cmd: Command to run (e.g., ["claude", "-p", "-", ...]). 154 prompt_text: Text to pipe to stdin. 155 translate: Function that receives (raw_event_dict, aggregator, callback) 156 and emits our Event types. Must return the cli_session_id from the 157 init event (or None for non-init events). 158 callback: JSONEventCallback for emitting events. 159 aggregator: ThinkingAggregator for text buffering. 160 cwd: Working directory for the subprocess. Defaults to project root. 161 env: Optional environment overrides (merged with os.environ). 162 timeout: Subprocess timeout in seconds. Default 600. 163 first_event_timeout: Timeout for first stdout line in seconds. Default 30. 164 """ 165 166 def __init__( 167 self, 168 cmd: list[str], 169 prompt_text: str, 170 translate: Callable[ 171 [dict[str, Any], ThinkingAggregator, JSONEventCallback], 172 str | None, 173 ], 174 callback: JSONEventCallback, 175 aggregator: ThinkingAggregator, 176 cwd: Path | None = None, 177 env: dict[str, str] | None = None, 178 timeout: int = 600, 179 first_event_timeout: int = 30, 180 ) -> None: 181 self.cmd = cmd 182 self.prompt_text = prompt_text 183 self.translate = translate 184 self.callback = callback 185 self.aggregator = aggregator 186 self.cwd = cwd or _PROJECT_ROOT 187 self.env = env 188 self.timeout = timeout 189 self.first_event_timeout = first_event_timeout 190 self._timed_out_waiting_for_first_event = False 191 self.cli_session_id: str | None = None 192 193 async def run(self) -> str: 194 """Spawn the CLI process, stream events, and return the final result. 195 196 Returns: 197 The final result text from the agent. 198 199 Raises: 200 RuntimeError: If the CLI binary is not found or process fails. 201 """ 202 binary = self.cmd[0] 203 if not shutil.which(binary): 204 raise RuntimeError( 205 f"CLI tool '{binary}' not found. Install it and ensure it's on PATH." 206 ) 207 208 import os 209 210 proc_env = os.environ.copy() 211 if self.env: 212 proc_env.update(self.env) 213 214 LOG.info("Spawning CLI: %s (cwd=%s)", " ".join(self.cmd), self.cwd) 215 216 process = await asyncio.create_subprocess_exec( 217 *self.cmd, 218 stdin=asyncio.subprocess.PIPE, 219 stdout=asyncio.subprocess.PIPE, 220 stderr=asyncio.subprocess.PIPE, 221 limit=1024 * 1024, # 1 MB – tool results can exceed the 64 KB default 222 cwd=str(self.cwd), 223 env=proc_env, 224 ) 225 226 # Pipe prompt to stdin and close 227 if process.stdin: 228 process.stdin.write(self.prompt_text.encode("utf-8")) 229 process.stdin.close() 230 231 # Read stdout line by line, translate each JSONL event 232 stderr_lines: list[str] = [] 233 234 async def _read_stderr() -> None: 235 if not process.stderr: 236 return 237 async for raw_line in process.stderr: 238 line = raw_line.decode("utf-8", errors="replace").rstrip() 239 if line: 240 stderr_lines.append(line) 241 LOG.debug("[%s stderr] %s", binary, line) 242 243 stderr_task = asyncio.create_task(_read_stderr()) 244 self._timed_out_waiting_for_first_event = False 245 246 try: 247 await asyncio.wait_for( 248 self._process_stdout(process), 249 timeout=self.timeout, 250 ) 251 except asyncio.TimeoutError: 252 timeout_seconds = ( 253 self.first_event_timeout 254 if self._timed_out_waiting_for_first_event 255 else self.timeout 256 ) 257 LOG.error("CLI process timed out after %ss, killing", timeout_seconds) 258 process.kill() 259 await stderr_task 260 stderr_tail = "\n".join(stderr_lines[-20:]) 261 error_message = ( 262 f"CLI process timed out after {timeout_seconds}s. " 263 f"Stderr tail:\n{stderr_tail}\n" 264 "Check that the CLI tool is installed and authenticated." 265 ) 266 self.callback.emit( 267 { 268 "event": "error", 269 "error": error_message, 270 "ts": now_ms(), 271 } 272 ) 273 raise RuntimeError(error_message) 274 finally: 275 # Wait for stderr reader to finish 276 if not stderr_task.done(): 277 await stderr_task 278 279 # Wait for process to exit 280 return_code = await process.wait() 281 result = self.aggregator.flush_as_result() 282 283 if return_code != 0: 284 stderr_text = "\n".join(stderr_lines[-20:]) # Last 20 lines 285 if result: 286 # CLI failed but produced output — warn and return what we got 287 LOG.warning( 288 "CLI process exited with code %d but produced output. Stderr: %s", 289 return_code, 290 stderr_text, 291 ) 292 self.callback.emit( 293 { 294 "event": "warning", 295 "message": f"CLI exited with code {return_code}", 296 "stderr": stderr_text, 297 "ts": now_ms(), 298 } 299 ) 300 else: 301 # CLI failed with no output — this is an error. 302 # Don't emit error event here; the caller's exception 303 # handler is responsible for error event emission. 304 LOG.error( 305 "CLI process exited with code %d: %s", 306 return_code, 307 stderr_text, 308 ) 309 raise RuntimeError( 310 f"CLI process exited with code {return_code}. Stderr: {stderr_text}" 311 ) 312 313 return result 314 315 async def _process_stdout(self, process: asyncio.subprocess.Process) -> None: 316 """Read and translate JSONL lines from stdout.""" 317 if not process.stdout: 318 return 319 320 def _process_line(raw_line: bytes) -> None: 321 line = raw_line.decode("utf-8", errors="replace").strip() 322 if not line: 323 return 324 325 try: 326 event_data = json.loads(line) 327 except json.JSONDecodeError: 328 LOG.warning("Non-JSON stdout line: %s", line[:200]) 329 return 330 331 try: 332 session_id = self.translate(event_data, self.aggregator, self.callback) 333 if session_id: 334 self.cli_session_id = session_id 335 except Exception: 336 LOG.exception("Error translating CLI event: %s", line[:200]) 337 338 try: 339 first_line = await asyncio.wait_for( 340 process.stdout.readline(), 341 timeout=self.first_event_timeout, 342 ) 343 except asyncio.TimeoutError: 344 self._timed_out_waiting_for_first_event = True 345 raise 346 if not first_line: 347 return 348 _process_line(first_line) 349 350 while True: 351 try: 352 raw_line = await process.stdout.readline() 353 except asyncio.LimitOverrunError as exc: 354 LOG.warning( 355 "CLI stdout line exceeded buffer limit (%d bytes consumed before limit); " 356 "draining and emitting truncated tool_end", 357 exc.consumed, 358 ) 359 await _drain_line(process.stdout) 360 self.callback.emit( 361 { 362 "event": "tool_end", 363 "tool": "bash", 364 "result": "[output truncated: too large to process — try a more targeted query]", 365 "ts": now_ms(), 366 } 367 ) 368 continue 369 if not raw_line: 370 break 371 _process_line(raw_line) 372 373 374# --------------------------------------------------------------------------- 375# CLI Binary Check 376# --------------------------------------------------------------------------- 377 378 379def check_cli_binary(name: str) -> str: 380 """Check that a CLI binary is available on PATH. 381 382 Args: 383 name: Binary name (e.g., "claude", "codex", "gemini"). 384 385 Returns: 386 The full path to the binary. 387 388 Raises: 389 RuntimeError: If the binary is not found. 390 """ 391 path = shutil.which(name) 392 if not path: 393 raise RuntimeError( 394 f"CLI tool '{name}' not found on PATH. " 395 f"Install it and ensure it's accessible." 396 ) 397 return path 398 399 400# --------------------------------------------------------------------------- 401# Cogitate Environment 402# --------------------------------------------------------------------------- 403 404 405def build_cogitate_env(env_key: str) -> dict[str, str]: 406 """Build environment dict for a cogitate CLI subprocess. 407 408 By default, strips the provider's API key so the CLI uses its own 409 platform/account-based auth. Controlled by the ``providers.auth`` 410 section in journal config: 411 412 "providers": { 413 "auth": { 414 "anthropic": "platform" // default — strip key 415 } 416 } 417 418 Values: ``"platform"`` (default) strips the key; ``"api_key"`` preserves it. 419 420 Args: 421 env_key: Environment variable name to consider stripping 422 (e.g., ``"ANTHROPIC_API_KEY"``). 423 424 Returns: 425 Copy of ``os.environ`` with the key removed when auth mode is platform. 426 """ 427 from think.utils import get_config 428 429 config = get_config() 430 auth_config = config.get("providers", {}).get("auth", {}) 431 432 # Determine provider name from env_key for config lookup 433 # e.g., "ANTHROPIC_API_KEY" -> lookup auth_config for matching provider 434 # We check all auth entries; default is "platform" for any missing provider 435 auth_mode = "platform" 436 for provider, mode in auth_config.items(): 437 from think.providers import PROVIDER_METADATA 438 439 meta = PROVIDER_METADATA.get(provider, {}) 440 if meta.get("env_key") == env_key: 441 auth_mode = mode 442 break 443 444 env = os.environ.copy() 445 if auth_mode == "platform": 446 env.pop(env_key, None) 447 return env 448 449 450__all__ = [ 451 "CLIRunner", 452 "ThinkingAggregator", 453 "assemble_prompt", 454 "build_cogitate_env", 455 "check_cli_binary", 456]