FastAPI / New folder /embedding_service2.py
ravi19's picture
Deploy FastAPI to HF Space
b36cb8b
"""
Embedding Service for generating image embeddings
"""
import os
from typing import List, Dict, Any
from PIL import Image
import io
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
class ImageEmbeddingModel:
"""Class for generating embeddings from images using CLIP"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
"""Initialize the CLIP model
Args:
model_name: Name of the CLIP model to use
"""
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(model_name)
def generate_embedding(self, image_data: bytes) -> List[float]:
"""Generate embedding for an image from binary data
Args:
image_data: Binary image data
Returns:
Image embedding as a list of floats
"""
# Load image from binary data
image = Image.open(io.BytesIO(image_data)).convert("RGB")
return self.generate_embedding_from_pil(image)
def generate_embedding_from_pil(self, image: Image.Image) -> List[float]:
"""Generate embedding for a PIL Image
Args:
image: PIL Image object
Returns:
Image embedding as a list of floats
"""
# Process image for CLIP
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Generate embedding
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
# Normalize embedding and convert to list
image_embedding = image_features.cpu().numpy()[0]
normalized_embedding = image_embedding / np.linalg.norm(image_embedding)
return normalized_embedding.tolist()
def get_embeddings_from_folder(self, folder_path: str) -> Dict[str, Any]:
"""Generate embeddings for all images in a folder
Args:
folder_path: Path to folder containing images
Returns:
Dictionary mapping filenames to embeddings
"""
results = {}
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'}
# Check if folder exists
if not os.path.exists(folder_path):
return {"error": f"Folder {folder_path} does not exist"}
# Process each image file
for filename in os.listdir(folder_path):
if os.path.splitext(filename)[1].lower() in image_extensions:
try:
file_path = os.path.join(folder_path, filename)
with open(file_path, 'rb') as f:
image_data = f.read()
embedding = self.generate_embedding(image_data)
results[filename] = {
"embedding": embedding,
"status": "success"
}
except Exception as e:
results[filename] = {
"error": str(e),
"status": "failed"
}
return results