skytnt commited on
Commit
1184e3c
·
1 Parent(s): 919cf9e

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +180 -163
pipeline.py CHANGED
@@ -1,9 +1,9 @@
1
  import inspect
2
  import re
3
- from typing import Callable, List, Optional, Union
4
  import PIL
5
  import numpy as np
6
  import torch
 
7
  from transformers import CLIPFeatureExtractor, CLIPTokenizer
8
 
9
  from diffusers.onnx_utils import OnnxRuntimeModel
@@ -14,7 +14,8 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14
 
15
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
 
17
- re_attention = re.compile(r"""
 
18
  \\\(|
19
  \\\)|
20
  \\\[|
@@ -28,7 +29,9 @@ re_attention = re.compile(r"""
28
  ]|
29
  [^\\()\[\]:]+|
30
  :
31
- """, re.X)
 
 
32
 
33
 
34
  def parse_prompt_attention(text):
@@ -81,17 +84,17 @@ def parse_prompt_attention(text):
81
  text = m.group(0)
82
  weight = m.group(1)
83
 
84
- if text.startswith('\\'):
85
  res.append([text[1:], 1.0])
86
- elif text == '(':
87
  round_brackets.append(len(res))
88
- elif text == '[':
89
  square_brackets.append(len(res))
90
  elif weight is not None and len(round_brackets) > 0:
91
  multiply_range(round_brackets.pop(), float(weight))
92
- elif text == ')' and len(round_brackets) > 0:
93
  multiply_range(round_brackets.pop(), round_bracket_multiplier)
94
- elif text == ']' and len(square_brackets) > 0:
95
  multiply_range(square_brackets.pop(), square_bracket_multiplier)
96
  else:
97
  res.append([text, 1.0])
@@ -117,11 +120,7 @@ def parse_prompt_attention(text):
117
  return res
118
 
119
 
120
- def get_prompts_with_weights(
121
- pipe,
122
- prompt: List[str],
123
- max_length: int
124
- ):
125
  r"""
126
  Tokenize a list of prompts and return its tokens with weights of each token.
127
 
