LongCat-Video / app_exp.py
rahul7star's picture
Update app_exp.py
fb85c0f verified
raw
history blame
8.91 kB
import spaces
import gradio as gr
import torch
import os
import sys
import subprocess
import tempfile
import numpy as np
import spaces
from PIL import Image
# Define paths
REPO_PATH = "LongCat-Video"
CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
# Clone repo if missing
if not os.path.exists(REPO_PATH):
print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...")
subprocess.run(
["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
check=True
)
sys.path.insert(0, os.path.abspath(REPO_PATH))
# Imports from LongCat repo
from huggingface_hub import snapshot_download
from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
from longcat_video.context_parallel import context_parallel_util
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.utils import export_to_video
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
# Download model weights if missing
if not os.path.exists(CHECKPOINT_DIR):
snapshot_download(
repo_id="meituan-longcat/LongCat-Video",
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False,
ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
)
pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
print("--- Initializing Models ---")
try:
cp_split_hw = context_parallel_util.get_optimal_split(1)
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype)
vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
# ✅ 4-bit quantization enabled
bnb_4bit_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
dit = LongCatVideoTransformer3DModel.from_pretrained(
CHECKPOINT_DIR,
enable_flashattn3=False,
enable_flashattn2=False,
enable_xformers=True,
subfolder="dit",
cp_split_hw=cp_split_hw,
torch_dtype=torch_dtype,
#quantization_config=bnb_4bit_config # ✅ added
)
pipe = LongCatVideoPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
scheduler=scheduler,
dit=dit,
).to(device)
pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors'), 'cfg_step_lora')
pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora')
print("--- Models loaded successfully ---")
except Exception as e:
print("❌ Model load error:", e)
pipe = None
# -------------------- GPU Cleanup --------------------
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# -------------------- Video Generation --------------------
def check_duration(*_args, duration_t2v=2, **_kwargs):
fps = 30
return duration_t2v * fps +30
@spaces.GPU(duration=check_duration)
def generate_video(
mode,
prompt,
neg_prompt,
image,
height, width, resolution,
seed,
use_distill,
use_refine,
duration_t2v=2,
progress=gr.Progress(track_tqdm=True)
):
if pipe is None:
raise gr.Error("Models failed to load.")
generator = torch.Generator(device=device).manual_seed(int(seed))
num_frames = int(duration_t2v * 30) # ✅ duration-based frame count
print(prompt)
is_distill = use_distill or use_refine
if is_distill:
pipe.dit.enable_loras(['cfg_step_lora'])
num_inference_steps = 16
guidance_scale = 1.0
neg = ""
else:
num_inference_steps = 50
guidance_scale = 4.0
neg = neg_prompt
if mode == "t2v":
output = pipe.generate_t2v(
prompt=prompt,
negative_prompt=neg,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
use_distill=is_distill,
guidance_scale=guidance_scale,
generator=generator,
)[0]
else:
pil_image = Image.fromarray(image)
output = pipe.generate_i2v(
image=pil_image,
prompt=prompt,
negative_prompt=neg,
resolution=resolution,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
use_distill=is_distill,
guidance_scale=guidance_scale,
generator=generator,
)[0]
if is_distill:
pipe.dit.disable_all_loras()
torch_gc()
if use_refine:
progress(0.5, desc="Refining")
pipe.dit.enable_loras(['refinement_lora'])
pipe.dit.enable_bsa()
frames = [(frame * 255).astype(np.uint8) for frame in output]
frames = [Image.fromarray(f) for f in frames]
ref_img = Image.fromarray(image) if mode == "i2v" else None
output = pipe.generate_refine(
image=ref_img,
prompt=prompt,
stage1_video=frames,
num_cond_frames=1 if mode == "i2v" else 0,
num_inference_steps=50,
generator=generator,
)[0]
pipe.dit.disable_all_loras()
pipe.dit.disable_bsa()
torch_gc()
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
export_to_video(output, tmp.name, fps=30)
print("video generatwd")
return tmp.name
# -------------------- Gradio UI --------------------
css = ".fillable{max-width:960px !important}"
with gr.Blocks(css=css) as demo:
gr.Markdown("# 🎬 LongCat-Video")
gr.Markdown("13.6B parameter dense video-generation model — [HuggingFace](https://huggingface.co/meituan-longcat/LongCat-Video)")
with gr.Tabs():
# --- T2V ---
with gr.TabItem("Text-to-Video"):
mode_t2v = gr.State("t2v")
prompt_t2v = gr.Textbox(label="Prompt", lines=4)
neg_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles")
height_t2v = gr.Slider(256, 1024, value=480, step=64, label="Height")
width_t2v = gr.Slider(256, 1024, value=832, step=64, label="Width")
seed_t2v = gr.Number(label="Seed", value=42)
distill_t2v = gr.Checkbox(label="Use Distill Mode", value=True)
refine_t2v = gr.Checkbox(label="Use Refine Mode", value=False)
duration_t2v = gr.Slider(1, 20, step=1, value=2, label="Duration (seconds)") # ✅ added
t2v_button = gr.Button("Generate Video")
video_out_t2v = gr.Video(label="Generated Video")
t2v_button.click(
fn=generate_video,
inputs=[mode_t2v, prompt_t2v, neg_t2v, gr.State(None),
height_t2v, width_t2v, gr.State(None),
seed_t2v, distill_t2v, refine_t2v, duration_t2v],
outputs=video_out_t2v
)
# --- I2V ---
with gr.TabItem("Image-to-Video"):
mode_i2v = gr.State("i2v")
image_i2v = gr.Image(type="numpy", label="Input Image")
prompt_i2v = gr.Textbox(label="Prompt", lines=4)
neg_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles, watermark")
resolution_i2v = gr.Dropdown(["480p", "720p"], value="480p", label="Resolution")
seed_i2v = gr.Number(label="Seed", value=42)
distill_i2v = gr.Checkbox(label="Use Distill Mode", value=True)
refine_i2v = gr.Checkbox(label="Use Refine Mode", value=False)
duration_i2v = gr.Slider(1, 20, step=1, value=2, label="Duration (seconds)") # ✅ added
i2v_button = gr.Button("Generate Video")
video_out_i2v = gr.Video(label="Generated Video")
i2v_button.click(
fn=generate_video,
inputs=[mode_i2v, prompt_i2v, neg_i2v, image_i2v,
gr.State(None), gr.State(None), resolution_i2v,
seed_i2v, distill_i2v, refine_i2v, duration_i2v],
outputs=video_out_i2v
)
if __name__ == "__main__":
demo.launch()