Update pipeline.py
Browse files- pipeline.py +180 -163
pipeline.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import inspect
|
| 2 |
import re
|
| 3 |
-
from typing import Callable, List, Optional, Union
|
| 4 |
import PIL
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
|
|
|
| 7 |
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
| 8 |
|
| 9 |
from diffusers.onnx_utils import OnnxRuntimeModel
|
|
@@ -14,7 +14,8 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
| 14 |
|
| 15 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 16 |
|
| 17 |
-
re_attention = re.compile(
|
|
|
|
| 18 |
\\\(|
|
| 19 |
\\\)|
|
| 20 |
\\\[|
|
|
@@ -28,7 +29,9 @@ re_attention = re.compile(r"""
|
|
| 28 |
]|
|
| 29 |
[^\\()\[\]:]+|
|
| 30 |
:
|
| 31 |
-
""",
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def parse_prompt_attention(text):
|
|
@@ -81,17 +84,17 @@ def parse_prompt_attention(text):
|
|
| 81 |
text = m.group(0)
|
| 82 |
weight = m.group(1)
|
| 83 |
|
| 84 |
-
if text.startswith(
|
| 85 |
res.append([text[1:], 1.0])
|
| 86 |
-
elif text ==
|
| 87 |
round_brackets.append(len(res))
|
| 88 |
-
elif text ==
|
| 89 |
square_brackets.append(len(res))
|
| 90 |
elif weight is not None and len(round_brackets) > 0:
|
| 91 |
multiply_range(round_brackets.pop(), float(weight))
|
| 92 |
-
elif text ==
|
| 93 |
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 94 |
-
elif text ==
|
| 95 |
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 96 |
else:
|
| 97 |
res.append([text, 1.0])
|
|
@@ -117,11 +120,7 @@ def parse_prompt_attention(text):
|
|
| 117 |
return res
|
| 118 |
|
| 119 |
|
| 120 |
-
def get_prompts_with_weights(
|
| 121 |
-
pipe,
|
| 122 |
-
prompt: List[str],
|
| 123 |
-
max_length: int
|
| 124 |
-
):
|
| 125 |
r"""
|
| 126 |
Tokenize a list of prompts and return its tokens with weights of each token.
|
| 127 |
|
|
@@ -155,9 +154,7 @@ def get_prompts_with_weights(
|
|
| 155 |
return tokens, weights
|
| 156 |
|
| 157 |
|
| 158 |
-
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos,
|
| 159 |
-
no_boseos_middle=True,
|
| 160 |
-
chunk_length=77):
|
| 161 |
r"""
|
| 162 |
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 163 |
"""
|
|
@@ -166,27 +163,24 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos,
|
|
| 166 |
for i in range(len(tokens)):
|
| 167 |
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
| 168 |
if no_boseos_middle:
|
| 169 |
-
weights[i] = [1.] + weights[i] + [1.] * (max_length - 1 - len(weights[i]))
|
| 170 |
else:
|
| 171 |
w = []
|
| 172 |
if len(weights[i]) == 0:
|
| 173 |
-
w = [1.] * weights_length
|
| 174 |
else:
|
| 175 |
for j in range((len(weights[i]) - 1) // chunk_length + 1):
|
| 176 |
-
w.append(1.) # weight for starting token in this chunk
|
| 177 |
-
w += weights[i][j * chunk_length: min(len(weights[i]), (j + 1) * chunk_length)]
|
| 178 |
-
w.append(1.) # weight for ending token in this chunk
|
| 179 |
-
w += [1.] * (weights_length - len(w))
|
| 180 |
weights[i] = w[:]
|
| 181 |
|
| 182 |
return tokens, weights
|
| 183 |
|
| 184 |
|
| 185 |
def get_unweighted_text_embeddings(
|
| 186 |
-
|
| 187 |
-
text_input: np.array,
|
| 188 |
-
chunk_length: int,
|
| 189 |
-
no_boseos_middle: Optional[bool] = True
|
| 190 |
):
|
| 191 |
"""
|
| 192 |
When the length of tokens is a multiple of the capacity of the text encoder,
|
|
@@ -197,7 +191,7 @@ def get_unweighted_text_embeddings(
|
|
| 197 |
text_embeddings = []
|
| 198 |
for i in range(max_embeddings_multiples):
|
| 199 |
# extract the i-th chunk
|
| 200 |
-
text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1) * (chunk_length - 2) + 2].copy()
|
| 201 |
|
| 202 |
# cover the head and the tail by the starting and the ending tokens
|
| 203 |
text_input_chunk[:, 0] = text_input[0, 0]
|
|
@@ -224,14 +218,14 @@ def get_unweighted_text_embeddings(
|
|
| 224 |
|
| 225 |
|
| 226 |
def get_weighted_text_embeddings(
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
):
|
| 236 |
r"""
|
| 237 |
Prompts can be assigned with local weights using brackets. For example,
|
|
@@ -269,47 +263,67 @@ def get_weighted_text_embeddings(
|
|
| 269 |
uncond_prompt = [uncond_prompt]
|
| 270 |
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
| 271 |
else:
|
| 272 |
-
prompt_tokens = [
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
| 275 |
if uncond_prompt is not None:
|
| 276 |
if isinstance(uncond_prompt, str):
|
| 277 |
uncond_prompt = [uncond_prompt]
|
| 278 |
-
uncond_tokens = [
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
| 284 |
max_length = max([len(token) for token in prompt_tokens])
|
| 285 |
if uncond_prompt is not None:
|
| 286 |
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
| 287 |
|
| 288 |
-
max_embeddings_multiples = min(
|
| 289 |
-
|
|
|
|
| 290 |
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
| 291 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 292 |
|
| 293 |
# pad the length of tokens and weights
|
| 294 |
bos = pipe.tokenizer.bos_token_id
|
| 295 |
eos = pipe.tokenizer.eos_token_id
|
| 296 |
-
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
|
| 300 |
if uncond_prompt is not None:
|
| 301 |
-
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
|
| 305 |
|
| 306 |
# get the embeddings
|
| 307 |
-
text_embeddings = get_unweighted_text_embeddings(
|
| 308 |
-
|
|
|
|
| 309 |
prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
|
| 310 |
if uncond_prompt is not None:
|
| 311 |
-
uncond_embeddings = get_unweighted_text_embeddings(
|
| 312 |
-
|
|
|
|
| 313 |
uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
|
| 314 |
|
| 315 |
# assign weights to the prompts and normalize in the sense of mean
|
|
@@ -363,15 +377,15 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 363 |
"""
|
| 364 |
|
| 365 |
def __init__(
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
):
|
| 376 |
super().__init__()
|
| 377 |
self.register_modules(
|
|
@@ -387,26 +401,26 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 387 |
|
| 388 |
@torch.no_grad()
|
| 389 |
def __call__(
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
):
|
| 411 |
r"""
|
| 412 |
Function invoked when calling the pipeline for generation.
|
|
@@ -417,10 +431,10 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 417 |
negative_prompt (`str` or `List[str]`, *optional*):
|
| 418 |
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 419 |
if `guidance_scale` is less than `1`).
|
| 420 |
-
init_image (`
|
| 421 |
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 422 |
process.
|
| 423 |
-
mask_image (`
|
| 424 |
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
|
| 425 |
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 426 |
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
|
@@ -449,10 +463,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 449 |
eta (`float`, *optional*, defaults to 0.0):
|
| 450 |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 451 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 452 |
-
generator (`
|
| 453 |
-
A
|
| 454 |
-
|
| 455 |
-
latents (`torch.FloatTensor`, *optional*):
|
| 456 |
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 457 |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 458 |
tensor will ge generated by sampling using the supplied random `generator`.
|
|
@@ -466,7 +479,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 466 |
plain tuple.
|
| 467 |
callback (`Callable`, *optional*):
|
| 468 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 469 |
-
called with the following arguments: `callback(step: int, timestep: int, latents:
|
| 470 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 471 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 472 |
called at every step.
|
|
@@ -494,7 +507,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 494 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 495 |
|
| 496 |
if (callback_steps is None) or (
|
| 497 |
-
|
| 498 |
):
|
| 499 |
raise ValueError(
|
| 500 |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
@@ -527,7 +540,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 527 |
prompt=prompt,
|
| 528 |
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
| 529 |
max_embeddings_multiples=max_embeddings_multiples,
|
| 530 |
-
**kwargs
|
| 531 |
)
|
| 532 |
|
| 533 |
text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
|
|
@@ -587,8 +600,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 587 |
|
| 588 |
# add noise to latents using the timesteps
|
| 589 |
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
|
| 590 |
-
latents = self.scheduler.add_noise(
|
| 591 |
-
|
|
|
|
| 592 |
|
| 593 |
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
| 594 |
timesteps = self.scheduler.timesteps[t_start:]
|
|
@@ -623,8 +637,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 623 |
|
| 624 |
if mask is not None:
|
| 625 |
# masking
|
| 626 |
-
init_latents_proper = self.scheduler.add_noise(
|
| 627 |
-
|
|
|
|
| 628 |
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 629 |
|
| 630 |
# call the callback, if provided
|
|
@@ -636,20 +651,22 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 636 |
# it seems likes there is a problem for using half-precision vae decoder if batchsize>1
|
| 637 |
image = []
|
| 638 |
for i in range(latents.shape[0]):
|
| 639 |
-
image.append(self.vae_decoder(latent_sample=latents[i:i + 1])[0])
|
| 640 |
image = np.concatenate(image)
|
| 641 |
|
| 642 |
image = np.clip(image / 2 + 0.5, 0, 1)
|
| 643 |
image = image.transpose((0, 2, 3, 1))
|
| 644 |
|
| 645 |
if self.safety_checker is not None:
|
| 646 |
-
safety_checker_input = self.feature_extractor(
|
| 647 |
-
|
| 648 |
-
|
|
|
|
| 649 |
images, has_nsfw_concept = [], []
|
| 650 |
for i in range(image.shape[0]):
|
| 651 |
-
image_i, has_nsfw_concept_i = self.safety_checker(
|
| 652 |
-
|
|
|
|
| 653 |
images.append(image_i)
|
| 654 |
has_nsfw_concept.append(has_nsfw_concept_i)
|
| 655 |
image = np.concatenate(images)
|
|
@@ -665,23 +682,23 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 665 |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 666 |
|
| 667 |
def text2img(
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
):
|
| 686 |
r"""
|
| 687 |
Function for text-to-image generation.
|
|
@@ -710,7 +727,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 710 |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 711 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 712 |
generator (`np.random.RandomState`, *optional*):
|
| 713 |
-
A
|
| 714 |
latents (`np.ndarray`, *optional*):
|
| 715 |
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 716 |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
@@ -725,7 +742,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 725 |
plain tuple.
|
| 726 |
callback (`Callable`, *optional*):
|
| 727 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 728 |
-
called with the following arguments: `callback(step: int, timestep: int, latents:
|
| 729 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 730 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 731 |
called at every step.
|
|
@@ -752,26 +769,26 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 752 |
return_dict=return_dict,
|
| 753 |
callback=callback,
|
| 754 |
callback_steps=callback_steps,
|
| 755 |
-
**kwargs
|
| 756 |
)
|
| 757 |
|
| 758 |
def img2img(
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
):
|
| 776 |
r"""
|
| 777 |
Function for image-to-image generation.
|
|
@@ -804,8 +821,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 804 |
eta (`float`, *optional*, defaults to 0.0):
|
| 805 |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 806 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 807 |
-
generator (`
|
| 808 |
-
A
|
| 809 |
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 810 |
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 811 |
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
@@ -816,7 +833,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 816 |
plain tuple.
|
| 817 |
callback (`Callable`, *optional*):
|
| 818 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 819 |
-
called with the following arguments: `callback(step: int, timestep: int, latents:
|
| 820 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 821 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 822 |
called at every step.
|
|
@@ -842,27 +859,27 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 842 |
return_dict=return_dict,
|
| 843 |
callback=callback,
|
| 844 |
callback_steps=callback_steps,
|
| 845 |
-
**kwargs
|
| 846 |
)
|
| 847 |
|
| 848 |
def inpaint(
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
):
|
| 867 |
r"""
|
| 868 |
Function for inpaint.
|
|
@@ -899,8 +916,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 899 |
eta (`float`, *optional*, defaults to 0.0):
|
| 900 |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 901 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 902 |
-
generator (`
|
| 903 |
-
A random
|
| 904 |
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 905 |
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 906 |
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
@@ -911,7 +928,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 911 |
plain tuple.
|
| 912 |
callback (`Callable`, *optional*):
|
| 913 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 914 |
-
called with the following arguments: `callback(step: int, timestep: int, latents:
|
| 915 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 916 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 917 |
called at every step.
|
|
@@ -938,5 +955,5 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
| 938 |
return_dict=return_dict,
|
| 939 |
callback=callback,
|
| 940 |
callback_steps=callback_steps,
|
| 941 |
-
**kwargs
|
| 942 |
)
|
|
|
|
| 1 |
import inspect
|
| 2 |
import re
|
|
|
|
| 3 |
import PIL
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
+
from typing import Callable, List, Optional, Union
|
| 7 |
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
| 8 |
|
| 9 |
from diffusers.onnx_utils import OnnxRuntimeModel
|
|
|
|
| 14 |
|
| 15 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 16 |
|
| 17 |
+
re_attention = re.compile(
|
| 18 |
+
r"""
|
| 19 |
\\\(|
|
| 20 |
\\\)|
|
| 21 |
\\\[|
|
|
|
|
| 29 |
]|
|
| 30 |
[^\\()\[\]:]+|
|
| 31 |
:
|
| 32 |
+
""",
|
| 33 |
+
re.X,
|
| 34 |
+
)
|
| 35 |
|
| 36 |
|
| 37 |
def parse_prompt_attention(text):
|
|
|
|
| 84 |
text = m.group(0)
|
| 85 |
weight = m.group(1)
|
| 86 |
|
| 87 |
+
if text.startswith("\\"):
|
| 88 |
res.append([text[1:], 1.0])
|
| 89 |
+
elif text == "(":
|
| 90 |
round_brackets.append(len(res))
|
| 91 |
+
elif text == "[":
|
| 92 |
square_brackets.append(len(res))
|
| 93 |
elif weight is not None and len(round_brackets) > 0:
|
| 94 |
multiply_range(round_brackets.pop(), float(weight))
|
| 95 |
+
elif text == ")" and len(round_brackets) > 0:
|
| 96 |
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 97 |
+
elif text == "]" and len(square_brackets) > 0:
|
| 98 |
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 99 |
else:
|
| 100 |
res.append([text, 1.0])
|
|
|
|
| 120 |
return res
|
| 121 |
|
| 122 |
|
| 123 |
+
def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
r"""
|
| 125 |
Tokenize a list of prompts and return its tokens with weights of each token.
|
| 126 |
|
|
|
|
| 154 |
return tokens, weights
|
| 155 |
|
| 156 |
|
| 157 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
|
|
|
|
|
|
| 158 |
r"""
|
| 159 |
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 160 |
"""
|
|
|
|
| 163 |
for i in range(len(tokens)):
|
| 164 |
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
| 165 |
if no_boseos_middle:
|
| 166 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 167 |
else:
|
| 168 |
w = []
|
| 169 |
if len(weights[i]) == 0:
|
| 170 |
+
w = [1.0] * weights_length
|
| 171 |
else:
|
| 172 |
for j in range((len(weights[i]) - 1) // chunk_length + 1):
|
| 173 |
+
w.append(1.0) # weight for starting token in this chunk
|
| 174 |
+
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
|
| 175 |
+
w.append(1.0) # weight for ending token in this chunk
|
| 176 |
+
w += [1.0] * (weights_length - len(w))
|
| 177 |
weights[i] = w[:]
|
| 178 |
|
| 179 |
return tokens, weights
|
| 180 |
|
| 181 |
|
| 182 |
def get_unweighted_text_embeddings(
|
| 183 |
+
pipe, text_input: np.array, chunk_length: int, no_boseos_middle: Optional[bool] = True
|
|
|
|
|
|
|
|
|
|
| 184 |
):
|
| 185 |
"""
|
| 186 |
When the length of tokens is a multiple of the capacity of the text encoder,
|
|
|
|
| 191 |
text_embeddings = []
|
| 192 |
for i in range(max_embeddings_multiples):
|
| 193 |
# extract the i-th chunk
|
| 194 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()
|
| 195 |
|
| 196 |
# cover the head and the tail by the starting and the ending tokens
|
| 197 |
text_input_chunk[:, 0] = text_input[0, 0]
|
|
|
|
| 218 |
|
| 219 |
|
| 220 |
def get_weighted_text_embeddings(
|
| 221 |
+
pipe,
|
| 222 |
+
prompt: Union[str, List[str]],
|
| 223 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
| 224 |
+
max_embeddings_multiples: Optional[int] = 4,
|
| 225 |
+
no_boseos_middle: Optional[bool] = False,
|
| 226 |
+
skip_parsing: Optional[bool] = False,
|
| 227 |
+
skip_weighting: Optional[bool] = False,
|
| 228 |
+
**kwargs,
|
| 229 |
):
|
| 230 |
r"""
|
| 231 |
Prompts can be assigned with local weights using brackets. For example,
|
|
|
|
| 263 |
uncond_prompt = [uncond_prompt]
|
| 264 |
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
| 265 |
else:
|
| 266 |
+
prompt_tokens = [
|
| 267 |
+
token[1:-1]
|
| 268 |
+
for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids
|
| 269 |
+
]
|
| 270 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
| 271 |
if uncond_prompt is not None:
|
| 272 |
if isinstance(uncond_prompt, str):
|
| 273 |
uncond_prompt = [uncond_prompt]
|
| 274 |
+
uncond_tokens = [
|
| 275 |
+
token[1:-1]
|
| 276 |
+
for token in pipe.tokenizer(
|
| 277 |
+
uncond_prompt, max_length=max_length, truncation=True, return_tensors="np"
|
| 278 |
+
).input_ids
|
| 279 |
+
]
|
| 280 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
| 281 |
|
| 282 |
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
| 283 |
max_length = max([len(token) for token in prompt_tokens])
|
| 284 |
if uncond_prompt is not None:
|
| 285 |
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
| 286 |
|
| 287 |
+
max_embeddings_multiples = min(
|
| 288 |
+
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
|
| 289 |
+
)
|
| 290 |
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
| 291 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 292 |
|
| 293 |
# pad the length of tokens and weights
|
| 294 |
bos = pipe.tokenizer.bos_token_id
|
| 295 |
eos = pipe.tokenizer.eos_token_id
|
| 296 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 297 |
+
prompt_tokens,
|
| 298 |
+
prompt_weights,
|
| 299 |
+
max_length,
|
| 300 |
+
bos,
|
| 301 |
+
eos,
|
| 302 |
+
no_boseos_middle=no_boseos_middle,
|
| 303 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
| 304 |
+
)
|
| 305 |
prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
|
| 306 |
if uncond_prompt is not None:
|
| 307 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
| 308 |
+
uncond_tokens,
|
| 309 |
+
uncond_weights,
|
| 310 |
+
max_length,
|
| 311 |
+
bos,
|
| 312 |
+
eos,
|
| 313 |
+
no_boseos_middle=no_boseos_middle,
|
| 314 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
| 315 |
+
)
|
| 316 |
uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
|
| 317 |
|
| 318 |
# get the embeddings
|
| 319 |
+
text_embeddings = get_unweighted_text_embeddings(
|
| 320 |
+
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
| 321 |
+
)
|
| 322 |
prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
|
| 323 |
if uncond_prompt is not None:
|
| 324 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
| 325 |
+
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
| 326 |
+
)
|
| 327 |
uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
|
| 328 |
|
| 329 |
# assign weights to the prompts and normalize in the sense of mean
|
|
|
|
| 377 |
"""
|
| 378 |
|
| 379 |
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
vae_encoder: OnnxRuntimeModel,
|
| 382 |
+
vae_decoder: OnnxRuntimeModel,
|
| 383 |
+
text_encoder: OnnxRuntimeModel,
|
| 384 |
+
tokenizer: CLIPTokenizer,
|
| 385 |
+
unet: OnnxRuntimeModel,
|
| 386 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
| 387 |
+
safety_checker: OnnxRuntimeModel,
|
| 388 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 389 |
):
|
| 390 |
super().__init__()
|
| 391 |
self.register_modules(
|
|
|
|
| 401 |
|
| 402 |
@torch.no_grad()
|
| 403 |
def __call__(
|
| 404 |
+
self,
|
| 405 |
+
prompt: Union[str, List[str]],
|
| 406 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 407 |
+
init_image: Union[np.ndarray, PIL.Image.Image] = None,
|
| 408 |
+
mask_image: Union[np.ndarray, PIL.Image.Image] = None,
|
| 409 |
+
height: int = 512,
|
| 410 |
+
width: int = 512,
|
| 411 |
+
num_inference_steps: int = 50,
|
| 412 |
+
guidance_scale: float = 7.5,
|
| 413 |
+
strength: float = 0.8,
|
| 414 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 415 |
+
eta: float = 0.0,
|
| 416 |
+
generator: Optional[np.random.RandomState] = None,
|
| 417 |
+
latents: Optional[np.ndarray] = None,
|
| 418 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 419 |
+
output_type: Optional[str] = "pil",
|
| 420 |
+
return_dict: bool = True,
|
| 421 |
+
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
| 422 |
+
callback_steps: Optional[int] = 1,
|
| 423 |
+
**kwargs,
|
| 424 |
):
|
| 425 |
r"""
|
| 426 |
Function invoked when calling the pipeline for generation.
|
|
|
|
| 431 |
negative_prompt (`str` or `List[str]`, *optional*):
|
| 432 |
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 433 |
if `guidance_scale` is less than `1`).
|
| 434 |
+
init_image (`np.ndarray` or `PIL.Image.Image`):
|
| 435 |
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 436 |
process.
|
| 437 |
+
mask_image (`np.ndarray` or `PIL.Image.Image`):
|
| 438 |
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
|
| 439 |
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 440 |
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
|
|
|
| 463 |
eta (`float`, *optional*, defaults to 0.0):
|
| 464 |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 465 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 466 |
+
generator (`np.random.RandomState`, *optional*):
|
| 467 |
+
A np.random.RandomState to make generation deterministic.
|
| 468 |
+
latents (`np.ndarray`, *optional*):
|
|
|
|
| 469 |
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 470 |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 471 |
tensor will ge generated by sampling using the supplied random `generator`.
|
|
|
|
| 479 |
plain tuple.
|
| 480 |
callback (`Callable`, *optional*):
|
| 481 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 482 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
|
| 483 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 484 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 485 |
called at every step.
|
|
|
|
| 507 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 508 |
|
| 509 |
if (callback_steps is None) or (
|
| 510 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 511 |
):
|
| 512 |
raise ValueError(
|
| 513 |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
|
|
| 540 |
prompt=prompt,
|
| 541 |
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
| 542 |
max_embeddings_multiples=max_embeddings_multiples,
|
| 543 |
+
**kwargs,
|
| 544 |
)
|
| 545 |
|
| 546 |
text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
|
|
|
|
| 600 |
|
| 601 |
# add noise to latents using the timesteps
|
| 602 |
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
|
| 603 |
+
latents = self.scheduler.add_noise(
|
| 604 |
+
torch.from_numpy(init_latents), torch.from_numpy(noise), timesteps
|
| 605 |
+
).numpy()
|
| 606 |
|
| 607 |
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
| 608 |
timesteps = self.scheduler.timesteps[t_start:]
|
|
|
|
| 637 |
|
| 638 |
if mask is not None:
|
| 639 |
# masking
|
| 640 |
+
init_latents_proper = self.scheduler.add_noise(
|
| 641 |
+
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
|
| 642 |
+
).numpy()
|
| 643 |
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 644 |
|
| 645 |
# call the callback, if provided
|
|
|
|
| 651 |
# it seems likes there is a problem for using half-precision vae decoder if batchsize>1
|
| 652 |
image = []
|
| 653 |
for i in range(latents.shape[0]):
|
| 654 |
+
image.append(self.vae_decoder(latent_sample=latents[i : i + 1])[0])
|
| 655 |
image = np.concatenate(image)
|
| 656 |
|
| 657 |
image = np.clip(image / 2 + 0.5, 0, 1)
|
| 658 |
image = image.transpose((0, 2, 3, 1))
|
| 659 |
|
| 660 |
if self.safety_checker is not None:
|
| 661 |
+
safety_checker_input = self.feature_extractor(
|
| 662 |
+
self.numpy_to_pil(image), return_tensors="np"
|
| 663 |
+
).pixel_values.astype(image.dtype)
|
| 664 |
+
# There will throw an error if use safety_checker directly and batchsize>1
|
| 665 |
images, has_nsfw_concept = [], []
|
| 666 |
for i in range(image.shape[0]):
|
| 667 |
+
image_i, has_nsfw_concept_i = self.safety_checker(
|
| 668 |
+
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
| 669 |
+
)
|
| 670 |
images.append(image_i)
|
| 671 |
has_nsfw_concept.append(has_nsfw_concept_i)
|
| 672 |
image = np.concatenate(images)
|
|
|
|
| 682 |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 683 |
|
| 684 |
def text2img(
|
| 685 |
+
self,
|
| 686 |
+
prompt: Union[str, List[str]],
|
| 687 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 688 |
+
height: int = 512,
|
| 689 |
+
width: int = 512,
|
| 690 |
+
num_inference_steps: int = 50,
|
| 691 |
+
guidance_scale: float = 7.5,
|
| 692 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 693 |
+
eta: float = 0.0,
|
| 694 |
+
generator: Optional[np.random.RandomState] = None,
|
| 695 |
+
latents: Optional[np.ndarray] = None,
|
| 696 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 697 |
+
output_type: Optional[str] = "pil",
|
| 698 |
+
return_dict: bool = True,
|
| 699 |
+
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
| 700 |
+
callback_steps: Optional[int] = 1,
|
| 701 |
+
**kwargs,
|
| 702 |
):
|
| 703 |
r"""
|
| 704 |
Function for text-to-image generation.
|
|
|
|
| 727 |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 728 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 729 |
generator (`np.random.RandomState`, *optional*):
|
| 730 |
+
A np.random.RandomState to make generation deterministic.
|
| 731 |
latents (`np.ndarray`, *optional*):
|
| 732 |
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 733 |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
|
|
| 742 |
plain tuple.
|
| 743 |
callback (`Callable`, *optional*):
|
| 744 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 745 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
|
| 746 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 747 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 748 |
called at every step.
|
|
|
|
| 769 |
return_dict=return_dict,
|
| 770 |
callback=callback,
|
| 771 |
callback_steps=callback_steps,
|
| 772 |
+
**kwargs,
|
| 773 |
)
|
| 774 |
|
| 775 |
def img2img(
|
| 776 |
+
self,
|
| 777 |
+
init_image: Union[np.ndarray, PIL.Image.Image],
|
| 778 |
+
prompt: Union[str, List[str]],
|
| 779 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 780 |
+
strength: float = 0.8,
|
| 781 |
+
num_inference_steps: Optional[int] = 50,
|
| 782 |
+
guidance_scale: Optional[float] = 7.5,
|
| 783 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 784 |
+
eta: Optional[float] = 0.0,
|
| 785 |
+
generator: Optional[np.random.RandomState] = None,
|
| 786 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 787 |
+
output_type: Optional[str] = "pil",
|
| 788 |
+
return_dict: bool = True,
|
| 789 |
+
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
| 790 |
+
callback_steps: Optional[int] = 1,
|
| 791 |
+
**kwargs,
|
| 792 |
):
|
| 793 |
r"""
|
| 794 |
Function for image-to-image generation.
|
|
|
|
| 821 |
eta (`float`, *optional*, defaults to 0.0):
|
| 822 |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 823 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 824 |
+
generator (`np.random.RandomState`, *optional*):
|
| 825 |
+
A np.random.RandomState to make generation deterministic.
|
| 826 |
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 827 |
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 828 |
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
|
|
| 833 |
plain tuple.
|
| 834 |
callback (`Callable`, *optional*):
|
| 835 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 836 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
|
| 837 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 838 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 839 |
called at every step.
|
|
|
|
| 859 |
return_dict=return_dict,
|
| 860 |
callback=callback,
|
| 861 |
callback_steps=callback_steps,
|
| 862 |
+
**kwargs,
|
| 863 |
)
|
| 864 |
|
| 865 |
def inpaint(
|
| 866 |
+
self,
|
| 867 |
+
init_image: Union[np.ndarray, PIL.Image.Image],
|
| 868 |
+
mask_image: Union[np.ndarray, PIL.Image.Image],
|
| 869 |
+
prompt: Union[str, List[str]],
|
| 870 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 871 |
+
strength: float = 0.8,
|
| 872 |
+
num_inference_steps: Optional[int] = 50,
|
| 873 |
+
guidance_scale: Optional[float] = 7.5,
|
| 874 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 875 |
+
eta: Optional[float] = 0.0,
|
| 876 |
+
generator: Optional[np.random.RandomState] = None,
|
| 877 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 878 |
+
output_type: Optional[str] = "pil",
|
| 879 |
+
return_dict: bool = True,
|
| 880 |
+
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
| 881 |
+
callback_steps: Optional[int] = 1,
|
| 882 |
+
**kwargs,
|
| 883 |
):
|
| 884 |
r"""
|
| 885 |
Function for inpaint.
|
|
|
|
| 916 |
eta (`float`, *optional*, defaults to 0.0):
|
| 917 |
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 918 |
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 919 |
+
generator (`np.random.RandomState`, *optional*):
|
| 920 |
+
A np.random.RandomState to make generation deterministic.
|
| 921 |
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 922 |
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 923 |
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
|
|
| 928 |
plain tuple.
|
| 929 |
callback (`Callable`, *optional*):
|
| 930 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 931 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
|
| 932 |
callback_steps (`int`, *optional*, defaults to 1):
|
| 933 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 934 |
called at every step.
|
|
|
|
| 955 |
return_dict=return_dict,
|
| 956 |
callback=callback,
|
| 957 |
callback_steps=callback_steps,
|
| 958 |
+
**kwargs,
|
| 959 |
)
|