FastAPI / services /embedding_service.py
ravi19's picture
Deploy FastAPI to HF Space
b36cb8b
"""
Image embedding service for generating vector embeddings from images
"""
import io
import os
import base64
from typing import List, Tuple
import open_clip
import torch
from fastapi import UploadFile, HTTPException
from PIL import Image
import torch
class ImageEmbeddingModel:
"""Class for handling image embedding using CLIP model"""
def __init__(self, model_name: str):
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess_train, self.preprocess_val = self._initialize_model()
def _initialize_model(self) -> Tuple:
"""Initialize the CLIP model for image embeddings"""
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(self.model_name)
tokenizer = open_clip.get_tokenizer(self.model_name)
model.to(self.device)
model.eval()
return model, preprocess_train, preprocess_val
def get_embedding_from_pil(self, image: Image.Image) -> List[float]:
"""Get embedding from PIL image"""
processed_image = self.preprocess_val(image).unsqueeze(0).to(self.device)
# with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu')
if self.device == 'cuda':
autocast_context = torch.amp.autocast(device_type='cuda')
else:
# On CPU, autocast should either be skipped or forced to float32
autocast_context = torch.amp.autocast(device_type='cpu', dtype=torch.float32)
with torch.no_grad(), autocast_context:
image_features = self.model.encode_image(processed_image, normalize=True)
return image_features.cpu().numpy()[0].tolist()
async def get_embedding_from_upload(self, image_file: UploadFile) -> List[float]:
"""Get embedding from uploaded image file"""
try:
contents = await image_file.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
return self.get_embedding_from_pil(img)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}")
def get_embedding_from_base64(self, base64_data: str) -> List[float]:
"""Get embedding from base64 encoded image"""
try:
# Handle data URI format
if ',' in base64_data:
base64_data = base64_data.split(',')[1]
image_bytes = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return self.get_embedding_from_pil(image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}")
def get_embeddings_from_folder(self, image_folder: str) -> List[List[float]]:
"""Get embeddings from all images in a folder"""
embeddings = []
if not os.path.exists(image_folder):
raise HTTPException(status_code=404, detail=f"Folder not found: {image_folder}")
for image_name in os.listdir(image_folder):
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
image_path = os.path.join(image_folder, image_name)
img = Image.open(image_path).convert("RGB")
embeddings.append(self.get_embedding_from_pil(img))
except Exception as e:
print(f"Error processing {image_name}: {str(e)}")
return embeddings