prediff_code / models /knowledge_alignment /alignment_pl_base.py
weatherforecast1024's picture
Upload folder using huggingface_hub
7667a87 verified
"""
Code is adapted from https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/diffusion/ddpm.py
"""
import warnings
from typing import Sequence, Union, Dict, Any, Optional, Callable
from functools import partial
from einops import rearrange
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import lightning.pytorch as pl
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from diffusers.models.autoencoder_kl import AutoencoderKLOutput, DecoderOutput
from models.model_utils.distributions import DiagonalGaussianDistribution
from models.core_model.diffusion_utils import make_beta_schedule, extract_into_tensor, default
from utils.layout import parse_layout_shape
from utils.optim import disabled_train, get_loss_fn
class AlignmentPL(pl.LightningModule):
def __init__(
self,
torch_nn_module: nn.Module,
target_fn: Callable,
layout: str = "NTHWC",
timesteps=1000,
beta_schedule="linear",
loss_type: str = "l2",
monitor="val_loss",
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
# latent diffusion
first_stage_model: Union[Dict[str, Any], nn.Module] = None,
cond_stage_model: Union[str, Dict[str, Any], nn.Module] = None,
num_timesteps_cond=None,
cond_stage_trainable=False,
cond_stage_forward=None,
scale_by_std=False,
scale_factor=1.0,
):
r"""
Parameters
----------
Parameters
----------
torch_nn_module: nn.Module
The `.forward()` method of model should have the following signature:
`output = model.forward(zt, t, y, zc, **kwargs)`
target_fn: Callable
The function that the `torch_nn_module` is going to learn.
The signature of `target_fn` should be:
`violation_score = target_fn(x, y=None, **kwargs)`
layout: str
e.g., "NTHWC", "NHWC".
timesteps: int
1000 by default.
beta_schedule: str
one of ["linear", "cosine", "sqrt_linear", "sqrt"].
loss_type: str
one of ["l2", "l1"].
monitor: str
name of logged var for selecting best val model.
linear_start: float
linear_end: float
cosine_s: float
given_betas: Optional
If provided, `linear_start`, `linear_end`, `cosine_s` take no effect.
If None, `linear_start`, `linear_end`, `cosine_s` are used to generate betas via `make_beta_schedule()`.
first_stage_model: Dict or nn.Module
Dict : configs for instantiating the first_stage_model.
nn.Module : a model that has method ".encode()" to encode the inputs.
cond_stage_model: str or Dict or nn.Module
"__is_first_stage__": use the first_stage_model also for encoding conditionings.
Dict : configs for instantiating the cond_stage_model.
nn.Module : a model that has method ".encode()" or use `self()` to encodes the conditionings.
cond_stage_trainable: bool
Whether to train the cond_stage_model jointly
num_timesteps_cond: int
cond_stage_forward: str
The name of the forward method of the cond_stage_model.
scale_by_std
scale_factor
"""
super(AlignmentPL, self).__init__()
self.torch_nn_module = torch_nn_module
self.target_fn = target_fn
self.loss_fn = get_loss_fn(loss_type)
self.layout = layout
self.parse_layout_shape(layout=layout)
if monitor is not None:
self.monitor = monitor
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.num_timesteps_cond = default(num_timesteps_cond, 1)
assert self.num_timesteps_cond <= timesteps
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
self.cond_stage_trainable = cond_stage_trainable
self.scale_by_std = scale_by_std
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_model)
self.instantiate_cond_stage(cond_stage_model, cond_stage_forward)
def parse_layout_shape(self, layout):
parsed_dict = parse_layout_shape(layout=layout)
self.batch_axis = parsed_dict["batch_axis"]
self.t_axis = parsed_dict["t_axis"]
self.h_axis = parsed_dict["h_axis"]
self.w_axis = parsed_dict["w_axis"]
self.c_axis = parsed_dict["c_axis"]
self.all_slice = [slice(None, None), ] * len(layout)
def extract_into_tensor(self, a, t, x_shape):
return extract_into_tensor(a=a, t=t, x_shape=x_shape,
batch_axis=self.batch_axis)
@property
def loss_mean_dim(self):
# mean over all dims except for batch_axis.
if not hasattr(self, "_loss_mean_dim"):
_loss_mean_dim = list(range(len(self.layout)))
_loss_mean_dim.pop(self.batch_axis)
self._loss_mean_dim = tuple(_loss_mean_dim)
return self._loss_mean_dim
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def make_cond_schedule(self, ):
cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
cond_ids[:self.num_timesteps_cond] = ids
self.register_buffer('cond_ids', cond_ids)
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
# TODO: restarted_from_ckpt not configured
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x, _ = self.get_input(batch)
x = x.to(self.device)
x = rearrange(x, f"{self.einops_layout} -> {self.einops_spatial_layout}")
z = self.encode_first_stage(x)
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def instantiate_first_stage(self, first_stage_model):
if isinstance(first_stage_model, nn.Module):
model = first_stage_model
else:
assert first_stage_model is None
raise NotImplementedError("No default first_stage_model supported yet!")
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, cond_stage_model, cond_stage_forward):
if cond_stage_model is None:
self.cond_stage_model = None
self.cond_stage_forward = None
return
is_first_stage_flag = cond_stage_model == "__is_first_stage__"
if cond_stage_model == "__is_first_stage__":
model = self.first_stage_model
if self.cond_stage_trainable:
warnings.warn("`cond_stage_trainable` is True while `cond_stage_model` is '__is_first_stage__'. "
"force `cond_stage_trainable` to be False")
self.cond_stage_trainable = False
elif isinstance(cond_stage_model, nn.Module):
model = cond_stage_model
else:
raise NotImplementedError
self.cond_stage_model = model
if (self.cond_stage_model is not None) and (not self.cond_stage_trainable):
for param in self.cond_stage_model.parameters():
param.requires_grad = False
if cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
cond_stage_forward = self.cond_stage_model.encode
else:
cond_stage_forward = self.cond_stage_model.__call__
else:
assert hasattr(self.cond_stage_model, cond_stage_forward)
cond_stage_forward = getattr(self.cond_stage_model, cond_stage_forward)
def wrapper(cond_stage_forward: Callable, is_first_stage_flag=False):
def func(c: Dict[str, Any]):
if is_first_stage_flag:
# in this case, `cond_stage_model` is equivalent to `self.first_stage_model`,
# which takes `torch.Tensor` instead of `Dict` as input.
c = c.get("y") # get the conditioning tensor
batch_size = c.shape[self.batch_axis]
c = rearrange(c, f"{self.einops_layout} -> {self.einops_spatial_layout}")
c = cond_stage_forward(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
elif isinstance(c, AutoencoderKLOutput):
c = c.latent_dist.mode()
else:
pass
if is_first_stage_flag:
c = rearrange(c, f"{self.einops_spatial_layout} -> {self.einops_layout}", N=batch_size)
return c
return func
self.cond_stage_forward = wrapper(cond_stage_forward, is_first_stage_flag)
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
elif isinstance(encoder_posterior, AutoencoderKLOutput):
z = encoder_posterior.latent_dist.sample()
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
r"""
Try the following approaches to encode the conditional input `c`:
1. `self.cond_stage_forward` is a str, call the method of `self.cond_stage_model`.
2. call `encode()` method of `self.cond_stage_model`.
3. call `forward()` of `self.cond_stage_model`, i.e., `self.cond_stage_model()`.
"""
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
@property
def einops_layout(self):
return " ".join(self.layout)
@property
def einops_spatial_layout(self):
if not hasattr(self, "_einops_spatial_layout"):
assert len(self.layout) == 4 or len(self.layout) == 5
self._einops_spatial_layout = "(N T) C H W" if self.layout.find("T") else "N C H W"
return self._einops_spatial_layout
@torch.no_grad()
def decode_first_stage(self, z, force_not_quantize=False):
z = 1. / self.scale_factor * z
batch_size = z.shape[self.batch_axis]
z = rearrange(z, f"{self.einops_layout} -> {self.einops_spatial_layout}")
output = self.first_stage_model.decode(z)
if isinstance(output, DecoderOutput):
output = output.sample
output = rearrange(output, f"{self.einops_spatial_layout} -> {self.einops_layout}", N=batch_size)
return output
@torch.no_grad()
def encode_first_stage(self, x):
encoder_posterior = self.first_stage_model.encode(x)
output = self.get_first_stage_encoding(encoder_posterior).detach()
return output
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (self.extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
self.extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
@torch.no_grad()
def get_input(self, batch, **kwargs):
r"""
dataset dependent
re-implement it for each specific dataset
Parameters
----------
batch: Any
raw data batch from specific dataloader
Returns
-------
out: Sequence[torch.Tensor, Dict[str, Any]]
out[0] should be a torch.Tensor which is the target to generate
out[1] should be a dict consists of several key-value pairs for conditioning
"""
return batch
def forward(self, batch, t=None, verbose=False, return_verbose=False, **kwargs):
# similar to latent diffusion
x, c, aux_input_dict = self.get_input(batch) # torch.Tensor, Dict[str, Any], Dict[str, Any]
if verbose:
print("inputs:")
print(f"x.shape = {x.shape}")
for key, val in c.items():
if hasattr(val, "shape"):
print(f"{key}.shape = {val.shape}")
batch_size = x.shape[self.batch_axis]
x = x.to(self.device)
x_spatial = rearrange(x, f"{self.einops_layout} -> {self.einops_spatial_layout}")
z = self.encode_first_stage(x_spatial)
if verbose:
print("after first stage:")
print(f"z.shape = {z.shape}")
# xrec = self.decode_first_stage(z)
z = rearrange(z, f"{self.einops_spatial_layout} -> {self.einops_layout}", N=batch_size)
if t is None:
t = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device).long()
y = c if isinstance(c, torch.Tensor) else c.get("y", None)
if self.cond_stage_model is not None:
assert c is not None
zc = self.cond_stage_forward(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t]
zc = self.q_sample(x_start=zc, t=tc, noise=torch.randn_like(c.float()))
if verbose and hasattr(zc, "shape"):
print(f"zc.shape = {zc.shape}")
else:
zc = y
if verbose and hasattr(y, "shape"):
print(f"y.shape = {y.shape}")
# calculate the loss
zt = self.q_sample(x_start=z, t=t, noise=torch.randn_like(z))
target = self.target_fn(x, y, **aux_input_dict)
pred = self.torch_nn_module(zt, t, y=y, zc=zc, **aux_input_dict)
loss = self.loss_fn(pred, target)
# other metrics
with torch.no_grad():
mae = F.l1_loss(pred, target).float().cpu().item()
avg_gt = torch.abs(target).mean().float().cpu().item()
loss_dict = {
"mae": mae,
"avg_gt": avg_gt,
"relative_mae": mae / (avg_gt + 1E-8),
}
if return_verbose:
return loss, loss_dict, \
{"pred": pred, "target": target, "t": t, "zc": zc, "zt": zt}
else:
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self(batch)
self.log("train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False)
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False)
return loss
def validation_step(self, batch, batch_idx):
loss, loss_dict = self(batch)
self.log("val_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
loss_dict = {f"val/{key}": val for key, val in loss_dict.items()}
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
return loss
def test_step(self, batch, batch_idx):
loss, loss_dict = self(batch)
self.log("test_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
loss_dict = {f"test/{key}": val for key, val in loss_dict.items()}
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
return loss
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.torch_nn_module.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt