File size: 916 Bytes
7667a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Callable
from torch import nn
from torch.nn import functional as F


def warmup_lambda(warmup_steps, min_lr_ratio=0.1):
    def ret_lambda(epoch):
        if epoch <= warmup_steps:
            return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps
        else:
            return 1.0
    return ret_lambda


def get_loss_fn(loss: str = "l2") -> Callable:
    if loss in ("l2", "mse"):
        return F.mse_loss
    elif loss in ("l1", "mae"):
        return F.l1_loss


def disabled_train(self):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


def disable_train(model: nn.Module):
    r"""
    Disable training to avoid error when used in pl.LightningModule
    """
    model.eval()
    model.train = disabled_train
    for param in model.parameters():
        param.requires_grad = False
    return model