from typing import Callable from torch import nn from torch.nn import functional as F def warmup_lambda(warmup_steps, min_lr_ratio=0.1): def ret_lambda(epoch): if epoch <= warmup_steps: return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps else: return 1.0 return ret_lambda def get_loss_fn(loss: str = "l2") -> Callable: if loss in ("l2", "mse"): return F.mse_loss elif loss in ("l1", "mae"): return F.l1_loss def disabled_train(self): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def disable_train(model: nn.Module): r""" Disable training to avoid error when used in pl.LightningModule """ model.eval() model.train = disabled_train for param in model.parameters(): param.requires_grad = False return model