|
|
""" |
|
|
Vector database service for interacting with Qdrant |
|
|
""" |
|
|
|
|
|
from typing import List, Dict, Any |
|
|
|
|
|
from fastapi import HTTPException |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client.models import Distance, PointStruct, VectorParams |
|
|
|
|
|
class VectorDatabaseClient: |
|
|
"""Class for interacting with Qdrant vector database""" |
|
|
|
|
|
def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int): |
|
|
self.url = url |
|
|
self.api_key = api_key |
|
|
self.collection_name = collection_name |
|
|
self.embedding_size = embedding_size |
|
|
self.client = QdrantClient(url=url, api_key=api_key) |
|
|
|
|
|
def ensure_collection_exists(self) -> None: |
|
|
"""Ensure the Qdrant collection exists""" |
|
|
collections = self.client.get_collections() |
|
|
collection_names = [c.name for c in collections.collections] |
|
|
|
|
|
if self.collection_name not in collection_names: |
|
|
self.client.create_collection( |
|
|
collection_name=self.collection_name, |
|
|
vectors_config=VectorParams( |
|
|
size=self.embedding_size, |
|
|
distance=Distance.COSINE |
|
|
) |
|
|
) |
|
|
print(f"✅ Collection '{self.collection_name}' created.") |
|
|
else: |
|
|
print(f"ℹ️ Collection '{self.collection_name}' already exists.") |
|
|
|
|
|
def add_image(self, image_id: str, embedding: List[float], payload: Dict[str, Any]) -> None: |
|
|
"""Add an image embedding to the database""" |
|
|
self.client.upsert( |
|
|
collection_name=self.collection_name, |
|
|
points=[ |
|
|
PointStruct( |
|
|
id=image_id, |
|
|
vector=embedding, |
|
|
payload=payload |
|
|
) |
|
|
] |
|
|
) |
|
|
|
|
|
def search_by_vector(self, embedding: List[float], limit: int = 1) -> List[Dict[str, Any]]: |
|
|
"""Search for similar images using an embedding vector""" |
|
|
results = self.client.search( |
|
|
collection_name=self.collection_name, |
|
|
query_vector=embedding, |
|
|
limit=limit |
|
|
) |
|
|
|
|
|
return [ |
|
|
{ |
|
|
"id": r.id, |
|
|
"score": r.score, |
|
|
"payload": r.payload |
|
|
} |
|
|
for r in results |
|
|
] |
|
|
|
|
|
def list_collections(self) -> List[str]: |
|
|
"""List all collections in the database""" |
|
|
return [c.name for c in self.client.get_collections().collections] |
|
|
|