rahul7star commited on
Commit
b573d19
·
verified ·
1 Parent(s): a9bf3f2

Update app_exp.py

Browse files
Files changed (1) hide show
  1. app_exp.py +87 -76
app_exp.py CHANGED
@@ -11,33 +11,21 @@ from PIL import Image
11
  from huggingface_hub import snapshot_download, hf_hub_download
12
 
13
  # ============================================================
14
- # 0️⃣ FlashAttention 3 Setup
15
  # ============================================================
16
- # try:
17
- # print("Attempting to download and install FlashAttention wheel...")
18
- # flash_attention_wheel = hf_hub_download(
19
- # repo_id="rahul7star/flash-attn-3",
20
- # repo_type="model",
21
- # filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
22
- # )
23
- # subprocess.run(["pip", "install", flash_attention_wheel], check=True)
24
- # site.addsitedir(site.getsitepackages()[0])
25
- # importlib.invalidate_caches()
26
- # print("✅ FlashAttention installed successfully.")
27
- # enable_fa3 = True
28
- # except Exception as e:
29
- # print(f"⚠️ Could not install FlashAttention: {e}")
30
- # print("Continuing without FlashAttention...")
31
- # enable_fa3 = False
32
 
33
  # ============================================================
34
- # 1️⃣ Repository Setup
35
  # ============================================================
36
  REPO_PATH = "LongCat-Video"
37
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
38
 
39
  if not os.path.exists(REPO_PATH):
40
- print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...")
41
  subprocess.run(
42
  ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
43
  check=True
@@ -52,10 +40,10 @@ from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DMod
52
  from longcat_video.context_parallel import context_parallel_util
53
  from transformers import AutoTokenizer, UMT5EncoderModel
54
  from diffusers.utils import export_to_video
 
 
55
 
56
- # Download weights if not present
57
  if not os.path.exists(CHECKPOINT_DIR):
58
- print(f"Downloading model weights to '{CHECKPOINT_DIR}'...")
59
  snapshot_download(
60
  repo_id="meituan-longcat/LongCat-Video",
61
  local_dir=CHECKPOINT_DIR,
@@ -64,33 +52,61 @@ if not os.path.exists(CHECKPOINT_DIR):
64
  )
65
 
66
  # ============================================================
67
- # 2️⃣ Device & Models
68
  # ============================================================
69
  device = "cuda" if torch.cuda.is_available() else "cpu"
70
- torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
71
-
72
- print(f"Device: {device}, dtype: {torch_dtype}")
73
-
74
  pipe = None
 
75
  try:
76
  cp_split_hw = context_parallel_util.get_optimal_split(1)
77
-
78
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
79
- text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype)
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
82
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
83
 
 
84
  dit = LongCatVideoTransformer3DModel.from_pretrained(
85
  CHECKPOINT_DIR,
86
  enable_flashattn3=enable_fa3,
87
- enable_flashattn2=False,
88
  enable_xformers=True,
89
  subfolder="dit",
90
  cp_split_hw=cp_split_hw,
91
  torch_dtype=torch_dtype
92
  )
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  pipe = LongCatVideoPipeline(
95
  tokenizer=tokenizer,
96
  text_encoder=text_encoder,
@@ -99,55 +115,57 @@ try:
99
  dit=dit,
100
  )
101
  pipe.to(device)
102
-
103
- # Load LoRA weights
104
- pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors'), 'cfg_step_lora')
105
- pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora')
106
-
107
- print("✅ Models loaded successfully")
108
 
109
  except Exception as e:
110
  print(f"❌ Failed to load models: {e}")
111
  pipe = None
112
 
113
  # ============================================================
114
- # 3️⃣ Generation Helpers
115
  # ============================================================
116
  def torch_gc():
117
  if torch.cuda.is_available():
118
  torch.cuda.empty_cache()
119
  torch.cuda.ipc_collect()
120
 
121
- def generate_video(
122
  mode,
123
- prompt,
124
- neg_prompt,
125
  image,
126
  height, width, resolution,
127
  seed,
128
  use_distill,
129
  use_refine,
130
- duration_sec,
131
- progress=gr.Progress(track_tqdm=True)
132
  ):
 
 
 
 
 
 
 
 
 
 
 
133
  if pipe is None:
134
- raise gr.Error("Models failed to load")
135
 
136
- # Adaptive FPS for faster testing
137
  fps = 15 if use_distill else 30
138
  num_frames = int(duration_sec * fps)
139
  generator = torch.Generator(device=device).manual_seed(int(seed))
140
  is_distill = use_distill or use_refine
