|
|
""" |
|
|
Vector Database Service implementation for Qdrant |
|
|
""" |
|
|
|
|
|
from typing import List, Dict, Any, Optional |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client.models import PointStruct, VectorParams, Distance, Record |
|
|
|
|
|
|
|
|
|
|
|
class VectorDatabaseClient: |
|
|
"""Client for interacting with Qdrant vector database""" |
|
|
|
|
|
def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int): |
|
|
"""Initialize Qdrant client and collection settings |
|
|
|
|
|
Args: |
|
|
url: Qdrant server URL |
|
|
api_key: API key for Qdrant |
|
|
collection_name: Name of the collection to use |
|
|
embedding_size: Size of embedding vectors |
|
|
""" |
|
|
self.client = QdrantClient(url=url, api_key=api_key) |
|
|
self.collection_name = collection_name |
|
|
self.embedding_size = embedding_size |
|
|
|
|
|
def ensure_collection_exists(self): |
|
|
"""Ensure the collection exists, create it if it doesn't""" |
|
|
collections = [c.name for c in self.client.get_collections().collections] |
|
|
|
|
|
if self.collection_name not in collections: |
|
|
self.client.create_collection( |
|
|
collection_name=self.collection_name, |
|
|
vectors_config=VectorParams( |
|
|
size=self.embedding_size, |
|
|
distance=Distance.COSINE |
|
|
) |
|
|
) |
|
|
print(f"✅ Collection '{self.collection_name}' created.") |
|
|
else: |
|
|
print(f"ℹ️ Collection '{self.collection_name}' already exists.") |
|
|
|
|
|
def add_embedding(self, id: str, embedding: List[float], filename: str, metadata: Optional[str] = None) -> str: |
|
|
"""Add an embedding to the collection |
|
|
|
|
|
Args: |
|
|
id: Unique ID for the point |
|
|
embedding: Vector embedding |
|
|
filename: Original filename |
|
|
metadata: Optional metadata as JSON string |
|
|
|
|
|
Returns: |
|
|
ID of the added point |
|
|
""" |
|
|
payload = {"filename": filename} |
|
|
if metadata: |
|
|
payload["metadata"] = metadata |
|
|
|
|
|
self.client.upsert( |
|
|
collection_name=self.collection_name, |
|
|
points=[ |
|
|
PointStruct( |
|
|
id=id, |
|
|
vector=embedding, |
|
|
payload=payload |
|
|
) |
|
|
] |
|
|
) |
|
|
return id |
|
|
|
|
|
def add_embedding_with_payload(self, id: str, embedding: List[float], payload: Dict[str, Any]) -> str: |
|
|
"""Add an embedding with a custom payload |
|
|
|
|
|
Args: |
|
|
id: Unique ID for the point |
|
|
embedding: Vector embedding |
|
|
payload: Dictionary of metadata to store |
|
|
|
|
|
Returns: |
|
|
ID of the added point |
|
|
""" |
|
|
self.client.upsert( |
|
|
collection_name=self.collection_name, |
|
|
points=[ |
|
|
PointStruct( |
|
|
id=id, |
|
|
vector=embedding, |
|
|
payload=payload |
|
|
) |
|
|
] |
|
|
) |
|
|
return id |
|
|
|
|
|
def search_by_embedding(self, embedding: List[float], limit: int = 5) -> List[Record]: |
|
|
"""Search for similar vectors |
|
|
|
|
|
Args: |
|
|
embedding: Query vector |
|
|
limit: Maximum number of results |
|
|
|
|
|
Returns: |
|
|
List of search results |
|
|
""" |
|
|
results = self.client.search( |
|
|
collection_name=self.collection_name, |
|
|
query_vector=embedding, |
|
|
limit=limit |
|
|
) |
|
|
return results |
|
|
|
|
|
def search_by_id(self, id: str, limit: int = 1) -> List[Record]: |
|
|
"""Search for similar vectors using an existing vector as query |
|
|
|
|
|
Args: |
|
|
id: ID of the existing vector to use as query |
|
|
limit: Maximum number of results |
|
|
|
|
|
Returns: |
|
|
List of search results |
|
|
""" |
|
|
|
|
|
vector = self.client.retrieve( |
|
|
collection_name=self.collection_name, |
|
|
ids=[id] |
|
|
) |
|
|
|
|
|
if not vector or len(vector) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
return self.search_by_embedding(vector[0].vector, limit) |
|
|
|
|
|
def delete_embedding(self, id: str) -> bool: |
|
|
"""Delete an embedding from the collection |
|
|
|
|
|
Args: |
|
|
id: ID of the embedding to delete |
|
|
|
|
|
Returns: |
|
|
True if deleted, False if not found |
|
|
""" |
|
|
self.client.delete( |
|
|
collection_name=self.collection_name, |
|
|
points_selector=[id] |
|
|
) |
|
|
return True |
|
|
|
|
|
def list_collections(self) -> List[str]: |
|
|
"""List all collections in the database |
|
|
|
|
|
Returns: |
|
|
List of collection names |
|
|
""" |
|
|
return [c.name for c in self.client.get_collections().collections] |