personal memory agent
at main 1252 lines 41 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4import fnmatch 5import inspect 6import json 7import logging 8import os 9import subprocess 10import time 11from datetime import datetime, timezone 12from pathlib import Path 13from typing import Any, Dict, List, Optional, Union 14 15import frontmatter 16 17from think.utils import get_config, get_journal 18 19# --------------------------------------------------------------------------- 20# Tier constants 21# --------------------------------------------------------------------------- 22 23TIER_PRO = 1 24TIER_FLASH = 2 25TIER_LITE = 3 26 27# --------------------------------------------------------------------------- 28# Model constants 29# 30# IMPORTANT: When updating these models, verify pricing support: 31# 1. Run: make test-only TEST=tests/test_models.py::test_all_default_models_have_pricing 32# 2. If test fails, update genai-prices: make update-prices 33# 3. If still failing, the model may be too new for genai-prices 34# 35# The genai-prices library provides token cost data. New models may not have 36# pricing immediately after release. See: https://pypi.org/project/genai-prices/ 37# --------------------------------------------------------------------------- 38 39# Valid OpenAI reasoning effort suffixes appended to model names. 40# E.g., "gpt-5.2-high" → reasoning_effort="high", "gpt-5.2" → omitted. 41OPENAI_EFFORT_SUFFIXES = ("-none", "-low", "-medium", "-high", "-xhigh") 42 43# Map model names that genai-prices doesn't recognize yet to a known equivalent. 44MODEL_PRICE_ALIASES: Dict[str, str] = { 45 "gpt-5.4": "gpt-5.2", 46 "gpt-5.4-mini": "gpt-5-mini", 47} 48 49GEMINI_PRO = "gemini-3.1-pro-preview" 50GEMINI_FLASH = "gemini-3-flash-preview" 51GEMINI_LITE = "gemini-2.5-flash-lite" 52 53GPT_5 = "gpt-5.4" 54GPT_5_MINI = "gpt-5.4-low" 55GPT_5_NANO = "gpt-5.4-mini" 56 57CLAUDE_OPUS_4 = "claude-opus-4-5" 58CLAUDE_SONNET_4 = "claude-sonnet-4-5" 59CLAUDE_HAIKU_4 = "claude-haiku-4-5" 60 61# --------------------------------------------------------------------------- 62# System defaults: provider -> tier -> model 63# --------------------------------------------------------------------------- 64 65PROVIDER_DEFAULTS: Dict[str, Dict[int, str]] = { 66 "google": { 67 TIER_PRO: GEMINI_PRO, 68 TIER_FLASH: GEMINI_FLASH, 69 TIER_LITE: GEMINI_LITE, 70 }, 71 "openai": { 72 TIER_PRO: GPT_5, 73 TIER_FLASH: GPT_5_MINI, 74 TIER_LITE: GPT_5_NANO, 75 }, 76 "anthropic": { 77 TIER_PRO: CLAUDE_OPUS_4, 78 TIER_FLASH: CLAUDE_SONNET_4, 79 TIER_LITE: CLAUDE_HAIKU_4, 80 }, 81} 82 83TYPE_DEFAULTS: Dict[str, Dict[str, Any]] = { 84 "generate": {"provider": "google", "tier": TIER_FLASH, "backup": "anthropic"}, 85 "cogitate": {"provider": "openai", "tier": TIER_FLASH, "backup": "anthropic"}, 86} 87 88 89# --------------------------------------------------------------------------- 90# Exceptions 91# --------------------------------------------------------------------------- 92 93 94class IncompleteJSONError(ValueError): 95 """Raised when JSON response is truncated due to token limits or other reasons. 96 97 Attributes: 98 reason: The finish/stop reason from the API (e.g., "MAX_TOKENS", "length"). 99 partial_text: The truncated response text, useful for debugging. 100 """ 101 102 def __init__(self, reason: str, partial_text: str): 103 self.reason = reason 104 self.partial_text = partial_text 105 super().__init__(f"JSON response incomplete (reason: {reason})") 106 107 108# --------------------------------------------------------------------------- 109# Prompt context discovery 110# 111# Context metadata (tier, label, group) is defined in prompt .md files via 112# YAML frontmatter. This eliminates duplication between code and config. 113# 114# NAMING CONVENTION: 115# {module}.{feature}[.{operation}] 116# 117# Examples: 118# - observe.describe.frame -> observe module, describe feature, frame operation 119# - observe.enrich -> observe module, enrich feature (no sub-operation) 120# - talent.system.meetings -> talent module, system source, meetings config 121# - talent.entities.observer -> talent module, entities app, observer config 122# - app.chat.title -> apps module, chat app, title operation 123# 124# DISCOVERY SOURCES: 125# 1. Prompt files listed in PROMPT_PATHS (with context in frontmatter) 126# 2. Categories from observe/categories/*.md (tier/label/group in frontmatter) 127# 3. Talent configs from talent/*.md and apps/*/talent/*.md 128# 129# When adding new contexts: 130# 1. Create a .md prompt file with YAML frontmatter containing: 131# context, tier, label, group 132# 2. Add the path to PROMPT_PATHS 133# 3. If not listed, context falls back to the type's default tier 134# --------------------------------------------------------------------------- 135 136# Flat list of prompt files that define context metadata in frontmatter. 137# Each must have: context, tier, label, group in YAML frontmatter. 138PROMPT_PATHS: List[str] = [ 139 "observe/describe.md", 140 "observe/enrich.md", 141 "observe/extract.md", 142 "observe/transcribe/gemini.md", 143 "think/detect_created.md", 144 "think/detect_transcript_segment.md", 145 "think/detect_transcript_json.md", 146 "think/planner.md", 147] 148 149 150# --------------------------------------------------------------------------- 151# Dynamic context discovery 152# --------------------------------------------------------------------------- 153 154# Cached context registry (built lazily on first use) 155_context_registry: Optional[Dict[str, Dict[str, Any]]] = None 156_LEGACY_CONTEXT_PREFIX = "muse." 157_TALENT_CONTEXT_PREFIX = "talent." 158 159 160def _discover_prompt_contexts() -> Dict[str, Dict[str, Any]]: 161 """Load context metadata from prompt files listed in PROMPT_PATHS. 162 163 Each file must have YAML frontmatter with: 164 - context: The context string (e.g., "observe.enrich") 165 - tier: Tier number (1=pro, 2=flash, 3=lite) 166 - label: Human-readable name 167 - group: Settings UI category 168 169 Returns 170 ------- 171 Dict[str, Dict[str, Any]] 172 Mapping of context patterns to {tier, label, group} dicts. 173 """ 174 contexts = {} 175 base_dir = Path(__file__).parent.parent # Project root 176 177 for rel_path in PROMPT_PATHS: 178 path = base_dir / rel_path 179 if not path.exists(): 180 logging.getLogger(__name__).warning(f"Prompt file not found: {path}") 181 continue 182 183 try: 184 post = frontmatter.load(path) 185 meta = post.metadata or {} 186 187 context = meta.get("context") 188 if not context: 189 logging.getLogger(__name__).warning(f"No context in {path}") 190 continue 191 192 contexts[context] = { 193 "tier": meta.get("tier", TIER_FLASH), 194 "label": meta.get("label", context), 195 "group": meta.get("group", "Other"), 196 } 197 except Exception as e: 198 logging.getLogger(__name__).warning(f"Failed to load {path}: {e}") 199 200 return contexts 201 202 203def _discover_talent_contexts() -> Dict[str, Dict[str, Any]]: 204 """Discover talent context defaults from talent/*.md config files. 205 206 Uses get_talent_configs() from think.talent to load all talent configurations 207 and converts them to context patterns with tier/label/group metadata. 208 209 Returns 210 ------- 211 Dict[str, Dict[str, Any]] 212 Mapping of context patterns to {tier, label, group, type} dicts. 213 Context patterns are: talent.system.{name} or talent.{app}.{name} 214 """ 215 from think.talent import get_talent_configs, key_to_context 216 217 contexts = {} 218 219 # Load all talent configs (including disabled for completeness) 220 all_configs = get_talent_configs(include_disabled=True) 221 222 for key, config in all_configs.items(): 223 context = key_to_context(key) 224 contexts[context] = { 225 "tier": config.get("tier", TIER_FLASH), 226 "label": config.get("label", config.get("title", key)), 227 "group": config.get("group", "Think"), 228 "type": config.get("type"), 229 } 230 231 return contexts 232 233 234def _build_context_registry() -> Dict[str, Dict[str, Any]]: 235 """Build complete context registry from discovered configs. 236 237 Merges: 238 1. Prompt contexts from _discover_prompt_contexts() 239 2. Category contexts from observe/describe.py CATEGORIES 240 3. Talent contexts from _discover_talent_contexts() 241 242 Returns 243 ------- 244 Dict[str, Dict[str, Any]] 245 Complete context registry mapping patterns to {tier, label, group}. 246 """ 247 # Start with prompt contexts (from PROMPT_PATHS) 248 registry = _discover_prompt_contexts() 249 250 # Merge category contexts (lazy import to avoid circular dependency) 251 try: 252 from observe.describe import CATEGORIES 253 254 for category, metadata in CATEGORIES.items(): 255 context = metadata.get("context", f"observe.describe.{category}") 256 registry[context] = { 257 "tier": metadata.get("tier", TIER_FLASH), 258 "label": metadata.get("label", category.replace("_", " ").title()), 259 "group": metadata.get("group", "Screen Analysis"), 260 } 261 except ImportError: 262 pass # observe module not available 263 264 # Merge talent contexts (agents + generators) 265 talent_contexts = _discover_talent_contexts() 266 registry.update(talent_contexts) 267 268 return registry 269 270 271def get_context_registry() -> Dict[str, Dict[str, Any]]: 272 """Get the complete context registry, building it lazily on first use. 273 274 Returns 275 ------- 276 Dict[str, Dict[str, Any]] 277 Complete context registry mapping patterns to {tier, label, group}. 278 """ 279 global _context_registry 280 if _context_registry is None: 281 _context_registry = _build_context_registry() 282 return _context_registry 283 284 285def _resolve_tier(context: str, agent_type: str) -> int: 286 """Resolve context to tier number. 287 288 Checks journal config contexts first, then dynamic context registry with glob matching. 289 290 Parameters 291 ---------- 292 context 293 Context string (e.g., "talent.system.default", "observe.describe.frame"). 294 agent_type 295 Agent type ("generate" or "cogitate"). 296 297 Returns 298 ------- 299 int 300 Tier number (1=pro, 2=flash, 3=lite). 301 """ 302 from think.utils import get_config 303 304 default_tier = TYPE_DEFAULTS[agent_type]["tier"] 305 306 journal_config = get_config() 307 providers_config = journal_config.get("providers", {}) 308 contexts = providers_config.get("contexts", {}) 309 310 # Get dynamic context registry (discovered prompts, categories, talent configs) 311 registry = get_context_registry() 312 313 # Check journal config contexts first (exact match) 314 if context in contexts: 315 return contexts[context].get("tier", default_tier) 316 317 # Check context registry (exact match) 318 if context in registry: 319 return registry[context]["tier"] 320 321 # Check glob patterns in both 322 for pattern, ctx_config in contexts.items(): 323 if fnmatch.fnmatch(context, pattern): 324 return ctx_config.get("tier", default_tier) 325 326 for pattern, ctx_default in registry.items(): 327 if fnmatch.fnmatch(context, pattern): 328 return ctx_default["tier"] 329 330 return default_tier 331 332 333def _resolve_model(provider: str, tier: int, config_models: Dict[str, Any]) -> str: 334 """Resolve tier to model string for a given provider. 335 336 Checks config overrides first, then falls back to system defaults. 337 If requested tier is unavailable, falls back to more capable tiers 338 (3→2→1, i.e., lite→flash→pro). 339 340 Parameters 341 ---------- 342 provider 343 Provider name ("google", "openai", "anthropic"). 344 tier 345 Tier number (1=pro, 2=flash, 3=lite). 346 config_models 347 The "models" section from providers config, mapping provider to tier overrides. 348 349 Returns 350 ------- 351 str 352 Model identifier string. 353 """ 354 # Check config overrides first 355 provider_overrides = config_models.get(provider, {}) 356 357 # Try requested tier, then fall back to more capable tiers (lower numbers) 358 for t in [tier, tier - 1, tier - 2] if tier > 1 else [tier]: 359 if t < 1: 360 continue 361 362 # Check config override (tier as string key in JSON) 363 tier_key = str(t) 364 if tier_key in provider_overrides: 365 return provider_overrides[tier_key] 366 367 # Check system defaults 368 provider_defaults = PROVIDER_DEFAULTS.get(provider, {}) 369 if t in provider_defaults: 370 return provider_defaults[t] 371 372 # Ultimate fallback: system default for provider at TIER_FLASH 373 provider_defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["google"]) 374 return provider_defaults.get(TIER_FLASH, GEMINI_FLASH) 375 376 377def resolve_model_for_provider( 378 context: str, provider: str, agent_type: str = "generate" 379) -> str: 380 """Resolve model for a specific provider based on context tier. 381 382 Use this when provider is overridden from the default - resolves the 383 appropriate model for the given provider at the context's tier. 384 385 Parameters 386 ---------- 387 context 388 Context string (e.g., "talent.system.default"). 389 provider 390 Provider name ("google", "openai", "anthropic"). 391 agent_type 392 Agent type ("generate" or "cogitate"). 393 394 Returns 395 ------- 396 str 397 Model identifier string for the provider at the context's tier. 398 """ 399 from think.utils import get_config 400 401 tier = _resolve_tier(context, agent_type) 402 journal_config = get_config() 403 providers_config = journal_config.get("providers", {}) 404 config_models = providers_config.get("models", {}) 405 406 return _resolve_model(provider, tier, config_models) 407 408 409def resolve_provider(context: str, agent_type: str) -> tuple[str, str]: 410 """Resolve context to provider and model based on configuration. 411 412 Matches context against configured contexts using exact match first, 413 then glob patterns (via fnmatch), falling back to type-specific defaults. 414 415 Supports both explicit model strings and tier-based routing: 416 - {"provider": "google", "model": "gemini-3-flash-preview"} - explicit model 417 - {"provider": "google", "tier": 2} - tier-based (2=flash) 418 - {"tier": 1} - tier only, inherits provider from type default 419 420 The "models" section in providers config allows overriding which model 421 is used for each tier per provider. 422 423 Parameters 424 ---------- 425 context 426 Context string (e.g., "observe.describe.frame", "talent.system.meetings"). 427 agent_type 428 Agent type ("generate" or "cogitate"). 429 430 Returns 431 ------- 432 tuple[str, str] 433 (provider_name, model) tuple. Provider is one of "google", "openai", 434 "anthropic". Model is the full model identifier string. 435 """ 436 config = get_config() 437 providers = config.get("providers", {}) 438 config_models = providers.get("models", {}) 439 440 # Get type-specific defaults from config, falling back to system constants 441 type_defaults = TYPE_DEFAULTS[agent_type] 442 type_config = providers.get(agent_type, {}) 443 default_provider = type_config.get("provider", type_defaults["provider"]) 444 default_tier = type_config.get("tier", type_defaults["tier"]) 445 446 # Handle explicit "model" key in type config (overrides tier-based resolution) 447 if "model" in type_config and "tier" not in type_config: 448 default_model = type_config["model"] 449 else: 450 default_model = _resolve_model(default_provider, default_tier, config_models) 451 452 contexts = providers.get("contexts", {}) 453 454 # Find matching context config 455 match_config: Optional[Dict[str, Any]] = None 456 457 if context and contexts: 458 # Check for exact match first 459 if context in contexts: 460 match_config = contexts[context] 461 else: 462 # Check glob patterns - most specific (longest non-wildcard prefix) wins 463 matches = [] 464 for pattern, ctx_config in contexts.items(): 465 if fnmatch.fnmatch(context, pattern): 466 specificity = len(pattern.split("*")[0]) 467 matches.append((specificity, pattern, ctx_config)) 468 469 if matches: 470 matches.sort(key=lambda x: x[0], reverse=True) 471 _, _, match_config = matches[0] 472 473 # No context match - check dynamic context registry for this context 474 if match_config is None: 475 # Get dynamic context registry (discovered prompts, categories, talent configs) 476 registry = get_context_registry() 477 478 # Check for matching context default (exact match first, then glob) 479 context_tier = None 480 if context: 481 if context in registry: 482 context_tier = registry[context]["tier"] 483 else: 484 # Check glob patterns 485 matches = [] 486 for pattern, ctx_default in registry.items(): 487 if fnmatch.fnmatch(context, pattern): 488 specificity = len(pattern.split("*")[0]) 489 matches.append((specificity, ctx_default["tier"])) 490 if matches: 491 matches.sort(key=lambda x: x[0], reverse=True) 492 context_tier = matches[0][1] 493 494 if context_tier is not None: 495 model = _resolve_model(default_provider, context_tier, config_models) 496 return (default_provider, model) 497 498 return (default_provider, default_model) 499 500 # Resolve provider (from match or default) 501 provider = match_config.get("provider", default_provider) 502 503 # Resolve model: explicit model takes precedence over tier 504 if "model" in match_config: 505 model = match_config["model"] 506 elif "tier" in match_config: 507 tier = match_config["tier"] 508 # Validate tier 509 if not isinstance(tier, int) or tier < 1 or tier > 3: 510 logging.getLogger(__name__).warning( 511 "Invalid tier %r in context %r, using default", tier, context 512 ) 513 tier = default_tier 514 model = _resolve_model(provider, tier, config_models) 515 else: 516 # No model or tier specified - use default tier 517 model = _resolve_model(provider, default_tier, config_models) 518 519 return (provider, model) 520 521 522def log_token_usage( 523 model: str, 524 usage: Union[Dict[str, Any], Any], 525 context: Optional[str] = None, 526 segment: Optional[str] = None, 527 type: Optional[str] = None, 528) -> None: 529 """Log token usage to journal with unified schema. 530 531 Providers normalize usage into the unified schema (see USAGE_KEYS in 532 shared.py) before returning GenerateResult. This function passes 533 through those known keys, computes total_tokens when missing, and 534 handles a few legacy field aliases from CLI backends. 535 536 Parameters 537 ---------- 538 model : str 539 Model name (e.g., "gpt-5", "gemini-2.5-flash") 540 usage : dict 541 Normalized usage dict with keys from USAGE_KEYS. 542 context : str, optional 543 Context string (e.g., "module.function:123" or "talent.system.default"). 544 If None, auto-detects from call stack. 545 segment : str, optional 546 Segment key (e.g., "143022_300") for attribution. 547 If None, falls back to SOL_SEGMENT environment variable. 548 type : str, optional 549 Token entry type (e.g., "generate", "cogitate"). 550 """ 551 from think.providers.shared import USAGE_KEYS 552 553 try: 554 journal = get_journal() 555 556 # Auto-detect calling context if not provided 557 if context is None: 558 frame = inspect.currentframe() 559 caller_frame = frame.f_back if frame else None 560 561 # Skip frames that contain "gemini" in function name 562 while caller_frame and "gemini" in caller_frame.f_code.co_name.lower(): 563 caller_frame = caller_frame.f_back 564 565 if caller_frame: 566 module_name = caller_frame.f_globals.get("__name__", "unknown") 567 func_name = caller_frame.f_code.co_name 568 line_num = caller_frame.f_lineno 569 570 # Clean up module name 571 for prefix in ["think.", "observe.", "convey."]: 572 if module_name.startswith(prefix): 573 module_name = module_name[len(prefix) :] 574 break 575 576 context = f"{module_name}.{func_name}:{line_num}" 577 578 # Pass through known keys from the already-normalized usage dict. 579 normalized_usage: Dict[str, int] = {} 580 for key in USAGE_KEYS: 581 val = usage.get(key) 582 if val: 583 normalized_usage[key] = val 584 585 # Legacy alias: some CLI backends emit cached_input_tokens 586 if not normalized_usage.get("cached_tokens") and usage.get( 587 "cached_input_tokens" 588 ): 589 normalized_usage["cached_tokens"] = usage["cached_input_tokens"] 590 591 # Compute total_tokens from parts when missing (e.g. Codex CLI omits it) 592 if not normalized_usage.get("total_tokens"): 593 inp = normalized_usage.get("input_tokens", 0) 594 out = normalized_usage.get("output_tokens", 0) 595 if inp or out: 596 normalized_usage["total_tokens"] = inp + out 597 598 # Build token log entry 599 token_data = { 600 "timestamp": time.time(), 601 "model": model, 602 "context": context, 603 "usage": normalized_usage, 604 } 605 606 # Add segment: prefer parameter, fallback to env (set by think/insight, observe handlers) 607 segment_key = segment or os.getenv("SOL_SEGMENT") 608 if segment_key: 609 token_data["segment"] = segment_key 610 if type: 611 token_data["type"] = type 612 613 # Save to journal/tokens/<YYYYMMDD>.jsonl (one file per day) 614 tokens_dir = Path(journal) / "tokens" 615 tokens_dir.mkdir(exist_ok=True) 616 617 filename = time.strftime("%Y%m%d.jsonl") 618 filepath = tokens_dir / filename 619 620 # Atomic append - safe for parallel writers 621 with open(filepath, "a") as f: 622 f.write(json.dumps(token_data) + "\n") 623 624 except Exception: 625 # Silently fail - logging shouldn't break the main flow 626 pass 627 628 629def get_model_provider(model: str) -> str: 630 """Get the provider name from a model identifier. 631 632 Parameters 633 ---------- 634 model : str 635 Model name (e.g., "gpt-5", "gemini-2.5-flash", "claude-sonnet-4-5") 636 637 Returns 638 ------- 639 str 640 Provider name: "openai", "google", "anthropic", or "unknown" 641 """ 642 model_lower = model.lower() 643 644 if model_lower.startswith("gpt"): 645 return "openai" 646 elif model_lower.startswith("gemini"): 647 return "google" 648 elif model_lower.startswith("claude"): 649 return "anthropic" 650 else: 651 return "unknown" 652 653 654def calc_token_cost(token_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: 655 """Calculate cost for a token usage record. 656 657 Parameters 658 ---------- 659 token_data : dict 660 Token usage record from journal logs with structure: 661 { 662 "model": "gemini-2.5-flash", 663 "usage": { 664 "input_tokens": 1500, 665 "output_tokens": 500, 666 "cached_tokens": 800, 667 "reasoning_tokens": 200, 668 ... 669 } 670 } 671 672 Returns 673 ------- 674 dict or None 675 Cost breakdown: 676 { 677 "total_cost": 0.00123, 678 "input_cost": 0.00075, 679 "output_cost": 0.00048, 680 "currency": "USD" 681 } 682 Returns None if pricing unavailable or calculation fails. 683 """ 684 try: 685 from genai_prices import Usage, calc_price 686 687 model = token_data.get("model") 688 usage_data = token_data.get("usage", {}) 689 690 if not model or not usage_data: 691 return None 692 693 # Strip OpenAI reasoning effort suffixes for price lookup 694 for suffix in OPENAI_EFFORT_SUFFIXES: 695 if model.endswith(suffix): 696 model = model[: -len(suffix)] 697 break 698 699 # Get provider ID before aliasing (alias may change the model family) 700 provider_id = get_model_provider(model) 701 if provider_id == "unknown": 702 return None 703 704 # Apply price aliases for models genai-prices doesn't recognize yet 705 model = MODEL_PRICE_ALIASES.get(model, model) 706 707 # Map our token fields to genai_prices Usage format 708 # Note: Gemini reports reasoning_tokens separately, but they're billed at 709 # output token rates. genai-prices doesn't have a separate field for reasoning, 710 # so we add them to output_tokens for correct pricing. 711 input_tokens = usage_data.get("input_tokens", 0) 712 output_tokens = usage_data.get("output_tokens", 0) 713 cached_tokens = usage_data.get("cached_tokens", 0) 714 reasoning_tokens = usage_data.get("reasoning_tokens", 0) 715 716 # Add reasoning tokens to output for pricing (Gemini bills them as output) 717 total_output_tokens = output_tokens + reasoning_tokens 718 719 # Create Usage object 720 usage = Usage( 721 input_tokens=input_tokens, 722 output_tokens=total_output_tokens, 723 cache_read_tokens=cached_tokens if cached_tokens > 0 else None, 724 ) 725 726 # Calculate price 727 result = calc_price( 728 usage=usage, 729 model_ref=model, 730 provider_id=provider_id, 731 ) 732 733 # Return simplified cost breakdown 734 return { 735 "total_cost": float(result.total_price), 736 "input_cost": float(result.input_price), 737 "output_cost": float(result.output_price), 738 "currency": "USD", 739 } 740 741 except Exception: 742 # Silently fail if pricing unavailable 743 return None 744 745 746def calc_agent_cost( 747 model: Optional[str], usage: Optional[Dict[str, Any]] 748) -> Optional[float]: 749 """Calculate total cost for an agent run from model and usage data. 750 751 Convenience wrapper around calc_token_cost for agent cost lookups. 752 753 Returns total cost in USD, or None if data is missing or pricing unavailable. 754 """ 755 if not model or not usage: 756 return None 757 try: 758 cost_data = calc_token_cost({"model": model, "usage": usage}) 759 if cost_data: 760 return cost_data["total_cost"] 761 except Exception: 762 return None 763 return None 764 765 766def _normalize_legacy_context(ctx: str) -> str: 767 """Normalize legacy token-log context strings to the talent namespace.""" 768 if ctx.startswith(_LEGACY_CONTEXT_PREFIX): 769 return _TALENT_CONTEXT_PREFIX + ctx[len(_LEGACY_CONTEXT_PREFIX) :] 770 return ctx 771 772 773def iter_token_log(day: str) -> Any: 774 """Iterate over token log entries for a given day. 775 776 Yields parsed JSON entries from the token log file, skipping empty lines 777 and invalid JSON. This is a shared utility for code that processes token logs. 778 779 Parameters 780 ---------- 781 day : str 782 Day in YYYYMMDD format. 783 784 Yields 785 ------ 786 dict 787 Parsed token log entry with fields: timestamp, model, context, usage, 788 and optionally segment. 789 """ 790 journal = get_journal() 791 log_path = Path(journal) / "tokens" / f"{day}.jsonl" 792 793 if not log_path.exists(): 794 return 795 796 with open(log_path, "r") as f: 797 for line in f: 798 line = line.strip() 799 if not line: 800 continue 801 try: 802 entry = json.loads(line) 803 ctx = entry.get("context") 804 if isinstance(ctx, str): 805 entry["context"] = _normalize_legacy_context(ctx) 806 yield entry 807 except json.JSONDecodeError: 808 continue 809 810 811def get_usage_cost( 812 day: str, 813 segment: Optional[str] = None, 814 context: Optional[str] = None, 815) -> Dict[str, Any]: 816 """Get aggregated token usage and cost for a day, optionally filtered. 817 818 This is a shared utility for apps that want to display cost information 819 for segments, agent runs, or other contexts. 820 821 Parameters 822 ---------- 823 day : str 824 Day in YYYYMMDD format. 825 segment : str, optional 826 Filter to entries with this exact segment key. 827 context : str, optional 828 Filter to entries where context starts with this prefix. 829 For example, "talent.system" matches "talent.system.default". 830 831 Returns 832 ------- 833 dict 834 Aggregated usage data: 835 { 836 "requests": int, 837 "tokens": int, 838 "cost": float, # USD 839 } 840 Returns zeros if no matching entries or day file doesn't exist. 841 """ 842 result = {"requests": 0, "tokens": 0, "cost": 0.0} 843 844 for entry in iter_token_log(day): 845 # Apply filters 846 if segment is not None and entry.get("segment") != segment: 847 continue 848 if context is not None: 849 entry_context = entry.get("context", "") 850 if not entry_context.startswith(context): 851 continue 852 853 # Skip unknown providers (can't calculate cost) 854 model = entry.get("model", "unknown") 855 if get_model_provider(model) == "unknown": 856 continue 857 858 # Accumulate 859 usage = entry.get("usage", {}) 860 result["requests"] += 1 861 result["tokens"] += usage.get("total_tokens", 0) or 0 862 863 cost_data = calc_token_cost(entry) 864 if cost_data: 865 result["cost"] += cost_data["total_cost"] 866 867 return result 868 869 870# --------------------------------------------------------------------------- 871# Unified generate/agenerate with provider routing 872# --------------------------------------------------------------------------- 873 874 875def _validate_json_response(result: Dict[str, Any], json_output: bool) -> None: 876 """Validate response for JSON output mode. 877 878 Raises IncompleteJSONError if finish_reason indicates truncation. 879 """ 880 if not json_output: 881 return 882 883 finish_reason = result.get("finish_reason") 884 if finish_reason and finish_reason != "stop": 885 raise IncompleteJSONError( 886 reason=finish_reason, 887 partial_text=result.get("text", ""), 888 ) 889 890 891def generate( 892 contents: Union[str, List[Any]], 893 context: str, 894 temperature: float = 0.3, 895 max_output_tokens: int = 8192 * 2, 896 system_instruction: Optional[str] = None, 897 json_output: bool = False, 898 thinking_budget: Optional[int] = None, 899 timeout_s: Optional[float] = None, 900 **kwargs: Any, 901) -> str: 902 """Generate text using the configured provider for the given context. 903 904 Routes the request to the appropriate backend (Google, OpenAI, or Anthropic) 905 based on the providers configuration in journal.json. 906 907 Parameters 908 ---------- 909 contents : str or List 910 The content to send to the model. 911 context : str 912 Context string for routing and token logging (e.g., "talent.system.meetings"). 913 This is required and determines which provider/model to use. 914 temperature : float 915 Temperature for generation (default: 0.3). 916 max_output_tokens : int 917 Maximum tokens for the model's response output. 918 system_instruction : str, optional 919 System instruction for the model. 920 json_output : bool 921 Whether to request JSON response format. 922 thinking_budget : int, optional 923 Token budget for model thinking (ignored by providers that don't support it). 924 timeout_s : float, optional 925 Request timeout in seconds. 926 **kwargs 927 Additional provider-specific options passed through to the backend. 928 929 Returns 930 ------- 931 str 932 Response text from the model. 933 934 Raises 935 ------ 936 ValueError 937 If the resolved provider is not supported. 938 IncompleteJSONError 939 If json_output=True and response was truncated. 940 """ 941 from think.providers import get_provider_module 942 943 # Allow model override via kwargs (used by callers with explicit model selection) 944 model_override = kwargs.pop("model", None) 945 946 provider, model = resolve_provider(context, "generate") 947 if model_override: 948 model = model_override 949 950 # Get provider module via registry (raises ValueError for unknown providers) 951 provider_mod = get_provider_module(provider) 952 953 # Call provider's run_generate (returns GenerateResult) 954 result = provider_mod.run_generate( 955 contents=contents, 956 model=model, 957 temperature=temperature, 958 max_output_tokens=max_output_tokens, 959 system_instruction=system_instruction, 960 json_output=json_output, 961 thinking_budget=thinking_budget, 962 timeout_s=timeout_s, 963 **kwargs, 964 ) 965 966 # Log token usage centrally (before validation so truncated responses 967 # still get their usage recorded) 968 if result.get("usage"): 969 log_token_usage( 970 model=model, 971 usage=result["usage"], 972 context=context, 973 type="generate", 974 ) 975 976 # Validate JSON output if requested 977 _validate_json_response(result, json_output) 978 979 return result["text"] 980 981 982# --------------------------------------------------------------------------- 983# Provider Health & Fallback Helpers 984# --------------------------------------------------------------------------- 985 986 987def get_backup_provider(agent_type: str) -> Optional[str]: 988 """Get the backup provider for the given agent type. 989 990 Reads from the type-specific section in journal config, falling back 991 to TYPE_DEFAULTS. 992 993 Returns None if backup would be the same as the primary provider. 994 """ 995 type_defaults = TYPE_DEFAULTS[agent_type] 996 config = get_config() 997 providers_config = config.get("providers", {}) 998 type_config = providers_config.get(agent_type, {}) 999 primary_provider = type_config.get("provider", type_defaults["provider"]) 1000 backup = type_config.get("backup", type_defaults["backup"]) 1001 if backup == primary_provider: 1002 return None 1003 return backup 1004 1005 1006def load_health_status() -> Optional[dict]: 1007 """Load health status from journal/health/agents.json. 1008 1009 Returns parsed dict or None if file is missing/unreadable. 1010 """ 1011 try: 1012 health_path = Path(get_journal()) / "health" / "agents.json" 1013 with open(health_path) as f: 1014 return json.load(f) 1015 except (FileNotFoundError, json.JSONDecodeError, OSError): 1016 return None 1017 1018 1019def is_provider_healthy(provider: str, health_data: Optional[dict]) -> bool: 1020 """Check if a provider is healthy based on health data. 1021 1022 Returns True (assume healthy) when: 1023 - health_data is None (no data available) 1024 - No results exist for the provider 1025 - Any result for the provider has ok=True 1026 1027 Returns False only when all results for the provider have ok=False. 1028 """ 1029 if health_data is None: 1030 return True 1031 results = health_data.get("results", []) 1032 provider_results = [r for r in results if r.get("provider") == provider] 1033 if not provider_results: 1034 return True 1035 return any(r.get("ok") for r in provider_results) 1036 1037 1038def should_recheck_health(health_data: Optional[dict]) -> bool: 1039 """Check if health data is stale (>1 hour old). 1040 1041 Returns False when health_data is None or on parse errors. 1042 """ 1043 if health_data is None: 1044 return False 1045 checked_at = health_data.get("checked_at") 1046 if not checked_at: 1047 return False 1048 try: 1049 checked_time = datetime.fromisoformat(checked_at) 1050 if checked_time.tzinfo is None: 1051 checked_time = checked_time.replace(tzinfo=timezone.utc) 1052 age = datetime.now(timezone.utc) - checked_time 1053 return age.total_seconds() > 3600 1054 except (ValueError, TypeError): 1055 return False 1056 1057 1058def request_health_recheck() -> None: 1059 """Request a health re-check by spawning a background process. 1060 1061 Fire-and-forget; errors are logged but never propagated. 1062 """ 1063 try: 1064 subprocess.Popen( 1065 ["sol", "agents", "check", "--targeted"], 1066 stdout=subprocess.DEVNULL, 1067 stderr=subprocess.DEVNULL, 1068 ) 1069 except Exception: 1070 logging.getLogger(__name__).debug( 1071 "Failed to request health recheck", exc_info=True 1072 ) 1073 1074 1075def generate_with_result( 1076 contents: Union[str, List[Any]], 1077 context: str, 1078 temperature: float = 0.3, 1079 max_output_tokens: int = 8192 * 2, 1080 system_instruction: Optional[str] = None, 1081 json_output: bool = False, 1082 thinking_budget: Optional[int] = None, 1083 timeout_s: Optional[float] = None, 1084 **kwargs: Any, 1085) -> dict: 1086 """Generate text and return full result with usage data. 1087 1088 Same as generate() but returns the full GenerateResult dict instead of 1089 just the text. Used by cortex-managed generators that need usage data 1090 for event emission. 1091 1092 Returns 1093 ------- 1094 dict 1095 GenerateResult with: text, usage, finish_reason, thinking. 1096 """ 1097 from think.providers import get_provider_module 1098 1099 model_override = kwargs.pop("model", None) 1100 provider_override = kwargs.pop("provider", None) 1101 1102 provider, model = resolve_provider(context, "generate") 1103 if provider_override: 1104 provider = provider_override 1105 if not model_override: 1106 model = resolve_model_for_provider(context, provider, "generate") 1107 if model_override: 1108 model = model_override 1109 1110 provider_mod = get_provider_module(provider) 1111 1112 result = provider_mod.run_generate( 1113 contents=contents, 1114 model=model, 1115 temperature=temperature, 1116 max_output_tokens=max_output_tokens, 1117 system_instruction=system_instruction, 1118 json_output=json_output, 1119 thinking_budget=thinking_budget, 1120 timeout_s=timeout_s, 1121 **kwargs, 1122 ) 1123 1124 # Log token usage centrally (before validation so truncated responses 1125 # still get their usage recorded) 1126 if result.get("usage"): 1127 log_token_usage( 1128 model=model, 1129 usage=result["usage"], 1130 context=context, 1131 type="generate", 1132 ) 1133 1134 # Validate JSON output if requested 1135 _validate_json_response(result, json_output) 1136 1137 return result 1138 1139 1140async def agenerate( 1141 contents: Union[str, List[Any]], 1142 context: str, 1143 temperature: float = 0.3, 1144 max_output_tokens: int = 8192 * 2, 1145 system_instruction: Optional[str] = None, 1146 json_output: bool = False, 1147 thinking_budget: Optional[int] = None, 1148 timeout_s: Optional[float] = None, 1149 **kwargs: Any, 1150) -> str: 1151 """Async generate text using the configured provider for the given context. 1152 1153 Routes the request to the appropriate backend (Google, OpenAI, or Anthropic) 1154 based on the providers configuration in journal.json. 1155 1156 Parameters 1157 ---------- 1158 contents : str or List 1159 The content to send to the model. 1160 context : str 1161 Context string for routing and token logging (e.g., "talent.system.meetings"). 1162 This is required and determines which provider/model to use. 1163 temperature : float 1164 Temperature for generation (default: 0.3). 1165 max_output_tokens : int 1166 Maximum tokens for the model's response output. 1167 system_instruction : str, optional 1168 System instruction for the model. 1169 json_output : bool 1170 Whether to request JSON response format. 1171 thinking_budget : int, optional 1172 Token budget for model thinking (ignored by providers that don't support it). 1173 timeout_s : float, optional 1174 Request timeout in seconds. 1175 **kwargs 1176 Additional provider-specific options passed through to the backend. 1177 1178 Returns 1179 ------- 1180 str 1181 Response text from the model. 1182 1183 Raises 1184 ------ 1185 ValueError 1186 If the resolved provider is not supported. 1187 IncompleteJSONError 1188 If json_output=True and response was truncated. 1189 """ 1190 from think.providers import get_provider_module 1191 1192 # Allow model override via kwargs (used by Batch for explicit model selection) 1193 model_override = kwargs.pop("model", None) 1194 1195 provider, model = resolve_provider(context, "generate") 1196 if model_override: 1197 model = model_override 1198 1199 # Get provider module via registry (raises ValueError for unknown providers) 1200 provider_mod = get_provider_module(provider) 1201 1202 # Call provider's run_agenerate (returns GenerateResult) 1203 result = await provider_mod.run_agenerate( 1204 contents=contents, 1205 model=model, 1206 temperature=temperature, 1207 max_output_tokens=max_output_tokens, 1208 system_instruction=system_instruction, 1209 json_output=json_output, 1210 thinking_budget=thinking_budget, 1211 timeout_s=timeout_s, 1212 **kwargs, 1213 ) 1214 1215 # Log token usage centrally (before validation so truncated responses 1216 # still get their usage recorded) 1217 if result.get("usage"): 1218 log_token_usage( 1219 model=model, 1220 usage=result["usage"], 1221 context=context, 1222 type="generate", 1223 ) 1224 1225 # Validate JSON output if requested 1226 _validate_json_response(result, json_output) 1227 1228 return result["text"] 1229 1230 1231__all__ = [ 1232 # Provider configuration 1233 "TYPE_DEFAULTS", 1234 "PROMPT_PATHS", 1235 "get_context_registry", 1236 # Model constants (used by provider backends for defaults) 1237 "GEMINI_FLASH", 1238 "GPT_5", 1239 "CLAUDE_SONNET_4", 1240 # Unified API 1241 "generate", 1242 "generate_with_result", 1243 "agenerate", 1244 "resolve_provider", 1245 # Utilities 1246 "log_token_usage", 1247 "calc_token_cost", 1248 "calc_agent_cost", 1249 "get_usage_cost", 1250 "iter_token_log", 1251 "get_model_provider", 1252]