carlex3321 commited on
Commit
6f11d0a
·
verified ·
1 Parent(s): cb89de3

Update app_animatediff.py

Browse files
Files changed (1) hide show
  1. app_animatediff.py +62 -55
app_animatediff.py CHANGED
@@ -1,14 +1,14 @@
1
- # app_gradio_img2vid.py
2
  import os, io, tempfile
3
- from typing import Optional, List
4
  from PIL import Image
5
  import torch
6
  import gradio as gr
7
  from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
8
  from diffusers.utils import export_to_gif
9
 
10
- MODEL_ID = "SG161222/Realistic_Vision_V5.1_noVAE" # SD1.5 finetunado [attached_file:1]
11
- ADAPTER_ID = "guoyww/animatediff-motion-adapter-v1-5-2" # MotionAdapter p/ SD1.4/1.5 [attached_file:1]
 
12
 
13
  pipe = None
14
 
@@ -16,19 +16,27 @@ def load_pipe(model_id: str, adapter_id: str, cpu_offload: bool):
16
  global pipe
17
  if pipe is not None:
18
  return pipe
 
19
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
20
 
21
- # 1) Carregar adapter SEM dtype
22
- adapter = MotionAdapter.from_pretrained(adapter_id) # antes: dtype=dtype (removido) [attached_file:1]
23
 
24
- # 2) Carregar a pipeline com dtype
25
- p = AnimateDiffPipeline.from_pretrained(
26
- model_id,
27
- motion_adapter=adapter,
28
- dtype=dtype # ou torch_dtype=dtype dependendo da versão instalada [attached_file:1]
29
- ) # [attached_file:1]
 
 
 
 
 
 
 
30
 
31
- # 3) Scheduler recomendado
32
  p.scheduler = DDIMScheduler.from_pretrained(
33
  model_id,
34
  subfolder="scheduler",
@@ -38,23 +46,22 @@ def load_pipe(model_id: str, adapter_id: str, cpu_offload: bool):
38
  steps_offset=1
39
  ) # [attached_file:1]
40
 
41
- # 4) Otimizações VAE nas APIs novas
42
- p.vae.enable_slicing() # [attached_file:1]
43
  try:
44
- p.vae.enable_tiling() # [attached_file:1]
45
  except Exception:
46
  pass
47
 
48
- # 5) Device/offload
49
  if cpu_offload and torch.cuda.is_available():
50
- p.enable_model_cpu_offload() # [attached_file:1]
51
  else:
52
  p.to("cuda" if torch.cuda.is_available() else "cpu")
53
 
54
  pipe = p
55
  return pipe
56
 
