FastAPI / services /vector_db_service.py
ravi19's picture
Deploy FastAPI to HF Space
b36cb8b
"""
Vector database service for interacting with Qdrant
"""
from typing import List, Dict, Any
from fastapi import HTTPException # type: ignore
from qdrant_client import QdrantClient # type: ignore
from qdrant_client.models import Distance, PointStruct, VectorParams # type: ignore
class VectorDatabaseClient:
"""Class for interacting with Qdrant vector database"""
def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int):
self.url = url
self.api_key = api_key
self.collection_name = collection_name
self.embedding_size = embedding_size
self.client = QdrantClient(url=url, api_key=api_key)
def ensure_collection_exists(self) -> None:
"""Ensure the Qdrant collection exists"""
collections = self.client.get_collections()
collection_names = [c.name for c in collections.collections]
if self.collection_name not in collection_names:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.embedding_size,
distance=Distance.COSINE
)
)
print(f"✅ Collection '{self.collection_name}' created.")
else:
print(f"ℹ️ Collection '{self.collection_name}' already exists.")
def add_image(self, image_id: str, embedding: List[float], payload: Dict[str, Any]) -> None:
"""Add an image embedding to the database"""
self.client.upsert(
collection_name=self.collection_name,
points=[
PointStruct(
id=image_id,
vector=embedding,
payload=payload
)
]
)
def search_by_vector(self, embedding: List[float], limit: int = 1) -> List[Dict[str, Any]]:
"""Search for similar images using an embedding vector"""
results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
limit=limit
)
return [
{
"id": r.id,
"score": r.score,
"payload": r.payload
}
for r in results
]
def list_collections(self) -> List[str]:
"""List all collections in the database"""
return [c.name for c in self.client.get_collections().collections]