141
 
142
- # Stage 1
143
  progress(0.2, desc="Stage 1: Base Video Generation")
144
  pipe.dit.enable_loras(['cfg_step_lora'] if is_distill else [])
145
-
146
  num_inference_steps = 12 if is_distill else 24
147
  guidance_scale = 2.0 if is_distill else 4.0
148
  curr_neg_prompt = "" if is_distill else neg_prompt
149
 
150
- if mode == "t2v":
151
  output = pipe.generate_t2v(
152
  prompt=prompt,
153
  negative_prompt=curr_neg_prompt,
@@ -176,16 +194,13 @@ def generate_video(
176
  pipe.dit.disable_all_loras()
177
  torch_gc()
178
 
179
- # Stage 2: Optional refinement
180
  if use_refine:
181
  progress(0.5, desc="Stage 2: Refinement")
182
  pipe.dit.enable_loras(['refinement_lora'])
183
  pipe.dit.enable_bsa()
184
-
185
- stage1_video_pil = [(frame * 255).astype(np.uint8) for frame in output]
186
  stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]
187
- refine_image = Image.fromarray(image) if mode == 'i2v' else None
188
-
189
  output = pipe.generate_refine(
190
  image=refine_image,
191
  prompt=prompt,
@@ -194,44 +209,43 @@ def generate_video(
194
  num_inference_steps=50,
195
  generator=generator
196
  )[0]
197
-
198
  pipe.dit.disable_all_loras()
199
  pipe.dit.disable_bsa()
200
  torch_gc()
201
 
202
- # Export video
203
  progress(1.0, desc="Exporting video")
204
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file:
205
- export_to_video(output, temp_video_file.name, fps=fps)
206
- return temp_video_file.name
207
 
208
  # ============================================================
209
  # 4️⃣ Gradio UI
210
  # ============================================================
211
- css = ".fillable{max-width: 960px !important}"
212
 
213
  with gr.Blocks(css=css) as demo:
214
- gr.Markdown("# 🎬 LongCat-Video")
215
  gr.Markdown("13.6B parameter dense video-generation model by Meituan — [[Model](https://huggingface.co/meituan-longcat/LongCat-Video)]")
216
 
217
- with gr.Tabs() as tabs:
 
218
  with gr.TabItem("Text-to-Video"):
219
  mode_t2v = gr.State("t2v")
220
  with gr.Row():
221
  with gr.Column(scale=2):
222
  prompt_t2v = gr.Textbox(label="Prompt", lines=4)
223
  neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
224
- height_t2v = gr.Slider(256, 1024, step=64, value=480, label="Height")
225
- width_t2v = gr.Slider(256, 1024, step=64, value=832, label="Width")
226
- seed_t2v = gr.Number(value=42, label="Seed")
227
- distill_t2v = gr.Checkbox(value=True, label="Use Distill Mode")
228
- refine_t2v = gr.Checkbox(value=False, label="Use Refine Mode")
229
- duration_t2v = gr.Slider(1, 20, step=1, value=2, label="Video Duration (seconds)")
230
-
231
  t2v_button = gr.Button("Generate Video")
232
  with gr.Column(scale=3):
233
  video_output_t2v = gr.Video(label="Generated Video")
234
 
 
235
  with gr.TabItem("Image-to-Video"):
236
  mode_i2v = gr.State("i2v")
237
  with gr.Row():
@@ -240,16 +254,15 @@ with gr.Blocks(css=css) as demo:
240
  prompt_i2v = gr.Textbox(label="Prompt", lines=4)
241
  neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
242
  resolution_i2v = gr.Dropdown(["480p","720p"], value="480p", label="Resolution")
243
- seed_i2v = gr.Number(value=42, label="Seed")
244
- distill_i2v = gr.Checkbox(value=True, label="Use Distill Mode")
245
- refine_i2v = gr.Checkbox(value=False, label="Use Refine Mode")
246
- duration_i2v = gr.Slider(1, 20, step=1, value=2, label="Video Duration (seconds)")
247
-
248
  i2v_button = gr.Button("Generate Video")
249
  with gr.Column(scale=3):
250
  video_output_i2v = gr.Video(label="Generated Video")
251
 
252
- # Event binding
253
  t2v_button.click(
254
  generate_video,
255
  inputs=[mode_t2v, prompt_t2v, neg_prompt_t2v, gr.State(None),
@@ -266,8 +279,6 @@ with gr.Blocks(css=css) as demo:
266
  outputs=video_output_i2v
267
  )
268
 
269
- # ============================================================
270
- # 5️⃣ Launch
271
- # ============================================================
272
- if __name__ == "__main__":
273
  demo.launch()
 
11
  from huggingface_hub import snapshot_download, hf_hub_download
12
 
13
  # ============================================================
14
+ # 0️⃣ Install required packages
15
  # ============================================================
16
+ subprocess.run(["pip3", "install", "-U", "cache-dit"], check=True)
17
+
18
+
19
+
20
+ import cache_dit
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ============================================================
23
+ # 1️⃣ Repository & Weights
24
  # ============================================================
25
  REPO_PATH = "LongCat-Video"
26
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
27
 
28
  if not os.path.exists(REPO_PATH):
 
29
  subprocess.run(
30
  ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
31
  check=True
 
40
  from longcat_video.context_parallel import context_parallel_util
41
  from transformers import AutoTokenizer, UMT5EncoderModel
42
  from diffusers.utils import export_to_video
43
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
44
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
45
 
 
46
  if not os.path.exists(CHECKPOINT_DIR):
 
47
  snapshot_download(
48
  repo_id="meituan-longcat/LongCat-Video",
49
  local_dir=CHECKPOINT_DIR,
 
52
  )
53
 
54
  # ============================================================
55
+ # 2️⃣ Device & Models (with cache & quantization)
56
  # ============================================================
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
58
+ torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32
 
 
 
59
  pipe = None
60
+
61
  try:
62
  cp_split_hw = context_parallel_util.get_optimal_split(1)
 
63
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
64
+
65
+ # Text encoder with 4-bit quantization
66
+ text_encoder = UMT5EncoderModel.from_pretrained(
67
+ CHECKPOINT_DIR,
68
+ subfolder="text_encoder",
69
+ torch_dtype=torch_dtype,
70
+ quantization_config=TransformersBitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_quant_type="nf4",
73
+ bnb_4bit_compute_dtype=torch_dtype
74
+ )
75
+ )
76
 
77
  vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
78
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
79
 
80
+ # DiT model with FP8/4-bit quantization + cache
81
  dit = LongCatVideoTransformer3DModel.from_pretrained(
82
  CHECKPOINT_DIR,
83
  enable_flashattn3=enable_fa3,
 
84
  enable_xformers=True,
85
  subfolder="dit",
86
  cp_split_hw=cp_split_hw,
87
  torch_dtype=torch_dtype
88
  )
89
 
90
+ # Enable Cache-DiT
91
+ cache_dit.enable_cache(
92
+ cache_dit.BlockAdapter(
93
+ transformer=dit,
94
+ blocks=dit.blocks,
95
+ forward_pattern=cache_dit.ForwardPattern.Pattern_3,
96
+ check_forward_pattern=False,
97
+ has_separate_cfg=False
98
+ ),
99
+ cache_config=cache_dit.DBCacheConfig(
100
+ Fn_compute_blocks=1,
101
+ Bn_compute_blocks=1,
102
+ max_warmup_steps=5,
103
+ max_cached_steps=50,
104
+ max_continuous_cached_steps=50,
105
+ residual_diff_threshold=0.01,
106
+ num_inference_steps=50
107
+ )
108
+ )
109
+
110
  pipe = LongCatVideoPipeline(
111
  tokenizer=tokenizer,
112
  text_encoder=text_encoder,
 
115
  dit=dit,
116
  )
117
  pipe.to(device)
118
+ print("✅ Models loaded with Cache-DiT and quantization")
 
 
 
 
 
119
 
120
  except Exception as e:
121
  print(f"❌ Failed to load models: {e}")
122
  pipe = None
123
 
124
  # ============================================================
125
+ # 3️⃣ Generation Helper
126
  # ============================================================
127
  def torch_gc():
128
  if torch.cuda.is_available():
129
  torch.cuda.empty_cache()
130
  torch.cuda.ipc_collect()
131
 
132
+ def check_duration(
133
  mode,
134
+ prompt,
135
+ neg_prompt,
136
  image,
137
  height, width, resolution,
138
  seed,
139
  use_distill,
140
  use_refine,
141
+ progress
 
142
  ):
143
+ if use_distill and resolution=="480p":
144
+ return 180
145
+ elif resolution=="720p":
146
+ return 360
147
+ else:
148
+ return 900
149
+
150
+ @spaces.GPU(duration=180)
151
+ def generate_video(mode, prompt, neg_prompt, image, height, width, resolution,
152
+ seed, use_distill, use_refine, duration_sec, progress=gr.Progress(track_tqdm=True)):
153
+
154
  if pipe is None:
155
+ raise gr.Error("Models not loaded")
156
 
 
157
  fps = 15 if use_distill else 30
158
  num_frames = int(duration_sec * fps)
159
  generator = torch.Generator(device=device).manual_seed(int(seed))
160
  is_distill = use_distill or use_refine
161
 
 
162
  progress(0.2, desc="Stage 1: Base Video Generation")
163
  pipe.dit.enable_loras(['cfg_step_lora'] if is_distill else [])
 
164
  num_inference_steps = 12 if is_distill else 24
165
  guidance_scale = 2.0 if is_distill else 4.0
166
  curr_neg_prompt = "" if is_distill else neg_prompt
167
 
168
+ if mode=="t2v":
169
  output = pipe.generate_t2v(
170
  prompt=prompt,
171
  negative_prompt=curr_neg_prompt,
 
194
  pipe.dit.disable_all_loras()
195
  torch_gc()
196
 
 
197
  if use_refine:
198
  progress(0.5, desc="Stage 2: Refinement")
199
  pipe.dit.enable_loras(['refinement_lora'])
200
  pipe.dit.enable_bsa()
201
+ stage1_video_pil = [(frame*255).astype(np.uint8) for frame in output]
 
202
  stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]
203
+ refine_image = Image.fromarray(image) if mode=='i2v' else None
 
204
  output = pipe.generate_refine(
205
  image=refine_image,
206
  prompt=prompt,
 
209
  num_inference_steps=50,
210
  generator=generator
211
  )[0]
 
212
  pipe.dit.disable_all_loras()
213
  pipe.dit.disable_bsa()
214
  torch_gc()
215
 
 
216
  progress(1.0, desc="Exporting video")
217
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
218
+ export_to_video(output, f.name, fps=fps)
219
+ return f.name
220
 
221
  # ============================================================
222
  # 4️⃣ Gradio UI
223
  # ============================================================
224
+ css=".fillable{max-width:960px !important}"
225
 
226
  with gr.Blocks(css=css) as demo:
227
+ gr.Markdown("# 🎬 LongCat-Video with Cache-DiT & Quantization")
228
  gr.Markdown("13.6B parameter dense video-generation model by Meituan — [[Model](https://huggingface.co/meituan-longcat/LongCat-Video)]")
229
 
230
+ with gr.Tabs():
231
+ # Text-to-Video
232
  with gr.TabItem("Text-to-Video"):
233
  mode_t2v = gr.State("t2v")
234
  with gr.Row():
235
  with gr.Column(scale=2):
236
  prompt_t2v = gr.Textbox(label="Prompt", lines=4)
237
  neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
238
+ height_t2v = gr.Slider(256,1024,step=64,value=480,label="Height")
239
+ width_t2v = gr.Slider(256,1024,step=64,value=832,label="Width")
240
+ seed_t2v = gr.Number(value=42,label="Seed")
241
+ distill_t2v = gr.Checkbox(value=True,label="Use Distill Mode")
242
+ refine_t2v = gr.Checkbox(value=False,label="Use Refine Mode")
243
+ duration_t2v = gr.Slider(1,20,step=1,value=2,label="Video Duration (seconds)")
 
244
  t2v_button = gr.Button("Generate Video")
245
  with gr.Column(scale=3):
246
  video_output_t2v = gr.Video(label="Generated Video")
247
 
248
+ # Image-to-Video
249
  with gr.TabItem("Image-to-Video"):
250
  mode_i2v = gr.State("i2v")
251
  with gr.Row():
 
254
  prompt_i2v = gr.Textbox(label="Prompt", lines=4)
255
  neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
256
  resolution_i2v = gr.Dropdown(["480p","720p"], value="480p", label="Resolution")
257
+ seed_i2v = gr.Number(value=42,label="Seed")
258
+ distill_i2v = gr.Checkbox(value=True,label="Use Distill Mode")
259
+ refine_i2v = gr.Checkbox(value=False,label="Use Refine Mode")
260
+ duration_i2v = gr.Slider(1,20,step=1,value=2,label="Video Duration (seconds)")
 
261
  i2v_button = gr.Button("Generate Video")
262
  with gr.Column(scale=3):
263
  video_output_i2v = gr.Video(label="Generated Video")
264
 
265
+ # Bind events
266
  t2v_button.click(
267
  generate_video,
268
  inputs=[mode_t2v, prompt_t2v, neg_prompt_t2v, gr.State(None),
 
279
  outputs=video_output_i2v
280
  )
281
 
282
+ # Launch
283
+ if __name__=="__main__":
 
 
284
  demo.launch()