personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4"""
5Async batch processing for LLM API requests.
6
7Provides Batch for concurrent execution of multiple LLM API calls
8with dynamic request queuing and result streaming via async iterator.
9Routes requests to providers based on context via the unified agenerate() API.
10
11Example:
12 batch = Batch(max_concurrent=5)
13
14 req = batch.create(contents="What is 2+2?", context="myapp.calc")
15 req.my_id = "calc1"
16 batch.add(req)
17
18 async for req in batch.drain_batch():
19 print(f"{req.my_id}: {req.response}")
20
21Provider-specific features:
22 - client: Optional client for connection reuse (Google only, others use singletons)
23"""
24
25import asyncio
26import time
27from typing import Any, List, Optional, Union
28
29from .models import agenerate
30
31
32class BatchRequest:
33 """
34 Mutable request object for a single LLM API call.
35
36 Core attributes are passed to agenerate(). Callers can add
37 arbitrary attributes for tracking (e.g., frame_id, stage, etc).
38
39 After execution, these attributes are populated:
40 - response: Optional[str] - Response text (None if error)
41 - error: Optional[str] - Error message (None if success)
42 - duration: float - Execution time in seconds
43 - model_used: str - Model that was used
44 """
45
46 def __init__(
47 self,
48 contents: Union[str, List[Any]],
49 context: str,
50 model: Optional[str] = None,
51 temperature: float = 0.3,
52 max_output_tokens: int = 8192 * 2,
53 system_instruction: Optional[str] = None,
54 json_output: bool = False,
55 thinking_budget: Optional[int] = None,
56 timeout_s: Optional[float] = None,
57 ):
58 self.contents = contents
59 self.context = context
60 self.model = model
61 self.temperature = temperature
62 self.max_output_tokens = max_output_tokens
63 self.system_instruction = system_instruction
64 self.json_output = json_output
65 self.thinking_budget = thinking_budget
66 self.timeout_s = timeout_s
67
68 # Populated after execution
69 self.response: Optional[str] = None
70 self.error: Optional[str] = None
71 self.duration: float = 0.0
72 self.model_used: str = model or ""
73
74
75class Batch:
76 """
77 Async batch processor for LLM API requests.
78
79 Manages concurrent execution with dynamic request queuing and result
80 streaming via async iterator pattern. Routes to providers via agenerate().
81
82 Example:
83 batch = Batch(max_concurrent=5)
84
85 # Add requests
86 req1 = batch.create(contents="What is 2+2?", context="myapp.calc")
87 req1.task_id = "calc1"
88 batch.add(req1)
89
90 req2 = batch.create(contents="What is 3+3?", context="myapp.calc")
91 req2.task_id = "calc2"
92 batch.add(req2)
93
94 # Process results as they complete
95 async for req in batch.drain_batch():
96 print(f"{req.task_id}: {req.response}")
97 """
98
99 def __init__(self, max_concurrent: int = 5, client: Any = None):
100 """
101 Initialize batch processor.
102
103 Parameters
104 ----------
105 max_concurrent : int
106 Maximum number of concurrent API requests (default: 5)
107 client : Any, optional
108 Provider client for connection reuse. Passed through to backend.
109 Google: genai.Client instance for connection pooling
110 Other providers: Ignored (they use internal singletons)
111 """
112 self.max_concurrent = max_concurrent
113 self.client = client
114 self.semaphore = asyncio.Semaphore(max_concurrent)
115 self.result_queue: asyncio.Queue = asyncio.Queue()
116 self.pending_tasks: set = set()
117
118 def create(
119 self,
120 contents: Union[str, List[Any]],
121 context: str,
122 model: Optional[str] = None,
123 temperature: float = 0.3,
124 max_output_tokens: int = 8192 * 2,
125 system_instruction: Optional[str] = None,
126 json_output: bool = False,
127 thinking_budget: Optional[int] = None,
128 timeout_s: Optional[float] = None,
129 ) -> BatchRequest:
130 """
131 Create a new BatchRequest.
132
133 Convenience factory method. Caller can add arbitrary attributes
134 to the returned request before calling add().
135
136 Parameters
137 ----------
138 contents : str or List
139 The content to send to the model
140 context : str
141 Context string for provider routing (e.g., "observe.describe.frame")
142 model : str, optional
143 Model override. If not provided, resolved from context.
144
145 Returns
146 -------
147 BatchRequest
148 New request object ready to be customized and added
149 """
150 return BatchRequest(
151 contents=contents,
152 context=context,
153 model=model,
154 temperature=temperature,
155 max_output_tokens=max_output_tokens,
156 system_instruction=system_instruction,
157 json_output=json_output,
158 thinking_budget=thinking_budget,
159 timeout_s=timeout_s,
160 )
161
162 def add(self, request: BatchRequest) -> None:
163 """
164 Add request to batch for execution.
165
166 Request will be executed concurrently (up to max_concurrent limit).
167 Non-blocking - returns immediately. Can be called at any time, even
168 during iteration or after draining.
169
170 Parameters
171 ----------
172 request : BatchRequest
173 Request to execute
174 """
175 task = asyncio.create_task(self._execute_request(request))
176 self.pending_tasks.add(task)
177
178 def update(self, request: BatchRequest, **kwargs) -> None:
179 """
180 Update request attributes and re-add to batch for execution.
181
182 This is useful for retries or multi-stage processing where you want
183 to reuse the same request object with different parameters.
184
185 Parameters
186 ----------
187 request : BatchRequest
188 Request to update and re-execute
189 **kwargs
190 Any attributes to update on the request object
191
192 Example
193 -------
194 >>> batch.update(
195 ... req,
196 ... contents="New prompt",
197 ... temperature=0.8,
198 ... custom_attr="foo"
199 ... )
200 """
201 # Update any provided attributes
202 for key, value in kwargs.items():
203 setattr(request, key, value)
204
205 # Clear previous execution results
206 request.response = None
207 request.error = None
208 request.duration = 0.0
209
210 # Re-add to batch
211 self.add(request)
212
213 def is_drained(self) -> bool:
214 """
215 Check if batch is fully drained.
216
217 Returns True when there are no pending tasks and no results waiting
218 in the queue.
219
220 Returns
221 -------
222 bool
223 True if batch is drained, False otherwise
224 """
225 # Clean up completed tasks
226 self.pending_tasks = {t for t in self.pending_tasks if not t.done()}
227 return len(self.pending_tasks) == 0 and self.result_queue.empty()
228
229 async def wait_until_drained(self) -> None:
230 """
231 Wait until all pending work completes and queue is empty.
232
233 Blocks until is_drained() returns True.
234 """
235 while not self.is_drained():
236 await asyncio.sleep(0.1)
237
238 async def _execute_request(self, request: BatchRequest) -> None:
239 """
240 Execute a single request and put result in queue.
241
242 Parameters
243 ----------
244 request : BatchRequest
245 Request to execute (will be modified in place)
246 """
247 start_time = time.time()
248 try:
249 async with self.semaphore:
250 # Build kwargs for provider-specific options
251 kwargs: dict = {}
252 if self.client is not None:
253 kwargs["client"] = self.client
254 if request.model is not None:
255 kwargs["model"] = request.model
256
257 response = await agenerate(
258 contents=request.contents,
259 context=request.context,
260 temperature=request.temperature,
261 max_output_tokens=request.max_output_tokens,
262 system_instruction=request.system_instruction,
263 json_output=request.json_output,
264 thinking_budget=request.thinking_budget,
265 timeout_s=request.timeout_s,
266 **kwargs,
267 )
268 request.duration = time.time() - start_time
269 request.response = response
270 request.error = None
271
272 # Track which model was actually used
273 if request.model:
274 request.model_used = request.model
275 else:
276 # Model was resolved from context - we don't have easy access
277 # to what was resolved, so leave as empty string
278 pass
279 except Exception as e:
280 request.duration = time.time() - start_time
281 request.response = None
282 request.error = str(e)
283
284 # Put completed request in result queue
285 await self.result_queue.put(request)
286
287 async def drain_batch(self):
288 """
289 Async generator that yields completed requests until batch is drained.
290
291 Yields results from the queue while there's still pending work OR
292 results waiting. Stops when both pending_tasks is empty AND queue
293 is empty.
294
295 This can be called multiple times - each call will drain whatever
296 work is currently in the batch.
297
298 Yields
299 ------
300 BatchRequest
301 Completed request with response/error populated
302
303 Example
304 -------
305 >>> async for req in batch.drain_batch():
306 ... print(req.response)
307 ... if req.error:
308 ... batch.add(req) # Retry on error
309 """
310 while True:
311 # Check if we're truly drained
312 self.pending_tasks = {t for t in self.pending_tasks if not t.done()}
313
314 # If drained, stop iteration
315 if len(self.pending_tasks) == 0 and self.result_queue.empty():
316 break
317
318 # Try to get a result (with timeout to avoid blocking forever)
319 try:
320 result = await asyncio.wait_for(self.result_queue.get(), timeout=0.1)
321 yield result
322 except asyncio.TimeoutError:
323 # No result ready yet, but might have pending work
324 continue
325
326
327__all__ = ["BatchRequest", "Batch"]