Image2Video / app_quant.py
rahul7star's picture
Create app_quant.py
a7fb1fd verified
raw
history blame
4.51 kB
import torch
import spaces
import gradio as gr
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from diffusers import ZImagePipeline, AutoModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
# ============================================================
# Model Settings
# ============================================================
model_cache = "./weights/"
model_id = "Tongyi-MAI/Z-Image-Turbo"
torch_dtype = torch.bfloat16
USE_CPU_OFFLOAD = False
# ============================================================
# GPU Check
# ============================================================
if torch.cuda.is_available():
print(f"INFO: CUDA available: {torch.cuda.get_device_name(0)} (count={torch.cuda.device_count()})")
device = "cuda:0"
gpu_id = 0
else:
raise RuntimeError("ERROR: CUDA not available. This program requires a CUDA-enabled GPU.")
# ============================================================
# Load Transformer
# ============================================================
print("INFO: Loading transformer block ...")
quantization_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
)
transformer = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
print("INFO: Transformer block loaded.")
if USE_CPU_OFFLOAD:
transformer = transformer.to("cpu")
# ============================================================
# Load Text Encoder
# ============================================================
print("INFO: Loading text encoder ...")
quantization_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
text_encoder = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="text_encoder",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
print("INFO: Text encoder loaded.")
if USE_CPU_OFFLOAD:
text_encoder = text_encoder.to("cpu")
# ============================================================
# Build Pipeline
# ============================================================
print("INFO: Building pipeline ...")
pipe = ZImagePipeline.from_pretrained(
model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch_dtype,
)
if USE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload(gpu_id=gpu_id)
print("INFO: CPU offload active")
else:
pipe.to(device)
print("INFO: Pipeline to GPU")
# ============================================================
# Inference Function for Gradio
# ============================================================
@spaces.GPU
def generate_image(prompt, height, width, steps, seed):
generator = torch.Generator(device).manual_seed(seed)
output = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=0.0,
generator=generator,
)
return output.images[0]
# ============================================================
# Gradio UI
# ============================================================
with gr.Blocks(title="Z-Image-Turbo Generator") as demo:
gr.Markdown("# **Z-Image-Turbo — 4bit Quantized Image Generator**")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image")
height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps")
seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Output Image")
btn.click(
generate_image,
inputs=[prompt, height, width, steps, seed],
outputs=[output_image],
)
# ============================================================
# Launch
# ============================================================
demo.launch()