Akshat7777 commited on
Commit
730aa68
·
verified ·
1 Parent(s): 78e4a3d

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +42 -25
generate.py CHANGED
@@ -1,37 +1,54 @@
1
  import torch
2
- from diffusers import AnimateDiffPipeline, DDIMScheduler
 
3
 
4
- # Load model only once (on import)
5
  def load_model():
6
- # Example AnimateDiff model from HF Hub
7
- model_id = "guoyww/animatediff-motion-adapter-v1-5"
 
 
 
 
8
 
9
- pipe = AnimateDiffPipeline.from_pretrained(
10
- model_id,
11
- torch_dtype=torch.float16
12
- )
 
 
13
 
14
- scheduler = DDIMScheduler.from_pretrained(
15
- "runwayml/stable-diffusion-v1-5",
16
- subfolder="scheduler"
17
- )
 
18
 
19
- pipe.scheduler = scheduler
20
- pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- return pipe
 
23
 
 
 
 
24
 
25
- # Global model instance (loaded once)
26
  pipe = load_model()
27
 
28
 
29
- # Generation function
30
- def generate(prompt: str, num_inference_steps: int = 50, guidance_scale: float = 7.5):
31
- with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
32
- result = pipe(
33
- prompt,
34
- num_inference_steps=num_inference_steps,
35
- guidance_scale=guidance_scale
36
- )
37
- return result.frames[0] # returning first frame (you can adapt this to video/gif)
 
 
 
 
 
 
 
1
  import torch
2
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
3
+ from diffusers.utils import export_to_gif
4
 
 
5
  def load_model():
6
+ try:
7
+ # Load Motion Adapter
8
+ adapter = MotionAdapter.from_pretrained(
9
+ "guoyww/animatediff-motion-adapter-v1-5",
10
+ torch_dtype=torch.float16
11
+ )
12
 
13
+ # Load AnimateDiff pipeline with Stable Diffusion 1.5
14
+ pipeline = AnimateDiffPipeline.from_pretrained(
15
+ "runwayml/stable-diffusion-v1-5",
16
+ motion_adapter=adapter,
17
+ torch_dtype=torch.float16
18
+ )
19
 
20
+ # Use Euler scheduler (smoother animations)
21
+ pipeline.scheduler = EulerDiscreteScheduler.from_config(
22
+ pipeline.scheduler.config,
23
+ timestep_spacing="trailing"
24
+ )
25
 
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ pipeline = pipeline.to(device)
28
 
29
+ print("✅ Models loaded successfully!")
30
+ return pipeline
31
 
32
+ except Exception as e:
33
+ print(f"❌ Error during model loading: {e}")
34
+ raise
35
 
36
+ # Load once globally
37
  pipe = load_model()
38
 
39
 
40
+ def generate(prompt: str, num_frames: int = 16, steps: int = 25, guidance: float = 7.5, seed: int = 42, out_path: str = "output.gif"):
41
+ """
42
+ Generate an animated GIF from a text prompt.
43
+ """
44
+ generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
45
+ result = pipe(
46
+ prompt=prompt,
47
+ num_frames=num_frames,
48
+ num_inference_steps=steps,
49
+ guidance_scale=guidance,
50
+ generator=generator
51
+ )
52
+ frames = result.frames[0]
53
+ export_to_gif(frames, out_path)
54
+ return out_path