api-embedding / core /embedding.py
fahmiaziz98
init
fea62df
raw
history blame
5.39 kB
from loguru import logger
from typing import Dict, List, Optional, Any
from sentence_transformers import SentenceTransformer
from sentence_transformers import SparseEncoder
class ModelConfig:
def __init__(self, model_id: str, config: Dict[str, Any]):
self.id = model_id
self.name = config["name"]
self.type = config["type"] # "embedding" or "sparse"
self.dimension = int(config["dimension"])
self.max_tokens = int(config["max_tokens"])
self.description = config["description"]
self.language = config["language"]
self.repository = config["repository"]
class EmbeddingModel:
"""
Embedding model wrapper for dense embeddings.
attributes:
config: ModelConfig instance
model: SentenceTransformer instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
self.config = config
self.model: Optional[SentenceTransformer] = None
self._loaded = False
def load(self) -> None:
"""Load the embedding model."""
if self._loaded:
return
logger.info(f"Loading embedding model: {self.config.name}")
try:
self.model = SentenceTransformer(self.config.name, device="cpu", trust_remote_code=True)
self._loaded = True
logger.success(f"Loaded embedding model: {self.config.id}")
except Exception as e:
logger.error(f"Failed to load embedding model {self.config.id}: {e}")
raise
def embed(self, texts: List[str], prompt: Optional[str] = None) -> List[List[float]]:
"""
method to generate embeddings for a list of texts.
Args:
texts: List of input texts
prompt: Optional prompt for instruction-based models
Returns:
List of embedding vectors
"""
if not self._loaded:
self.load()
try:
embeddings = self.model.encode(texts, prompt=prompt)
return [embedding.tolist() for embedding in embeddings]
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
raise
class SparseEmbeddingModel:
"""
Sparse embedding model wrapper.
Attributes:
config: ModelConfig instance
model: SparseEncoder instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
self.config = config
self.model: Optional[SparseEncoder] = None
self._loaded = False
def _format_values(self, values: List[float]) -> List[float]:
"""Format float values to a fixed precision."""
return [round(float(v), 7) for v in values]
def load(self) -> None:
"""Load the sparse embedding model."""
if self._loaded:
return
logger.info(f"Loading sparse model: {self.config.name}")
try:
self.model = SparseEncoder(self.config.name)
self._loaded = True
logger.success(f"Loaded sparse model: {self.config.id}")
except Exception as e:
logger.error(f"Failed to load sparse model {self.config.id}: {e}")
raise
def embed(self, text: str, prompt: Optional[str] = None) -> Dict[Any, Any]:
"""
Generate a sparse embedding for a single text.
Args:
text: Input text
prompt: Optional prompt for instruction-based models
Returns:
Sparse embedding as a dictionary with 'indices' and 'values' keys.
"""
try:
tensor = self.model.encode([text])
values = tensor[0].coalesce().values().tolist()
indices = tensor[0].coalesce().indices()[0].tolist()
return {
"indices": indices,
"values": self._format_values(values)
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise
def embed_batch(self, texts: List[str], prompt: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Generate sparse embeddings for a batch of texts.
Args:
texts: List of input texts
prompt: Optional prompt for instruction-based models
Returns:
List of sparse embeddings as dictionaries with 'text' and 'sparse_embedding' keys.
"""
if not self._loaded:
self.load()
try:
tensors = self.model.encode(texts)
results = []
for i, tensor in enumerate(tensors):
values = tensor.coalesce().values().tolist()
indices = tensor.coalesce().indices()[0].tolist()
results.append({
"text": texts[i],
"sparse_embedding": {
"indices": indices,
"values": self._format_values(values)
}
})
return results
except Exception as e:
logger.error(f"Sparse embedding generation failed: {e}")
raise