Spaces:
Runtime error
Runtime error
Update fluxcombined.py
Browse files- fluxcombined.py +5 -283
fluxcombined.py
CHANGED
|
@@ -874,7 +874,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
| 874 |
# initialize the random noise for denoising
|
| 875 |
latents = random_latents.clone().detach()
|
| 876 |
|
| 877 |
-
self.vae = self.vae.to(torch.float32)
|
| 878 |
|
| 879 |
# 9. Denoising loop
|
| 880 |
self.transformer.eval()
|
|
@@ -959,7 +959,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
| 959 |
else:
|
| 960 |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 961 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 962 |
-
image = self.vae.decode(latents
|
| 963 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 964 |
|
| 965 |
# Offload all models
|
|
@@ -973,7 +973,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
| 973 |
def get_diff_image(self, latents):
|
| 974 |
latents = self._unpack_latents(latents, 1024, 1024, self.vae_scale_factor)
|
| 975 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 976 |
-
image = self.vae.decode(latents
|
| 977 |
image = self.image_processor.postprocess(image, output_type="pt")
|
| 978 |
return image
|
| 979 |
|
|
@@ -983,7 +983,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
| 983 |
img = img.resize((512, 512))
|
| 984 |
return custom_image_processor(img).unsqueeze(0).to(device)
|
| 985 |
|
| 986 |
-
|
| 987 |
@torch.no_grad()
|
| 988 |
def edit(
|
| 989 |
self,
|
|
@@ -1019,283 +1018,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
| 1019 |
mask_image=None,
|
| 1020 |
source_steps=1,
|
| 1021 |
):
|
| 1022 |
-
|
| 1023 |
-
height = height or self.default_sample_size * self.vae_scale_factor
|
| 1024 |
-
width = width or self.default_sample_size * self.vae_scale_factor
|
| 1025 |
-
|
| 1026 |
-
# 1. Check inputs. Raise error if not correct
|
| 1027 |
-
self.check_inputs(
|
| 1028 |
-
prompt,
|
| 1029 |
-
prompt_2,
|
| 1030 |
-
height,
|
| 1031 |
-
width,
|
| 1032 |
-
# negative_prompt=negative_prompt,
|
| 1033 |
-
# negative_prompt_2=negative_prompt_2,
|
| 1034 |
-
prompt_embeds=prompt_embeds,
|
| 1035 |
-
# negative_prompt_embeds=negative_prompt_embeds,
|
| 1036 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1037 |
-
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1038 |
-
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 1039 |
-
max_sequence_length=max_sequence_length,
|
| 1040 |
-
)
|
| 1041 |
-
|
| 1042 |
-
self._guidance_scale = guidance_scale
|
| 1043 |
-
self._joint_attention_kwargs = joint_attention_kwargs
|
| 1044 |
-
self._interrupt = False
|
| 1045 |
-
|
| 1046 |
-
# 2. Define call parameters
|
| 1047 |
-
if prompt is not None and isinstance(prompt, str):
|
| 1048 |
-
batch_size = 1
|
| 1049 |
-
elif prompt is not None and isinstance(prompt, list):
|
| 1050 |
-
batch_size = len(prompt)
|
| 1051 |
-
else:
|
| 1052 |
-
batch_size = prompt_embeds.shape[0]
|
| 1053 |
-
|
| 1054 |
-
device = self._execution_device
|
| 1055 |
-
|
| 1056 |
-
lora_scale = (
|
| 1057 |
-
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 1058 |
-
)
|
| 1059 |
-
do_true_cfg = true_cfg > 1 and negative_prompt is not None
|
| 1060 |
-
(
|
| 1061 |
-
prompt_embeds,
|
| 1062 |
-
pooled_prompt_embeds,
|
| 1063 |
-
text_ids,
|
| 1064 |
-
negative_prompt_embeds,
|
| 1065 |
-
negative_pooled_prompt_embeds,
|
| 1066 |
-
) = self.encode_prompt_edit(
|
| 1067 |
-
prompt=prompt,
|
| 1068 |
-
prompt_2=prompt_2,
|
| 1069 |
-
negative_prompt=negative_prompt,
|
| 1070 |
-
negative_prompt_2=negative_prompt_2,
|
| 1071 |
-
prompt_embeds=prompt_embeds,
|
| 1072 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1073 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
| 1074 |
-
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1075 |
-
device=device,
|
| 1076 |
-
num_images_per_prompt=num_images_per_prompt,
|
| 1077 |
-
max_sequence_length=max_sequence_length,
|
| 1078 |
-
lora_scale=lora_scale,
|
| 1079 |
-
do_true_cfg=do_true_cfg,
|
| 1080 |
-
)
|
| 1081 |
-
# text_ids = text_ids.repeat(batch_size, 1, 1)
|
| 1082 |
-
|
| 1083 |
-
if do_true_cfg:
|
| 1084 |
-
# Concatenate embeddings
|
| 1085 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1086 |
-
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 1087 |
-
|
| 1088 |
-
# 4. Prepare latent variables
|
| 1089 |
-
num_channels_latents = self.transformer.config.in_channels // 4
|
| 1090 |
-
random_latents, latent_image_ids = self.prepare_latents(
|
| 1091 |
-
batch_size * num_images_per_prompt,
|
| 1092 |
-
num_channels_latents,
|
| 1093 |
-
height,
|
| 1094 |
-
width,
|
| 1095 |
-
prompt_embeds.dtype,
|
| 1096 |
-
device,
|
| 1097 |
-
generator,
|
| 1098 |
-
latents,
|
| 1099 |
-
)
|
| 1100 |
-
# latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1)
|
| 1101 |
-
|
| 1102 |
-
# 5. Prepare timesteps
|
| 1103 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 1104 |
-
image_seq_len = random_latents.shape[1]
|
| 1105 |
-
mu = calculate_shift(
|
| 1106 |
-
image_seq_len,
|
| 1107 |
-
self.scheduler.config.base_image_seq_len,
|
| 1108 |
-
self.scheduler.config.max_image_seq_len,
|
| 1109 |
-
self.scheduler.config.base_shift,
|
| 1110 |
-
self.scheduler.config.max_shift,
|
| 1111 |
-
)
|
| 1112 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1113 |
-
self.scheduler,
|
| 1114 |
-
num_inference_steps,
|
| 1115 |
-
device,
|
| 1116 |
-
timesteps,
|
| 1117 |
-
sigmas,
|
| 1118 |
-
mu=mu,
|
| 1119 |
-
)
|
| 1120 |
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1121 |
-
self._num_timesteps = len(timesteps)
|
| 1122 |
-
|
| 1123 |
-
# 4. Preprocess image
|
| 1124 |
-
image = self.image_processor.preprocess(input_image)
|
| 1125 |
-
image = image.to(device=device, dtype=self.transformer.dtype)
|
| 1126 |
-
latents = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
# Convert PIL image to tensor
|
| 1130 |
-
if mask_image:
|
| 1131 |
-
from torchvision import transforms as TF
|
| 1132 |
-
|
| 1133 |
-
h, w = latents.shape[2], latents.shape[3]
|
| 1134 |
-
mask = TF.ToTensor()(mask_image).to(device=device, dtype=self.transformer.dtype)
|
| 1135 |
-
mask = TF.Resize((h, w), interpolation=TF.InterpolationMode.NEAREST)(mask)
|
| 1136 |
-
mask = (mask > 0.5).float()
|
| 1137 |
-
mask = mask.squeeze(0)#.squeeze(0) # Remove the added dimensions
|
| 1138 |
-
else:
|
| 1139 |
-
mask = torch.ones_like(latents).to(device=device)
|
| 1140 |
-
|
| 1141 |
-
print(mask.shape, latents.shape)
|
| 1142 |
-
|
| 1143 |
-
bool_mask = mask.unsqueeze(0).unsqueeze(0).expand_as(latents)
|
| 1144 |
-
mask=(1-bool_mask*1.0).to(latents.dtype)
|
| 1145 |
-
|
| 1146 |
-
masked_latents = (latents * mask).clone().detach() # apply the mask and get gt_latents
|
| 1147 |
-
masked_latents = self._pack_latents(masked_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
|
| 1148 |
-
|
| 1149 |
-
source_latents = (latents).clone().detach() # apply the mask and get gt_latents
|
| 1150 |
-
source_latents = self._pack_latents(source_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
|
| 1151 |
-
|
| 1152 |
-
mask = self._pack_latents(mask, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
|
| 1153 |
-
|
| 1154 |
-
# initialize the random noise for denoising
|
| 1155 |
-
latents = random_latents.clone().detach()
|
| 1156 |
-
|
| 1157 |
-
self.vae = self.vae.to(torch.float32)
|
| 1158 |
-
|
| 1159 |
-
# 9. Denoising loop
|
| 1160 |
-
self.transformer.eval()
|
| 1161 |
-
self.vae.eval()
|
| 1162 |
-
|
| 1163 |
-
# 6. Denoising loop
|
| 1164 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1165 |
-
for i, t in enumerate(timesteps):
|
| 1166 |
-
if self.interrupt:
|
| 1167 |
-
continue
|
| 1168 |
-
|
| 1169 |
-
latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
|
| 1170 |
-
|
| 1171 |
-
# handle guidance
|
| 1172 |
-
if self.transformer.config.guidance_embeds:
|
| 1173 |
-
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 1174 |
-
guidance = guidance.expand(latent_model_input.shape[0])
|
| 1175 |
-
else:
|
| 1176 |
-
guidance = None
|
| 1177 |
-
|
| 1178 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1179 |
-
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
| 1180 |
-
|
| 1181 |
-
noise_pred = self.transformer(
|
| 1182 |
-
hidden_states=latent_model_input,
|
| 1183 |
-
timestep=timestep / 1000,
|
| 1184 |
-
guidance=guidance,
|
| 1185 |
-
pooled_projections=pooled_prompt_embeds,
|
| 1186 |
-
encoder_hidden_states=prompt_embeds,
|
| 1187 |
-
txt_ids=text_ids,
|
| 1188 |
-
img_ids=latent_image_ids,
|
| 1189 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1190 |
-
return_dict=False,
|
| 1191 |
-
)[0]
|
| 1192 |
-
|
| 1193 |
-
if do_true_cfg:
|
| 1194 |
-
neg_noise_pred, noise_pred = noise_pred.chunk(2)
|
| 1195 |
-
# noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
|
| 1196 |
-
noise_pred = noise_pred + (1-mask)*(noise_pred - neg_noise_pred) * true_cfg
|
| 1197 |
-
# else:
|
| 1198 |
-
# neg_noise_pred, noise_pred = noise_pred.chunk(2)
|
| 1199 |
-
|
| 1200 |
-
# perform CG
|
| 1201 |
-
if i < max_steps:
|
| 1202 |
-
opt_latents = latents.detach().clone()
|
| 1203 |
-
with torch.enable_grad():
|
| 1204 |
-
opt_latents = opt_latents.detach().requires_grad_()
|
| 1205 |
-
opt_latents = torch.autograd.Variable(opt_latents, requires_grad=True)
|
| 1206 |
-
# optimizer = torch.optim.Adam([opt_latents], lr=learning_rate)
|
| 1207 |
-
|
| 1208 |
-
for _ in range(optimization_steps):
|
| 1209 |
-
latents_p = self.scheduler.step_final(noise_pred, t, opt_latents, return_dict=False)[0]
|
| 1210 |
-
if i < source_steps:
|
| 1211 |
-
loss = (1000*torch.nn.functional.mse_loss(latents_p, source_latents, reduction='none')).mean()
|
| 1212 |
-
else:
|
| 1213 |
-
loss = (1000*torch.nn.functional.mse_loss(latents_p, masked_latents, reduction='none')*mask).mean()
|
| 1214 |
-
|
| 1215 |
-
grad = torch.autograd.grad(loss, opt_latents)[0]
|
| 1216 |
-
# grad = torch.clamp(grad, -0.5, 0.5)
|
| 1217 |
-
opt_latents = opt_latents - learning_rate * grad
|
| 1218 |
-
|
| 1219 |
-
latents = opt_latents.detach().clone()
|
| 1220 |
-
|
| 1221 |
-
|
| 1222 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 1223 |
-
latents_dtype = latents.dtype
|
| 1224 |
-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1225 |
-
|
| 1226 |
-
if latents.dtype != latents_dtype:
|
| 1227 |
-
if torch.backends.mps.is_available():
|
| 1228 |
-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1229 |
-
latents = latents.to(latents_dtype)
|
| 1230 |
-
|
| 1231 |
-
if callback_on_step_end is not None:
|
| 1232 |
-
callback_kwargs = {}
|
| 1233 |
-
for k in callback_on_step_end_tensor_inputs:
|
| 1234 |
-
callback_kwargs[k] = locals()[k]
|
| 1235 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1236 |
-
|
| 1237 |
-
latents = callback_outputs.pop("latents", latents)
|
| 1238 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1239 |
-
|
| 1240 |
-
# call the callback, if provided
|
| 1241 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1242 |
-
progress_bar.update()
|
| 1243 |
-
|
| 1244 |
-
if XLA_AVAILABLE:
|
| 1245 |
-
xm.mark_step()
|
| 1246 |
-
|
| 1247 |
-
if output_type == "latent":
|
| 1248 |
-
image = latents
|
| 1249 |
-
|
| 1250 |
-
else:
|
| 1251 |
-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1252 |
-
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1253 |
-
image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
|
| 1254 |
-
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1255 |
-
|
| 1256 |
-
# Offload all models
|
| 1257 |
-
self.maybe_free_model_hooks()
|
| 1258 |
-
|
| 1259 |
-
if not return_dict:
|
| 1260 |
-
return (image,)
|
| 1261 |
-
|
| 1262 |
-
return FluxPipelineOutput(images=image)
|
| 1263 |
-
|
| 1264 |
-
@torch.no_grad()
|
| 1265 |
-
def edit2(
|
| 1266 |
-
self,
|
| 1267 |
-
prompt: Union[str, List[str]] = None,
|
| 1268 |
-
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 1269 |
-
negative_prompt: Union[str, List[str]] = None, #
|
| 1270 |
-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 1271 |
-
true_cfg: float = 1.0, #
|
| 1272 |
-
height: Optional[int] = None,
|
| 1273 |
-
width: Optional[int] = None,
|
| 1274 |
-
num_inference_steps: int = 28,
|
| 1275 |
-
timesteps: List[int] = None,
|
| 1276 |
-
guidance_scale: float = 3.5,
|
| 1277 |
-
num_images_per_prompt: Optional[int] = 1,
|
| 1278 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 1279 |
-
latents: Optional[torch.FloatTensor] = None,
|
| 1280 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1281 |
-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1282 |
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1283 |
-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1284 |
-
output_type: Optional[str] = "pil",
|
| 1285 |
-
return_dict: bool = True,
|
| 1286 |
-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1287 |
-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 1288 |
-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 1289 |
-
max_sequence_length: int = 512,
|
| 1290 |
-
optimization_steps: int = 3,
|
| 1291 |
-
learning_rate: float = 0.8,
|
| 1292 |
-
max_steps: int = 5,
|
| 1293 |
-
input_image = None,
|
| 1294 |
-
save_masked_image = False,
|
| 1295 |
-
output_path="",
|
| 1296 |
-
mask_image=None,
|
| 1297 |
-
source_steps=1,
|
| 1298 |
-
):
|
| 1299 |
r"""
|
| 1300 |
Function invoked when calling the pipeline for generation.
|
| 1301 |
|
|
@@ -1498,7 +1220,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
| 1498 |
# initialize the random noise for denoising
|
| 1499 |
latents = random_latents.clone().detach()
|
| 1500 |
|
| 1501 |
-
self.vae = self.vae.to(torch.float32)
|
| 1502 |
|
| 1503 |
# 9. Denoising loop
|
| 1504 |
self.transformer.eval()
|
|
@@ -1594,7 +1316,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
| 1594 |
else:
|
| 1595 |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1596 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1597 |
-
image = self.vae.decode(latents
|
| 1598 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1599 |
|
| 1600 |
# Offload all models
|
|
|
|
| 874 |
# initialize the random noise for denoising
|
| 875 |
latents = random_latents.clone().detach()
|
| 876 |
|
| 877 |
+
# self.vae = self.vae.to(torch.float32)
|
| 878 |
|
| 879 |
# 9. Denoising loop
|
| 880 |
self.transformer.eval()
|
|
|
|
| 959 |
else:
|
| 960 |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 961 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 962 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 963 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 964 |
|
| 965 |
# Offload all models
|
|
|
|
| 973 |
def get_diff_image(self, latents):
|
| 974 |
latents = self._unpack_latents(latents, 1024, 1024, self.vae_scale_factor)
|
| 975 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 976 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 977 |
image = self.image_processor.postprocess(image, output_type="pt")
|
| 978 |
return image
|
| 979 |
|
|
|
|
| 983 |
img = img.resize((512, 512))
|
| 984 |
return custom_image_processor(img).unsqueeze(0).to(device)
|
| 985 |
|
|
|
|
| 986 |
@torch.no_grad()
|
| 987 |
def edit(
|
| 988 |
self,
|
|
|
|
| 1018 |
mask_image=None,
|
| 1019 |
source_steps=1,
|
| 1020 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
r"""
|
| 1022 |
Function invoked when calling the pipeline for generation.
|
| 1023 |
|
|
|
|
| 1220 |
# initialize the random noise for denoising
|
| 1221 |
latents = random_latents.clone().detach()
|
| 1222 |
|
| 1223 |
+
# self.vae = self.vae.to(torch.float32)
|
| 1224 |
|
| 1225 |
# 9. Denoising loop
|
| 1226 |
self.transformer.eval()
|
|
|
|
| 1316 |
else:
|
| 1317 |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1318 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1319 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1320 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1321 |
|
| 1322 |
# Offload all models
|