sitammeur's picture
Update app.py
ac2f896 verified
import warnings
warnings.filterwarnings("ignore")
import sys
import gradio as gr
from classifier import MedSigLIPClassifier
from logger import logging
from exception import CustomExceptionHandling
# Initialize the classifier
# This might take a moment to download/load the model
classifier = MedSigLIPClassifier()
def infer(image, candidate_labels):
"""Infer function to predict the probability of the given image and candidate labels."""
try:
if not image:
raise gr.Error("No image uploaded")
# Split labels by comma and strip whitespace
labels = [l.strip() for l in candidate_labels.split(",") if l.strip()]
if not labels:
raise gr.Error("No labels provided")
# Call the classifier
logging.info("Calling the classifier")
return classifier.predict(image, labels)
except Exception as e:
# Custom exception handling
raise CustomExceptionHandling(e, sys) from e
# Gradio interface
with gr.Blocks(title="MedSigLIP Classifier") as demo:
with gr.Column():
gr.Markdown(
"""
# MedSigLIP Zero-Shot Classification
MedSigLIP is a medical adaptation of SigLIP that embeds medical images and text into a shared space. It uses 400M vision + 400M text encoders, supports 448×448 images and 64 text tokens, and is trained on broad de-identified medical image–text pairs plus natural images.
Best for medical image tasks like data-efficient classification, zero-shot classification, and semantic retrieval.
## Links
* Model Page: https://huggingface.co/google/medsiglip-448
* Model Documentation: https://developers.google.com/health-ai-developer-foundations/medsiglip
* GitHub: https://github.com/google-health/medsiglip
"""
)
with gr.Row():
# Add image input, text input and run button
with gr.Column():
image_input = gr.Image(
type="pil", label="Image", placeholder="Upload an image", height=310
)
text_input = gr.Textbox(
label="Labels",
placeholder="Enter your input labels here (comma separated)",
)
run_button = gr.Button("Run")
with gr.Column():
output_label = gr.Label(label="Output", num_top_classes=3)
# Add examples
gr.Examples(
examples=[
[
"images/sample1.png",
"a photo of a leg with no rash, a photo of a leg with a rash",
],
[
"images/sample2.png",
"a photo of an arm with no rash, a photo of an arm with a rash",
],
],
inputs=[image_input, text_input],
outputs=[output_label],
fn=infer,
cache_examples=True,
cache_mode="lazy",
)
# Add run button click event
run_button.click(
fn=infer, inputs=[image_input, text_input], outputs=[output_label]
)
# Launch the app
demo.launch(debug=False, theme=gr.themes.Monochrome())