File size: 2,526 Bytes
b36cb8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Vector database service for interacting with Qdrant
"""

from typing import List, Dict, Any

from fastapi import HTTPException # type: ignore
from qdrant_client import QdrantClient # type: ignore
from qdrant_client.models import Distance, PointStruct, VectorParams # type: ignore

class VectorDatabaseClient:
    """Class for interacting with Qdrant vector database"""
    
    def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int):
        self.url = url
        self.api_key = api_key
        self.collection_name = collection_name
        self.embedding_size = embedding_size
        self.client = QdrantClient(url=url, api_key=api_key)
        
    def ensure_collection_exists(self) -> None:
        """Ensure the Qdrant collection exists"""
        collections = self.client.get_collections()
        collection_names = [c.name for c in collections.collections]
        
        if self.collection_name not in collection_names:
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(
                    size=self.embedding_size,
                    distance=Distance.COSINE
                )
            )
            print(f"✅ Collection '{self.collection_name}' created.")
        else:
            print(f"ℹ️ Collection '{self.collection_name}' already exists.")
    
    def add_image(self, image_id: str, embedding: List[float], payload: Dict[str, Any]) -> None:
        """Add an image embedding to the database"""
        self.client.upsert(
            collection_name=self.collection_name,
            points=[
                PointStruct(
                    id=image_id,
                    vector=embedding,
                    payload=payload
                )
            ]
        )
    
    def search_by_vector(self, embedding: List[float], limit: int = 1) -> List[Dict[str, Any]]:
        """Search for similar images using an embedding vector"""
        results = self.client.search(
            collection_name=self.collection_name,
            query_vector=embedding,
            limit=limit
        )
        
        return [
            {
                "id": r.id,
                "score": r.score,
                "payload": r.payload
            }
            for r in results
        ]
    
    def list_collections(self) -> List[str]:
        """List all collections in the database"""
        return [c.name for c in self.client.get_collections().collections]