""" Embedding Service for generating image embeddings """ import os from typing import List, Dict, Any from PIL import Image import io import numpy as np import torch from transformers import CLIPProcessor, CLIPModel class ImageEmbeddingModel: """Class for generating embeddings from images using CLIP""" def __init__(self, model_name: str = "openai/clip-vit-base-patch32"): """Initialize the CLIP model Args: model_name: Name of the CLIP model to use """ self.model_name = model_name self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = CLIPModel.from_pretrained(model_name).to(self.device) self.processor = CLIPProcessor.from_pretrained(model_name) def generate_embedding(self, image_data: bytes) -> List[float]: """Generate embedding for an image from binary data Args: image_data: Binary image data Returns: Image embedding as a list of floats """ # Load image from binary data image = Image.open(io.BytesIO(image_data)).convert("RGB") return self.generate_embedding_from_pil(image) def generate_embedding_from_pil(self, image: Image.Image) -> List[float]: """Generate embedding for a PIL Image Args: image: PIL Image object Returns: Image embedding as a list of floats """ # Process image for CLIP inputs = self.processor(images=image, return_tensors="pt").to(self.device) # Generate embedding with torch.no_grad(): image_features = self.model.get_image_features(**inputs) # Normalize embedding and convert to list image_embedding = image_features.cpu().numpy()[0] normalized_embedding = image_embedding / np.linalg.norm(image_embedding) return normalized_embedding.tolist() def get_embeddings_from_folder(self, folder_path: str) -> Dict[str, Any]: """Generate embeddings for all images in a folder Args: folder_path: Path to folder containing images Returns: Dictionary mapping filenames to embeddings """ results = {} image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'} # Check if folder exists if not os.path.exists(folder_path): return {"error": f"Folder {folder_path} does not exist"} # Process each image file for filename in os.listdir(folder_path): if os.path.splitext(filename)[1].lower() in image_extensions: try: file_path = os.path.join(folder_path, filename) with open(file_path, 'rb') as f: image_data = f.read() embedding = self.generate_embedding(image_data) results[filename] = { "embedding": embedding, "status": "success" } except Exception as e: results[filename] = { "error": str(e), "status": "failed" } return results