File size: 3,419 Bytes
ac2f896
a2cda0b
ac2f896
a2cda0b
db50b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2cda0b
db50b86
 
 
a2cda0b
 
06a0dde
 
 
a2cda0b
 
 
 
 
db50b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2cda0b
db50b86
 
 
a2cda0b
db50b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2cda0b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import warnings

warnings.filterwarnings("ignore")

import sys
import gradio as gr
from classifier import MedSigLIPClassifier
from logger import logging
from exception import CustomExceptionHandling

# Initialize the classifier
# This might take a moment to download/load the model
classifier = MedSigLIPClassifier()


def infer(image, candidate_labels):
    """Infer function to predict the probability of the given image and candidate labels."""
    try:
        if not image:
            raise gr.Error("No image uploaded")

        # Split labels by comma and strip whitespace
        labels = [l.strip() for l in candidate_labels.split(",") if l.strip()]

        if not labels:
            raise gr.Error("No labels provided")

        # Call the classifier
        logging.info("Calling the classifier")
        return classifier.predict(image, labels)

    except Exception as e:
        # Custom exception handling
        raise CustomExceptionHandling(e, sys) from e


# Gradio interface
with gr.Blocks(title="MedSigLIP Classifier") as demo:
    with gr.Column():

        gr.Markdown(
            """
            # MedSigLIP Zero-Shot Classification
            MedSigLIP is a medical adaptation of SigLIP that embeds medical images and text into a shared space. It uses 400M vision + 400M text encoders, supports 448×448 images and 64 text tokens, and is trained on broad de-identified medical image–text pairs plus natural images.
            Best for medical image tasks like data-efficient classification, zero-shot classification, and semantic retrieval.
            
            ## Links
            * Model Page: https://huggingface.co/google/medsiglip-448
            * Model Documentation: https://developers.google.com/health-ai-developer-foundations/medsiglip
            * GitHub: https://github.com/google-health/medsiglip
            """
        )

        with gr.Row():
            # Add image input, text input and run button
            with gr.Column():
                image_input = gr.Image(
                    type="pil", label="Image", placeholder="Upload an image", height=310
                )
                text_input = gr.Textbox(
                    label="Labels",
                    placeholder="Enter your input labels here (comma separated)",
                )
                run_button = gr.Button("Run")
            with gr.Column():
                output_label = gr.Label(label="Output", num_top_classes=3)

                # Add examples
                gr.Examples(
                    examples=[
                        [
                            "images/sample1.png",
                            "a photo of a leg with no rash, a photo of a leg with a rash",
                        ],
                        [
                            "images/sample2.png",
                            "a photo of an arm with no rash, a photo of an arm with a rash",
                        ],
                    ],
                    inputs=[image_input, text_input],
                    outputs=[output_label],
                    fn=infer,
                    cache_examples=True,
                    cache_mode="lazy",
                )

        # Add run button click event
        run_button.click(
            fn=infer, inputs=[image_input, text_input], outputs=[output_label]
        )

# Launch the app
demo.launch(debug=False, theme=gr.themes.Monochrome())