Westlake-AGI-Lab's picture
Update app.py
438aec8 verified
import spaces
import os
import json
from openai import OpenAI
import random
import re
import time
import torch
import gradio as gr
from ProT2I.prot2i_pipeline_sdxl import ProT2IPipeline
from ProT2I.processors import create_controller
from PIL import Image
import numpy as np
import difflib
_HEADER_ = '''
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">Detail++ for SDXL</h1>
</div>
⭐⭐⭐**Tips:**
- ⭐We provide a version of llm automatically decomposing the prompt, and you just need to input the complex prompts with various attributes like color, style etc. in the `Prompt` textbox.
- ⭐For attributes overflow, you can adaptively increase the `Threshold Value` for mask extraction.
- ⭐Also you can adjust the sub-prompts mannually in `Decomposed sub-prompts`. When entering this, please use the fixed format as followed:
- The first line must start with <strong>[original]</strong>.
- Subsequent lines must start with <strong>[sub-index][subject words]</strong>, where <em>subject words</em> indicates the corresponding subject of currently adding attributes.
- Add one branch <strong>[sub-0][None]</strong>, if you want to remove all confusing attributes firstly.
'''
def create_placeholder_image():
return Image.fromarray(np.ones((1024, 1024, 3), dtype=np.uint8) * 255)
def get_diff_string(str1, str2):
"""
str1 and str2 are two strings.
This function returns the difference between the two strings as a string.
"""
diff = difflib.ndiff(str1.split(), str2.split())
added_parts = [word[2:] for word in diff if word.startswith('+ ')] # get added parts
return ' '.join(added_parts)
def process_text(prompt):
client = OpenAI(
base_url="https://a1.aizex.me/v1",
api_key = os.getenv('api_key'),
)
system_prompt = """**Detailed Instruction Prompt for Decomposing Image Descriptions**
You are provided with an original prompt that describes an image containing one or more subjects with detailed attributes (such as colors, clothing, objects, etc.). Your task is to generate a series of sub-prompts that decompose the original prompt into simpler, attribute-focused branches. Follow the steps and rules below exactly:
1. **Output Format Requirements:**
- **First Line:**
- Begin with `[original]` followed by a space and then the complete original prompt exactly as provided.
- **Subsequent Lines:**
- Each additional line must start with `[sub-index][subject]` where:
- `sub-index` is a sequential number starting from 0.
- `subject` is a keyword that indicates which subject's detailed attribute is being highlighted. If the attribute added is global, like background, use `None`. For the first branch, use `None` as the subject keyword.
- **Line Separation:**
- Each sub-prompt must appear on its own line.
2. **Decomposition Rules:**
- **Generic Version ([sub-0][None]):**
- Create a version of the prompt that has all specific detailed attributes (e.g., color adjectives, style adjectives) removed. This produces a simplified, generic description of the scene.
- **Attribute-Specific Branches:**
- For every distinct subject in the original prompt that has a specific attribute, generate a branch that reintroduces that particular attribute while keeping all other subjects in their generic state.
- Each branch must re-add the attribute detail for only one subject. For example, if the original prompt mentions a “red hat” on one subject and a “blue tracksuit” on another, then:
- One branch should reintroduce “red” for the hat.
- Another branch should reintroduce “blue” for the tracksuit.
- The keyword inside the brackets (after the sub-index) should indicate the subject whose attribute is restored (e.g., `hat`, `tracksuit`, `car`, etc.).
3. **General Guidelines:**
- **Consistency:**
- Ensure that the modified sub-prompts are logically consistent with the original description. Only one attribute should be reintroduced per branch, while all other attribute details remain generic.
- **Precision:**
- Follow the exact fixed format with square brackets and no extra characters or commentary.
- **No Extra Text:**
- Do not include any explanations, notes, or additional commentary in the output. The final output should only contain the sub-prompts as specified.
- **Output format:**
- The output should be a JSON object with a single key `variants` that contains a list of sub-prompts.
4. **Example to Follow:**
Given the original prompt:
```
a man wearing a red hat and blue tracksuit is standing in front of a green sports car
```
The output should be:
```
{"variants":
[
[original] a man wearing a red hat and blue tracksuit is standing in front of a green sports car
[sub-0][None] a man wearing a hat and tracksuit is standing in front of a sports car
[sub-1][hat] a man wearing a red hat and tracksuit is standing in front of a sports car
[sub-2][tracksuit] a man wearing a hat and blue tracksuit is standing in front of a sports car
[sub-3][car] a man wearing a hat and tracksuit is standing in front of a green sports car
]
}
```
5. **Another Example to Follow:**
Given the original prompt:
```
In a cyberpunk style city night, a VanGogh-style hound dog is standing in front of a lego-style sports car
```
The output should be:
```
{"variants":
[
[original] In a cyberpunk style city night, a VanGogh-style hound dog is standing in front of a Lego-style sports car
[sub-0][None] In a city night, a hound dog is standing in front of a sports car
[sub-1][None] In a cyberpunk style city night, a hound dog is standing in front of a sports car
[sub-2][hound dog] In a city night, a VanGogh-style hound dog is standing in front of a sports car
[sub-3][car] In a city night, a hound dog is standing in front of a Lego-style sports car
]
}
```
6. **Task Summary:**
- Your task is to read the given original prompt and output a set of sub-prompts using the format above.
- The first sub-prompt ([sub-0][None]) should be the fully generic version.
- Each subsequent sub-prompt should selectively reintroduce one detailed attribute corresponding to a subject from the original prompt.
Now, use this detailed instruction prompt to generate the decomposed sub-prompts for any provided original image description.
---
"""
response = client.chat.completions.create(
model="gpt-4-turbo",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
temperature=0.7,
)
return response.choices[0].message.content
def init_pipeline():
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
pipe = ProT2IPipeline.from_pretrained("SG161222/RealVisXL_V4.0", use_safetensors=True, variant='fp16').to(torch.float16).to(device)
return pipe, device
def parse_sub_prompts(text):
lines = [line.strip() for line in text.split('\n') if line.strip()]
if not lines:
raise ValueError("Please enter at least one line.")
sps = []
nps = []
if not lines[0].lower().startswith("[original]"):
raise ValueError("The first line must start with indicating the original description.")
sps.append(lines[0][len("[original]"):].strip())
for line in lines[1:]:
m = re.match(r"^\[sub-\d+\]\[([^\]]+)\]\s*(.*)$", line)
if not m:
raise ValueError(f"Sub-prompt format error: {line}\nFormat should be: [sub-index][mask] prompt")
mask = m.group(1).strip()
prompt = m.group(2).strip()
sps.append(prompt)
nps.append(mask if mask.lower() != "none" else None)
# print(sps)
# print(nps)
return sps, nps
def process_image(
sub_prompts,
n_self_replace,
lb_threshold,
attention_res,
use_nurse,
centroid_alignment,
width,
height,
inference_steps,
seed
):
try:
sps, nps = parse_sub_prompts(sub_prompts)
if len(sps) != len(nps) + 1:
placeholder_image = create_placeholder_image()
err = f"Error: Number of sub-prompts ({len(sps)}) should be equal to number of masking words + 1 ({len(nps)}+1)"
return placeholder_image, [placeholder_image] * 3, err
pipe, device = init_pipeline()
guidance_scale = 7.5
n_cross = 0.0
scale_factor = 1750
scale_range = (1.0, 0.0)
angle_loss_weight = 0.0
max_refinement_steps = [6, 3]
nursing_thresholds = {0: 26, 1: 25, 2: 24, 3: 23, 4: 22.5, 5: 22}
save_cross_attention_maps = False
if seed == -1:
seed = random.randint(0, 1000000)
g_cpu = torch.Generator().manual_seed(seed)
controller_list = []
run_name = f'runs-SDXL/{time.strftime("%Y%m%d-%H%M%S")}-{seed}'
controller_np = [[sps[i-1], sps[i]] for i in range(1, len(sps))]
status_messages = [f"seed: {seed}"]
for i in range(len(controller_np)):
controller_kwargs = {
"edit_type": "refine",
"local_blend_words": nps[i],
"n_cross_replace": {"default_": n_cross},
"n_self_replace": float(n_self_replace),
"lb_threshold": float(lb_threshold),
"lb_prompt": [sps[0]]*2,
"is_nursing": use_nurse,
"lb_res": (int(attention_res), int(attention_res)),
"run_name": run_name,
"save_map": save_cross_attention_maps,
}
if nps[i] is None:
subject_str = ",".join([str(x) for x in nps if x is not None])
status_messages.append(f"Remove attributes from {subject_str}")
else:
diff_str = get_diff_string(sps[i], sps[i+1]) if i+1 < len(sps) else ""
if diff_str:
status_messages.append(f"Add {diff_str} to {nps[i]}")
controller = create_controller(
prompts=controller_np[i],
cross_attention_kwargs=controller_kwargs,
num_inference_steps=inference_steps,
tokenizer=pipe.tokenizer,
device=device,
attn_res=(int(attention_res), int(attention_res))
)
controller_list.append(controller)
cross_attention_kwargs = {
"subprompts": sps,
"set_controller": controller_list,
"subject_words": nps if use_nurse else None,
"nursing_threshold": nursing_thresholds,
"max_refinement_steps": max_refinement_steps,
"scale_factor": scale_factor,
"scale_range": scale_range,
"centroid_alignment": centroid_alignment,
"angle_loss_weight": angle_loss_weight,
}
output = pipe(
prompt=sps[-1],
width=width,
height=height,
cross_attention_kwargs=cross_attention_kwargs,
num_inference_steps=inference_steps,
num_images_per_prompt=1,
generator=g_cpu,
attn_res=(int(attention_res), int(attention_res)),
)[0]
return output["images"][-1], output["images"], "\n".join(status_messages)
except Exception as e:
placeholder_image = create_placeholder_image()
return placeholder_image, [placeholder_image] * len(sub_prompts), f"Error: {str(e)}"
article = r"""
---
📝 **Citation**
<br>
If our work is helpful for your research or applications, please cite us via:
```bibtex
@misc{chen2025detailtrainingfreeenhancertexttoimage,
title={Detail++: Training-Free Detail Enhancer for Text-to-Image Diffusion Models},
author={Lifeng Chen and Jiner Wang and Zihao Pan and Beier Zhu and Xiaofeng Yang and Chi Zhang},
year={2025},
eprint={2507.17853},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2507.17853},
}
```
📧 **Contact**
<br>
If you have any questions, please feel free to open an issue or directly reach us out at <b>1633724411c@gmail.com</b>.
"""
# Create Gradio interface
with gr.Blocks() as iface:
gr.Markdown(_HEADER_)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt")
with gr.Accordion("Decomposed sub-prompts", open=False):
sub_prompts = gr.Textbox(
lines=7,
label="Sub-prompts",
placeholder="You can enter sub-prompts manually, one per line, e.g.\n"
"[original]...\n"
"[sub-0][None]...\n"
"[sub-1][hat]...\n"
"..."
)
n_self_replace = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.8,
step=0.1,
label="Percetange of self-attention map substitution steps"
)
lb_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Threshold for latent mask extraction of subject words"
)
attention_res = gr.Number(
label="Attention map resolution",
value=32
)
with gr.Row():
use_nurse = gr.Checkbox(
label="Use attention nursing",
value=True
)
centroid_alignment = gr.Checkbox(
label="Use centroid alignment",
value=False
)
with gr.Row():
width = gr.Number(
label="Width",
value=1024
)
height = gr.Number(
label="Height",
value=1024
)
inference_steps = gr.Number(
label="Inference steps",
value=20
)
seed = gr.Number(
label="Seed (-1 for random)",
value=-1
)
generate_btn = gr.Button("Generate Image")
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image")
with gr.Accordion("Progressive Generating Process", open=False):
gallery = gr.Gallery(
label="Generation Steps",
show_label=True,
elem_id="gallery",
columns=2,
rows=3,
height="auto"
)
output_status = gr.Textbox(label="Status", lines=7)
# Connect the generate button to the process_image function
generate_btn.click(
fn=process_image,
inputs=[
sub_prompts,
n_self_replace,
lb_threshold,
attention_res,
use_nurse,
centroid_alignment,
width,
height,
inference_steps,
seed
],
outputs=[output_image, gallery, output_status]
)
@spaces.GPU
def generate_image(
prompt,
sub_prompts,
n_self_replace,
lb_threshold,
attention_res,
use_nurse,
centroid_alignment,
width,
height,
inference_steps,
seed
):
try:
if not sub_prompts or sub_prompts.strip() == "":
gpt_output = process_text(prompt)
new_sub_prompts = "\n".join(json.loads(gpt_output)["variants"])
else:
new_sub_prompts = sub_prompts
image, gallery_list, status = process_image(
new_sub_prompts,
n_self_replace,
lb_threshold,
attention_res,
use_nurse,
centroid_alignment,
width,
height,
inference_steps,
seed
)
return image, gallery_list, status, new_sub_prompts
except Exception as e:
error_message = f"Error: {str(e)}"
print(error_message)
return None, [None] * 3, error_message, sub_prompts
# Create Gradio interface
with gr.Blocks() as iface:
gr.Markdown(_HEADER_)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt")
with gr.Accordion("Decomposed sub-prompts", open=False):
sub_prompts = gr.Textbox(
lines=7,
label="Sub-prompts",
placeholder="Enter sub-prompts, one per line, e.g.\n"
"[original]...\n"
"[sub-0][None]...\n"
"[sub-1][hat]...\n"
"..."
)
n_self_replace = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.8,
step=0.1,
label="Percetange of self-attention map substitution steps"
)
lb_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Threshold for latent mask extraction of subject words"
)
attention_res = gr.Number(
label="Attention map resolution",
value=32
)
with gr.Row():
use_nurse = gr.Checkbox(
label="Use attention nursing",
value=True
)
centroid_alignment = gr.Checkbox(
label="Use centroid alignment",
value=False
)
with gr.Row():
width = gr.Number(
label="Width",
value=1024
)
height = gr.Number(
label="Height",
value=1024
)
inference_steps = gr.Number(
label="Inference steps",
value=20
)
seed = gr.Number(
label="Seed (-1 for random)",
value=-1
)
generate_btn = gr.Button("Generate Image")
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image")
with gr.Accordion("Progressive Generating Process", open=False):
gallery = gr.Gallery(
label="Generation Steps",
show_label=True,
elem_id="gallery",
columns=2,
rows=3,
height="auto"
)
output_status = gr.Textbox(label="Status", lines=7)
# 修改回调:调用 generate_image 函数,同时更新 sub_prompts 文本框
generate_btn.click(
fn=generate_image,
inputs=[
prompt,
sub_prompts,
n_self_replace,
lb_threshold,
attention_res,
use_nurse,
centroid_alignment,
width,
height,
inference_steps,
seed
],
outputs=[output_image, gallery, output_status, sub_prompts]
)
# Examples
example_data = [
[
"In a cyberpunk style city night, a cartoon style hound dog is standing in front of a lego style sports car",
"",
0.5,
20,
5
],
[
"A sketch-style robot is leaning against an oil-painting style tree",
"",
0.5,
20,
2
],
[
"a man wearing a red hat and blue tracksuit is standing in front of a green sports car",
"",
0.5,
20,
6
]
]
gr.Examples(
examples=example_data,
inputs=[
prompt,
sub_prompts,
lb_threshold,
inference_steps,
seed
]
)
gr.Markdown(article)
if __name__ == "__main__":
iface.launch()