Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import warnings | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from diffusers import ZImagePipeline | |
| from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel | |
| # ==================== Environment Variables ================================== | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") | |
| ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" | |
| ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" | |
| ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # ============================================================================= | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| RES_CHOICES = { | |
| "1024": [ | |
| "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )", | |
| "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )", | |
| "832x1248 ( 2:3 )", "1280x720 ( 16:9 )", "720x1280 ( 9:16 )", | |
| "1344x576 ( 21:9 )", "576x1344 ( 9:21 )", | |
| ], | |
| "1280": [ | |
| "1280x1280 ( 1:1 )", "1440x1120 ( 9:7 )", "1120x1440 ( 7:9 )", | |
| "1472x1104 ( 4:3 )", "1104x1472 ( 3:4 )", "1536x1024 ( 3:2 )", | |
| "1024x1536 ( 2:3 )", "1600x896 ( 16:9 )", "896x1600 ( 9:16 )", | |
| "1680x720 ( 21:9 )", "720x1680 ( 9:21 )", | |
| ], | |
| } | |
| EXAMPLE_PROMPTS = [ | |
| ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"], | |
| ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里..."], | |
| ["一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子..."], | |
| ["Young Chinese woman in red Hanfu, intricate embroidery..."], | |
| ["A vertical digital illustration depicting a serene and majestic Chinese landscape..."], | |
| ["一张虚构的英语电影《回忆之味》(The Taste of Memory)的电影海报..."], | |
| ["一张方形构图的特写照片,主体是一片巨大的、鲜绿色的植物叶片..."], | |
| ] | |
| def get_resolution(resolution): | |
| match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) | |
| if match: | |
| return int(match.group(1)), int(match.group(2)) | |
| return 1024, 1024 | |
| def load_models(model_path, enable_compile=False, attention_backend="native"): | |
| print(f"Loading models from {model_path}...") | |
| use_auth_token = HF_TOKEN if HF_TOKEN else True | |
| # Load VAE, Text Encoder, Tokenizer | |
| if not os.path.exists(model_path): | |
| vae = AutoencoderKL.from_pretrained( | |
| f"{model_path}", subfolder="vae", torch_dtype=torch.bfloat16, | |
| device_map="cuda", use_auth_token=use_auth_token, | |
| ) | |
| text_encoder = AutoModel.from_pretrained( | |
| f"{model_path}", subfolder="text_encoder", torch_dtype=torch.bfloat16, | |
| device_map="cuda", use_auth_token=use_auth_token, | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token) | |
| else: | |
| vae = AutoencoderKL.from_pretrained(os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda") | |
| text_encoder = AutoModel.from_pretrained(os.path.join(model_path, "text_encoder"), torch_dtype=torch.bfloat16, device_map="cuda").eval() | |
| tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) | |
| tokenizer.padding_side = "left" | |
| if enable_compile: | |
| print("Enabling torch.compile optimizations...") | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.max_autotune_gemm = True | |
| torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" | |
| torch._inductor.config.triton.cudagraphs = False | |
| pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None) | |
| if enable_compile: | |
| pipe.vae.disable_tiling() | |
| # Load Transformer | |
| if not os.path.exists(model_path): | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token | |
| ).to("cuda", torch.bfloat16) | |
| else: | |
| transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to("cuda", torch.bfloat16) | |
| pipe.transformer = transformer | |
| pipe.transformer.set_attention_backend(attention_backend) | |
| if enable_compile: | |
| print("Compiling transformer...") | |
| pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) | |
| pipe.to("cuda", torch.bfloat16) | |
| return pipe | |
| def generate_image(pipe, prompt, width=1024, height=1024, seed=42, guidance_scale=5.0, num_inference_steps=50, shift=3.0, max_sequence_length=512, progress=gr.Progress(track_tqdm=True)): | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) | |
| pipe.scheduler = scheduler | |
| image = pipe( | |
| prompt=prompt, height=height, width=width, | |
| guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, | |
| generator=generator, max_sequence_length=max_sequence_length, | |
| ).images[0] | |
| return image | |
| def warmup_model(pipe, resolutions): | |
| print("Starting warmup phase...") | |
| dummy_prompt = "warmup" | |
| for res_str in resolutions: | |
| try: | |
| w, h = get_resolution(res_str) | |
| for i in range(3): | |
| generate_image(pipe, prompt=dummy_prompt, width=w, height=h, num_inference_steps=9, guidance_scale=0.0, seed=42 + i) | |
| except Exception as e: | |
| print(f"Warmup failed for {res_str}: {e}") | |
| print("Warmup completed.") | |
| # Global Pipe Variable | |
| pipe = None | |
| def init_app(): | |
| global pipe | |
| try: | |
| pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) | |
| print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}") | |
| if ENABLE_WARMUP: | |
| all_resolutions = [] | |
| for cat in RES_CHOICES.values(): | |
| all_resolutions.extend(cat) | |
| warmup_model(pipe, all_resolutions) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| pipe = None | |
| # 移除 Prompt Expander 初始化 | |
| def generate(prompt, width=1024, height=1024, seed=42, steps=9, shift=3.0, random_seed=True, gallery_images=None, progress=gr.Progress(track_tqdm=True)): | |
| if pipe is None: | |
| raise gr.Error("Model not loaded. Please check logs.") | |
| if random_seed: | |
| new_seed = random.randint(1, 1000000) | |
| else: | |
| new_seed = seed if seed != -1 else random.randint(1, 1000000) | |
| image = generate_image( | |
| pipe=pipe, prompt=prompt, width=int(width), height=int(height), | |
| seed=new_seed, guidance_scale=0.0, num_inference_steps=int(steps + 1), shift=shift, | |
| ) | |
| if gallery_images is None: | |
| gallery_images = [] | |
| gallery_images.append(image) | |
| return gallery_images, str(new_seed), int(new_seed) | |
| # Initialize | |
| init_app() | |
| # ==================== AoTI (Ahead of Time Inductor compilation) ==================== | |
| # 安全检查:只有 pipe 成功加载后才执行优化配置,避免 AttributeError | |
| if pipe is not None: | |
| try: | |
| pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] | |
| spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") | |
| except Exception as e: | |
| print(f"Warning: Failed to load AoTI blocks: {e}") | |
| else: | |
| print("CRITICAL: Pipe is None. Model failed to load in init_app(). Check upstream errors.") | |
| # ==================== UI Construction ==================== | |
| with gr.Blocks(title="Z-Image Demo") as demo: | |
| gr.Markdown( | |
| """<div align="center"> | |
| # Z-Image Generation Demo | |
| [](https://github.com/Tongyi-MAI/Z-Image) | |
| *An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer* | |
| </div>""" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...") | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=640, maximum=2048, value=1024, step=64) | |
| height = gr.Slider(label="Height", minimum=640, maximum=2048, value=1024, step=64) | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=42, precision=0) | |
| random_seed = gr.Checkbox(label="Random Seed", value=True) | |
| with gr.Row(): | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=False) | |
| shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| gr.Markdown("### 📝 Example Prompts") | |
| gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None) | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery( | |
| label="Generated Images", columns=2, rows=2, height=600, object_fit="contain", format="png", interactive=False | |
| ) | |
| used_seed = gr.Textbox(label="Seed Used", interactive=False) | |
| generate_btn.click( | |
| generate, | |
| inputs=[prompt_input, width, height, seed, steps, shift, random_seed, output_gallery], | |
| outputs=[output_gallery, used_seed, seed], | |
| api_visibility="public", | |
| ) | |
| css=''' | |
| .fillable{max-width: 1230px !important} | |
| ''' | |
| if __name__ == "__main__": | |
| demo.launch(css=css, mcp_server=True) |