import torch import spaces import gradio as gr from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig from diffusers import ZImagePipeline, AutoModel from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig # ============================================================ # Model Settings # ============================================================ model_cache = "./weights/" model_id = "Tongyi-MAI/Z-Image-Turbo" torch_dtype = torch.bfloat16 USE_CPU_OFFLOAD = False # ============================================================ # GPU Check # ============================================================ if torch.cuda.is_available(): print(f"INFO: CUDA available: {torch.cuda.get_device_name(0)} (count={torch.cuda.device_count()})") device = "cuda:0" gpu_id = 0 else: raise RuntimeError("ERROR: CUDA not available. This program requires a CUDA-enabled GPU.") # ============================================================ # Load Transformer # ============================================================ print("INFO: Loading transformer block ...") quantization_config = DiffusersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, llm_int8_skip_modules=["transformer_blocks.0.img_mod"], ) transformer = AutoModel.from_pretrained( model_id, cache_dir=model_cache, subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch_dtype, device_map=device, ) print("INFO: Transformer block loaded.") if USE_CPU_OFFLOAD: transformer = transformer.to("cpu") # ============================================================ # Load Text Encoder # ============================================================ print("INFO: Loading text encoder ...") quantization_config = TransformersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) text_encoder = AutoModel.from_pretrained( model_id, cache_dir=model_cache, subfolder="text_encoder", quantization_config=quantization_config, torch_dtype=torch_dtype, device_map=device, ) print("INFO: Text encoder loaded.") if USE_CPU_OFFLOAD: text_encoder = text_encoder.to("cpu") # ============================================================ # Build Pipeline # ============================================================ print("INFO: Building pipeline ...") pipe = ZImagePipeline.from_pretrained( model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch_dtype, ) if USE_CPU_OFFLOAD: pipe.enable_model_cpu_offload(gpu_id=gpu_id) print("INFO: CPU offload active") else: pipe.to(device) print("INFO: Pipeline to GPU") # ============================================================ # Inference Function for Gradio # ============================================================ @spaces.GPU def generate_image(prompt, height, width, steps, seed): generator = torch.Generator(device).manual_seed(seed) output = pipe( prompt=prompt, height=height, width=width, num_inference_steps=steps, guidance_scale=0.0, generator=generator, ) return output.images[0] # ============================================================ # Gradio UI # ============================================================ with gr.Blocks(title="Z-Image-Turbo Generator") as demo: gr.Markdown("# **Z-Image-Turbo — 4bit Quantized Image Generator**") with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image") height = gr.Slider(256, 2048, value=1024, step=8, label="Height") width = gr.Slider(256, 2048, value=1024, step=8, label="Width") steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps") seed = gr.Slider(0, 999999, value=42, step=1, label="Seed") btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): output_image = gr.Image(label="Output Image") btn.click( generate_image, inputs=[prompt, height, width, steps, seed], outputs=[output_image], ) # ============================================================ # Launch # ============================================================ demo.launch()