ravi19 commited on
Commit
b36cb8b
·
1 Parent(s): 678a6fd

Deploy FastAPI to HF Space

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv/
2
+ env/
3
+ ENV/
4
+ .venv/
New folder/.env1 ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ QDRANT_URL='https://b6138c60-0a19-4ba7-b6a5-f70a7d653b57.us-west-1-0.aws.cloud.qdrant.io'
2
+ QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.XQrkVFAz02zgcvVYbmoneq36biKdbP6491n5I-RrCpQ"
New folder/embedding_service2.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Embedding Service for generating image embeddings
3
+ """
4
+
5
+ import os
6
+ from typing import List, Dict, Any
7
+ from PIL import Image
8
+ import io
9
+ import numpy as np
10
+ import torch
11
+ from transformers import CLIPProcessor, CLIPModel
12
+
13
+
14
+ class ImageEmbeddingModel:
15
+ """Class for generating embeddings from images using CLIP"""
16
+
17
+ def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
18
+ """Initialize the CLIP model
19
+
20
+ Args:
21
+ model_name: Name of the CLIP model to use
22
+ """
23
+ self.model_name = model_name
24
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ self.model = CLIPModel.from_pretrained(model_name).to(self.device)
26
+ self.processor = CLIPProcessor.from_pretrained(model_name)
27
+
28
+ def generate_embedding(self, image_data: bytes) -> List[float]:
29
+ """Generate embedding for an image from binary data
30
+
31
+ Args:
32
+ image_data: Binary image data
33
+
34
+ Returns:
35
+ Image embedding as a list of floats
36
+ """
37
+ # Load image from binary data
38
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
39
+ return self.generate_embedding_from_pil(image)
40
+
41
+ def generate_embedding_from_pil(self, image: Image.Image) -> List[float]:
42
+ """Generate embedding for a PIL Image
43
+
44
+ Args:
45
+ image: PIL Image object
46
+
47
+ Returns:
48
+ Image embedding as a list of floats
49
+ """
50
+ # Process image for CLIP
51
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
52
+
53
+ # Generate embedding
54
+ with torch.no_grad():
55
+ image_features = self.model.get_image_features(**inputs)
56
+
57
+ # Normalize embedding and convert to list
58
+ image_embedding = image_features.cpu().numpy()[0]
59
+ normalized_embedding = image_embedding / np.linalg.norm(image_embedding)
60
+ return normalized_embedding.tolist()
61
+
62
+ def get_embeddings_from_folder(self, folder_path: str) -> Dict[str, Any]:
63
+ """Generate embeddings for all images in a folder
64
+
65
+ Args:
66
+ folder_path: Path to folder containing images
67
+
68
+ Returns:
69
+ Dictionary mapping filenames to embeddings
70
+ """
71
+ results = {}
72
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'}
73
+
74
+ # Check if folder exists
75
+ if not os.path.exists(folder_path):
76
+ return {"error": f"Folder {folder_path} does not exist"}
77
+
78
+ # Process each image file
79
+ for filename in os.listdir(folder_path):
80
+ if os.path.splitext(filename)[1].lower() in image_extensions:
81
+ try:
82
+ file_path = os.path.join(folder_path, filename)
83
+ with open(file_path, 'rb') as f:
84
+ image_data = f.read()
85
+
86
+ embedding = self.generate_embedding(image_data)
87
+ results[filename] = {
88
+ "embedding": embedding,
89
+ "status": "success"
90
+ }
91
+ except Exception as e:
92
+ results[filename] = {
93
+ "error": str(e),
94
+ "status": "failed"
95
+ }
96
+
97
+ return results
New folder/vector_db_service2.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vector Database Service implementation for Qdrant
3
+ """
4
+
5
+ from typing import List, Dict, Any, Optional
6
+ from qdrant_client import QdrantClient
7
+ from qdrant_client.models import PointStruct, VectorParams, Distance, Record
8
+
9
+
10
+
11
+ class VectorDatabaseClient:
12
+ """Client for interacting with Qdrant vector database"""
13
+
14
+ def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int):
15
+ """Initialize Qdrant client and collection settings
16
+
17
+ Args:
18
+ url: Qdrant server URL
19
+ api_key: API key for Qdrant
20
+ collection_name: Name of the collection to use
21
+ embedding_size: Size of embedding vectors
22
+ """
23
+ self.client = QdrantClient(url=url, api_key=api_key)
24
+ self.collection_name = collection_name
25
+ self.embedding_size = embedding_size
26
+
27
+ def ensure_collection_exists(self):
28
+ """Ensure the collection exists, create it if it doesn't"""
29
+ collections = [c.name for c in self.client.get_collections().collections]
30
+
31
+ if self.collection_name not in collections:
32
+ self.client.create_collection(
33
+ collection_name=self.collection_name,
34
+ vectors_config=VectorParams(
35
+ size=self.embedding_size,
36
+ distance=Distance.COSINE
37
+ )
38
+ )
39
+ print(f"✅ Collection '{self.collection_name}' created.")
40
+ else:
41
+ print(f"ℹ️ Collection '{self.collection_name}' already exists.")
42
+
43
+ def add_embedding(self, id: str, embedding: List[float], filename: str, metadata: Optional[str] = None) -> str:
44
+ """Add an embedding to the collection
45
+
46
+ Args:
47
+ id: Unique ID for the point
48
+ embedding: Vector embedding
49
+ filename: Original filename
50
+ metadata: Optional metadata as JSON string
51
+
52
+ Returns:
53
+ ID of the added point
54
+ """
55
+ payload = {"filename": filename}
56
+ if metadata:
57
+ payload["metadata"] = metadata
58
+
59
+ self.client.upsert(
60
+ collection_name=self.collection_name,
61
+ points=[
62
+ PointStruct(
63
+ id=id,
64
+ vector=embedding,
65
+ payload=payload
66
+ )
67
+ ]
68
+ )
69
+ return id
70
+
71
+ def add_embedding_with_payload(self, id: str, embedding: List[float], payload: Dict[str, Any]) -> str:
72
+ """Add an embedding with a custom payload
73
+
74
+ Args:
75
+ id: Unique ID for the point
76
+ embedding: Vector embedding
77
+ payload: Dictionary of metadata to store
78
+
79
+ Returns:
80
+ ID of the added point
81
+ """
82
+ self.client.upsert(
83
+ collection_name=self.collection_name,
84
+ points=[
85
+ PointStruct(
86
+ id=id,
87
+ vector=embedding,
88
+ payload=payload
89
+ )
90
+ ]
91
+ )
92
+ return id
93
+
94
+ def search_by_embedding(self, embedding: List[float], limit: int = 5) -> List[Record]:
95
+ """Search for similar vectors
96
+
97
+ Args:
98
+ embedding: Query vector
99
+ limit: Maximum number of results
100
+
101
+ Returns:
102
+ List of search results
103
+ """
104
+ results = self.client.search(
105
+ collection_name=self.collection_name,
106
+ query_vector=embedding,
107
+ limit=limit
108
+ )
109
+ return results
110
+
111
+ def search_by_id(self, id: str, limit: int = 1) -> List[Record]:
112
+ """Search for similar vectors using an existing vector as query
113
+
114
+ Args:
115
+ id: ID of the existing vector to use as query
116
+ limit: Maximum number of results
117
+
118
+ Returns:
119
+ List of search results
120
+ """
121
+ # Get the vector by ID
122
+ vector = self.client.retrieve(
123
+ collection_name=self.collection_name,
124
+ ids=[id]
125
+ )
126
+
127
+ if not vector or len(vector) == 0:
128
+ return []
129
+
130
+ # Use the vector to search
131
+ return self.search_by_embedding(vector[0].vector, limit)
132
+
133
+ def delete_embedding(self, id: str) -> bool:
134
+ """Delete an embedding from the collection
135
+
136
+ Args:
137
+ id: ID of the embedding to delete
138
+
139
+ Returns:
140
+ True if deleted, False if not found
141
+ """
142
+ self.client.delete(
143
+ collection_name=self.collection_name,
144
+ points_selector=[id]
145
+ )
146
+ return True
147
+
148
+ def list_collections(self) -> List[str]:
149
+ """List all collections in the database
150
+
151
+ Returns:
152
+ List of collection names
153
+ """
154
+ return [c.name for c in self.client.get_collections().collections]
README.md CHANGED
@@ -1,12 +1,97 @@
1
- ---
2
- title: FastAPI
3
- emoji: 📚
4
- colorFrom: indigo
5
- colorTo: purple
6
- sdk: static
7
- pinned: false
8
- license: apache-2.0
9
- short_description: Fast API with Marqo model
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Similarity Search API
2
+
3
+ A FastAPI application for image similarity search using CLIP embeddings and Qdrant vector database.
4
+
5
+ ## Features
6
+
7
+ - Upload images and store their vector embeddings
8
+ - Search for similar images using an uploaded image or base64 encoded image
9
+ - Secure API with API key authentication
10
+ - Well-organized, modular codebase following OOP principles
11
+
12
+ ## Installation
13
+
14
+ 1. Clone this repository
15
+ 2. Install dependencies:
16
+
17
+ ```bash
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ 3. Set up environment variables (optional, defaults are provided):
22
+
23
+ ```bash
24
+ export QDRANT_URL="your-qdrant-url"
25
+ export QDRANT_API_KEY="your-qdrant-api-key"
26
+ export COLLECTION_NAME="your-collection-name"
27
+ export API_KEY="your-api-key"
28
+ export PORT=8000
29
+ export ENVIRONMENT="production" # Or "development" for debug mode with auto-reload
30
+ ```
31
+
32
+ ## Usage
33
+
34
+ Run the application:
35
+
36
+ ```bash
37
+ python app.py
38
+ ```
39
+
40
+ The API will be available at http://localhost:8000 (or the port specified in environment variables).
41
+
42
+ ### API Documentation
43
+
44
+ Once running, API documentation is available at:
45
+ - Swagger UI: http://localhost:8000/docs
46
+ - ReDoc: http://localhost:8000/redoc
47
+
48
+ ## API Endpoints
49
+
50
+ - `POST /add-image/`: Add an image to the database
51
+ - `POST /add-images-from-folder/`: Add all images from a folder to the database
52
+ - `POST /search-by-image/`: Search for similar images using an uploaded image
53
+ - `POST /search-by-image-scan/`: Search for similar images using a base64 encoded image
54
+ - `GET /collections`: List all collections in the database
55
+ - `GET /health`: Health check endpoint
56
+
57
+ ## Project Structure
58
+
59
+ ```
60
+ image_similarity_api/
61
+
62
+ ├── app.py # Main application entry point
63
+ ├── config.py # Configuration settings
64
+ ├── models/
65
+ │ ├── __init__.py
66
+ │ └── schemas.py # Pydantic models
67
+ ├── services/
68
+ │ ├── __init__.py
69
+ │ ├── embedding.py # Image embedding service
70
+ │ ├── security.py # Security service
71
+ │ └── vector_db.py # Vector database service
72
+ ├── api/
73
+ │ ├── __init__.py
74
+ │ └── routes.py # API routes
75
+ ├── requirements.txt # Project dependencies
76
+ └── README.md # Project documentation
77
+ ```
78
+
79
+ ## Development
80
+
81
+ For development, set the ENVIRONMENT variable to "development" for auto-reload:
82
+
83
+ ```bash
84
+ export ENVIRONMENT="development"
85
+ python app.py
86
+ ```
87
+
88
+ ## Deployment
89
+
90
+ This application can be deployed to any platform that supports Python applications:
91
+
92
+ 1. Docker
93
+ 2. Kubernetes
94
+ 3. Cloud platforms (AWS, GCP, Azure, etc.)
95
+ 4. Serverless platforms (with appropriate adapters)
96
+
97
+ Remember to set all required environment variables in your production environment.
__pycache__/config.cpython-312.pyc ADDED
Binary file (1.49 kB). View file
 
api/__init__.py ADDED
File without changes
api/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (127 Bytes). View file
 
api/__pycache__/routes.cpython-312.pyc ADDED
Binary file (5.01 kB). View file
 
api/routes.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Routes for the Image Similarity Search API
3
+ Contains all endpoints for the application using your original route implementation
4
+ """
5
+
6
+ import uuid
7
+ import base64
8
+ import io
9
+ from typing import List, Optional
10
+ from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path # type: ignore
11
+ from pydantic import BaseModel
12
+ from PIL import Image
13
+
14
+ from services.embedding_service import ImageEmbeddingModel
15
+ from services.vector_db_service import VectorDatabaseClient
16
+
17
+
18
+ class Base64ImageRequest(BaseModel):
19
+ """Request model for base64 encoded images"""
20
+ image_data: str
21
+
22
+
23
+ def register_routes(
24
+ app: FastAPI,
25
+ embedding_model: ImageEmbeddingModel,
26
+ vector_db: VectorDatabaseClient,
27
+ ):
28
+ """Register all routes with the FastAPI app"""
29
+
30
+ @app.api_route("/", methods=["GET", "HEAD"])
31
+ async def read_root():
32
+ return {"status": "API running"}
33
+
34
+ @app.post("/add-image/")
35
+ async def add_image(
36
+ file: UploadFile = File(...),
37
+ item_name: str = Form(...),
38
+ design_name: str = Form(...),
39
+ item_price: float = Form(...)
40
+ ):
41
+ """Upload an image with product details and store its embedding"""
42
+ # Process the image to get embedding
43
+ # image_data = await file.read()
44
+ embedding = await embedding_model.get_embedding_from_upload(file)
45
+
46
+ # Generate a unique ID
47
+ image_id = str(uuid.uuid4())
48
+
49
+ # Store additional metadata in payload
50
+ payload = {
51
+ "filename": file.filename,
52
+ "item_name": item_name,
53
+ "design_name": design_name,
54
+ "item_price": item_price
55
+ }
56
+
57
+ # Store in vector database
58
+ vector_db.add_image(image_id, embedding, payload)
59
+
60
+ return {"message": "Image added successfully", "id": image_id}
61
+
62
+ @app.post("/add-images-from-folder/")
63
+ async def add_images_from_folder(folder_path: str):
64
+ """Process and add all images from a specified folder"""
65
+ embeddings = embedding_model.get_embeddings_from_folder(folder_path)
66
+ return {"embeddings": embeddings}
67
+
68
+ @app.post("/search-by-image/")
69
+ async def search_by_image(file: UploadFile = File(...)):
70
+ """Search for similar images by uploading a file"""
71
+ # Process the image to get embedding
72
+ # image_data = await file.read()
73
+ embedding = await embedding_model.get_embedding_from_upload(file)
74
+
75
+ # Search using the embedding
76
+ results = vector_db.search_by_vector(embedding, limit=1)
77
+
78
+ # return [
79
+ # {
80
+ # "id": r.id,
81
+ # "score": r.score,
82
+ # "payload": r.payload
83
+ # }
84
+ # for r in results
85
+ # ]
86
+ return results
87
+
88
+ @app.post("/search-by-image-scan/")
89
+ async def search_by_image_scan(request: Base64ImageRequest):
90
+ """Search for similar images using a base64 encoded image"""
91
+ # Decode base64 image
92
+ image_data = request.image_data
93
+ image_bytes = base64.b64decode(image_data.split(',')[1] if ',' in image_data else image_data)
94
+
95
+ # Convert to PIL Image
96
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
97
+
98
+ # Process image to get embedding
99
+ embedding = embedding_model.get_embedding_from_pil(image)
100
+
101
+ # Search using the embedding
102
+ results = vector_db.search_by_vector(embedding, limit=1)
103
+
104
+ return results
105
+
106
+ @app.get("/collections")
107
+ def list_collections():
108
+ """List all available collections in the vector database"""
109
+ return vector_db.list_collections()
api/routes1.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Routes for the Image Similarity Search API
3
+ Contains all endpoints for the application
4
+ """
5
+
6
+ from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path
7
+ from typing import List, Optional
8
+ from pydantic import BaseModel
9
+
10
+ from services.embedding_service import ImageEmbeddingModel
11
+ from services.vector_db_service import VectorDatabaseClient
12
+
13
+
14
+ class SearchResponse(BaseModel):
15
+ """Response model for search results"""
16
+ image_id: str
17
+ similarity: float
18
+ metadata: Optional[dict] = None
19
+
20
+
21
+ def register_routes(
22
+ app: FastAPI,
23
+ embedding_model: ImageEmbeddingModel,
24
+ vector_db: VectorDatabaseClient,
25
+ # Remove security_service parameter
26
+ ):
27
+ """Register all routes with the FastAPI app"""
28
+ router = APIRouter()
29
+
30
+ @router.post("/upload", response_model=dict)
31
+ async def upload_image(
32
+ file: UploadFile = File(...),
33
+ metadata: Optional[str] = Form(None),
34
+ # Remove security dependency: api_key: str = Depends(security_service.verify_api_key)
35
+ ):
36
+ """Upload an image and store its embedding"""
37
+ # Process the image and generate embedding
38
+ image_data = await file.read()
39
+ embedding = embedding_model.generate_embedding(image_data)
40
+
41
+ # Store in vector database with optional metadata
42
+ image_id = vector_db.add_embedding(embedding, file.filename, metadata)
43
+
44
+ return {"image_id": image_id, "message": "Image uploaded successfully"}
45
+
46
+ @router.get("/search/by-id/{image_id}", response_model=List[SearchResponse])
47
+ async def search_by_id(
48
+ image_id: str = Path(..., description="ID of the uploaded image to use as query"),
49
+ limit: int = Query(5, description="Maximum number of results to return"),
50
+ # Remove security dependency: api_key: str = Depends(security_service.verify_api_key)
51
+ ):
52
+ """Search for similar images using an existing image ID as the query"""
53
+ results = vector_db.search_by_id(image_id, limit)
54
+ return [
55
+ SearchResponse(
56
+ image_id=result.id,
57
+ similarity=result.score,
58
+ metadata=result.metadata
59
+ )
60
+ for result in results
61
+ ]
62
+
63
+ @router.post("/search/by-image", response_model=List[SearchResponse])
64
+ async def search_by_image(
65
+ file: UploadFile = File(...),
66
+ limit: int = Query(5, description="Maximum number of results to return"),
67
+ # Remove security dependency: api_key: str = Depends(security_service.verify_api_key)
68
+ ):
69
+ """Search for similar images by uploading a new image"""
70
+ # Process the image and generate embedding
71
+ image_data = await file.read()
72
+ embedding = embedding_model.generate_embedding(image_data)
73
+
74
+ # Search using the embedding
75
+ results = vector_db.search_by_embedding(embedding, limit)
76
+ return [
77
+ SearchResponse(
78
+ image_id=result.id,
79
+ similarity=result.score,
80
+ metadata=result.metadata
81
+ )
82
+ for result in results
83
+ ]
84
+
85
+ @router.delete("/images/{image_id}")
86
+ async def delete_image(
87
+ image_id: str = Path(..., description="ID of the image to delete"),
88
+ # Remove security dependency: api_key: str = Depends(security_service.verify_api_key)
89
+ ):
90
+ """Delete an image from the database"""
91
+ success = vector_db.delete_embedding(image_id)
92
+ if success:
93
+ return {"message": f"Image {image_id} deleted successfully"}
94
+ return {"message": f"Image {image_id} not found"}
95
+
96
+
97
+ # Add the router to the app
98
+ app.include_router(router, prefix="/api/v1")
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image Similarity Search API with FastAPI and Qdrant - Fixed Access
3
+ This application provides endpoints for uploading images and searching for similar images
4
+ using vector embeddings from the CLIP model. Implemented using OOP principles.
5
+ """
6
+
7
+ import uvicorn # type: ignore
8
+ from fastapi import FastAPI # type: ignore
9
+ from contextlib import asynccontextmanager
10
+ import os
11
+ import ssl
12
+ from fastapi.middleware.cors import CORSMiddleware # type: ignore
13
+
14
+ from config import Config
15
+ from services.embedding_service import ImageEmbeddingModel
16
+ from services.vector_db_service import VectorDatabaseClient
17
+ from api.routes import register_routes
18
+
19
+
20
+ @asynccontextmanager
21
+ async def lifespan(app: FastAPI):
22
+ """Lifespan context manager for FastAPI application startup and shutdown events"""
23
+ # This runs before the application starts
24
+ vector_db = app.state.vector_db
25
+ vector_db.ensure_collection_exists()
26
+
27
+ yield # This yields control back to FastAPI
28
+
29
+ # This runs when the application is shutting down
30
+ # Cleanup code can go here if needed
31
+
32
+
33
+ class ImageSimilarityAPI:
34
+ """Main application class that orchestrates all components"""
35
+
36
+ def __init__(self):
37
+ # Initialize config
38
+ self.config = Config()
39
+
40
+ # Initialize components
41
+ self.embedding_model = ImageEmbeddingModel(self.config.model_name)
42
+ self.vector_db = VectorDatabaseClient(
43
+ self.config.qdrant_url,
44
+ self.config.qdrant_api_key,
45
+ self.config.collection_name,
46
+ self.config.embedding_size
47
+ )
48
+
49
+ # Initialize FastAPI app with lifespan handler
50
+ self.app = FastAPI(
51
+ title="Image Similarity Search API",
52
+ description="API for uploading images and searching for similar images using CLIP embeddings",
53
+ version="1.0.0",
54
+ lifespan=lifespan
55
+ )
56
+
57
+ # ✅ Enable CORS to allow mobile access
58
+ self.app.add_middleware(
59
+ CORSMiddleware,
60
+ allow_origins=["*"], # Or set to ["http://192.168.1.42"] for better security
61
+ allow_credentials=True,
62
+ allow_methods=["*"],
63
+ allow_headers=["*"],
64
+ )
65
+ # Store vector_db in app state for use in lifespan
66
+ self.app.state.vector_db = self.vector_db
67
+
68
+ # Register routes
69
+ register_routes(self.app, self.embedding_model, self.vector_db)
70
+
71
+ def run(self, use_https=False, cert_file="./certs/cert.pem", key_file="./certs/key.pem"):
72
+ """Run the FastAPI application with optional HTTPS support
73
+
74
+ Args:
75
+ use_https: Whether to use HTTPS or plain HTTP
76
+ cert_file: Path to SSL certificate file
77
+ key_file: Path to SSL private key file
78
+ """
79
+ host = "0.0.0.0" # Use localhost instead of 0.0.0.0 for better access
80
+ port = 8000 if not use_https else 8443
81
+
82
+ ssl_context = None
83
+ if use_https:
84
+ # Check if certificate files exist
85
+ if not os.path.exists(cert_file) or not os.path.exists(key_file):
86
+ print(f"ERROR: SSL certificate files not found at {cert_file} and/or {key_file}")
87
+ print("Falling back to HTTP. To use HTTPS, please provide valid certificate files.")
88
+ use_https = False
89
+ else:
90
+ # Create SSL context for HTTPS
91
+ ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
92
+ ssl_context.load_cert_chain(cert_file, key_file)
93
+
94
+ # Print access URLs for convenience
95
+ protocol = "https" if use_https else "http"
96
+ print(f"\n{'='*50}")
97
+ print(f"Access the API at: {protocol}://{host}:{port}")
98
+ print(f"Swagger UI available at: {protocol}://{host}:{port}/docs")
99
+ print(f"ReDoc UI available at: {protocol}://{host}:{port}/redoc")
100
+ print(f"{'='*50}\n")
101
+
102
+ uvicorn.run(
103
+ self.app,
104
+ host=host,
105
+ port=port,
106
+ reload=self.config.environment == "development",
107
+ ssl_certfile=cert_file if use_https else None,
108
+ ssl_keyfile=key_file if use_https else None
109
+ )
110
+
111
+
112
+ def create_app() -> FastAPI:
113
+ """Create and return the FastAPI application"""
114
+ api = ImageSimilarityAPI()
115
+ return api.app
116
+
117
+
118
+ if __name__ == "__main__":
119
+ api = ImageSimilarityAPI()
120
+ # Set to False for now until certificates are properly set up
121
+ api.run(
122
+ use_https=False, # Change to True when certificates are ready
123
+ cert_file="./certs/cert.pem",
124
+ key_file="./certs/key.pem"
125
+ )
config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the Image Similarity API
3
+ """
4
+
5
+ import os
6
+
7
+
8
+ class Config:
9
+ """Configuration class for the application"""
10
+
11
+ def __init__(self):
12
+ self.qdrant_url = os.getenv("QDRANT_URL",
13
+ "https://b6138c60-0a19-4ba7-b6a5-f70a7d653b57.us-west-1-0.aws.cloud.qdrant.io")
14
+ self.qdrant_api_key = os.getenv("QDRANT_API_KEY",
15
+ "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.XQrkVFAz02zgcvVYbmoneq36biKdbP6491n5I-RrCpQ")
16
+ self.collection_name = os.getenv("COLLECTION_NAME", "marqe_embedings")
17
+ # self.api_key = os.getenv("API_KEY", "your-api-key-here")
18
+ self.model_name = os.getenv("MODEL_NAME", "hf-hub:Marqo/marqo-ecommerce-embeddings-L")
19
+ self.embedding_size = 768
20
+ self.port = int(os.getenv("PORT", 8000))
21
+ self.environment = os.getenv("ENVIRONMENT", "production")
models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ Empty __init__.py file to make the models directory a proper Python package
3
+ """
4
+
5
+ # This file is intentionally left empty
models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (224 Bytes). View file
 
models/__pycache__/schemas.cpython-312.pyc ADDED
Binary file (1.04 kB). View file
 
models/schemas.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic model schemas for API request and response types
3
+ """
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class Base64ImageRequest(BaseModel):
9
+ """Model for accepting base64 encoded images"""
10
+ image_data: str
11
+
12
+
13
+ class SearchResult(BaseModel):
14
+ """Model for search results"""
15
+ id: str
16
+ score: float
17
+ payload: dict
18
+
19
+
20
+ class ErrorResponse(BaseModel):
21
+ """Model for error responses"""
22
+ detail: str
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ torch
3
+ open_clip_torch
4
+ fastapi
5
+ uvicorn
6
+ qdrant-client
7
+ nest_asyncio
8
+ python-multipart
9
+ pillow
10
+ numpy
11
+ pydantic
12
+ python-dotenv
13
+
services/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ Empty __init__.py file to make the api directory a proper Python package
3
+ """
4
+
5
+ # This file is intentionally left empty
services/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (223 Bytes). View file
 
services/__pycache__/embedding_service.cpython-312.pyc ADDED
Binary file (6.21 kB). View file
 
services/__pycache__/security_service.cpython-312.pyc ADDED
Binary file (1.31 kB). View file
 
services/__pycache__/vector_db_service.cpython-312.pyc ADDED
Binary file (3.83 kB). View file
 
services/embedding_service.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image embedding service for generating vector embeddings from images
3
+ """
4
+
5
+ import io
6
+ import os
7
+ import base64
8
+ from typing import List, Tuple
9
+
10
+ import open_clip
11
+ import torch
12
+ from fastapi import UploadFile, HTTPException
13
+ from PIL import Image
14
+ import torch
15
+
16
+
17
+ class ImageEmbeddingModel:
18
+ """Class for handling image embedding using CLIP model"""
19
+
20
+ def __init__(self, model_name: str):
21
+ self.model_name = model_name
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.model, self.preprocess_train, self.preprocess_val = self._initialize_model()
24
+
25
+ def _initialize_model(self) -> Tuple:
26
+ """Initialize the CLIP model for image embeddings"""
27
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(self.model_name)
28
+ tokenizer = open_clip.get_tokenizer(self.model_name)
29
+ model.to(self.device)
30
+ model.eval()
31
+ return model, preprocess_train, preprocess_val
32
+
33
+ def get_embedding_from_pil(self, image: Image.Image) -> List[float]:
34
+ """Get embedding from PIL image"""
35
+ processed_image = self.preprocess_val(image).unsqueeze(0).to(self.device)
36
+
37
+ # with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu')
38
+ if self.device == 'cuda':
39
+ autocast_context = torch.amp.autocast(device_type='cuda')
40
+ else:
41
+ # On CPU, autocast should either be skipped or forced to float32
42
+ autocast_context = torch.amp.autocast(device_type='cpu', dtype=torch.float32)
43
+ with torch.no_grad(), autocast_context:
44
+ image_features = self.model.encode_image(processed_image, normalize=True)
45
+
46
+
47
+ return image_features.cpu().numpy()[0].tolist()
48
+
49
+ async def get_embedding_from_upload(self, image_file: UploadFile) -> List[float]:
50
+ """Get embedding from uploaded image file"""
51
+ try:
52
+ contents = await image_file.read()
53
+ img = Image.open(io.BytesIO(contents)).convert("RGB")
54
+ return self.get_embedding_from_pil(img)
55
+ except Exception as e:
56
+ raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}")
57
+
58
+ def get_embedding_from_base64(self, base64_data: str) -> List[float]:
59
+ """Get embedding from base64 encoded image"""
60
+ try:
61
+ # Handle data URI format
62
+ if ',' in base64_data:
63
+ base64_data = base64_data.split(',')[1]
64
+
65
+ image_bytes = base64.b64decode(base64_data)
66
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
67
+
68
+ return self.get_embedding_from_pil(image)
69
+ except Exception as e:
70
+ raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}")
71
+
72
+ def get_embeddings_from_folder(self, image_folder: str) -> List[List[float]]:
73
+ """Get embeddings from all images in a folder"""
74
+ embeddings = []
75
+
76
+ if not os.path.exists(image_folder):
77
+ raise HTTPException(status_code=404, detail=f"Folder not found: {image_folder}")
78
+
79
+ for image_name in os.listdir(image_folder):
80
+ if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
81
+ try:
82
+ image_path = os.path.join(image_folder, image_name)
83
+ img = Image.open(image_path).convert("RGB")
84
+ embeddings.append(self.get_embedding_from_pil(img))
85
+ except Exception as e:
86
+ print(f"Error processing {image_name}: {str(e)}")
87
+
88
+ return embeddings
services/security_service.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Security service for API authentication and authorization
3
+ """
4
+
5
+ from fastapi import Depends, HTTPException
6
+ from fastapi.security import APIKeyHeader
7
+
8
+
9
+ class SecurityService:
10
+ """Class for handling API security"""
11
+
12
+ def __init__(self, api_key: str):
13
+ self.api_key = api_key
14
+ self.api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
15
+
16
+ async def verify_api_key(self, api_key: str = Depends(APIKeyHeader(name="X-API-Key", auto_error=False))):
17
+ """Verify API key dependency"""
18
+ if self.api_key != "your-api-key-here" and api_key != self.api_key: # Skip check if using default key
19
+ raise HTTPException(status_code=401, detail="Invalid API key")
20
+ return api_key
services/vector_db_service.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vector database service for interacting with Qdrant
3
+ """
4
+
5
+ from typing import List, Dict, Any
6
+
7
+ from fastapi import HTTPException # type: ignore
8
+ from qdrant_client import QdrantClient # type: ignore
9
+ from qdrant_client.models import Distance, PointStruct, VectorParams # type: ignore
10
+
11
+ class VectorDatabaseClient:
12
+ """Class for interacting with Qdrant vector database"""
13
+
14
+ def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int):
15
+ self.url = url
16
+ self.api_key = api_key
17
+ self.collection_name = collection_name
18
+ self.embedding_size = embedding_size
19
+ self.client = QdrantClient(url=url, api_key=api_key)
20
+
21
+ def ensure_collection_exists(self) -> None:
22
+ """Ensure the Qdrant collection exists"""
23
+ collections = self.client.get_collections()
24
+ collection_names = [c.name for c in collections.collections]
25
+
26
+ if self.collection_name not in collection_names:
27
+ self.client.create_collection(
28
+ collection_name=self.collection_name,
29
+ vectors_config=VectorParams(
30
+ size=self.embedding_size,
31
+ distance=Distance.COSINE
32
+ )
33
+ )
34
+ print(f"✅ Collection '{self.collection_name}' created.")
35
+ else:
36
+ print(f"ℹ️ Collection '{self.collection_name}' already exists.")
37
+
38
+ def add_image(self, image_id: str, embedding: List[float], payload: Dict[str, Any]) -> None:
39
+ """Add an image embedding to the database"""
40
+ self.client.upsert(
41
+ collection_name=self.collection_name,
42
+ points=[
43
+ PointStruct(
44
+ id=image_id,
45
+ vector=embedding,
46
+ payload=payload
47
+ )
48
+ ]
49
+ )
50
+
51
+ def search_by_vector(self, embedding: List[float], limit: int = 1) -> List[Dict[str, Any]]:
52
+ """Search for similar images using an embedding vector"""
53
+ results = self.client.search(
54
+ collection_name=self.collection_name,
55
+ query_vector=embedding,
56
+ limit=limit
57
+ )
58
+
59
+ return [
60
+ {
61
+ "id": r.id,
62
+ "score": r.score,
63
+ "payload": r.payload
64
+ }
65
+ for r in results
66
+ ]
67
+
68
+ def list_collections(self) -> List[str]:
69
+ """List all collections in the database"""
70
+ return [c.name for c in self.client.get_collections().collections]