1import logging
2from typing import List
3
4from sentence_transformers import SentenceTransformer
5import torch
6
7from config import CONFIG
8
9
10logger = logging.getLogger(__name__)
11
12
13class EmbeddingService:
14 def __init__(self) -> None:
15 self.model = None
16
17 def initialized(self):
18 return self.model is not None
19
20 def initialize(self) -> None:
21 device = CONFIG.embedding_device
22
23 if device == "cuda" and not torch.cuda.is_available():
24 device = "cpu"
25 logger.warning("CUDA requested but not availaable, falling back to CPU")
26
27 if device == "cuda":
28 logger.info("Using CUDA")
29 else:
30 logger.info("Using CPU")
31
32 self.model = SentenceTransformer(CONFIG.embedding_model, device=device)
33
34 def encode(self, text: str) -> List[float]:
35 if not text or not text.strip():
36 return [0.0] * CONFIG.embedding_size
37
38 vector = self.model.encode(text, convert_to_numpy=True)
39 return vector.tolist()
40
41
42EMBEDDING_SERVICE = EmbeddingService()