personal memory agent
at main 358 lines 9.9 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4"""Integration tests for Batch with real LLM APIs.""" 5 6import os 7import tempfile 8import time 9from pathlib import Path 10 11import pytest 12 13from think.batch import Batch 14from think.models import GEMINI_FLASH, GEMINI_LITE 15 16# Lite model for timing-sensitive tests (faster responses, less variance) 17_TIMING_MODEL = GEMINI_LITE 18 19# Default context for integration tests - uses Google provider 20TEST_CONTEXT = "test.batch.integration" 21 22 23@pytest.mark.asyncio 24@pytest.mark.integration 25@pytest.mark.requires_api 26async def test_batch_basic_execution(): 27 """Test basic batch execution with real API.""" 28 batch = Batch(max_concurrent=3) 29 30 # Add simple requests 31 req1 = batch.create( 32 contents="What is 2+2? Reply with just the number.", 33 context=TEST_CONTEXT, 34 model=GEMINI_FLASH, 35 ) 36 req1.id = "calc1" 37 batch.add(req1) 38 39 req2 = batch.create( 40 contents="What is 3+3? Reply with just the number.", 41 context=TEST_CONTEXT, 42 model=GEMINI_FLASH, 43 ) 44 req2.id = "calc2" 45 batch.add(req2) 46 47 # Collect results 48 results = [] 49 async for req in batch.drain_batch(): 50 results.append(req) 51 52 # Verify both completed 53 assert len(results) == 2 54 55 # Check IDs are preserved 56 result_ids = {r.id for r in results} 57 assert result_ids == {"calc1", "calc2"} 58 59 # Check responses 60 for r in results: 61 assert r.response is not None 62 assert r.error is None 63 assert r.duration > 0 64 assert ( 65 "2" in r.response 66 or "3" in r.response 67 or "4" in r.response 68 or "6" in r.response 69 ) 70 71 72@pytest.mark.asyncio 73@pytest.mark.integration 74@pytest.mark.requires_api 75async def test_batch_concurrent_timing(): 76 """Test that concurrent execution is actually faster than sequential.""" 77 # Sequential baseline 78 start = time.time() 79 batch_seq = Batch(max_concurrent=1) 80 for i in range(2): 81 req = batch_seq.create( 82 contents=f"Count to {i + 1}. Reply with just the number.", 83 context=TEST_CONTEXT, 84 model=_TIMING_MODEL, 85 ) 86 batch_seq.add(req) 87 88 seq_results = [] 89 async for req in batch_seq.drain_batch(): 90 seq_results.append(req) 91 seq_duration = time.time() - start 92 93 # Concurrent execution 94 start = time.time() 95 batch_conc = Batch(max_concurrent=2) 96 for i in range(2): 97 req = batch_conc.create( 98 contents=f"Count to {i + 1}. Reply with just the number.", 99 context=TEST_CONTEXT, 100 model=_TIMING_MODEL, 101 ) 102 batch_conc.add(req) 103 104 conc_results = [] 105 async for req in batch_conc.drain_batch(): 106 conc_results.append(req) 107 conc_duration = time.time() - start 108 109 # Both should complete successfully 110 assert len(seq_results) == 2 111 assert len(conc_results) == 2 112 113 # Concurrent should not be dramatically slower than sequential. 114 # We use a lenient threshold (1.5x) because API latency varies significantly, 115 # making precise timing comparisons unreliable. This catches actual concurrency 116 # bugs (like requests running sequentially) while tolerating normal variance. 117 assert conc_duration < seq_duration * 1.5, ( 118 f"Concurrent ({conc_duration:.2f}s) should not be much slower than " 119 f"sequential ({seq_duration:.2f}s)" 120 ) 121 122 123@pytest.mark.asyncio 124@pytest.mark.integration 125@pytest.mark.requires_api 126async def test_batch_json_output(): 127 """Test batch with JSON output mode.""" 128 batch = Batch(max_concurrent=2) 129 130 req = batch.create( 131 contents='Return a JSON object with "result": 10', 132 context=TEST_CONTEXT, 133 model=GEMINI_FLASH, 134 json_output=True, 135 ) 136 req.id = "json_test" 137 batch.add(req) 138 139 results = [] 140 async for r in batch.drain_batch(): 141 results.append(r) 142 143 assert len(results) == 1 144 assert results[0].response is not None 145 assert "{" in results[0].response 146 assert "}" in results[0].response 147 148 149@pytest.mark.asyncio 150@pytest.mark.integration 151@pytest.mark.requires_api 152async def test_batch_different_models(): 153 """Test batch with different models.""" 154 batch = Batch(max_concurrent=2) 155 156 req1 = batch.create( 157 contents="Say 'flash'", 158 context=TEST_CONTEXT, 159 model=GEMINI_FLASH, 160 ) 161 req1.model_type = "flash" 162 batch.add(req1) 163 164 req2 = batch.create( 165 contents="Say 'lite'", 166 context=TEST_CONTEXT, 167 model=GEMINI_LITE, 168 ) 169 req2.model_type = "lite" 170 batch.add(req2) 171 172 results = [] 173 async for r in batch.drain_batch(): 174 results.append(r) 175 176 assert len(results) == 2 177 178 # Both should succeed 179 for r in results: 180 assert r.response is not None 181 assert r.error is None 182 183 184@pytest.mark.asyncio 185@pytest.mark.integration 186@pytest.mark.requires_api 187async def test_batch_dynamic_adding(): 188 """Test multi-stage pattern - add stage 2 based on stage 1 results.""" 189 batch = Batch(max_concurrent=3) 190 191 # Stage 1: Initial requests 192 req1 = batch.create( 193 contents="What is 5+5? Just the number.", 194 context=TEST_CONTEXT, 195 model=GEMINI_FLASH, 196 ) 197 req1.stage = "stage1" 198 req1.value = 5 199 batch.add(req1) 200 201 stage1_count = 0 202 stage2_added = False 203 204 async for req in batch.drain_batch(): 205 if req.stage == "stage1": 206 stage1_count += 1 207 208 # Add stage 2 request based on result 209 if not stage2_added: 210 req2 = batch.create( 211 contents=f"Previous answer was {req.response}. Double it. Just the number.", 212 context=TEST_CONTEXT, 213 model=GEMINI_FLASH, 214 ) 215 req2.stage = "stage2" 216 batch.add(req2) 217 stage2_added = True 218 219 # Should have processed both stages 220 assert stage1_count == 1 221 222 223@pytest.mark.asyncio 224@pytest.mark.integration 225@pytest.mark.requires_api 226async def test_batch_token_logging(): 227 """Test that token logging works with batch execution.""" 228 with tempfile.TemporaryDirectory() as tmpdir: 229 os.environ["_SOLSTONE_JOURNAL_OVERRIDE"] = tmpdir 230 231 batch = Batch(max_concurrent=2) 232 233 req = batch.create( 234 contents="Say hello", 235 context=TEST_CONTEXT, 236 model=GEMINI_FLASH, 237 ) 238 batch.add(req) 239 240 results = [] 241 async for r in batch.drain_batch(): 242 results.append(r) 243 244 # Check that token logs were created 245 tokens_dir = Path(tmpdir) / "tokens" 246 if tokens_dir.exists(): 247 log_files = list(tokens_dir.glob("*.jsonl")) 248 # Should have at least one log file 249 assert len(log_files) >= 1 250 251 252@pytest.mark.asyncio 253@pytest.mark.integration 254@pytest.mark.requires_api 255async def test_batch_error_recovery(): 256 """Test retry pattern with real API (simulate by using invalid then valid).""" 257 batch = Batch(max_concurrent=2) 258 259 # This might error or succeed depending on model - just test the pattern 260 req1 = batch.create( 261 contents="What is 1+1? Reply with just the number.", 262 context=TEST_CONTEXT, 263 model=GEMINI_FLASH, 264 max_output_tokens=5, # Very small, might cause issues 265 ) 266 req1.attempt = 1 267 batch.add(req1) 268 269 retried = False 270 async for req in batch.drain_batch(): 271 if req.attempt == 1: 272 # Always add a retry request to test the pattern 273 if not retried: 274 req2 = batch.create( 275 contents="What is 1+1? Reply with just the number.", 276 context=TEST_CONTEXT, 277 model=GEMINI_FLASH, 278 max_output_tokens=100, # Normal size 279 ) 280 req2.attempt = 2 281 batch.add(req2) 282 retried = True 283 284 # Pattern should complete successfully 285 assert retried 286 287 288@pytest.mark.asyncio 289@pytest.mark.integration 290@pytest.mark.requires_api 291async def test_batch_client_reuse(): 292 """Test that client is reused across requests in batch (Google-specific).""" 293 from think.providers.google import get_or_create_client 294 295 # Create shared client 296 client = get_or_create_client() 297 298 # Use it in batch - client is passed through to Google backend 299 batch = Batch(max_concurrent=2, client=client) 300 301 req1 = batch.create( 302 contents="Say 'first'", 303 context=TEST_CONTEXT, 304 model=GEMINI_FLASH, 305 ) 306 batch.add(req1) 307 308 req2 = batch.create( 309 contents="Say 'second'", 310 context=TEST_CONTEXT, 311 model=GEMINI_FLASH, 312 ) 313 batch.add(req2) 314 315 results = [] 316 async for req in batch.drain_batch(): 317 results.append(req) 318 319 # Both should succeed with shared client 320 assert len(results) == 2 321 for r in results: 322 assert r.response is not None 323 assert r.error is None 324 325 326@pytest.mark.asyncio 327@pytest.mark.integration 328@pytest.mark.requires_api 329async def test_batch_custom_attributes_preserved(): 330 """Test that custom attributes added to requests are preserved.""" 331 batch = Batch(max_concurrent=2) 332 333 req = batch.create( 334 contents="What is 10+10? Just the number.", 335 context=TEST_CONTEXT, 336 model=GEMINI_FLASH, 337 ) 338 req.frame_id = 42 339 req.monitor = "DP-3" 340 req.metadata = {"foo": "bar", "nested": {"baz": 123}} 341 batch.add(req) 342 343 results = [] 344 async for r in batch.drain_batch(): 345 results.append(r) 346 347 assert len(results) == 1 348 result = results[0] 349 350 # Custom attributes should be preserved 351 assert result.frame_id == 42 352 assert result.monitor == "DP-3" 353 assert result.metadata == {"foo": "bar", "nested": {"baz": 123}} 354 355 # Result attributes should be populated 356 assert result.response is not None 357 assert result.error is None 358 assert result.duration > 0