Aduc-sdr-2_5 / app_animatediff.py
carlex3321's picture
Update app_animatediff.py
6f11d0a verified
raw
history blame
6.97 kB
import os, io, tempfile
from typing import Optional
from PIL import Image
import torch
import gradio as gr
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
# Modelos padrão (ajuste se desejar)
MODEL_ID = "SG161222/Realistic_Vision_V5.1_noVAE" # SD1.5 finetunado [attached_file:1]
ADAPTER_ID = "guoyww/animatediff-motion-adapter-v1-5-2" # MotionAdapter p/ SD1.4/1.5 [attached_file:1]
pipe = None
def load_pipe(model_id: str, adapter_id: str, cpu_offload: bool):
global pipe
if pipe is not None:
return pipe
# dtype preferível: float16 em CUDA, senão float32
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# MotionAdapter não aceita dtype em from_pretrained nas versões atuais
adapter = MotionAdapter.from_pretrained(adapter_id) # [attached_file:1]
# Carregar pipeline com dtype
try:
p = AnimateDiffPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
dtype=dtype # novas versões aceitam 'dtype' [attached_file:1]
)
except TypeError:
p = AnimateDiffPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
torch_dtype=dtype # fallback para versões que ainda usam torch_dtype [attached_file:1]
)
# Scheduler recomendado para estabilidade temporal
p.scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1
) # [attached_file:1]
# Otimizações de VRAM (APIs novas via VAE)
p.vae.enable_slicing() # [attached_file:1]
try:
p.vae.enable_tiling() # útil em resoluções mais altas [attached_file:1]
except Exception:
pass
# Alocação de device / offload
if cpu_offload and torch.cuda.is_available():
p.enable_model_cpu_offload() # reduz pico de VRAM [attached_file:1]
else:
p.to("cuda" if torch.cuda.is_available() else "cpu")
pipe = p
return pipe
def generate(
image: Image.Image,
prompt: str,
negative_prompt: str,
num_frames: int,
steps: int,
guidance: float,
seed: int,
width: Optional[int],
height: Optional[int],
fps: int,
save_mp4: bool,
model_id_ui: str,
adapter_id_ui: str,
cpu_offload: bool
):
if image is None or not prompt or not prompt.strip():
return None, None, "Envie uma imagem e um prompt válidos." # [attached_file:1]
p = load_pipe(model_id_ui or MODEL_ID, adapter_id_ui or ADAPTER_ID, cpu_offload)
gen = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(int(seed))
# img2vid sem IP-Adapter: NÃO passar ip_adapter_image
out = p(
prompt=prompt,
negative_prompt=negative_prompt or "",
num_frames=int(num_frames),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
generator=gen,
width=int(width) if width else None,
height=int(height) if height else None
) # [attached_file:1]
frames = out.frames[0] # lista de PILs [attached_file:1]
# Salvar GIF em caminho temporário com extensão .gif (evita erro do PIL)
temp_gif = os.path.join(tempfile.gettempdir(), "animation.gif")
export_to_gif(frames, temp_gif, fps=int(fps)) # [attached_file:1]
# Opcional: gravar MP4 com imageio-ffmpeg
mp4_path = None
if save_mp4:
try:
import imageio
mp4_path = os.path.join(tempfile.gettempdir(), "animation.mp4")
# Converter cada frame PIL para ndarray esperado pelo writer
with imageio.get_writer(mp4_path, fps=int(fps), codec="libx264", quality=8) as writer:
for fr in frames:
writer.append_data(imageio.v3.imread(io.BytesIO(fr.convert("RGB").tobytes())))
except Exception:
mp4_path = None # se falhar, apenas não retorna MP4
return temp_gif, mp4_path, f"Gerado {len(frames)} frames @ {fps} fps." # [attached_file:1]
def ui():
with gr.Blocks(title="AnimateDiff img2vid") as demo:
gr.Markdown("## AnimateDiff img2vid") # [attached_file:1]
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil", label="Imagem inicial") # [attached_file:1]
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Descreva estilo/movimento...") # [attached_file:1]
negative = gr.Textbox(label="Negative prompt", lines=2, value="low quality, worst quality") # [attached_file:1]
with gr.Row():
frames = gr.Slider(8, 64, value=16, step=1, label="Frames") # [attached_file:1]
steps = gr.Slider(4, 60, value=25, step=1, label="Steps") # [attached_file:1]
with gr.Row():
guidance = gr.Slider(0.5, 15.0, value=7.5, step=0.5, label="Guidance") # [attached_file:1]
fps = gr.Slider(4, 30, value=8, step=1, label="FPS") # [attached_file:1]
with gr.Row():
seed = gr.Number(value=42, precision=0, label="Seed") # [attached_file:1]
width = gr.Number(value=None, precision=0, label="Largura (opcional)") # [attached_file:1]
height = gr.Number(value=None, precision=0, label="Altura (opcional)") # [attached_file:1]
with gr.Row():
model_id_ui = gr.Textbox(value=MODEL_ID, label="Model ID (SD1.5 finetune)") # [attached_file:1]
adapter_id_ui = gr.Textbox(value=ADAPTER_ID, label="MotionAdapter ID") # [attached_file:1]
with gr.Row():
cpu_offload = gr.Checkbox(value=False, label="CPU offload") # [attached_file:1]
save_mp4 = gr.Checkbox(value=False, label="Salvar MP4") # [attached_file:1]
run_btn = gr.Button("Gerar animação") # [attached_file:1]
with gr.Column(scale=1):
video_out = gr.Video(label="Preview (GIF)") # [attached_file:1]
file_mp4 = gr.File(label="MP4 (download)", interactive=False) # [attached_file:1]
status = gr.Textbox(label="Status", interactive=False) # [attached_file:1]
def _run(*args):
temp_gif, mp4_path, msg = generate(*args)
return temp_gif, mp4_path, msg # [attached_file:1]
run_btn.click(
_run,
inputs=[image, prompt, negative, frames, steps, guidance, seed, width, height, fps, save_mp4, model_id_ui, adapter_id_ui, cpu_offload],
outputs=[video_out, file_mp4, status]
)
return demo
if __name__ == "__main__":
demo = ui()
demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=True) # [attached_file:1]