|
|
""" |
|
|
Image embedding service for generating vector embeddings from images |
|
|
""" |
|
|
|
|
|
import io |
|
|
import os |
|
|
import base64 |
|
|
from typing import List, Tuple |
|
|
|
|
|
import open_clip |
|
|
import torch |
|
|
from fastapi import UploadFile, HTTPException |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
|
|
|
class ImageEmbeddingModel: |
|
|
"""Class for handling image embedding using CLIP model""" |
|
|
|
|
|
def __init__(self, model_name: str): |
|
|
self.model_name = model_name |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model, self.preprocess_train, self.preprocess_val = self._initialize_model() |
|
|
|
|
|
def _initialize_model(self) -> Tuple: |
|
|
"""Initialize the CLIP model for image embeddings""" |
|
|
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(self.model_name) |
|
|
tokenizer = open_clip.get_tokenizer(self.model_name) |
|
|
model.to(self.device) |
|
|
model.eval() |
|
|
return model, preprocess_train, preprocess_val |
|
|
|
|
|
def get_embedding_from_pil(self, image: Image.Image) -> List[float]: |
|
|
"""Get embedding from PIL image""" |
|
|
processed_image = self.preprocess_val(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
if self.device == 'cuda': |
|
|
autocast_context = torch.amp.autocast(device_type='cuda') |
|
|
else: |
|
|
|
|
|
autocast_context = torch.amp.autocast(device_type='cpu', dtype=torch.float32) |
|
|
with torch.no_grad(), autocast_context: |
|
|
image_features = self.model.encode_image(processed_image, normalize=True) |
|
|
|
|
|
|
|
|
return image_features.cpu().numpy()[0].tolist() |
|
|
|
|
|
async def get_embedding_from_upload(self, image_file: UploadFile) -> List[float]: |
|
|
"""Get embedding from uploaded image file""" |
|
|
try: |
|
|
contents = await image_file.read() |
|
|
img = Image.open(io.BytesIO(contents)).convert("RGB") |
|
|
return self.get_embedding_from_pil(img) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}") |
|
|
|
|
|
def get_embedding_from_base64(self, base64_data: str) -> List[float]: |
|
|
"""Get embedding from base64 encoded image""" |
|
|
try: |
|
|
|
|
|
if ',' in base64_data: |
|
|
base64_data = base64_data.split(',')[1] |
|
|
|
|
|
image_bytes = base64.b64decode(base64_data) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
return self.get_embedding_from_pil(image) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}") |
|
|
|
|
|
def get_embeddings_from_folder(self, image_folder: str) -> List[List[float]]: |
|
|
"""Get embeddings from all images in a folder""" |
|
|
embeddings = [] |
|
|
|
|
|
if not os.path.exists(image_folder): |
|
|
raise HTTPException(status_code=404, detail=f"Folder not found: {image_folder}") |
|
|
|
|
|
for image_name in os.listdir(image_folder): |
|
|
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
|
try: |
|
|
image_path = os.path.join(image_folder, image_name) |
|
|
img = Image.open(image_path).convert("RGB") |
|
|
embeddings.append(self.get_embedding_from_pil(img)) |
|
|
except Exception as e: |
|
|
print(f"Error processing {image_name}: {str(e)}") |
|
|
|
|
|
return embeddings |
|
|
|