+11
-13
src/atkafka_consumer/consumer.py
+11
-13
src/atkafka_consumer/consumer.py
···
1
1
import asyncio
2
2
import json
3
3
import logging
4
-
from typing import Any, Callable, List, Optional
5
-
4
+
from typing import Any, Callable, List, Optional, Union, Awaitable
6
5
from aiokafka import AIOKafkaConsumer, ConsumerRecord
7
-
8
6
from .models import AtKafkaEvent
9
7
10
8
logger = logging.getLogger(__name__)
···
16
14
bootstrap_servers: List[str],
17
15
input_topic: str,
18
16
group_id: str,
19
-
on_event: Callable[[AtKafkaEvent], None],
17
+
on_event: Union[
18
+
Callable[[AtKafkaEvent], None], Callable[[AtKafkaEvent], Awaitable[None]]
19
+
],
20
20
offset: str = "earliest",
21
21
max_concurrent_tasks: int = 100,
22
22
):
···
26
26
self._offset = offset
27
27
self._max_concurrent_tasks = max_concurrent_tasks
28
28
self._on_event = on_event
29
-
30
29
self._consumer: Optional[AIOKafkaConsumer] = None
31
-
32
30
self._semaphore: Optional[asyncio.Semaphore] = None
33
31
self._shutdown_event: Optional[asyncio.Event] = None
34
32
35
33
async def stop(self):
36
34
assert self._consumer is not None
37
-
38
35
if self._shutdown_event:
39
36
self._shutdown_event.set()
40
-
41
37
await self._consumer.stop()
42
38
logger.info("stopped kafka consumer")
43
39
44
40
async def _handle_event(self, message: ConsumerRecord[Any, Any]):
45
41
assert self._semaphore is not None
46
-
47
42
async with self._semaphore:
48
43
try:
49
44
evt = AtKafkaEvent.model_validate(message.value)
···
51
46
logger.error(f"Failed to handle event: {e}")
52
47
raise e
53
48
54
-
self._on_event(evt)
49
+
try:
50
+
result = self._on_event(evt)
51
+
if asyncio.iscoroutine(result):
52
+
await result
53
+
except Exception as e:
54
+
logger.error(f"Error in on_event callback: {e}")
55
+
raise
55
56
56
57
async def run(self):
57
58
self._semaphore = asyncio.Semaphore(self._max_concurrent_tasks)
58
59
self._shutdown_event = asyncio.Event()
59
-
60
60
self._consumer = AIOKafkaConsumer(
61
61
self._input_topic,
62
62
bootstrap_servers=",".join(self._bootstrap_servers),
···
68
68
max_poll_interval_ms=300000,
69
69
value_deserializer=lambda m: json.loads(m.decode("utf-8")),
70
70
)
71
-
72
71
await self._consumer.start()
73
72
logger.info("started kafka consumer")
74
73
···
88
87
for t in done:
89
88
if t.exception():
90
89
logger.error(f"Task failed with exception: {t.exception()}")
91
-
92
90
except Exception as e:
93
91
logger.error(f"Error consuming messages: {e}")
94
92
raise