"""SAM v3 connection handler -- one per client TCP connection. Ported from net.i2p.sam.SAMv3Handler. """ from __future__ import annotations import asyncio import logging from i2p_sam.protocol import SAMCommand, SAMReply, SUPPORTED_VERSIONS from i2p_sam.sessions_db import SessionsDB, SessionRecord from i2p_sam.stream_session import SAMStreamSession from i2p_sam.datagram_session import SAMDatagramSession from i2p_sam.raw_session import SAMRawSession from i2p_sam.primary_session import PrimarySession from i2p_sam.utils import generate_transient_destination, negotiate_version logger = logging.getLogger(__name__) class SAMHandler: """Handles one SAM client connection. Lifecycle: 1. HELLO negotiation (mandatory first message) 2. SESSION CREATE (one session per handler, or PRIMARY for multi) 3. Command loop: STREAM, DATAGRAM, RAW, DEST, NAMING, PING 4. Socket stealing on STREAM CONNECT/ACCEPT """ # Timeout for initial HELLO negotiation HELLO_TIMEOUT = 60.0 # Timeout for subsequent commands (keepalive) COMMAND_TIMEOUT = 180.0 def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, sessions_db: SessionsDB, ) -> None: self._reader = reader self._writer = writer self._sessions_db = sessions_db self._negotiated_version: str | None = None self._session_nickname: str | None = None self._session_style: str | None = None self._stream_session: SAMStreamSession | None = None self._datagram_session: SAMDatagramSession | None = None self._raw_session: SAMRawSession | None = None self._primary_session: PrimarySession | None = None self._running = True async def run(self) -> None: """Main handler loop. First negotiates protocol version via HELLO, then enters command dispatch loop until the client disconnects or sends QUIT. """ try: # Step 1: HELLO negotiation if not await self._handle_hello(): return # Step 2: Command loop while self._running: line = await asyncio.wait_for( self._reader.readline(), timeout=self.COMMAND_TIMEOUT ) if not line: break text = line.decode("utf-8", errors="replace").strip() if text: await self._dispatch(text) except asyncio.TimeoutError: logger.debug("SAM client timed out") except ConnectionResetError: logger.debug("SAM client connection reset") finally: # Clean up session on disconnect if self._session_nickname: await self._sessions_db.remove(self._session_nickname) try: self._writer.close() await self._writer.wait_closed() except Exception: pass async def _handle_hello(self) -> bool: """Negotiate protocol version. Returns True on success. Reads the HELLO VERSION line, negotiates version, and sends reply. The client must send HELLO as the first command within HELLO_TIMEOUT. """ try: line = await asyncio.wait_for( self._reader.readline(), timeout=self.HELLO_TIMEOUT ) except asyncio.TimeoutError: logger.debug("SAM client did not send HELLO within timeout") return False if not line: return False text = line.decode("utf-8", errors="replace").strip() if not text: return False try: cmd = SAMCommand.parse(text) except ValueError: await self._write(SAMReply.hello_noversion()) return False if cmd.verb != "HELLO" or cmd.opcode != "VERSION": await self._write(SAMReply.hello_noversion()) return False client_min = cmd.params.get("MIN", "3.0") client_max = cmd.params.get("MAX", "3.3") version = negotiate_version(client_min, client_max, SUPPORTED_VERSIONS) if version is None: await self._write(SAMReply.hello_noversion()) return False self._negotiated_version = version await self._write(SAMReply.hello_ok(version)) return True async def _dispatch(self, line: str) -> None: """Route command to handler method. Args: line: The raw text line from the client (already stripped). """ try: cmd = SAMCommand.parse(line) except ValueError: logger.warning("Failed to parse SAM command: %s", line) return match cmd.verb: case "SESSION": await self._handle_session(cmd) case "STREAM": await self._handle_stream(cmd) case "DATAGRAM": await self._handle_datagram(cmd) case "RAW": await self._handle_raw(cmd) case "DEST": await self._handle_dest(cmd) case "NAMING": await self._handle_naming(cmd) case "PING": await self._handle_ping(cmd) case "QUIT" | "STOP" | "EXIT": self._running = False case _: logger.warning("Unknown SAM verb: %s", cmd.verb) async def _handle_session(self, cmd: SAMCommand) -> None: """SESSION CREATE/ADD/REMOVE. CREATE: generate or load destination, register in SessionsDB. ADD: add subsession to existing PRIMARY session. REMOVE: remove subsession from PRIMARY session. """ opcode = cmd.opcode.upper() if opcode == "CREATE": await self._session_create(cmd) elif opcode == "ADD": await self._session_add(cmd) elif opcode == "REMOVE": await self._session_remove(cmd) else: await self._write(SAMReply.session_error( "I2P_ERROR", f"Unknown SESSION opcode: {opcode}")) async def _session_create(self, cmd: SAMCommand) -> None: """Handle SESSION CREATE command.""" nickname = cmd.params.get("ID", "") style = cmd.params.get("STYLE", "STREAM").upper() dest_param = cmd.params.get("DESTINATION", "TRANSIENT") if not nickname: await self._write(SAMReply.session_error( "I2P_ERROR", "Missing session ID")) return # Check for duplicate nickname if await self._sessions_db.has(nickname): await self._write(SAMReply.session_error("DUPLICATED_ID")) return # Generate or load destination if dest_param.upper() == "TRANSIENT": raw_dest, dest_b64 = generate_transient_destination() else: # Use provided destination key (I2P base64) from i2p_data.data_helper import from_base64 try: raw_dest = from_base64(dest_param) dest_b64 = dest_param except Exception: await self._write(SAMReply.session_error( "INVALID_KEY", "Cannot decode destination")) return # Create session record record = SessionRecord( nickname=nickname, style=style, destination=raw_dest, destination_b64=dest_b64, handler=self, ) if not await self._sessions_db.add(record): await self._write(SAMReply.session_error("DUPLICATED_ID")) return self._session_nickname = nickname self._session_style = style # Create style-specific session object if style == "STREAM": self._stream_session = SAMStreamSession(nickname, dest_b64) elif style == "DATAGRAM": listen_port = int(cmd.params.get("PORT", "0")) listen_host = cmd.params.get("HOST", "127.0.0.1") self._datagram_session = SAMDatagramSession( nickname, dest_b64, listen_port, listen_host) elif style == "RAW": protocol = int(cmd.params.get("PROTOCOL", "18")) self._raw_session = SAMRawSession(nickname, dest_b64, protocol) elif style == "PRIMARY": self._primary_session = PrimarySession(nickname, dest_b64) await self._write(SAMReply.session_ok(dest_b64)) async def _session_add(self, cmd: SAMCommand) -> None: """Handle SESSION ADD (for PRIMARY sessions).""" nickname = cmd.params.get("ID", "") style = cmd.params.get("STYLE", "STREAM").upper() from_port = cmd.params.get("FROM_PORT", "0") if not self._primary_session or self._session_nickname != nickname: await self._write(SAMReply.session_error( "I2P_ERROR", "No PRIMARY session with that ID")) return try: record = await self._primary_session.add_subsession( from_port, style, self) await self._write(SAMReply.session_ok(record.destination_b64)) except ValueError: await self._write(SAMReply.session_error("DUPLICATED_ID")) async def _session_remove(self, cmd: SAMCommand) -> None: """Handle SESSION REMOVE (for PRIMARY sessions).""" nickname = cmd.params.get("ID", "") from_port = cmd.params.get("FROM_PORT", "0") if not self._primary_session or self._session_nickname != nickname: await self._write(SAMReply.session_error( "I2P_ERROR", "No PRIMARY session with that ID")) return if await self._primary_session.remove_subsession(from_port): await self._write(SAMReply.session_ok( self._primary_session._destination_b64)) else: await self._write(SAMReply.session_error( "I2P_ERROR", "Subsession not found")) async def _handle_stream(self, cmd: SAMCommand) -> None: """STREAM CONNECT/ACCEPT/FORWARD. CONNECT: connect to remote destination, steal socket for raw tunnel. ACCEPT: wait for incoming connection, steal socket. FORWARD: listen on local port, forward connections. """ opcode = cmd.opcode.upper() nickname = cmd.params.get("ID", "") if not self._stream_session: await self._write(SAMReply.stream_error( "I2P_ERROR", "No STREAM session active")) return if opcode == "CONNECT": target = cmd.params.get("DESTINATION", "") if not target: await self._write(SAMReply.stream_error( "I2P_ERROR", "Missing DESTINATION")) return silent = cmd.params.get("SILENT", "false").lower() == "true" from_port = int(cmd.params.get("FROM_PORT", "0")) to_port = int(cmd.params.get("TO_PORT", "0")) await self._stream_session.connect( target, self._reader, self._writer, silent=silent, from_port=from_port, to_port=to_port) self._running = False # Socket is stolen elif opcode == "ACCEPT": silent = cmd.params.get("SILENT", "false").lower() == "true" await self._stream_session.accept( self._reader, self._writer, silent=silent) self._running = False # Socket is stolen elif opcode == "FORWARD": port = int(cmd.params.get("PORT", "0")) host = cmd.params.get("HOST", "127.0.0.1") silent = cmd.params.get("SILENT", "false").lower() == "true" await self._stream_session.forward(port, host, silent) await self._write(SAMReply.stream_ok()) else: await self._write(SAMReply.stream_error( "I2P_ERROR", f"Unknown STREAM opcode: {opcode}")) async def _handle_datagram(self, cmd: SAMCommand) -> None: """DATAGRAM SEND.""" if not self._datagram_session: return opcode = cmd.opcode.upper() if opcode == "SEND": target = cmd.params.get("DESTINATION", "") size = int(cmd.params.get("SIZE", "0")) from_port = int(cmd.params.get("FROM_PORT", "0")) to_port = int(cmd.params.get("TO_PORT", "0")) # Read the payload if size > 0: data = await self._reader.readexactly(size) else: data = b"" await self._datagram_session.send( target, data, from_port=from_port, to_port=to_port) async def _handle_raw(self, cmd: SAMCommand) -> None: """RAW SEND.""" if not self._raw_session: return opcode = cmd.opcode.upper() if opcode == "SEND": target = cmd.params.get("DESTINATION", "") size = int(cmd.params.get("SIZE", "0")) protocol = int(cmd.params.get("PROTOCOL", "18")) from_port = int(cmd.params.get("FROM_PORT", "0")) to_port = int(cmd.params.get("TO_PORT", "0")) if size > 0: data = await self._reader.readexactly(size) else: data = b"" await self._raw_session.send( target, data, protocol=protocol, from_port=from_port, to_port=to_port) async def _handle_dest(self, cmd: SAMCommand) -> None: """DEST LOOKUP — resolve a hostname to a destination. In production, this would query the I2P network database. """ opcode = cmd.opcode.upper() if opcode == "GENERATE": # Generate a new transient destination _, dest_b64 = generate_transient_destination() await self._write(SAMReply.dest_reply("TRANSIENT", dest_b64)) elif opcode == "LOOKUP": name = cmd.params.get("NAME", "") if not name: await self._write(SAMReply.dest_not_found("")) return # Without a real naming service, we can only resolve "ME" if name == "ME" and self._session_nickname: session = await self._sessions_db.get(self._session_nickname) if session: await self._write(SAMReply.dest_reply(name, session.destination_b64)) return await self._write(SAMReply.dest_not_found(name)) async def _handle_naming(self, cmd: SAMCommand) -> None: """NAMING LOOKUP — resolve a name via the naming service.""" name = cmd.params.get("NAME", "") if not name: await self._write(SAMReply.naming_not_found("")) return # "ME" resolves to this session's destination if name == "ME" and self._session_nickname: session = await self._sessions_db.get(self._session_nickname) if session: await self._write(SAMReply.naming_reply(name, session.destination_b64)) return # Without a real naming service, all other lookups fail await self._write(SAMReply.naming_not_found(name)) async def _handle_ping(self, cmd: SAMCommand) -> None: """PING -> PONG. The opcode field contains the ping data to echo back. """ data = cmd.opcode if cmd.opcode else "" await self._write(SAMReply.pong(data)) async def _write(self, message: str) -> None: """Write message to client. Args: message: The SAM protocol message to send. """ self._writer.write(message.encode("utf-8")) await self._writer.drain()