Spaces:
Paused
Paused
File size: 4,509 Bytes
a7fb1fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import spaces
import gradio as gr
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from diffusers import ZImagePipeline, AutoModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
# ============================================================
# Model Settings
# ============================================================
model_cache = "./weights/"
model_id = "Tongyi-MAI/Z-Image-Turbo"
torch_dtype = torch.bfloat16
USE_CPU_OFFLOAD = False
# ============================================================
# GPU Check
# ============================================================
if torch.cuda.is_available():
print(f"INFO: CUDA available: {torch.cuda.get_device_name(0)} (count={torch.cuda.device_count()})")
device = "cuda:0"
gpu_id = 0
else:
raise RuntimeError("ERROR: CUDA not available. This program requires a CUDA-enabled GPU.")
# ============================================================
# Load Transformer
# ============================================================
print("INFO: Loading transformer block ...")
quantization_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
)
transformer = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
print("INFO: Transformer block loaded.")
if USE_CPU_OFFLOAD:
transformer = transformer.to("cpu")
# ============================================================
# Load Text Encoder
# ============================================================
print("INFO: Loading text encoder ...")
quantization_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
text_encoder = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="text_encoder",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
print("INFO: Text encoder loaded.")
if USE_CPU_OFFLOAD:
text_encoder = text_encoder.to("cpu")
# ============================================================
# Build Pipeline
# ============================================================
print("INFO: Building pipeline ...")
pipe = ZImagePipeline.from_pretrained(
model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch_dtype,
)
if USE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload(gpu_id=gpu_id)
print("INFO: CPU offload active")
else:
pipe.to(device)
print("INFO: Pipeline to GPU")
# ============================================================
# Inference Function for Gradio
# ============================================================
@spaces.GPU
def generate_image(prompt, height, width, steps, seed):
generator = torch.Generator(device).manual_seed(seed)
output = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=0.0,
generator=generator,
)
return output.images[0]
# ============================================================
# Gradio UI
# ============================================================
with gr.Blocks(title="Z-Image-Turbo Generator") as demo:
gr.Markdown("# **Z-Image-Turbo — 4bit Quantized Image Generator**")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image")
height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps")
seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Output Image")
btn.click(
generate_image,
inputs=[prompt, height, width, steps, seed],
outputs=[output_image],
)
# ============================================================
# Launch
# ============================================================
demo.launch()
|