1import logging
2from typing import List
3from time import time
4
5from sentence_transformers import SentenceTransformer
6import torch
7
8from config import CONFIG
9from metrics import prom_metrics
10
11
12logger = logging.getLogger(__name__)
13
14
15class EmbeddingService:
16 def __init__(self) -> None:
17 self.model = None
18
19 def initialized(self):
20 return self.model is not None
21
22 def initialize(self) -> None:
23 device = CONFIG.embedding_device
24
25 if device == "cuda" and not torch.cuda.is_available():
26 device = "cpu"
27 logger.warning("CUDA requested but not availaable, falling back to CPU")
28
29 if device == "cuda":
30 logger.info("Using CUDA")
31 else:
32 logger.info("Using CPU")
33
34 self.model = SentenceTransformer(CONFIG.embedding_model, device=device)
35
36 def encode(self, text: str) -> List[float]:
37 if not text or not text.strip():
38 return [0.0] * CONFIG.embedding_size
39
40 status = "error"
41 start_time = time()
42 try:
43 vector = self.model.encode(text, convert_to_numpy=True)
44 status = "ok"
45 return vector.tolist()
46 except Exception as e:
47 logger.error(f"Error getting embedding: {e}")
48 raise e
49 finally:
50 prom_metrics.embedding_performed.labels(status=status).inc()
51 prom_metrics.embedding_duration.labels(status=status).observe(
52 time() - start_time
53 )
54
55
56EMBEDDING_SERVICE = EmbeddingService()