""" Image embedding service for generating vector embeddings from images """ import io import os import base64 from typing import List, Tuple import open_clip import torch from fastapi import UploadFile, HTTPException from PIL import Image import torch class ImageEmbeddingModel: """Class for handling image embedding using CLIP model""" def __init__(self, model_name: str): self.model_name = model_name self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model, self.preprocess_train, self.preprocess_val = self._initialize_model() def _initialize_model(self) -> Tuple: """Initialize the CLIP model for image embeddings""" model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(self.model_name) tokenizer = open_clip.get_tokenizer(self.model_name) model.to(self.device) model.eval() return model, preprocess_train, preprocess_val def get_embedding_from_pil(self, image: Image.Image) -> List[float]: """Get embedding from PIL image""" processed_image = self.preprocess_val(image).unsqueeze(0).to(self.device) # with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu') if self.device == 'cuda': autocast_context = torch.amp.autocast(device_type='cuda') else: # On CPU, autocast should either be skipped or forced to float32 autocast_context = torch.amp.autocast(device_type='cpu', dtype=torch.float32) with torch.no_grad(), autocast_context: image_features = self.model.encode_image(processed_image, normalize=True) return image_features.cpu().numpy()[0].tolist() async def get_embedding_from_upload(self, image_file: UploadFile) -> List[float]: """Get embedding from uploaded image file""" try: contents = await image_file.read() img = Image.open(io.BytesIO(contents)).convert("RGB") return self.get_embedding_from_pil(img) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}") def get_embedding_from_base64(self, base64_data: str) -> List[float]: """Get embedding from base64 encoded image""" try: # Handle data URI format if ',' in base64_data: base64_data = base64_data.split(',')[1] image_bytes = base64.b64decode(base64_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") return self.get_embedding_from_pil(image) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}") def get_embeddings_from_folder(self, image_folder: str) -> List[List[float]]: """Get embeddings from all images in a folder""" embeddings = [] if not os.path.exists(image_folder): raise HTTPException(status_code=404, detail=f"Folder not found: {image_folder}") for image_name in os.listdir(image_folder): if image_name.lower().endswith(('.png', '.jpg', '.jpeg')): try: image_path = os.path.join(image_folder, image_name) img = Image.open(image_path).convert("RGB") embeddings.append(self.get_embedding_from_pil(img)) except Exception as e: print(f"Error processing {image_name}: {str(e)}") return embeddings