IdlecloudX commited on
Commit
27626d2
·
verified ·
1 Parent(s): b210526

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +246 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import logging
3
+ import os
4
+ import random
5
+ import re
6
+ import sys
7
+ import warnings
8
+
9
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
10
+ import gradio as gr
11
+ import torch
12
+ from transformers import AutoModel, AutoTokenizer
13
+
14
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
15
+
16
+ from diffusers import ZImagePipeline
17
+ from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
18
+
19
+ # ==================== Environment Variables ==================================
20
+ MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
21
+ ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
22
+ ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
23
+ ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
24
+ HF_TOKEN = os.environ.get("HF_TOKEN")
25
+ # =============================================================================
26
+
27
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
28
+ warnings.filterwarnings("ignore")
29
+ logging.getLogger("transformers").setLevel(logging.ERROR)
30
+
31
+ RES_CHOICES = {
32
+ "1024": [
33
+ "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )",
34
+ "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )",
35
+ "832x1248 ( 2:3 )", "1280x720 ( 16:9 )", "720x1280 ( 9:16 )",
36
+ "1344x576 ( 21:9 )", "576x1344 ( 9:21 )",
37
+ ],
38
+ "1280": [
39
+ "1280x1280 ( 1:1 )", "1440x1120 ( 9:7 )", "1120x1440 ( 7:9 )",
40
+ "1472x1104 ( 4:3 )", "1104x1472 ( 3:4 )", "1536x1024 ( 3:2 )",
41
+ "1024x1536 ( 2:3 )", "1600x896 ( 16:9 )", "896x1600 ( 9:16 )",
42
+ "1680x720 ( 21:9 )", "720x1680 ( 9:21 )",
43
+ ],
44
+ }
45
+
46
+ EXAMPLE_PROMPTS = [
47
+ ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
48
+ ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里..."],
49
+ ["一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子..."],
50
+ ["Young Chinese woman in red Hanfu, intricate embroidery..."],
51
+ ["A vertical digital illustration depicting a serene and majestic Chinese landscape..."],
52
+ ["一张虚构的英语电影《回忆之味》(The Taste of Memory)的电影海报..."],
53
+ ["一张方形构图的特写照片,主体是一片巨大的、鲜绿色的植物叶片..."],
54
+ ]
55
+
56
+ def get_resolution(resolution):
57
+ match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
58
+ if match:
59
+ return int(match.group(1)), int(match.group(2))
60
+ return 1024, 1024
61
+
62
+ def load_models(model_path, enable_compile=False, attention_backend="native"):
63
+ print(f"Loading models from {model_path}...")
64
+
65
+ use_auth_token = HF_TOKEN if HF_TOKEN else True
66
+
67
+ # Load VAE, Text Encoder, Tokenizer
68
+ if not os.path.exists(model_path):
69
+ vae = AutoencoderKL.from_pretrained(
70
+ f"{model_path}", subfolder="vae", torch_dtype=torch.bfloat16,
71
+ device_map="cuda", use_auth_token=use_auth_token,
72
+ )
73
+ text_encoder = AutoModel.from_pretrained(
74
+ f"{model_path}", subfolder="text_encoder", torch_dtype=torch.bfloat16,
75
+ device_map="cuda", use_auth_token=use_auth_token,
76
+ ).eval()
77
+ tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token)
78
+ else:
79
+ vae = AutoencoderKL.from_pretrained(os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda")
80
+ text_encoder = AutoModel.from_pretrained(os.path.join(model_path, "text_encoder"), torch_dtype=torch.bfloat16, device_map="cuda").eval()
81
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
82
+
83
+ tokenizer.padding_side = "left"
84
+
85
+ if enable_compile:
86
+ print("Enabling torch.compile optimizations...")
87
+ torch._inductor.config.conv_1x1_as_mm = True
88
+ torch._inductor.config.coordinate_descent_tuning = True
89
+ torch._inductor.config.epilogue_fusion = False
90
+ torch._inductor.config.coordinate_descent_check_all_directions = True
91
+ torch._inductor.config.max_autotune_gemm = True
92
+ torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
93
+ torch._inductor.config.triton.cudagraphs = False
94
+
95
+ pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
96
+
97
+ if enable_compile:
98
+ pipe.vae.disable_tiling()
99
+
100
+ # Load Transformer
101
+ if not os.path.exists(model_path):
102
+ transformer = ZImageTransformer2DModel.from_pretrained(
103
+ f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token
104
+ ).to("cuda", torch.bfloat16)
105
+ else:
106
+ transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to("cuda", torch.bfloat16)
107
+
108
+ pipe.transformer = transformer
109
+ pipe.transformer.set_attention_backend(attention_backend)
110
+
111
+ if enable_compile:
112
+ print("Compiling transformer...")
113
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
114
+
115
+ pipe.to("cuda", torch.bfloat16)
116
+ return pipe
117
+
118
+ def generate_image(pipe, prompt, width=1024, height=1024, seed=42, guidance_scale=5.0, num_inference_steps=50, shift=3.0, max_sequence_length=512, progress=gr.Progress(track_tqdm=True)):
119
+ generator = torch.Generator("cuda").manual_seed(seed)
120
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
121
+ pipe.scheduler = scheduler
122
+
123
+ image = pipe(
124
+ prompt=prompt, height=height, width=width,
125
+ guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
126
+ generator=generator, max_sequence_length=max_sequence_length,
127
+ ).images[0]
128
+
129
+ return image
130
+
131
+ def warmup_model(pipe, resolutions):
132
+ print("Starting warmup phase...")
133
+ dummy_prompt = "warmup"
134
+ for res_str in resolutions:
135
+ try:
136
+ w, h = get_resolution(res_str)
137
+ for i in range(3):
138
+ generate_image(pipe, prompt=dummy_prompt, width=w, height=h, num_inference_steps=9, guidance_scale=0.0, seed=42 + i)
139
+ except Exception as e:
140
+ print(f"Warmup failed for {res_str}: {e}")
141
+ print("Warmup completed.")
142
+
143
+ # Global Pipe Variable
144
+ pipe = None
145
+
146
+ def init_app():
147
+ global pipe
148
+ try:
149
+ pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
150
+ print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}")
151
+
152
+ if ENABLE_WARMUP:
153
+ all_resolutions = []
154
+ for cat in RES_CHOICES.values():
155
+ all_resolutions.extend(cat)
156
+ warmup_model(pipe, all_resolutions)
157
+
158
+ except Exception as e:
159
+ print(f"Error loading model: {e}")
160
+ pipe = None
161
+ # 移除 Prompt Expander 初始化
162
+
163
+ @spaces.GPU
164
+ def generate(prompt, width=1024, height=1024, seed=42, steps=9, shift=3.0, random_seed=True, gallery_images=None, progress=gr.Progress(track_tqdm=True)):
165
+ if pipe is None:
166
+ raise gr.Error("Model not loaded. Please check logs.")
167
+
168
+ if random_seed:
169
+ new_seed = random.randint(1, 1000000)
170
+ else:
171
+ new_seed = seed if seed != -1 else random.randint(1, 1000000)
172
+
173
+ image = generate_image(
174
+ pipe=pipe, prompt=prompt, width=int(width), height=int(height),
175
+ seed=new_seed, guidance_scale=0.0, num_inference_steps=int(steps + 1), shift=shift,
176
+ )
177
+
178
+ if gallery_images is None:
179
+ gallery_images = []
180
+ gallery_images.append(image)
181
+
182
+ return gallery_images, str(new_seed), int(new_seed)
183
+
184
+ # Initialize
185
+ init_app()
186
+
187
+ # ==================== AoTI (Ahead of Time Inductor compilation) ====================
188
+ # 安全检查:只有 pipe 成功加载后才执行优化配置,避免 AttributeError
189
+ if pipe is not None:
190
+ try:
191
+ pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
192
+ spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
193
+ except Exception as e:
194
+ print(f"Warning: Failed to load AoTI blocks: {e}")
195
+ else:
196
+ print("CRITICAL: Pipe is None. Model failed to load in init_app(). Check upstream errors.")
197
+
198
+ # ==================== UI Construction ====================
199
+ with gr.Blocks(title="Z-Image Demo") as demo:
200
+ gr.Markdown(
201
+ """<div align="center">
202
+ # Z-Image Generation Demo
203
+ [![GitHub](https://img.shields.io/badge/GitHub-Z--Image-181717?logo=github&logoColor=white)](https://github.com/Tongyi-MAI/Z-Image)
204
+ *An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer*
205
+ </div>"""
206
+ )
207
+
208
+ with gr.Row():
209
+ with gr.Column(scale=1):
210
+ prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...")
211
+
212
+ with gr.Row():
213
+ width = gr.Slider(label="Width", minimum=640, maximum=2048, value=1024, step=64)
214
+ height = gr.Slider(label="Height", minimum=640, maximum=2048, value=1024, step=64)
215
+
216
+ with gr.Row():
217
+ seed = gr.Number(label="Seed", value=42, precision=0)
218
+ random_seed = gr.Checkbox(label="Random Seed", value=True)
219
+
220
+ with gr.Row():
221
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=False)
222
+ shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
223
+
224
+ generate_btn = gr.Button("Generate", variant="primary")
225
+
226
+ gr.Markdown("### 📝 Example Prompts")
227
+ gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None)
228
+
229
+ with gr.Column(scale=1):
230
+ output_gallery = gr.Gallery(
231
+ label="Generated Images", columns=2, rows=2, height=600, object_fit="contain", format="png", interactive=False
232
+ )
233
+ used_seed = gr.Textbox(label="Seed Used", interactive=False)
234
+
235
+ generate_btn.click(
236
+ generate,
237
+ inputs=[prompt_input, width, height, seed, steps, shift, random_seed, output_gallery],
238
+ outputs=[output_gallery, used_seed, seed],
239
+ api_visibility="public",
240
+ )
241
+
242
+ css='''
243
+ .fillable{max-width: 1230px !important}
244
+ '''
245
+ if __name__ == "__main__":
246
+ demo.launch(css=css, mcp_server=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ accelerate
5
+ spaces
6
+ openai
7
+ git+https://github.com/huggingface/diffusers.git
8
+ kernels