File size: 3,600 Bytes
b36cb8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Image embedding service for generating vector embeddings from images
"""

import io
import os
import base64
from typing import List, Tuple

import open_clip
import torch
from fastapi import UploadFile, HTTPException
from PIL import Image
import torch


class ImageEmbeddingModel:
    """Class for handling image embedding using CLIP model"""
    
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess_train, self.preprocess_val = self._initialize_model()
        
    def _initialize_model(self) -> Tuple:
        """Initialize the CLIP model for image embeddings"""
        model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(self.model_name)
        tokenizer = open_clip.get_tokenizer(self.model_name)
        model.to(self.device)
        model.eval()
        return model, preprocess_train, preprocess_val
    
    def get_embedding_from_pil(self, image: Image.Image) -> List[float]:
        """Get embedding from PIL image"""
        processed_image = self.preprocess_val(image).unsqueeze(0).to(self.device)

        # with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu')
        if self.device == 'cuda':
            autocast_context = torch.amp.autocast(device_type='cuda')
        else:
            # On CPU, autocast should either be skipped or forced to float32
            autocast_context = torch.amp.autocast(device_type='cpu', dtype=torch.float32)
        with torch.no_grad(), autocast_context:
            image_features = self.model.encode_image(processed_image, normalize=True)


        return image_features.cpu().numpy()[0].tolist()
    
    async def get_embedding_from_upload(self, image_file: UploadFile) -> List[float]:
        """Get embedding from uploaded image file"""
        try:
            contents = await image_file.read()
            img = Image.open(io.BytesIO(contents)).convert("RGB")
            return self.get_embedding_from_pil(img)
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}")
    
    def get_embedding_from_base64(self, base64_data: str) -> List[float]:
        """Get embedding from base64 encoded image"""
        try:
            # Handle data URI format
            if ',' in base64_data:
                base64_data = base64_data.split(',')[1]
                
            image_bytes = base64.b64decode(base64_data)
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
            return self.get_embedding_from_pil(image)
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}")
    
    def get_embeddings_from_folder(self, image_folder: str) -> List[List[float]]:
        """Get embeddings from all images in a folder"""
        embeddings = []

        if not os.path.exists(image_folder):
            raise HTTPException(status_code=404, detail=f"Folder not found: {image_folder}")

        for image_name in os.listdir(image_folder):
            if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                try:
                    image_path = os.path.join(image_folder, image_name)
                    img = Image.open(image_path).convert("RGB")
                    embeddings.append(self.get_embedding_from_pil(img))
                except Exception as e:
                    print(f"Error processing {image_name}: {str(e)}")

        return embeddings