personal memory agent
at main 460 lines 15 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4import asyncio 5import json 6from datetime import datetime, timedelta, timezone 7from io import StringIO 8from types import SimpleNamespace 9from unittest.mock import MagicMock 10 11import pytest 12 13from think.agents import _is_retryable_error 14from think.models import ( 15 TYPE_DEFAULTS, 16 get_backup_provider, 17 is_provider_healthy, 18 should_recheck_health, 19) 20 21 22def test_is_provider_healthy_all_failed(): 23 health_data = { 24 "results": [ 25 {"provider": "google", "ok": False}, 26 {"provider": "google", "ok": False}, 27 ] 28 } 29 assert is_provider_healthy("google", health_data) is False 30 31 32def test_is_provider_healthy_some_passed(): 33 health_data = { 34 "results": [ 35 {"provider": "google", "ok": False}, 36 {"provider": "google", "ok": True}, 37 ] 38 } 39 assert is_provider_healthy("google", health_data) is True 40 41 42def test_is_provider_healthy_no_data(): 43 assert is_provider_healthy("google", None) is True 44 45 46def test_is_provider_healthy_no_results_for_provider(): 47 health_data = {"results": [{"provider": "anthropic", "ok": False}]} 48 assert is_provider_healthy("google", health_data) is True 49 50 51def test_should_recheck_health_stale(): 52 checked_at = (datetime.now(timezone.utc) - timedelta(hours=2)).isoformat() 53 health_data = {"checked_at": checked_at} 54 assert should_recheck_health(health_data) is True 55 56 57def test_should_recheck_health_fresh(): 58 checked_at = (datetime.now(timezone.utc) - timedelta(minutes=10)).isoformat() 59 health_data = {"checked_at": checked_at} 60 assert should_recheck_health(health_data) is False 61 62 63def test_get_backup_provider_from_config(monkeypatch): 64 monkeypatch.setattr( 65 "think.models.get_config", 66 lambda: {"providers": {"generate": {"provider": "google", "backup": "openai"}}}, 67 ) 68 assert get_backup_provider("generate") == "openai" 69 70 71def test_get_backup_provider_fallback_constant(monkeypatch): 72 monkeypatch.setattr("think.models.get_config", lambda: {}) 73 assert get_backup_provider("generate") == TYPE_DEFAULTS["generate"]["backup"] 74 assert get_backup_provider("cogitate") == TYPE_DEFAULTS["cogitate"]["backup"] 75 76 77def test_get_backup_provider_none_when_same_as_primary(monkeypatch): 78 monkeypatch.setattr( 79 "think.models.get_config", 80 lambda: { 81 "providers": { 82 "generate": {"provider": "openai", "backup": "openai"}, 83 } 84 }, 85 ) 86 assert get_backup_provider("generate") is None 87 88 89def _mock_base_agent_config() -> dict: 90 return { 91 "type": "cogitate", 92 "path": None, 93 "sources": {}, 94 "system_instruction": "", 95 "user_instruction": "", 96 "prompt": "", 97 "disabled": False, 98 } 99 100 101def _patch_prepare_config_dependencies(monkeypatch): 102 monkeypatch.setattr( 103 "think.talent.get_agent", lambda *args, **kwargs: _mock_base_agent_config() 104 ) 105 monkeypatch.setattr( 106 "think.talent.key_to_context", lambda _name: "talent.system.default" 107 ) 108 monkeypatch.setattr( 109 "think.models.resolve_provider", 110 lambda _context, _type: ("google", "gemini-3-flash-preview"), 111 ) 112 113 114def test_preflight_swap_unhealthy_primary(monkeypatch): 115 from think.agents import prepare_config 116 117 _patch_prepare_config_dependencies(monkeypatch) 118 monkeypatch.setattr( 119 "think.models.load_health_status", 120 lambda: {"results": [{"provider": "google", "ok": False}]}, 121 ) 122 monkeypatch.setattr("think.models.should_recheck_health", lambda _h: False) 123 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic") 124 monkeypatch.setattr( 125 "think.models.resolve_model_for_provider", 126 lambda _context, _provider, _type="generate": "claude-sonnet-4-5", 127 ) 128 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") 129 130 config = prepare_config({"name": "unified", "prompt": "hello"}) 131 132 assert config["provider"] == "anthropic" 133 assert config["model"] == "claude-sonnet-4-5" 134 assert config["fallback_from"] == "google" 135 136 137def test_preflight_no_swap_healthy_primary(monkeypatch): 138 from think.agents import prepare_config 139 140 _patch_prepare_config_dependencies(monkeypatch) 141 monkeypatch.setattr( 142 "think.models.load_health_status", 143 lambda: {"results": [{"provider": "google", "ok": True}]}, 144 ) 145 monkeypatch.setattr("think.models.should_recheck_health", lambda _h: False) 146 147 config = prepare_config({"name": "unified", "prompt": "hello"}) 148 149 assert config["provider"] == "google" 150 assert "fallback_from" not in config 151 152 153def test_preflight_no_swap_no_backup_key(monkeypatch): 154 from think.agents import prepare_config 155 156 _patch_prepare_config_dependencies(monkeypatch) 157 monkeypatch.setattr( 158 "think.models.load_health_status", 159 lambda: {"results": [{"provider": "google", "ok": False}]}, 160 ) 161 monkeypatch.setattr("think.models.should_recheck_health", lambda _h: False) 162 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic") 163 monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) 164 165 config = prepare_config({"name": "unified", "prompt": "hello"}) 166 167 assert config["provider"] == "google" 168 assert "fallback_from" not in config 169 170 171def test_on_failure_retry_cogitate(monkeypatch): 172 from think.agents import _execute_with_tools 173 174 events = [] 175 attempts = {"primary": 0, "backup": 0} 176 177 async def fail_cogitate(*_args, **_kwargs): 178 attempts["primary"] += 1 179 raise RuntimeError("primary down") 180 181 async def pass_cogitate(*_args, **kwargs): 182 attempts["backup"] += 1 183 on_event = kwargs.get("on_event") 184 if on_event: 185 on_event({"event": "finish", "result": "backup result"}) 186 return "backup result" 187 188 monkeypatch.setattr( 189 "think.providers.PROVIDER_REGISTRY", {"google": "x", "anthropic": "y"} 190 ) 191 monkeypatch.setattr( 192 "think.providers.get_provider_module", 193 lambda provider: SimpleNamespace( 194 run_cogitate=fail_cogitate if provider == "google" else pass_cogitate 195 ), 196 ) 197 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic") 198 monkeypatch.setattr( 199 "think.models.resolve_model_for_provider", 200 lambda _context, _provider, _type="cogitate": "claude-sonnet-4-5", 201 ) 202 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") 203 204 config = { 205 "provider": "google", 206 "model": "gemini-3-flash-preview", 207 "health_stale": False, 208 "context": "talent.system.default", 209 } 210 211 asyncio.run(_execute_with_tools(config, events.append)) 212 213 assert attempts["primary"] == 1 214 assert attempts["backup"] == 1 215 assert config["provider"] == "anthropic" 216 assert config["model"] == "claude-sonnet-4-5" 217 assert config["fallback_from"] == "google" 218 assert any(e.get("event") == "fallback" for e in events) 219 220 221def test_on_failure_retry_cogitate_uses_context_from_name(monkeypatch): 222 from think.agents import _execute_with_tools 223 224 events = [] 225 seen = {} 226 227 async def fail_cogitate(*_args, **_kwargs): 228 raise RuntimeError("primary down") 229 230 async def pass_cogitate(*_args, **kwargs): 231 on_event = kwargs.get("on_event") 232 if on_event: 233 on_event({"event": "finish", "result": "backup result"}) 234 return "backup result" 235 236 def resolve_model(context, _provider, _type="cogitate"): 237 seen["context"] = context 238 return "claude-sonnet-4-5" 239 240 monkeypatch.setattr( 241 "think.providers.PROVIDER_REGISTRY", {"google": "x", "anthropic": "y"} 242 ) 243 monkeypatch.setattr( 244 "think.providers.get_provider_module", 245 lambda provider: SimpleNamespace( 246 run_cogitate=fail_cogitate if provider == "google" else pass_cogitate 247 ), 248 ) 249 monkeypatch.setattr( 250 "think.talent.key_to_context", 251 lambda _name: "talent.system.default", 252 ) 253 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic") 254 monkeypatch.setattr("think.models.resolve_model_for_provider", resolve_model) 255 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") 256 257 config = { 258 "name": "unified", 259 "provider": "google", 260 "model": "gemini-3-flash-preview", 261 "health_stale": False, 262 } 263 264 asyncio.run(_execute_with_tools(config, events.append)) 265 266 assert seen["context"] == "talent.system.default" 267 268 269def test_on_failure_retry_generate(monkeypatch): 270 from think.agents import _execute_generate 271 272 events = [] 273 calls = {"count": 0} 274 275 def mock_generate_with_result(**kwargs): 276 calls["count"] += 1 277 if calls["count"] == 1: 278 raise RuntimeError("primary generate failed") 279 assert kwargs.get("provider") == "anthropic" 280 assert kwargs.get("model") == "claude-sonnet-4-5" 281 return {"text": "backup text", "usage": {"input_tokens": 1, "output_tokens": 1}} 282 283 monkeypatch.setattr( 284 "think.talent.key_to_context", lambda _name: "talent.system.default" 285 ) 286 monkeypatch.setattr("think.models.generate_with_result", mock_generate_with_result) 287 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic") 288 monkeypatch.setattr( 289 "think.models.resolve_model_for_provider", 290 lambda _context, _provider, _type="generate": "claude-sonnet-4-5", 291 ) 292 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") 293 294 config = { 295 "name": "unified", 296 "provider": "google", 297 "model": "gemini-3-flash-preview", 298 "prompt": "hello", 299 "health_stale": False, 300 } 301 302 asyncio.run(_execute_generate(config, events.append)) 303 304 assert calls["count"] == 2 305 assert config["provider"] == "anthropic" 306 assert config["fallback_from"] == "google" 307 assert any(e.get("event") == "fallback" for e in events) 308 assert events[-1]["event"] == "finish" 309 assert events[-1]["result"] == "backup text" 310 311 312def test_on_failure_no_retry_value_error(monkeypatch): 313 from think.agents import _execute_generate 314 315 events = [] 316 assert _is_retryable_error(ValueError("bad input")) is False 317 318 def bad_generate(**_kwargs): 319 raise ValueError("bad input") 320 321 monkeypatch.setattr( 322 "think.talent.key_to_context", lambda _name: "talent.system.default" 323 ) 324 monkeypatch.setattr("think.models.generate_with_result", bad_generate) 325 326 config = { 327 "name": "unified", 328 "provider": "google", 329 "model": "gemini-3-flash-preview", 330 "prompt": "hello", 331 "health_stale": False, 332 } 333 334 with pytest.raises(ValueError, match="bad input"): 335 asyncio.run(_execute_generate(config, events.append)) 336 337 assert not any(e.get("event") == "fallback" for e in events) 338 339 340def test_on_failure_both_fail_raises_original(monkeypatch): 341 from think.agents import _execute_generate 342 343 events = [] 344 calls = {"count": 0} 345 346 def always_fail(**kwargs): 347 calls["count"] += 1 348 if kwargs.get("provider") == "anthropic": 349 raise RuntimeError("backup failed") 350 raise RuntimeError("primary failed") 351 352 monkeypatch.setattr( 353 "think.talent.key_to_context", lambda _name: "talent.system.default" 354 ) 355 monkeypatch.setattr("think.models.generate_with_result", always_fail) 356 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic") 357 monkeypatch.setattr( 358 "think.models.resolve_model_for_provider", 359 lambda _context, _provider, _type="generate": "claude-sonnet-4-5", 360 ) 361 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") 362 363 config = { 364 "name": "unified", 365 "provider": "google", 366 "model": "gemini-3-flash-preview", 367 "prompt": "hello", 368 "health_stale": False, 369 } 370 371 with pytest.raises(RuntimeError, match="primary failed"): 372 asyncio.run(_execute_generate(config, events.append)) 373 374 assert calls["count"] == 2 375 376 377def test_fallback_event_emitted(): 378 from think.agents import _run_agent 379 380 events = [] 381 config = { 382 "type": "cogitate", 383 "name": "unified", 384 "provider": "anthropic", 385 "model": "claude-sonnet-4-5", 386 "prompt": "hello", 387 "fallback_from": "google", 388 } 389 390 asyncio.run(_run_agent(config, events.append, dry_run=True)) 391 392 fallback_events = [e for e in events if e.get("event") == "fallback"] 393 assert len(fallback_events) == 1 394 assert fallback_events[0]["reason"] == "preflight" 395 396 397def test_recheck_requested_on_stale(monkeypatch): 398 from think.agents import _execute_with_tools 399 400 async def pass_cogitate(*_args, **kwargs): 401 on_event = kwargs.get("on_event") 402 if on_event: 403 on_event({"event": "finish", "result": "ok"}) 404 return "ok" 405 406 recheck_mock = MagicMock() 407 408 monkeypatch.setattr("think.providers.PROVIDER_REGISTRY", {"google": "x"}) 409 monkeypatch.setattr( 410 "think.providers.get_provider_module", 411 lambda _provider: SimpleNamespace(run_cogitate=pass_cogitate), 412 ) 413 monkeypatch.setattr("think.models.request_health_recheck", recheck_mock) 414 415 config = { 416 "provider": "google", 417 "model": "gemini-3-flash-preview", 418 "health_stale": True, 419 } 420 421 asyncio.run(_execute_with_tools(config, lambda _e: None)) 422 423 recheck_mock.assert_called_once() 424 assert config["health_stale"] is False 425 426 427def test_main_async_no_duplicate_error_when_evented(monkeypatch, capsys): 428 from think.agents import main_async 429 430 ndjson_input = json.dumps({"name": "unified", "prompt": "hello"}) 431 monkeypatch.setattr("sys.stdin", StringIO(ndjson_input)) 432 433 async def fake_run_agent(_config, emit_event, dry_run=False): 434 emit_event({"event": "error", "error": "provider failed"}) 435 exc = RuntimeError("provider failed") 436 setattr(exc, "_evented", True) 437 raise exc 438 439 mock_args = MagicMock() 440 mock_args.verbose = False 441 mock_args.dry_run = False 442 mock_args.subcommand = None 443 444 monkeypatch.setattr("think.agents.setup_cli", lambda _parser: mock_args) 445 monkeypatch.setattr( 446 "think.agents.setup_logging", 447 lambda _verbose=False: MagicMock(), 448 ) 449 monkeypatch.setattr( 450 "think.agents.prepare_config", lambda _request: {"type": "cogitate"} 451 ) 452 monkeypatch.setattr("think.agents.validate_config", lambda _config: None) 453 monkeypatch.setattr("think.agents._run_agent", fake_run_agent) 454 455 asyncio.run(main_async()) 456 457 lines = [line for line in capsys.readouterr().out.splitlines() if line.strip()] 458 events = [json.loads(line) for line in lines] 459 error_events = [event for event in events if event.get("event") == "error"] 460 assert len(error_events) == 1