| | import contextlib |
| | import copy |
| | import random |
| | from typing import Any, Dict, Iterable, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from .models import UNet2DConditionModel |
| | from .utils import deprecate, is_transformers_available |
| |
|
| |
|
| | if is_transformers_available(): |
| | import transformers |
| |
|
| |
|
| | def set_seed(seed: int): |
| | """ |
| | Args: |
| | Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. |
| | seed (`int`): The seed to set. |
| | """ |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | |
| |
|
| |
|
| | def compute_snr(noise_scheduler, timesteps): |
| | """ |
| | Computes SNR as per |
| | https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
| | """ |
| | alphas_cumprod = noise_scheduler.alphas_cumprod |
| | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| |
|
| | |
| | |
| | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() |
| | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
| | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
| | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
| |
|
| | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() |
| | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
| | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
| | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
| |
|
| | |
| | snr = (alpha / sigma) ** 2 |
| | return snr |
| |
|
| |
|
| | def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: |
| | r""" |
| | Returns: |
| | A state dict containing just the LoRA parameters. |
| | """ |
| | lora_state_dict = {} |
| |
|
| | for name, module in unet.named_modules(): |
| | if hasattr(module, "set_lora_layer"): |
| | lora_layer = getattr(module, "lora_layer") |
| | if lora_layer is not None: |
| | current_lora_layer_sd = lora_layer.state_dict() |
| | for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): |
| | |
| | lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param |
| |
|
| | return lora_state_dict |
| |
|
| |
|
| | |
| | class EMAModel: |
| | """ |
| | Exponential Moving Average of models weights |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | parameters: Iterable[torch.nn.Parameter], |
| | decay: float = 0.9999, |
| | min_decay: float = 0.0, |
| | update_after_step: int = 0, |
| | use_ema_warmup: bool = False, |
| | inv_gamma: Union[float, int] = 1.0, |
| | power: Union[float, int] = 2 / 3, |
| | model_cls: Optional[Any] = None, |
| | model_config: Dict[str, Any] = None, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | parameters (Iterable[torch.nn.Parameter]): The parameters to track. |
| | decay (float): The decay factor for the exponential moving average. |
| | min_decay (float): The minimum decay factor for the exponential moving average. |
| | update_after_step (int): The number of steps to wait before starting to update the EMA weights. |
| | use_ema_warmup (bool): Whether to use EMA warmup. |
| | inv_gamma (float): |
| | Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. |
| | power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. |
| | device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA |
| | weights will be stored on CPU. |
| | |
| | @crowsonkb's notes on EMA Warmup: |
| | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan |
| | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), |
| | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 |
| | at 215.4k steps). |
| | """ |
| |
|
| | if isinstance(parameters, torch.nn.Module): |
| | deprecation_message = ( |
| | "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " |
| | "Please pass the parameters of the module instead." |
| | ) |
| | deprecate( |
| | "passing a `torch.nn.Module` to `ExponentialMovingAverage`", |
| | "1.0.0", |
| | deprecation_message, |
| | standard_warn=False, |
| | ) |
| | parameters = parameters.parameters() |
| |
|
| | |
| | use_ema_warmup = True |
| |
|
| | if kwargs.get("max_value", None) is not None: |
| | deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." |
| | deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) |
| | decay = kwargs["max_value"] |
| |
|
| | if kwargs.get("min_value", None) is not None: |
| | deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." |
| | deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) |
| | min_decay = kwargs["min_value"] |
| |
|
| | parameters = list(parameters) |
| | self.shadow_params = [p.clone().detach() for p in parameters] |
| |
|
| | if kwargs.get("device", None) is not None: |
| | deprecation_message = "The `device` argument is deprecated. Please use `to` instead." |
| | deprecate("device", "1.0.0", deprecation_message, standard_warn=False) |
| | self.to(device=kwargs["device"]) |
| |
|
| | self.temp_stored_params = None |
| |
|
| | self.decay = decay |
| | self.min_decay = min_decay |
| | self.update_after_step = update_after_step |
| | self.use_ema_warmup = use_ema_warmup |
| | self.inv_gamma = inv_gamma |
| | self.power = power |
| | self.optimization_step = 0 |
| | self.cur_decay_value = None |
| |
|
| | self.model_cls = model_cls |
| | self.model_config = model_config |
| |
|
| | @classmethod |
| | def from_pretrained(cls, path, model_cls) -> "EMAModel": |
| | _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) |
| | model = model_cls.from_pretrained(path) |
| |
|
| | ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) |
| |
|
| | ema_model.load_state_dict(ema_kwargs) |
| | return ema_model |
| |
|
| | def save_pretrained(self, path): |
| | if self.model_cls is None: |
| | raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") |
| |
|
| | if self.model_config is None: |
| | raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") |
| |
|
| | model = self.model_cls.from_config(self.model_config) |
| | state_dict = self.state_dict() |
| | state_dict.pop("shadow_params", None) |
| |
|
| | model.register_to_config(**state_dict) |
| | self.copy_to(model.parameters()) |
| | model.save_pretrained(path) |
| |
|
| | def get_decay(self, optimization_step: int) -> float: |
| | """ |
| | Compute the decay factor for the exponential moving average. |
| | """ |
| | step = max(0, optimization_step - self.update_after_step - 1) |
| |
|
| | if step <= 0: |
| | return 0.0 |
| |
|
| | if self.use_ema_warmup: |
| | cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power |
| | else: |
| | cur_decay_value = (1 + step) / (10 + step) |
| |
|
| | cur_decay_value = min(cur_decay_value, self.decay) |
| | |
| | cur_decay_value = max(cur_decay_value, self.min_decay) |
| | return cur_decay_value |
| |
|
| | @torch.no_grad() |
| | def step(self, parameters: Iterable[torch.nn.Parameter]): |
| | if isinstance(parameters, torch.nn.Module): |
| | deprecation_message = ( |
| | "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " |
| | "Please pass the parameters of the module instead." |
| | ) |
| | deprecate( |
| | "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", |
| | "1.0.0", |
| | deprecation_message, |
| | standard_warn=False, |
| | ) |
| | parameters = parameters.parameters() |
| |
|
| | parameters = list(parameters) |
| |
|
| | self.optimization_step += 1 |
| |
|
| | |
| | decay = self.get_decay(self.optimization_step) |
| | self.cur_decay_value = decay |
| | one_minus_decay = 1 - decay |
| |
|
| | context_manager = contextlib.nullcontext |
| | if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): |
| | import deepspeed |
| |
|
| | for s_param, param in zip(self.shadow_params, parameters): |
| | if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): |
| | context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) |
| |
|
| | with context_manager(): |
| | if param.requires_grad: |
| | s_param.sub_(one_minus_decay * (s_param - param)) |
| | else: |
| | s_param.copy_(param) |
| |
|
| | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: |
| | """ |
| | Copy current averaged parameters into given collection of parameters. |
| | |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | updated with the stored moving averages. If `None`, the parameters with which this |
| | `ExponentialMovingAverage` was initialized will be used. |
| | """ |
| | parameters = list(parameters) |
| | for s_param, param in zip(self.shadow_params, parameters): |
| | param.data.copy_(s_param.to(param.device).data) |
| |
|
| | def to(self, device=None, dtype=None) -> None: |
| | r"""Move internal buffers of the ExponentialMovingAverage to `device`. |
| | |
| | Args: |
| | device: like `device` argument to `torch.Tensor.to` |
| | """ |
| | |
| | self.shadow_params = [ |
| | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) |
| | for p in self.shadow_params |
| | ] |
| |
|
| | def state_dict(self) -> dict: |
| | r""" |
| | Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during |
| | checkpointing to save the ema state dict. |
| | """ |
| | |
| | |
| | |
| | return { |
| | "decay": self.decay, |
| | "min_decay": self.min_decay, |
| | "optimization_step": self.optimization_step, |
| | "update_after_step": self.update_after_step, |
| | "use_ema_warmup": self.use_ema_warmup, |
| | "inv_gamma": self.inv_gamma, |
| | "power": self.power, |
| | "shadow_params": self.shadow_params, |
| | } |
| |
|
| | def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: |
| | r""" |
| | Args: |
| | Save the current parameters for restoring later. |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | temporarily stored. |
| | """ |
| | self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] |
| |
|
| | def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: |
| | r""" |
| | Args: |
| | Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: |
| | affecting the original optimization process. Store the parameters before the `copy_to()` method. After |
| | validation (or model saving), use this to restore the former parameters. |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | updated with the stored parameters. If `None`, the parameters with which this |
| | `ExponentialMovingAverage` was initialized will be used. |
| | """ |
| | if self.temp_stored_params is None: |
| | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") |
| | for c_param, param in zip(self.temp_stored_params, parameters): |
| | param.data.copy_(c_param.data) |
| |
|
| | |
| | self.temp_stored_params = None |
| |
|
| | def load_state_dict(self, state_dict: dict) -> None: |
| | r""" |
| | Args: |
| | Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the |
| | ema state dict. |
| | state_dict (dict): EMA state. Should be an object returned |
| | from a call to :meth:`state_dict`. |
| | """ |
| | |
| | state_dict = copy.deepcopy(state_dict) |
| |
|
| | self.decay = state_dict.get("decay", self.decay) |
| | if self.decay < 0.0 or self.decay > 1.0: |
| | raise ValueError("Decay must be between 0 and 1") |
| |
|
| | self.min_decay = state_dict.get("min_decay", self.min_decay) |
| | if not isinstance(self.min_decay, float): |
| | raise ValueError("Invalid min_decay") |
| |
|
| | self.optimization_step = state_dict.get("optimization_step", self.optimization_step) |
| | if not isinstance(self.optimization_step, int): |
| | raise ValueError("Invalid optimization_step") |
| |
|
| | self.update_after_step = state_dict.get("update_after_step", self.update_after_step) |
| | if not isinstance(self.update_after_step, int): |
| | raise ValueError("Invalid update_after_step") |
| |
|
| | self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) |
| | if not isinstance(self.use_ema_warmup, bool): |
| | raise ValueError("Invalid use_ema_warmup") |
| |
|
| | self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) |
| | if not isinstance(self.inv_gamma, (float, int)): |
| | raise ValueError("Invalid inv_gamma") |
| |
|
| | self.power = state_dict.get("power", self.power) |
| | if not isinstance(self.power, (float, int)): |
| | raise ValueError("Invalid power") |
| |
|
| | shadow_params = state_dict.get("shadow_params", None) |
| | if shadow_params is not None: |
| | self.shadow_params = shadow_params |
| | if not isinstance(self.shadow_params, list): |
| | raise ValueError("shadow_params must be a list") |
| | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): |
| | raise ValueError("shadow_params must all be Tensors") |
| |
|