Westlake-AGI-Lab commited on
Commit
f92b956
·
verified ·
1 Parent(s): 8190b16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -126
app.py CHANGED
@@ -1,154 +1,316 @@
1
- import gradio as gr
2
- import numpy as np
3
  import random
4
-
5
- import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "SG161222/RealVisXL_V4.0" # Replace to the model you would like to use
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
22
 
 
 
 
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
 
 
30
  width,
31
  height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
  )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
  )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
-
119
  with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
 
 
 
145
  width,
146
  height,
147
- guidance_scale,
148
- num_inference_steps,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
+ import os
 
2
  import random
3
+ import time
 
 
4
  import torch
5
+ import gradio as gr
6
+ from ProT2I.prot2i_pipeline_sdxl import ProT2IPipeline
7
+ from ProT2I.processors import create_controller
8
+ from PIL import Image
9
+ import numpy as np
10
+ import difflib
11
+ import spaces
12
 
13
+ _HEADER_ = '''
14
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
15
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">ProT2I for SDXL</h1>
16
+ </div>
17
 
18
+ ⭐⭐⭐**Tips:**
19
+ - ⭐`Sub-prompts:` Enter the decomposed sub-prompts, one per line.
20
+ - ⭐`Subject Masking Words:` Enter the subject words for each sub-prompt, one per line. (Leave it a blank line, if you want to remove all attributes firstly.)
21
+ - ⭐We provide an example at the bottom that you can try.
22
+ - ⭐For attributes overflow, you can adaptively increase the `Threshold Value` for mask extraction.
23
+ '''
24
 
25
+ def create_placeholder_image():
26
+ return Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
27
 
28
+ def get_diff_string(str1, str2):
29
+ """
30
+ `str1` and `str2` are two strings.
31
+ This function returns the difference between the two strings as a string.
32
+ """
33
+ diff = difflib.ndiff(str1.split(), str2.split())
34
+ added_parts = [word[2:] for word in diff if word.startswith('+ ')] # get added parts
35
+ return ' '.join(added_parts)
36
 
37
+ def init_pipeline():
38
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
39
+ pipe = ProT2IPipeline.from_pretrained("SG161222/RealVisXL_V4.0", use_safetensors=True, variant='fp16').to(torch.float16)
40
+ pipe.enable_model_cpu_offload()
41
+ return pipe, device
42
 
43
+ def process_image(
44
+ sub_prompts,
45
+ lb_words,
46
+ n_self_replace,
47
+ lb_threshold,
48
+ attention_res,
49
+ use_nurse,
50
+ centroid_alignment,
51
  width,
52
  height,
53
+ inference_steps,
54
+ seed
 
55
  ):
