Spaces:
Runtime error
Runtime error
| # ************************************************************************* | |
| # Copyright (2023) Bytedance Inc. | |
| # | |
| # Copyright (2023) DragDiffusion Authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ************************************************************************* | |
| import os | |
| import gradio as gr | |
| from utils.ui_utils import get_points, undo_points | |
| from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag | |
| from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen | |
| LENGTH=480 # length of the square area displaying/editing images | |
| with gr.Blocks() as demo: | |
| # layout definition | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| # Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435) | |
| """) | |
| # UI components for editing real images | |
| with gr.Tab(label="Editing Real Image"): | |
| mask = gr.State(value=None) # store mask | |
| selected_points = gr.State([]) # store points | |
| original_image = gr.State(value=None) # store original input image | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""") | |
| canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", | |
| show_label=True, height=LENGTH, width=LENGTH) # for mask painting | |
| train_lora_button = gr.Button("Train LoRA") | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""") | |
| input_image = gr.Image(type="numpy", label="Click Points", | |
| show_label=True, height=LENGTH, width=LENGTH) # for points clicking | |
| undo_button = gr.Button("Undo point") | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""") | |
| output_image = gr.Image(type="numpy", label="Editing Results", | |
| show_label=True, height=LENGTH, width=LENGTH) | |
| with gr.Row(): | |
| run_button = gr.Button("Run") | |
| clear_all_button = gr.Button("Clear All") | |
| # general parameters | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt") | |
| lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path") | |
| lora_status_bar = gr.Textbox(label="display LoRA training status") | |
| # algorithm specific parameters | |
| with gr.Tab("Drag Config"): | |
| with gr.Row(): | |
| n_pix_step = gr.Number( | |
| value=40, | |
| label="number of pixel steps", | |
| info="Number of gradient descent (motion supervision) steps on latent.", | |
| precision=0) | |
| lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") | |
| # n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0) | |
| inversion_strength = gr.Slider(0, 1.0, | |
| value=0.75, | |
| label="inversion strength", | |
| info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") | |
| latent_lr = gr.Number(value=0.01, label="latent lr") | |
| start_step = gr.Number(value=0, label="start_step", precision=0, visible=False) | |
| start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False) | |
| with gr.Tab("Base Model Config"): | |
| with gr.Row(): | |
| local_models_dir = 'local_pretrained_models' | |
| local_models_choice = \ | |
| [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] | |
| model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", | |
| label="Diffusion Model Path", | |
| choices=[ | |
| "runwayml/stable-diffusion-v1-5", | |
| ] + local_models_choice | |
| ) | |
| vae_path = gr.Dropdown(value="default", | |
| label="VAE choice", | |
| choices=["default", | |
| "stabilityai/sd-vae-ft-mse"] + local_models_choice | |
| ) | |
| with gr.Tab("LoRA Parameters"): | |
| with gr.Row(): | |
| lora_step = gr.Number(value=200, label="LoRA training steps", precision=0) | |
| lora_lr = gr.Number(value=0.0002, label="LoRA learning rate") | |
| lora_rank = gr.Number(value=16, label="LoRA rank", precision=0) | |
| # UI components for editing generated images | |
| with gr.Tab(label="Editing Generated Image"): | |
| mask_gen = gr.State(value=None) # store mask | |
| selected_points_gen = gr.State([]) # store points | |
| original_image_gen = gr.State(value=None) # store the diffusion-generated image | |
| intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""") | |
| canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask", | |
| show_label=True, height=LENGTH, width=LENGTH) # for mask painting | |
| gen_img_button = gr.Button("Generate Image") | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""") | |
| input_image_gen = gr.Image(type="numpy", label="Click Points", | |
| show_label=True, height=LENGTH, width=LENGTH) # for points clicking | |
| undo_button_gen = gr.Button("Undo point") | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""") | |
| output_image_gen = gr.Image(type="numpy", label="Editing Results", | |
| show_label=True, height=LENGTH, width=LENGTH) | |
| with gr.Row(): | |
| run_button_gen = gr.Button("Run") | |
| clear_all_button_gen = gr.Button("Clear All") | |
| # general parameters | |
| with gr.Row(): | |
| pos_prompt_gen = gr.Textbox(label="Positive Prompt") | |
| neg_prompt_gen = gr.Textbox(label="Negative Prompt") | |
| with gr.Tab("Generation Config"): | |
| with gr.Row(): | |
| local_models_dir = 'local_pretrained_models' | |
| local_models_choice = \ | |
| [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] | |
| model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", | |
| label="Diffusion Model Path", | |
| choices=[ | |
| "runwayml/stable-diffusion-v1-5", | |
| "gsdf/Counterfeit-V2.5", | |
| "emilianJR/majicMIX_realistic", | |
| "SG161222/Realistic_Vision_V2.0", | |
| "stablediffusionapi/landscapesupermix", | |
| "huangzhe0803/ArchitectureRealMix", | |
| "stablediffusionapi/interiordesignsuperm" | |
| ] + local_models_choice | |
| ) | |
| vae_path_gen = gr.Dropdown(value="default", | |
| label="VAE choice", | |
| choices=["default", | |
| "stabilityai/sd-vae-ft-mse"] + local_models_choice | |
| ) | |
| lora_path_gen = gr.Textbox(value="", label="LoRA path") | |
| gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0) | |
| height = gr.Number(value=512, label="Height", precision=0) | |
| width = gr.Number(value=512, label="Width", precision=0) | |
| guidance_scale = gr.Number(value=7.5, label="CFG Scale") | |
| scheduler_name_gen = gr.Dropdown( | |
| value="DDIM", | |
| label="Scheduler", | |
| choices=[ | |
| "DDIM", | |
| "DPM++2M", | |
| "DPM++2M_karras" | |
| ] | |
| ) | |
| n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0) | |
| with gr.Tab(label="Drag Config"): | |
| with gr.Row(): | |
| n_pix_step_gen = gr.Number( | |
| value=40, | |
| label="Number of Pixel Steps", | |
| info="Number of gradient descent (motion supervision) steps on latent.", | |
| precision=0) | |
| lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") | |
| # n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0) | |
| inversion_strength_gen = gr.Slider(0, 1.0, | |
| value=0.75, | |
| label="Inversion Strength", | |
| info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") | |
| latent_lr_gen = gr.Number(value=0.01, label="latent lr") | |
| start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False) | |
| start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False) | |
| # Add a checkbox for users to select if they want a GIF of the process | |
| with gr.Row(): | |
| create_gif_checkbox = gr.Checkbox(label="create_GIF", value=False) | |
| create_tracking_point_checkbox = gr.Checkbox(label="create_tracking_point", value=False) | |
| gif_interval = gr.Number(value=10, label="interval_GIF", precision=0, info="The interval of the GIF, i.e. the number of steps between each frame of the GIF.") | |
| gif_fps = gr.Number(value=1, label="fps_GIF", precision=0, info="The fps of the GIF, i.e. the number of frames per second of the GIF.") | |
| # event definition | |
| # event for dragging user-input real image | |
| canvas.edit( | |
| store_img, | |
| [canvas], | |
| [original_image, selected_points, input_image, mask] | |
| ) | |
| input_image.select( | |
| get_points, | |
| [input_image, selected_points], | |
| [input_image], | |
| ) | |
| undo_button.click( | |
| undo_points, | |
| [original_image, mask], | |
| [input_image, selected_points] | |
| ) | |
| train_lora_button.click( | |
| train_lora_interface, | |
| [original_image, | |
| prompt, | |
| model_path, | |
| vae_path, | |
| lora_path, | |
| lora_step, | |
| lora_lr, | |
| lora_rank], | |
| [lora_status_bar] | |
| ) | |
| run_button.click( | |
| run_drag, | |
| [original_image, | |
| input_image, | |
| mask, | |
| prompt, | |
| selected_points, | |
| inversion_strength, | |
| lam, | |
| latent_lr, | |
| n_pix_step, | |
| model_path, | |
| vae_path, | |
| lora_path, | |
| start_step, | |
| start_layer, | |
| create_gif_checkbox, | |
| gif_interval, | |
| ], | |
| [output_image] | |
| ) | |
| clear_all_button.click( | |
| clear_all, | |
| [gr.Number(value=LENGTH, visible=False, precision=0)], | |
| [canvas, | |
| input_image, | |
| output_image, | |
| selected_points, | |
| original_image, | |
| mask] | |
| ) | |
| # event for dragging generated image | |
| canvas_gen.edit( | |
| store_img_gen, | |
| [canvas_gen], | |
| [original_image_gen, selected_points_gen, input_image_gen, mask_gen] | |
| ) | |
| input_image_gen.select( | |
| get_points, | |
| [input_image_gen, selected_points_gen], | |
| [input_image_gen], | |
| ) | |
| gen_img_button.click( | |
| gen_img, | |
| [ | |
| gr.Number(value=LENGTH, visible=False, precision=0), | |
| height, | |
| width, | |
| n_inference_step_gen, | |
| scheduler_name_gen, | |
| gen_seed, | |
| guidance_scale, | |
| pos_prompt_gen, | |
| neg_prompt_gen, | |
| model_path_gen, | |
| vae_path_gen, | |
| lora_path_gen, | |
| ], | |
| [canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen] | |
| ) | |
| undo_button_gen.click( | |
| undo_points, | |
| [original_image_gen, mask_gen], | |
| [input_image_gen, selected_points_gen] | |
| ) | |
| run_button_gen.click( | |
| run_drag_gen, | |
| [ | |
| n_inference_step_gen, | |
| scheduler_name_gen, | |
| original_image_gen, # the original image generated by the diffusion model | |
| input_image_gen, # image with clicking, masking, etc. | |
| intermediate_latents_gen, | |
| guidance_scale, | |
| mask_gen, | |
| pos_prompt_gen, | |
| neg_prompt_gen, | |
| selected_points_gen, | |
| inversion_strength_gen, | |
| lam_gen, | |
| latent_lr_gen, | |
| n_pix_step_gen, | |
| model_path_gen, | |
| vae_path_gen, | |
| lora_path_gen, | |
| start_step_gen, | |
| start_layer_gen, | |
| create_gif_checkbox, | |
| create_tracking_point_checkbox, | |
| gif_interval, | |
| gif_fps | |
| ], | |
| [output_image_gen] | |
| ) | |
| clear_all_button_gen.click( | |
| clear_all_gen, | |
| [gr.Number(value=LENGTH, visible=False, precision=0)], | |
| [canvas_gen, | |
| input_image_gen, | |
| output_image_gen, | |
| selected_points_gen, | |
| original_image_gen, | |
| mask_gen, | |
| intermediate_latents_gen, | |
| ] | |
| ) | |
| demo.queue().launch(share=True, debug=True) | |