personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4import fnmatch
5import inspect
6import json
7import logging
8import os
9import subprocess
10import time
11from datetime import datetime, timezone
12from pathlib import Path
13from typing import Any, Dict, List, Optional, Union
14
15import frontmatter
16
17from think.utils import get_config, get_journal
18
19# ---------------------------------------------------------------------------
20# Tier constants
21# ---------------------------------------------------------------------------
22
23TIER_PRO = 1
24TIER_FLASH = 2
25TIER_LITE = 3
26
27# ---------------------------------------------------------------------------
28# Model constants
29#
30# IMPORTANT: When updating these models, verify pricing support:
31# 1. Run: make test-only TEST=tests/test_models.py::test_all_default_models_have_pricing
32# 2. If test fails, update genai-prices: make update-prices
33# 3. If still failing, the model may be too new for genai-prices
34#
35# The genai-prices library provides token cost data. New models may not have
36# pricing immediately after release. See: https://pypi.org/project/genai-prices/
37# ---------------------------------------------------------------------------
38
39# Valid OpenAI reasoning effort suffixes appended to model names.
40# E.g., "gpt-5.2-high" → reasoning_effort="high", "gpt-5.2" → omitted.
41OPENAI_EFFORT_SUFFIXES = ("-none", "-low", "-medium", "-high", "-xhigh")
42
43# Map model names that genai-prices doesn't recognize yet to a known equivalent.
44MODEL_PRICE_ALIASES: Dict[str, str] = {
45 "gpt-5.4": "gpt-5.2",
46 "gpt-5.4-mini": "gpt-5-mini",
47}
48
49GEMINI_PRO = "gemini-3.1-pro-preview"
50GEMINI_FLASH = "gemini-3-flash-preview"
51GEMINI_LITE = "gemini-2.5-flash-lite"
52
53GPT_5 = "gpt-5.4"
54GPT_5_MINI = "gpt-5.4-low"
55GPT_5_NANO = "gpt-5.4-mini"
56
57CLAUDE_OPUS_4 = "claude-opus-4-5"
58CLAUDE_SONNET_4 = "claude-sonnet-4-5"
59CLAUDE_HAIKU_4 = "claude-haiku-4-5"
60
61# ---------------------------------------------------------------------------
62# System defaults: provider -> tier -> model
63# ---------------------------------------------------------------------------
64
65PROVIDER_DEFAULTS: Dict[str, Dict[int, str]] = {
66 "google": {
67 TIER_PRO: GEMINI_PRO,
68 TIER_FLASH: GEMINI_FLASH,
69 TIER_LITE: GEMINI_LITE,
70 },
71 "openai": {
72 TIER_PRO: GPT_5,
73 TIER_FLASH: GPT_5_MINI,
74 TIER_LITE: GPT_5_NANO,
75 },
76 "anthropic": {
77 TIER_PRO: CLAUDE_OPUS_4,
78 TIER_FLASH: CLAUDE_SONNET_4,
79 TIER_LITE: CLAUDE_HAIKU_4,
80 },
81}
82
83TYPE_DEFAULTS: Dict[str, Dict[str, Any]] = {
84 "generate": {"provider": "google", "tier": TIER_FLASH, "backup": "anthropic"},
85 "cogitate": {"provider": "openai", "tier": TIER_FLASH, "backup": "anthropic"},
86}
87
88
89# ---------------------------------------------------------------------------
90# Exceptions
91# ---------------------------------------------------------------------------
92
93
94class IncompleteJSONError(ValueError):
95 """Raised when JSON response is truncated due to token limits or other reasons.
96
97 Attributes:
98 reason: The finish/stop reason from the API (e.g., "MAX_TOKENS", "length").
99 partial_text: The truncated response text, useful for debugging.
100 """
101
102 def __init__(self, reason: str, partial_text: str):
103 self.reason = reason
104 self.partial_text = partial_text
105 super().__init__(f"JSON response incomplete (reason: {reason})")
106
107
108# ---------------------------------------------------------------------------
109# Prompt context discovery
110#
111# Context metadata (tier, label, group) is defined in prompt .md files via
112# YAML frontmatter. This eliminates duplication between code and config.
113#
114# NAMING CONVENTION:
115# {module}.{feature}[.{operation}]
116#
117# Examples:
118# - observe.describe.frame -> observe module, describe feature, frame operation
119# - observe.enrich -> observe module, enrich feature (no sub-operation)
120# - talent.system.meetings -> talent module, system source, meetings config
121# - talent.entities.observer -> talent module, entities app, observer config
122# - app.chat.title -> apps module, chat app, title operation
123#
124# DISCOVERY SOURCES:
125# 1. Prompt files listed in PROMPT_PATHS (with context in frontmatter)
126# 2. Categories from observe/categories/*.md (tier/label/group in frontmatter)
127# 3. Talent configs from talent/*.md and apps/*/talent/*.md
128#
129# When adding new contexts:
130# 1. Create a .md prompt file with YAML frontmatter containing:
131# context, tier, label, group
132# 2. Add the path to PROMPT_PATHS
133# 3. If not listed, context falls back to the type's default tier
134# ---------------------------------------------------------------------------
135
136# Flat list of prompt files that define context metadata in frontmatter.
137# Each must have: context, tier, label, group in YAML frontmatter.
138PROMPT_PATHS: List[str] = [
139 "observe/describe.md",
140 "observe/enrich.md",
141 "observe/extract.md",
142 "observe/transcribe/gemini.md",
143 "think/detect_created.md",
144 "think/detect_transcript_segment.md",
145 "think/detect_transcript_json.md",
146 "think/planner.md",
147]
148
149
150# ---------------------------------------------------------------------------
151# Dynamic context discovery
152# ---------------------------------------------------------------------------
153
154# Cached context registry (built lazily on first use)
155_context_registry: Optional[Dict[str, Dict[str, Any]]] = None
156_LEGACY_CONTEXT_PREFIX = "muse."
157_TALENT_CONTEXT_PREFIX = "talent."
158
159
160def _discover_prompt_contexts() -> Dict[str, Dict[str, Any]]:
161 """Load context metadata from prompt files listed in PROMPT_PATHS.
162
163 Each file must have YAML frontmatter with:
164 - context: The context string (e.g., "observe.enrich")
165 - tier: Tier number (1=pro, 2=flash, 3=lite)
166 - label: Human-readable name
167 - group: Settings UI category
168
169 Returns
170 -------
171 Dict[str, Dict[str, Any]]
172 Mapping of context patterns to {tier, label, group} dicts.
173 """
174 contexts = {}
175 base_dir = Path(__file__).parent.parent # Project root
176
177 for rel_path in PROMPT_PATHS:
178 path = base_dir / rel_path
179 if not path.exists():
180 logging.getLogger(__name__).warning(f"Prompt file not found: {path}")
181 continue
182
183 try:
184 post = frontmatter.load(path)
185 meta = post.metadata or {}
186
187 context = meta.get("context")
188 if not context:
189 logging.getLogger(__name__).warning(f"No context in {path}")
190 continue
191
192 contexts[context] = {
193 "tier": meta.get("tier", TIER_FLASH),
194 "label": meta.get("label", context),
195 "group": meta.get("group", "Other"),
196 }
197 except Exception as e:
198 logging.getLogger(__name__).warning(f"Failed to load {path}: {e}")
199
200 return contexts
201
202
203def _discover_talent_contexts() -> Dict[str, Dict[str, Any]]:
204 """Discover talent context defaults from talent/*.md config files.
205
206 Uses get_talent_configs() from think.talent to load all talent configurations
207 and converts them to context patterns with tier/label/group metadata.
208
209 Returns
210 -------
211 Dict[str, Dict[str, Any]]
212 Mapping of context patterns to {tier, label, group, type} dicts.
213 Context patterns are: talent.system.{name} or talent.{app}.{name}
214 """
215 from think.talent import get_talent_configs, key_to_context
216
217 contexts = {}
218
219 # Load all talent configs (including disabled for completeness)
220 all_configs = get_talent_configs(include_disabled=True)
221
222 for key, config in all_configs.items():
223 context = key_to_context(key)
224 contexts[context] = {
225 "tier": config.get("tier", TIER_FLASH),
226 "label": config.get("label", config.get("title", key)),
227 "group": config.get("group", "Think"),
228 "type": config.get("type"),
229 }
230
231 return contexts
232
233
234def _build_context_registry() -> Dict[str, Dict[str, Any]]:
235 """Build complete context registry from discovered configs.
236
237 Merges:
238 1. Prompt contexts from _discover_prompt_contexts()
239 2. Category contexts from observe/describe.py CATEGORIES
240 3. Talent contexts from _discover_talent_contexts()
241
242 Returns
243 -------
244 Dict[str, Dict[str, Any]]
245 Complete context registry mapping patterns to {tier, label, group}.
246 """
247 # Start with prompt contexts (from PROMPT_PATHS)
248 registry = _discover_prompt_contexts()
249
250 # Merge category contexts (lazy import to avoid circular dependency)
251 try:
252 from observe.describe import CATEGORIES
253
254 for category, metadata in CATEGORIES.items():
255 context = metadata.get("context", f"observe.describe.{category}")
256 registry[context] = {
257 "tier": metadata.get("tier", TIER_FLASH),
258 "label": metadata.get("label", category.replace("_", " ").title()),
259 "group": metadata.get("group", "Screen Analysis"),
260 }
261 except ImportError:
262 pass # observe module not available
263
264 # Merge talent contexts (agents + generators)
265 talent_contexts = _discover_talent_contexts()
266 registry.update(talent_contexts)
267
268 return registry
269
270
271def get_context_registry() -> Dict[str, Dict[str, Any]]:
272 """Get the complete context registry, building it lazily on first use.
273
274 Returns
275 -------
276 Dict[str, Dict[str, Any]]
277 Complete context registry mapping patterns to {tier, label, group}.
278 """
279 global _context_registry
280 if _context_registry is None:
281 _context_registry = _build_context_registry()
282 return _context_registry
283
284
285def _resolve_tier(context: str, agent_type: str) -> int:
286 """Resolve context to tier number.
287
288 Checks journal config contexts first, then dynamic context registry with glob matching.
289
290 Parameters
291 ----------
292 context
293 Context string (e.g., "talent.system.default", "observe.describe.frame").
294 agent_type
295 Agent type ("generate" or "cogitate").
296
297 Returns
298 -------
299 int
300 Tier number (1=pro, 2=flash, 3=lite).
301 """
302 from think.utils import get_config
303
304 default_tier = TYPE_DEFAULTS[agent_type]["tier"]
305
306 journal_config = get_config()
307 providers_config = journal_config.get("providers", {})
308 contexts = providers_config.get("contexts", {})
309
310 # Get dynamic context registry (discovered prompts, categories, talent configs)
311 registry = get_context_registry()
312
313 # Check journal config contexts first (exact match)
314 if context in contexts:
315 return contexts[context].get("tier", default_tier)
316
317 # Check context registry (exact match)
318 if context in registry:
319 return registry[context]["tier"]
320
321 # Check glob patterns in both
322 for pattern, ctx_config in contexts.items():
323 if fnmatch.fnmatch(context, pattern):
324 return ctx_config.get("tier", default_tier)
325
326 for pattern, ctx_default in registry.items():
327 if fnmatch.fnmatch(context, pattern):
328 return ctx_default["tier"]
329
330 return default_tier
331
332
333def _resolve_model(provider: str, tier: int, config_models: Dict[str, Any]) -> str:
334 """Resolve tier to model string for a given provider.
335
336 Checks config overrides first, then falls back to system defaults.
337 If requested tier is unavailable, falls back to more capable tiers
338 (3→2→1, i.e., lite→flash→pro).
339
340 Parameters
341 ----------
342 provider
343 Provider name ("google", "openai", "anthropic").
344 tier
345 Tier number (1=pro, 2=flash, 3=lite).
346 config_models
347 The "models" section from providers config, mapping provider to tier overrides.
348
349 Returns
350 -------
351 str
352 Model identifier string.
353 """
354 # Check config overrides first
355 provider_overrides = config_models.get(provider, {})
356
357 # Try requested tier, then fall back to more capable tiers (lower numbers)
358 for t in [tier, tier - 1, tier - 2] if tier > 1 else [tier]:
359 if t < 1:
360 continue
361
362 # Check config override (tier as string key in JSON)
363 tier_key = str(t)
364 if tier_key in provider_overrides:
365 return provider_overrides[tier_key]
366
367 # Check system defaults
368 provider_defaults = PROVIDER_DEFAULTS.get(provider, {})
369 if t in provider_defaults:
370 return provider_defaults[t]
371
372 # Ultimate fallback: system default for provider at TIER_FLASH
373 provider_defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["google"])
374 return provider_defaults.get(TIER_FLASH, GEMINI_FLASH)
375
376
377def resolve_model_for_provider(
378 context: str, provider: str, agent_type: str = "generate"
379) -> str:
380 """Resolve model for a specific provider based on context tier.
381
382 Use this when provider is overridden from the default - resolves the
383 appropriate model for the given provider at the context's tier.
384
385 Parameters
386 ----------
387 context
388 Context string (e.g., "talent.system.default").
389 provider
390 Provider name ("google", "openai", "anthropic").
391 agent_type
392 Agent type ("generate" or "cogitate").
393
394 Returns
395 -------
396 str
397 Model identifier string for the provider at the context's tier.
398 """
399 from think.utils import get_config
400
401 tier = _resolve_tier(context, agent_type)
402 journal_config = get_config()
403 providers_config = journal_config.get("providers", {})
404 config_models = providers_config.get("models", {})
405
406 return _resolve_model(provider, tier, config_models)
407
408
409def resolve_provider(context: str, agent_type: str) -> tuple[str, str]:
410 """Resolve context to provider and model based on configuration.
411
412 Matches context against configured contexts using exact match first,
413 then glob patterns (via fnmatch), falling back to type-specific defaults.
414
415 Supports both explicit model strings and tier-based routing:
416 - {"provider": "google", "model": "gemini-3-flash-preview"} - explicit model
417 - {"provider": "google", "tier": 2} - tier-based (2=flash)
418 - {"tier": 1} - tier only, inherits provider from type default
419
420 The "models" section in providers config allows overriding which model
421 is used for each tier per provider.
422
423 Parameters
424 ----------
425 context
426 Context string (e.g., "observe.describe.frame", "talent.system.meetings").
427 agent_type
428 Agent type ("generate" or "cogitate").
429
430 Returns
431 -------
432 tuple[str, str]
433 (provider_name, model) tuple. Provider is one of "google", "openai",
434 "anthropic". Model is the full model identifier string.
435 """
436 config = get_config()
437 providers = config.get("providers", {})
438 config_models = providers.get("models", {})
439
440 # Get type-specific defaults from config, falling back to system constants
441 type_defaults = TYPE_DEFAULTS[agent_type]
442 type_config = providers.get(agent_type, {})
443 default_provider = type_config.get("provider", type_defaults["provider"])
444 default_tier = type_config.get("tier", type_defaults["tier"])
445
446 # Handle explicit "model" key in type config (overrides tier-based resolution)
447 if "model" in type_config and "tier" not in type_config:
448 default_model = type_config["model"]
449 else:
450 default_model = _resolve_model(default_provider, default_tier, config_models)
451
452 contexts = providers.get("contexts", {})
453
454 # Find matching context config
455 match_config: Optional[Dict[str, Any]] = None
456
457 if context and contexts:
458 # Check for exact match first
459 if context in contexts:
460 match_config = contexts[context]
461 else:
462 # Check glob patterns - most specific (longest non-wildcard prefix) wins
463 matches = []
464 for pattern, ctx_config in contexts.items():
465 if fnmatch.fnmatch(context, pattern):
466 specificity = len(pattern.split("*")[0])
467 matches.append((specificity, pattern, ctx_config))
468
469 if matches:
470 matches.sort(key=lambda x: x[0], reverse=True)
471 _, _, match_config = matches[0]
472
473 # No context match - check dynamic context registry for this context
474 if match_config is None:
475 # Get dynamic context registry (discovered prompts, categories, talent configs)
476 registry = get_context_registry()
477
478 # Check for matching context default (exact match first, then glob)
479 context_tier = None
480 if context:
481 if context in registry:
482 context_tier = registry[context]["tier"]
483 else:
484 # Check glob patterns
485 matches = []
486 for pattern, ctx_default in registry.items():
487 if fnmatch.fnmatch(context, pattern):
488 specificity = len(pattern.split("*")[0])
489 matches.append((specificity, ctx_default["tier"]))
490 if matches:
491 matches.sort(key=lambda x: x[0], reverse=True)
492 context_tier = matches[0][1]
493
494 if context_tier is not None:
495 model = _resolve_model(default_provider, context_tier, config_models)
496 return (default_provider, model)
497
498 return (default_provider, default_model)
499
500 # Resolve provider (from match or default)
501 provider = match_config.get("provider", default_provider)
502
503 # Resolve model: explicit model takes precedence over tier
504 if "model" in match_config:
505 model = match_config["model"]
506 elif "tier" in match_config:
507 tier = match_config["tier"]
508 # Validate tier
509 if not isinstance(tier, int) or tier < 1 or tier > 3:
510 logging.getLogger(__name__).warning(
511 "Invalid tier %r in context %r, using default", tier, context
512 )
513 tier = default_tier
514 model = _resolve_model(provider, tier, config_models)
515 else:
516 # No model or tier specified - use default tier
517 model = _resolve_model(provider, default_tier, config_models)
518
519 return (provider, model)
520
521
522def log_token_usage(
523 model: str,
524 usage: Union[Dict[str, Any], Any],
525 context: Optional[str] = None,
526 segment: Optional[str] = None,
527 type: Optional[str] = None,
528) -> None:
529 """Log token usage to journal with unified schema.
530
531 Providers normalize usage into the unified schema (see USAGE_KEYS in
532 shared.py) before returning GenerateResult. This function passes
533 through those known keys, computes total_tokens when missing, and
534 handles a few legacy field aliases from CLI backends.
535
536 Parameters
537 ----------
538 model : str
539 Model name (e.g., "gpt-5", "gemini-2.5-flash")
540 usage : dict
541 Normalized usage dict with keys from USAGE_KEYS.
542 context : str, optional
543 Context string (e.g., "module.function:123" or "talent.system.default").
544 If None, auto-detects from call stack.
545 segment : str, optional
546 Segment key (e.g., "143022_300") for attribution.
547 If None, falls back to SOL_SEGMENT environment variable.
548 type : str, optional
549 Token entry type (e.g., "generate", "cogitate").
550 """
551 from think.providers.shared import USAGE_KEYS
552
553 try:
554 journal = get_journal()
555
556 # Auto-detect calling context if not provided
557 if context is None:
558 frame = inspect.currentframe()
559 caller_frame = frame.f_back if frame else None
560
561 # Skip frames that contain "gemini" in function name
562 while caller_frame and "gemini" in caller_frame.f_code.co_name.lower():
563 caller_frame = caller_frame.f_back
564
565 if caller_frame:
566 module_name = caller_frame.f_globals.get("__name__", "unknown")
567 func_name = caller_frame.f_code.co_name
568 line_num = caller_frame.f_lineno
569
570 # Clean up module name
571 for prefix in ["think.", "observe.", "convey."]:
572 if module_name.startswith(prefix):
573 module_name = module_name[len(prefix) :]
574 break
575
576 context = f"{module_name}.{func_name}:{line_num}"
577
578 # Pass through known keys from the already-normalized usage dict.
579 normalized_usage: Dict[str, int] = {}
580 for key in USAGE_KEYS:
581 val = usage.get(key)
582 if val:
583 normalized_usage[key] = val
584
585 # Legacy alias: some CLI backends emit cached_input_tokens
586 if not normalized_usage.get("cached_tokens") and usage.get(
587 "cached_input_tokens"
588 ):
589 normalized_usage["cached_tokens"] = usage["cached_input_tokens"]
590
591 # Compute total_tokens from parts when missing (e.g. Codex CLI omits it)
592 if not normalized_usage.get("total_tokens"):
593 inp = normalized_usage.get("input_tokens", 0)
594 out = normalized_usage.get("output_tokens", 0)
595 if inp or out:
596 normalized_usage["total_tokens"] = inp + out
597
598 # Build token log entry
599 token_data = {
600 "timestamp": time.time(),
601 "model": model,
602 "context": context,
603 "usage": normalized_usage,
604 }
605
606 # Add segment: prefer parameter, fallback to env (set by think/insight, observe handlers)
607 segment_key = segment or os.getenv("SOL_SEGMENT")
608 if segment_key:
609 token_data["segment"] = segment_key
610 if type:
611 token_data["type"] = type
612
613 # Save to journal/tokens/<YYYYMMDD>.jsonl (one file per day)
614 tokens_dir = Path(journal) / "tokens"
615 tokens_dir.mkdir(exist_ok=True)
616
617 filename = time.strftime("%Y%m%d.jsonl")
618 filepath = tokens_dir / filename
619
620 # Atomic append - safe for parallel writers
621 with open(filepath, "a") as f:
622 f.write(json.dumps(token_data) + "\n")
623
624 except Exception:
625 # Silently fail - logging shouldn't break the main flow
626 pass
627
628
629def get_model_provider(model: str) -> str:
630 """Get the provider name from a model identifier.
631
632 Parameters
633 ----------
634 model : str
635 Model name (e.g., "gpt-5", "gemini-2.5-flash", "claude-sonnet-4-5")
636
637 Returns
638 -------
639 str
640 Provider name: "openai", "google", "anthropic", or "unknown"
641 """
642 model_lower = model.lower()
643
644 if model_lower.startswith("gpt"):
645 return "openai"
646 elif model_lower.startswith("gemini"):
647 return "google"
648 elif model_lower.startswith("claude"):
649 return "anthropic"
650 else:
651 return "unknown"
652
653
654def calc_token_cost(token_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
655 """Calculate cost for a token usage record.
656
657 Parameters
658 ----------
659 token_data : dict
660 Token usage record from journal logs with structure:
661 {
662 "model": "gemini-2.5-flash",
663 "usage": {
664 "input_tokens": 1500,
665 "output_tokens": 500,
666 "cached_tokens": 800,
667 "reasoning_tokens": 200,
668 ...
669 }
670 }
671
672 Returns
673 -------
674 dict or None
675 Cost breakdown:
676 {
677 "total_cost": 0.00123,
678 "input_cost": 0.00075,
679 "output_cost": 0.00048,
680 "currency": "USD"
681 }
682 Returns None if pricing unavailable or calculation fails.
683 """
684 try:
685 from genai_prices import Usage, calc_price
686
687 model = token_data.get("model")
688 usage_data = token_data.get("usage", {})
689
690 if not model or not usage_data:
691 return None
692
693 # Strip OpenAI reasoning effort suffixes for price lookup
694 for suffix in OPENAI_EFFORT_SUFFIXES:
695 if model.endswith(suffix):
696 model = model[: -len(suffix)]
697 break
698
699 # Get provider ID before aliasing (alias may change the model family)
700 provider_id = get_model_provider(model)
701 if provider_id == "unknown":
702 return None
703
704 # Apply price aliases for models genai-prices doesn't recognize yet
705 model = MODEL_PRICE_ALIASES.get(model, model)
706
707 # Map our token fields to genai_prices Usage format
708 # Note: Gemini reports reasoning_tokens separately, but they're billed at
709 # output token rates. genai-prices doesn't have a separate field for reasoning,
710 # so we add them to output_tokens for correct pricing.
711 input_tokens = usage_data.get("input_tokens", 0)
712 output_tokens = usage_data.get("output_tokens", 0)
713 cached_tokens = usage_data.get("cached_tokens", 0)
714 reasoning_tokens = usage_data.get("reasoning_tokens", 0)
715
716 # Add reasoning tokens to output for pricing (Gemini bills them as output)
717 total_output_tokens = output_tokens + reasoning_tokens
718
719 # Create Usage object
720 usage = Usage(
721 input_tokens=input_tokens,
722 output_tokens=total_output_tokens,
723 cache_read_tokens=cached_tokens if cached_tokens > 0 else None,
724 )
725
726 # Calculate price
727 result = calc_price(
728 usage=usage,
729 model_ref=model,
730 provider_id=provider_id,
731 )
732
733 # Return simplified cost breakdown
734 return {
735 "total_cost": float(result.total_price),
736 "input_cost": float(result.input_price),
737 "output_cost": float(result.output_price),
738 "currency": "USD",
739 }
740
741 except Exception:
742 # Silently fail if pricing unavailable
743 return None
744
745
746def calc_agent_cost(
747 model: Optional[str], usage: Optional[Dict[str, Any]]
748) -> Optional[float]:
749 """Calculate total cost for an agent run from model and usage data.
750
751 Convenience wrapper around calc_token_cost for agent cost lookups.
752
753 Returns total cost in USD, or None if data is missing or pricing unavailable.
754 """
755 if not model or not usage:
756 return None
757 try:
758 cost_data = calc_token_cost({"model": model, "usage": usage})
759 if cost_data:
760 return cost_data["total_cost"]
761 except Exception:
762 return None
763 return None
764
765
766def _normalize_legacy_context(ctx: str) -> str:
767 """Normalize legacy token-log context strings to the talent namespace."""
768 if ctx.startswith(_LEGACY_CONTEXT_PREFIX):
769 return _TALENT_CONTEXT_PREFIX + ctx[len(_LEGACY_CONTEXT_PREFIX) :]
770 return ctx
771
772
773def iter_token_log(day: str) -> Any:
774 """Iterate over token log entries for a given day.
775
776 Yields parsed JSON entries from the token log file, skipping empty lines
777 and invalid JSON. This is a shared utility for code that processes token logs.
778
779 Parameters
780 ----------
781 day : str
782 Day in YYYYMMDD format.
783
784 Yields
785 ------
786 dict
787 Parsed token log entry with fields: timestamp, model, context, usage,
788 and optionally segment.
789 """
790 journal = get_journal()
791 log_path = Path(journal) / "tokens" / f"{day}.jsonl"
792
793 if not log_path.exists():
794 return
795
796 with open(log_path, "r") as f:
797 for line in f:
798 line = line.strip()
799 if not line:
800 continue
801 try:
802 entry = json.loads(line)
803 ctx = entry.get("context")
804 if isinstance(ctx, str):
805 entry["context"] = _normalize_legacy_context(ctx)
806 yield entry
807 except json.JSONDecodeError:
808 continue
809
810
811def get_usage_cost(
812 day: str,
813 segment: Optional[str] = None,
814 context: Optional[str] = None,
815) -> Dict[str, Any]:
816 """Get aggregated token usage and cost for a day, optionally filtered.
817
818 This is a shared utility for apps that want to display cost information
819 for segments, agent runs, or other contexts.
820
821 Parameters
822 ----------
823 day : str
824 Day in YYYYMMDD format.
825 segment : str, optional
826 Filter to entries with this exact segment key.
827 context : str, optional
828 Filter to entries where context starts with this prefix.
829 For example, "talent.system" matches "talent.system.default".
830
831 Returns
832 -------
833 dict
834 Aggregated usage data:
835 {
836 "requests": int,
837 "tokens": int,
838 "cost": float, # USD
839 }
840 Returns zeros if no matching entries or day file doesn't exist.
841 """
842 result = {"requests": 0, "tokens": 0, "cost": 0.0}
843
844 for entry in iter_token_log(day):
845 # Apply filters
846 if segment is not None and entry.get("segment") != segment:
847 continue
848 if context is not None:
849 entry_context = entry.get("context", "")
850 if not entry_context.startswith(context):
851 continue
852
853 # Skip unknown providers (can't calculate cost)
854 model = entry.get("model", "unknown")
855 if get_model_provider(model) == "unknown":
856 continue
857
858 # Accumulate
859 usage = entry.get("usage", {})
860 result["requests"] += 1
861 result["tokens"] += usage.get("total_tokens", 0) or 0
862
863 cost_data = calc_token_cost(entry)
864 if cost_data:
865 result["cost"] += cost_data["total_cost"]
866
867 return result
868
869
870# ---------------------------------------------------------------------------
871# Unified generate/agenerate with provider routing
872# ---------------------------------------------------------------------------
873
874
875def _validate_json_response(result: Dict[str, Any], json_output: bool) -> None:
876 """Validate response for JSON output mode.
877
878 Raises IncompleteJSONError if finish_reason indicates truncation.
879 """
880 if not json_output:
881 return
882
883 finish_reason = result.get("finish_reason")
884 if finish_reason and finish_reason != "stop":
885 raise IncompleteJSONError(
886 reason=finish_reason,
887 partial_text=result.get("text", ""),
888 )
889
890
891def generate(
892 contents: Union[str, List[Any]],
893 context: str,
894 temperature: float = 0.3,
895 max_output_tokens: int = 8192 * 2,
896 system_instruction: Optional[str] = None,
897 json_output: bool = False,
898 thinking_budget: Optional[int] = None,
899 timeout_s: Optional[float] = None,
900 **kwargs: Any,
901) -> str:
902 """Generate text using the configured provider for the given context.
903
904 Routes the request to the appropriate backend (Google, OpenAI, or Anthropic)
905 based on the providers configuration in journal.json.
906
907 Parameters
908 ----------
909 contents : str or List
910 The content to send to the model.
911 context : str
912 Context string for routing and token logging (e.g., "talent.system.meetings").
913 This is required and determines which provider/model to use.
914 temperature : float
915 Temperature for generation (default: 0.3).
916 max_output_tokens : int
917 Maximum tokens for the model's response output.
918 system_instruction : str, optional
919 System instruction for the model.
920 json_output : bool
921 Whether to request JSON response format.
922 thinking_budget : int, optional
923 Token budget for model thinking (ignored by providers that don't support it).
924 timeout_s : float, optional
925 Request timeout in seconds.
926 **kwargs
927 Additional provider-specific options passed through to the backend.
928
929 Returns
930 -------
931 str
932 Response text from the model.
933
934 Raises
935 ------
936 ValueError
937 If the resolved provider is not supported.
938 IncompleteJSONError
939 If json_output=True and response was truncated.
940 """
941 from think.providers import get_provider_module
942
943 # Allow model override via kwargs (used by callers with explicit model selection)
944 model_override = kwargs.pop("model", None)
945
946 provider, model = resolve_provider(context, "generate")
947 if model_override:
948 model = model_override
949
950 # Get provider module via registry (raises ValueError for unknown providers)
951 provider_mod = get_provider_module(provider)
952
953 # Call provider's run_generate (returns GenerateResult)
954 result = provider_mod.run_generate(
955 contents=contents,
956 model=model,
957 temperature=temperature,
958 max_output_tokens=max_output_tokens,
959 system_instruction=system_instruction,
960 json_output=json_output,
961 thinking_budget=thinking_budget,
962 timeout_s=timeout_s,
963 **kwargs,
964 )
965
966 # Log token usage centrally (before validation so truncated responses
967 # still get their usage recorded)
968 if result.get("usage"):
969 log_token_usage(
970 model=model,
971 usage=result["usage"],
972 context=context,
973 type="generate",
974 )
975
976 # Validate JSON output if requested
977 _validate_json_response(result, json_output)
978
979 return result["text"]
980
981
982# ---------------------------------------------------------------------------
983# Provider Health & Fallback Helpers
984# ---------------------------------------------------------------------------
985
986
987def get_backup_provider(agent_type: str) -> Optional[str]:
988 """Get the backup provider for the given agent type.
989
990 Reads from the type-specific section in journal config, falling back
991 to TYPE_DEFAULTS.
992
993 Returns None if backup would be the same as the primary provider.
994 """
995 type_defaults = TYPE_DEFAULTS[agent_type]
996 config = get_config()
997 providers_config = config.get("providers", {})
998 type_config = providers_config.get(agent_type, {})
999 primary_provider = type_config.get("provider", type_defaults["provider"])
1000 backup = type_config.get("backup", type_defaults["backup"])
1001 if backup == primary_provider:
1002 return None
1003 return backup
1004
1005
1006def load_health_status() -> Optional[dict]:
1007 """Load health status from journal/health/agents.json.
1008
1009 Returns parsed dict or None if file is missing/unreadable.
1010 """
1011 try:
1012 health_path = Path(get_journal()) / "health" / "agents.json"
1013 with open(health_path) as f:
1014 return json.load(f)
1015 except (FileNotFoundError, json.JSONDecodeError, OSError):
1016 return None
1017
1018
1019def is_provider_healthy(provider: str, health_data: Optional[dict]) -> bool:
1020 """Check if a provider is healthy based on health data.
1021
1022 Returns True (assume healthy) when:
1023 - health_data is None (no data available)
1024 - No results exist for the provider
1025 - Any result for the provider has ok=True
1026
1027 Returns False only when all results for the provider have ok=False.
1028 """
1029 if health_data is None:
1030 return True
1031 results = health_data.get("results", [])
1032 provider_results = [r for r in results if r.get("provider") == provider]
1033 if not provider_results:
1034 return True
1035 return any(r.get("ok") for r in provider_results)
1036
1037
1038def should_recheck_health(health_data: Optional[dict]) -> bool:
1039 """Check if health data is stale (>1 hour old).
1040
1041 Returns False when health_data is None or on parse errors.
1042 """
1043 if health_data is None:
1044 return False
1045 checked_at = health_data.get("checked_at")
1046 if not checked_at:
1047 return False
1048 try:
1049 checked_time = datetime.fromisoformat(checked_at)
1050 if checked_time.tzinfo is None:
1051 checked_time = checked_time.replace(tzinfo=timezone.utc)
1052 age = datetime.now(timezone.utc) - checked_time
1053 return age.total_seconds() > 3600
1054 except (ValueError, TypeError):
1055 return False
1056
1057
1058def request_health_recheck() -> None:
1059 """Request a health re-check by spawning a background process.
1060
1061 Fire-and-forget; errors are logged but never propagated.
1062 """
1063 try:
1064 subprocess.Popen(
1065 ["sol", "agents", "check", "--targeted"],
1066 stdout=subprocess.DEVNULL,
1067 stderr=subprocess.DEVNULL,
1068 )
1069 except Exception:
1070 logging.getLogger(__name__).debug(
1071 "Failed to request health recheck", exc_info=True
1072 )
1073
1074
1075def generate_with_result(
1076 contents: Union[str, List[Any]],
1077 context: str,
1078 temperature: float = 0.3,
1079 max_output_tokens: int = 8192 * 2,
1080 system_instruction: Optional[str] = None,
1081 json_output: bool = False,
1082 thinking_budget: Optional[int] = None,
1083 timeout_s: Optional[float] = None,
1084 **kwargs: Any,
1085) -> dict:
1086 """Generate text and return full result with usage data.
1087
1088 Same as generate() but returns the full GenerateResult dict instead of
1089 just the text. Used by cortex-managed generators that need usage data
1090 for event emission.
1091
1092 Returns
1093 -------
1094 dict
1095 GenerateResult with: text, usage, finish_reason, thinking.
1096 """
1097 from think.providers import get_provider_module
1098
1099 model_override = kwargs.pop("model", None)
1100 provider_override = kwargs.pop("provider", None)
1101
1102 provider, model = resolve_provider(context, "generate")
1103 if provider_override:
1104 provider = provider_override
1105 if not model_override:
1106 model = resolve_model_for_provider(context, provider, "generate")
1107 if model_override:
1108 model = model_override
1109
1110 provider_mod = get_provider_module(provider)
1111
1112 result = provider_mod.run_generate(
1113 contents=contents,
1114 model=model,
1115 temperature=temperature,
1116 max_output_tokens=max_output_tokens,
1117 system_instruction=system_instruction,
1118 json_output=json_output,
1119 thinking_budget=thinking_budget,
1120 timeout_s=timeout_s,
1121 **kwargs,
1122 )
1123
1124 # Log token usage centrally (before validation so truncated responses
1125 # still get their usage recorded)
1126 if result.get("usage"):
1127 log_token_usage(
1128 model=model,
1129 usage=result["usage"],
1130 context=context,
1131 type="generate",
1132 )
1133
1134 # Validate JSON output if requested
1135 _validate_json_response(result, json_output)
1136
1137 return result
1138
1139
1140async def agenerate(
1141 contents: Union[str, List[Any]],
1142 context: str,
1143 temperature: float = 0.3,
1144 max_output_tokens: int = 8192 * 2,
1145 system_instruction: Optional[str] = None,
1146 json_output: bool = False,
1147 thinking_budget: Optional[int] = None,
1148 timeout_s: Optional[float] = None,
1149 **kwargs: Any,
1150) -> str:
1151 """Async generate text using the configured provider for the given context.
1152
1153 Routes the request to the appropriate backend (Google, OpenAI, or Anthropic)
1154 based on the providers configuration in journal.json.
1155
1156 Parameters
1157 ----------
1158 contents : str or List
1159 The content to send to the model.
1160 context : str
1161 Context string for routing and token logging (e.g., "talent.system.meetings").
1162 This is required and determines which provider/model to use.
1163 temperature : float
1164 Temperature for generation (default: 0.3).
1165 max_output_tokens : int
1166 Maximum tokens for the model's response output.
1167 system_instruction : str, optional
1168 System instruction for the model.
1169 json_output : bool
1170 Whether to request JSON response format.
1171 thinking_budget : int, optional
1172 Token budget for model thinking (ignored by providers that don't support it).
1173 timeout_s : float, optional
1174 Request timeout in seconds.
1175 **kwargs
1176 Additional provider-specific options passed through to the backend.
1177
1178 Returns
1179 -------
1180 str
1181 Response text from the model.
1182
1183 Raises
1184 ------
1185 ValueError
1186 If the resolved provider is not supported.
1187 IncompleteJSONError
1188 If json_output=True and response was truncated.
1189 """
1190 from think.providers import get_provider_module
1191
1192 # Allow model override via kwargs (used by Batch for explicit model selection)
1193 model_override = kwargs.pop("model", None)
1194
1195 provider, model = resolve_provider(context, "generate")
1196 if model_override:
1197 model = model_override
1198
1199 # Get provider module via registry (raises ValueError for unknown providers)
1200 provider_mod = get_provider_module(provider)
1201
1202 # Call provider's run_agenerate (returns GenerateResult)
1203 result = await provider_mod.run_agenerate(
1204 contents=contents,
1205 model=model,
1206 temperature=temperature,
1207 max_output_tokens=max_output_tokens,
1208 system_instruction=system_instruction,
1209 json_output=json_output,
1210 thinking_budget=thinking_budget,
1211 timeout_s=timeout_s,
1212 **kwargs,
1213 )
1214
1215 # Log token usage centrally (before validation so truncated responses
1216 # still get their usage recorded)
1217 if result.get("usage"):
1218 log_token_usage(
1219 model=model,
1220 usage=result["usage"],
1221 context=context,
1222 type="generate",
1223 )
1224
1225 # Validate JSON output if requested
1226 _validate_json_response(result, json_output)
1227
1228 return result["text"]
1229
1230
1231__all__ = [
1232 # Provider configuration
1233 "TYPE_DEFAULTS",
1234 "PROMPT_PATHS",
1235 "get_context_registry",
1236 # Model constants (used by provider backends for defaults)
1237 "GEMINI_FLASH",
1238 "GPT_5",
1239 "CLAUDE_SONNET_4",
1240 # Unified API
1241 "generate",
1242 "generate_with_result",
1243 "agenerate",
1244 "resolve_provider",
1245 # Utilities
1246 "log_token_usage",
1247 "calc_token_cost",
1248 "calc_agent_cost",
1249 "get_usage_cost",
1250 "iter_token_log",
1251 "get_model_provider",
1252]