Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import spaces | |
| import gradio as gr | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| from diffusers import ZImagePipeline, AutoModel | |
| from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
| # ============================================================ | |
| # Model Settings | |
| # ============================================================ | |
| model_cache = "./weights/" | |
| model_id = "Tongyi-MAI/Z-Image-Turbo" | |
| torch_dtype = torch.bfloat16 | |
| USE_CPU_OFFLOAD = False | |
| # ============================================================ | |
| # GPU Check | |
| # ============================================================ | |
| if torch.cuda.is_available(): | |
| print(f"INFO: CUDA available: {torch.cuda.get_device_name(0)} (count={torch.cuda.device_count()})") | |
| device = "cuda:0" | |
| gpu_id = 0 | |
| else: | |
| raise RuntimeError("ERROR: CUDA not available. This program requires a CUDA-enabled GPU.") | |
| # ============================================================ | |
| # Load Transformer | |
| # ============================================================ | |
| print("INFO: Loading transformer block ...") | |
| quantization_config = DiffusersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| llm_int8_skip_modules=["transformer_blocks.0.img_mod"], | |
| ) | |
| transformer = AutoModel.from_pretrained( | |
| model_id, | |
| cache_dir=model_cache, | |
| subfolder="transformer", | |
| quantization_config=quantization_config, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| ) | |
| print("INFO: Transformer block loaded.") | |
| if USE_CPU_OFFLOAD: | |
| transformer = transformer.to("cpu") | |
| # ============================================================ | |
| # Load Text Encoder | |
| # ============================================================ | |
| print("INFO: Loading text encoder ...") | |
| quantization_config = TransformersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| text_encoder = AutoModel.from_pretrained( | |
| model_id, | |
| cache_dir=model_cache, | |
| subfolder="text_encoder", | |
| quantization_config=quantization_config, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| ) | |
| print("INFO: Text encoder loaded.") | |
| if USE_CPU_OFFLOAD: | |
| text_encoder = text_encoder.to("cpu") | |
| # ============================================================ | |
| # Build Pipeline | |
| # ============================================================ | |
| print("INFO: Building pipeline ...") | |
| pipe = ZImagePipeline.from_pretrained( | |
| model_id, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| torch_dtype=torch_dtype, | |
| ) | |
| if USE_CPU_OFFLOAD: | |
| pipe.enable_model_cpu_offload(gpu_id=gpu_id) | |
| print("INFO: CPU offload active") | |
| else: | |
| pipe.to(device) | |
| print("INFO: Pipeline to GPU") | |
| # ============================================================ | |
| # Inference Function for Gradio | |
| # ============================================================ | |
| def generate_image(prompt, height, width, steps, seed): | |
| generator = torch.Generator(device).manual_seed(seed) | |
| output = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| guidance_scale=0.0, | |
| generator=generator, | |
| ) | |
| return output.images[0] | |
| # ============================================================ | |
| # Gradio UI | |
| # ============================================================ | |
| with gr.Blocks(title="Z-Image-Turbo Generator") as demo: | |
| gr.Markdown("# **Z-Image-Turbo — 4bit Quantized Image Generator**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image") | |
| height = gr.Slider(256, 2048, value=1024, step=8, label="Height") | |
| width = gr.Slider(256, 2048, value=1024, step=8, label="Width") | |
| steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps") | |
| seed = gr.Slider(0, 999999, value=42, step=1, label="Seed") | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Output Image") | |
| btn.click( | |
| generate_image, | |
| inputs=[prompt, height, width, steps, seed], | |
| outputs=[output_image], | |
| ) | |
| # ============================================================ | |
| # Launch | |
| # ============================================================ | |
| demo.launch() | |