@@ -155,9 +154,7 @@ def get_prompts_with_weights(
155
  return tokens, weights
156
 
157
 
158
- def pad_tokens_and_weights(tokens, weights, max_length, bos, eos,
159
- no_boseos_middle=True,
160
- chunk_length=77):
161
  r"""
162
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
163
  """
@@ -166,27 +163,24 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos,
166
  for i in range(len(tokens)):
167
  tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
168
  if no_boseos_middle:
169
- weights[i] = [1.] + weights[i] + [1.] * (max_length - 1 - len(weights[i]))
170
  else:
171
  w = []
172
  if len(weights[i]) == 0:
173
- w = [1.] * weights_length
174
  else:
175
  for j in range((len(weights[i]) - 1) // chunk_length + 1):
176
- w.append(1.) # weight for starting token in this chunk
177
- w += weights[i][j * chunk_length: min(len(weights[i]), (j + 1) * chunk_length)]
178
- w.append(1.) # weight for ending token in this chunk
179
- w += [1.] * (weights_length - len(w))
180
  weights[i] = w[:]
181
 
182
  return tokens, weights
183
 
184
 
185
  def get_unweighted_text_embeddings(
186
- pipe,
187
- text_input: np.array,
188
- chunk_length: int,
189
- no_boseos_middle: Optional[bool] = True
190
  ):
191
  """
192
  When the length of tokens is a multiple of the capacity of the text encoder,
@@ -197,7 +191,7 @@ def get_unweighted_text_embeddings(
197
  text_embeddings = []
198
  for i in range(max_embeddings_multiples):
199
  # extract the i-th chunk
200
- text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1) * (chunk_length - 2) + 2].copy()
201
 
202
  # cover the head and the tail by the starting and the ending tokens
203
  text_input_chunk[:, 0] = text_input[0, 0]
@@ -224,14 +218,14 @@ def get_unweighted_text_embeddings(
224
 
225
 
226
  def get_weighted_text_embeddings(
227
- pipe,
228
- prompt: Union[str, List[str]],
229
- uncond_prompt: Optional[Union[str, List[str]]] = None,
230
- max_embeddings_multiples: Optional[int] = 4,
231
- no_boseos_middle: Optional[bool] = False,
232
- skip_parsing: Optional[bool] = False,
233
- skip_weighting: Optional[bool] = False,
234
- **kwargs
235
  ):
236
  r"""
237
  Prompts can be assigned with local weights using brackets. For example,
@@ -269,47 +263,67 @@ def get_weighted_text_embeddings(
269
  uncond_prompt = [uncond_prompt]
270
  uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
271
  else:
272
- prompt_tokens = [token[1:-1] for token in
273
- pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids]
274
- prompt_weights = [[1.] * len(token) for token in prompt_tokens]
 
 
275
  if uncond_prompt is not None:
276
  if isinstance(uncond_prompt, str):
277
  uncond_prompt = [uncond_prompt]
278
- uncond_tokens = [token[1:-1] for token in
279
- pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True,
280
- return_tensors="np").input_ids]
281
- uncond_weights = [[1.] * len(token) for token in uncond_tokens]
 
 
 
282
 
283
  # round up the longest length of tokens to a multiple of (model_max_length - 2)
284
  max_length = max([len(token) for token in prompt_tokens])
285
  if uncond_prompt is not None:
286
  max_length = max(max_length, max([len(token) for token in uncond_tokens]))
287
 
288
- max_embeddings_multiples = min(max_embeddings_multiples,
289
- (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1)
 
290
  max_embeddings_multiples = max(1, max_embeddings_multiples)
291
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
292
 
293
  # pad the length of tokens and weights
294
  bos = pipe.tokenizer.bos_token_id
295
  eos = pipe.tokenizer.eos_token_id
296
- prompt_tokens, prompt_weights = pad_tokens_and_weights(prompt_tokens, prompt_weights, max_length, bos, eos,
297
- no_boseos_middle=no_boseos_middle,
298
- chunk_length=pipe.tokenizer.model_max_length)
 
 
 
 
 
 
299
  prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
300
  if uncond_prompt is not None:
301
- uncond_tokens, uncond_weights = pad_tokens_and_weights(uncond_tokens, uncond_weights, max_length, bos, eos,
302
- no_boseos_middle=no_boseos_middle,
303
- chunk_length=pipe.tokenizer.model_max_length)
 
 
 
 
 
 
304
  uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
305
 
306
  # get the embeddings
307
- text_embeddings = get_unweighted_text_embeddings(pipe, prompt_tokens, pipe.tokenizer.model_max_length,
308
- no_boseos_middle=no_boseos_middle)
 
309
  prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
310
  if uncond_prompt is not None:
311
- uncond_embeddings = get_unweighted_text_embeddings(pipe, uncond_tokens, pipe.tokenizer.model_max_length,
312
- no_boseos_middle=no_boseos_middle)
 
313
  uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
314
 
315
  # assign weights to the prompts and normalize in the sense of mean
@@ -363,15 +377,15 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
363
  """
364
 
365
  def __init__(
366
- self,
367
- vae_encoder: OnnxRuntimeModel,
368
- vae_decoder: OnnxRuntimeModel,
369
- text_encoder: OnnxRuntimeModel,
370
- tokenizer: CLIPTokenizer,
371
- unet: OnnxRuntimeModel,
372
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
373
- safety_checker: OnnxRuntimeModel,
374
- feature_extractor: CLIPFeatureExtractor,
375
  ):
376
  super().__init__()
377
  self.register_modules(
@@ -387,26 +401,26 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
387
 
388
  @torch.no_grad()
389
  def __call__(
390
- self,
391
- prompt: Union[str, List[str]],
392
- negative_prompt: Optional[Union[str, List[str]]] = None,
393
- init_image: Union[np.ndarray, PIL.Image.Image] = None,
394
- mask_image: Union[np.ndarray, PIL.Image.Image] = None,
395
- height: int = 512,
396
- width: int = 512,
397
- num_inference_steps: int = 50,
398
- guidance_scale: float = 7.5,
399
- strength: float = 0.8,
400
- num_images_per_prompt: Optional[int] = 1,
401
- eta: float = 0.0,
402
- generator: Optional[np.random.RandomState] = None,
403
- latents: Optional[np.ndarray] = None,
404
- max_embeddings_multiples: Optional[int] = 3,
405
- output_type: Optional[str] = "pil",
406
- return_dict: bool = True,
407
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
408
- callback_steps: Optional[int] = 1,
409
- **kwargs,
410
  ):
411
  r"""
412
  Function invoked when calling the pipeline for generation.
@@ -417,10 +431,10 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
417
  negative_prompt (`str` or `List[str]`, *optional*):
418
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
419
  if `guidance_scale` is less than `1`).
420
- init_image (`torch.FloatTensor` or `PIL.Image.Image`):
421
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
422
  process.
423
- mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
424
  `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
425
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
426
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
@@ -449,10 +463,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
449
  eta (`float`, *optional*, defaults to 0.0):
450
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
451
  [`schedulers.DDIMScheduler`], will be ignored for others.
452
- generator (`torch.Generator`, *optional*):
453
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
454
- deterministic.
455
- latents (`torch.FloatTensor`, *optional*):
456
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
457
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
458
  tensor will ge generated by sampling using the supplied random `generator`.
@@ -466,7 +479,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
466
  plain tuple.
467
  callback (`Callable`, *optional*):
468
  A function that will be called every `callback_steps` steps during inference. The function will be
469
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
470
  callback_steps (`int`, *optional*, defaults to 1):
471
  The frequency at which the `callback` function will be called. If not specified, the callback will be
472
  called at every step.
@@ -494,7 +507,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
494
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
495
 
496
  if (callback_steps is None) or (
497
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
498
  ):
499
  raise ValueError(
500
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
@@ -527,7 +540,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
527
  prompt=prompt,
528
  uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
529
  max_embeddings_multiples=max_embeddings_multiples,
530
- **kwargs
531
  )
532
 
533
  text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
@@ -587,8 +600,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
587
 
588
  # add noise to latents using the timesteps
589
  noise = generator.randn(*init_latents.shape).astype(latents_dtype)
590
- latents = self.scheduler.add_noise(torch.from_numpy(init_latents), torch.from_numpy(noise),
591
- timesteps).numpy()
 
592
 
593
  t_start = max(num_inference_steps - init_timestep + offset, 0)
594
  timesteps = self.scheduler.timesteps[t_start:]
@@ -623,8 +637,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
623
 
624
  if mask is not None:
625
  # masking
626
- init_latents_proper = self.scheduler.add_noise(torch.from_numpy(init_latents_orig),
627
- torch.from_numpy(noise), torch.tensor([t])).numpy()
 
628
  latents = (init_latents_proper * mask) + (latents * (1 - mask))
629
 
630
  # call the callback, if provided
@@ -636,20 +651,22 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
636
  # it seems likes there is a problem for using half-precision vae decoder if batchsize>1
637
  image = []
638
  for i in range(latents.shape[0]):
639
- image.append(self.vae_decoder(latent_sample=latents[i:i + 1])[0])
640
  image = np.concatenate(image)
641
 
642
  image = np.clip(image / 2 + 0.5, 0, 1)
643
  image = image.transpose((0, 2, 3, 1))
644
 
645
  if self.safety_checker is not None:
646
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image),
647
- return_tensors="np").pixel_values.astype(image.dtype)
648
- # There will throw an error if batchsize>1
 
649
  images, has_nsfw_concept = [], []
650
  for i in range(image.shape[0]):
651
- image_i, has_nsfw_concept_i = self.safety_checker(clip_input=safety_checker_input[i:i + 1],
652
- images=image[i:i + 1])
 
653
  images.append(image_i)
654
  has_nsfw_concept.append(has_nsfw_concept_i)
655
  image = np.concatenate(images)
@@ -665,23 +682,23 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
665
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
666
 
667
  def text2img(
668
- self,
669
- prompt: Union[str, List[str]],
670
- negative_prompt: Optional[Union[str, List[str]]] = None,
671
- height: int = 512,
672
- width: int = 512,
673
- num_inference_steps: int = 50,
674
- guidance_scale: float = 7.5,
675
- num_images_per_prompt: Optional[int] = 1,
676
- eta: float = 0.0,
677
- generator: Optional[np.random.RandomState] = None,
678
- latents: Optional[np.ndarray] = None,
679
- max_embeddings_multiples: Optional[int] = 3,
680
- output_type: Optional[str] = "pil",
681
- return_dict: bool = True,
682
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
683
- callback_steps: Optional[int] = 1,
684
- **kwargs,
685
  ):
686
  r"""
687
  Function for text-to-image generation.
@@ -710,7 +727,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
710
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
711
  [`schedulers.DDIMScheduler`], will be ignored for others.
712
  generator (`np.random.RandomState`, *optional*):
713
- A numpy RandomState to make generation deterministic.
714
  latents (`np.ndarray`, *optional*):
715
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
716
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
@@ -725,7 +742,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
725
  plain tuple.
726
  callback (`Callable`, *optional*):
727
  A function that will be called every `callback_steps` steps during inference. The function will be
728
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
729
  callback_steps (`int`, *optional*, defaults to 1):
730
  The frequency at which the `callback` function will be called. If not specified, the callback will be
731
  called at every step.
@@ -752,26 +769,26 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
752
  return_dict=return_dict,
753
  callback=callback,
754
  callback_steps=callback_steps,
755
- **kwargs
756
  )
757
 
758
  def img2img(
759
- self,
760
- init_image: Union[np.ndarray, PIL.Image.Image],
761
- prompt: Union[str, List[str]],
762
- negative_prompt: Optional[Union[str, List[str]]] = None,
763
- strength: float = 0.8,
764
- num_inference_steps: Optional[int] = 50,
765
- guidance_scale: Optional[float] = 7.5,
766
- num_images_per_prompt: Optional[int] = 1,
767
- eta: Optional[float] = 0.0,
768
- generator: Optional[np.random.RandomState] = None,
769
- max_embeddings_multiples: Optional[int] = 3,
770
- output_type: Optional[str] = "pil",
771
- return_dict: bool = True,
772
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
773
- callback_steps: Optional[int] = 1,
774
- **kwargs,
775
  ):
776
  r"""
777
  Function for image-to-image generation.
@@ -804,8 +821,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
804
  eta (`float`, *optional*, defaults to 0.0):
805
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
806
  [`schedulers.DDIMScheduler`], will be ignored for others.
807
- generator (`torch.Generator`, *optional*):
808
- A numpy RandomState to make generation deterministic.
809
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
810
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
811
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -816,7 +833,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
816
  plain tuple.
817
  callback (`Callable`, *optional*):
818
  A function that will be called every `callback_steps` steps during inference. The function will be
819
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
820
  callback_steps (`int`, *optional*, defaults to 1):
821
  The frequency at which the `callback` function will be called. If not specified, the callback will be
822
  called at every step.
@@ -842,27 +859,27 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
842
  return_dict=return_dict,
843
  callback=callback,
844
  callback_steps=callback_steps,
845
- **kwargs
846
  )
847
 
848
  def inpaint(
849
- self,
850
- init_image: Union[np.ndarray, PIL.Image.Image],
851
- mask_image: Union[np.ndarray, PIL.Image.Image],
852
- prompt: Union[str, List[str]],
853
- negative_prompt: Optional[Union[str, List[str]]] = None,
854
- strength: float = 0.8,
855
- num_inference_steps: Optional[int] = 50,
856
- guidance_scale: Optional[float] = 7.5,
857
- num_images_per_prompt: Optional[int] = 1,
858
- eta: Optional[float] = 0.0,
859
- generator: Optional[np.random.RandomState] = None,
860
- max_embeddings_multiples: Optional[int] = 3,
861
- output_type: Optional[str] = "pil",
862
- return_dict: bool = True,
863
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
864
- callback_steps: Optional[int] = 1,
865
- **kwargs,
866
  ):
867
  r"""
868
  Function for inpaint.
@@ -899,8 +916,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
899
  eta (`float`, *optional*, defaults to 0.0):
900
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
901
  [`schedulers.DDIMScheduler`], will be ignored for others.
902
- generator (`torch.Generator`, *optional*):
903
- A random RandomState to make generation deterministic.
904
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
905
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
906
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -911,7 +928,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
911
  plain tuple.
912
  callback (`Callable`, *optional*):
913
  A function that will be called every `callback_steps` steps during inference. The function will be
914
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
915
  callback_steps (`int`, *optional*, defaults to 1):
916
  The frequency at which the `callback` function will be called. If not specified, the callback will be
917
  called at every step.
@@ -938,5 +955,5 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
938
  return_dict=return_dict,
939
  callback=callback,
940
  callback_steps=callback_steps,
941
- **kwargs
942
  )
 
1
  import inspect
2
  import re
 
3
  import PIL
4
  import numpy as np
5
  import torch
6
+ from typing import Callable, List, Optional, Union
7
  from transformers import CLIPFeatureExtractor, CLIPTokenizer
8
 
9
  from diffusers.onnx_utils import OnnxRuntimeModel
 
14
 
15
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
 
17
+ re_attention = re.compile(
18
+ r"""
19
  \\\(|
20
  \\\)|
21
  \\\[|
 
29
  ]|
30
  [^\\()\[\]:]+|
31
  :
32
+ """,
33
+ re.X,
34
+ )
35
 
36
 
37
  def parse_prompt_attention(text):
 
84
  text = m.group(0)
85
  weight = m.group(1)
86
 
87
+ if text.startswith("\\"):
88
  res.append([text[1:], 1.0])
89
+ elif text == "(":
90
  round_brackets.append(len(res))
91
+ elif text == "[":
92
  square_brackets.append(len(res))
93
  elif weight is not None and len(round_brackets) > 0:
94
  multiply_range(round_brackets.pop(), float(weight))
95
+ elif text == ")" and len(round_brackets) > 0:
96
  multiply_range(round_brackets.pop(), round_bracket_multiplier)
97
+ elif text == "]" and len(square_brackets) > 0:
98
  multiply_range(square_brackets.pop(), square_bracket_multiplier)
99
  else:
100
  res.append([text, 1.0])
 
120
  return res
121
 
122
 
123
+ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
 
 
 
 
124
  r"""
125
  Tokenize a list of prompts and return its tokens with weights of each token.
126
 
 
154
  return tokens, weights
155
 
156
 
157
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
 
 
158
  r"""
159
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
160
  """
 
163
  for i in range(len(tokens)):
164
  tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
165
  if no_boseos_middle:
166
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
167
  else:
168
  w = []
169
  if len(weights[i]) == 0:
170
+ w = [1.0] * weights_length
171
  else:
172
  for j in range((len(weights[i]) - 1) // chunk_length + 1):
173
+ w.append(1.0) # weight for starting token in this chunk
174
+ w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
175
+ w.append(1.0) # weight for ending token in this chunk
176
+ w += [1.0] * (weights_length - len(w))
177
  weights[i] = w[:]
178
 
179
  return tokens, weights
180
 
181
 
182
  def get_unweighted_text_embeddings(
183
+ pipe, text_input: np.array, chunk_length: int, no_boseos_middle: Optional[bool] = True
 
 
 
184
  ):
185
  """
186
  When the length of tokens is a multiple of the capacity of the text encoder,
 
191
  text_embeddings = []
192
  for i in range(max_embeddings_multiples):
193
  # extract the i-th chunk
194
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()
195
 
196
  # cover the head and the tail by the starting and the ending tokens
197
  text_input_chunk[:, 0] = text_input[0, 0]
 
218
 
219
 
220
  def get_weighted_text_embeddings(
221
+ pipe,
222
+ prompt: Union[str, List[str]],
223
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
224
+ max_embeddings_multiples: Optional[int] = 4,
225
+ no_boseos_middle: Optional[bool] = False,
226
+ skip_parsing: Optional[bool] = False,
227
+ skip_weighting: Optional[bool] = False,
228
+ **kwargs,
229
  ):
230
  r"""
231
  Prompts can be assigned with local weights using brackets. For example,
 
263
  uncond_prompt = [uncond_prompt]
264
  uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
265
  else:
266
+ prompt_tokens = [
267
+ token[1:-1]
268
+ for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids
269
+ ]
270
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
271
  if uncond_prompt is not None:
272
  if isinstance(uncond_prompt, str):
273
  uncond_prompt = [uncond_prompt]
274
+ uncond_tokens = [
275
+ token[1:-1]
276
+ for token in pipe.tokenizer(
277
+ uncond_prompt, max_length=max_length, truncation=True, return_tensors="np"
278
+ ).input_ids
279
+ ]
280
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
281
 
282
  # round up the longest length of tokens to a multiple of (model_max_length - 2)
283
  max_length = max([len(token) for token in prompt_tokens])
284
  if uncond_prompt is not None:
285
  max_length = max(max_length, max([len(token) for token in uncond_tokens]))
286
 
287
+ max_embeddings_multiples = min(
288
+ max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
289
+ )
290
  max_embeddings_multiples = max(1, max_embeddings_multiples)
291
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
292
 
293
  # pad the length of tokens and weights
294
  bos = pipe.tokenizer.bos_token_id
295
  eos = pipe.tokenizer.eos_token_id
296
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
297
+ prompt_tokens,
298
+ prompt_weights,
299
+ max_length,
300
+ bos,
301
+ eos,
302
+ no_boseos_middle=no_boseos_middle,
303
+ chunk_length=pipe.tokenizer.model_max_length,
304
+ )
305
  prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
306
  if uncond_prompt is not None:
307
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
308
+ uncond_tokens,
309
+ uncond_weights,
310
+ max_length,
311
+ bos,
312
+ eos,
313
+ no_boseos_middle=no_boseos_middle,
314
+ chunk_length=pipe.tokenizer.model_max_length,
315
+ )
316
  uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
317
 
318
  # get the embeddings
319
+ text_embeddings = get_unweighted_text_embeddings(
320
+ pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
321
+ )
322
  prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
323
  if uncond_prompt is not None:
324
+ uncond_embeddings = get_unweighted_text_embeddings(
325
+ pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
326
+ )
327
  uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
328
 
329
  # assign weights to the prompts and normalize in the sense of mean
 
377
  """
378
 
379
  def __init__(
380
+ self,
381
+ vae_encoder: OnnxRuntimeModel,
382
+ vae_decoder: OnnxRuntimeModel,
383
+ text_encoder: OnnxRuntimeModel,
384
+ tokenizer: CLIPTokenizer,
385
+ unet: OnnxRuntimeModel,
386
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
387
+ safety_checker: OnnxRuntimeModel,
388
+ feature_extractor: CLIPFeatureExtractor,
389
  ):
390
  super().__init__()
391
  self.register_modules(
 
401
 
402
  @torch.no_grad()
403
  def __call__(
404
+ self,
405
+ prompt: Union[str, List[str]],
406
+ negative_prompt: Optional[Union[str, List[str]]] = None,
407
+ init_image: Union[np.ndarray, PIL.Image.Image] = None,
408
+ mask_image: Union[np.ndarray, PIL.Image.Image] = None,
409
+ height: int = 512,
410
+ width: int = 512,
411
+ num_inference_steps: int = 50,
412
+ guidance_scale: float = 7.5,
413
+ strength: float = 0.8,
414
+ num_images_per_prompt: Optional[int] = 1,
415
+ eta: float = 0.0,
416
+ generator: Optional[np.random.RandomState] = None,
417
+ latents: Optional[np.ndarray] = None,
418
+ max_embeddings_multiples: Optional[int] = 3,
419
+ output_type: Optional[str] = "pil",
420
+ return_dict: bool = True,
421
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
422
+ callback_steps: Optional[int] = 1,
423
+ **kwargs,
424
  ):
425
  r"""
426
  Function invoked when calling the pipeline for generation.
 
431
  negative_prompt (`str` or `List[str]`, *optional*):
432
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
433
  if `guidance_scale` is less than `1`).
434
+ init_image (`np.ndarray` or `PIL.Image.Image`):
435
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
436
  process.
437
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
438
  `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
439
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
440
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
 
463
  eta (`float`, *optional*, defaults to 0.0):
464
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
465
  [`schedulers.DDIMScheduler`], will be ignored for others.
466
+ generator (`np.random.RandomState`, *optional*):
467
+ A np.random.RandomState to make generation deterministic.
468
+ latents (`np.ndarray`, *optional*):
 
469
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
470
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
471
  tensor will ge generated by sampling using the supplied random `generator`.
 
479
  plain tuple.
480
  callback (`Callable`, *optional*):
481
  A function that will be called every `callback_steps` steps during inference. The function will be
482
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
483
  callback_steps (`int`, *optional*, defaults to 1):
484
  The frequency at which the `callback` function will be called. If not specified, the callback will be
485
  called at every step.
 
507
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
508
 
509
  if (callback_steps is None) or (
510
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
511
  ):
512
  raise ValueError(
513
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
 
540
  prompt=prompt,
541
  uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
542
  max_embeddings_multiples=max_embeddings_multiples,
543
+ **kwargs,
544
  )
545
 
546
  text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
 
600
 
601
  # add noise to latents using the timesteps
602
  noise = generator.randn(*init_latents.shape).astype(latents_dtype)
603
+ latents = self.scheduler.add_noise(
604
+ torch.from_numpy(init_latents), torch.from_numpy(noise), timesteps
605
+ ).numpy()
606
 
607
  t_start = max(num_inference_steps - init_timestep + offset, 0)
608
  timesteps = self.scheduler.timesteps[t_start:]
 
637
 
638
  if mask is not None:
639
  # masking
640
+ init_latents_proper = self.scheduler.add_noise(
641
+ torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
642
+ ).numpy()
643
  latents = (init_latents_proper * mask) + (latents * (1 - mask))
644
 
645
  # call the callback, if provided
 
651
  # it seems likes there is a problem for using half-precision vae decoder if batchsize>1
652
  image = []
653
  for i in range(latents.shape[0]):
654
+ image.append(self.vae_decoder(latent_sample=latents[i : i + 1])[0])
655
  image = np.concatenate(image)
656
 
657
  image = np.clip(image / 2 + 0.5, 0, 1)
658
  image = image.transpose((0, 2, 3, 1))
659
 
660
  if self.safety_checker is not None:
661
+ safety_checker_input = self.feature_extractor(
662
+ self.numpy_to_pil(image), return_tensors="np"
663
+ ).pixel_values.astype(image.dtype)
664
+ # There will throw an error if use safety_checker directly and batchsize>1
665
  images, has_nsfw_concept = [], []
666
  for i in range(image.shape[0]):
667
+ image_i, has_nsfw_concept_i = self.safety_checker(
668
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
669
+ )
670
  images.append(image_i)
671
  has_nsfw_concept.append(has_nsfw_concept_i)
672
  image = np.concatenate(images)
 
682
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
683
 
684
  def text2img(
685
+ self,
686
+ prompt: Union[str, List[str]],
687
+ negative_prompt: Optional[Union[str, List[str]]] = None,
688
+ height: int = 512,
689
+ width: int = 512,
690
+ num_inference_steps: int = 50,
691
+ guidance_scale: float = 7.5,
692
+ num_images_per_prompt: Optional[int] = 1,
693
+ eta: float = 0.0,
694
+ generator: Optional[np.random.RandomState] = None,
695
+ latents: Optional[np.ndarray] = None,
696
+ max_embeddings_multiples: Optional[int] = 3,
697
+ output_type: Optional[str] = "pil",
698
+ return_dict: bool = True,
699
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
700
+ callback_steps: Optional[int] = 1,
701
+ **kwargs,
702
  ):
703
  r"""
704
  Function for text-to-image generation.
 
727
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
728
  [`schedulers.DDIMScheduler`], will be ignored for others.
729
  generator (`np.random.RandomState`, *optional*):
730
+ A np.random.RandomState to make generation deterministic.
731
  latents (`np.ndarray`, *optional*):
732
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
733
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
 
742
  plain tuple.
743
  callback (`Callable`, *optional*):
744
  A function that will be called every `callback_steps` steps during inference. The function will be
745
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
746
  callback_steps (`int`, *optional*, defaults to 1):
747
  The frequency at which the `callback` function will be called. If not specified, the callback will be
748
  called at every step.
 
769
  return_dict=return_dict,
770
  callback=callback,
771
  callback_steps=callback_steps,
772
+ **kwargs,
773
  )
774
 
775
  def img2img(
776
+ self,
777
+ init_image: Union[np.ndarray, PIL.Image.Image],
778
+ prompt: Union[str, List[str]],
779
+ negative_prompt: Optional[Union[str, List[str]]] = None,
780
+ strength: float = 0.8,
781
+ num_inference_steps: Optional[int] = 50,
782
+ guidance_scale: Optional[float] = 7.5,
783
+ num_images_per_prompt: Optional[int] = 1,
784
+ eta: Optional[float] = 0.0,
785
+ generator: Optional[np.random.RandomState] = None,
786
+ max_embeddings_multiples: Optional[int] = 3,
787
+ output_type: Optional[str] = "pil",
788
+ return_dict: bool = True,
789
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
790
+ callback_steps: Optional[int] = 1,
791
+ **kwargs,
792
  ):
793
  r"""
794
  Function for image-to-image generation.
 
821
  eta (`float`, *optional*, defaults to 0.0):
822
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
823
  [`schedulers.DDIMScheduler`], will be ignored for others.
824
+ generator (`np.random.RandomState`, *optional*):
825
+ A np.random.RandomState to make generation deterministic.
826
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
827
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
828
  output_type (`str`, *optional*, defaults to `"pil"`):
 
833
  plain tuple.
834
  callback (`Callable`, *optional*):
835
  A function that will be called every `callback_steps` steps during inference. The function will be
836
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
837
  callback_steps (`int`, *optional*, defaults to 1):
838
  The frequency at which the `callback` function will be called. If not specified, the callback will be
839
  called at every step.
 
859
  return_dict=return_dict,
860
  callback=callback,
861
  callback_steps=callback_steps,
862
+ **kwargs,
863
  )
864
 
865
  def inpaint(
866
+ self,
867
+ init_image: Union[np.ndarray, PIL.Image.Image],
868
+ mask_image: Union[np.ndarray, PIL.Image.Image],
869
+ prompt: Union[str, List[str]],
870
+ negative_prompt: Optional[Union[str, List[str]]] = None,
871
+ strength: float = 0.8,
872
+ num_inference_steps: Optional[int] = 50,
873
+ guidance_scale: Optional[float] = 7.5,
874
+ num_images_per_prompt: Optional[int] = 1,
875
+ eta: Optional[float] = 0.0,
876
+ generator: Optional[np.random.RandomState] = None,
877
+ max_embeddings_multiples: Optional[int] = 3,
878
+ output_type: Optional[str] = "pil",
879
+ return_dict: bool = True,
880
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
881
+ callback_steps: Optional[int] = 1,
882
+ **kwargs,
883
  ):
884
  r"""
885
  Function for inpaint.
 
916
  eta (`float`, *optional*, defaults to 0.0):
917
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
918
  [`schedulers.DDIMScheduler`], will be ignored for others.
919
+ generator (`np.random.RandomState`, *optional*):
920
+ A np.random.RandomState to make generation deterministic.
921
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
922
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
923
  output_type (`str`, *optional*, defaults to `"pil"`):
 
928
  plain tuple.
929
  callback (`Callable`, *optional*):
930
  A function that will be called every `callback_steps` steps during inference. The function will be
931
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
932
  callback_steps (`int`, *optional*, defaults to 1):
933
  The frequency at which the `callback` function will be called. If not specified, the callback will be
934
  called at every step.
 
955
  return_dict=return_dict,
956
  callback=callback,
957
  callback_steps=callback_steps,
958
+ **kwargs,
959
  )