personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4"""Shared utilities and types for AI providers.
5
6This module contains:
7- Event TypedDicts emitted by providers during agent execution
8- GenerateResult TypedDict returned by run_generate/run_agenerate
9- JSONEventCallback for event emission
10- Utility functions for common provider operations
11"""
12
13from __future__ import annotations
14
15import json
16from typing import Any, Callable, Literal, Optional, Union
17
18from typing_extensions import Required, TypedDict
19
20from think.utils import now_ms
21
22# ---------------------------------------------------------------------------
23# Event Types
24# ---------------------------------------------------------------------------
25
26
27class ToolStartEvent(TypedDict, total=False):
28 """Event emitted when a tool starts."""
29
30 event: Literal["tool_start"]
31 ts: int
32 tool: str
33 args: Optional[dict[str, Any]]
34 call_id: Optional[str] # Unique ID to pair with tool_end event
35 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s)
36
37
38class ToolEndEvent(TypedDict, total=False):
39 """Event emitted when a tool finishes."""
40
41 event: Literal["tool_end"]
42 ts: int
43 tool: str
44 args: Optional[dict[str, Any]]
45 result: Any
46 call_id: Optional[str] # Matches the call_id from tool_start
47 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s)
48
49
50class StartEvent(TypedDict, total=False):
51 """Event emitted when an agent run begins."""
52
53 event: Required[Literal["start"]]
54 ts: Required[int]
55 prompt: Required[str]
56 name: Required[str]
57 model: Required[str]
58 provider: Required[str]
59 session_id: Optional[str] # CLI session ID for continuation
60 chat_id: Optional[str] # Chat ID for reverse lookup
61 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s)
62
63
64class FinishEvent(TypedDict, total=False):
65 """Event emitted when an agent run finishes successfully."""
66
67 event: Required[Literal["finish"]]
68 ts: Required[int]
69 result: Required[str]
70 usage: Optional[dict[str, Any]]
71 cli_session_id: Optional[str] # Provider CLI session/thread ID for resume
72 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s)
73
74
75class ErrorEvent(TypedDict, total=False):
76 """Event emitted when an error occurs."""
77
78 event: Literal["error"]
79 ts: int
80 error: str
81 trace: Optional[str]
82 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s)
83
84
85class AgentUpdatedEvent(TypedDict, total=False):
86 """Event emitted when the agent context changes."""
87
88 event: Required[Literal["agent_updated"]]
89 ts: Required[int]
90 agent: Required[str]
91 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s)
92
93
94class ThinkingEvent(TypedDict, total=False):
95 """Event emitted when thinking/reasoning summaries are available.
96
97 For Anthropic models, may include a signature for verification when
98 passing thinking blocks back during tool use continuations.
99 For redacted thinking, summary will contain "[redacted]" and
100 redacted_data will contain the encrypted content.
101 """
102
103 event: Required[Literal["thinking"]]
104 ts: Required[int]
105 summary: Required[str]
106 model: Optional[str]
107 signature: Optional[str] # Anthropic thinking block signature
108 redacted_data: Optional[str] # Encrypted data for redacted thinking
109 raw: Optional[list[dict[str, Any]]] # Original provider JSON event(s)
110
111
112class FallbackEvent(TypedDict, total=False):
113 """Event emitted when provider fallback occurs."""
114
115 event: Required[Literal["fallback"]]
116 ts: Required[int]
117 original_provider: Required[str]
118 backup_provider: Required[str]
119 reason: Required[str] # "preflight" or "on_failure"
120 error: Optional[str] # Error message for on_failure case
121
122
123Event = Union[
124 ToolStartEvent,
125 ToolEndEvent,
126 StartEvent,
127 FinishEvent,
128 ErrorEvent,
129 ThinkingEvent,
130 AgentUpdatedEvent,
131 FallbackEvent,
132]
133
134
135# ---------------------------------------------------------------------------
136# Usage Schema
137# ---------------------------------------------------------------------------
138
139# Canonical keys for the normalized usage dict returned by all providers.
140# log_token_usage() passes through exactly these keys (when present and non-zero).
141USAGE_KEYS = frozenset(
142 {
143 "input_tokens",
144 "output_tokens",
145 "total_tokens",
146 "cached_tokens",
147 "reasoning_tokens",
148 "cache_creation_tokens",
149 "requests",
150 }
151)
152
153# ---------------------------------------------------------------------------
154# GenerateResult
155# ---------------------------------------------------------------------------
156
157
158class GenerateResult(TypedDict, total=False):
159 """Result from provider run_generate/run_agenerate functions.
160
161 Structured result that allows the wrapper to handle cross-cutting concerns
162 like token logging and JSON validation centrally.
163
164 The thinking field contains dicts with: summary (str), signature (optional str),
165 redacted_data (optional str for Anthropic redacted thinking).
166 """
167
168 text: Required[str] # Response text
169 usage: Optional[dict] # Normalized usage dict (input_tokens, output_tokens, etc.)
170 finish_reason: Optional[str] # Normalized: "stop", "max_tokens", "safety", etc.
171 thinking: Optional[list] # List of thinking block dicts
172
173
174# ---------------------------------------------------------------------------
175# JSONEventCallback
176# ---------------------------------------------------------------------------
177
178
179class JSONEventCallback:
180 """Emit JSON events via a callback."""
181
182 def __init__(self, callback: Optional[Callable[[Event], None]] = None) -> None:
183 self.callback = callback
184
185 def emit(self, data: Event) -> None:
186 if "ts" not in data:
187 data = {**data, "ts": now_ms()}
188 if self.callback:
189 self.callback(data)
190
191 def close(self) -> None:
192 pass
193
194
195# ---------------------------------------------------------------------------
196# Raw Event Trimming
197# ---------------------------------------------------------------------------
198
199# Structural keys preserved when trimming oversized raw events.
200_RAW_STRUCTURAL_KEYS = frozenset(
201 {
202 "type",
203 "id",
204 "tool_id",
205 "tool_name",
206 "role",
207 "event_type",
208 "timestamp",
209 }
210)
211
212_RAW_BYTE_LIMIT = 16_384 # 16 KB
213
214
215def safe_raw(
216 events: list[dict[str, Any]],
217 limit: int = _RAW_BYTE_LIMIT,
218) -> list[dict[str, Any]]:
219 """Return *events* as-is if small enough, otherwise a trimmed version.
220
221 When the JSON-serialized size exceeds *limit* bytes, each event is reduced
222 to its structural keys and a ``_raw_trimmed`` dict is appended with the
223 original byte count and the limit that was applied.
224 """
225 serialized = json.dumps(events, ensure_ascii=False)
226 if len(serialized.encode("utf-8")) <= limit:
227 return events
228
229 trimmed = [
230 {k: v for k, v in e.items() if k in _RAW_STRUCTURAL_KEYS} for e in events
231 ]
232 trimmed.append(
233 {"_raw_trimmed": {"original_bytes": len(serialized), "limit": limit}}
234 )
235 return trimmed
236
237
238__all__ = [
239 "Event",
240 "GenerateResult",
241 "JSONEventCallback",
242 "ThinkingEvent",
243 "USAGE_KEYS",
244 "safe_raw",
245]