from typing import List, Tuple import chromadb from chromadb.utils import embedding_functions import os from .dto.chunk_dto import ChunkDTO from lpm_kernel.common.llm import LLMClient from lpm_kernel.file_data.document_dto import DocumentDTO from typing import List, Dict, Optional from lpm_kernel.configs.logging import get_train_process_logger logger = get_train_process_logger() class EmbeddingService: def __init__(self): from lpm_kernel.file_data.chroma_utils import detect_embedding_model_dimension from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService chroma_path = os.getenv("CHROMA_PERSIST_DIRECTORY", "./data/chroma_db") self.client = chromadb.PersistentClient(path=chroma_path) self.llm_client = LLMClient() # Get embedding model dimension from user config try: user_llm_config_service = UserLLMConfigService() user_llm_config = user_llm_config_service.get_available_llm() if user_llm_config and user_llm_config.embedding_model_name: # Detect dimension based on model name self.dimension = detect_embedding_model_dimension(user_llm_config.embedding_model_name) logger.info(f"Detected embedding dimension: {self.dimension} for model: {user_llm_config.embedding_model_name}") else: # Default to OpenAI dimension if no config found self.dimension = 1536 logger.info(f"No embedding model configured, using default dimension: {self.dimension}") except Exception as e: # Default to OpenAI dimension if error occurs self.dimension = 1536 logger.error(f"Error detecting embedding dimension, using default: {self.dimension}. Error: {str(e)}", exc_info=True) # Check for dimension mismatches in all collections first collections_to_init = ["documents", "document_chunks"] dimension_mismatch_detected = False # First pass: check all collections for dimension mismatches for collection_name in collections_to_init: try: collection = self.client.get_collection(name=collection_name) if collection.metadata.get("dimension") != self.dimension: logger.warning(f"Dimension mismatch in '{collection_name}' collection: {collection.metadata.get('dimension')} vs {self.dimension}") dimension_mismatch_detected = True except ValueError: # Collection doesn't exist yet, will be created later pass # Handle dimension mismatch if detected in any collection if dimension_mismatch_detected: self._handle_dimension_mismatch() # Second pass: create or get collections with the correct dimension try: self.document_collection = self.client.get_collection(name="documents") # Verify dimension after possible reinitialization doc_dimension = self.document_collection.metadata.get("dimension") if doc_dimension != self.dimension: logger.error(f"Collection 'documents' still has incorrect dimension after reinitialization: {doc_dimension} vs {self.dimension}") # Try to reinitialize again if dimension is still incorrect raise RuntimeError(f"Failed to set correct dimension for 'documents' collection: {doc_dimension} vs {self.dimension}") except ValueError: # Collection doesn't exist, create it with the correct dimension try: self.document_collection = self.client.create_collection( name="documents", metadata={"hnsw:space": "cosine", "dimension": self.dimension} ) logger.info(f"Created 'documents' collection with dimension {self.dimension}") except Exception as e: logger.error(f"Failed to create 'documents' collection: {str(e)}", exc_info=True) raise RuntimeError(f"Failed to create 'documents' collection: {str(e)}") try: self.chunk_collection = self.client.get_collection(name="document_chunks") # Verify dimension after possible reinitialization chunk_dimension = self.chunk_collection.metadata.get("dimension") if chunk_dimension != self.dimension: logger.error(f"Collection 'document_chunks' still has incorrect dimension after reinitialization: {chunk_dimension} vs {self.dimension}") # Try to reinitialize again if dimension is still incorrect raise RuntimeError(f"Failed to set correct dimension for 'document_chunks' collection: {chunk_dimension} vs {self.dimension}") except ValueError: # Collection doesn't exist, create it with the correct dimension try: self.chunk_collection = self.client.create_collection( name="document_chunks", metadata={"hnsw:space": "cosine", "dimension": self.dimension} ) logger.info(f"Created 'document_chunks' collection with dimension {self.dimension}") except Exception as e: logger.error(f"Failed to create 'document_chunks' collection: {str(e)}", exc_info=True) raise RuntimeError(f"Failed to create 'document_chunks' collection: {str(e)}") def generate_document_embedding(self, document: DocumentDTO) -> List[float]: """Process document level embedding and store in ChromaDB""" try: if not document.raw_content: logger.warning( f"Document {document.id} has no content to process embedding" ) return None # get embedding logger.info(f"Generating embedding for document {document.id}") embeddings = self.llm_client.get_embedding([document.raw_content]) if embeddings is None or len(embeddings) == 0: logger.error(f"Failed to get embedding for document {document.id}") return None embedding = embeddings[0] logger.info(f"Successfully got embedding for document {document.id}") # store to ChromaDB try: logger.info(f"Storing embedding for document {document.id} in ChromaDB") self.document_collection.add( documents=[document.raw_content], ids=[str(document.id)], embeddings=[embedding.tolist()], metadatas=[ { "title": document.title or document.name, "mime_type": document.mime_type, "create_time": document.create_time.isoformat() if document.create_time else None, "document_size": document.document_size, "url": document.url, } ], ) logger.info(f"Successfully stored embedding for document {document.id}") # verify embedding storage result = self.document_collection.get( ids=[str(document.id)], include=["embeddings"] ) if not result or not result["embeddings"]: logger.error( f"Failed to verify embedding storage for document {document.id}" ) return None logger.info(f"Verified embedding storage for document {document.id}") return embedding except Exception as e: logger.error(f"Error storing document embedding in ChromaDB: {str(e)}", exc_info=True) return None except Exception as e: logger.error(f"Error processing document embedding: {str(e)}", exc_info=True) raise def generate_chunk_embeddings(self, chunks: List[ChunkDTO]) -> List[ChunkDTO]: """Process chunk level embeddings""" """ Store in ChromaDB, the structure is as follows: documents=[c.content for c in unprocessed_chunks], ids=[str(c.id) for c in unprocessed_chunks], embeddings=embeddings.tolist(), metadatas=[ { "document_id": str(c.document_id), "topic": c.topic or "", "tags": ",".join(c.tags) if c.tags else "", } for c in unprocessed_chunks ], """ try: unprocessed_chunks = [c for c in chunks if not c.has_embedding] if not unprocessed_chunks: logger.info("No unprocessed chunks found") return chunks logger.info(f"Processing embeddings for {len(unprocessed_chunks)} chunks") contents = [c.content for c in unprocessed_chunks] logger.info("Getting embeddings from LLM service... {}".format(contents)) embeddings = self.llm_client.get_embedding(contents) if embeddings is None or len(embeddings) == 0: logger.error("Failed to get embeddings from LLM service") return chunks logger.info(f"Successfully got embeddings with shape: {embeddings.shape}") try: logger.info("Adding embeddings to ChromaDB...") self.chunk_collection.add( documents=[c.content for c in unprocessed_chunks], ids=[str(c.id) for c in unprocessed_chunks], embeddings=embeddings.tolist(), metadatas=[ { "document_id": str(c.document_id), "topic": c.topic or "", "tags": ",".join(c.tags) if c.tags else "", } for c in unprocessed_chunks ], ) logger.info("Successfully added embeddings to ChromaDB") # verify embeddings storage for chunk in unprocessed_chunks: result = self.chunk_collection.get( ids=[str(chunk.id)], include=["embeddings"] ) if result and result["embeddings"]: chunk.has_embedding = True logger.info(f"Verified embedding for chunk {chunk.id}") else: logger.warning( f"Failed to verify embedding for chunk {chunk.id}" ) chunk.has_embedding = False except Exception as e: logger.error(f"Error storing embeddings in ChromaDB: {str(e)}", exc_info=True) for chunk in unprocessed_chunks: chunk.has_embedding = False raise return chunks except Exception as e: logger.error(f"Error processing chunk embeddings: {str(e)}", exc_info=True) raise def get_chunk_embedding_by_chunk_id(self, chunk_id: int) -> Optional[List[float]]: """Get the corresponding embedding vector by chunk_id Args: chunk_id (int): chunk ID Returns: List[float]: embedding vector, return None if not found Raises: ValueError: when chunk_id is invalid Exception: other errors """ try: if not isinstance(chunk_id, int) or chunk_id < 0: raise ValueError("Invalid chunk_id") # query from ChromaDB result = self.chunk_collection.get( ids=[str(chunk_id)], include=["embeddings"] ) if not result or not result["embeddings"]: logger.warning(f"No embedding found for chunk {chunk_id}") return None return result["embeddings"][0] except Exception as e: logger.error(f"Error getting embedding for chunk {chunk_id}: {str(e)}") raise def get_document_embedding_by_document_id( self, document_id: int ) -> Optional[List[float]]: """Get the corresponding embedding vector by document_id Args: document_id (int): document ID Returns: List[float]: embedding vector, return None if not found Raises: ValueError: when document_id is invalid Exception: other errors """ try: if not isinstance(document_id, int) or document_id < 0: raise ValueError("Invalid document_id") # query from ChromaDB result = self.document_collection.get( ids=[str(document_id)], include=["embeddings"] ) if not result or not result["embeddings"]: logger.warning(f"No embedding found for document {document_id}") return None return result["embeddings"][0] except Exception as e: logger.error( f"Error getting embedding for document {document_id}: {str(e)}" ) raise def _handle_dimension_mismatch(self): """ Handle dimension mismatch between current embedding model and ChromaDB collections This method will reinitialize ChromaDB collections with the new dimension """ from lpm_kernel.file_data.chroma_utils import reinitialize_chroma_collections logger.warning(f"Detected dimension mismatch in ChromaDB collections. Reinitializing with dimension {self.dimension}") # Log the operation for better debugging logger.info(f"Calling reinitialize_chroma_collections with dimension {self.dimension}") try: success = reinitialize_chroma_collections(self.dimension) if success: logger.info(f"Successfully reinitialized ChromaDB collections with dimension {self.dimension}") # Refresh collection references try: self.document_collection = self.client.get_collection(name="documents") self.chunk_collection = self.client.get_collection(name="document_chunks") # Double-check dimensions after refresh doc_dimension = self.document_collection.metadata.get("dimension") chunk_dimension = self.chunk_collection.metadata.get("dimension") if doc_dimension != self.dimension or chunk_dimension != self.dimension: logger.error(f"Dimension mismatch after refresh: documents={doc_dimension}, chunks={chunk_dimension}, expected={self.dimension}") raise RuntimeError(f"Failed to handle dimension mismatch: collections have incorrect dimensions after reinitialization") except Exception as e: logger.error(f"Error refreshing collection references: {str(e)}", exc_info=True) raise RuntimeError(f"Failed to refresh ChromaDB collections after reinitialization: {str(e)}") else: logger.error("Failed to reinitialize ChromaDB collections") raise RuntimeError("Failed to handle dimension mismatch in ChromaDB collections") except Exception as e: logger.error(f"Error during dimension mismatch handling: {str(e)}", exc_info=True) raise RuntimeError(f"Failed to handle dimension mismatch in ChromaDB collections: {str(e)}") def search_similar_chunks( self, query: str, limit: int = 5 ) -> List[Tuple[ChunkDTO, float]]: """Search similar chunks, return list of ChunkDTO objects and their similarity scores Args: query (str): query text limit (int, optional): return result limit. Defaults to 5. Returns: List[Tuple[ChunkDTO, float]]: return list of (ChunkDTO, similarity score), sorted by similarity score in descending order Raises: ValueError: when query parameters are invalid Exception: other errors """ try: if not query or not query.strip(): raise ValueError("Query string cannot be empty") if limit < 1: raise ValueError("Limit must be positive") # calculate query text embedding query_embedding = self.llm_client.get_embedding([query]) if query_embedding is None or len(query_embedding) == 0: raise Exception("Failed to generate embedding for query") # query ChromaDB results = self.chunk_collection.query( query_embeddings=[query_embedding[0].tolist()], n_results=limit, include=["documents", "metadatas", "distances"], ) if not results or not results["ids"]: return [] # convert results to ChunkDTO objects similar_chunks = [] for i in range(len(results["ids"])): chunk_id = results["ids"][0][i] # ChromaDB returns nested lists document_id = results["metadatas"][0][i]["document_id"] content = results["documents"][0][i] topic = results["metadatas"][0][i].get("topic", "") tags = ( results["metadatas"][0][i].get("tags", "").split(",") if results["metadatas"][0][i].get("tags") else [] ) # calculate similarity score (ChromaDB returns distances, need to convert to similarity) similarity_score = ( 1 - results["distances"][0][i] ) # assume using Euclidean distance or cosine distance chunk = ChunkDTO( id=int(chunk_id), document_id=int(document_id), content=content, topic=topic, tags=tags, has_embedding=True, ) similar_chunks.append((chunk, similarity_score)) # sort by similarity score in descending order similar_chunks.sort(key=lambda x: x[1], reverse=True) return similar_chunks except ValueError as ve: logger.error(f"Invalid input parameters: {str(ve)}") raise except Exception as e: logger.error(f"Error searching similar chunks: {str(e)}") raise