File size: 8,912 Bytes
9b7d954
5925ea4
 
b2baa38
 
c44ad11
b2baa38
 
d9b2184
b2baa38
9b7d954
d9b2184
b2baa38
 
c44ad11
d7fc5bc
b2baa38
d9b2184
5925ea4
 
 
 
 
b2baa38
ece32ef
d7fc5bc
d9b2184
b2baa38
 
 
 
 
 
c44ad11
b573d19
b2baa38
d7fc5bc
ece32ef
5925ea4
 
 
 
 
 
 
 
d9b2184
 
b573d19
d9b2184
5925ea4
 
d7fc5bc
5925ea4
d9b2184
5925ea4
 
 
d7fc5bc
d9b2184
 
 
 
 
 
5925ea4
 
d9b2184
 
5925ea4
 
 
d9b2184
9bb0c6e
b573d19
 
5925ea4
 
 
 
 
 
d7fc5bc
d9b2184
 
 
5925ea4
d7fc5bc
5925ea4
d7fc5bc
5925ea4
 
d7fc5bc
 
35a5c71
 
 
 
b2baa38
d7fc5bc
 
d9b2184
f2b2fa0
d9b2184
 
 
3e5f523
d9b2184
 
3e5f523
5925ea4
3e5f523
c44ad11
3e5f523
d9b2184
 
3e5f523
c44ad11
d9b2184
 
d7fc5bc
 
d9b2184
c44ad11
 
d9b2184
 
 
 
d7fc5bc
d9b2184
 
 
d7fc5bc
c44ad11
d9b2184
b2baa38
 
d7fc5bc
5925ea4
 
b2baa38
c44ad11
5925ea4
c44ad11
d9b2184
b2baa38
5925ea4
d9b2184
b2baa38
d9b2184
b2baa38
d7fc5bc
5925ea4
b2baa38
c44ad11
5925ea4
c44ad11
d9b2184
b2baa38
 
d9b2184
 
d7fc5bc
c44ad11
 
b2baa38
d7fc5bc
76410ab
 
d7fc5bc
 
 
 
 
76410ab
d7fc5bc
35a5c71
d7fc5bc
 
76410ab
d9b2184
b2baa38
d7fc5bc
c44ad11
 
 
 
d7fc5bc
 
fb85c0f
d7fc5bc
5925ea4
d7fc5bc
 
 
c44ad11
d9b2184
 
5925ea4
d7fc5bc
 
5925ea4
 
d9b2184
d7fc5bc
 
 
 
d9b2184
 
d7fc5bc
 
d9b2184
d7fc5bc
d9b2184
 
 
d7fc5bc
d9b2184
 
d7fc5bc
d9b2184
c44ad11
d7fc5bc
5925ea4
 
d9b2184
 
d7fc5bc
d9b2184
d7fc5bc
d9b2184
 
d7fc5bc
 
d9b2184
d7fc5bc
d9b2184
 
 
d7fc5bc
d9b2184
 
d7fc5bc
d9b2184
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import spaces
import gradio as gr
import torch
import os
import sys
import subprocess
import tempfile
import numpy as np
import spaces
from PIL import Image

# Define paths
REPO_PATH = "LongCat-Video"
CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")

# Clone repo if missing
if not os.path.exists(REPO_PATH):
    print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...")
    subprocess.run(
        ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
        check=True
    )

sys.path.insert(0, os.path.abspath(REPO_PATH))

# Imports from LongCat repo
from huggingface_hub import snapshot_download
from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
from longcat_video.context_parallel import context_parallel_util
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.utils import export_to_video
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig

# Download model weights if missing
if not os.path.exists(CHECKPOINT_DIR):
    snapshot_download(
        repo_id="meituan-longcat/LongCat-Video",
        local_dir=CHECKPOINT_DIR,
        local_dir_use_symlinks=False,
        ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
    )

pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32

print("--- Initializing Models ---")
try:
    cp_split_hw = context_parallel_util.get_optimal_split(1)

    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
    text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype)
    vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)

    # ✅ 4-bit quantization enabled
    bnb_4bit_config = DiffusersBitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    dit = LongCatVideoTransformer3DModel.from_pretrained(
        CHECKPOINT_DIR,
        enable_flashattn3=False,
        enable_flashattn2=False,
        enable_xformers=True,
        subfolder="dit",
        cp_split_hw=cp_split_hw,
        torch_dtype=torch_dtype,
        #quantization_config=bnb_4bit_config  # ✅ added
    )

    pipe = LongCatVideoPipeline(
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        vae=vae,
        scheduler=scheduler,
        dit=dit,
    ).to(device)

    pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors'), 'cfg_step_lora')
    pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora')

    print("--- Models loaded successfully ---")
except Exception as e:
    print("❌ Model load error:", e)
    pipe = None


# -------------------- GPU Cleanup --------------------
def torch_gc():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

# -------------------- Video Generation --------------------
def check_duration(*_args, duration_t2v=2, **_kwargs):
    fps = 30
    return duration_t2v * fps +30

