this repo has no description

support coroutine callbacks

Changed files
+11 -13
src
atkafka_consumer
+11 -13
src/atkafka_consumer/consumer.py
··· 1 1 import asyncio 2 2 import json 3 3 import logging 4 - from typing import Any, Callable, List, Optional 5 - 4 + from typing import Any, Callable, List, Optional, Union, Awaitable 6 5 from aiokafka import AIOKafkaConsumer, ConsumerRecord 7 - 8 6 from .models import AtKafkaEvent 9 7 10 8 logger = logging.getLogger(__name__) ··· 16 14 bootstrap_servers: List[str], 17 15 input_topic: str, 18 16 group_id: str, 19 - on_event: Callable[[AtKafkaEvent], None], 17 + on_event: Union[ 18 + Callable[[AtKafkaEvent], None], Callable[[AtKafkaEvent], Awaitable[None]] 19 + ], 20 20 offset: str = "earliest", 21 21 max_concurrent_tasks: int = 100, 22 22 ): ··· 26 26 self._offset = offset 27 27 self._max_concurrent_tasks = max_concurrent_tasks 28 28 self._on_event = on_event 29 - 30 29 self._consumer: Optional[AIOKafkaConsumer] = None 31 - 32 30 self._semaphore: Optional[asyncio.Semaphore] = None 33 31 self._shutdown_event: Optional[asyncio.Event] = None 34 32 35 33 async def stop(self): 36 34 assert self._consumer is not None 37 - 38 35 if self._shutdown_event: 39 36 self._shutdown_event.set() 40 - 41 37 await self._consumer.stop() 42 38 logger.info("stopped kafka consumer") 43 39 44 40 async def _handle_event(self, message: ConsumerRecord[Any, Any]): 45 41 assert self._semaphore is not None 46 - 47 42 async with self._semaphore: 48 43 try: 49 44 evt = AtKafkaEvent.model_validate(message.value) ··· 51 46 logger.error(f"Failed to handle event: {e}") 52 47 raise e 53 48 54 - self._on_event(evt) 49 + try: 50 + result = self._on_event(evt) 51 + if asyncio.iscoroutine(result): 52 + await result 53 + except Exception as e: 54 + logger.error(f"Error in on_event callback: {e}") 55 + raise 55 56 56 57 async def run(self): 57 58 self._semaphore = asyncio.Semaphore(self._max_concurrent_tasks) 58 59 self._shutdown_event = asyncio.Event() 59 - 60 60 self._consumer = AIOKafkaConsumer( 61 61 self._input_topic, 62 62 bootstrap_servers=",".join(self._bootstrap_servers), ··· 68 68 max_poll_interval_ms=300000, 69 69 value_deserializer=lambda m: json.loads(m.decode("utf-8")), 70 70 ) 71 - 72 71 await self._consumer.start() 73 72 logger.info("started kafka consumer") 74 73 ··· 88 87 for t in done: 89 88 if t.exception(): 90 89 logger.error(f"Task failed with exception: {t.exception()}") 91 - 92 90 except Exception as e: 93 91 logger.error(f"Error consuming messages: {e}") 94 92 raise