weatherforecast1024's picture
Upload folder using huggingface_hub
7667a87 verified
import torch
def get_sample_align_fn(sample_align_model):
r"""
Code is adapted from https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/scripts/classifier_sample.py#L54-L61
"""
def sample_align_fn(x, *args, **kwargs):
r"""
Calculates `grad(log(p(y|x)))`
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
Parameters
----------
x: torch.Tensor
Returns
-------
grad
"""
# with torch.inference_mode(False):
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = sample_align_model(x_in, *args, **kwargs)
grad = torch.autograd.grad(logits.sum(), x_in, allow_unused=True)[0]
return grad
return sample_align_fn
def get_alignment_kwargs_avg_x(context_seq=None, target_seq=None, ):
r"""
Please customize this function for generating knowledge "avg_x_gt"
that guides the inference.
E.g., this function uses 2.0 ground-truth future average intensity as "avg_x_gt" for demonstration.
Parameters
----------
context_seq: torch.Tensor, aka "y"
target_seq: torch.Tensor, aka "x"
Returns
-------
alignment_kwargs: Dict
"""
multiplier = 2.0
batch_size = target_seq.shape[0]
ret = torch.mean(target_seq.view(batch_size, -1),
dim=1, keepdim=True) * multiplier
return {"avg_x_gt": ret}