| | from typing import Dict, List, Any |
| | import io |
| | import base64 |
| | from PIL import Image |
| | import torch |
| | import open_clip |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| |
|
| | if torch.backends.mps.is_available(): |
| | device = "mps" |
| | else: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Using device: {device}") |
| |
|
| |
|
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path='hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K'): |
| | self.tokenizer = open_clip.get_tokenizer(path) |
| | self.model, self.preprocess = open_clip.create_model_from_pretrained(path) |
| | self.model = self.model.to(device) |
| |
|
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str`) |
| | date (:obj: `str`) |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | |
| | classes = data.pop('classes') |
| | base64_image = data.pop('base64_image') |
| | image_data = base64.b64decode(base64_image) |
| | image = Image.open(io.BytesIO(image_data)) |
| | image = self.preprocess(image).unsqueeze(0).to(device) |
| | text = self.tokenizer(classes).to(device) |
| |
|
| | with torch.no_grad(): |
| | image_features = self.model.encode_image(image) |
| | text_features = self.model.encode_text(text) |
| | image_features /= image_features.norm(dim=-1, keepdim=True) |
| | text_features /= text_features.norm(dim=-1, keepdim=True) |
| |
|
| | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
| | return { |
| | "text_probs": text_probs.tolist()[0], |
| | "image_features" : image_features.tolist()[0], |
| | "text_features" : text_features.tolist()[0] |
| | } |
| | |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | handler = EndpointHandler() |
| | |
| | with open("/Users/mpa/Library/Mobile Documents/com~apple~CloudDocs/mac/work/zillow-scrapper/properties/76031221/1af0f3c34bff2173ab74ae46a5905d4a-cc_ft_1536.jpg", "rb") as f: |
| | image_data = f.read() |
| | base64_image = base64.b64encode(image_data).decode("utf-8") |
| |
|
| | data = { |
| | "classes": ["bedroom", "kitchen", "bathroom", "living room", "dining room", "patio", "backyard", "front yard", "garage", "pool"], |
| | "base64_image": base64_image |
| | } |
| | results = handler(data) |
| | print('output') |