A trust and safety agent that interacts with Osprey for investigation, real-time analysis, and prevention implementations
at main 410 lines 15 kB view raw
1import asyncio 2import json 3import logging 4from abc import ABC, abstractmethod 5from dataclasses import dataclass 6from typing import Any, Literal 7 8import anthropic 9import httpx 10from anthropic.types import TextBlock, ToolUseBlock 11 12from src.agent.prompt import build_system_prompt 13from src.tools.executor import ToolExecutor 14 15logger = logging.getLogger(__name__) 16 17 18@dataclass 19class AgentTextBlock: 20 text: str 21 22 23@dataclass 24class AgentToolUseBlock: 25 id: str 26 name: str 27 input: dict[str, Any] 28 29 30@dataclass 31class AgentResponse: 32 content: list[AgentTextBlock | AgentToolUseBlock] 33 stop_reason: Literal["end_turn", "tool_use"] 34 reasoning_content: str | None = None 35 36 37class AgentClient(ABC): 38 @abstractmethod 39 async def complete( 40 self, 41 messages: list[dict[str, Any]], 42 system: str | None = None, 43 tools: list[dict[str, Any]] | None = None, 44 ) -> AgentResponse: 45 pass 46 47 48class AnthropicClient(AgentClient): 49 def __init__( 50 self, api_key: str, model_name: str = "claude-sonnet-4-5-20250929" 51 ) -> None: 52 self._client = anthropic.AsyncAnthropic(api_key=api_key) 53 self._model_name = model_name 54 55 async def complete( 56 self, 57 messages: list[dict[str, Any]], 58 system: str | None = None, 59 tools: list[dict[str, Any]] | None = None, 60 ) -> AgentResponse: 61 system_text = system or build_system_prompt() 62 kwargs: dict[str, Any] = { 63 "model": self._model_name, 64 "max_tokens": 16_000, 65 "system": [ 66 { 67 "type": "text", 68 "text": system_text, 69 "cache_control": {"type": "ephemeral"}, 70 } 71 ], 72 "messages": self._inject_cache_breakpoints(messages), 73 } 74 75 if tools: 76 tools = [dict(t) for t in tools] 77 tools[-1]["cache_control"] = {"type": "ephemeral"} 78 kwargs["tools"] = tools 79 80 async with self._client.messages.stream(**kwargs) as stream: # type: ignore 81 msg = await stream.get_final_message() 82 83 content: list[AgentTextBlock | AgentToolUseBlock] = [] 84 for block in msg.content: 85 if isinstance(block, TextBlock): 86 content.append(AgentTextBlock(text=block.text)) 87 elif isinstance(block, ToolUseBlock): 88 content.append( 89 AgentToolUseBlock( 90 id=block.id, 91 name=block.name, 92 input=block.input, # type: ignore 93 ) 94 ) 95 96 return AgentResponse( 97 content=content, 98 stop_reason=msg.stop_reason or "end_turn", # type: ignore TODO: fix this 99 ) 100 101 @staticmethod 102 def _inject_cache_breakpoints( 103 messages: list[dict[str, Any]], 104 ) -> list[dict[str, Any]]: 105 """ 106 a helper that adds cache_control breakpoints to the conversation so that 107 the conversation prefix is cached across successive calls. we place a single 108 breakpoint in th last message's content block, combined with the sys-prompt 109 and tool defs breakpoints. ensures that we stay in the 4-breakpoint limit 110 that ant requires 111 """ 112 if not messages: 113 return messages 114 115 # shallow-copy the list so we don't mutate the caller's conversation 116 messages = list(messages) 117 last_msg = dict(messages[-1]) 118 content = last_msg["content"] 119 120 if isinstance(content, str): 121 last_msg["content"] = [ 122 { 123 "type": "text", 124 "text": content, 125 "cache_control": {"type": "ephemeral"}, 126 } 127 ] 128 elif isinstance(content, list) and content: 129 content = [dict(b) for b in content] 130 content[-1] = dict(content[-1]) 131 content[-1]["cache_control"] = {"type": "ephemeral"} 132 last_msg["content"] = content 133 134 messages[-1] = last_msg 135 return messages 136 137 138class OpenAICompatibleClient(AgentClient): 139 """client for openapi compatible apis like openai, moonshot, etc""" 140 141 def __init__(self, api_key: str, model_name: str, endpoint: str) -> None: 142 self._api_key = api_key 143 self._model_name = model_name 144 self._endpoint = endpoint.rstrip("/") 145 self._http = httpx.AsyncClient(timeout=300.0) 146 147 async def complete( 148 self, 149 messages: list[dict[str, Any]], 150 system: str | None = None, 151 tools: list[dict[str, Any]] | None = None, 152 ) -> AgentResponse: 153 oai_messages = self._convert_messages(messages, system or build_system_prompt()) 154 155 payload: dict[str, Any] = { 156 "model": self._model_name, 157 "messages": oai_messages, 158 "max_tokens": 16_000, 159 } 160 161 if tools: 162 payload["tools"] = self._convert_tools(tools) 163 164 resp = await self._http.post( 165 f"{self._endpoint}/chat/completions", 166 headers={ 167 "Authorization": f"Bearer {self._api_key}", 168 "Content-Type": "application/json", 169 }, 170 json=payload, 171 ) 172 if not resp.is_success: 173 logger.error("API error %d: %s", resp.status_code, resp.text[:1000]) 174 resp.raise_for_status() 175 data = resp.json() 176 177 return self._parse_response(data) 178 179 def _convert_messages( 180 self, messages: list[dict[str, Any]], system: str 181 ) -> list[dict[str, Any]]: 182 """for anthropic chats, we'll convert the outputs into a similar format""" 183 result: list[dict[str, Any]] = [{"role": "system", "content": system}] 184 185 for msg in messages: 186 role = msg["role"] 187 content = msg["content"] 188 189 if isinstance(content, str): 190 result.append({"role": role, "content": content}) 191 elif isinstance(content, list): 192 if role == "assistant": 193 text_parts = [] 194 tool_calls = [] 195 for block in content: 196 if block.get("type") == "text": 197 text_parts.append(block["text"]) 198 elif block.get("type") == "tool_use": 199 tool_calls.append( 200 { 201 "id": block["id"], 202 "type": "function", 203 "function": { 204 "name": block["name"], 205 "arguments": json.dumps(block["input"]), 206 }, 207 } 208 ) 209 oai_msg: dict[str, Any] = {"role": "assistant"} 210 if msg.get("reasoning_content"): 211 oai_msg["reasoning_content"] = msg["reasoning_content"] 212 # some openai-compatible apis reject content: null on 213 # assistant messages with tool_calls, so omit it when empty 214 if text_parts: 215 oai_msg["content"] = "\n".join(text_parts) 216 else: 217 oai_msg["content"] = "" 218 if tool_calls: 219 oai_msg["tool_calls"] = tool_calls 220 result.append(oai_msg) 221 elif role == "user": 222 if content and content[0].get("type") == "tool_result": 223 for block in content: 224 result.append( 225 { 226 "role": "tool", 227 "tool_call_id": block["tool_use_id"], 228 "content": block.get("content", ""), 229 } 230 ) 231 else: 232 text = " ".join(b.get("text", str(b)) for b in content) 233 result.append({"role": "user", "content": text}) 234 235 return result 236 237 def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: 238 """convert anthropic tool defs to oai function calling format""" 239 result = [] 240 for t in tools: 241 func: dict[str, Any] = { 242 "name": t["name"], 243 "description": t.get("description", ""), 244 } 245 if "input_schema" in t: 246 func["parameters"] = t["input_schema"] 247 result.append({"type": "function", "function": func}) 248 return result 249 250 def _parse_response(self, data: dict[str, Any]) -> AgentResponse: 251 """convert an oai chat completion resp to agentresponse""" 252 choice = data["choices"][0] 253 message = choice["message"] 254 finish_reason = choice.get("finish_reason", "stop") 255 256 content: list[AgentTextBlock | AgentToolUseBlock] = [] 257 258 if message.get("content"): 259 content.append(AgentTextBlock(text=message["content"])) 260 261 if message.get("tool_calls"): 262 for tc in message["tool_calls"]: 263 try: 264 args = json.loads(tc["function"]["arguments"]) 265 except (json.JSONDecodeError, KeyError): 266 args = {} 267 content.append( 268 AgentToolUseBlock( 269 id=tc["id"], 270 name=tc["function"]["name"], 271 input=args, 272 ) 273 ) 274 275 stop_reason = "tool_use" if finish_reason == "tool_calls" else "end_turn" 276 reasoning_content = message.get("reasoning_content") 277 return AgentResponse( 278 content=content, 279 stop_reason=stop_reason, 280 reasoning_content=reasoning_content, 281 ) 282 283 284MAX_TOOL_RESULT_LENGTH = 10_000 285 286 287class Agent: 288 def __init__( 289 self, 290 model_api: Literal["anthropic", "openai", "openapi"], 291 model_name: str, 292 model_api_key: str | None, 293 model_endpoint: str | None = None, 294 tool_executor: ToolExecutor | None = None, 295 ) -> None: 296 match model_api: 297 case "anthropic": 298 assert model_api_key 299 self._client: AgentClient = AnthropicClient( 300 api_key=model_api_key, model_name=model_name 301 ) 302 case "openai": 303 assert model_api_key 304 self._client = OpenAICompatibleClient( 305 api_key=model_api_key, 306 model_name=model_name, 307 endpoint="https://api.openai.com/v1", 308 ) 309 case "openapi": 310 assert model_api_key 311 assert model_endpoint, "model_endpoint is required for openapi" 312 self._client = OpenAICompatibleClient( 313 api_key=model_api_key, 314 model_name=model_name, 315 endpoint=model_endpoint, 316 ) 317 318 self._tool_executor = tool_executor 319 self._conversation: list[dict[str, Any]] = [] 320 321 def _get_tools(self) -> list[dict[str, Any]] | None: 322 """get tool definitions for the agent""" 323 324 if self._tool_executor is None: 325 return None 326 return [self._tool_executor.get_execute_code_tool_definition()] 327 328 async def _handle_tool_call(self, tool_use: AgentToolUseBlock) -> dict[str, Any]: 329 """handle a tool call from the model""" 330 if tool_use.name == "execute_code" and self._tool_executor: 331 code = tool_use.input.get("code", "") 332 result = await self._tool_executor.execute_code(code) 333 return result 334 else: 335 return {"error": f"Unknown tool: {tool_use.name}"} 336 337 async def chat(self, user_message: str) -> str: 338 """send a message and get a response, handling tool calls""" 339 self._conversation.append({"role": "user", "content": user_message}) 340 341 while True: 342 resp = await self._client.complete( 343 messages=self._conversation, 344 tools=self._get_tools(), 345 ) 346 347 assistant_content: list[dict[str, Any]] = [] 348 text_response = "" 349 350 for block in resp.content: 351 if isinstance(block, AgentTextBlock): 352 assistant_content.append({"type": "text", "text": block.text}) 353 text_response += block.text 354 elif isinstance(block, AgentToolUseBlock): # type: ignore TODO: for now this errors because there are no other types, but ignore for now 355 assistant_content.append( 356 { 357 "type": "tool_use", 358 "id": block.id, 359 "name": block.name, 360 "input": block.input, 361 } 362 ) 363 364 assistant_msg: dict[str, Any] = { 365 "role": "assistant", 366 "content": assistant_content, 367 } 368 if resp.reasoning_content: 369 assistant_msg["reasoning_content"] = resp.reasoning_content 370 self._conversation.append(assistant_msg) 371 372 # find any tool calls that we need to handle 373 if resp.stop_reason == "tool_use": 374 tool_results: list[dict[str, Any]] = [] 375 for block in resp.content: 376 if isinstance(block, AgentToolUseBlock): 377 code = block.input.get("code", "") 378 logger.info("Tool call: %s\n%s", block.name, code) 379 result = await self._handle_tool_call(block) 380 is_error = "error" in result 381 summary = str(result)[:500] 382 logger.info( 383 "Tool result (%s): %s", 384 "error" if is_error else "ok", 385 summary, 386 ) 387 content_str = str(result) 388 if len(content_str) > MAX_TOOL_RESULT_LENGTH: 389 content_str = ( 390 content_str[:MAX_TOOL_RESULT_LENGTH] 391 + "\n... (truncated)" 392 ) 393 394 tool_results.append( 395 { 396 "type": "tool_result", 397 "tool_use_id": block.id, 398 "content": content_str, 399 } 400 ) 401 402 self._conversation.append({"role": "user", "content": tool_results}) 403 else: 404 # once there are no more tool calls, we proceed to the text response 405 return text_response 406 407 async def run(self): 408 while True: 409 logger.info("running tasks...") 410 await asyncio.sleep(30)