personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4import asyncio
5import json
6from datetime import datetime, timedelta, timezone
7from io import StringIO
8from types import SimpleNamespace
9from unittest.mock import MagicMock
10
11import pytest
12
13from think.agents import _is_retryable_error
14from think.models import (
15 TYPE_DEFAULTS,
16 get_backup_provider,
17 is_provider_healthy,
18 should_recheck_health,
19)
20
21
22def test_is_provider_healthy_all_failed():
23 health_data = {
24 "results": [
25 {"provider": "google", "ok": False},
26 {"provider": "google", "ok": False},
27 ]
28 }
29 assert is_provider_healthy("google", health_data) is False
30
31
32def test_is_provider_healthy_some_passed():
33 health_data = {
34 "results": [
35 {"provider": "google", "ok": False},
36 {"provider": "google", "ok": True},
37 ]
38 }
39 assert is_provider_healthy("google", health_data) is True
40
41
42def test_is_provider_healthy_no_data():
43 assert is_provider_healthy("google", None) is True
44
45
46def test_is_provider_healthy_no_results_for_provider():
47 health_data = {"results": [{"provider": "anthropic", "ok": False}]}
48 assert is_provider_healthy("google", health_data) is True
49
50
51def test_should_recheck_health_stale():
52 checked_at = (datetime.now(timezone.utc) - timedelta(hours=2)).isoformat()
53 health_data = {"checked_at": checked_at}
54 assert should_recheck_health(health_data) is True
55
56
57def test_should_recheck_health_fresh():
58 checked_at = (datetime.now(timezone.utc) - timedelta(minutes=10)).isoformat()
59 health_data = {"checked_at": checked_at}
60 assert should_recheck_health(health_data) is False
61
62
63def test_get_backup_provider_from_config(monkeypatch):
64 monkeypatch.setattr(
65 "think.models.get_config",
66 lambda: {"providers": {"generate": {"provider": "google", "backup": "openai"}}},
67 )
68 assert get_backup_provider("generate") == "openai"
69
70
71def test_get_backup_provider_fallback_constant(monkeypatch):
72 monkeypatch.setattr("think.models.get_config", lambda: {})
73 assert get_backup_provider("generate") == TYPE_DEFAULTS["generate"]["backup"]
74 assert get_backup_provider("cogitate") == TYPE_DEFAULTS["cogitate"]["backup"]
75
76
77def test_get_backup_provider_none_when_same_as_primary(monkeypatch):
78 monkeypatch.setattr(
79 "think.models.get_config",
80 lambda: {
81 "providers": {
82 "generate": {"provider": "openai", "backup": "openai"},
83 }
84 },
85 )
86 assert get_backup_provider("generate") is None
87
88
89def _mock_base_agent_config() -> dict:
90 return {
91 "type": "cogitate",
92 "path": None,
93 "sources": {},
94 "system_instruction": "",
95 "user_instruction": "",
96 "prompt": "",
97 "disabled": False,
98 }
99
100
101def _patch_prepare_config_dependencies(monkeypatch):
102 monkeypatch.setattr(
103 "think.talent.get_agent", lambda *args, **kwargs: _mock_base_agent_config()
104 )
105 monkeypatch.setattr(
106 "think.talent.key_to_context", lambda _name: "talent.system.default"
107 )
108 monkeypatch.setattr(
109 "think.models.resolve_provider",
110 lambda _context, _type: ("google", "gemini-3-flash-preview"),
111 )
112
113
114def test_preflight_swap_unhealthy_primary(monkeypatch):
115 from think.agents import prepare_config
116
117 _patch_prepare_config_dependencies(monkeypatch)
118 monkeypatch.setattr(
119 "think.models.load_health_status",
120 lambda: {"results": [{"provider": "google", "ok": False}]},
121 )
122 monkeypatch.setattr("think.models.should_recheck_health", lambda _h: False)
123 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic")
124 monkeypatch.setattr(
125 "think.models.resolve_model_for_provider",
126 lambda _context, _provider, _type="generate": "claude-sonnet-4-5",
127 )
128 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
129
130 config = prepare_config({"name": "unified", "prompt": "hello"})
131
132 assert config["provider"] == "anthropic"
133 assert config["model"] == "claude-sonnet-4-5"
134 assert config["fallback_from"] == "google"
135
136
137def test_preflight_no_swap_healthy_primary(monkeypatch):
138 from think.agents import prepare_config
139
140 _patch_prepare_config_dependencies(monkeypatch)
141 monkeypatch.setattr(
142 "think.models.load_health_status",
143 lambda: {"results": [{"provider": "google", "ok": True}]},
144 )
145 monkeypatch.setattr("think.models.should_recheck_health", lambda _h: False)
146
147 config = prepare_config({"name": "unified", "prompt": "hello"})
148
149 assert config["provider"] == "google"
150 assert "fallback_from" not in config
151
152
153def test_preflight_no_swap_no_backup_key(monkeypatch):
154 from think.agents import prepare_config
155
156 _patch_prepare_config_dependencies(monkeypatch)
157 monkeypatch.setattr(
158 "think.models.load_health_status",
159 lambda: {"results": [{"provider": "google", "ok": False}]},
160 )
161 monkeypatch.setattr("think.models.should_recheck_health", lambda _h: False)
162 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic")
163 monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
164
165 config = prepare_config({"name": "unified", "prompt": "hello"})
166
167 assert config["provider"] == "google"
168 assert "fallback_from" not in config
169
170
171def test_on_failure_retry_cogitate(monkeypatch):
172 from think.agents import _execute_with_tools
173
174 events = []
175 attempts = {"primary": 0, "backup": 0}
176
177 async def fail_cogitate(*_args, **_kwargs):
178 attempts["primary"] += 1
179 raise RuntimeError("primary down")
180
181 async def pass_cogitate(*_args, **kwargs):
182 attempts["backup"] += 1
183 on_event = kwargs.get("on_event")
184 if on_event:
185 on_event({"event": "finish", "result": "backup result"})
186 return "backup result"
187
188 monkeypatch.setattr(
189 "think.providers.PROVIDER_REGISTRY", {"google": "x", "anthropic": "y"}
190 )
191 monkeypatch.setattr(
192 "think.providers.get_provider_module",
193 lambda provider: SimpleNamespace(
194 run_cogitate=fail_cogitate if provider == "google" else pass_cogitate
195 ),
196 )
197 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic")
198 monkeypatch.setattr(
199 "think.models.resolve_model_for_provider",
200 lambda _context, _provider, _type="cogitate": "claude-sonnet-4-5",
201 )
202 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
203
204 config = {
205 "provider": "google",
206 "model": "gemini-3-flash-preview",
207 "health_stale": False,
208 "context": "talent.system.default",
209 }
210
211 asyncio.run(_execute_with_tools(config, events.append))
212
213 assert attempts["primary"] == 1
214 assert attempts["backup"] == 1
215 assert config["provider"] == "anthropic"
216 assert config["model"] == "claude-sonnet-4-5"
217 assert config["fallback_from"] == "google"
218 assert any(e.get("event") == "fallback" for e in events)
219
220
221def test_on_failure_retry_cogitate_uses_context_from_name(monkeypatch):
222 from think.agents import _execute_with_tools
223
224 events = []
225 seen = {}
226
227 async def fail_cogitate(*_args, **_kwargs):
228 raise RuntimeError("primary down")
229
230 async def pass_cogitate(*_args, **kwargs):
231 on_event = kwargs.get("on_event")
232 if on_event:
233 on_event({"event": "finish", "result": "backup result"})
234 return "backup result"
235
236 def resolve_model(context, _provider, _type="cogitate"):
237 seen["context"] = context
238 return "claude-sonnet-4-5"
239
240 monkeypatch.setattr(
241 "think.providers.PROVIDER_REGISTRY", {"google": "x", "anthropic": "y"}
242 )
243 monkeypatch.setattr(
244 "think.providers.get_provider_module",
245 lambda provider: SimpleNamespace(
246 run_cogitate=fail_cogitate if provider == "google" else pass_cogitate
247 ),
248 )
249 monkeypatch.setattr(
250 "think.talent.key_to_context",
251 lambda _name: "talent.system.default",
252 )
253 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic")
254 monkeypatch.setattr("think.models.resolve_model_for_provider", resolve_model)
255 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
256
257 config = {
258 "name": "unified",
259 "provider": "google",
260 "model": "gemini-3-flash-preview",
261 "health_stale": False,
262 }
263
264 asyncio.run(_execute_with_tools(config, events.append))
265
266 assert seen["context"] == "talent.system.default"
267
268
269def test_on_failure_retry_generate(monkeypatch):
270 from think.agents import _execute_generate
271
272 events = []
273 calls = {"count": 0}
274
275 def mock_generate_with_result(**kwargs):
276 calls["count"] += 1
277 if calls["count"] == 1:
278 raise RuntimeError("primary generate failed")
279 assert kwargs.get("provider") == "anthropic"
280 assert kwargs.get("model") == "claude-sonnet-4-5"
281 return {"text": "backup text", "usage": {"input_tokens": 1, "output_tokens": 1}}
282
283 monkeypatch.setattr(
284 "think.talent.key_to_context", lambda _name: "talent.system.default"
285 )
286 monkeypatch.setattr("think.models.generate_with_result", mock_generate_with_result)
287 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic")
288 monkeypatch.setattr(
289 "think.models.resolve_model_for_provider",
290 lambda _context, _provider, _type="generate": "claude-sonnet-4-5",
291 )
292 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
293
294 config = {
295 "name": "unified",
296 "provider": "google",
297 "model": "gemini-3-flash-preview",
298 "prompt": "hello",
299 "health_stale": False,
300 }
301
302 asyncio.run(_execute_generate(config, events.append))
303
304 assert calls["count"] == 2
305 assert config["provider"] == "anthropic"
306 assert config["fallback_from"] == "google"
307 assert any(e.get("event") == "fallback" for e in events)
308 assert events[-1]["event"] == "finish"
309 assert events[-1]["result"] == "backup text"
310
311
312def test_on_failure_no_retry_value_error(monkeypatch):
313 from think.agents import _execute_generate
314
315 events = []
316 assert _is_retryable_error(ValueError("bad input")) is False
317
318 def bad_generate(**_kwargs):
319 raise ValueError("bad input")
320
321 monkeypatch.setattr(
322 "think.talent.key_to_context", lambda _name: "talent.system.default"
323 )
324 monkeypatch.setattr("think.models.generate_with_result", bad_generate)
325
326 config = {
327 "name": "unified",
328 "provider": "google",
329 "model": "gemini-3-flash-preview",
330 "prompt": "hello",
331 "health_stale": False,
332 }
333
334 with pytest.raises(ValueError, match="bad input"):
335 asyncio.run(_execute_generate(config, events.append))
336
337 assert not any(e.get("event") == "fallback" for e in events)
338
339
340def test_on_failure_both_fail_raises_original(monkeypatch):
341 from think.agents import _execute_generate
342
343 events = []
344 calls = {"count": 0}
345
346 def always_fail(**kwargs):
347 calls["count"] += 1
348 if kwargs.get("provider") == "anthropic":
349 raise RuntimeError("backup failed")
350 raise RuntimeError("primary failed")
351
352 monkeypatch.setattr(
353 "think.talent.key_to_context", lambda _name: "talent.system.default"
354 )
355 monkeypatch.setattr("think.models.generate_with_result", always_fail)
356 monkeypatch.setattr("think.models.get_backup_provider", lambda _type: "anthropic")
357 monkeypatch.setattr(
358 "think.models.resolve_model_for_provider",
359 lambda _context, _provider, _type="generate": "claude-sonnet-4-5",
360 )
361 monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
362
363 config = {
364 "name": "unified",
365 "provider": "google",
366 "model": "gemini-3-flash-preview",
367 "prompt": "hello",
368 "health_stale": False,
369 }
370
371 with pytest.raises(RuntimeError, match="primary failed"):
372 asyncio.run(_execute_generate(config, events.append))
373
374 assert calls["count"] == 2
375
376
377def test_fallback_event_emitted():
378 from think.agents import _run_agent
379
380 events = []
381 config = {
382 "type": "cogitate",
383 "name": "unified",
384 "provider": "anthropic",
385 "model": "claude-sonnet-4-5",
386 "prompt": "hello",
387 "fallback_from": "google",
388 }
389
390 asyncio.run(_run_agent(config, events.append, dry_run=True))
391
392 fallback_events = [e for e in events if e.get("event") == "fallback"]
393 assert len(fallback_events) == 1
394 assert fallback_events[0]["reason"] == "preflight"
395
396
397def test_recheck_requested_on_stale(monkeypatch):
398 from think.agents import _execute_with_tools
399
400 async def pass_cogitate(*_args, **kwargs):
401 on_event = kwargs.get("on_event")
402 if on_event:
403 on_event({"event": "finish", "result": "ok"})
404 return "ok"
405
406 recheck_mock = MagicMock()
407
408 monkeypatch.setattr("think.providers.PROVIDER_REGISTRY", {"google": "x"})
409 monkeypatch.setattr(
410 "think.providers.get_provider_module",
411 lambda _provider: SimpleNamespace(run_cogitate=pass_cogitate),
412 )
413 monkeypatch.setattr("think.models.request_health_recheck", recheck_mock)
414
415 config = {
416 "provider": "google",
417 "model": "gemini-3-flash-preview",
418 "health_stale": True,
419 }
420
421 asyncio.run(_execute_with_tools(config, lambda _e: None))
422
423 recheck_mock.assert_called_once()
424 assert config["health_stale"] is False
425
426
427def test_main_async_no_duplicate_error_when_evented(monkeypatch, capsys):
428 from think.agents import main_async
429
430 ndjson_input = json.dumps({"name": "unified", "prompt": "hello"})
431 monkeypatch.setattr("sys.stdin", StringIO(ndjson_input))
432
433 async def fake_run_agent(_config, emit_event, dry_run=False):
434 emit_event({"event": "error", "error": "provider failed"})
435 exc = RuntimeError("provider failed")
436 setattr(exc, "_evented", True)
437 raise exc
438
439 mock_args = MagicMock()
440 mock_args.verbose = False
441 mock_args.dry_run = False
442 mock_args.subcommand = None
443
444 monkeypatch.setattr("think.agents.setup_cli", lambda _parser: mock_args)
445 monkeypatch.setattr(
446 "think.agents.setup_logging",
447 lambda _verbose=False: MagicMock(),
448 )
449 monkeypatch.setattr(
450 "think.agents.prepare_config", lambda _request: {"type": "cogitate"}
451 )
452 monkeypatch.setattr("think.agents.validate_config", lambda _config: None)
453 monkeypatch.setattr("think.agents._run_agent", fake_run_agent)
454
455 asyncio.run(main_async())
456
457 lines = [line for line in capsys.readouterr().out.splitlines() if line.strip()]
458 events = [json.loads(line) for line in lines]
459 error_events = [event for event in events if event.get("event") == "error"]
460 assert len(error_events) == 1