personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4"""Integration tests for Batch with real LLM APIs."""
5
6import os
7import tempfile
8import time
9from pathlib import Path
10
11import pytest
12
13from think.batch import Batch
14from think.models import GEMINI_FLASH, GEMINI_LITE
15
16# Lite model for timing-sensitive tests (faster responses, less variance)
17_TIMING_MODEL = GEMINI_LITE
18
19# Default context for integration tests - uses Google provider
20TEST_CONTEXT = "test.batch.integration"
21
22
23@pytest.mark.asyncio
24@pytest.mark.integration
25@pytest.mark.requires_api
26async def test_batch_basic_execution():
27 """Test basic batch execution with real API."""
28 batch = Batch(max_concurrent=3)
29
30 # Add simple requests
31 req1 = batch.create(
32 contents="What is 2+2? Reply with just the number.",
33 context=TEST_CONTEXT,
34 model=GEMINI_FLASH,
35 )
36 req1.id = "calc1"
37 batch.add(req1)
38
39 req2 = batch.create(
40 contents="What is 3+3? Reply with just the number.",
41 context=TEST_CONTEXT,
42 model=GEMINI_FLASH,
43 )
44 req2.id = "calc2"
45 batch.add(req2)
46
47 # Collect results
48 results = []
49 async for req in batch.drain_batch():
50 results.append(req)
51
52 # Verify both completed
53 assert len(results) == 2
54
55 # Check IDs are preserved
56 result_ids = {r.id for r in results}
57 assert result_ids == {"calc1", "calc2"}
58
59 # Check responses
60 for r in results:
61 assert r.response is not None
62 assert r.error is None
63 assert r.duration > 0
64 assert (
65 "2" in r.response
66 or "3" in r.response
67 or "4" in r.response
68 or "6" in r.response
69 )
70
71
72@pytest.mark.asyncio
73@pytest.mark.integration
74@pytest.mark.requires_api
75async def test_batch_concurrent_timing():
76 """Test that concurrent execution is actually faster than sequential."""
77 # Sequential baseline
78 start = time.time()
79 batch_seq = Batch(max_concurrent=1)
80 for i in range(2):
81 req = batch_seq.create(
82 contents=f"Count to {i + 1}. Reply with just the number.",
83 context=TEST_CONTEXT,
84 model=_TIMING_MODEL,
85 )
86 batch_seq.add(req)
87
88 seq_results = []
89 async for req in batch_seq.drain_batch():
90 seq_results.append(req)
91 seq_duration = time.time() - start
92
93 # Concurrent execution
94 start = time.time()
95 batch_conc = Batch(max_concurrent=2)
96 for i in range(2):
97 req = batch_conc.create(
98 contents=f"Count to {i + 1}. Reply with just the number.",
99 context=TEST_CONTEXT,
100 model=_TIMING_MODEL,
101 )
102 batch_conc.add(req)
103
104 conc_results = []
105 async for req in batch_conc.drain_batch():
106 conc_results.append(req)
107 conc_duration = time.time() - start
108
109 # Both should complete successfully
110 assert len(seq_results) == 2
111 assert len(conc_results) == 2
112
113 # Concurrent should not be dramatically slower than sequential.
114 # We use a lenient threshold (1.5x) because API latency varies significantly,
115 # making precise timing comparisons unreliable. This catches actual concurrency
116 # bugs (like requests running sequentially) while tolerating normal variance.
117 assert conc_duration < seq_duration * 1.5, (
118 f"Concurrent ({conc_duration:.2f}s) should not be much slower than "
119 f"sequential ({seq_duration:.2f}s)"
120 )
121
122
123@pytest.mark.asyncio
124@pytest.mark.integration
125@pytest.mark.requires_api
126async def test_batch_json_output():
127 """Test batch with JSON output mode."""
128 batch = Batch(max_concurrent=2)
129
130 req = batch.create(
131 contents='Return a JSON object with "result": 10',
132 context=TEST_CONTEXT,
133 model=GEMINI_FLASH,
134 json_output=True,
135 )
136 req.id = "json_test"
137 batch.add(req)
138
139 results = []
140 async for r in batch.drain_batch():
141 results.append(r)
142
143 assert len(results) == 1
144 assert results[0].response is not None
145 assert "{" in results[0].response
146 assert "}" in results[0].response
147
148
149@pytest.mark.asyncio
150@pytest.mark.integration
151@pytest.mark.requires_api
152async def test_batch_different_models():
153 """Test batch with different models."""
154 batch = Batch(max_concurrent=2)
155
156 req1 = batch.create(
157 contents="Say 'flash'",
158 context=TEST_CONTEXT,
159 model=GEMINI_FLASH,
160 )
161 req1.model_type = "flash"
162 batch.add(req1)
163
164 req2 = batch.create(
165 contents="Say 'lite'",
166 context=TEST_CONTEXT,
167 model=GEMINI_LITE,
168 )
169 req2.model_type = "lite"
170 batch.add(req2)
171
172 results = []
173 async for r in batch.drain_batch():
174 results.append(r)
175
176 assert len(results) == 2
177
178 # Both should succeed
179 for r in results:
180 assert r.response is not None
181 assert r.error is None
182
183
184@pytest.mark.asyncio
185@pytest.mark.integration
186@pytest.mark.requires_api
187async def test_batch_dynamic_adding():
188 """Test multi-stage pattern - add stage 2 based on stage 1 results."""
189 batch = Batch(max_concurrent=3)
190
191 # Stage 1: Initial requests
192 req1 = batch.create(
193 contents="What is 5+5? Just the number.",
194 context=TEST_CONTEXT,
195 model=GEMINI_FLASH,
196 )
197 req1.stage = "stage1"
198 req1.value = 5
199 batch.add(req1)
200
201 stage1_count = 0
202 stage2_added = False
203
204 async for req in batch.drain_batch():
205 if req.stage == "stage1":
206 stage1_count += 1
207
208 # Add stage 2 request based on result
209 if not stage2_added:
210 req2 = batch.create(
211 contents=f"Previous answer was {req.response}. Double it. Just the number.",
212 context=TEST_CONTEXT,
213 model=GEMINI_FLASH,
214 )
215 req2.stage = "stage2"
216 batch.add(req2)
217 stage2_added = True
218
219 # Should have processed both stages
220 assert stage1_count == 1
221
222
223@pytest.mark.asyncio
224@pytest.mark.integration
225@pytest.mark.requires_api
226async def test_batch_token_logging():
227 """Test that token logging works with batch execution."""
228 with tempfile.TemporaryDirectory() as tmpdir:
229 os.environ["_SOLSTONE_JOURNAL_OVERRIDE"] = tmpdir
230
231 batch = Batch(max_concurrent=2)
232
233 req = batch.create(
234 contents="Say hello",
235 context=TEST_CONTEXT,
236 model=GEMINI_FLASH,
237 )
238 batch.add(req)
239
240 results = []
241 async for r in batch.drain_batch():
242 results.append(r)
243
244 # Check that token logs were created
245 tokens_dir = Path(tmpdir) / "tokens"
246 if tokens_dir.exists():
247 log_files = list(tokens_dir.glob("*.jsonl"))
248 # Should have at least one log file
249 assert len(log_files) >= 1
250
251
252@pytest.mark.asyncio
253@pytest.mark.integration
254@pytest.mark.requires_api
255async def test_batch_error_recovery():
256 """Test retry pattern with real API (simulate by using invalid then valid)."""
257 batch = Batch(max_concurrent=2)
258
259 # This might error or succeed depending on model - just test the pattern
260 req1 = batch.create(
261 contents="What is 1+1? Reply with just the number.",
262 context=TEST_CONTEXT,
263 model=GEMINI_FLASH,
264 max_output_tokens=5, # Very small, might cause issues
265 )
266 req1.attempt = 1
267 batch.add(req1)
268
269 retried = False
270 async for req in batch.drain_batch():
271 if req.attempt == 1:
272 # Always add a retry request to test the pattern
273 if not retried:
274 req2 = batch.create(
275 contents="What is 1+1? Reply with just the number.",
276 context=TEST_CONTEXT,
277 model=GEMINI_FLASH,
278 max_output_tokens=100, # Normal size
279 )
280 req2.attempt = 2
281 batch.add(req2)
282 retried = True
283
284 # Pattern should complete successfully
285 assert retried
286
287
288@pytest.mark.asyncio
289@pytest.mark.integration
290@pytest.mark.requires_api
291async def test_batch_client_reuse():
292 """Test that client is reused across requests in batch (Google-specific)."""
293 from think.providers.google import get_or_create_client
294
295 # Create shared client
296 client = get_or_create_client()
297
298 # Use it in batch - client is passed through to Google backend
299 batch = Batch(max_concurrent=2, client=client)
300
301 req1 = batch.create(
302 contents="Say 'first'",
303 context=TEST_CONTEXT,
304 model=GEMINI_FLASH,
305 )
306 batch.add(req1)
307
308 req2 = batch.create(
309 contents="Say 'second'",
310 context=TEST_CONTEXT,
311 model=GEMINI_FLASH,
312 )
313 batch.add(req2)
314
315 results = []
316 async for req in batch.drain_batch():
317 results.append(req)
318
319 # Both should succeed with shared client
320 assert len(results) == 2
321 for r in results:
322 assert r.response is not None
323 assert r.error is None
324
325
326@pytest.mark.asyncio
327@pytest.mark.integration
328@pytest.mark.requires_api
329async def test_batch_custom_attributes_preserved():
330 """Test that custom attributes added to requests are preserved."""
331 batch = Batch(max_concurrent=2)
332
333 req = batch.create(
334 contents="What is 10+10? Just the number.",
335 context=TEST_CONTEXT,
336 model=GEMINI_FLASH,
337 )
338 req.frame_id = 42
339 req.monitor = "DP-3"
340 req.metadata = {"foo": "bar", "nested": {"baz": 123}}
341 batch.add(req)
342
343 results = []
344 async for r in batch.drain_batch():
345 results.append(r)
346
347 assert len(results) == 1
348 result = results[0]
349
350 # Custom attributes should be preserved
351 assert result.frame_id == 42
352 assert result.monitor == "DP-3"
353 assert result.metadata == {"foo": "bar", "nested": {"baz": 123}}
354
355 # Result attributes should be populated
356 assert result.response is not None
357 assert result.error is None
358 assert result.duration > 0