Smikke's picture
Deploy optimized Wan2.2 video generation with Zero GPU support
d16eb70 verified
# IMPORTANT: spaces must be imported first to avoid CUDA initialization issues
import spaces
# Standard library imports
import os
# Third-party imports (non-CUDA)
import numpy as np
from PIL import Image
import gradio as gr
# CUDA-related imports (must come after spaces)
import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
# Model configuration
MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Global pipeline variable
pipe = None
def initialize_pipeline():
"""Initialize the Wan2.2 pipeline"""
global pipe
if pipe is None:
print("Loading Wan2.2-TI2V-5B model...")
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID,
subfolder="vae",
torch_dtype=torch.float32
)
pipe = WanPipeline.from_pretrained(
MODEL_ID,
vae=vae,
torch_dtype=dtype
)
pipe.to(device)
print("Model loaded successfully!")
return pipe
@spaces.GPU(duration=180) # Allocate GPU for 3 minutes (max allowed for Pro)
def generate_video(
prompt: str,
image: Image.Image = None,
width: int = 1280,
height: int = 704,
num_frames: int = 73,
num_inference_steps: int = 35,
guidance_scale: float = 5.0,
seed: int = -1
):
"""
Generate video from text prompt and optional image
Args:
prompt: Text description of the video to generate
image: Optional input image for image-to-video generation
width: Video width (default: 1280)
height: Video height (default: 704)
num_frames: Number of frames to generate (default: 73 for 3 seconds at 24fps)
num_inference_steps: Number of denoising steps (default: 35 for faster generation)
guidance_scale: Guidance scale for generation (default: 5.0)
seed: Random seed for reproducibility (-1 for random)
"""
try:
# Initialize pipeline
pipeline = initialize_pipeline()
# Set seed for reproducibility
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device=device).manual_seed(seed)
# Prepare generation parameters
gen_params = {
"prompt": prompt,
"height": height,
"width": width,
"num_frames": num_frames,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
}
# Add image if provided (for image-to-video)
if image is not None:
gen_params["image"] = image
# Generate video
print(f"Generating video with prompt: {prompt}")
print(f"Parameters: {width}x{height}, {num_frames} frames, seed: {seed}")
output = pipeline(**gen_params).frames[0]
# Export to video file
output_path = "output.mp4"
export_to_video(output, output_path, fps=24)
return output_path, f"Video generated successfully! Seed used: {seed}"
except Exception as e:
error_msg = f"Error generating video: {str(e)}"
print(error_msg)
return None, error_msg
# Create Gradio interface
with gr.Blocks(title="Wan2.2 Video Generation") as demo:
gr.Markdown(
"""
# Wan2.2 Video Generation
Generate high-quality videos from text prompts or images using Wan2.2-TI2V-5B model.
This model supports both **Text-to-Video** and **Image-to-Video** generation at 720P/24fps.
**Note:** Generation takes 2-3 minutes. Settings are optimized for Zero GPU 3-minute limit.
"""
)
with gr.Row():
with gr.Column():
# Input controls
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate...",
lines=3,
value="Two anthropomorphic cats in comfy boxing gear fight on stage"
)
image_input = gr.Image(
label="Input Image (Optional - for Image-to-Video)",
type="pil",
sources=["upload"]
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width_input = gr.Slider(
label="Width",
minimum=512,
maximum=1920,
step=64,
value=1280
)
height_input = gr.Slider(
label="Height",
minimum=512,
maximum=1080,
step=64,
value=704
)
num_frames_input = gr.Slider(
label="Number of Frames (more frames = longer video)",
minimum=25,
maximum=145,
step=24,
value=73,
info="73 frames ≈ 3 seconds at 24fps (optimized for Zero GPU limits)"
)
num_steps_input = gr.Slider(
label="Inference Steps (more steps = better quality, slower)",
minimum=20,
maximum=60,
step=5,
value=35
)
guidance_scale_input = gr.Slider(
label="Guidance Scale (higher = closer to prompt)",
minimum=1.0,
maximum=15.0,
step=0.5,
value=5.0
)
seed_input = gr.Number(
label="Seed (-1 for random)",
value=-1,
precision=0
)
generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Column():
# Output
video_output = gr.Video(
label="Generated Video",
autoplay=True
)
status_output = gr.Textbox(
label="Status",
lines=2
)
# Examples
gr.Examples(
examples=[
[
"Two anthropomorphic cats in comfy boxing gear fight on stage",
None,
1280,
704,
73,
35,
5.0,
42
],
[
"A serene underwater scene with colorful coral reefs and tropical fish swimming gracefully",
None,
1280,
704,
73,
35,
5.0,
123
],
[
"A bustling futuristic city at night with neon lights and flying cars",
None,
1280,
704,
73,
35,
5.0,
456
],
[
"A peaceful mountain landscape with snow-capped peaks and a flowing river",
None,
1280,
704,
73,
35,
5.0,
789
],
],
inputs=[
prompt_input,
image_input,
width_input,
height_input,
num_frames_input,
num_steps_input,
guidance_scale_input,
seed_input
],
outputs=[video_output, status_output],
fn=generate_video,
cache_examples=False,
)
# Connect generate button
generate_btn.click(
fn=generate_video,
inputs=[
prompt_input,
image_input,
width_input,
height_input,
num_frames_input,
num_steps_input,
guidance_scale_input,
seed_input
],
outputs=[video_output, status_output]
)
gr.Markdown(
"""
## Tips for Best Results:
- Use detailed, descriptive prompts
- For image-to-video: Upload a clear image that matches your prompt
- Higher inference steps = better quality but slower generation
- Adjust guidance scale to balance creativity vs. prompt adherence
- Use the same seed to reproduce results
- Keep generation under 3 minutes to fit Zero GPU limits
## Model Information:
- Model: Wan2.2-TI2V-5B (5B parameters)
- Resolution: 720P (1280x704 or custom)
- Frame Rate: 24 fps
- Default Duration: 3 seconds (optimized for Zero GPU)
- Generation Time: ~2-3 minutes on Zero GPU (with optimized settings)
"""
)
# Launch the app
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch()