personal memory agent
at main 327 lines 11 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4""" 5Async batch processing for LLM API requests. 6 7Provides Batch for concurrent execution of multiple LLM API calls 8with dynamic request queuing and result streaming via async iterator. 9Routes requests to providers based on context via the unified agenerate() API. 10 11Example: 12 batch = Batch(max_concurrent=5) 13 14 req = batch.create(contents="What is 2+2?", context="myapp.calc") 15 req.my_id = "calc1" 16 batch.add(req) 17 18 async for req in batch.drain_batch(): 19 print(f"{req.my_id}: {req.response}") 20 21Provider-specific features: 22 - client: Optional client for connection reuse (Google only, others use singletons) 23""" 24 25import asyncio 26import time 27from typing import Any, List, Optional, Union 28 29from .models import agenerate 30 31 32class BatchRequest: 33 """ 34 Mutable request object for a single LLM API call. 35 36 Core attributes are passed to agenerate(). Callers can add 37 arbitrary attributes for tracking (e.g., frame_id, stage, etc). 38 39 After execution, these attributes are populated: 40 - response: Optional[str] - Response text (None if error) 41 - error: Optional[str] - Error message (None if success) 42 - duration: float - Execution time in seconds 43 - model_used: str - Model that was used 44 """ 45 46 def __init__( 47 self, 48 contents: Union[str, List[Any]], 49 context: str, 50 model: Optional[str] = None, 51 temperature: float = 0.3, 52 max_output_tokens: int = 8192 * 2, 53 system_instruction: Optional[str] = None, 54 json_output: bool = False, 55 thinking_budget: Optional[int] = None, 56 timeout_s: Optional[float] = None, 57 ): 58 self.contents = contents 59 self.context = context 60 self.model = model 61 self.temperature = temperature 62 self.max_output_tokens = max_output_tokens 63 self.system_instruction = system_instruction 64 self.json_output = json_output 65 self.thinking_budget = thinking_budget 66 self.timeout_s = timeout_s 67 68 # Populated after execution 69 self.response: Optional[str] = None 70 self.error: Optional[str] = None 71 self.duration: float = 0.0 72 self.model_used: str = model or "" 73 74 75class Batch: 76 """ 77 Async batch processor for LLM API requests. 78 79 Manages concurrent execution with dynamic request queuing and result 80 streaming via async iterator pattern. Routes to providers via agenerate(). 81 82 Example: 83 batch = Batch(max_concurrent=5) 84 85 # Add requests 86 req1 = batch.create(contents="What is 2+2?", context="myapp.calc") 87 req1.task_id = "calc1" 88 batch.add(req1) 89 90 req2 = batch.create(contents="What is 3+3?", context="myapp.calc") 91 req2.task_id = "calc2" 92 batch.add(req2) 93 94 # Process results as they complete 95 async for req in batch.drain_batch(): 96 print(f"{req.task_id}: {req.response}") 97 """ 98 99 def __init__(self, max_concurrent: int = 5, client: Any = None): 100 """ 101 Initialize batch processor. 102 103 Parameters 104 ---------- 105 max_concurrent : int 106 Maximum number of concurrent API requests (default: 5) 107 client : Any, optional 108 Provider client for connection reuse. Passed through to backend. 109 Google: genai.Client instance for connection pooling 110 Other providers: Ignored (they use internal singletons) 111 """ 112 self.max_concurrent = max_concurrent 113 self.client = client 114 self.semaphore = asyncio.Semaphore(max_concurrent) 115 self.result_queue: asyncio.Queue = asyncio.Queue() 116 self.pending_tasks: set = set() 117 118 def create( 119 self, 120 contents: Union[str, List[Any]], 121 context: str, 122 model: Optional[str] = None, 123 temperature: float = 0.3, 124 max_output_tokens: int = 8192 * 2, 125 system_instruction: Optional[str] = None, 126 json_output: bool = False, 127 thinking_budget: Optional[int] = None, 128 timeout_s: Optional[float] = None, 129 ) -> BatchRequest: 130 """ 131 Create a new BatchRequest. 132 133 Convenience factory method. Caller can add arbitrary attributes 134 to the returned request before calling add(). 135 136 Parameters 137 ---------- 138 contents : str or List 139 The content to send to the model 140 context : str 141 Context string for provider routing (e.g., "observe.describe.frame") 142 model : str, optional 143 Model override. If not provided, resolved from context. 144 145 Returns 146 ------- 147 BatchRequest 148 New request object ready to be customized and added 149 """ 150 return BatchRequest( 151 contents=contents, 152 context=context, 153 model=model, 154 temperature=temperature, 155 max_output_tokens=max_output_tokens, 156 system_instruction=system_instruction, 157 json_output=json_output, 158 thinking_budget=thinking_budget, 159 timeout_s=timeout_s, 160 ) 161 162 def add(self, request: BatchRequest) -> None: 163 """ 164 Add request to batch for execution. 165 166 Request will be executed concurrently (up to max_concurrent limit). 167 Non-blocking - returns immediately. Can be called at any time, even 168 during iteration or after draining. 169 170 Parameters 171 ---------- 172 request : BatchRequest 173 Request to execute 174 """ 175 task = asyncio.create_task(self._execute_request(request)) 176 self.pending_tasks.add(task) 177 178 def update(self, request: BatchRequest, **kwargs) -> None: 179 """ 180 Update request attributes and re-add to batch for execution. 181 182 This is useful for retries or multi-stage processing where you want 183 to reuse the same request object with different parameters. 184 185 Parameters 186 ---------- 187 request : BatchRequest 188 Request to update and re-execute 189 **kwargs 190 Any attributes to update on the request object 191 192 Example 193 ------- 194 >>> batch.update( 195 ... req, 196 ... contents="New prompt", 197 ... temperature=0.8, 198 ... custom_attr="foo" 199 ... ) 200 """ 201 # Update any provided attributes 202 for key, value in kwargs.items(): 203 setattr(request, key, value) 204 205 # Clear previous execution results 206 request.response = None 207 request.error = None 208 request.duration = 0.0 209 210 # Re-add to batch 211 self.add(request) 212 213 def is_drained(self) -> bool: 214 """ 215 Check if batch is fully drained. 216 217 Returns True when there are no pending tasks and no results waiting 218 in the queue. 219 220 Returns 221 ------- 222 bool 223 True if batch is drained, False otherwise 224 """ 225 # Clean up completed tasks 226 self.pending_tasks = {t for t in self.pending_tasks if not t.done()} 227 return len(self.pending_tasks) == 0 and self.result_queue.empty() 228 229 async def wait_until_drained(self) -> None: 230 """ 231 Wait until all pending work completes and queue is empty. 232 233 Blocks until is_drained() returns True. 234 """ 235 while not self.is_drained(): 236 await asyncio.sleep(0.1) 237 238 async def _execute_request(self, request: BatchRequest) -> None: 239 """ 240 Execute a single request and put result in queue. 241 242 Parameters 243 ---------- 244 request : BatchRequest 245 Request to execute (will be modified in place) 246 """ 247 start_time = time.time() 248 try: 249 async with self.semaphore: 250 # Build kwargs for provider-specific options 251 kwargs: dict = {} 252 if self.client is not None: 253 kwargs["client"] = self.client 254 if request.model is not None: 255 kwargs["model"] = request.model 256 257 response = await agenerate( 258 contents=request.contents, 259 context=request.context, 260 temperature=request.temperature, 261 max_output_tokens=request.max_output_tokens, 262 system_instruction=request.system_instruction, 263 json_output=request.json_output, 264 thinking_budget=request.thinking_budget, 265 timeout_s=request.timeout_s, 266 **kwargs, 267 ) 268 request.duration = time.time() - start_time 269 request.response = response 270 request.error = None 271 272 # Track which model was actually used 273 if request.model: 274 request.model_used = request.model 275 else: 276 # Model was resolved from context - we don't have easy access 277 # to what was resolved, so leave as empty string 278 pass 279 except Exception as e: 280 request.duration = time.time() - start_time 281 request.response = None 282 request.error = str(e) 283 284 # Put completed request in result queue 285 await self.result_queue.put(request) 286 287 async def drain_batch(self): 288 """ 289 Async generator that yields completed requests until batch is drained. 290 291 Yields results from the queue while there's still pending work OR 292 results waiting. Stops when both pending_tasks is empty AND queue 293 is empty. 294 295 This can be called multiple times - each call will drain whatever 296 work is currently in the batch. 297 298 Yields 299 ------ 300 BatchRequest 301 Completed request with response/error populated 302 303 Example 304 ------- 305 >>> async for req in batch.drain_batch(): 306 ... print(req.response) 307 ... if req.error: 308 ... batch.add(req) # Retry on error 309 """ 310 while True: 311 # Check if we're truly drained 312 self.pending_tasks = {t for t in self.pending_tasks if not t.done()} 313 314 # If drained, stop iteration 315 if len(self.pending_tasks) == 0 and self.result_queue.empty(): 316 break 317 318 # Try to get a result (with timeout to avoid blocking forever) 319 try: 320 result = await asyncio.wait_for(self.result_queue.get(), timeout=0.1) 321 yield result 322 except asyncio.TimeoutError: 323 # No result ready yet, but might have pending work 324 continue 325 326 327__all__ = ["BatchRequest", "Batch"]