Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import json | |
| from openai import OpenAI | |
| import random | |
| import re | |
| import time | |
| import torch | |
| import gradio as gr | |
| from ProT2I.prot2i_pipeline_sdxl import ProT2IPipeline | |
| from ProT2I.processors import create_controller | |
| from PIL import Image | |
| import numpy as np | |
| import difflib | |
| _HEADER_ = ''' | |
| <div style="text-align: center; max-width: 650px; margin: 0 auto;"> | |
| <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">Detail++ for SDXL</h1> | |
| </div> | |
| ⭐⭐⭐**Tips:** | |
| - ⭐We provide a version of llm automatically decomposing the prompt, and you just need to input the complex prompts with various attributes like color, style etc. in the `Prompt` textbox. | |
| - ⭐For attributes overflow, you can adaptively increase the `Threshold Value` for mask extraction. | |
| - ⭐Also you can adjust the sub-prompts mannually in `Decomposed sub-prompts`. When entering this, please use the fixed format as followed: | |
| - The first line must start with <strong>[original]</strong>. | |
| - Subsequent lines must start with <strong>[sub-index][subject words]</strong>, where <em>subject words</em> indicates the corresponding subject of currently adding attributes. | |
| - Add one branch <strong>[sub-0][None]</strong>, if you want to remove all confusing attributes firstly. | |
| ''' | |
| def create_placeholder_image(): | |
| return Image.fromarray(np.ones((1024, 1024, 3), dtype=np.uint8) * 255) | |
| def get_diff_string(str1, str2): | |
| """ | |
| str1 and str2 are two strings. | |
| This function returns the difference between the two strings as a string. | |
| """ | |
| diff = difflib.ndiff(str1.split(), str2.split()) | |
| added_parts = [word[2:] for word in diff if word.startswith('+ ')] # get added parts | |
| return ' '.join(added_parts) | |
| def process_text(prompt): | |
| client = OpenAI( | |
| base_url="https://a1.aizex.me/v1", | |
| api_key = os.getenv('api_key'), | |
| ) | |
| system_prompt = """**Detailed Instruction Prompt for Decomposing Image Descriptions** | |
| You are provided with an original prompt that describes an image containing one or more subjects with detailed attributes (such as colors, clothing, objects, etc.). Your task is to generate a series of sub-prompts that decompose the original prompt into simpler, attribute-focused branches. Follow the steps and rules below exactly: | |
| 1. **Output Format Requirements:** | |
| - **First Line:** | |
| - Begin with `[original]` followed by a space and then the complete original prompt exactly as provided. | |
| - **Subsequent Lines:** | |
| - Each additional line must start with `[sub-index][subject]` where: | |
| - `sub-index` is a sequential number starting from 0. | |
| - `subject` is a keyword that indicates which subject's detailed attribute is being highlighted. If the attribute added is global, like background, use `None`. For the first branch, use `None` as the subject keyword. | |
| - **Line Separation:** | |
| - Each sub-prompt must appear on its own line. | |
| 2. **Decomposition Rules:** | |
| - **Generic Version ([sub-0][None]):** | |
| - Create a version of the prompt that has all specific detailed attributes (e.g., color adjectives, style adjectives) removed. This produces a simplified, generic description of the scene. | |
| - **Attribute-Specific Branches:** | |
| - For every distinct subject in the original prompt that has a specific attribute, generate a branch that reintroduces that particular attribute while keeping all other subjects in their generic state. | |
| - Each branch must re-add the attribute detail for only one subject. For example, if the original prompt mentions a “red hat” on one subject and a “blue tracksuit” on another, then: | |
| - One branch should reintroduce “red” for the hat. | |
| - Another branch should reintroduce “blue” for the tracksuit. | |
| - The keyword inside the brackets (after the sub-index) should indicate the subject whose attribute is restored (e.g., `hat`, `tracksuit`, `car`, etc.). | |
| 3. **General Guidelines:** | |
| - **Consistency:** | |
| - Ensure that the modified sub-prompts are logically consistent with the original description. Only one attribute should be reintroduced per branch, while all other attribute details remain generic. | |
| - **Precision:** | |
| - Follow the exact fixed format with square brackets and no extra characters or commentary. | |
| - **No Extra Text:** | |
| - Do not include any explanations, notes, or additional commentary in the output. The final output should only contain the sub-prompts as specified. | |
| - **Output format:** | |
| - The output should be a JSON object with a single key `variants` that contains a list of sub-prompts. | |
| 4. **Example to Follow:** | |
| Given the original prompt: | |
| ``` | |
| a man wearing a red hat and blue tracksuit is standing in front of a green sports car | |
| ``` | |
| The output should be: | |
| ``` | |
| {"variants": | |
| [ | |
| [original] a man wearing a red hat and blue tracksuit is standing in front of a green sports car | |
| [sub-0][None] a man wearing a hat and tracksuit is standing in front of a sports car | |
| [sub-1][hat] a man wearing a red hat and tracksuit is standing in front of a sports car | |
| [sub-2][tracksuit] a man wearing a hat and blue tracksuit is standing in front of a sports car | |
| [sub-3][car] a man wearing a hat and tracksuit is standing in front of a green sports car | |
| ] | |
| } | |
| ``` | |
| 5. **Another Example to Follow:** | |
| Given the original prompt: | |
| ``` | |
| In a cyberpunk style city night, a VanGogh-style hound dog is standing in front of a lego-style sports car | |
| ``` | |
| The output should be: | |
| ``` | |
| {"variants": | |
| [ | |
| [original] In a cyberpunk style city night, a VanGogh-style hound dog is standing in front of a Lego-style sports car | |
| [sub-0][None] In a city night, a hound dog is standing in front of a sports car | |
| [sub-1][None] In a cyberpunk style city night, a hound dog is standing in front of a sports car | |
| [sub-2][hound dog] In a city night, a VanGogh-style hound dog is standing in front of a sports car | |
| [sub-3][car] In a city night, a hound dog is standing in front of a Lego-style sports car | |
| ] | |
| } | |
| ``` | |
| 6. **Task Summary:** | |
| - Your task is to read the given original prompt and output a set of sub-prompts using the format above. | |
| - The first sub-prompt ([sub-0][None]) should be the fully generic version. | |
| - Each subsequent sub-prompt should selectively reintroduce one detailed attribute corresponding to a subject from the original prompt. | |
| Now, use this detailed instruction prompt to generate the decomposed sub-prompts for any provided original image description. | |
| --- | |
| """ | |
| response = client.chat.completions.create( | |
| model="gpt-4-turbo", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.7, | |
| ) | |
| return response.choices[0].message.content | |
| def init_pipeline(): | |
| device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
| pipe = ProT2IPipeline.from_pretrained("SG161222/RealVisXL_V4.0", use_safetensors=True, variant='fp16').to(torch.float16).to(device) | |
| return pipe, device | |
| def parse_sub_prompts(text): | |
| lines = [line.strip() for line in text.split('\n') if line.strip()] | |
| if not lines: | |
| raise ValueError("Please enter at least one line.") | |
| sps = [] | |
| nps = [] | |
| if not lines[0].lower().startswith("[original]"): | |
| raise ValueError("The first line must start with indicating the original description.") | |
| sps.append(lines[0][len("[original]"):].strip()) | |
| for line in lines[1:]: | |
| m = re.match(r"^\[sub-\d+\]\[([^\]]+)\]\s*(.*)$", line) | |
| if not m: | |
| raise ValueError(f"Sub-prompt format error: {line}\nFormat should be: [sub-index][mask] prompt") | |
| mask = m.group(1).strip() | |
| prompt = m.group(2).strip() | |
| sps.append(prompt) | |
| nps.append(mask if mask.lower() != "none" else None) | |
| # print(sps) | |
| # print(nps) | |
| return sps, nps | |
| def process_image( | |
| sub_prompts, | |
| n_self_replace, | |
| lb_threshold, | |
| attention_res, | |
| use_nurse, | |
| centroid_alignment, | |
| width, | |
| height, | |
| inference_steps, | |
| seed | |
| ): | |
| try: | |
| sps, nps = parse_sub_prompts(sub_prompts) | |
| if len(sps) != len(nps) + 1: | |
| placeholder_image = create_placeholder_image() | |
| err = f"Error: Number of sub-prompts ({len(sps)}) should be equal to number of masking words + 1 ({len(nps)}+1)" | |
| return placeholder_image, [placeholder_image] * 3, err | |
| pipe, device = init_pipeline() | |
| guidance_scale = 7.5 | |
| n_cross = 0.0 | |
| scale_factor = 1750 | |
| scale_range = (1.0, 0.0) | |
| angle_loss_weight = 0.0 | |
| max_refinement_steps = [6, 3] | |
| nursing_thresholds = {0: 26, 1: 25, 2: 24, 3: 23, 4: 22.5, 5: 22} | |
| save_cross_attention_maps = False | |
| if seed == -1: | |
| seed = random.randint(0, 1000000) | |
| g_cpu = torch.Generator().manual_seed(seed) | |
| controller_list = [] | |
| run_name = f'runs-SDXL/{time.strftime("%Y%m%d-%H%M%S")}-{seed}' | |
| controller_np = [[sps[i-1], sps[i]] for i in range(1, len(sps))] | |
| status_messages = [f"seed: {seed}"] | |
| for i in range(len(controller_np)): | |
| controller_kwargs = { | |
| "edit_type": "refine", | |
| "local_blend_words": nps[i], | |
| "n_cross_replace": {"default_": n_cross}, | |
| "n_self_replace": float(n_self_replace), | |
| "lb_threshold": float(lb_threshold), | |
| "lb_prompt": [sps[0]]*2, | |
| "is_nursing": use_nurse, | |
| "lb_res": (int(attention_res), int(attention_res)), | |
| "run_name": run_name, | |
| "save_map": save_cross_attention_maps, | |
| } | |
| if nps[i] is None: | |
| subject_str = ",".join([str(x) for x in nps if x is not None]) | |
| status_messages.append(f"Remove attributes from {subject_str}") | |
| else: | |
| diff_str = get_diff_string(sps[i], sps[i+1]) if i+1 < len(sps) else "" | |
| if diff_str: | |
| status_messages.append(f"Add {diff_str} to {nps[i]}") | |
| controller = create_controller( | |
| prompts=controller_np[i], | |
| cross_attention_kwargs=controller_kwargs, | |
| num_inference_steps=inference_steps, | |
| tokenizer=pipe.tokenizer, | |
| device=device, | |
| attn_res=(int(attention_res), int(attention_res)) | |
| ) | |
| controller_list.append(controller) | |
| cross_attention_kwargs = { | |
| "subprompts": sps, | |
| "set_controller": controller_list, | |
| "subject_words": nps if use_nurse else None, | |
| "nursing_threshold": nursing_thresholds, | |
| "max_refinement_steps": max_refinement_steps, | |
| "scale_factor": scale_factor, | |
| "scale_range": scale_range, | |
| "centroid_alignment": centroid_alignment, | |
| "angle_loss_weight": angle_loss_weight, | |
| } | |
| output = pipe( | |
| prompt=sps[-1], | |
| width=width, | |
| height=height, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| num_inference_steps=inference_steps, | |
| num_images_per_prompt=1, | |
| generator=g_cpu, | |
| attn_res=(int(attention_res), int(attention_res)), | |
| )[0] | |
| return output["images"][-1], output["images"], "\n".join(status_messages) | |
| except Exception as e: | |
| placeholder_image = create_placeholder_image() | |
| return placeholder_image, [placeholder_image] * len(sub_prompts), f"Error: {str(e)}" | |
| article = r""" | |
| --- | |
| 📝 **Citation** | |
| <br> | |
| If our work is helpful for your research or applications, please cite us via: | |
| ```bibtex | |
| @misc{chen2025detailtrainingfreeenhancertexttoimage, | |
| title={Detail++: Training-Free Detail Enhancer for Text-to-Image Diffusion Models}, | |
| author={Lifeng Chen and Jiner Wang and Zihao Pan and Beier Zhu and Xiaofeng Yang and Chi Zhang}, | |
| year={2025}, | |
| eprint={2507.17853}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.CV}, | |
| url={https://arxiv.org/abs/2507.17853}, | |
| } | |
| ``` | |
| 📧 **Contact** | |
| <br> | |
| If you have any questions, please feel free to open an issue or directly reach us out at <b>1633724411c@gmail.com</b>. | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown(_HEADER_) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt") | |
| with gr.Accordion("Decomposed sub-prompts", open=False): | |
| sub_prompts = gr.Textbox( | |
| lines=7, | |
| label="Sub-prompts", | |
| placeholder="You can enter sub-prompts manually, one per line, e.g.\n" | |
| "[original]...\n" | |
| "[sub-0][None]...\n" | |
| "[sub-1][hat]...\n" | |
| "..." | |
| ) | |
| n_self_replace = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Percetange of self-attention map substitution steps" | |
| ) | |
| lb_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.25, | |
| step=0.05, | |
| label="Threshold for latent mask extraction of subject words" | |
| ) | |
| attention_res = gr.Number( | |
| label="Attention map resolution", | |
| value=32 | |
| ) | |
| with gr.Row(): | |
| use_nurse = gr.Checkbox( | |
| label="Use attention nursing", | |
| value=True | |
| ) | |
| centroid_alignment = gr.Checkbox( | |
| label="Use centroid alignment", | |
| value=False | |
| ) | |
| with gr.Row(): | |
| width = gr.Number( | |
| label="Width", | |
| value=1024 | |
| ) | |
| height = gr.Number( | |
| label="Height", | |
| value=1024 | |
| ) | |
| inference_steps = gr.Number( | |
| label="Inference steps", | |
| value=20 | |
| ) | |
| seed = gr.Number( | |
| label="Seed (-1 for random)", | |
| value=-1 | |
| ) | |
| generate_btn = gr.Button("Generate Image") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Image") | |
| with gr.Accordion("Progressive Generating Process", open=False): | |
| gallery = gr.Gallery( | |
| label="Generation Steps", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=2, | |
| rows=3, | |
| height="auto" | |
| ) | |
| output_status = gr.Textbox(label="Status", lines=7) | |
| # Connect the generate button to the process_image function | |
| generate_btn.click( | |
| fn=process_image, | |
| inputs=[ | |
| sub_prompts, | |
| n_self_replace, | |
| lb_threshold, | |
| attention_res, | |
| use_nurse, | |
| centroid_alignment, | |
| width, | |
| height, | |
| inference_steps, | |
| seed | |
| ], | |
| outputs=[output_image, gallery, output_status] | |
| ) | |
| def generate_image( | |
| prompt, | |
| sub_prompts, | |
| n_self_replace, | |
| lb_threshold, | |
| attention_res, | |
| use_nurse, | |
| centroid_alignment, | |
| width, | |
| height, | |
| inference_steps, | |
| seed | |
| ): | |
| try: | |
| if not sub_prompts or sub_prompts.strip() == "": | |
| gpt_output = process_text(prompt) | |
| new_sub_prompts = "\n".join(json.loads(gpt_output)["variants"]) | |
| else: | |
| new_sub_prompts = sub_prompts | |
| image, gallery_list, status = process_image( | |
| new_sub_prompts, | |
| n_self_replace, | |
| lb_threshold, | |
| attention_res, | |
| use_nurse, | |
| centroid_alignment, | |
| width, | |
| height, | |
| inference_steps, | |
| seed | |
| ) | |
| return image, gallery_list, status, new_sub_prompts | |
| except Exception as e: | |
| error_message = f"Error: {str(e)}" | |
| print(error_message) | |
| return None, [None] * 3, error_message, sub_prompts | |
| # Create Gradio interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown(_HEADER_) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt") | |
| with gr.Accordion("Decomposed sub-prompts", open=False): | |
| sub_prompts = gr.Textbox( | |
| lines=7, | |
| label="Sub-prompts", | |
| placeholder="Enter sub-prompts, one per line, e.g.\n" | |
| "[original]...\n" | |
| "[sub-0][None]...\n" | |
| "[sub-1][hat]...\n" | |
| "..." | |
| ) | |
| n_self_replace = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Percetange of self-attention map substitution steps" | |
| ) | |
| lb_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.25, | |
| step=0.05, | |
| label="Threshold for latent mask extraction of subject words" | |
| ) | |
| attention_res = gr.Number( | |
| label="Attention map resolution", | |
| value=32 | |
| ) | |
| with gr.Row(): | |
| use_nurse = gr.Checkbox( | |
| label="Use attention nursing", | |
| value=True | |
| ) | |
| centroid_alignment = gr.Checkbox( | |
| label="Use centroid alignment", | |
| value=False | |
| ) | |
| with gr.Row(): | |
| width = gr.Number( | |
| label="Width", | |
| value=1024 | |
| ) | |
| height = gr.Number( | |
| label="Height", | |
| value=1024 | |
| ) | |
| inference_steps = gr.Number( | |
| label="Inference steps", | |
| value=20 | |
| ) | |
| seed = gr.Number( | |
| label="Seed (-1 for random)", | |
| value=-1 | |
| ) | |
| generate_btn = gr.Button("Generate Image") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Image") | |
| with gr.Accordion("Progressive Generating Process", open=False): | |
| gallery = gr.Gallery( | |
| label="Generation Steps", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=2, | |
| rows=3, | |
| height="auto" | |
| ) | |
| output_status = gr.Textbox(label="Status", lines=7) | |
| # 修改回调:调用 generate_image 函数,同时更新 sub_prompts 文本框 | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| prompt, | |
| sub_prompts, | |
| n_self_replace, | |
| lb_threshold, | |
| attention_res, | |
| use_nurse, | |
| centroid_alignment, | |
| width, | |
| height, | |
| inference_steps, | |
| seed | |
| ], | |
| outputs=[output_image, gallery, output_status, sub_prompts] | |
| ) | |
| # Examples | |
| example_data = [ | |
| [ | |
| "In a cyberpunk style city night, a cartoon style hound dog is standing in front of a lego style sports car", | |
| "", | |
| 0.5, | |
| 20, | |
| 5 | |
| ], | |
| [ | |
| "A sketch-style robot is leaning against an oil-painting style tree", | |
| "", | |
| 0.5, | |
| 20, | |
| 2 | |
| ], | |
| [ | |
| "a man wearing a red hat and blue tracksuit is standing in front of a green sports car", | |
| "", | |
| 0.5, | |
| 20, | |
| 6 | |
| ] | |
| ] | |
| gr.Examples( | |
| examples=example_data, | |
| inputs=[ | |
| prompt, | |
| sub_prompts, | |
| lb_threshold, | |
| inference_steps, | |
| seed | |
| ] | |
| ) | |
| gr.Markdown(article) | |
| if __name__ == "__main__": | |
| iface.launch() |