Spaces:
Running
on
L4
Running
on
L4
| import gradio as gr | |
| import torch | |
| import spaces | |
| import json | |
| import base64 | |
| from io import BytesIO | |
| from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor | |
| import os | |
| import pandas as pd | |
| from utils import * | |
| from PIL import Image | |
| # Carga de modelos | |
| sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge") | |
| sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge") | |
| sam_model = SamModel.from_pretrained("facebook/sam-vit-huge") | |
| sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
| def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None): | |
| if input_boxes is not None: | |
| input_boxes = [input_boxes] | |
| inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
| ) | |
| scores = outputs.iou_scores | |
| return masks, scores | |
| def encode_pil_to_base64(pil_image): | |
| buffer = BytesIO() | |
| pil_image.save(buffer, format="PNG") | |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| def compare_images_points_and_masks(user_image, input_boxes, input_points): | |
| for example_path, example_data in example_data_map.items(): | |
| if example_data["size"] == list(user_image.size): | |
| user_image = Image.open(example_data['original_image_path']) | |
| input_boxes = input_boxes.values.tolist() | |
| input_points = input_points.values.tolist() | |
| input_boxes = [[[int(coord) for coord in box] for box in input_boxes if any(box)]] | |
| input_points = [[[int(coord) for coord in point] for point in input_points if any(point)]] | |
| input_boxes = input_boxes if input_boxes[0] else None | |
| input_points = input_points if input_points[0] else None | |
| sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points) | |
| sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points) | |
| if input_boxes and input_points: | |
| img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM') | |
| img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ') | |
| elif input_boxes: | |
| img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM') | |
| img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ') | |
| elif input_points: | |
| img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM') | |
| img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ') | |
| print('user_image', user_image) | |
| print("img1_b64", img1_b64) | |
| print("img2_b64", img2_b64) | |
| html_code = f""" | |
| <div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer"> | |
| <div style="position: relative; width: 100%;"> | |
| <img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;"> | |
| <div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;"> | |
| <img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;"> | |
| </div> | |
| <div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div> | |
| </div> | |
| <input type="range" min="0" max="100" value="0" | |
| style="width:100%; margin-top: 10px;" | |
| oninput=" | |
| const val = this.value; | |
| const container = document.getElementById('imageCompareContainer'); | |
| const width = container.offsetWidth; | |
| const clipValue = 100 - val; | |
| document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)'; | |
| document.getElementById('sliderLine').style.left = (width * val / 100) + 'px'; | |
| "> | |
| </div> | |
| """ | |
| return html_code | |
| def load_examples(json_file="examples.json"): | |
| with open(json_file, "r") as f: | |
| examples = json.load(f) | |
| return examples | |
| examples = load_examples() | |
| example_paths = [example["image_path"] for example in examples] | |
| example_data_map = { | |
| example["image_path"]: { | |
| "original_image_path": example["original_image_path"], | |
| "points": example["points"], | |
| "boxes": example["boxes"], | |
| "size": example["size"] | |
| } | |
| for example in examples | |
| } | |
| theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald") | |
| with gr.Blocks(theme=theme, title="π Compare SAM vs SAM-HQ") as demo: | |
| image_path_box = gr.Textbox(visible=False) | |
| gr.Markdown("## π Compare SAM vs SAM-HQ") | |
| gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it") | |
| gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)") | |
| with gr.Row(): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Example image (click below to load)", | |
| interactive=False, | |
| height=500, | |
| show_label=True | |
| ) | |
| gr.Examples( | |
| examples=example_paths, | |
| inputs=[image_input], | |
| label="Click an example to try π", | |
| ) | |
| result_html = gr.HTML(elem_id="result-html") | |
| with gr.Row(): | |
| points_input = gr.Dataframe( | |
| headers=["x", "y"], | |
| label="Points", | |
| datatype=["number", "number"], | |
| col_count=(2, "fixed") | |
| ) | |
| boxes_input = gr.Dataframe( | |
| headers=["x0", "y0", "x1", "y1"], | |
| label="Boxes", | |
| datatype=["number", "number", "number", "number"], | |
| col_count=(4, "fixed") | |
| ) | |
| def on_image_change(image): | |
| for example_path, example_data in example_data_map.items(): | |
| print(image.size) | |
| if example_data["size"] == list(image.size): | |
| return example_data["points"], example_data["boxes"] | |
| return [], [] | |
| image_input.change( | |
| fn=on_image_change, | |
| inputs=[image_input], | |
| outputs=[points_input, boxes_input] | |
| ) | |
| compare_button = gr.Button("Compare points and masks") | |
| compare_button.click(fn=compare_images_points_and_masks, inputs=[image_input, boxes_input, points_input], outputs=result_html) | |
| gr.HTML(""" | |
| <style> | |
| #result-html { | |
| min-height: 500px; | |
| border: 1px solid #ccc; | |
| padding: 10px; | |
| box-sizing: border-box; | |
| background-color: #fff; | |
| border-radius: 8px; | |
| box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| </style> | |
| """) | |
| demo.launch() | |