import os from dotenv import load_dotenv import numpy as np from PIL import Image import torch from transformers import AutoProcessor, AutoModel import tensorflow as tf # Load the Environment Variables from .env file load_dotenv() # Access token for using the model access_token = os.environ.get("ACCESS_TOKEN") class MedSigLIPClassifier: """MedSigLIPClassifier class for zero-shot classification of medical images.""" def __init__(self, model_id="google/medsiglip-448"): """Initialize the classifier with the given model ID.""" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = AutoModel.from_pretrained(model_id, token=access_token).to( self.device ) self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True, token=access_token) def _resize(self, image): """Resizes the image using TensorFlow's resize method to match MedSigLIP training preprocessing.""" return Image.fromarray( tf.image.resize( images=image, size=[448, 448], method="bilinear", antialias=False ) .numpy() .astype(np.uint8) ) def predict(self, image: Image.Image, candidate_labels: list[str]): """Predicts the probabilities for the given image and candidate labels.""" # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Resize image resized_image = self._resize(image) # Prepare inputs inputs = self.processor( text=candidate_labels, images=resized_image, padding="max_length", return_tensors="pt", ).to(self.device) # Inference with torch.no_grad(): outputs = self.model(**inputs) logits_per_image = outputs.logits_per_image probs = torch.softmax(logits_per_image, dim=1) # Format results probs_list = probs[0].tolist() return {label: prob for label, prob in zip(candidate_labels, probs_list)}