personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4"""Tests for the Batch async batch processor."""
5
6import asyncio
7from unittest.mock import AsyncMock, patch
8
9import pytest
10
11from think.batch import Batch, BatchRequest
12from think.models import GEMINI_FLASH, GEMINI_LITE
13
14
15def test_batch_request_creation():
16 """Test BatchRequest can be created with required and custom params."""
17 # Required params only
18 req = BatchRequest(contents="Test prompt", context="test.context")
19 assert req.contents == "Test prompt"
20 assert req.context == "test.context"
21 assert req.model is None
22 assert req.temperature == 0.3
23 assert req.response is None
24 assert req.error is None
25
26 # With model override
27 req2 = BatchRequest(
28 contents=["Part 1", "Part 2"],
29 context="test.context",
30 model=GEMINI_LITE,
31 temperature=0.7,
32 json_output=True,
33 )
34 assert req2.contents == ["Part 1", "Part 2"]
35 assert req2.model == GEMINI_LITE
36 assert req2.temperature == 0.7
37 assert req2.json_output is True
38
39
40def test_batch_request_custom_attributes():
41 """Test that arbitrary attributes can be added to BatchRequest."""
42 req = BatchRequest(contents="Test", context="test.context")
43 req.frame_id = 123
44 req.stage = "initial"
45 req.custom_data = {"foo": "bar"}
46
47 assert req.frame_id == 123
48 assert req.stage == "initial"
49 assert req.custom_data == {"foo": "bar"}
50
51
52@pytest.mark.asyncio
53@patch("think.batch.agenerate", new_callable=AsyncMock)
54async def test_batch_basic(mock_agenerate):
55 """Test basic batch execution with single request."""
56 mock_agenerate.return_value = "Response 1"
57
58 # Create batch and add request
59 batch = Batch(max_concurrent=5)
60 req = batch.create(contents="What is 2+2?", context="test.calc")
61 req.task_id = "calc1"
62 batch.add(req)
63
64 # Iterate and verify
65 results = []
66 async for completed_req in batch.drain_batch():
67 results.append(completed_req)
68
69 assert len(results) == 1
70 assert results[0].task_id == "calc1"
71 assert results[0].response == "Response 1"
72 assert results[0].error is None
73 assert results[0].duration > 0
74
75 # Verify API was called correctly
76 mock_agenerate.assert_called_once()
77 call_kwargs = mock_agenerate.call_args[1]
78 assert call_kwargs["contents"] == "What is 2+2?"
79 assert call_kwargs["context"] == "test.calc"
80
81
82@pytest.mark.asyncio
83@patch("think.batch.agenerate", new_callable=AsyncMock)
84async def test_batch_with_model_override(mock_agenerate):
85 """Test batch with explicit model override."""
86 mock_agenerate.return_value = "Response"
87
88 batch = Batch(max_concurrent=5)
89 req = batch.create(contents="Test", context="test.context", model=GEMINI_FLASH)
90 batch.add(req)
91
92 results = []
93 async for completed_req in batch.drain_batch():
94 results.append(completed_req)
95
96 assert len(results) == 1
97 assert results[0].model_used == GEMINI_FLASH
98
99 # Verify model was passed through
100 call_kwargs = mock_agenerate.call_args[1]
101 assert call_kwargs["model"] == GEMINI_FLASH
102
103
104@pytest.mark.asyncio
105@patch("think.batch.agenerate", new_callable=AsyncMock)
106async def test_batch_multiple_requests(mock_agenerate):
107 """Test batch with multiple requests."""
108 mock_agenerate.side_effect = ["Response 1", "Response 2", "Response 3"]
109
110 batch = Batch(max_concurrent=2)
111
112 # Add multiple requests
113 req1 = batch.create(contents="Prompt 1", context="test.context")
114 req1.id = 1
115 batch.add(req1)
116
117 req2 = batch.create(contents="Prompt 2", context="test.context")
118 req2.id = 2
119 batch.add(req2)
120
121 req3 = batch.create(contents="Prompt 3", context="test.context")
122 req3.id = 3
123 batch.add(req3)
124
125 # Collect results
126 results = []
127 async for req in batch.drain_batch():
128 results.append(req)
129
130 # Should have all 3 results
131 assert len(results) == 3
132
133 # Results may come in any order (concurrent execution)
134 result_ids = {r.id for r in results}
135 assert result_ids == {1, 2, 3}
136
137 # All should have responses
138 for r in results:
139 assert r.response is not None
140 assert r.error is None
141
142
143@pytest.mark.asyncio
144@patch("think.batch.agenerate", new_callable=AsyncMock)
145async def test_batch_error_handling(mock_agenerate):
146 """Test that errors are captured in request.error."""
147 mock_agenerate.side_effect = ValueError("API error")
148
149 batch = Batch(max_concurrent=5)
150 req = batch.create(contents="Bad prompt", context="test.context")
151 req.id = "error_test"
152 batch.add(req)
153
154 # Get result
155 results = []
156 async for r in batch.drain_batch():
157 results.append(r)
158
159 assert len(results) == 1
160 assert results[0].id == "error_test"
161 assert results[0].response is None
162 assert results[0].error == "API error"
163 assert results[0].duration > 0
164
165
166@pytest.mark.asyncio
167@patch("think.batch.agenerate", new_callable=AsyncMock)
168async def test_batch_dynamic_adding(mock_agenerate):
169 """Test adding requests dynamically during iteration."""
170 mock_agenerate.return_value = "Response"
171
172 batch = Batch(max_concurrent=5)
173
174 # Add initial request
175 req1 = batch.create(contents="Initial", context="test.context")
176 req1.stage = "initial"
177 batch.add(req1)
178
179 # Process and add more during iteration
180 results = []
181 added_followup = False
182
183 async for req in batch.drain_batch():
184 results.append(req)
185
186 # After first result, add a follow-up
187 if req.stage == "initial" and not added_followup:
188 req2 = batch.create(contents="Follow-up", context="test.context")
189 req2.stage = "followup"
190 batch.add(req2)
191 added_followup = True
192
193 # Should have both results
194 assert len(results) == 2
195 stages = {r.stage for r in results}
196 assert stages == {"initial", "followup"}
197
198
199@pytest.mark.asyncio
200@patch("think.batch.agenerate", new_callable=AsyncMock)
201async def test_batch_retry_pattern(mock_agenerate):
202 """Test retry pattern - add failed request back with different model."""
203 # First call fails, second succeeds
204 call_count = 0
205
206 async def mock_response(*args, **kwargs):
207 nonlocal call_count
208 call_count += 1
209 if call_count == 1:
210 raise ValueError("Transient error")
211 return "Success on retry"
212
213 mock_agenerate.side_effect = mock_response
214
215 batch = Batch(max_concurrent=5)
216
217 # Add initial request
218 req1 = batch.create(contents="Test", context="test.context", model=GEMINI_FLASH)
219 req1.attempt = 1
220 batch.add(req1)
221
222 results = []
223 async for req in batch.drain_batch():
224 results.append(req)
225
226 # If error, retry with different model
227 if req.error and req.attempt == 1:
228 retry = batch.create(
229 contents=req.contents, context="test.context", model=GEMINI_LITE
230 )
231 retry.attempt = 2
232 batch.add(retry)
233
234 # Should have both attempts
235 assert len(results) == 2
236 assert results[0].error is not None
237 assert results[0].attempt == 1
238 assert results[1].response == "Success on retry"
239 assert results[1].attempt == 2
240
241
242@pytest.mark.asyncio
243@patch("think.batch.agenerate", new_callable=AsyncMock)
244async def test_batch_factory_method(mock_agenerate):
245 """Test that batch.create() factory method works correctly."""
246 mock_agenerate.return_value = "Response"
247
248 batch = Batch()
249
250 # Use factory method
251 req = batch.create(
252 contents="Test",
253 context="test.context",
254 model=GEMINI_LITE,
255 temperature=0.8,
256 json_output=True,
257 )
258
259 assert isinstance(req, BatchRequest)
260 assert req.contents == "Test"
261 assert req.context == "test.context"
262 assert req.model == GEMINI_LITE
263 assert req.temperature == 0.8
264 assert req.json_output is True
265
266
267@pytest.mark.asyncio
268@patch("think.batch.agenerate", new_callable=AsyncMock)
269async def test_batch_can_add_after_draining(mock_agenerate):
270 """Test that adding after draining works (reusable batch)."""
271 mock_agenerate.side_effect = ["Response 1", "Response 2"]
272
273 batch = Batch()
274
275 # First batch
276 req1 = batch.create(contents="First", context="test.context")
277 req1.id = 1
278 batch.add(req1)
279
280 results = []
281 async for req in batch.drain_batch():
282 results.append(req)
283
284 assert len(results) == 1
285 assert results[0].id == 1
286
287 # Add more work after draining
288 req2 = batch.create(contents="Second", context="test.context")
289 req2.id = 2
290 batch.add(req2)
291
292 async for req in batch.drain_batch():
293 results.append(req)
294
295 # Should have both results
296 assert len(results) == 2
297 assert {r.id for r in results} == {1, 2}
298
299
300@pytest.mark.asyncio
301@patch("think.batch.agenerate", new_callable=AsyncMock)
302async def test_batch_empty_batch(mock_agenerate):
303 """Test that empty batch (no requests) completes immediately."""
304 batch = Batch()
305
306 results = []
307 async for req in batch.drain_batch():
308 results.append(req)
309
310 assert len(results) == 0
311
312
313@pytest.mark.asyncio
314@patch("think.batch.agenerate", new_callable=AsyncMock)
315async def test_batch_concurrency_limit(mock_agenerate):
316 """Test that semaphore limits concurrent requests."""
317 # Track concurrent calls
318 concurrent_calls = 0
319 max_concurrent_seen = 0
320 lock = asyncio.Lock()
321
322 async def mock_with_tracking(*args, **kwargs):
323 nonlocal concurrent_calls, max_concurrent_seen
324 async with lock:
325 concurrent_calls += 1
326 max_concurrent_seen = max(max_concurrent_seen, concurrent_calls)
327
328 await asyncio.sleep(0.1) # Simulate API call
329
330 async with lock:
331 concurrent_calls -= 1
332
333 return "Response"
334
335 mock_agenerate.side_effect = mock_with_tracking
336
337 # Create batch with max_concurrent=2
338 batch = Batch(max_concurrent=2)
339
340 # Add 5 requests
341 for i in range(5):
342 req = batch.create(contents=f"Request {i}", context="test.context")
343 batch.add(req)
344
345 results = []
346 async for req in batch.drain_batch():
347 results.append(req)
348
349 assert len(results) == 5
350 # Should never exceed max_concurrent=2
351 assert max_concurrent_seen <= 2
352
353
354@pytest.mark.asyncio
355@patch("think.batch.agenerate", new_callable=AsyncMock)
356async def test_batch_update_method(mock_agenerate):
357 """Test batch.update() method for modifying and re-adding requests."""
358 # Track which model was used in each call
359 call_models = []
360
361 async def mock_track_model(*args, **kwargs):
362 call_models.append(kwargs.get("model", "unknown"))
363 return f"Response from {kwargs.get('model', 'unknown')}"
364
365 mock_agenerate.side_effect = mock_track_model
366
367 batch = Batch(max_concurrent=5)
368
369 # Create initial request
370 req = batch.create(
371 contents="Initial prompt", context="test.context", model=GEMINI_FLASH
372 )
373 req.stage = "initial"
374 batch.add(req)
375
376 results = []
377 result_count = 0
378 async for completed_req in batch.drain_batch():
379 result_count += 1
380 # Capture the response at this point
381 results.append((result_count, completed_req.response, completed_req.stage))
382
383 # After first result, update and re-add with different model
384 if result_count == 1:
385 batch.update(
386 completed_req,
387 contents="Updated prompt",
388 model=GEMINI_LITE,
389 stage="updated", # Update custom attribute too
390 custom_field="test_value", # Add new custom attribute
391 )
392
393 # Should have both results
394 assert len(results) == 2
395 assert results[0][2] == "initial" # First result was initial stage
396 assert results[1][2] == "updated" # Second result was updated stage
397
398 # Verify models used
399 assert call_models == [GEMINI_FLASH, GEMINI_LITE]
400
401 # Verify correct responses at each stage
402 assert results[0][1] == f"Response from {GEMINI_FLASH}"
403 assert results[1][1] == f"Response from {GEMINI_LITE}"
404
405 # Verify custom attribute was set
406 assert req.custom_field == "test_value"
407
408
409def test_batch_request_with_timeout():
410 """Test BatchRequest can be created with timeout_s parameter."""
411 req = BatchRequest(contents="Test prompt", context="test.context", timeout_s=30)
412 assert req.timeout_s == 30
413
414 req2 = BatchRequest(contents="Test prompt", context="test.context", timeout_s=60.5)
415 assert req2.timeout_s == 60.5
416
417 # Default is None
418 req3 = BatchRequest(contents="Test prompt", context="test.context")
419 assert req3.timeout_s is None
420
421
422@pytest.mark.asyncio
423@patch("think.batch.agenerate", new_callable=AsyncMock)
424async def test_batch_timeout_passthrough(mock_agenerate):
425 """Test that timeout_s is passed through to agenerate."""
426 mock_agenerate.return_value = "Response"
427
428 batch = Batch(max_concurrent=5)
429
430 # Create request with timeout_s
431 req = batch.create(contents="Test prompt", context="test.context", timeout_s=45)
432 batch.add(req)
433
434 results = []
435 async for completed_req in batch.drain_batch():
436 results.append(completed_req)
437
438 assert len(results) == 1
439
440 # Verify timeout_s was passed to agenerate
441 mock_agenerate.assert_called_once()
442 call_kwargs = mock_agenerate.call_args[1]
443 assert call_kwargs["timeout_s"] == 45
444
445
446@pytest.mark.asyncio
447@patch("think.batch.agenerate", new_callable=AsyncMock)
448async def test_batch_client_passthrough(mock_agenerate):
449 """Test that client is passed through to agenerate for Google connection reuse."""
450 mock_agenerate.return_value = "Response"
451
452 # Create a mock client (would be genai.Client for Google)
453 mock_client = object()
454
455 batch = Batch(max_concurrent=5, client=mock_client)
456 req = batch.create(contents="Test", context="test.context")
457 batch.add(req)
458
459 results = []
460 async for completed_req in batch.drain_batch():
461 results.append(completed_req)
462
463 assert len(results) == 1
464
465 # Verify client was passed through
466 call_kwargs = mock_agenerate.call_args[1]
467 assert call_kwargs["client"] is mock_client