@spaces.GPU(duration=check_duration)
def generate_video(
    mode,
    prompt,
    neg_prompt,
    image,
    height, width, resolution,
    seed,
    use_distill,
    use_refine,
    duration_t2v=2,
    progress=gr.Progress(track_tqdm=True)
):
    if pipe is None:
        raise gr.Error("Models failed to load.")

    generator = torch.Generator(device=device).manual_seed(int(seed))
    num_frames = int(duration_t2v * 30)  # ✅ duration-based frame count
    print(prompt)

    is_distill = use_distill or use_refine
    if is_distill:
        pipe.dit.enable_loras(['cfg_step_lora'])
        num_inference_steps = 16
        guidance_scale = 1.0
        neg = ""
    else:
        num_inference_steps = 50
        guidance_scale = 4.0
        neg = neg_prompt

    if mode == "t2v":
        output = pipe.generate_t2v(
            prompt=prompt,
            negative_prompt=neg,
            height=height,
            width=width,
            num_frames=num_frames,
            num_inference_steps=num_inference_steps,
            use_distill=is_distill,
            guidance_scale=guidance_scale,
            generator=generator,
        )[0]
    else:
        pil_image = Image.fromarray(image)
        output = pipe.generate_i2v(
            image=pil_image,
            prompt=prompt,
            negative_prompt=neg,
            resolution=resolution,
            num_frames=num_frames,
            num_inference_steps=num_inference_steps,
            use_distill=is_distill,
            guidance_scale=guidance_scale,
            generator=generator,
        )[0]

    if is_distill:
        pipe.dit.disable_all_loras()

    torch_gc()

    if use_refine:
        progress(0.5, desc="Refining")
        pipe.dit.enable_loras(['refinement_lora'])
        pipe.dit.enable_bsa()

        frames = [(frame * 255).astype(np.uint8) for frame in output]
        frames = [Image.fromarray(f) for f in frames]
        ref_img = Image.fromarray(image) if mode == "i2v" else None

        output = pipe.generate_refine(
            image=ref_img,
            prompt=prompt,
            stage1_video=frames,
            num_cond_frames=1 if mode == "i2v" else 0,
            num_inference_steps=50,
            generator=generator,
        )[0]

        pipe.dit.disable_all_loras()
        pipe.dit.disable_bsa()
        torch_gc()

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
        export_to_video(output, tmp.name, fps=30)
        print("video generatwd")
        return tmp.name


# -------------------- Gradio UI --------------------
css = ".fillable{max-width:960px !important}"
with gr.Blocks(css=css) as demo:
    gr.Markdown("# 🎬 LongCat-Video")
    gr.Markdown("13.6B parameter dense video-generation model — [HuggingFace](https://huggingface.co/meituan-longcat/LongCat-Video)")

    with gr.Tabs():
        # --- T2V ---
        with gr.TabItem("Text-to-Video"):
            mode_t2v = gr.State("t2v")
            prompt_t2v = gr.Textbox(label="Prompt", lines=4)
            neg_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles")
            height_t2v = gr.Slider(256, 1024, value=480, step=64, label="Height")
            width_t2v = gr.Slider(256, 1024, value=832, step=64, label="Width")
            seed_t2v = gr.Number(label="Seed", value=42)
            distill_t2v = gr.Checkbox(label="Use Distill Mode", value=True)
            refine_t2v = gr.Checkbox(label="Use Refine Mode", value=False)
            duration_t2v = gr.Slider(1, 20, step=1, value=2, label="Duration (seconds)")  # ✅ added

            t2v_button = gr.Button("Generate Video")
            video_out_t2v = gr.Video(label="Generated Video")

            t2v_button.click(
                fn=generate_video,
                inputs=[mode_t2v, prompt_t2v, neg_t2v, gr.State(None),
                        height_t2v, width_t2v, gr.State(None),
                        seed_t2v, distill_t2v, refine_t2v, duration_t2v],
                outputs=video_out_t2v
            )

        # --- I2V ---
        with gr.TabItem("Image-to-Video"):
            mode_i2v = gr.State("i2v")
            image_i2v = gr.Image(type="numpy", label="Input Image")
            prompt_i2v = gr.Textbox(label="Prompt", lines=4)
            neg_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles, watermark")
            resolution_i2v = gr.Dropdown(["480p", "720p"], value="480p", label="Resolution")
            seed_i2v = gr.Number(label="Seed", value=42)
            distill_i2v = gr.Checkbox(label="Use Distill Mode", value=True)
            refine_i2v = gr.Checkbox(label="Use Refine Mode", value=False)
            duration_i2v = gr.Slider(1, 20, step=1, value=2, label="Duration (seconds)")  # ✅ added

            i2v_button = gr.Button("Generate Video")
            video_out_i2v = gr.Video(label="Generated Video")

            i2v_button.click(
                fn=generate_video,
                inputs=[mode_i2v, prompt_i2v, neg_i2v, image_i2v,
                        gr.State(None), gr.State(None), resolution_i2v,
                        seed_i2v, distill_i2v, refine_i2v, duration_i2v],
                outputs=video_out_i2v
            )

if __name__ == "__main__":
    demo.launch()