Update pipeline.py
Browse files- pipeline.py +21 -21
pipeline.py
CHANGED
|
@@ -374,28 +374,28 @@ class AnimateDiffControlNetPipeline(
|
|
| 374 |
|
| 375 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 376 |
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
|
| 398 |
-
|
| 399 |
|
| 400 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 401 |
def prepare_ip_adapter_image_embeds(
|
|
|
|
| 374 |
|
| 375 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 376 |
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 377 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 378 |
+
|
| 379 |
+
if not isinstance(image, torch.Tensor):
|
| 380 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 381 |
+
|
| 382 |
+
image = image.to(device=device, dtype=dtype)
|
| 383 |
+
if output_hidden_states:
|
| 384 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 385 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 386 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 387 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 388 |
+
).hidden_states[-2]
|
| 389 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 390 |
+
num_images_per_prompt, dim=0
|
| 391 |
+
)
|
| 392 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 393 |
+
else:
|
| 394 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 395 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 396 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 397 |
|
| 398 |
+
return image_embeds, uncond_image_embeds
|
| 399 |
|
| 400 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 401 |
def prepare_ip_adapter_image_embeds(
|