| | import torch |
| |
|
| | def get_sample_align_fn(sample_align_model): |
| | r""" |
| | Code is adapted from https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/scripts/classifier_sample.py#L54-L61 |
| | """ |
| | def sample_align_fn(x, *args, **kwargs): |
| | r""" |
| | Calculates `grad(log(p(y|x)))` |
| | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). |
| | |
| | Parameters |
| | ---------- |
| | x: torch.Tensor |
| | |
| | Returns |
| | ------- |
| | grad |
| | """ |
| | |
| | with torch.enable_grad(): |
| | x_in = x.detach().requires_grad_(True) |
| | logits = sample_align_model(x_in, *args, **kwargs) |
| | grad = torch.autograd.grad(logits.sum(), x_in, allow_unused=True)[0] |
| | return grad |
| | return sample_align_fn |
| |
|
| | def get_alignment_kwargs_avg_x(context_seq=None, target_seq=None, ): |
| | r""" |
| | Please customize this function for generating knowledge "avg_x_gt" |
| | that guides the inference. |
| | E.g., this function uses 2.0 ground-truth future average intensity as "avg_x_gt" for demonstration. |
| | |
| | Parameters |
| | ---------- |
| | context_seq: torch.Tensor, aka "y" |
| | target_seq: torch.Tensor, aka "x" |
| | |
| | Returns |
| | ------- |
| | alignment_kwargs: Dict |
| | """ |
| | multiplier = 2.0 |
| | batch_size = target_seq.shape[0] |
| | ret = torch.mean(target_seq.view(batch_size, -1), |
| | dim=1, keepdim=True) * multiplier |
| | return {"avg_x_gt": ret} |