56
+ try:
57
+ # Initialize pipeline
58
+ pipe, device = init_pipeline()
59
+
60
+ # Process sub-prompts
61
+ sps = [prompt.strip() for prompt in sub_prompts.split('\n') if prompt.strip()]
62
+
63
+ # Process semantic masking words
64
+ nps = []
65
+ for word in lb_words.split('\n'):
66
+ if word.strip():
67
+ nps.append(word.strip())
68
+ else:
69
+ nps.append(None)
70
+
71
+ # Validate inputs
72
+ if len(nps) + 1 != len(sps):
73
+ placeholder_image = create_placeholder_image()
74
+ return placeholder_image, [placeholder_image] * 3, f"Error: Number of semantic masks ({len(nps)}) should be one less than number of sub-prompts ({len(sps)})"
75
+
76
+ # Set fixed parameters from config
77
+ guidance_scale = 7.5
78
+ n_cross = 0.0
79
+ scale_factor = 1750
80
+ scale_range = (1.0, 0.0)
81
+ angle_loss_weight = 0.0
82
+ max_refinement_steps = [6, 3]
83
+ nursing_thresholds = {
84
+ 0: 26, 1: 25, 2: 24, 3: 23, 4: 22.5, 5: 22,
85
+ }
86
+ save_cross_attention_maps = False
87
+
88
+ if seed == -1:
89
+ seed = random.randint(0, 1000000)
90
+ g_cpu = torch.Generator().manual_seed(seed)
91
+
92
+ # Create controllers
93
+ controller_list = []
94
+ run_name = f'runs-SDXL/{time.strftime("%Y%m%d-%H%M%S")}-{seed}'
95
+ controller_np = [[sps[i-1], sps[i]] for i in range(1, len(sps))]
96
+
97
+ # Prepare status messages
98
+ status_messages = [f"seed: {seed}"]
99
+
100
+ for i in range(len(controller_np)):
101
+ controller_kwargs = {
102
+ "edit_type": "refine",
103
+ "local_blend_words": nps[i],
104
+ "n_cross_replace": {"default_": n_cross},
105
+ "n_self_replace": float(n_self_replace),
106
+ "lb_threshold": float(lb_threshold)+1,
107
+ "lb_prompt": [sps[0]]*2,
108
+ "is_nursing": use_nurse,
109
+ "lb_res": (int(attention_res), int(attention_res)),
110
+ "run_name": run_name,
111
+ "save_map": save_cross_attention_maps,
112
+ }
113
+
114
+ # Get difference between sps[i+1] and sps[i]
115
+ if nps[i] is None:
116
+ subject_strig = ",".join(nps[1:])
117
+ status_messages.append(f"Remove attributes from {subject_strig}")
118
+ else:
119
+ diff_str = get_diff_string(sps[i], sps[i+1])
120
+ if diff_str:
121
+ status_messages.append(f"Add {diff_str} to {nps[i]}")
122
+
123
+ controller = create_controller(
124
+ prompts=controller_np[i],
125
+ cross_attention_kwargs=controller_kwargs,
126
+ num_inference_steps=inference_steps,
127
+ tokenizer=pipe.tokenizer,
128
+ device=device,
129
+ attn_res=(int(attention_res), int(attention_res))
130
  )
131
+ controller_list.append(controller)
132
+
133
+ # Set up cross attention kwargs
134
+ cross_attention_kwargs = {
135
+ "subprompts": sps,
136
+ "set_controller": controller_list,
137
+ "subject_words": nps if use_nurse else None,
138
+ "nursing_threshold": nursing_thresholds,
139
+ "max_refinement_steps": max_refinement_steps,
140
+ "scale_factor": scale_factor,
141
+ "scale_range": scale_range,
142
+ "centroid_alignment": centroid_alignment,
143
+ "angle_loss_weight": angle_loss_weight,
144
+ }
145
+
146
+ # Generate images
147
+ output = pipe(
148
+ prompt=sps[-1], # Use the last sub-prompt as the final prompt
149
+ width=width,
150
+ height=height,
151
+ cross_attention_kwargs=cross_attention_kwargs,
152
+ num_inference_steps=inference_steps,
153
+ num_images_per_prompt=1,
154
+ generator=g_cpu,
155
+ attn_res=(int(attention_res), int(attention_res)),
156
+ )[0]
157
+
158
+ return output["images"][-1], output["images"], "\n".join(status_messages)
159
+
160
+ except Exception as e:
161
+ placeholder_image = create_placeholder_image()
162
+ return placeholder_image, [placeholder_image] * 3, f"Error: {str(e)}"
163
 
164
+ # Create Gradio interface
165
+ with gr.Blocks() as iface:
166
+ gr.Markdown(_HEADER_)
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=1):
170
+ sub_prompts = gr.Textbox(
171
+ lines=5,
172
+ label="Sub-prompts",
173
+ placeholder="Enter sub-prompts, one per line..."
174
  )
175
+
176
+ lb_words = gr.Textbox(
177
+ lines=4,
178
+ label="Subject masking words",
179
+ placeholder="Enter subject words, one per line..."
 
 
180
  )
