import gradio as gr import os import sys import torch import spaces # Required for @spaces.GPU decorator if running on HF Spaces GPU import shutil import uuid import subprocess from glob import glob from huggingface_hub import snapshot_download import tempfile from moviepy.editor import VideoFileClip from pydub import AudioSegment import argparse from omegaconf import OmegaConf from diffusers import AutoencoderKL, DDIMScheduler # Ensure these imports match your project structure if latentsync isn't installed globally from latentsync.models.unet import UNet3DConditionModel from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline from diffusers.utils.import_utils import is_xformers_available from accelerate.utils import set_seed from latentsync.whisper.audio2feature import Audio2Feature # --- Model Download --- # Best practice: Define cache dir if needed, or let HF Hub handle it MODEL_CACHE_DIR = "./checkpoints" os.makedirs(MODEL_CACHE_DIR, exist_ok=True) print("Downloading models...") try: snapshot_download( repo_id="chunyu-li/LatentSync", local_dir=MODEL_CACHE_DIR, # allow_patterns=["*.pt", "*.yaml", "*.json"], # Optionally specify patterns # ignore_patterns=["*.safetensors"], # Optionally ignore patterns local_dir_use_symlinks=False # Recommended on Spaces to avoid symlink issues ) print("Model download complete.") except Exception as e: print(f"Error downloading models: {e}") # Decide how to handle download failure (e.g., exit, try again, use local if available) # For simplicity, we'll proceed, but a real app might exit here. # sys.exit(1) # --- Device Setup --- device = "cuda" if torch.cuda.is_available() else "cpu" torch.set_grad_enabled(False) # Important for inference if torch.cuda.is_available(): torch_dtype = torch.float16 print(f"Running on GPU ({torch.cuda.get_device_name(0)}) with dtype {torch_dtype}") else: torch_dtype = torch.float32 print(f"Running on CPU with dtype {torch_dtype}") # --- Preprocessing Functions --- def process_video(input_video_path, temp_dir, max_duration=10): """ Crop video to max_duration if longer. Saves to temp_dir. """ os.makedirs(temp_dir, exist_ok=True) output_video_path = os.path.join(temp_dir, f"cropped_{uuid.uuid4()}.mp4") # Unique name try: print(f"Processing video: {input_video_path}") with VideoFileClip(input_video_path) as video: if video.duration > max_duration: print(f"Video duration ({video.duration}s) > {max_duration}s. Cropping.") video = video.subclip(0, max_duration) else: print(f"Video duration ({video.duration}s) <= {max_duration}s. No cropping needed.") # Use recommended codecs, adjust bitrate/preset for quality/speed trade-off video.write_videofile(output_video_path, codec="libx264", audio_codec="aac", logger=None) # suppress verbose logs print(f"Processed video saved to: {output_video_path}") return output_video_path except Exception as e: print(f"Error processing video {input_video_path}: {e}") # Clean up partial file if it exists if os.path.exists(output_video_path): os.remove(output_video_path) raise # Re-raise the exception def process_audio(file_path, temp_dir, max_duration_ms=8000): """ Trim audio to max_duration_ms if longer. Saves WAV to temp_dir. """ os.makedirs(temp_dir, exist_ok=True) output_path = os.path.join(temp_dir, f"trimmed_audio_{uuid.uuid4()}.wav") # Unique name try: print(f"Processing audio: {file_path}") audio = AudioSegment.from_file(file_path) if len(audio) > max_duration_ms: print(f"Audio duration ({len(audio)}ms) > {max_duration_ms}ms. Trimming.") audio = audio[:max_duration_ms] else: print(f"Audio duration ({len(audio)}ms) <= {max_duration_ms}ms. No trimming needed.") audio.export(output_path, format="wav") print(f"Processed audio saved at: {output_path}") return output_path except Exception as e: print(f"Error processing audio {file_path}: {e}") # Clean up partial file if it exists if os.path.exists(output_path): os.remove(output_path) raise # Re-raise the exception # --- Main Inference Function --- # Use @spaces.GPU decorator ONLY if deploying on Hugging Face Spaces GPU hardware # Remove it if running locally or on CPU Spaces. @spaces.GPU # type: ignore # Add type ignore if 'spaces' is conditionally imported or might not be present def main(video_path, audio_path, progress=gr.Progress(track_tqdm=True)): """ Main function to perform lip synchronization. Handles preprocessing, model loading, inference, and cleanup. """ if not video_path or not audio_path: raise gr.Error("Please provide both a video and an audio file.") print(f"\n--- Starting Job ---") print(f"Received Video: {video_path}") print(f"Received Audio: {audio_path}") # --- Configuration & Model Paths --- # It's often cleaner to load OmegaConf config outside the main function if static # but keeping it here is fine too. unet_config_path = "configs/unet/second_stage.yaml" # Relative to MODEL_CACHE_DIR or root config = OmegaConf.load(os.path.join(MODEL_CACHE_DIR, unet_config_path)) inference_ckpt_path = os.path.join(MODEL_CACHE_DIR, "latentsync_unet.pt") vae_model_path = "stabilityai/sd-vae-ft-mse" # Standard VAE # Determine Whisper model based on UNet config cross_attention_dim = config.model.get("cross_attention_dim", 768) # Use .get for safety if cross_attention_dim == 768: whisper_model_name = "small.pt" elif cross_attention_dim == 384: whisper_model_name = "tiny.pt" else: raise ValueError(f"Unsupported cross_attention_dim: {cross_attention_dim}") whisper_model_path = os.path.join(MODEL_CACHE_DIR, "whisper", whisper_model_name) print(f"Using UNet config: {unet_config_path}") print(f"Using UNet checkpoint: {inference_ckpt_path}") print(f"Using VAE: {vae_model_path}") print(f"Using Whisper model: {whisper_model_path}") # --- Preprocessing Handling --- # Decide if preprocessing (cropping/trimming) is always needed, or conditional # For API usage, you might want to *always* enforce limits for consistency & resource control. # Let's enforce limits here. # Create a temporary directory for this specific job temp_dir = tempfile.mkdtemp() print(f"Created temporary directory: {temp_dir}") processed_video_path = None processed_audio_path = None try: # Process Video (ensure it's within limits) processed_video_path = process_video(video_path, temp_dir, max_duration=10) # Enforce 10s max # Process Audio (ensure it's within limits) processed_audio_path = process_audio(audio_path, temp_dir, max_duration_ms=8000) # Enforce 8s max # --- Load Models (Load after potentially long download/preprocessing) --- print("Loading models...") progress(0.1, desc="Loading VAE...") vae = AutoencoderKL.from_pretrained(vae_model_path, torch_dtype=torch_dtype).to(device) # Explicitly set scaling factor if needed, though often handled by pipeline # vae.config.scaling_factor = 0.18215 # vae.config.shift_factor = 0 progress(0.3, desc="Loading Audio Encoder...") audio_encoder = Audio2Feature(model_path=whisper_model_path, device=device, num_frames=config.data.num_frames) progress(0.5, desc="Loading UNet...") unet, _ = UNet3DConditionModel.from_pretrained( OmegaConf.to_container(config.model), inference_ckpt_path, device=device, torch_dtype=torch_dtype, # Pass dtype here ) # unet = unet.to(dtype=torch_dtype) # Already handled by from_pretrained if dtype passed # Enable xformers if available (check compatibility with torch version) if is_xformers_available(): try: unet.enable_xformers_memory_efficient_attention() print("xFormers memory efficient attention enabled.") except Exception as e: print(f"Could not enable xFormers: {e}. Running without it.") progress(0.7, desc="Setting up Pipeline...") scheduler = DDIMScheduler.from_pretrained(os.path.join(MODEL_CACHE_DIR, "configs")) # Load scheduler config pipeline = LipsyncPipeline( vae=vae, audio_encoder=audio_encoder, unet=unet, scheduler=scheduler, ).to(device) print("Models loaded and pipeline ready.") # --- Inference --- seed = torch.seed() # Get a random seed set_seed(seed) print(f"Using Seed: {seed}") output_dir = "outputs" # Define a directory for final outputs os.makedirs(output_dir, exist_ok=True) unique_id = str(uuid.uuid4()) video_out_path = os.path.join(output_dir, f"video_out_{unique_id}.mp4") print(f"Starting inference. Output will be saved to: {video_out_path}") progress(0.8, desc="Generating Lip Sync...") # Note: The pipeline call might handle tqdm internally, potentially conflicting # with gr.Progress. If you see double progress bars, consider removing gr.Progress # or modifying the pipeline not to show its own tqdm. pipeline( video_path=processed_video_path, # Use processed video audio_path=processed_audio_path, # Use processed audio video_out_path=video_out_path, video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"), # Mask might not be needed for API user num_frames=config.data.num_frames, num_inference_steps=config.run.inference_steps, guidance_scale=1.0, weight_dtype=torch_dtype, # Pass the determined dtype width=config.data.resolution, height=config.data.resolution, # Pass progress object maybe? Check if pipeline supports it ) progress(1.0, desc="Completed!") print(f"Inference complete. Output video: {video_out_path}") return video_out_path # Return the path to the final output video except Exception as e: print(f"Error during main execution: {e}") import traceback traceback.print_exc() # Print detailed traceback for debugging # Raise a Gradio specific error to show it nicely in the UI/API response raise gr.Error(f"An error occurred: {e}") finally: # --- Cleanup --- if temp_dir and os.path.exists(temp_dir): try: shutil.rmtree(temp_dir) print(f"Cleaned up temporary directory: {temp_dir}") except Exception as e: print(f"Error cleaning up temporary directory {temp_dir}: {e}") print("--- Job Finished ---") # --- Gradio Interface Definition --- css=""" div#col-container { margin: 0 auto; max-width: 982px; /* Adjust as needed */ } /* Add other CSS styling if desired */ """ # Using a standard theme for broader compatibility, replace if custom theme is essential theme = gr.themes.Soft( primary_hue="blue", secondary_hue="purple", ).set( # Example of setting specific component properties if needed # button_primary_background_fill="#007bff", ) with gr.Blocks(theme=theme, css=css) as demo: gr.Markdown("# LatentSync: Audio Conditioned Latent Diffusion Models for Lip Sync") gr.Markdown("Generate lip-synced videos based on an input video and audio.") gr.Markdown("**Note:** Input videos longer than 10s and audio longer than 8s will be automatically trimmed.") gr.Markdown("**Demo by [Sunder Ali Khowaja](https://sander-ali.github.io) - [X](https://x.com/SunderAKhowaja) -[Github](https://github.com/sander-ali) -[Hugging Face](https://huggingface.co/SunderAli17)**") with gr.Row(): with gr.Column(scale=1): video_input = gr.Video(label="Reference Video", info="Upload the video to lip-sync (max 10s used).", format="mp4") audio_input = gr.Audio(label="Target Audio", info="Upload the audio to sync lips to (max 8s used).", type="filepath") # type="filepath" is crucial submit_btn = gr.Button("Generate Lip Sync", variant="primary") with gr.Column(scale=1): video_result = gr.Video(label="Result", info="Generated lip-synced video.") with gr.Row(): gr.Examples( examples_per_page=3, examples=[ # --- CORRECTED PATHS --- ["assets/demo1_video.mp4", "assets/demo1_audio.wav"], ["assets/demo2_video.mp4", "assets/demo2_audio.wav"], ["assets/demo3_video.mp4", "assets/demo3_audio.wav"], ], inputs=[video_input, audio_input], outputs=[video_result], # Output to video_result for examples fn=main, # Make examples clickable cache_examples=False, # Set True if inputs/outputs are deterministic and you want caching # Ensure the 'assets' directory exists at the root of your Space repo! ) # Link the button click to the main function submit_btn.click( fn=main, inputs=[video_input, audio_input], outputs=[video_result], api_name="predict" # <-- This makes fn 'main' available at the /run/predict endpoint ) # --- Launch the Gradio App --- # Use queue() for handling multiple users/long inference times # show_api=True exposes the API documentation link in the footer demo.queue().launch( show_api=True, show_error=True, # Shows Python errors in the UI console, helpful for debugging share=True, # Set to True to create a temporary public link (useful for testing) # server_name="0.0.0.0", # Optional: Makes it accessible on local network # server_port=7860, # Optional: Specify port debug=True # Optional: More verbose logging )