|
|
""" |
|
|
Embedding Service for generating image embeddings |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import List, Dict, Any |
|
|
from PIL import Image |
|
|
import io |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
|
|
|
|
|
class ImageEmbeddingModel: |
|
|
"""Class for generating embeddings from images using CLIP""" |
|
|
|
|
|
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"): |
|
|
"""Initialize the CLIP model |
|
|
|
|
|
Args: |
|
|
model_name: Name of the CLIP model to use |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model = CLIPModel.from_pretrained(model_name).to(self.device) |
|
|
self.processor = CLIPProcessor.from_pretrained(model_name) |
|
|
|
|
|
def generate_embedding(self, image_data: bytes) -> List[float]: |
|
|
"""Generate embedding for an image from binary data |
|
|
|
|
|
Args: |
|
|
image_data: Binary image data |
|
|
|
|
|
Returns: |
|
|
Image embedding as a list of floats |
|
|
""" |
|
|
|
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB") |
|
|
return self.generate_embedding_from_pil(image) |
|
|
|
|
|
def generate_embedding_from_pil(self, image: Image.Image) -> List[float]: |
|
|
"""Generate embedding for a PIL Image |
|
|
|
|
|
Args: |
|
|
image: PIL Image object |
|
|
|
|
|
Returns: |
|
|
Image embedding as a list of floats |
|
|
""" |
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
image_features = self.model.get_image_features(**inputs) |
|
|
|
|
|
|
|
|
image_embedding = image_features.cpu().numpy()[0] |
|
|
normalized_embedding = image_embedding / np.linalg.norm(image_embedding) |
|
|
return normalized_embedding.tolist() |
|
|
|
|
|
def get_embeddings_from_folder(self, folder_path: str) -> Dict[str, Any]: |
|
|
"""Generate embeddings for all images in a folder |
|
|
|
|
|
Args: |
|
|
folder_path: Path to folder containing images |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping filenames to embeddings |
|
|
""" |
|
|
results = {} |
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'} |
|
|
|
|
|
|
|
|
if not os.path.exists(folder_path): |
|
|
return {"error": f"Folder {folder_path} does not exist"} |
|
|
|
|
|
|
|
|
for filename in os.listdir(folder_path): |
|
|
if os.path.splitext(filename)[1].lower() in image_extensions: |
|
|
try: |
|
|
file_path = os.path.join(folder_path, filename) |
|
|
with open(file_path, 'rb') as f: |
|
|
image_data = f.read() |
|
|
|
|
|
embedding = self.generate_embedding(image_data) |
|
|
results[filename] = { |
|
|
"embedding": embedding, |
|
|
"status": "success" |
|
|
} |
|
|
except Exception as e: |
|
|
results[filename] = { |
|
|
"error": str(e), |
|
|
"status": "failed" |
|
|
} |
|
|
|
|
|
return results |