rahul7star commited on
Commit
a7fb1fd
·
verified ·
1 Parent(s): dc103ee

Create app_quant.py

Browse files
Files changed (1) hide show
  1. app_quant.py +139 -0
app_quant.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ import gradio as gr
4
+
5
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
6
+ from diffusers import ZImagePipeline, AutoModel
7
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
8
+
9
+ # ============================================================
10
+ # Model Settings
11
+ # ============================================================
12
+ model_cache = "./weights/"
13
+ model_id = "Tongyi-MAI/Z-Image-Turbo"
14
+ torch_dtype = torch.bfloat16
15
+ USE_CPU_OFFLOAD = False
16
+
17
+ # ============================================================
18
+ # GPU Check
19
+ # ============================================================
20
+ if torch.cuda.is_available():
21
+ print(f"INFO: CUDA available: {torch.cuda.get_device_name(0)} (count={torch.cuda.device_count()})")
22
+ device = "cuda:0"
23
+ gpu_id = 0
24
+ else:
25
+ raise RuntimeError("ERROR: CUDA not available. This program requires a CUDA-enabled GPU.")
26
+
27
+ # ============================================================
28
+ # Load Transformer
29
+ # ============================================================
30
+ print("INFO: Loading transformer block ...")
31
+ quantization_config = DiffusersBitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_quant_type="nf4",
34
+ bnb_4bit_compute_dtype=torch.bfloat16,
35
+ bnb_4bit_use_double_quant=True,
36
+ llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
37
+ )
38
+ transformer = AutoModel.from_pretrained(
39
+ model_id,
40
+ cache_dir=model_cache,
41
+ subfolder="transformer",
42
+ quantization_config=quantization_config,
43
+ torch_dtype=torch_dtype,
44
+ device_map=device,
45
+ )
46
+ print("INFO: Transformer block loaded.")
47
+
48
+ if USE_CPU_OFFLOAD:
49
+ transformer = transformer.to("cpu")
50
+
51
+ # ============================================================
52
+ # Load Text Encoder
53
+ # ============================================================
54
+ print("INFO: Loading text encoder ...")
55
+ quantization_config = TransformersBitsAndBytesConfig(
56
+ load_in_4bit=True,
57
+ bnb_4bit_quant_type="nf4",
58
+ bnb_4bit_compute_dtype=torch.bfloat16,
59
+ bnb_4bit_use_double_quant=True,
60
+ )
61
+ text_encoder = AutoModel.from_pretrained(
62
+ model_id,
63
+ cache_dir=model_cache,
64
+ subfolder="text_encoder",
65
+ quantization_config=quantization_config,
66
+ torch_dtype=torch_dtype,
67
+ device_map=device,
68
+ )
69
+ print("INFO: Text encoder loaded.")
70
+
71
+ if USE_CPU_OFFLOAD:
72
+ text_encoder = text_encoder.to("cpu")
73
+
74
+ # ============================================================
75
+ # Build Pipeline
76
+ # ============================================================
77
+ print("INFO: Building pipeline ...")
78
+ pipe = ZImagePipeline.from_pretrained(
79
+ model_id,
80
+ transformer=transformer,
81
+ text_encoder=text_encoder,
82
+ torch_dtype=torch_dtype,
83
+ )
84
+
85
+ if USE_CPU_OFFLOAD:
86
+ pipe.enable_model_cpu_offload(gpu_id=gpu_id)
87
+ print("INFO: CPU offload active")
88
+ else:
89
+ pipe.to(device)
90
+ print("INFO: Pipeline to GPU")
91
+
92
+ # ============================================================
93
+ # Inference Function for Gradio
94
+ # ============================================================
95
+ @spaces.GPU
96
+ def generate_image(prompt, height, width, steps, seed):
97
+ generator = torch.Generator(device).manual_seed(seed)
98
+
99
+ output = pipe(
100
+ prompt=prompt,
101
+ height=height,
102
+ width=width,
103
+ num_inference_steps=steps,
104
+ guidance_scale=0.0,
105
+ generator=generator,
106
+ )
107
+
108
+ return output.images[0]
109
+
110
+
111
+ # ============================================================
112
+ # Gradio UI
113
+ # ============================================================
114
+ with gr.Blocks(title="Z-Image-Turbo Generator") as demo:
115
+ gr.Markdown("# **Z-Image-Turbo — 4bit Quantized Image Generator**")
116
+
117
+ with gr.Row():
118
+ with gr.Column(scale=1):
119
+ prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image")
120
+ height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
121
+ width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
122
+ steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps")
123
+ seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
124
+
125
+ btn = gr.Button("Generate", variant="primary")
126
+
127
+ with gr.Column(scale=1):
128
+ output_image = gr.Image(label="Output Image")
129
+
130
+ btn.click(
131
+ generate_image,
132
+ inputs=[prompt, height, width, steps, seed],
133
+ outputs=[output_image],
134
+ )
135
+
136
+ # ============================================================
137
+ # Launch
138
+ # ============================================================
139
+ demo.launch()