personal memory agent
at main 246 lines 8.1 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4"""Integration test for Google provider with real API calls.""" 5 6import json 7import os 8import subprocess 9from pathlib import Path 10 11import pytest 12from dotenv import load_dotenv 13 14from tests.integration.conftest import require_cli_tool 15from think.models import GEMINI_FLASH, GEMINI_PRO 16 17 18def get_fixtures_env(): 19 """Load the tests/fixtures/.env file and return the environment.""" 20 fixtures_env = Path(__file__).parent.parent / "fixtures" / ".env" 21 if not fixtures_env.exists(): 22 return None, None, None 23 24 # Load the env file 25 load_dotenv(fixtures_env, override=True) 26 27 api_key = os.getenv("GOOGLE_API_KEY") 28 journal_path = os.getenv("_SOLSTONE_JOURNAL_OVERRIDE") 29 30 return fixtures_env, api_key, journal_path 31 32 33@pytest.mark.integration 34@pytest.mark.requires_api 35def test_google_provider_basic(): 36 """Test Google provider with basic prompt via CLI.""" 37 require_cli_tool("Google", "gemini") 38 fixtures_env, api_key, journal_path = get_fixtures_env() 39 40 if not fixtures_env: 41 pytest.skip("tests/fixtures/.env not found") 42 43 if not api_key: 44 pytest.skip("GOOGLE_API_KEY not found in tests/fixtures/.env file") 45 46 if not journal_path: 47 pytest.skip("_SOLSTONE_JOURNAL_OVERRIDE not found in tests/fixtures/.env file") 48 49 # Prepare environment 50 env = os.environ.copy() 51 env["_SOLSTONE_JOURNAL_OVERRIDE"] = journal_path 52 env["GOOGLE_API_KEY"] = api_key 53 54 # Create NDJSON input (no tool config) 55 ndjson_input = json.dumps( 56 { 57 "prompt": "what is 1+1? Just give me the number.", 58 "provider": "google", 59 "name": "default", 60 "model": GEMINI_FLASH, 61 "max_output_tokens": 100, 62 } 63 ) 64 65 # Run the sol agents command 66 cmd = ["sol", "agents"] 67 result = subprocess.run( 68 cmd, 69 env=env, 70 input=ndjson_input, 71 capture_output=True, 72 text=True, 73 timeout=30, 74 ) 75 76 # Check that the command succeeded 77 assert result.returncode == 0, f"Command failed with stderr: {result.stderr}" 78 79 # Parse stdout events (should be JSONL format) 80 stdout_lines = result.stdout.strip().split("\n") 81 events = [] 82 for line in stdout_lines: 83 if line: 84 try: 85 events.append(json.loads(line)) 86 except json.JSONDecodeError as e: 87 pytest.fail(f"Failed to parse JSON line: {line}\nError: {e}") 88 89 # Verify we have events 90 assert len(events) >= 2, ( 91 f"Expected at least start and finish events, got {len(events)}" 92 ) 93 94 # Check start event 95 start_event = events[0] 96 assert start_event["event"] == "start" 97 assert start_event["prompt"] == "what is 1+1? Just give me the number." 98 assert start_event["model"] == GEMINI_FLASH 99 assert start_event["name"] == "default" 100 assert isinstance(start_event["ts"], int) 101 102 # Check finish event 103 finish_event = events[-1] 104 assert finish_event["event"] == "finish" 105 assert isinstance(finish_event["ts"], int) 106 assert "result" in finish_event 107 108 # The result should contain "2" 109 result_text = finish_event["result"].lower() 110 assert "2" in result_text or "two" in result_text, ( 111 f"Expected '2' in response, got: {finish_event['result']}" 112 ) 113 114 # Check for no errors 115 error_events = [e for e in events if e.get("event") == "error"] 116 assert len(error_events) == 0, f"Found error events: {error_events}" 117 118 # Verify stderr has no errors (warnings about thought_signature are OK) 119 if result.stderr: 120 assert ( 121 "error" not in result.stderr.lower() or "thought_signature" in result.stderr 122 ), f"Unexpected stderr content: {result.stderr}" 123 124 125@pytest.mark.integration 126@pytest.mark.requires_api 127def test_google_provider_with_thinking(): 128 """Test Google provider with thinking enabled.""" 129 require_cli_tool("Google", "gemini") 130 fixtures_env, api_key, journal_path = get_fixtures_env() 131 132 if not fixtures_env: 133 pytest.skip("tests/fixtures/.env not found") 134 135 if not api_key: 136 pytest.skip("GOOGLE_API_KEY not found in tests/fixtures/.env file") 137 138 if not journal_path: 139 pytest.skip("_SOLSTONE_JOURNAL_OVERRIDE not found in tests/fixtures/.env file") 140 141 # Prepare environment 142 env = os.environ.copy() 143 env["_SOLSTONE_JOURNAL_OVERRIDE"] = journal_path 144 env["GOOGLE_API_KEY"] = api_key 145 146 # Create NDJSON input with thinking model (if available) 147 ndjson_input = json.dumps( 148 { 149 "prompt": "What is the square root of 16? Just the number please.", 150 "provider": "google", 151 "name": "default", 152 "model": GEMINI_PRO, # Pro model for thinking 153 "max_output_tokens": 2000, 154 } 155 ) 156 157 # Run the sol agents command 158 cmd = ["sol", "agents"] 159 result = subprocess.run( 160 cmd, 161 env=env, 162 input=ndjson_input, 163 capture_output=True, 164 text=True, 165 timeout=30, 166 ) 167 168 # Allow for model unavailability 169 if result.returncode != 0: 170 if ( 171 "model not found" in result.stderr.lower() 172 or "invalid model" in result.stderr.lower() 173 ): 174 pytest.skip("Thinking model not available") 175 assert False, f"Command failed with stderr: {result.stderr}" 176 177 # Parse events 178 stdout_lines = result.stdout.strip().split("\n") 179 events = [json.loads(line) for line in stdout_lines if line] 180 181 # Check for thinking events (may be present with thinking models) 182 # thinking_events may be present with thinking models (not asserted) 183 # With thinking models, we might get thinking events 184 185 # Verify the answer is correct 186 finish_event = events[-1] 187 188 # Check if this was an API error (intermittent failures) 189 if finish_event.get("event") == "error": 190 error_msg = finish_event.get("error", "Unknown error") 191 trace = finish_event.get("trace", "") 192 if ( 193 "quota" in error_msg.lower() 194 or "rate" in error_msg.lower() 195 or "retry" in error_msg.lower() 196 ): 197 pytest.skip(f"Intermittent Google API error: {error_msg}") 198 else: 199 pytest.fail(f"Unexpected error: {error_msg}\nTrace: {trace}") 200 201 assert finish_event["event"] == "finish", ( 202 f"Expected finish event, got: {finish_event}" 203 ) 204 assert "result" in finish_event, f"No result in finish event: {finish_event}" 205 if finish_event["result"]: 206 result_text = finish_event["result"].lower() 207 assert "4" in result_text or "four" in result_text, ( 208 f"Expected '4' in response, got: {finish_event['result']}" 209 ) 210 211 212@pytest.mark.integration 213@pytest.mark.requires_api 214def test_google_json_truncation_detection(): 215 """Test that Google provider detects JSON response truncation via finish_reason. 216 217 Uses a very small max_output_tokens to force truncation, verifying that 218 the provider returns finish_reason='max_tokens' which callers can use 219 to detect incomplete responses. 220 """ 221 fixtures_env, api_key, _ = get_fixtures_env() 222 223 if not fixtures_env: 224 pytest.skip("tests/fixtures/.env not found") 225 226 if not api_key: 227 pytest.skip("GOOGLE_API_KEY not found in tests/fixtures/.env file") 228 229 # Import provider directly for this test 230 from think.providers import google as google_provider 231 232 # Request JSON output with very small token limit to force truncation 233 # Use run_generate which returns GenerateResult, then check finish_reason 234 result = google_provider.run_generate( 235 contents="Return a JSON array of the first 50 prime numbers.", 236 model=GEMINI_FLASH, 237 json_output=True, 238 max_output_tokens=10, # Too small to complete the response 239 ) 240 241 # Verify truncation was detected via finish_reason 242 assert result["finish_reason"] == "max_tokens", ( 243 f"Expected max_tokens finish_reason, got: {result['finish_reason']}" 244 ) 245 # Partial text should be present 246 assert isinstance(result["text"], str)