personal memory agent
at main 245 lines 7.3 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4"""Shared utilities and types for AI providers. 5 6This module contains: 7- Event TypedDicts emitted by providers during agent execution 8- GenerateResult TypedDict returned by run_generate/run_agenerate 9- JSONEventCallback for event emission 10- Utility functions for common provider operations 11""" 12 13from __future__ import annotations 14 15import json 16from typing import Any, Callable, Literal, Optional, Union 17 18from typing_extensions import Required, TypedDict 19 20from think.utils import now_ms 21 22# --------------------------------------------------------------------------- 23# Event Types 24# --------------------------------------------------------------------------- 25 26 27class ToolStartEvent(TypedDict, total=False): 28 """Event emitted when a tool starts.""" 29 30 event: Literal["tool_start"] 31 ts: int 32 tool: str 33 args: Optional[dict[str, Any]] 34 call_id: Optional[str] # Unique ID to pair with tool_end event 35 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s) 36 37 38class ToolEndEvent(TypedDict, total=False): 39 """Event emitted when a tool finishes.""" 40 41 event: Literal["tool_end"] 42 ts: int 43 tool: str 44 args: Optional[dict[str, Any]] 45 result: Any 46 call_id: Optional[str] # Matches the call_id from tool_start 47 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s) 48 49 50class StartEvent(TypedDict, total=False): 51 """Event emitted when an agent run begins.""" 52 53 event: Required[Literal["start"]] 54 ts: Required[int] 55 prompt: Required[str] 56 name: Required[str] 57 model: Required[str] 58 provider: Required[str] 59 session_id: Optional[str] # CLI session ID for continuation 60 chat_id: Optional[str] # Chat ID for reverse lookup 61 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s) 62 63 64class FinishEvent(TypedDict, total=False): 65 """Event emitted when an agent run finishes successfully.""" 66 67 event: Required[Literal["finish"]] 68 ts: Required[int] 69 result: Required[str] 70 usage: Optional[dict[str, Any]] 71 cli_session_id: Optional[str] # Provider CLI session/thread ID for resume 72 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s) 73 74 75class ErrorEvent(TypedDict, total=False): 76 """Event emitted when an error occurs.""" 77 78 event: Literal["error"] 79 ts: int 80 error: str 81 trace: Optional[str] 82 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s) 83 84 85class AgentUpdatedEvent(TypedDict, total=False): 86 """Event emitted when the agent context changes.""" 87 88 event: Required[Literal["agent_updated"]] 89 ts: Required[int] 90 agent: Required[str] 91 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s) 92 93 94class ThinkingEvent(TypedDict, total=False): 95 """Event emitted when thinking/reasoning summaries are available. 96 97 For Anthropic models, may include a signature for verification when 98 passing thinking blocks back during tool use continuations. 99 For redacted thinking, summary will contain "[redacted]" and 100 redacted_data will contain the encrypted content. 101 """ 102 103 event: Required[Literal["thinking"]] 104 ts: Required[int] 105 summary: Required[str] 106 model: Optional[str] 107 signature: Optional[str] # Anthropic thinking block signature 108 redacted_data: Optional[str] # Encrypted data for redacted thinking 109 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s) 110 111 112class FallbackEvent(TypedDict, total=False): 113 """Event emitted when provider fallback occurs.""" 114 115 event: Required[Literal["fallback"]] 116 ts: Required[int] 117 original_provider: Required[str] 118 backup_provider: Required[str] 119 reason: Required[str] # "preflight" or "on_failure" 120 error: Optional[str] # Error message for on_failure case 121 122 123Event = Union[ 124 ToolStartEvent, 125 ToolEndEvent, 126 StartEvent, 127 FinishEvent, 128 ErrorEvent, 129 ThinkingEvent, 130 AgentUpdatedEvent, 131 FallbackEvent, 132] 133 134 135# --------------------------------------------------------------------------- 136# Usage Schema 137# --------------------------------------------------------------------------- 138 139# Canonical keys for the normalized usage dict returned by all providers. 140# log_token_usage() passes through exactly these keys (when present and non-zero). 141USAGE_KEYS = frozenset( 142 { 143 "input_tokens", 144 "output_tokens", 145 "total_tokens", 146 "cached_tokens", 147 "reasoning_tokens", 148 "cache_creation_tokens", 149 "requests", 150 } 151) 152 153# --------------------------------------------------------------------------- 154# GenerateResult 155# --------------------------------------------------------------------------- 156 157 158class GenerateResult(TypedDict, total=False): 159 """Result from provider run_generate/run_agenerate functions. 160 161 Structured result that allows the wrapper to handle cross-cutting concerns 162 like token logging and JSON validation centrally. 163 164 The thinking field contains dicts with: summary (str), signature (optional str), 165 redacted_data (optional str for Anthropic redacted thinking). 166 """ 167 168 text: Required[str] # Response text 169 usage: Optional[dict] # Normalized usage dict (input_tokens, output_tokens, etc.) 170 finish_reason: Optional[str] # Normalized: "stop", "max_tokens", "safety", etc. 171 thinking: Optional[list] # List of thinking block dicts 172 173 174# --------------------------------------------------------------------------- 175# JSONEventCallback 176# --------------------------------------------------------------------------- 177 178 179class JSONEventCallback: 180 """Emit JSON events via a callback.""" 181 182 def __init__(self, callback: Optional[Callable[[Event], None]] = None) -> None: 183 self.callback = callback 184 185 def emit(self, data: Event) -> None: 186 if "ts" not in data: 187 data = {**data, "ts": now_ms()} 188 if self.callback: 189 self.callback(data) 190 191 def close(self) -> None: 192 pass 193 194 195# --------------------------------------------------------------------------- 196# Raw Event Trimming 197# --------------------------------------------------------------------------- 198 199# Structural keys preserved when trimming oversized raw events. 200_RAW_STRUCTURAL_KEYS = frozenset( 201 { 202 "type", 203 "id", 204 "tool_id", 205 "tool_name", 206 "role", 207 "event_type", 208 "timestamp", 209 } 210) 211 212_RAW_BYTE_LIMIT = 16_384 # 16 KB 213 214 215def safe_raw( 216 events: list[dict[str, Any]], 217 limit: int = _RAW_BYTE_LIMIT, 218) -> list[dict[str, Any]]: 219 """Return *events* as-is if small enough, otherwise a trimmed version. 220 221 When the JSON-serialized size exceeds *limit* bytes, each event is reduced 222 to its structural keys and a ``_raw_trimmed`` dict is appended with the 223 original byte count and the limit that was applied. 224 """ 225 serialized = json.dumps(events, ensure_ascii=False) 226 if len(serialized.encode("utf-8")) <= limit: 227 return events 228 229 trimmed = [ 230 {k: v for k, v in e.items() if k in _RAW_STRUCTURAL_KEYS} for e in events 231 ] 232 trimmed.append( 233 {"_raw_trimmed": {"original_bytes": len(serialized), "limit": limit}} 234 ) 235 return trimmed 236 237 238__all__ = [ 239 "Event", 240 "GenerateResult", 241 "JSONEventCallback", 242 "ThinkingEvent", 243 "USAGE_KEYS", 244 "safe_raw", 245]