File size: 4,509 Bytes
a7fb1fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import spaces
import gradio as gr

from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from diffusers import ZImagePipeline, AutoModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

# ============================================================
# Model Settings
# ============================================================
model_cache = "./weights/"
model_id = "Tongyi-MAI/Z-Image-Turbo"
torch_dtype = torch.bfloat16
USE_CPU_OFFLOAD = False

# ============================================================
# GPU Check
# ============================================================
if torch.cuda.is_available():
    print(f"INFO: CUDA available: {torch.cuda.get_device_name(0)} (count={torch.cuda.device_count()})")
    device = "cuda:0"
    gpu_id = 0
else:
    raise RuntimeError("ERROR: CUDA not available. This program requires a CUDA-enabled GPU.")

# ============================================================
# Load Transformer
# ============================================================
print("INFO: Loading transformer block ...")
quantization_config = DiffusersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
)
transformer = AutoModel.from_pretrained(
    model_id,
    cache_dir=model_cache,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=torch_dtype,
    device_map=device,
)
print("INFO: Transformer block loaded.")

if USE_CPU_OFFLOAD:
    transformer = transformer.to("cpu")

# ============================================================
# Load Text Encoder
# ============================================================
print("INFO: Loading text encoder ...")
quantization_config = TransformersBitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
text_encoder = AutoModel.from_pretrained(
    model_id,
    cache_dir=model_cache,
    subfolder="text_encoder",
    quantization_config=quantization_config,
    torch_dtype=torch_dtype,
    device_map=device,
)
print("INFO: Text encoder loaded.")

if USE_CPU_OFFLOAD:
    text_encoder = text_encoder.to("cpu")

# ============================================================
# Build Pipeline
# ============================================================
print("INFO: Building pipeline ...")
pipe = ZImagePipeline.from_pretrained(
    model_id,
    transformer=transformer,
    text_encoder=text_encoder,
    torch_dtype=torch_dtype,
)

if USE_CPU_OFFLOAD:
    pipe.enable_model_cpu_offload(gpu_id=gpu_id)
    print("INFO: CPU offload active")
else:
    pipe.to(device)
    print("INFO: Pipeline to GPU")

# ============================================================
# Inference Function for Gradio
# ============================================================
@spaces.GPU
def generate_image(prompt, height, width, steps, seed):
    generator = torch.Generator(device).manual_seed(seed)
    
    output = pipe(
        prompt=prompt,
        height=height,
        width=width,
        num_inference_steps=steps,
        guidance_scale=0.0,
        generator=generator,
    )

    return output.images[0]


# ============================================================
# Gradio UI
# ============================================================
with gr.Blocks(title="Z-Image-Turbo Generator") as demo:
    gr.Markdown("# **Z-Image-Turbo — 4bit Quantized Image Generator**")

    with gr.Row():
        with gr.Column(scale=1):
            prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image")
            height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
            width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
            steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps")
            seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")

            btn = gr.Button("Generate", variant="primary")
        
        with gr.Column(scale=1):
            output_image = gr.Image(label="Output Image")

    btn.click(
        generate_image,
        inputs=[prompt, height, width, steps, seed],
        outputs=[output_image],
    )

# ============================================================
# Launch
# ============================================================
demo.launch()