A trust and safety agent that interacts with Osprey for investigation, real-time analysis, and prevention implementations
1import asyncio
2import json
3import logging
4from abc import ABC, abstractmethod
5from dataclasses import dataclass
6from typing import Any, Literal
7
8import anthropic
9import httpx
10from anthropic.types import TextBlock, ToolUseBlock
11
12from src.agent.prompt import build_system_prompt
13from src.tools.executor import ToolExecutor
14
15logger = logging.getLogger(__name__)
16
17
18@dataclass
19class AgentTextBlock:
20 text: str
21
22
23@dataclass
24class AgentToolUseBlock:
25 id: str
26 name: str
27 input: dict[str, Any]
28
29
30@dataclass
31class AgentResponse:
32 content: list[AgentTextBlock | AgentToolUseBlock]
33 stop_reason: Literal["end_turn", "tool_use"]
34 reasoning_content: str | None = None
35
36
37class AgentClient(ABC):
38 @abstractmethod
39 async def complete(
40 self,
41 messages: list[dict[str, Any]],
42 system: str | None = None,
43 tools: list[dict[str, Any]] | None = None,
44 ) -> AgentResponse:
45 pass
46
47
48class AnthropicClient(AgentClient):
49 def __init__(
50 self, api_key: str, model_name: str = "claude-sonnet-4-5-20250929"
51 ) -> None:
52 self._client = anthropic.AsyncAnthropic(api_key=api_key)
53 self._model_name = model_name
54
55 async def complete(
56 self,
57 messages: list[dict[str, Any]],
58 system: str | None = None,
59 tools: list[dict[str, Any]] | None = None,
60 ) -> AgentResponse:
61 system_text = system or build_system_prompt()
62 kwargs: dict[str, Any] = {
63 "model": self._model_name,
64 "max_tokens": 16_000,
65 "system": [
66 {
67 "type": "text",
68 "text": system_text,
69 "cache_control": {"type": "ephemeral"},
70 }
71 ],
72 "messages": self._inject_cache_breakpoints(messages),
73 }
74
75 if tools:
76 tools = [dict(t) for t in tools]
77 tools[-1]["cache_control"] = {"type": "ephemeral"}
78 kwargs["tools"] = tools
79
80 async with self._client.messages.stream(**kwargs) as stream: # type: ignore
81 msg = await stream.get_final_message()
82
83 content: list[AgentTextBlock | AgentToolUseBlock] = []
84 for block in msg.content:
85 if isinstance(block, TextBlock):
86 content.append(AgentTextBlock(text=block.text))
87 elif isinstance(block, ToolUseBlock):
88 content.append(
89 AgentToolUseBlock(
90 id=block.id,
91 name=block.name,
92 input=block.input, # type: ignore
93 )
94 )
95
96 return AgentResponse(
97 content=content,
98 stop_reason=msg.stop_reason or "end_turn", # type: ignore TODO: fix this
99 )
100
101 @staticmethod
102 def _inject_cache_breakpoints(
103 messages: list[dict[str, Any]],
104 ) -> list[dict[str, Any]]:
105 """
106 a helper that adds cache_control breakpoints to the conversation so that
107 the conversation prefix is cached across successive calls. we place a single
108 breakpoint in th last message's content block, combined with the sys-prompt
109 and tool defs breakpoints. ensures that we stay in the 4-breakpoint limit
110 that ant requires
111 """
112 if not messages:
113 return messages
114
115 # shallow-copy the list so we don't mutate the caller's conversation
116 messages = list(messages)
117 last_msg = dict(messages[-1])
118 content = last_msg["content"]
119
120 if isinstance(content, str):
121 last_msg["content"] = [
122 {
123 "type": "text",
124 "text": content,
125 "cache_control": {"type": "ephemeral"},
126 }
127 ]
128 elif isinstance(content, list) and content:
129 content = [dict(b) for b in content]
130 content[-1] = dict(content[-1])
131 content[-1]["cache_control"] = {"type": "ephemeral"}
132 last_msg["content"] = content
133
134 messages[-1] = last_msg
135 return messages
136
137
138class OpenAICompatibleClient(AgentClient):
139 """client for openapi compatible apis like openai, moonshot, etc"""
140
141 def __init__(self, api_key: str, model_name: str, endpoint: str) -> None:
142 self._api_key = api_key
143 self._model_name = model_name
144 self._endpoint = endpoint.rstrip("/")
145 self._http = httpx.AsyncClient(timeout=300.0)
146
147 async def complete(
148 self,
149 messages: list[dict[str, Any]],
150 system: str | None = None,
151 tools: list[dict[str, Any]] | None = None,
152 ) -> AgentResponse:
153 oai_messages = self._convert_messages(messages, system or build_system_prompt())
154
155 payload: dict[str, Any] = {
156 "model": self._model_name,
157 "messages": oai_messages,
158 "max_tokens": 16_000,
159 }
160
161 if tools:
162 payload["tools"] = self._convert_tools(tools)
163
164 resp = await self._http.post(
165 f"{self._endpoint}/chat/completions",
166 headers={
167 "Authorization": f"Bearer {self._api_key}",
168 "Content-Type": "application/json",
169 },
170 json=payload,
171 )
172 if not resp.is_success:
173 logger.error("API error %d: %s", resp.status_code, resp.text[:1000])
174 resp.raise_for_status()
175 data = resp.json()
176
177 return self._parse_response(data)
178
179 def _convert_messages(
180 self, messages: list[dict[str, Any]], system: str
181 ) -> list[dict[str, Any]]:
182 """for anthropic chats, we'll convert the outputs into a similar format"""
183 result: list[dict[str, Any]] = [{"role": "system", "content": system}]
184
185 for msg in messages:
186 role = msg["role"]
187 content = msg["content"]
188
189 if isinstance(content, str):
190 result.append({"role": role, "content": content})
191 elif isinstance(content, list):
192 if role == "assistant":
193 text_parts = []
194 tool_calls = []
195 for block in content:
196 if block.get("type") == "text":
197 text_parts.append(block["text"])
198 elif block.get("type") == "tool_use":
199 tool_calls.append(
200 {
201 "id": block["id"],
202 "type": "function",
203 "function": {
204 "name": block["name"],
205 "arguments": json.dumps(block["input"]),
206 },
207 }
208 )
209 oai_msg: dict[str, Any] = {"role": "assistant"}
210 if msg.get("reasoning_content"):
211 oai_msg["reasoning_content"] = msg["reasoning_content"]
212 # some openai-compatible apis reject content: null on
213 # assistant messages with tool_calls, so omit it when empty
214 if text_parts:
215 oai_msg["content"] = "\n".join(text_parts)
216 else:
217 oai_msg["content"] = ""
218 if tool_calls:
219 oai_msg["tool_calls"] = tool_calls
220 result.append(oai_msg)
221 elif role == "user":
222 if content and content[0].get("type") == "tool_result":
223 for block in content:
224 result.append(
225 {
226 "role": "tool",
227 "tool_call_id": block["tool_use_id"],
228 "content": block.get("content", ""),
229 }
230 )
231 else:
232 text = " ".join(b.get("text", str(b)) for b in content)
233 result.append({"role": "user", "content": text})
234
235 return result
236
237 def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
238 """convert anthropic tool defs to oai function calling format"""
239 result = []
240 for t in tools:
241 func: dict[str, Any] = {
242 "name": t["name"],
243 "description": t.get("description", ""),
244 }
245 if "input_schema" in t:
246 func["parameters"] = t["input_schema"]
247 result.append({"type": "function", "function": func})
248 return result
249
250 def _parse_response(self, data: dict[str, Any]) -> AgentResponse:
251 """convert an oai chat completion resp to agentresponse"""
252 choice = data["choices"][0]
253 message = choice["message"]
254 finish_reason = choice.get("finish_reason", "stop")
255
256 content: list[AgentTextBlock | AgentToolUseBlock] = []
257
258 if message.get("content"):
259 content.append(AgentTextBlock(text=message["content"]))
260
261 if message.get("tool_calls"):
262 for tc in message["tool_calls"]:
263 try:
264 args = json.loads(tc["function"]["arguments"])
265 except (json.JSONDecodeError, KeyError):
266 args = {}
267 content.append(
268 AgentToolUseBlock(
269 id=tc["id"],
270 name=tc["function"]["name"],
271 input=args,
272 )
273 )
274
275 stop_reason = "tool_use" if finish_reason == "tool_calls" else "end_turn"
276 reasoning_content = message.get("reasoning_content")
277 return AgentResponse(
278 content=content,
279 stop_reason=stop_reason,
280 reasoning_content=reasoning_content,
281 )
282
283
284MAX_TOOL_RESULT_LENGTH = 10_000
285
286
287class Agent:
288 def __init__(
289 self,
290 model_api: Literal["anthropic", "openai", "openapi"],
291 model_name: str,
292 model_api_key: str | None,
293 model_endpoint: str | None = None,
294 tool_executor: ToolExecutor | None = None,
295 ) -> None:
296 match model_api:
297 case "anthropic":
298 assert model_api_key
299 self._client: AgentClient = AnthropicClient(
300 api_key=model_api_key, model_name=model_name
301 )
302 case "openai":
303 assert model_api_key
304 self._client = OpenAICompatibleClient(
305 api_key=model_api_key,
306 model_name=model_name,
307 endpoint="https://api.openai.com/v1",
308 )
309 case "openapi":
310 assert model_api_key
311 assert model_endpoint, "model_endpoint is required for openapi"
312 self._client = OpenAICompatibleClient(
313 api_key=model_api_key,
314 model_name=model_name,
315 endpoint=model_endpoint,
316 )
317
318 self._tool_executor = tool_executor
319 self._conversation: list[dict[str, Any]] = []
320
321 def _get_tools(self) -> list[dict[str, Any]] | None:
322 """get tool definitions for the agent"""
323
324 if self._tool_executor is None:
325 return None
326 return [self._tool_executor.get_execute_code_tool_definition()]
327
328 async def _handle_tool_call(self, tool_use: AgentToolUseBlock) -> dict[str, Any]:
329 """handle a tool call from the model"""
330 if tool_use.name == "execute_code" and self._tool_executor:
331 code = tool_use.input.get("code", "")
332 result = await self._tool_executor.execute_code(code)
333 return result
334 else:
335 return {"error": f"Unknown tool: {tool_use.name}"}
336
337 async def chat(self, user_message: str) -> str:
338 """send a message and get a response, handling tool calls"""
339 self._conversation.append({"role": "user", "content": user_message})
340
341 while True:
342 resp = await self._client.complete(
343 messages=self._conversation,
344 tools=self._get_tools(),
345 )
346
347 assistant_content: list[dict[str, Any]] = []
348 text_response = ""
349
350 for block in resp.content:
351 if isinstance(block, AgentTextBlock):
352 assistant_content.append({"type": "text", "text": block.text})
353 text_response += block.text
354 elif isinstance(block, AgentToolUseBlock): # type: ignore TODO: for now this errors because there are no other types, but ignore for now
355 assistant_content.append(
356 {
357 "type": "tool_use",
358 "id": block.id,
359 "name": block.name,
360 "input": block.input,
361 }
362 )
363
364 assistant_msg: dict[str, Any] = {
365 "role": "assistant",
366 "content": assistant_content,
367 }
368 if resp.reasoning_content:
369 assistant_msg["reasoning_content"] = resp.reasoning_content
370 self._conversation.append(assistant_msg)
371
372 # find any tool calls that we need to handle
373 if resp.stop_reason == "tool_use":
374 tool_results: list[dict[str, Any]] = []
375 for block in resp.content:
376 if isinstance(block, AgentToolUseBlock):
377 code = block.input.get("code", "")
378 logger.info("Tool call: %s\n%s", block.name, code)
379 result = await self._handle_tool_call(block)
380 is_error = "error" in result
381 summary = str(result)[:500]
382 logger.info(
383 "Tool result (%s): %s",
384 "error" if is_error else "ok",
385 summary,
386 )
387 content_str = str(result)
388 if len(content_str) > MAX_TOOL_RESULT_LENGTH:
389 content_str = (
390 content_str[:MAX_TOOL_RESULT_LENGTH]
391 + "\n... (truncated)"
392 )
393
394 tool_results.append(
395 {
396 "type": "tool_result",
397 "tool_use_id": block.id,
398 "content": content_str,
399 }
400 )
401
402 self._conversation.append({"role": "user", "content": tool_results})
403 else:
404 # once there are no more tool calls, we proceed to the text response
405 return text_response
406
407 async def run(self):
408 while True:
409 logger.info("running tasks...")
410 await asyncio.sleep(30)