personal memory agent
at main 467 lines 14 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4"""Tests for the Batch async batch processor.""" 5 6import asyncio 7from unittest.mock import AsyncMock, patch 8 9import pytest 10 11from think.batch import Batch, BatchRequest 12from think.models import GEMINI_FLASH, GEMINI_LITE 13 14 15def test_batch_request_creation(): 16 """Test BatchRequest can be created with required and custom params.""" 17 # Required params only 18 req = BatchRequest(contents="Test prompt", context="test.context") 19 assert req.contents == "Test prompt" 20 assert req.context == "test.context" 21 assert req.model is None 22 assert req.temperature == 0.3 23 assert req.response is None 24 assert req.error is None 25 26 # With model override 27 req2 = BatchRequest( 28 contents=["Part 1", "Part 2"], 29 context="test.context", 30 model=GEMINI_LITE, 31 temperature=0.7, 32 json_output=True, 33 ) 34 assert req2.contents == ["Part 1", "Part 2"] 35 assert req2.model == GEMINI_LITE 36 assert req2.temperature == 0.7 37 assert req2.json_output is True 38 39 40def test_batch_request_custom_attributes(): 41 """Test that arbitrary attributes can be added to BatchRequest.""" 42 req = BatchRequest(contents="Test", context="test.context") 43 req.frame_id = 123 44 req.stage = "initial" 45 req.custom_data = {"foo": "bar"} 46 47 assert req.frame_id == 123 48 assert req.stage == "initial" 49 assert req.custom_data == {"foo": "bar"} 50 51 52@pytest.mark.asyncio 53@patch("think.batch.agenerate", new_callable=AsyncMock) 54async def test_batch_basic(mock_agenerate): 55 """Test basic batch execution with single request.""" 56 mock_agenerate.return_value = "Response 1" 57 58 # Create batch and add request 59 batch = Batch(max_concurrent=5) 60 req = batch.create(contents="What is 2+2?", context="test.calc") 61 req.task_id = "calc1" 62 batch.add(req) 63 64 # Iterate and verify 65 results = [] 66 async for completed_req in batch.drain_batch(): 67 results.append(completed_req) 68 69 assert len(results) == 1 70 assert results[0].task_id == "calc1" 71 assert results[0].response == "Response 1" 72 assert results[0].error is None 73 assert results[0].duration > 0 74 75 # Verify API was called correctly 76 mock_agenerate.assert_called_once() 77 call_kwargs = mock_agenerate.call_args[1] 78 assert call_kwargs["contents"] == "What is 2+2?" 79 assert call_kwargs["context"] == "test.calc" 80 81 82@pytest.mark.asyncio 83@patch("think.batch.agenerate", new_callable=AsyncMock) 84async def test_batch_with_model_override(mock_agenerate): 85 """Test batch with explicit model override.""" 86 mock_agenerate.return_value = "Response" 87 88 batch = Batch(max_concurrent=5) 89 req = batch.create(contents="Test", context="test.context", model=GEMINI_FLASH) 90 batch.add(req) 91 92 results = [] 93 async for completed_req in batch.drain_batch(): 94 results.append(completed_req) 95 96 assert len(results) == 1 97 assert results[0].model_used == GEMINI_FLASH 98 99 # Verify model was passed through 100 call_kwargs = mock_agenerate.call_args[1] 101 assert call_kwargs["model"] == GEMINI_FLASH 102 103 104@pytest.mark.asyncio 105@patch("think.batch.agenerate", new_callable=AsyncMock) 106async def test_batch_multiple_requests(mock_agenerate): 107 """Test batch with multiple requests.""" 108 mock_agenerate.side_effect = ["Response 1", "Response 2", "Response 3"] 109 110 batch = Batch(max_concurrent=2) 111 112 # Add multiple requests 113 req1 = batch.create(contents="Prompt 1", context="test.context") 114 req1.id = 1 115 batch.add(req1) 116 117 req2 = batch.create(contents="Prompt 2", context="test.context") 118 req2.id = 2 119 batch.add(req2) 120 121 req3 = batch.create(contents="Prompt 3", context="test.context") 122 req3.id = 3 123 batch.add(req3) 124 125 # Collect results 126 results = [] 127 async for req in batch.drain_batch(): 128 results.append(req) 129 130 # Should have all 3 results 131 assert len(results) == 3 132 133 # Results may come in any order (concurrent execution) 134 result_ids = {r.id for r in results} 135 assert result_ids == {1, 2, 3} 136 137 # All should have responses 138 for r in results: 139 assert r.response is not None 140 assert r.error is None 141 142 143@pytest.mark.asyncio 144@patch("think.batch.agenerate", new_callable=AsyncMock) 145async def test_batch_error_handling(mock_agenerate): 146 """Test that errors are captured in request.error.""" 147 mock_agenerate.side_effect = ValueError("API error") 148 149 batch = Batch(max_concurrent=5) 150 req = batch.create(contents="Bad prompt", context="test.context") 151 req.id = "error_test" 152 batch.add(req) 153 154 # Get result 155 results = [] 156 async for r in batch.drain_batch(): 157 results.append(r) 158 159 assert len(results) == 1 160 assert results[0].id == "error_test" 161 assert results[0].response is None 162 assert results[0].error == "API error" 163 assert results[0].duration > 0 164 165 166@pytest.mark.asyncio 167@patch("think.batch.agenerate", new_callable=AsyncMock) 168async def test_batch_dynamic_adding(mock_agenerate): 169 """Test adding requests dynamically during iteration.""" 170 mock_agenerate.return_value = "Response" 171 172 batch = Batch(max_concurrent=5) 173 174 # Add initial request 175 req1 = batch.create(contents="Initial", context="test.context") 176 req1.stage = "initial" 177 batch.add(req1) 178 179 # Process and add more during iteration 180 results = [] 181 added_followup = False 182 183 async for req in batch.drain_batch(): 184 results.append(req) 185 186 # After first result, add a follow-up 187 if req.stage == "initial" and not added_followup: 188 req2 = batch.create(contents="Follow-up", context="test.context") 189 req2.stage = "followup" 190 batch.add(req2) 191 added_followup = True 192 193 # Should have both results 194 assert len(results) == 2 195 stages = {r.stage for r in results} 196 assert stages == {"initial", "followup"} 197 198 199@pytest.mark.asyncio 200@patch("think.batch.agenerate", new_callable=AsyncMock) 201async def test_batch_retry_pattern(mock_agenerate): 202 """Test retry pattern - add failed request back with different model.""" 203 # First call fails, second succeeds 204 call_count = 0 205 206 async def mock_response(*args, **kwargs): 207 nonlocal call_count 208 call_count += 1 209 if call_count == 1: 210 raise ValueError("Transient error") 211 return "Success on retry" 212 213 mock_agenerate.side_effect = mock_response 214 215 batch = Batch(max_concurrent=5) 216 217 # Add initial request 218 req1 = batch.create(contents="Test", context="test.context", model=GEMINI_FLASH) 219 req1.attempt = 1 220 batch.add(req1) 221 222 results = [] 223 async for req in batch.drain_batch(): 224 results.append(req) 225 226 # If error, retry with different model 227 if req.error and req.attempt == 1: 228 retry = batch.create( 229 contents=req.contents, context="test.context", model=GEMINI_LITE 230 ) 231 retry.attempt = 2 232 batch.add(retry) 233 234 # Should have both attempts 235 assert len(results) == 2 236 assert results[0].error is not None 237 assert results[0].attempt == 1 238 assert results[1].response == "Success on retry" 239 assert results[1].attempt == 2 240 241 242@pytest.mark.asyncio 243@patch("think.batch.agenerate", new_callable=AsyncMock) 244async def test_batch_factory_method(mock_agenerate): 245 """Test that batch.create() factory method works correctly.""" 246 mock_agenerate.return_value = "Response" 247 248 batch = Batch() 249 250 # Use factory method 251 req = batch.create( 252 contents="Test", 253 context="test.context", 254 model=GEMINI_LITE, 255 temperature=0.8, 256 json_output=True, 257 ) 258 259 assert isinstance(req, BatchRequest) 260 assert req.contents == "Test" 261 assert req.context == "test.context" 262 assert req.model == GEMINI_LITE 263 assert req.temperature == 0.8 264 assert req.json_output is True 265 266 267@pytest.mark.asyncio 268@patch("think.batch.agenerate", new_callable=AsyncMock) 269async def test_batch_can_add_after_draining(mock_agenerate): 270 """Test that adding after draining works (reusable batch).""" 271 mock_agenerate.side_effect = ["Response 1", "Response 2"] 272 273 batch = Batch() 274 275 # First batch 276 req1 = batch.create(contents="First", context="test.context") 277 req1.id = 1 278 batch.add(req1) 279 280 results = [] 281 async for req in batch.drain_batch(): 282 results.append(req) 283 284 assert len(results) == 1 285 assert results[0].id == 1 286 287 # Add more work after draining 288 req2 = batch.create(contents="Second", context="test.context") 289 req2.id = 2 290 batch.add(req2) 291 292 async for req in batch.drain_batch(): 293 results.append(req) 294 295 # Should have both results 296 assert len(results) == 2 297 assert {r.id for r in results} == {1, 2} 298 299 300@pytest.mark.asyncio 301@patch("think.batch.agenerate", new_callable=AsyncMock) 302async def test_batch_empty_batch(mock_agenerate): 303 """Test that empty batch (no requests) completes immediately.""" 304 batch = Batch() 305 306 results = [] 307 async for req in batch.drain_batch(): 308 results.append(req) 309 310 assert len(results) == 0 311 312 313@pytest.mark.asyncio 314@patch("think.batch.agenerate", new_callable=AsyncMock) 315async def test_batch_concurrency_limit(mock_agenerate): 316 """Test that semaphore limits concurrent requests.""" 317 # Track concurrent calls 318 concurrent_calls = 0 319 max_concurrent_seen = 0 320 lock = asyncio.Lock() 321 322 async def mock_with_tracking(*args, **kwargs): 323 nonlocal concurrent_calls, max_concurrent_seen 324 async with lock: 325 concurrent_calls += 1 326 max_concurrent_seen = max(max_concurrent_seen, concurrent_calls) 327 328 await asyncio.sleep(0.1) # Simulate API call 329 330 async with lock: 331 concurrent_calls -= 1 332 333 return "Response" 334 335 mock_agenerate.side_effect = mock_with_tracking 336 337 # Create batch with max_concurrent=2 338 batch = Batch(max_concurrent=2) 339 340 # Add 5 requests 341 for i in range(5): 342 req = batch.create(contents=f"Request {i}", context="test.context") 343 batch.add(req) 344 345 results = [] 346 async for req in batch.drain_batch(): 347 results.append(req) 348 349 assert len(results) == 5 350 # Should never exceed max_concurrent=2 351 assert max_concurrent_seen <= 2 352 353 354@pytest.mark.asyncio 355@patch("think.batch.agenerate", new_callable=AsyncMock) 356async def test_batch_update_method(mock_agenerate): 357 """Test batch.update() method for modifying and re-adding requests.""" 358 # Track which model was used in each call 359 call_models = [] 360 361 async def mock_track_model(*args, **kwargs): 362 call_models.append(kwargs.get("model", "unknown")) 363 return f"Response from {kwargs.get('model', 'unknown')}" 364 365 mock_agenerate.side_effect = mock_track_model 366 367 batch = Batch(max_concurrent=5) 368 369 # Create initial request 370 req = batch.create( 371 contents="Initial prompt", context="test.context", model=GEMINI_FLASH 372 ) 373 req.stage = "initial" 374 batch.add(req) 375 376 results = [] 377 result_count = 0 378 async for completed_req in batch.drain_batch(): 379 result_count += 1 380 # Capture the response at this point 381 results.append((result_count, completed_req.response, completed_req.stage)) 382 383 # After first result, update and re-add with different model 384 if result_count == 1: 385 batch.update( 386 completed_req, 387 contents="Updated prompt", 388 model=GEMINI_LITE, 389 stage="updated", # Update custom attribute too 390 custom_field="test_value", # Add new custom attribute 391 ) 392 393 # Should have both results 394 assert len(results) == 2 395 assert results[0][2] == "initial" # First result was initial stage 396 assert results[1][2] == "updated" # Second result was updated stage 397 398 # Verify models used 399 assert call_models == [GEMINI_FLASH, GEMINI_LITE] 400 401 # Verify correct responses at each stage 402 assert results[0][1] == f"Response from {GEMINI_FLASH}" 403 assert results[1][1] == f"Response from {GEMINI_LITE}" 404 405 # Verify custom attribute was set 406 assert req.custom_field == "test_value" 407 408 409def test_batch_request_with_timeout(): 410 """Test BatchRequest can be created with timeout_s parameter.""" 411 req = BatchRequest(contents="Test prompt", context="test.context", timeout_s=30) 412 assert req.timeout_s == 30 413 414 req2 = BatchRequest(contents="Test prompt", context="test.context", timeout_s=60.5) 415 assert req2.timeout_s == 60.5 416 417 # Default is None 418 req3 = BatchRequest(contents="Test prompt", context="test.context") 419 assert req3.timeout_s is None 420 421 422@pytest.mark.asyncio 423@patch("think.batch.agenerate", new_callable=AsyncMock) 424async def test_batch_timeout_passthrough(mock_agenerate): 425 """Test that timeout_s is passed through to agenerate.""" 426 mock_agenerate.return_value = "Response" 427 428 batch = Batch(max_concurrent=5) 429 430 # Create request with timeout_s 431 req = batch.create(contents="Test prompt", context="test.context", timeout_s=45) 432 batch.add(req) 433 434 results = [] 435 async for completed_req in batch.drain_batch(): 436 results.append(completed_req) 437 438 assert len(results) == 1 439 440 # Verify timeout_s was passed to agenerate 441 mock_agenerate.assert_called_once() 442 call_kwargs = mock_agenerate.call_args[1] 443 assert call_kwargs["timeout_s"] == 45 444 445 446@pytest.mark.asyncio 447@patch("think.batch.agenerate", new_callable=AsyncMock) 448async def test_batch_client_passthrough(mock_agenerate): 449 """Test that client is passed through to agenerate for Google connection reuse.""" 450 mock_agenerate.return_value = "Response" 451 452 # Create a mock client (would be genai.Client for Google) 453 mock_client = object() 454 455 batch = Batch(max_concurrent=5, client=mock_client) 456 req = batch.create(contents="Test", context="test.context") 457 batch.add(req) 458 459 results = [] 460 async for completed_req in batch.drain_batch(): 461 results.append(completed_req) 462 463 assert len(results) == 1 464 465 # Verify client was passed through 466 call_kwargs = mock_agenerate.call_args[1] 467 assert call_kwargs["client"] is mock_client