181
+
182
+ n_self_replace = gr.Slider(
183
+ minimum=0.0,
184
+ maximum=1.0,
185
+ value=0.8,
186
+ step=0.1,
187
+ label="Percetange of self-attention map substitution steps"
188
+ )
189
+
190
+ lb_threshold = gr.Slider(
191
+ minimum=0.0,
192
+ maximum=1.0,
193
+ value=0.25,
194
+ step=0.05,
195
+ label="Threshold for latent mask extraction of subject words"
196
+ )
197
+
198
+ attention_res = gr.Number(
199
+ label="Attention map resolution",
200
+ value=32
201
+ )
202
+
203
  with gr.Row():
204
+ use_nurse = gr.Checkbox(
205
+ label="Use attention nursing",
206
+ value=True
 
 
 
207
  )
208
+
209
+ centroid_alignment = gr.Checkbox(
210
+ label="Use centroid alignment",
211
+ value=False
 
 
 
212
  )
213
+
214
  with gr.Row():
215
+ width = gr.Number(
216
+ label="Width",
217
+ value=1024
 
 
 
218
  )
219
+
220
+ height = gr.Number(
221
+ label="Height",
222
+ value=1024
 
 
 
223
  )
224
+
225
+ inference_steps = gr.Number(
226
+ label="Inference steps",
227
+ value=20
228
+ )
229
+
230
+ seed = gr.Number(
231
+ label="Seed (-1 for random)",
232
+ value=-1
233
+ )
234
+
235
+ generate_btn = gr.Button("Generate Image")
236
+
237
+ with gr.Column(scale=1):
238
+ output_image = gr.Image(label="Generated Image")
239
+
240
+ with gr.Accordion("Progressive Generation Process", open=False):
241
+ gallery = gr.Gallery(
242
+ label="Generation Steps",
243
+ show_label=True,
244
+ elem_id="gallery",
245
+ columns=2,
246
+ rows=3,
247
+ height="auto"
248
+ )
249
+
250
+ output_status = gr.Textbox(label="Status", lines=4)
251
+
252
+ # Connect the generate button to the process_image function
253
+ generate_btn.click(
254
+ fn=process_image,
255
  inputs=[
256
+ sub_prompts,
257
+ lb_words,
258
+ n_self_replace,
259
+ lb_threshold,
260
+ attention_res,
261
+ use_nurse,
262
+ centroid_alignment,
263
  width,
264
  height,
265
+ inference_steps,
266
+ seed
267
+ ],
268
+ outputs=[output_image, gallery, output_status]
269
+ )
270
+
271
+ # Examples
272
+ example_data = [
273
+ [
274
+ "a car and a bench\na blue car and a bench\na car and a green bench",
275
+ "car\nbench",
276
+ 0.1,
277
+ 20,
278
+ 1
279
+ ],
280
+ [
281
+ "In a cyberpunk style city night, a hound dog is standing in front of a sports car\nVan Gogh style hound dog\nLego-style sports car",
282
+ "dog\ncar",
283
+ 0.25,
284
+ 20,
285
+ 2
286
  ],
287
+ [
288
+ "A sketch-style robot is leaning a oil-painting style tree\nA robot is leaning a tree\nA sketch-style robot is leaning a tree\nA robot is leaning a oil-painting style tree",
289
+ "\nrobot\ntree",
290
+ 0.25,
291
+ 20,
292
+ 0
293
+ ],
294
+ [
295
+ "a man wearing a red hat and blue tracksuit is standing in front of a green sports car\na man wearing a hat and tracksuit is standing in front of a sports car\na man wearing a red hat and tracksuit is standing in front of a sports car\na man wearing a hat and blue tracksuit is standing in front of a sports car\na man wearing a hat and tracksuit is standing in front of a green sports car",
296
+ "\nhat\ntracksuit\ncar",
297
+ 0.25,
298
+ 20,
299
+ 6
300
+ ],
301
+
302
+ ]
303
+
304
+ gr.Examples(
305
+ examples=example_data,
306
+ inputs=[
307
+ sub_prompts,
308
+ lb_words,
309
+ lb_threshold,
310
+ inference_steps,
311
+ seed
312
+ ]
313
  )
314
 
315
  if __name__ == "__main__":
316
+ iface.launch(share=True, server_port=7549)