File size: 3,600 Bytes
b36cb8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
"""
Image embedding service for generating vector embeddings from images
"""
import io
import os
import base64
from typing import List, Tuple
import open_clip
import torch
from fastapi import UploadFile, HTTPException
from PIL import Image
import torch
class ImageEmbeddingModel:
"""Class for handling image embedding using CLIP model"""
def __init__(self, model_name: str):
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess_train, self.preprocess_val = self._initialize_model()
def _initialize_model(self) -> Tuple:
"""Initialize the CLIP model for image embeddings"""
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(self.model_name)
tokenizer = open_clip.get_tokenizer(self.model_name)
model.to(self.device)
model.eval()
return model, preprocess_train, preprocess_val
def get_embedding_from_pil(self, image: Image.Image) -> List[float]:
"""Get embedding from PIL image"""
processed_image = self.preprocess_val(image).unsqueeze(0).to(self.device)
# with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu')
if self.device == 'cuda':
autocast_context = torch.amp.autocast(device_type='cuda')
else:
# On CPU, autocast should either be skipped or forced to float32
autocast_context = torch.amp.autocast(device_type='cpu', dtype=torch.float32)
with torch.no_grad(), autocast_context:
image_features = self.model.encode_image(processed_image, normalize=True)
return image_features.cpu().numpy()[0].tolist()
async def get_embedding_from_upload(self, image_file: UploadFile) -> List[float]:
"""Get embedding from uploaded image file"""
try:
contents = await image_file.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
return self.get_embedding_from_pil(img)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}")
def get_embedding_from_base64(self, base64_data: str) -> List[float]:
"""Get embedding from base64 encoded image"""
try:
# Handle data URI format
if ',' in base64_data:
base64_data = base64_data.split(',')[1]
image_bytes = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return self.get_embedding_from_pil(image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}")
def get_embeddings_from_folder(self, image_folder: str) -> List[List[float]]:
"""Get embeddings from all images in a folder"""
embeddings = []
if not os.path.exists(image_folder):
raise HTTPException(status_code=404, detail=f"Folder not found: {image_folder}")
for image_name in os.listdir(image_folder):
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
image_path = os.path.join(image_folder, image_name)
img = Image.open(image_path).convert("RGB")
embeddings.append(self.get_embedding_from_pil(img))
except Exception as e:
print(f"Error processing {image_name}: {str(e)}")
return embeddings
|