Spaces:
Runtime error
Runtime error
| import os | |
| os.system('python setup.py develop') | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import matplotlib | |
| from gradio_utils.utils import (process_img, get_select_coords, select_skeleton, | |
| reset_skeleton, reset_kp, process, update_examples) | |
| LENGTH = 480 # Length of the square area displaying/editing images | |
| matplotlib.use('agg') | |
| model_dir = Path('./checkpoints') | |
| parser = argparse.ArgumentParser(description='EdgeCape Demo') | |
| parser.add_argument('--checkpoint', | |
| help='checkpoint path', | |
| default='ckpt/1shot_split1.pth') | |
| args = parser.parse_args() | |
| checkpoint_path = args.checkpoint | |
| device = 'cuda' | |
| TIMEOUT = 80 | |
| with gr.Blocks() as demo: | |
| gr.Markdown(''' | |
| # We introduce EdgeCape, a novel framework that overcomes these limitations by predicting the graph's edge weights which optimizes localization. | |
| To further leverage structural priors, we propose integrating Markovian Structural Bias, which modulates the self-attention interaction between nodes based on the number of hops between them. | |
| We show that this improves the model’s ability to capture global spatial dependencies. | |
| Evaluated on the MP-100 benchmark, which includes 100 categories and over 20K images, | |
| EdgeCape achieves state-of-the-art results in the 1-shot setting and leads among similar-sized methods in the 5-shot setting, significantly improving keypoint localization accuracy. | |
| ### [Paper](https://arxiv.org/pdf/2411.16665) | [Project Page](https://orhir.github.io/edge_cape/) | |
| ## Instructions | |
| 1. Upload an image from the same category as the object you want to pose. | |
| 2. Mark keypoints on the middle image. When finished - press 'Confirm Clicked Points'. | |
| 3. Mark limbs on the right image. | |
| 4. Upload an image of the object you want to pose to the query image (**bottom**). | |
| 5. Click **Evaluate** to pose the query image. | |
| ''') | |
| global_state = gr.State({ | |
| "images": {}, | |
| "points": [], | |
| "skeleton": [], | |
| "prev_point": None, | |
| "curr_type_point": "start", | |
| "load_example": False, | |
| }) | |
| with gr.Row(): | |
| # Upload & Preprocess Image Column | |
| with gr.Column(): | |
| gr.Markdown( | |
| """<p style="text-align: center; font-size: 20px">Upload & Preprocess Image</p>""" | |
| ) | |
| support_image = gr.Image( | |
| height=LENGTH, | |
| width=LENGTH, | |
| type="pil", | |
| image_mode="RGB", | |
| label="Support Image", | |
| show_label=True, | |
| interactive=True, | |
| ) | |
| # Click Points Column | |
| with gr.Column(): | |
| gr.Markdown( | |
| """<p style="text-align: center; font-size: 20px">Click Points</p>""" | |
| ) | |
| kp_support_image = gr.Image( | |
| type="pil", | |
| label="Keypoints Image", | |
| show_label=True, | |
| height=LENGTH, | |
| width=LENGTH, | |
| interactive=False, | |
| show_fullscreen_button=False, | |
| ) | |
| with gr.Row(): | |
| confirm_kp_button = gr.Button("Confirm Clicked Points", scale=3) | |
| with gr.Row(): | |
| undo_kp_button = gr.Button("Undo Clicked Points", scale=3) | |
| # Editing Results Column | |
| with gr.Column(): | |
| gr.Markdown( | |
| """<p style="text-align: center; font-size: 20px">Click Skeleton</p>""" | |
| ) | |
| skel_support_image = gr.Image( | |
| type="pil", | |
| label="Skeleton Image", | |
| show_label=True, | |
| height=LENGTH, | |
| width=LENGTH, | |
| interactive=False, | |
| show_fullscreen_button=False, | |
| ) | |
| with gr.Row(): | |
| pass | |
| with gr.Row(): | |
| undo_skel_button = gr.Button("Undo Skeleton") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """<p style="text-align: center; font-size: 20px">Query Image</p>""" | |
| ) | |
| query_image = gr.Image( | |
| type="pil", | |
| image_mode="RGB", | |
| label="Query Image", | |
| show_label=True, | |
| interactive=True, | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """<p style="text-align: center; font-size: 20px">Output</p>""" | |
| ) | |
| output_img = gr.Plot(label="Output Image", ) | |
| with gr.Row(): | |
| eval_btn = gr.Button(value="Evaluate") | |
| with gr.Row(): | |
| gr.Markdown("## Examples") | |
| with gr.Row(): | |
| example_null = gr.Textbox(type='text', | |
| visible=False | |
| ) | |
| with gr.Row(): | |
| examples = gr.Examples([ | |
| ['examples/dog2.png', | |
| 'examples/dog1.png', | |
| json.dumps({ | |
| 'points': [(232, 200), (312, 204), (228, 264), (316, 472), (316, 616), (296, 868), (412, 872), | |
| (416, 624), (604, 608), (648, 860), (764, 852), (696, 608), (684, 432)], | |
| 'skeleton': [(0, 1), (1, 2), (0, 2), (3, 4), (4, 5), | |
| (3, 7), (7, 6), (3, 12), (12, 8), (8, 9), | |
| (12, 11), (11, 10)], | |
| }) | |
| ], | |
| ['examples/sofa1.jpg', | |
| 'examples/sofa2.png', | |
| json.dumps({'points': [[272, 561], [193, 482], [339, 460], [445, 530], [264, 369], [203, 318], [354, 300], | |
| [457, 341], [345, 63], [187, 68]], | |
| 'skeleton': [[0, 4], [1, 5], [2, 6], [3, 7], [7, 6], [6, 5], | |
| [5, 4], [4, 7], [5, 9], [9, 8], [8, 6]], | |
| })], | |
| ['examples/person1.jpeg', | |
| 'examples/person2.jpeg', | |
| json.dumps({ | |
| 'points': [[322, 488], [431, 486], [526, 644], [593, 486], [697, 492], [407, 728], | |
| [522, 726], [625, 737], [515, 798]], | |
| 'skeleton': [[0, 1], [1, 3], [3, 4], [1, 2], [2, 3], [5, 6], [6, 7], [7, 8], [8, 5]], | |
| })] | |
| ], | |
| inputs=[support_image, query_image, example_null], | |
| outputs=[support_image, kp_support_image, skel_support_image, query_image, global_state], | |
| fn=update_examples, | |
| run_on_click=True, | |
| examples_per_page=5, | |
| cache_examples=False, | |
| ) | |
| support_image.upload(process_img, | |
| inputs=[support_image, global_state], | |
| outputs=[kp_support_image, global_state]) | |
| kp_support_image.select(get_select_coords, | |
| [global_state], | |
| [global_state, kp_support_image], | |
| queue=False, ) | |
| confirm_kp_button.click(reset_skeleton, | |
| inputs=global_state, | |
| outputs=skel_support_image) | |
| undo_kp_button.click(reset_kp, | |
| inputs=global_state, | |
| outputs=[kp_support_image, skel_support_image]) | |
| undo_skel_button.click(reset_skeleton, | |
| inputs=global_state, | |
| outputs=skel_support_image) | |
| skel_support_image.select(select_skeleton, | |
| inputs=[global_state], | |
| outputs=[global_state, skel_support_image]) | |
| eval_btn.click(fn=process, | |
| inputs=[query_image, global_state], | |
| outputs=[output_img]) | |
| if __name__ == "__main__": | |
| print("Start app", parser.parse_args()) | |
| gr.close_all() | |
| demo.launch(show_api=False) | |