57
-
58
  def generate(
59
  image: Image.Image,
60
  prompt: str,
@@ -72,9 +79,12 @@ def generate(
72
  cpu_offload: bool
73
  ):
74
  if image is None or not prompt or not prompt.strip():
75
- return None, None, "Envie uma imagem e um prompt válidos."
 
76
  p = load_pipe(model_id_ui or MODEL_ID, adapter_id_ui or ADAPTER_ID, cpu_offload)
 
77
  gen = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(int(seed))
 
78
  # img2vid sem IP-Adapter: NÃO passar ip_adapter_image
79
  out = p(
80
  prompt=prompt,
@@ -86,64 +96,61 @@ def generate(
86
  width=int(width) if width else None,
87
  height=int(height) if height else None
88
  ) # [attached_file:1]
89
- frames = out.frames[0] # lista de PIL [attached_file:1]
90
 
91
- # GIF em memória
92
- gif_buf = io.BytesIO()
93
- export_to_gif(frames, gif_buf, fps=int(fps)) # [attached_file:1]
94
- gif_buf.seek(0)
 
95
 
 
96
  mp4_path = None
97
  if save_mp4:
98
  try:
99
  import imageio
100
  mp4_path = os.path.join(tempfile.gettempdir(), "animation.mp4")
 
101
  with imageio.get_writer(mp4_path, fps=int(fps), codec="libx264", quality=8) as writer:
102
  for fr in frames:
103
- writer.append_data(imageio.v3.imread(io.BytesIO(fr.tobytes())))
104
  except Exception:
105
- mp4_path = None
106
 
107
- return gif_buf, mp4_path, f"Gerado {len(frames)} frames @ {fps} fps."
108
 
109
  def ui():
110
  with gr.Blocks(title="AnimateDiff img2vid") as demo:
111
- gr.Markdown("## AnimateDiff img2vid")
112
  with gr.Row():
113
  with gr.Column(scale=1):
114
- image = gr.Image(type="pil", label="Imagem inicial")
115
- prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Descreva estilo/movimento...")
116
- negative = gr.Textbox(label="Negative prompt", lines=2, value="low quality, worst quality")
117
  with gr.Row():
118
- frames = gr.Slider(8, 64, value=16, step=1, label="Frames")
119
- steps = gr.Slider(4, 60, value=25, step=1, label="Steps")
120
  with gr.Row():
121
- guidance = gr.Slider(0.5, 15.0, value=7.5, step=0.5, label="Guidance")
122
- fps = gr.Slider(4, 30, value=8, step=1, label="FPS")
123
  with gr.Row():
124
- seed = gr.Number(value=42, precision=0, label="Seed")
125
- width = gr.Number(value=None, precision=0, label="Largura (opcional)")
126
- height = gr.Number(value=None, precision=0, label="Altura (opcional)")
127
  with gr.Row():
128
- model_id_ui = gr.Textbox(value=MODEL_ID, label="Model ID (SD1.5 finetune)")
129
- adapter_id_ui = gr.Textbox(value=ADAPTER_ID, label="MotionAdapter ID")
130
  with gr.Row():
131
- cpu_offload = gr.Checkbox(value=False, label="CPU offload")
132
- save_mp4 = gr.Checkbox(value=False, label="Salvar MP4")
133
- run_btn = gr.Button("Gerar animação")
134
  with gr.Column(scale=1):
135
- video_out = gr.Video(label="Preview (GIF salvo temporário)", interactive=False)
136
- file_mp4 = gr.File(label="MP4 (download)", interactive=False)
137
- status = gr.Textbox(label="Status", interactive=False)
138
 
139
  def _run(*args):
140
- gif_buf, mp4_path, msg = generate(*args)
141
- temp_gif = None
142
- if gif_buf:
143
- temp_gif = os.path.join(tempfile.gettempdir(), "animation.gif")
144
- with open(temp_gif, "wb") as f:
145
- f.write(gif_buf.read())
146
- return temp_gif, mp4_path, msg
147
 
148
  run_btn.click(
149
  _run,
@@ -154,4 +161,4 @@ def ui():
154
 
155
  if __name__ == "__main__":
156
  demo = ui()
157
- demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=True)
 
 
1
  import os, io, tempfile
2
+ from typing import Optional
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
  from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
7
  from diffusers.utils import export_to_gif
8
 
9
+ # Modelos padrão (ajuste se desejar)
10
+ MODEL_ID = "SG161222/Realistic_Vision_V5.1_noVAE" # SD1.5 finetunado [attached_file:1]
11
+ ADAPTER_ID = "guoyww/animatediff-motion-adapter-v1-5-2" # MotionAdapter p/ SD1.4/1.5 [attached_file:1]
12
 
13
  pipe = None
14
 
 
16
  global pipe
17
  if pipe is not None:
18
  return pipe
19
+ # dtype preferível: float16 em CUDA, senão float32
20
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
 
22
+ # MotionAdapter não aceita dtype em from_pretrained nas versões atuais
23
+ adapter = MotionAdapter.from_pretrained(adapter_id) # [attached_file:1]
24
 
25
+ # Carregar pipeline com dtype
26
+ try:
27
+ p = AnimateDiffPipeline.from_pretrained(
28
+ model_id,
29
+ motion_adapter=adapter,
30
+ dtype=dtype # novas versões aceitam 'dtype' [attached_file:1]
31
+ )
32
+ except TypeError:
33
+ p = AnimateDiffPipeline.from_pretrained(
34
+ model_id,
35
+ motion_adapter=adapter,
36
+ torch_dtype=dtype # fallback para versões que ainda usam torch_dtype [attached_file:1]
37
+ )
38
 
39
+ # Scheduler recomendado para estabilidade temporal
40
  p.scheduler = DDIMScheduler.from_pretrained(
41
  model_id,
42
  subfolder="scheduler",
 
46
  steps_offset=1
47
  ) # [attached_file:1]
48
 
49
+ # Otimizações de VRAM (APIs novas via VAE)
50
+ p.vae.enable_slicing() # [attached_file:1]
51
  try:
52
+ p.vae.enable_tiling() # útil em resoluções mais altas [attached_file:1]
53
  except Exception:
54
  pass
55
 
56
+ # Alocação de device / offload
57
  if cpu_offload and torch.cuda.is_available():
58
+ p.enable_model_cpu_offload() # reduz pico de VRAM [attached_file:1]
59
  else:
60
  p.to("cuda" if torch.cuda.is_available() else "cpu")
61
 
62
  pipe = p
63
  return pipe
64
 
 
65
  def generate(
66
  image: Image.Image,
67
  prompt: str,
 
79
  cpu_offload: bool
80
  ):
81
  if image is None or not prompt or not prompt.strip():
82
+ return None, None, "Envie uma imagem e um prompt válidos." # [attached_file:1]
83
+
84
  p = load_pipe(model_id_ui or MODEL_ID, adapter_id_ui or ADAPTER_ID, cpu_offload)
85
+
86
  gen = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(int(seed))
87
+
88
  # img2vid sem IP-Adapter: NÃO passar ip_adapter_image
89
  out = p(
90
  prompt=prompt,
 
96
  width=int(width) if width else None,
97
  height=int(height) if height else None
98
  ) # [attached_file:1]
 
99
 
100
+ frames = out.frames[0] # lista de PILs [attached_file:1]
101
+
102
+ # Salvar GIF em caminho temporário com extensão .gif (evita erro do PIL)
103
+ temp_gif = os.path.join(tempfile.gettempdir(), "animation.gif")
104
+ export_to_gif(frames, temp_gif, fps=int(fps)) # [attached_file:1]
105
 
106
+ # Opcional: gravar MP4 com imageio-ffmpeg
107
  mp4_path = None
108
  if save_mp4:
109
  try:
110
  import imageio
111
  mp4_path = os.path.join(tempfile.gettempdir(), "animation.mp4")
112
+ # Converter cada frame PIL para ndarray esperado pelo writer
113
  with imageio.get_writer(mp4_path, fps=int(fps), codec="libx264", quality=8) as writer:
114
  for fr in frames:
115
+ writer.append_data(imageio.v3.imread(io.BytesIO(fr.convert("RGB").tobytes())))
116
  except Exception:
117
+ mp4_path = None # se falhar, apenas não retorna MP4
118
 
119
+ return temp_gif, mp4_path, f"Gerado {len(frames)} frames @ {fps} fps." # [attached_file:1]
120
 
121
  def ui():
122
  with gr.Blocks(title="AnimateDiff img2vid") as demo:
123
+ gr.Markdown("## AnimateDiff img2vid") # [attached_file:1]
124
  with gr.Row():
125
  with gr.Column(scale=1):
126
+ image = gr.Image(type="pil", label="Imagem inicial") # [attached_file:1]
127
+ prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Descreva estilo/movimento...") # [attached_file:1]
128
+ negative = gr.Textbox(label="Negative prompt", lines=2, value="low quality, worst quality") # [attached_file:1]
129
  with gr.Row():
130
+ frames = gr.Slider(8, 64, value=16, step=1, label="Frames") # [attached_file:1]
131
+ steps = gr.Slider(4, 60, value=25, step=1, label="Steps") # [attached_file:1]
132
  with gr.Row():
133
+ guidance = gr.Slider(0.5, 15.0, value=7.5, step=0.5, label="Guidance") # [attached_file:1]
134
+ fps = gr.Slider(4, 30, value=8, step=1, label="FPS") # [attached_file:1]
135
  with gr.Row():
136
+ seed = gr.Number(value=42, precision=0, label="Seed") # [attached_file:1]
137
+ width = gr.Number(value=None, precision=0, label="Largura (opcional)") # [attached_file:1]
138
+ height = gr.Number(value=None, precision=0, label="Altura (opcional)") # [attached_file:1]
139
  with gr.Row():
140
+ model_id_ui = gr.Textbox(value=MODEL_ID, label="Model ID (SD1.5 finetune)") # [attached_file:1]
141
+ adapter_id_ui = gr.Textbox(value=ADAPTER_ID, label="MotionAdapter ID") # [attached_file:1]
142
  with gr.Row():
143
+ cpu_offload = gr.Checkbox(value=False, label="CPU offload") # [attached_file:1]
144
+ save_mp4 = gr.Checkbox(value=False, label="Salvar MP4") # [attached_file:1]
145
+ run_btn = gr.Button("Gerar animação") # [attached_file:1]
146
  with gr.Column(scale=1):
147
+ video_out = gr.Video(label="Preview (GIF)") # [attached_file:1]
148
+ file_mp4 = gr.File(label="MP4 (download)", interactive=False) # [attached_file:1]
149
+ status = gr.Textbox(label="Status", interactive=False) # [attached_file:1]
150
 
151
  def _run(*args):
152
+ temp_gif, mp4_path, msg = generate(*args)
153
+ return temp_gif, mp4_path, msg # [attached_file:1]
 
 
 
 
 
154
 
155
  run_btn.click(
156
  _run,
 
161
 
162
  if __name__ == "__main__":
163
  demo = ui()
164
+ demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=True) # [attached_file:1]