File size: 1,506 Bytes
7667a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch

def get_sample_align_fn(sample_align_model):
    r"""
    Code is adapted from https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/scripts/classifier_sample.py#L54-L61
    """
    def sample_align_fn(x, *args, **kwargs):
        r"""
        Calculates `grad(log(p(y|x)))`
        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).

        Parameters
        ----------
        x:  torch.Tensor

        Returns
        -------
        grad
        """
        # with torch.inference_mode(False):
        with torch.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = sample_align_model(x_in, *args, **kwargs)
            grad = torch.autograd.grad(logits.sum(), x_in, allow_unused=True)[0]
            return grad
    return sample_align_fn

def get_alignment_kwargs_avg_x(context_seq=None, target_seq=None, ):
    r"""
    Please customize this function for generating knowledge "avg_x_gt"
    that guides the inference.
    E.g., this function uses 2.0 ground-truth future average intensity as "avg_x_gt" for demonstration.

    Parameters
    ----------
    context_seq:    torch.Tensor, aka "y"
    target_seq:     torch.Tensor, aka "x"

    Returns
    -------
    alignment_kwargs:   Dict
    """
    multiplier = 2.0
    batch_size = target_seq.shape[0]
    ret = torch.mean(target_seq.view(batch_size, -1),
                     dim=1, keepdim=True) * multiplier
    return {"avg_x_gt": ret}