Deploy FastAPI to HF Space
Browse files- .gitignore +4 -0
- New folder/.env1 +2 -0
- New folder/embedding_service2.py +97 -0
- New folder/vector_db_service2.py +154 -0
- README.md +97 -12
- __pycache__/config.cpython-312.pyc +0 -0
- api/__init__.py +0 -0
- api/__pycache__/__init__.cpython-312.pyc +0 -0
- api/__pycache__/routes.cpython-312.pyc +0 -0
- api/routes.py +109 -0
- api/routes1.py +98 -0
- app.py +125 -0
- config.py +21 -0
- models/__init__.py +5 -0
- models/__pycache__/__init__.cpython-312.pyc +0 -0
- models/__pycache__/schemas.cpython-312.pyc +0 -0
- models/schemas.py +22 -0
- requirements.txt +13 -0
- services/__init__.py +5 -0
- services/__pycache__/__init__.cpython-312.pyc +0 -0
- services/__pycache__/embedding_service.cpython-312.pyc +0 -0
- services/__pycache__/security_service.cpython-312.pyc +0 -0
- services/__pycache__/vector_db_service.cpython-312.pyc +0 -0
- services/embedding_service.py +88 -0
- services/security_service.py +20 -0
- services/vector_db_service.py +70 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
venv/
|
| 2 |
+
env/
|
| 3 |
+
ENV/
|
| 4 |
+
.venv/
|
New folder/.env1
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
QDRANT_URL='https://b6138c60-0a19-4ba7-b6a5-f70a7d653b57.us-west-1-0.aws.cloud.qdrant.io'
|
| 2 |
+
QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.XQrkVFAz02zgcvVYbmoneq36biKdbP6491n5I-RrCpQ"
|
New folder/embedding_service2.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Embedding Service for generating image embeddings
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import List, Dict, Any
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import io
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ImageEmbeddingModel:
|
| 15 |
+
"""Class for generating embeddings from images using CLIP"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
|
| 18 |
+
"""Initialize the CLIP model
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
model_name: Name of the CLIP model to use
|
| 22 |
+
"""
|
| 23 |
+
self.model_name = model_name
|
| 24 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
|
| 26 |
+
self.processor = CLIPProcessor.from_pretrained(model_name)
|
| 27 |
+
|
| 28 |
+
def generate_embedding(self, image_data: bytes) -> List[float]:
|
| 29 |
+
"""Generate embedding for an image from binary data
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
image_data: Binary image data
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Image embedding as a list of floats
|
| 36 |
+
"""
|
| 37 |
+
# Load image from binary data
|
| 38 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 39 |
+
return self.generate_embedding_from_pil(image)
|
| 40 |
+
|
| 41 |
+
def generate_embedding_from_pil(self, image: Image.Image) -> List[float]:
|
| 42 |
+
"""Generate embedding for a PIL Image
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
image: PIL Image object
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Image embedding as a list of floats
|
| 49 |
+
"""
|
| 50 |
+
# Process image for CLIP
|
| 51 |
+
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
| 52 |
+
|
| 53 |
+
# Generate embedding
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
image_features = self.model.get_image_features(**inputs)
|
| 56 |
+
|
| 57 |
+
# Normalize embedding and convert to list
|
| 58 |
+
image_embedding = image_features.cpu().numpy()[0]
|
| 59 |
+
normalized_embedding = image_embedding / np.linalg.norm(image_embedding)
|
| 60 |
+
return normalized_embedding.tolist()
|
| 61 |
+
|
| 62 |
+
def get_embeddings_from_folder(self, folder_path: str) -> Dict[str, Any]:
|
| 63 |
+
"""Generate embeddings for all images in a folder
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
folder_path: Path to folder containing images
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Dictionary mapping filenames to embeddings
|
| 70 |
+
"""
|
| 71 |
+
results = {}
|
| 72 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'}
|
| 73 |
+
|
| 74 |
+
# Check if folder exists
|
| 75 |
+
if not os.path.exists(folder_path):
|
| 76 |
+
return {"error": f"Folder {folder_path} does not exist"}
|
| 77 |
+
|
| 78 |
+
# Process each image file
|
| 79 |
+
for filename in os.listdir(folder_path):
|
| 80 |
+
if os.path.splitext(filename)[1].lower() in image_extensions:
|
| 81 |
+
try:
|
| 82 |
+
file_path = os.path.join(folder_path, filename)
|
| 83 |
+
with open(file_path, 'rb') as f:
|
| 84 |
+
image_data = f.read()
|
| 85 |
+
|
| 86 |
+
embedding = self.generate_embedding(image_data)
|
| 87 |
+
results[filename] = {
|
| 88 |
+
"embedding": embedding,
|
| 89 |
+
"status": "success"
|
| 90 |
+
}
|
| 91 |
+
except Exception as e:
|
| 92 |
+
results[filename] = {
|
| 93 |
+
"error": str(e),
|
| 94 |
+
"status": "failed"
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
return results
|
New folder/vector_db_service2.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vector Database Service implementation for Qdrant
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Dict, Any, Optional
|
| 6 |
+
from qdrant_client import QdrantClient
|
| 7 |
+
from qdrant_client.models import PointStruct, VectorParams, Distance, Record
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VectorDatabaseClient:
|
| 12 |
+
"""Client for interacting with Qdrant vector database"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int):
|
| 15 |
+
"""Initialize Qdrant client and collection settings
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
url: Qdrant server URL
|
| 19 |
+
api_key: API key for Qdrant
|
| 20 |
+
collection_name: Name of the collection to use
|
| 21 |
+
embedding_size: Size of embedding vectors
|
| 22 |
+
"""
|
| 23 |
+
self.client = QdrantClient(url=url, api_key=api_key)
|
| 24 |
+
self.collection_name = collection_name
|
| 25 |
+
self.embedding_size = embedding_size
|
| 26 |
+
|
| 27 |
+
def ensure_collection_exists(self):
|
| 28 |
+
"""Ensure the collection exists, create it if it doesn't"""
|
| 29 |
+
collections = [c.name for c in self.client.get_collections().collections]
|
| 30 |
+
|
| 31 |
+
if self.collection_name not in collections:
|
| 32 |
+
self.client.create_collection(
|
| 33 |
+
collection_name=self.collection_name,
|
| 34 |
+
vectors_config=VectorParams(
|
| 35 |
+
size=self.embedding_size,
|
| 36 |
+
distance=Distance.COSINE
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
print(f"✅ Collection '{self.collection_name}' created.")
|
| 40 |
+
else:
|
| 41 |
+
print(f"ℹ️ Collection '{self.collection_name}' already exists.")
|
| 42 |
+
|
| 43 |
+
def add_embedding(self, id: str, embedding: List[float], filename: str, metadata: Optional[str] = None) -> str:
|
| 44 |
+
"""Add an embedding to the collection
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
id: Unique ID for the point
|
| 48 |
+
embedding: Vector embedding
|
| 49 |
+
filename: Original filename
|
| 50 |
+
metadata: Optional metadata as JSON string
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
ID of the added point
|
| 54 |
+
"""
|
| 55 |
+
payload = {"filename": filename}
|
| 56 |
+
if metadata:
|
| 57 |
+
payload["metadata"] = metadata
|
| 58 |
+
|
| 59 |
+
self.client.upsert(
|
| 60 |
+
collection_name=self.collection_name,
|
| 61 |
+
points=[
|
| 62 |
+
PointStruct(
|
| 63 |
+
id=id,
|
| 64 |
+
vector=embedding,
|
| 65 |
+
payload=payload
|
| 66 |
+
)
|
| 67 |
+
]
|
| 68 |
+
)
|
| 69 |
+
return id
|
| 70 |
+
|
| 71 |
+
def add_embedding_with_payload(self, id: str, embedding: List[float], payload: Dict[str, Any]) -> str:
|
| 72 |
+
"""Add an embedding with a custom payload
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
id: Unique ID for the point
|
| 76 |
+
embedding: Vector embedding
|
| 77 |
+
payload: Dictionary of metadata to store
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
ID of the added point
|
| 81 |
+
"""
|
| 82 |
+
self.client.upsert(
|
| 83 |
+
collection_name=self.collection_name,
|
| 84 |
+
points=[
|
| 85 |
+
PointStruct(
|
| 86 |
+
id=id,
|
| 87 |
+
vector=embedding,
|
| 88 |
+
payload=payload
|
| 89 |
+
)
|
| 90 |
+
]
|
| 91 |
+
)
|
| 92 |
+
return id
|
| 93 |
+
|
| 94 |
+
def search_by_embedding(self, embedding: List[float], limit: int = 5) -> List[Record]:
|
| 95 |
+
"""Search for similar vectors
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
embedding: Query vector
|
| 99 |
+
limit: Maximum number of results
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
List of search results
|
| 103 |
+
"""
|
| 104 |
+
results = self.client.search(
|
| 105 |
+
collection_name=self.collection_name,
|
| 106 |
+
query_vector=embedding,
|
| 107 |
+
limit=limit
|
| 108 |
+
)
|
| 109 |
+
return results
|
| 110 |
+
|
| 111 |
+
def search_by_id(self, id: str, limit: int = 1) -> List[Record]:
|
| 112 |
+
"""Search for similar vectors using an existing vector as query
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
id: ID of the existing vector to use as query
|
| 116 |
+
limit: Maximum number of results
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
List of search results
|
| 120 |
+
"""
|
| 121 |
+
# Get the vector by ID
|
| 122 |
+
vector = self.client.retrieve(
|
| 123 |
+
collection_name=self.collection_name,
|
| 124 |
+
ids=[id]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if not vector or len(vector) == 0:
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
# Use the vector to search
|
| 131 |
+
return self.search_by_embedding(vector[0].vector, limit)
|
| 132 |
+
|
| 133 |
+
def delete_embedding(self, id: str) -> bool:
|
| 134 |
+
"""Delete an embedding from the collection
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
id: ID of the embedding to delete
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
True if deleted, False if not found
|
| 141 |
+
"""
|
| 142 |
+
self.client.delete(
|
| 143 |
+
collection_name=self.collection_name,
|
| 144 |
+
points_selector=[id]
|
| 145 |
+
)
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
def list_collections(self) -> List[str]:
|
| 149 |
+
"""List all collections in the database
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
List of collection names
|
| 153 |
+
"""
|
| 154 |
+
return [c.name for c in self.client.get_collections().collections]
|
README.md
CHANGED
|
@@ -1,12 +1,97 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image Similarity Search API
|
| 2 |
+
|
| 3 |
+
A FastAPI application for image similarity search using CLIP embeddings and Qdrant vector database.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- Upload images and store their vector embeddings
|
| 8 |
+
- Search for similar images using an uploaded image or base64 encoded image
|
| 9 |
+
- Secure API with API key authentication
|
| 10 |
+
- Well-organized, modular codebase following OOP principles
|
| 11 |
+
|
| 12 |
+
## Installation
|
| 13 |
+
|
| 14 |
+
1. Clone this repository
|
| 15 |
+
2. Install dependencies:
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
pip install -r requirements.txt
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
3. Set up environment variables (optional, defaults are provided):
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
export QDRANT_URL="your-qdrant-url"
|
| 25 |
+
export QDRANT_API_KEY="your-qdrant-api-key"
|
| 26 |
+
export COLLECTION_NAME="your-collection-name"
|
| 27 |
+
export API_KEY="your-api-key"
|
| 28 |
+
export PORT=8000
|
| 29 |
+
export ENVIRONMENT="production" # Or "development" for debug mode with auto-reload
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## Usage
|
| 33 |
+
|
| 34 |
+
Run the application:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
python app.py
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
The API will be available at http://localhost:8000 (or the port specified in environment variables).
|
| 41 |
+
|
| 42 |
+
### API Documentation
|
| 43 |
+
|
| 44 |
+
Once running, API documentation is available at:
|
| 45 |
+
- Swagger UI: http://localhost:8000/docs
|
| 46 |
+
- ReDoc: http://localhost:8000/redoc
|
| 47 |
+
|
| 48 |
+
## API Endpoints
|
| 49 |
+
|
| 50 |
+
- `POST /add-image/`: Add an image to the database
|
| 51 |
+
- `POST /add-images-from-folder/`: Add all images from a folder to the database
|
| 52 |
+
- `POST /search-by-image/`: Search for similar images using an uploaded image
|
| 53 |
+
- `POST /search-by-image-scan/`: Search for similar images using a base64 encoded image
|
| 54 |
+
- `GET /collections`: List all collections in the database
|
| 55 |
+
- `GET /health`: Health check endpoint
|
| 56 |
+
|
| 57 |
+
## Project Structure
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
image_similarity_api/
|
| 61 |
+
│
|
| 62 |
+
├── app.py # Main application entry point
|
| 63 |
+
├── config.py # Configuration settings
|
| 64 |
+
├── models/
|
| 65 |
+
│ ├── __init__.py
|
| 66 |
+
│ └── schemas.py # Pydantic models
|
| 67 |
+
├── services/
|
| 68 |
+
│ ├── __init__.py
|
| 69 |
+
│ ├── embedding.py # Image embedding service
|
| 70 |
+
│ ├── security.py # Security service
|
| 71 |
+
│ └── vector_db.py # Vector database service
|
| 72 |
+
├── api/
|
| 73 |
+
│ ├── __init__.py
|
| 74 |
+
│ └── routes.py # API routes
|
| 75 |
+
├── requirements.txt # Project dependencies
|
| 76 |
+
└── README.md # Project documentation
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## Development
|
| 80 |
+
|
| 81 |
+
For development, set the ENVIRONMENT variable to "development" for auto-reload:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
export ENVIRONMENT="development"
|
| 85 |
+
python app.py
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## Deployment
|
| 89 |
+
|
| 90 |
+
This application can be deployed to any platform that supports Python applications:
|
| 91 |
+
|
| 92 |
+
1. Docker
|
| 93 |
+
2. Kubernetes
|
| 94 |
+
3. Cloud platforms (AWS, GCP, Azure, etc.)
|
| 95 |
+
4. Serverless platforms (with appropriate adapters)
|
| 96 |
+
|
| 97 |
+
Remember to set all required environment variables in your production environment.
|
__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
api/__init__.py
ADDED
|
File without changes
|
api/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (127 Bytes). View file
|
|
|
api/__pycache__/routes.cpython-312.pyc
ADDED
|
Binary file (5.01 kB). View file
|
|
|
api/routes.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Routes for the Image Similarity Search API
|
| 3 |
+
Contains all endpoints for the application using your original route implementation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import uuid
|
| 7 |
+
import base64
|
| 8 |
+
import io
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path # type: ignore
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
from services.embedding_service import ImageEmbeddingModel
|
| 15 |
+
from services.vector_db_service import VectorDatabaseClient
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Base64ImageRequest(BaseModel):
|
| 19 |
+
"""Request model for base64 encoded images"""
|
| 20 |
+
image_data: str
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def register_routes(
|
| 24 |
+
app: FastAPI,
|
| 25 |
+
embedding_model: ImageEmbeddingModel,
|
| 26 |
+
vector_db: VectorDatabaseClient,
|
| 27 |
+
):
|
| 28 |
+
"""Register all routes with the FastAPI app"""
|
| 29 |
+
|
| 30 |
+
@app.api_route("/", methods=["GET", "HEAD"])
|
| 31 |
+
async def read_root():
|
| 32 |
+
return {"status": "API running"}
|
| 33 |
+
|
| 34 |
+
@app.post("/add-image/")
|
| 35 |
+
async def add_image(
|
| 36 |
+
file: UploadFile = File(...),
|
| 37 |
+
item_name: str = Form(...),
|
| 38 |
+
design_name: str = Form(...),
|
| 39 |
+
item_price: float = Form(...)
|
| 40 |
+
):
|
| 41 |
+
"""Upload an image with product details and store its embedding"""
|
| 42 |
+
# Process the image to get embedding
|
| 43 |
+
# image_data = await file.read()
|
| 44 |
+
embedding = await embedding_model.get_embedding_from_upload(file)
|
| 45 |
+
|
| 46 |
+
# Generate a unique ID
|
| 47 |
+
image_id = str(uuid.uuid4())
|
| 48 |
+
|
| 49 |
+
# Store additional metadata in payload
|
| 50 |
+
payload = {
|
| 51 |
+
"filename": file.filename,
|
| 52 |
+
"item_name": item_name,
|
| 53 |
+
"design_name": design_name,
|
| 54 |
+
"item_price": item_price
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Store in vector database
|
| 58 |
+
vector_db.add_image(image_id, embedding, payload)
|
| 59 |
+
|
| 60 |
+
return {"message": "Image added successfully", "id": image_id}
|
| 61 |
+
|
| 62 |
+
@app.post("/add-images-from-folder/")
|
| 63 |
+
async def add_images_from_folder(folder_path: str):
|
| 64 |
+
"""Process and add all images from a specified folder"""
|
| 65 |
+
embeddings = embedding_model.get_embeddings_from_folder(folder_path)
|
| 66 |
+
return {"embeddings": embeddings}
|
| 67 |
+
|
| 68 |
+
@app.post("/search-by-image/")
|
| 69 |
+
async def search_by_image(file: UploadFile = File(...)):
|
| 70 |
+
"""Search for similar images by uploading a file"""
|
| 71 |
+
# Process the image to get embedding
|
| 72 |
+
# image_data = await file.read()
|
| 73 |
+
embedding = await embedding_model.get_embedding_from_upload(file)
|
| 74 |
+
|
| 75 |
+
# Search using the embedding
|
| 76 |
+
results = vector_db.search_by_vector(embedding, limit=1)
|
| 77 |
+
|
| 78 |
+
# return [
|
| 79 |
+
# {
|
| 80 |
+
# "id": r.id,
|
| 81 |
+
# "score": r.score,
|
| 82 |
+
# "payload": r.payload
|
| 83 |
+
# }
|
| 84 |
+
# for r in results
|
| 85 |
+
# ]
|
| 86 |
+
return results
|
| 87 |
+
|
| 88 |
+
@app.post("/search-by-image-scan/")
|
| 89 |
+
async def search_by_image_scan(request: Base64ImageRequest):
|
| 90 |
+
"""Search for similar images using a base64 encoded image"""
|
| 91 |
+
# Decode base64 image
|
| 92 |
+
image_data = request.image_data
|
| 93 |
+
image_bytes = base64.b64decode(image_data.split(',')[1] if ',' in image_data else image_data)
|
| 94 |
+
|
| 95 |
+
# Convert to PIL Image
|
| 96 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 97 |
+
|
| 98 |
+
# Process image to get embedding
|
| 99 |
+
embedding = embedding_model.get_embedding_from_pil(image)
|
| 100 |
+
|
| 101 |
+
# Search using the embedding
|
| 102 |
+
results = vector_db.search_by_vector(embedding, limit=1)
|
| 103 |
+
|
| 104 |
+
return results
|
| 105 |
+
|
| 106 |
+
@app.get("/collections")
|
| 107 |
+
def list_collections():
|
| 108 |
+
"""List all available collections in the vector database"""
|
| 109 |
+
return vector_db.list_collections()
|
api/routes1.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Routes for the Image Similarity Search API
|
| 3 |
+
Contains all endpoints for the application
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
from services.embedding_service import ImageEmbeddingModel
|
| 11 |
+
from services.vector_db_service import VectorDatabaseClient
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SearchResponse(BaseModel):
|
| 15 |
+
"""Response model for search results"""
|
| 16 |
+
image_id: str
|
| 17 |
+
similarity: float
|
| 18 |
+
metadata: Optional[dict] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def register_routes(
|
| 22 |
+
app: FastAPI,
|
| 23 |
+
embedding_model: ImageEmbeddingModel,
|
| 24 |
+
vector_db: VectorDatabaseClient,
|
| 25 |
+
# Remove security_service parameter
|
| 26 |
+
):
|
| 27 |
+
"""Register all routes with the FastAPI app"""
|
| 28 |
+
router = APIRouter()
|
| 29 |
+
|
| 30 |
+
@router.post("/upload", response_model=dict)
|
| 31 |
+
async def upload_image(
|
| 32 |
+
file: UploadFile = File(...),
|
| 33 |
+
metadata: Optional[str] = Form(None),
|
| 34 |
+
# Remove security dependency: api_key: str = Depends(security_service.verify_api_key)
|
| 35 |
+
):
|
| 36 |
+
"""Upload an image and store its embedding"""
|
| 37 |
+
# Process the image and generate embedding
|
| 38 |
+
image_data = await file.read()
|
| 39 |
+
embedding = embedding_model.generate_embedding(image_data)
|
| 40 |
+
|
| 41 |
+
# Store in vector database with optional metadata
|
| 42 |
+
image_id = vector_db.add_embedding(embedding, file.filename, metadata)
|
| 43 |
+
|
| 44 |
+
return {"image_id": image_id, "message": "Image uploaded successfully"}
|
| 45 |
+
|
| 46 |
+
@router.get("/search/by-id/{image_id}", response_model=List[SearchResponse])
|
| 47 |
+
async def search_by_id(
|
| 48 |
+
image_id: str = Path(..., description="ID of the uploaded image to use as query"),
|
| 49 |
+
limit: int = Query(5, description="Maximum number of results to return"),
|
| 50 |
+
# Remove security dependency: api_key: str = Depends(security_service.verify_api_key)
|
| 51 |
+
):
|
| 52 |
+
"""Search for similar images using an existing image ID as the query"""
|
| 53 |
+
results = vector_db.search_by_id(image_id, limit)
|
| 54 |
+
return [
|
| 55 |
+
SearchResponse(
|
| 56 |
+
image_id=result.id,
|
| 57 |
+
similarity=result.score,
|
| 58 |
+
metadata=result.metadata
|
| 59 |
+
)
|
| 60 |
+
for result in results
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
@router.post("/search/by-image", response_model=List[SearchResponse])
|
| 64 |
+
async def search_by_image(
|
| 65 |
+
file: UploadFile = File(...),
|
| 66 |
+
limit: int = Query(5, description="Maximum number of results to return"),
|
| 67 |
+
# Remove security dependency: api_key: str = Depends(security_service.verify_api_key)
|
| 68 |
+
):
|
| 69 |
+
"""Search for similar images by uploading a new image"""
|
| 70 |
+
# Process the image and generate embedding
|
| 71 |
+
image_data = await file.read()
|
| 72 |
+
embedding = embedding_model.generate_embedding(image_data)
|
| 73 |
+
|
| 74 |
+
# Search using the embedding
|
| 75 |
+
results = vector_db.search_by_embedding(embedding, limit)
|
| 76 |
+
return [
|
| 77 |
+
SearchResponse(
|
| 78 |
+
image_id=result.id,
|
| 79 |
+
similarity=result.score,
|
| 80 |
+
metadata=result.metadata
|
| 81 |
+
)
|
| 82 |
+
for result in results
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
@router.delete("/images/{image_id}")
|
| 86 |
+
async def delete_image(
|
| 87 |
+
image_id: str = Path(..., description="ID of the image to delete"),
|
| 88 |
+
# Remove security dependency: api_key: str = Depends(security_service.verify_api_key)
|
| 89 |
+
):
|
| 90 |
+
"""Delete an image from the database"""
|
| 91 |
+
success = vector_db.delete_embedding(image_id)
|
| 92 |
+
if success:
|
| 93 |
+
return {"message": f"Image {image_id} deleted successfully"}
|
| 94 |
+
return {"message": f"Image {image_id} not found"}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Add the router to the app
|
| 98 |
+
app.include_router(router, prefix="/api/v1")
|
app.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image Similarity Search API with FastAPI and Qdrant - Fixed Access
|
| 3 |
+
This application provides endpoints for uploading images and searching for similar images
|
| 4 |
+
using vector embeddings from the CLIP model. Implemented using OOP principles.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import uvicorn # type: ignore
|
| 8 |
+
from fastapi import FastAPI # type: ignore
|
| 9 |
+
from contextlib import asynccontextmanager
|
| 10 |
+
import os
|
| 11 |
+
import ssl
|
| 12 |
+
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
| 13 |
+
|
| 14 |
+
from config import Config
|
| 15 |
+
from services.embedding_service import ImageEmbeddingModel
|
| 16 |
+
from services.vector_db_service import VectorDatabaseClient
|
| 17 |
+
from api.routes import register_routes
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@asynccontextmanager
|
| 21 |
+
async def lifespan(app: FastAPI):
|
| 22 |
+
"""Lifespan context manager for FastAPI application startup and shutdown events"""
|
| 23 |
+
# This runs before the application starts
|
| 24 |
+
vector_db = app.state.vector_db
|
| 25 |
+
vector_db.ensure_collection_exists()
|
| 26 |
+
|
| 27 |
+
yield # This yields control back to FastAPI
|
| 28 |
+
|
| 29 |
+
# This runs when the application is shutting down
|
| 30 |
+
# Cleanup code can go here if needed
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ImageSimilarityAPI:
|
| 34 |
+
"""Main application class that orchestrates all components"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
# Initialize config
|
| 38 |
+
self.config = Config()
|
| 39 |
+
|
| 40 |
+
# Initialize components
|
| 41 |
+
self.embedding_model = ImageEmbeddingModel(self.config.model_name)
|
| 42 |
+
self.vector_db = VectorDatabaseClient(
|
| 43 |
+
self.config.qdrant_url,
|
| 44 |
+
self.config.qdrant_api_key,
|
| 45 |
+
self.config.collection_name,
|
| 46 |
+
self.config.embedding_size
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Initialize FastAPI app with lifespan handler
|
| 50 |
+
self.app = FastAPI(
|
| 51 |
+
title="Image Similarity Search API",
|
| 52 |
+
description="API for uploading images and searching for similar images using CLIP embeddings",
|
| 53 |
+
version="1.0.0",
|
| 54 |
+
lifespan=lifespan
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# ✅ Enable CORS to allow mobile access
|
| 58 |
+
self.app.add_middleware(
|
| 59 |
+
CORSMiddleware,
|
| 60 |
+
allow_origins=["*"], # Or set to ["http://192.168.1.42"] for better security
|
| 61 |
+
allow_credentials=True,
|
| 62 |
+
allow_methods=["*"],
|
| 63 |
+
allow_headers=["*"],
|
| 64 |
+
)
|
| 65 |
+
# Store vector_db in app state for use in lifespan
|
| 66 |
+
self.app.state.vector_db = self.vector_db
|
| 67 |
+
|
| 68 |
+
# Register routes
|
| 69 |
+
register_routes(self.app, self.embedding_model, self.vector_db)
|
| 70 |
+
|
| 71 |
+
def run(self, use_https=False, cert_file="./certs/cert.pem", key_file="./certs/key.pem"):
|
| 72 |
+
"""Run the FastAPI application with optional HTTPS support
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
use_https: Whether to use HTTPS or plain HTTP
|
| 76 |
+
cert_file: Path to SSL certificate file
|
| 77 |
+
key_file: Path to SSL private key file
|
| 78 |
+
"""
|
| 79 |
+
host = "0.0.0.0" # Use localhost instead of 0.0.0.0 for better access
|
| 80 |
+
port = 8000 if not use_https else 8443
|
| 81 |
+
|
| 82 |
+
ssl_context = None
|
| 83 |
+
if use_https:
|
| 84 |
+
# Check if certificate files exist
|
| 85 |
+
if not os.path.exists(cert_file) or not os.path.exists(key_file):
|
| 86 |
+
print(f"ERROR: SSL certificate files not found at {cert_file} and/or {key_file}")
|
| 87 |
+
print("Falling back to HTTP. To use HTTPS, please provide valid certificate files.")
|
| 88 |
+
use_https = False
|
| 89 |
+
else:
|
| 90 |
+
# Create SSL context for HTTPS
|
| 91 |
+
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
| 92 |
+
ssl_context.load_cert_chain(cert_file, key_file)
|
| 93 |
+
|
| 94 |
+
# Print access URLs for convenience
|
| 95 |
+
protocol = "https" if use_https else "http"
|
| 96 |
+
print(f"\n{'='*50}")
|
| 97 |
+
print(f"Access the API at: {protocol}://{host}:{port}")
|
| 98 |
+
print(f"Swagger UI available at: {protocol}://{host}:{port}/docs")
|
| 99 |
+
print(f"ReDoc UI available at: {protocol}://{host}:{port}/redoc")
|
| 100 |
+
print(f"{'='*50}\n")
|
| 101 |
+
|
| 102 |
+
uvicorn.run(
|
| 103 |
+
self.app,
|
| 104 |
+
host=host,
|
| 105 |
+
port=port,
|
| 106 |
+
reload=self.config.environment == "development",
|
| 107 |
+
ssl_certfile=cert_file if use_https else None,
|
| 108 |
+
ssl_keyfile=key_file if use_https else None
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_app() -> FastAPI:
|
| 113 |
+
"""Create and return the FastAPI application"""
|
| 114 |
+
api = ImageSimilarityAPI()
|
| 115 |
+
return api.app
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
api = ImageSimilarityAPI()
|
| 120 |
+
# Set to False for now until certificates are properly set up
|
| 121 |
+
api.run(
|
| 122 |
+
use_https=False, # Change to True when certificates are ready
|
| 123 |
+
cert_file="./certs/cert.pem",
|
| 124 |
+
key_file="./certs/key.pem"
|
| 125 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the Image Similarity API
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Config:
|
| 9 |
+
"""Configuration class for the application"""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.qdrant_url = os.getenv("QDRANT_URL",
|
| 13 |
+
"https://b6138c60-0a19-4ba7-b6a5-f70a7d653b57.us-west-1-0.aws.cloud.qdrant.io")
|
| 14 |
+
self.qdrant_api_key = os.getenv("QDRANT_API_KEY",
|
| 15 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.XQrkVFAz02zgcvVYbmoneq36biKdbP6491n5I-RrCpQ")
|
| 16 |
+
self.collection_name = os.getenv("COLLECTION_NAME", "marqe_embedings")
|
| 17 |
+
# self.api_key = os.getenv("API_KEY", "your-api-key-here")
|
| 18 |
+
self.model_name = os.getenv("MODEL_NAME", "hf-hub:Marqo/marqo-ecommerce-embeddings-L")
|
| 19 |
+
self.embedding_size = 768
|
| 20 |
+
self.port = int(os.getenv("PORT", 8000))
|
| 21 |
+
self.environment = os.getenv("ENVIRONMENT", "production")
|
models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Empty __init__.py file to make the models directory a proper Python package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# This file is intentionally left empty
|
models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
models/__pycache__/schemas.cpython-312.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
models/schemas.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic model schemas for API request and response types
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Base64ImageRequest(BaseModel):
|
| 9 |
+
"""Model for accepting base64 encoded images"""
|
| 10 |
+
image_data: str
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SearchResult(BaseModel):
|
| 14 |
+
"""Model for search results"""
|
| 15 |
+
id: str
|
| 16 |
+
score: float
|
| 17 |
+
payload: dict
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ErrorResponse(BaseModel):
|
| 21 |
+
"""Model for error responses"""
|
| 22 |
+
detail: str
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt
|
| 2 |
+
torch
|
| 3 |
+
open_clip_torch
|
| 4 |
+
fastapi
|
| 5 |
+
uvicorn
|
| 6 |
+
qdrant-client
|
| 7 |
+
nest_asyncio
|
| 8 |
+
python-multipart
|
| 9 |
+
pillow
|
| 10 |
+
numpy
|
| 11 |
+
pydantic
|
| 12 |
+
python-dotenv
|
| 13 |
+
|
services/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Empty __init__.py file to make the api directory a proper Python package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# This file is intentionally left empty
|
services/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (223 Bytes). View file
|
|
|
services/__pycache__/embedding_service.cpython-312.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
services/__pycache__/security_service.cpython-312.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
services/__pycache__/vector_db_service.cpython-312.pyc
ADDED
|
Binary file (3.83 kB). View file
|
|
|
services/embedding_service.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image embedding service for generating vector embeddings from images
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import base64
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
+
|
| 10 |
+
import open_clip
|
| 11 |
+
import torch
|
| 12 |
+
from fastapi import UploadFile, HTTPException
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ImageEmbeddingModel:
|
| 18 |
+
"""Class for handling image embedding using CLIP model"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, model_name: str):
|
| 21 |
+
self.model_name = model_name
|
| 22 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
self.model, self.preprocess_train, self.preprocess_val = self._initialize_model()
|
| 24 |
+
|
| 25 |
+
def _initialize_model(self) -> Tuple:
|
| 26 |
+
"""Initialize the CLIP model for image embeddings"""
|
| 27 |
+
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(self.model_name)
|
| 28 |
+
tokenizer = open_clip.get_tokenizer(self.model_name)
|
| 29 |
+
model.to(self.device)
|
| 30 |
+
model.eval()
|
| 31 |
+
return model, preprocess_train, preprocess_val
|
| 32 |
+
|
| 33 |
+
def get_embedding_from_pil(self, image: Image.Image) -> List[float]:
|
| 34 |
+
"""Get embedding from PIL image"""
|
| 35 |
+
processed_image = self.preprocess_val(image).unsqueeze(0).to(self.device)
|
| 36 |
+
|
| 37 |
+
# with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu')
|
| 38 |
+
if self.device == 'cuda':
|
| 39 |
+
autocast_context = torch.amp.autocast(device_type='cuda')
|
| 40 |
+
else:
|
| 41 |
+
# On CPU, autocast should either be skipped or forced to float32
|
| 42 |
+
autocast_context = torch.amp.autocast(device_type='cpu', dtype=torch.float32)
|
| 43 |
+
with torch.no_grad(), autocast_context:
|
| 44 |
+
image_features = self.model.encode_image(processed_image, normalize=True)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
return image_features.cpu().numpy()[0].tolist()
|
| 48 |
+
|
| 49 |
+
async def get_embedding_from_upload(self, image_file: UploadFile) -> List[float]:
|
| 50 |
+
"""Get embedding from uploaded image file"""
|
| 51 |
+
try:
|
| 52 |
+
contents = await image_file.read()
|
| 53 |
+
img = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 54 |
+
return self.get_embedding_from_pil(img)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}")
|
| 57 |
+
|
| 58 |
+
def get_embedding_from_base64(self, base64_data: str) -> List[float]:
|
| 59 |
+
"""Get embedding from base64 encoded image"""
|
| 60 |
+
try:
|
| 61 |
+
# Handle data URI format
|
| 62 |
+
if ',' in base64_data:
|
| 63 |
+
base64_data = base64_data.split(',')[1]
|
| 64 |
+
|
| 65 |
+
image_bytes = base64.b64decode(base64_data)
|
| 66 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 67 |
+
|
| 68 |
+
return self.get_embedding_from_pil(image)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}")
|
| 71 |
+
|
| 72 |
+
def get_embeddings_from_folder(self, image_folder: str) -> List[List[float]]:
|
| 73 |
+
"""Get embeddings from all images in a folder"""
|
| 74 |
+
embeddings = []
|
| 75 |
+
|
| 76 |
+
if not os.path.exists(image_folder):
|
| 77 |
+
raise HTTPException(status_code=404, detail=f"Folder not found: {image_folder}")
|
| 78 |
+
|
| 79 |
+
for image_name in os.listdir(image_folder):
|
| 80 |
+
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
|
| 81 |
+
try:
|
| 82 |
+
image_path = os.path.join(image_folder, image_name)
|
| 83 |
+
img = Image.open(image_path).convert("RGB")
|
| 84 |
+
embeddings.append(self.get_embedding_from_pil(img))
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error processing {image_name}: {str(e)}")
|
| 87 |
+
|
| 88 |
+
return embeddings
|
services/security_service.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Security service for API authentication and authorization
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import Depends, HTTPException
|
| 6 |
+
from fastapi.security import APIKeyHeader
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SecurityService:
|
| 10 |
+
"""Class for handling API security"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, api_key: str):
|
| 13 |
+
self.api_key = api_key
|
| 14 |
+
self.api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
| 15 |
+
|
| 16 |
+
async def verify_api_key(self, api_key: str = Depends(APIKeyHeader(name="X-API-Key", auto_error=False))):
|
| 17 |
+
"""Verify API key dependency"""
|
| 18 |
+
if self.api_key != "your-api-key-here" and api_key != self.api_key: # Skip check if using default key
|
| 19 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 20 |
+
return api_key
|
services/vector_db_service.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vector database service for interacting with Qdrant
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
|
| 7 |
+
from fastapi import HTTPException # type: ignore
|
| 8 |
+
from qdrant_client import QdrantClient # type: ignore
|
| 9 |
+
from qdrant_client.models import Distance, PointStruct, VectorParams # type: ignore
|
| 10 |
+
|
| 11 |
+
class VectorDatabaseClient:
|
| 12 |
+
"""Class for interacting with Qdrant vector database"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int):
|
| 15 |
+
self.url = url
|
| 16 |
+
self.api_key = api_key
|
| 17 |
+
self.collection_name = collection_name
|
| 18 |
+
self.embedding_size = embedding_size
|
| 19 |
+
self.client = QdrantClient(url=url, api_key=api_key)
|
| 20 |
+
|
| 21 |
+
def ensure_collection_exists(self) -> None:
|
| 22 |
+
"""Ensure the Qdrant collection exists"""
|
| 23 |
+
collections = self.client.get_collections()
|
| 24 |
+
collection_names = [c.name for c in collections.collections]
|
| 25 |
+
|
| 26 |
+
if self.collection_name not in collection_names:
|
| 27 |
+
self.client.create_collection(
|
| 28 |
+
collection_name=self.collection_name,
|
| 29 |
+
vectors_config=VectorParams(
|
| 30 |
+
size=self.embedding_size,
|
| 31 |
+
distance=Distance.COSINE
|
| 32 |
+
)
|
| 33 |
+
)
|
| 34 |
+
print(f"✅ Collection '{self.collection_name}' created.")
|
| 35 |
+
else:
|
| 36 |
+
print(f"ℹ️ Collection '{self.collection_name}' already exists.")
|
| 37 |
+
|
| 38 |
+
def add_image(self, image_id: str, embedding: List[float], payload: Dict[str, Any]) -> None:
|
| 39 |
+
"""Add an image embedding to the database"""
|
| 40 |
+
self.client.upsert(
|
| 41 |
+
collection_name=self.collection_name,
|
| 42 |
+
points=[
|
| 43 |
+
PointStruct(
|
| 44 |
+
id=image_id,
|
| 45 |
+
vector=embedding,
|
| 46 |
+
payload=payload
|
| 47 |
+
)
|
| 48 |
+
]
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def search_by_vector(self, embedding: List[float], limit: int = 1) -> List[Dict[str, Any]]:
|
| 52 |
+
"""Search for similar images using an embedding vector"""
|
| 53 |
+
results = self.client.search(
|
| 54 |
+
collection_name=self.collection_name,
|
| 55 |
+
query_vector=embedding,
|
| 56 |
+
limit=limit
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return [
|
| 60 |
+
{
|
| 61 |
+
"id": r.id,
|
| 62 |
+
"score": r.score,
|
| 63 |
+
"payload": r.payload
|
| 64 |
+
}
|
| 65 |
+
for r in results
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
def list_collections(self) -> List[str]:
|
| 69 |
+
"""List all collections in the database"""
|
| 70 |
+
return [c.name for c in self.client.get_collections().collections]
|