atnikos commited on
Commit
d8530c7
·
1 Parent(s): f66aca9

basic setup

Browse files
.gitignore ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .err
2
+ *.out
3
+ /cluster_scripts
4
+ /condor_logs
5
+ lightning_logs
6
+ sinc-env
7
+ fast-cluster
8
+ eval-deps
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: other
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: other
11
+ models : ["openai/clip-vit-large-patch14"]
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import spaces
3
  import torch
@@ -5,6 +6,36 @@ import random
5
 
6
  zero = torch.Tensor([0]).cuda()
7
  print(zero.device) # <-- 'cpu' 🤔
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  @spaces.GPU
10
  def greet(n):
@@ -20,17 +51,74 @@ def clear():
20
 
21
  def random_number():
22
  return str(random.uniform(0, 100))
 
 
 
 
23
 
 
 
24
  with gr.Blocks() as demo:
 
 
25
  input_text = gr.Textbox(label="Input Text")
26
- output_text = gr.Textbox(label="Output Text")
27
 
28
  with gr.Row():
29
  retrieve_button = gr.Button("Retrieve")
30
  clear_button = gr.Button("Clear")
31
  random_button = gr.Button("Random")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- retrieve_button.click(greet, inputs=input_text, outputs=output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  clear_button.click(clear, outputs=input_text)
35
  random_button.click(random_number, outputs=input_text)
36
 
 
1
+ from geometry_utils import diffout2motion
2
  import gradio as gr
3
  import spaces
4
  import torch
 
6
 
7
  zero = torch.Tensor([0]).cuda()
8
  print(zero.device) # <-- 'cpu' 🤔
9
+ # G&uumll Varol
10
+
11
+ WEBSITE = """
12
+ <div class="embed_hidden">
13
+ <h1 style='text-align: center'> ACRONYM: The actual title </h1>
14
+
15
+ <h2 style='text-align: center'>
16
+ <a href="https://google.com" target="_blank"><nobr>fname m. lname</nobr></a> &emsp;
17
+ <a href="https://google.com" target="_blank"><nobr>fname m. lname</nobr></a> &emsp;
18
+ <a href="https://google.com" target="_blank"><nobr>fname m. lname</nobr></a>
19
+ </h2>
20
+
21
+ <h2 style='text-align: center'>
22
+ <nobr>XXX 2024</nobr>
23
+ </h2>
24
+
25
+ <h3 style="text-align:center;">
26
+ <a target="_blank" href="https://arxiv.org/"> <button type="button" class="btn btn-primary btn-lg"> Paper </button></a>
27
+ <a target="_blank" href="https://github.com/"> <button type="button" class="btn btn-primary btn-lg"> Code </button></a>
28
+ <a target="_blank" href="google.com"> <button type="button" class="btn btn-primary btn-lg"> Webpage </button></a>
29
+ <a target="_blank" href="bibfile.com"> <button type="button" class="btn btn-primary btn-lg"> BibTex </button></a>
30
+ </h3>
31
+
32
+ <h3> Description </h3>
33
+ <p>
34
+ This space illustrates <a href='project.com' target='_blank'><b>XXX</b></a>, a method for XXX.
35
+ What does it do?
36
+ </p>
37
+ </div>
38
+ """
39
 
40
  @spaces.GPU
41
  def greet(n):
 
51
 
52
  def random_number():
53
  return str(random.uniform(0, 100))
54
+ from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
55
+
56
+ def download_models():
57
+ REPO_ID = 'athn-nik/example-model'
58
 
59
+ return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
60
+
61
  with gr.Blocks() as demo:
62
+ gr.Markdown(WEBSITE)
63
+
64
  input_text = gr.Textbox(label="Input Text")
65
+ # output_text = gr.Textbox(label="Output Text")
66
 
67
  with gr.Row():
68
  retrieve_button = gr.Button("Retrieve")
69
  clear_button = gr.Button("Clear")
70
  random_button = gr.Button("Random")
71
+ from normalization import Normalizer
72
+ normalizer = Normalizer()
73
+ # tmed_den = load_model()
74
+ from diffusion import create_diffusion
75
+ from text_encoder import ClipTextEncoder
76
+ from tmed_denoiser import TMED_denoiser
77
+ model_ckpt = download_models()
78
+ checkpoint = torch.load(model_ckpt)
79
+ print(checkpoint.keys())
80
+ checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
81
+ tmed_denoiser = TMED_denoiser().load_state_dict(checkpoint, strict=False)
82
+ text_encoder = ClipTextEncoder()
83
+ texts_cond = [input_text]
84
+ diffusion_process = create_diffusion(timestep_respacing=None,
85
+ learn_sigma=False, sigma_small=True,
86
+ diffusion_steps=300,
87
+ noise_schedule='squaredcos_cap_v2',
88
+ predict_type='sample',
89
+ predict_xstart=True) # noise vs sample
90
+ # uncond_tokens = [""] * len(texts_cond)
91
+ # if self.condition == 'text':
92
+ # uncond_tokens.extend(texts_cond)
93
+ # elif self.condition == 'text_uncond':
94
+ # uncond_tokens.extend(uncond_tokens)
95
+ bsz = 1
96
+ seqlen_tgt = 180
97
+ no_of_texts = len(texts_cond)
98
+ texts_cond = ['']*no_of_texts + texts_cond
99
+ texts_cond = ['']*no_of_texts + texts_cond
100
+ text_emb, text_mask = text_encoder(texts_cond)
101
 
102
+ cond_emb_motion = torch.zeros(1, bsz,
103
+ 512,
104
+ device='cuda')
105
+ cond_motion_mask = torch.ones((bsz, 1),
106
+ dtype=bool, device='cuda')
107
+ mask_target = torch.ones((1, bsz),
108
+ dtype=bool, device='cuda')
109
+ # complete noise
110
+ diff_out = tmed_denoiser.diffusion_reverse(text_emb,
111
+ text_mask,
112
+ cond_emb_motion,
113
+ cond_motion_mask,
114
+ mask_target,
115
+ diffusion_process,
116
+ init_vec=None,
117
+ init_from='noise',
118
+ gd_text=4.0,
119
+ gd_motion=2.0,
120
+ steps_num=300)
121
+ edited_motion = diffout2motion(diff_out)
122
  clear_button.click(clear, outputs=input_text)
123
  random_button.click(random_number, outputs=input_text)
124
 
deps/statistics_bodilex.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a0be87962557d3149203eb4586f3e670c1bd7785765ad8cef9ed91f6277a2c2
3
+ size 4826
diffusion/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "squaredcos_cap_v2":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None):
255
+ """
256
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257
+ the initial x, x_0.
258
+ :param model: the model, which takes a signal and a batch of timesteps
259
+ as input.
260
+ :param x: the [N x C x ...] tensor at time t.
261
+ :param t: a 1-D Tensor of timesteps.
262
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263
+ :param denoised_fn: if not None, a function which applies to the
264
+ x_start prediction before it is used to sample. Applies before
265
+ clip_denoised.
266
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
267
+ pass to the model. This can be used for conditioning.
268
+ :return: a dict with the following keys:
269
+ - 'mean': the model mean output.
270
+ - 'variance': the model variance output.
271
+ - 'log_variance': the log of 'variance'.
272
+ - 'pred_xstart': the prediction for x_0.
273
+ """
274
+ if model_kwargs is None:
275
+ model_kwargs = {}
276
+
277
+ B, C = x.shape[:2]
278
+ assert t.shape == (B,)
279
+ model_output = model(x, t, **model_kwargs)
280
+ if isinstance(model_output, tuple):
281
+ model_output, extra = model_output
282
+ else:
283
+ extra = None
284
+
285
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
287
+ model_output, model_var_values = th.split(model_output, C, dim=1)
288
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290
+ # The model_var_values is [-1, 1] for [min_var, max_var].
291
+ frac = (model_var_values + 1) / 2
292
+ model_log_variance = frac * max_log + (1 - frac) * min_log
293
+ model_variance = th.exp(model_log_variance)
294
+ else:
295
+ model_variance, model_log_variance = {
296
+ # for fixedlarge, we set the initial (log-)variance like so
297
+ # to get a better decoder log likelihood.
298
+ ModelVarType.FIXED_LARGE: (
299
+ np.append(self.posterior_variance[1], self.betas[1:]),
300
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301
+ ),
302
+ ModelVarType.FIXED_SMALL: (
303
+ self.posterior_variance,
304
+ self.posterior_log_variance_clipped,
305
+ ),
306
+ }[self.model_var_type]
307
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
308
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309
+
310
+ def process_xstart(x):
311
+ if denoised_fn is not None:
312
+ x = denoised_fn(x)
313
+ if clip_denoised:
314
+ return x.clamp(-1, 1)
315
+ return x
316
+
317
+ if self.model_mean_type == ModelMeanType.START_X:
318
+ pred_xstart = process_xstart(model_output)
319
+ else:
320
+ pred_xstart = process_xstart(
321
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
322
+ )
323
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
324
+
325
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326
+ return {
327
+ "mean": model_mean,
328
+ "variance": model_variance,
329
+ "log_variance": model_log_variance,
330
+ "pred_xstart": pred_xstart,
331
+ "extra": extra,
332
+ }
333
+
334
+ def _predict_xstart_from_eps(self, x_t, t, eps):
335
+ assert x_t.shape == eps.shape
336
+ return (
337
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339
+ )
340
+
341
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
342
+ return (
343
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
344
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
345
+
346
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
347
+ """
348
+ Compute the mean for the previous step, given a function cond_fn that
349
+ computes the gradient of a conditional log probability with respect to
350
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
351
+ condition on y.
352
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
353
+ """
354
+ gradient = cond_fn(x, t, **model_kwargs)
355
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
356
+ return new_mean
357
+
358
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute what the p_mean_variance output would have been, should the
361
+ model's score function be conditioned by cond_fn.
362
+ See condition_mean() for details on cond_fn.
363
+ Unlike condition_mean(), this instead uses the conditioning strategy
364
+ from Song et al (2020).
365
+ """
366
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
367
+
368
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
369
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
370
+
371
+ out = p_mean_var.copy()
372
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
373
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
374
+ return out
375
+
376
+ def p_sample(
377
+ self,
378
+ model,
379
+ x,
380
+ t,
381
+ clip_denoised=False,
382
+ denoised_fn=None,
383
+ cond_fn=None,
384
+ model_kwargs=None,
385
+ ):
386
+ """
387
+ Sample x_{t-1} from the model at the given timestep.
388
+ :param model: the model to sample from.
389
+ :param x: the current tensor at x_{t-1}.
390
+ :param t: the value of t, starting at 0 for the first diffusion step.
391
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
392
+ :param denoised_fn: if not None, a function which applies to the
393
+ x_start prediction before it is used to sample.
394
+ :param cond_fn: if not None, this is a gradient function that acts
395
+ similarly to the model.
396
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
397
+ pass to the model. This can be used for conditioning.
398
+ :return: a dict containing the following keys:
399
+ - 'sample': a random sample from the model.
400
+ - 'pred_xstart': a prediction of x_0.
401
+ """
402
+ out = self.p_mean_variance(
403
+ model,
404
+ x,
405
+ t,
406
+ clip_denoised=clip_denoised,
407
+ denoised_fn=denoised_fn,
408
+ model_kwargs=model_kwargs,
409
+ )
410
+ noise = th.randn_like(x)
411
+ nonzero_mask = (
412
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
413
+ ) # no noise when t == 0
414
+ if cond_fn is not None:
415
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
416
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
417
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
418
+
419
+ def p_sample_loop(
420
+ self,
421
+ model,
422
+ shape,
423
+ noise=None,
424
+ clip_denoised=False,
425
+ denoised_fn=None,
426
+ cond_fn=None,
427
+ model_kwargs=None,
428
+ device=None,
429
+ progress=False,
430
+ ):
431
+ """
432
+ Generate samples from the model.
433
+ :param model: the model module.
434
+ :param shape: the shape of the samples, (N, C, H, W).
435
+ :param noise: if specified, the noise from the encoder to sample.
436
+ Should be of the same shape as `shape`.
437
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
438
+ :param denoised_fn: if not None, a function which applies to the
439
+ x_start prediction before it is used to sample.
440
+ :param cond_fn: if not None, this is a gradient function that acts
441
+ similarly to the model.
442
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
443
+ pass to the model. This can be used for conditioning.
444
+ :param device: if specified, the device to create the samples on.
445
+ If not specified, use a model parameter's device.
446
+ :param progress: if True, show a tqdm progress bar.
447
+ :return: a non-differentiable batch of samples.
448
+ """
449
+ final = None
450
+ for sample in self.p_sample_loop_progressive(
451
+ model,
452
+ shape,
453
+ noise=noise,
454
+ clip_denoised=clip_denoised,
455
+ denoised_fn=denoised_fn,
456
+ cond_fn=cond_fn,
457
+ model_kwargs=model_kwargs,
458
+ device=device,
459
+ progress=progress,
460
+ ):
461
+ final = sample
462
+ return final["sample"]
463
+
464
+ def p_sample_loop_progressive(
465
+ self,
466
+ model,
467
+ shape,
468
+ noise=None,
469
+ clip_denoised=False,
470
+ denoised_fn=None,
471
+ cond_fn=None,
472
+ model_kwargs=None,
473
+ device=None,
474
+ progress=False,
475
+ ):
476
+ """
477
+ Generate samples from the model and yield intermediate samples from
478
+ each timestep of diffusion.
479
+ Arguments are the same as p_sample_loop().
480
+ Returns a generator over dicts, where each dict is the return value of
481
+ p_sample().
482
+ """
483
+ if device is None:
484
+ device = next(model.parameters()).device
485
+ assert isinstance(shape, (tuple, list))
486
+ if noise is not None:
487
+ img = noise
488
+ else:
489
+ img = th.randn(*shape, device=device)
490
+ indices = list(range(self.num_timesteps))[::-1]
491
+
492
+ if progress:
493
+ # Lazy import so that we don't depend on tqdm.
494
+ from tqdm.auto import tqdm
495
+
496
+ indices = tqdm(indices)
497
+
498
+ for i in indices:
499
+ t = th.tensor([i] * shape[0], device=device)
500
+ with th.no_grad():
501
+ out = self.p_sample(
502
+ model,
503
+ img,
504
+ t,
505
+ clip_denoised=False,
506
+ denoised_fn=denoised_fn,
507
+ cond_fn=cond_fn,
508
+ model_kwargs=model_kwargs,
509
+ )
510
+ yield out
511
+ img = out["sample"]
512
+
513
+ def ddim_sample(
514
+ self,
515
+ model,
516
+ x,
517
+ t,
518
+ clip_denoised=False,
519
+ denoised_fn=None,
520
+ cond_fn=None,
521
+ model_kwargs=None,
522
+ eta=0.0,
523
+ ):
524
+ """
525
+ Sample x_{t-1} from the model using DDIM.
526
+ Same usage as p_sample().
527
+ """
528
+ out = self.p_mean_variance(
529
+ model,
530
+ x,
531
+ t,
532
+ clip_denoised=False,
533
+ denoised_fn=denoised_fn,
534
+ model_kwargs=model_kwargs,
535
+ )
536
+ if cond_fn is not None:
537
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
538
+
539
+ # Usually our model outputs epsilon, but we re-derive it
540
+ # in case we used x_start or x_prev prediction.
541
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
542
+
543
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
544
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
545
+ sigma = (
546
+ eta
547
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
548
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
549
+ )
550
+ # Equation 12.
551
+ noise = th.randn_like(x)
552
+ mean_pred = (
553
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
554
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
555
+ )
556
+ nonzero_mask = (
557
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
558
+ ) # no noise when t == 0
559
+ sample = mean_pred + nonzero_mask * sigma * noise
560
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
561
+
562
+ def ddim_reverse_sample(
563
+ self,
564
+ model,
565
+ x,
566
+ t,
567
+ clip_denoised=False,
568
+ denoised_fn=None,
569
+ cond_fn=None,
570
+ model_kwargs=None,
571
+ eta=0.0,
572
+ ):
573
+ """
574
+ Sample x_{t+1} from the model using DDIM reverse ODE.
575
+ """
576
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
577
+ out = self.p_mean_variance(
578
+ model,
579
+ x,
580
+ t,
581
+ clip_denoised=clip_denoised,
582
+ denoised_fn=denoised_fn,
583
+ model_kwargs=model_kwargs,
584
+ )
585
+ if cond_fn is not None:
586
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
587
+ # Usually our model outputs epsilon, but we re-derive it
588
+ # in case we used x_start or x_prev prediction.
589
+ eps = (
590
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
591
+ - out["pred_xstart"]
592
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
593
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
594
+
595
+ # Equation 12. reversed
596
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
597
+
598
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
599
+
600
+ def ddim_sample_loop(
601
+ self,
602
+ model,
603
+ shape,
604
+ noise=None,
605
+ clip_denoised=True,
606
+ denoised_fn=None,
607
+ cond_fn=None,
608
+ model_kwargs=None,
609
+ device=None,
610
+ progress=False,
611
+ eta=0.0,
612
+ ):
613
+ """
614
+ Generate samples from the model using DDIM.
615
+ Same usage as p_sample_loop().
616
+ """
617
+ final = None
618
+ for sample in self.ddim_sample_loop_progressive(
619
+ model,
620
+ shape,
621
+ noise=noise,
622
+ clip_denoised=clip_denoised,
623
+ denoised_fn=denoised_fn,
624
+ cond_fn=cond_fn,
625
+ model_kwargs=model_kwargs,
626
+ device=device,
627
+ progress=progress,
628
+ eta=eta,
629
+ ):
630
+ final = sample
631
+ return final["sample"]
632
+
633
+ def ddim_sample_loop_progressive(
634
+ self,
635
+ model,
636
+ shape,
637
+ noise=None,
638
+ clip_denoised=True,
639
+ denoised_fn=None,
640
+ cond_fn=None,
641
+ model_kwargs=None,
642
+ device=None,
643
+ progress=False,
644
+ eta=0.0,
645
+ ):
646
+ """
647
+ Use DDIM to sample from the model and yield intermediate samples from
648
+ each timestep of DDIM.
649
+ Same usage as p_sample_loop_progressive().
650
+ """
651
+ if device is None:
652
+ device = next(model.parameters()).device
653
+ assert isinstance(shape, (tuple, list))
654
+ if noise is not None:
655
+ img = noise
656
+ else:
657
+ img = th.randn(*shape, device=device)
658
+ indices = list(range(self.num_timesteps))[::-1]
659
+
660
+ if progress:
661
+ # Lazy import so that we don't depend on tqdm.
662
+ from tqdm.auto import tqdm
663
+
664
+ indices = tqdm(indices)
665
+
666
+ for i in indices:
667
+ t = th.tensor([i] * shape[0], device=device)
668
+ with th.no_grad():
669
+ out = self.ddim_sample(
670
+ model,
671
+ img,
672
+ t,
673
+ clip_denoised=clip_denoised,
674
+ denoised_fn=denoised_fn,
675
+ cond_fn=cond_fn,
676
+ model_kwargs=model_kwargs,
677
+ eta=eta,
678
+ )
679
+ yield out
680
+ img = out["sample"]
681
+
682
+ def _vb_terms_bpd(
683
+ self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None
684
+ ):
685
+ """
686
+ Get a term for the variational lower-bound.
687
+ The resulting units are bits (rather than nats, as one might expect).
688
+ This allows for comparison to other papers.
689
+ :return: a dict with the following keys:
690
+ - 'output': a shape [N] tensor of NLLs or KLs.
691
+ - 'pred_xstart': the x_0 predictions.
692
+ """
693
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
694
+ x_start=x_start, x_t=x_t, t=t
695
+ )
696
+ out = self.p_mean_variance(
697
+ model, x_t, t, clip_denoised=False, model_kwargs=model_kwargs
698
+ )
699
+ kl = normal_kl(
700
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
701
+ )
702
+ kl = mean_flat(kl) / np.log(2.0)
703
+
704
+ decoder_nll = -discretized_gaussian_log_likelihood(
705
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
706
+ )
707
+ assert decoder_nll.shape == x_start.shape
708
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
709
+
710
+ # At the first timestep return the decoder NLL,
711
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
712
+ output = th.where((t == 0), decoder_nll, kl)
713
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
714
+
715
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
716
+ """
717
+ Compute training losses for a single timestep.
718
+ :param model: the model to evaluate loss on.
719
+ :param x_start: the [N x C x ...] tensor of inputs.
720
+ :param t: a batch of timestep indices.
721
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
722
+ pass to the model. This can be used for conditioning.
723
+ :param noise: if specified, the specific Gaussian noise to try to remove.
724
+ :return: a dict with the key "loss" containing a tensor of shape [N].
725
+ Some mean or variance settings may also have other keys.
726
+ """
727
+ if model_kwargs is None:
728
+ model_kwargs = {}
729
+ if noise is None:
730
+ noise = th.randn_like(x_start)
731
+ x_t = self.q_sample(x_start, t, noise=noise)
732
+
733
+ terms = {}
734
+
735
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
736
+ terms["loss"] = self._vb_terms_bpd(
737
+ model=model,
738
+ x_start=x_start,
739
+ x_t=x_t,
740
+ t=t,
741
+ clip_denoised=False,
742
+ model_kwargs=model_kwargs,
743
+ )["output"]
744
+ if self.loss_type == LossType.RESCALED_KL:
745
+ terms["loss"] *= self.num_timesteps
746
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
747
+ model_output = model(x_t, t, **model_kwargs)
748
+
749
+ if self.model_var_type in [
750
+ ModelVarType.LEARNED,
751
+ ModelVarType.LEARNED_RANGE,
752
+ ]:
753
+ B, C = x_t.shape[:2]
754
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
755
+ model_output, model_var_values = th.split(model_output, C, dim=1)
756
+ # Learn the variance using the variational bound, but don't let
757
+ # it affect our mean prediction.
758
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
759
+ terms["vb"] = self._vb_terms_bpd(
760
+ model=lambda *args, r=frozen_out: r,
761
+ x_start=x_start,
762
+ x_t=x_t,
763
+ t=t,
764
+ clip_denoised=False,
765
+ )["output"]
766
+ if self.loss_type == LossType.RESCALED_MSE:
767
+ # Divide by 1000 for equivalence with initial implementation.
768
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
769
+ terms["vb"] *= self.num_timesteps / 1000.0
770
+
771
+ target = {
772
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
773
+ x_start=x_start, x_t=x_t, t=t
774
+ )[0],
775
+ ModelMeanType.START_X: x_start,
776
+ ModelMeanType.EPSILON: noise,
777
+ }[self.model_mean_type]
778
+ assert model_output.shape == target.shape == x_start.shape
779
+ terms["mse"] = mean_flat((target - model_output) ** 2)
780
+ terms["target"] = target
781
+ terms['model_output'] = model_output
782
+ if "vb" in terms:
783
+ terms["loss"] = terms["mse"] + terms["vb"]
784
+ else:
785
+ terms["loss"] = terms["mse"]
786
+ else:
787
+ raise NotImplementedError(self.loss_type)
788
+
789
+ return terms
790
+
791
+ def _prior_bpd(self, x_start):
792
+ """
793
+ Get the prior KL term for the variational lower-bound, measured in
794
+ bits-per-dim.
795
+ This term can't be optimized, as it only depends on the encoder.
796
+ :param x_start: the [N x C x ...] tensor of inputs.
797
+ :return: a batch of [N] KL values (in bits), one per batch element.
798
+ """
799
+ batch_size = x_start.shape[0]
800
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
801
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
802
+ kl_prior = normal_kl(
803
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
804
+ )
805
+ return mean_flat(kl_prior) / np.log(2.0)
806
+
807
+ def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None):
808
+ """
809
+ Compute the entire variational lower-bound, measured in bits-per-dim,
810
+ as well as other related quantities.
811
+ :param model: the model to evaluate loss on.
812
+ :param x_start: the [N x C x ...] tensor of inputs.
813
+ :param clip_denoised: if True, clip denoised samples.
814
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
815
+ pass to the model. This can be used for conditioning.
816
+ :return: a dict containing the following keys:
817
+ - total_bpd: the total variational lower-bound, per batch element.
818
+ - prior_bpd: the prior term in the lower-bound.
819
+ - vb: an [N x T] tensor of terms in the lower-bound.
820
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
821
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
822
+ """
823
+ device = x_start.device
824
+ batch_size = x_start.shape[0]
825
+
826
+ vb = []
827
+ xstart_mse = []
828
+ mse = []
829
+ for t in list(range(self.num_timesteps))[::-1]:
830
+ t_batch = th.tensor([t] * batch_size, device=device)
831
+ noise = th.randn_like(x_start)
832
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
833
+ # Calculate VLB term at the current timestep
834
+ with th.no_grad():
835
+ out = self._vb_terms_bpd(
836
+ model,
837
+ x_start=x_start,
838
+ x_t=x_t,
839
+ t=t_batch,
840
+ clip_denoised=clip_denoised,
841
+ model_kwargs=model_kwargs,
842
+ )
843
+ vb.append(out["output"])
844
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
845
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
846
+ mse.append(mean_flat((eps - noise) ** 2))
847
+
848
+ vb = th.stack(vb, dim=1)
849
+ xstart_mse = th.stack(xstart_mse, dim=1)
850
+ mse = th.stack(mse, dim=1)
851
+
852
+ prior_bpd = self._prior_bpd(x_start)
853
+ total_bpd = vb.sum(dim=1) + prior_bpd
854
+ return {
855
+ "total_bpd": total_bpd,
856
+ "prior_bpd": prior_bpd,
857
+ "vb": vb,
858
+ "xstart_mse": xstart_mse,
859
+ "mse": mse,
860
+ }
861
+
862
+
863
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
864
+ """
865
+ Extract values from a 1-D numpy array for a batch of indices.
866
+ :param arr: the 1-D numpy array.
867
+ :param timesteps: a tensor of indices into the array to extract.
868
+ :param broadcast_shape: a larger shape of K dimensions with the batch
869
+ dimension equal to the length of timesteps.
870
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
871
+ """
872
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
873
+ while len(res.shape) < len(broadcast_shape):
874
+ res = res[..., None]
875
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/respace.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)
diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
gen_utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ def cast_dict_to_tensors(d, device="cpu"):
4
+ if isinstance(d, dict):
5
+ return {k: cast_dict_to_tensors(v, device) for k, v in d.items()}
6
+ elif isinstance(d, np.ndarray):
7
+ return torch.from_numpy(d).float().to(device)
8
+ elif isinstance(d, torch.Tensor):
9
+ return d.to(device)
10
+ else:
11
+ return d
geometry_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def diffout2motion(diffout):
4
+
5
+ # - "body_transl_delta_pelv_xy_wo_z"
6
+ # - "body_transl_z"
7
+ # - "z_orient_delta"
8
+ # - "body_orient_xy"
9
+ # - "body_pose"
10
+ # - "body_joints_local_wo_z_rot"
11
+ feats_unnorm = self.cat_inputs(self.unnorm_inputs(
12
+ self.uncat_inputs(diffout,
13
+ self.input_feats_dims),
14
+ self.input_feats))[0]
15
+ # FIRST POSE FOR GENERATION & DELTAS FOR INTEGRATION
16
+ if "body_joints_local_wo_z_rot" in self.input_feats:
17
+ idx = self.input_feats.index("body_joints_local_wo_z_rot")
18
+ feats_unnorm = feats_unnorm[..., :-self.input_feats_dims[idx]]
19
+
20
+ first_trans = torch.zeros(*diffout.shape[:-1], 3,
21
+ device=self.device)[:, [0]]
22
+ if 'z_orient_delta' in self.input_feats:
23
+ first_orient_z = torch.eye(3, device=self.device).unsqueeze(0) # Now the shape is (1, 1, 3, 3)
24
+ first_orient_z = first_orient_z.repeat(feats_unnorm.shape[0], 1, 1) # Now the shape is (B, 1, 3, 3)
25
+ first_orient_z = transform_body_pose(first_orient_z, 'rot->6d')
26
+
27
+ # --> first_orient_z convert to 6d
28
+ # integrate z orient delta --> z component tof orientation
29
+ z_orient_delta = feats_unnorm[..., 9:15]
30
+
31
+ from src.tools.transforms3d import apply_rot_delta, remove_z_rot, get_z_rot, change_for
32
+ prev_z = first_orient_z
33
+ full_z_angle = [first_orient_z[:, None]]
34
+ for i in range(1, z_orient_delta.shape[1]):
35
+ curr_z = apply_rot_delta(prev_z, z_orient_delta[:, i])
36
+ prev_z = curr_z.clone()
37
+ full_z_angle.append(curr_z[:,None])
38
+ full_z_angle = torch.cat(full_z_angle, dim=1)
39
+ full_z_angle_rotmat = get_z_rot(full_z_angle)
40
+ # full_orient = torch.cat([full_z_angle, xy_orient], dim=-1)
41
+ xy_orient = feats_unnorm[..., 3:9]
42
+ xy_orient_rotmat = transform_body_pose(xy_orient, '6d->rot')
43
+ # xy_orient = remove_z_rot(xy_orient, in_format="6d")
44
+
45
+ # GLOBAL ORIENTATION
46
+ # full_z_angle = transform_body_pose(full_z_angle_rotmat,
47
+ # 'rot->6d')
48
+
49
+ # full_global_orient = apply_rot_delta(full_z_angle,
50
+ # xy_orient)
51
+ full_global_orient_rotmat = full_z_angle_rotmat @ xy_orient_rotmat
52
+ full_global_orient = transform_body_pose(full_global_orient_rotmat,
53
+ 'rot->6d')
54
+
55
+ first_trans = self.cat_inputs(self.unnorm_inputs(
56
+ [first_trans],
57
+ ['body_transl'])
58
+ )[0]
59
+
60
+ # apply deltas
61
+ # get velocity in global c.f. and add it to the state position
62
+ assert 'body_transl_delta_pelv' in self.input_feats
63
+ pelvis_delta = feats_unnorm[..., :3]
64
+ trans_vel_pelv = change_for(pelvis_delta[:, 1:],
65
+ full_global_orient_rotmat[:, :-1],
66
+ forward=False)
67
+
68
+ # new_state_pos = prev_trans_norm.squeeze() + trans_vel_pelv
69
+ full_trans = torch.cumsum(trans_vel_pelv, dim=1) + first_trans
70
+ full_trans = torch.cat([first_trans, full_trans], dim=1)
71
+
72
+ # "body_transl_delta_pelv_xy_wo_z"
73
+ # first_trans = self.cat_inputs(self.unnorm_inputs(
74
+ # [first_trans],
75
+ # ['body_transl'])
76
+ # )[0]
77
+
78
+ # pelvis_xy = pelvis_delta_xy
79
+ # FULL TRANSLATION
80
+ # full_trans = torch.cat([pelvis_xy,
81
+ # feats_unnorm[..., 2:3][:,1:]], dim=-1)
82
+ #############
83
+ full_rots = torch.cat([full_global_orient,
84
+ feats_unnorm[...,-21*6:]],
85
+ dim=-1)
86
+ full_motion_unnorm = torch.cat([full_trans,
87
+ full_rots], dim=-1)
88
+
89
+ return full_motion_unnorm
model_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+
7
+ class TimestepEmbedderMDM(nn.Module):
8
+ def __init__(self, latent_dim):
9
+ super().__init__()
10
+ self.latent_dim = latent_dim
11
+
12
+ time_embed_dim = self.latent_dim
13
+ self.sequence_pos_encoder = PositionalEncoding(d_model=self.latent_dim)
14
+ # TODO add time embedding learnable
15
+ self.time_embed = nn.Sequential(
16
+ nn.Linear(self.latent_dim, time_embed_dim),
17
+ nn.SiLU(),
18
+ nn.Linear(time_embed_dim, time_embed_dim),
19
+ )
20
+
21
+ def forward(self, timesteps):
22
+ return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
23
+
24
+
25
+ class PositionalEncoding(nn.Module):
26
+ def __init__(self, d_model, dropout=0.1,
27
+ max_len=5000, batch_first=False, negative=False):
28
+ super().__init__()
29
+ self.batch_first = batch_first
30
+
31
+ self.dropout = nn.Dropout(p=dropout)
32
+ self.max_len = max_len
33
+
34
+ self.negative = negative
35
+
36
+ if negative:
37
+ pe = torch.zeros(2*max_len, d_model)
38
+ position = torch.arange(-max_len, max_len, dtype=torch.float).unsqueeze(1)
39
+ else:
40
+ pe = torch.zeros(max_len, d_model)
41
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
42
+
43
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
44
+ pe[:, 0::2] = torch.sin(position * div_term)
45
+ pe[:, 1::2] = torch.cos(position * div_term)
46
+ pe = pe.unsqueeze(0).transpose(0, 1)
47
+
48
+ self.register_buffer('pe', pe, persistent=False)
49
+
50
+ def forward(self, x, hist_frames=0):
51
+ if not self.negative:
52
+ center = 0
53
+ assert hist_frames == 0
54
+ first = 0
55
+ else:
56
+ center = self.max_len
57
+ first = center-hist_frames
58
+ if self.batch_first:
59
+ last = first + x.shape[1]
60
+ x = x + self.pe.permute(1, 0, 2)[:, first:last, :]
61
+ else:
62
+ last = first + x.shape[0]
63
+ x = x + self.pe[first:last, :]
64
+ return self.dropout(x)
normalization.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import exists
2
+ from gen_utils import cast_dict_to_tensors
3
+ from einops import rearrange
4
+ from torch import Tensor
5
+ from typing import List, Union
6
+ import torch
7
+ import numpy as np
8
+
9
+ class Normalizer:
10
+ def __init__(self, statistics_path: str='deps/statistics_bodilex.npy', nfeats: int=207,
11
+ input_feats: List[str] = ["body_transl_delta_pelv",
12
+ "body_orient_xy",
13
+ "z_orient_delta", "body_pose",
14
+ "body_joints_local_wo_z_rot"],
15
+ dim_per_feat: List[int] = [3, 6, 6, 126, 66], *args, **kwargs):
16
+
17
+ self.stats = self.load_norm_statistics(statistics_path, 'cuda')
18
+ # from src.model.utils.tools import pack_to_render
19
+ # mr = pack_to_render(aa.detach().cpu(), trans=None)
20
+ # mr = {k: v[0] for k, v in mr.items()}
21
+ # fname = render_motion(aitrenderer, mr,
22
+ # "/home/nathanasiou/Desktop/conditional_action_gen/modilex/pose_test",
23
+ # pose_repr='aa',
24
+ # text_for_vid=str(keyids[0]),
25
+ # color=color_map['generated'],
26
+ # smpl_layer=smpl_layer)
27
+
28
+ self.nfeats = nfeats
29
+ self.dim_per_feat = dim_per_feat
30
+ self.input_feats_dims = list(dim_per_feat)
31
+ self.input_feats = list(input_feats)
32
+
33
+
34
+ def load_norm_statistics(self, path, device):
35
+ # workaround for cluster local/sync
36
+ assert exists(path)
37
+ stats = np.load(path, allow_pickle=True)[()]
38
+ return cast_dict_to_tensors(stats, device=device)
39
+
40
+ def norm_and_cat(self, batch, features_types):
41
+ """
42
+ turn batch data into the format the forward() function expects
43
+ """
44
+ seq_first = lambda t: rearrange(t, 'b s ... -> s b ...')
45
+ input_batch = {}
46
+ ## PREPARE INPUT ##
47
+ motion_condition = any('source' in value for value in batch.keys())
48
+ mo_types = ['source', 'target']
49
+ for mot in mo_types:
50
+ list_of_feat_tensors = [seq_first(batch[f'{feat_type}_{mot}'])
51
+ for feat_type in features_types if f'{feat_type}_{mot}' in batch.keys()]
52
+ # normalise and cat to a unified feature vector
53
+ list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors,
54
+ features_types)
55
+ # list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx,
56
+ # x in zip(features_types,
57
+ # list_of_feat_tensors_normed)]
58
+ x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed)
59
+ input_batch[mot] = x_norm
60
+ return input_batch
61
+
62
+ def norm_and_cat_single_motion(self, batch, features_types):
63
+ """
64
+ turn batch data into the format the forward() function expects
65
+ """
66
+ seq_first = lambda t: rearrange(t, 'b s ... -> s b ...')
67
+ input_batch = {}
68
+ ## PREPARE INPUT ##
69
+
70
+ list_of_feat_tensors = [seq_first(batch[feat_type])
71
+ for feat_type in features_types]
72
+ # normalise and cat to a unified feature vector
73
+ list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors,
74
+ features_types)
75
+ # list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx,
76
+ # x in zip(features_types,
77
+ # list_of_feat_tensors_normed)]
78
+
79
+ x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed)
80
+ input_batch['motion'] = x_norm
81
+ return input_batch
82
+
83
+ def norm(self, x, stats):
84
+ mean = stats['mean'].to('cuda')
85
+ std = stats['std'].to('cuda')
86
+ return (x - mean) / (std + 1e-5)
87
+
88
+ def unnorm(self, x, stats):
89
+ mean = stats['mean'].to('cuda')
90
+ std = stats['std'].to('cuda')
91
+ return x * (std + 1e-5) + mean
92
+
93
+ def unnorm_state(self, state_norm: Tensor) -> Tensor:
94
+ # unnorm state
95
+ return self.cat_inputs(
96
+ self.unnorm_inputs(self.uncat_inputs(state_norm,
97
+ self.first_pose_feats_dims),
98
+ self.first_pose_feats))[0]
99
+
100
+ def unnorm_delta(self, delta_norm: Tensor) -> Tensor:
101
+ # unnorm delta
102
+ return self.cat_inputs(
103
+ self.unnorm_inputs(self.uncat_inputs(delta_norm,
104
+ self.input_feats_dims),
105
+ self.input_feats))[0]
106
+
107
+ def norm_state(self, state:Tensor) -> Tensor:
108
+ # normalise state
109
+ return self.cat_inputs(
110
+ self.norm_inputs(self.uncat_inputs(state,
111
+ self.first_pose_feats_dims),
112
+ self.first_pose_feats))[0]
113
+
114
+ def norm_delta(self, delta:Tensor) -> Tensor:
115
+ # normalise delta
116
+ return self.cat_inputs(
117
+ self.norm_inputs(self.uncat_inputs(delta, self.input_feats_dims),
118
+ self.input_feats))[0]
119
+
120
+ def cat_inputs(self, x_list: List[Tensor]):
121
+ """
122
+ cat the inputs to a unified vector and return their lengths in order
123
+ to un-cat them later
124
+ """
125
+ return torch.cat(x_list, dim=-1), [x.shape[-1] for x in x_list]
126
+
127
+ def uncat_inputs(self, x: Tensor, lengths: List[int]):
128
+ """
129
+ split the unified feature vector back to its original parts
130
+ """
131
+ return torch.split(x, lengths, dim=-1)
132
+
133
+ def norm_inputs(self, x_list: List[Tensor], names: List[str]):
134
+ """
135
+ Normalise inputs using the self.stats metrics
136
+ """
137
+ x_norm = []
138
+ for x, name in zip(x_list, names):
139
+
140
+ x_norm.append(self.norm(x, self.stats[name]))
141
+ return x_norm
142
+
143
+ def unnorm_inputs(self, x_list: List[Tensor], names: List[str]):
144
+ """
145
+ Un-normalise inputs using the self.stats metrics
146
+ """
147
+ x_unnorm = []
148
+ for x, name in zip(x_list, names):
149
+ x_unnorm.append(self.unnorm(x, self.stats[name]))
150
+ return x_unnorm
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  spaces
2
  gradio==4.36.1
3
  torch
 
 
1
  spaces
2
  gradio==4.36.1
3
  torch
4
+ transformers==4.41.2
text_encoder.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+ class ClipTextEncoder(nn.Module):
8
+ def __init__(
9
+ self,
10
+ modelpath: str='deps/clip-vit-large-patch14', # clip-vit-base-patch32
11
+ finetune: bool = False,
12
+ **kwargs
13
+ ) -> None:
14
+
15
+ super().__init__()
16
+ from transformers import logging
17
+ from transformers import AutoModel, AutoTokenizer
18
+ logging.set_verbosity_error()
19
+ # Tokenizer
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+
22
+ self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
23
+ self.text_model = AutoModel.from_pretrained(modelpath)
24
+
25
+ # Don't train the model
26
+ if not finetune:
27
+ self.text_model.training = False
28
+ for p in self.text_model.parameters():
29
+ p.requires_grad = False
30
+
31
+ # Then configure the model
32
+ self.max_length = self.tokenizer.model_max_length
33
+ self.text_encoded_dim = self.text_model.config.text_config.hidden_size
34
+
35
+ def forward(self, texts: List[str]):
36
+ # get prompt text embeddings
37
+ text_inputs = self.tokenizer(
38
+ texts,
39
+ padding="max_length",
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_tensors="pt",
43
+ )
44
+ text_input_ids = text_inputs.input_ids.to(self.text_model.device)
45
+ txt_att_mask = text_inputs.attention_mask.to(self.text_model.device)
46
+ # split into max length Clip can handle
47
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
48
+ text_input_ids = text_input_ids[:, :self.tokenizer.
49
+ model_max_length]
50
+
51
+ # use pooled ouuput if latent dim is two-dimensional
52
+ # pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim)
53
+ # text encoder forward, clip must use get_text_features
54
+ # (batch_Size, seq_length , text_encoded_dim)
55
+ text_embeddings = self.text_model.text_model(text_input_ids,
56
+ # attention_mask=txt_att_mask
57
+ ).last_hidden_state
58
+
59
+ return text_embeddings, txt_att_mask.bool()
tmed_denoiser.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model_utils import TimestepEmbedderMDM
4
+ from model_utils import PositionalEncoding
5
+
6
+ class TMED_denoiser(nn.Module):
7
+
8
+ def __init__(self,
9
+ nfeats: int = 207,
10
+ condition: str = "text",
11
+ latent_dim: list = 512,
12
+ ff_size: int = 1024,
13
+ num_layers: int = 8,
14
+ num_heads: int = 4,
15
+ dropout: float = 0.1,
16
+ activation: str = "gelu",
17
+ text_encoded_dim: int = 768,
18
+ pred_delta_motion: bool = False,
19
+ use_sep: bool = True,
20
+ **kwargs) -> None:
21
+
22
+ super().__init__()
23
+ self.latent_dim = latent_dim
24
+ self.pred_delta_motion = pred_delta_motion
25
+ self.text_encoded_dim = text_encoded_dim
26
+ self.condition = condition
27
+ self.feat_comb_coeff = nn.Parameter(torch.tensor([1.0]))
28
+ self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim)
29
+ self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim)
30
+ self.pose_proj_out = nn.Linear(self.latent_dim, nfeats)
31
+
32
+ # emb proj
33
+ if self.condition in ["text", "text_uncond"]:
34
+ # text condition
35
+ # project time from text_encoded_dim to latent_dim
36
+ self.embed_timestep = TimestepEmbedderMDM(self.latent_dim)
37
+
38
+ # FIXME me TODO this
39
+ # self.time_embedding = TimestepEmbedderMDM(self.latent_dim)
40
+
41
+ # project time+text to latent_dim
42
+ if text_encoded_dim != self.latent_dim:
43
+ # todo 10.24 debug why relu
44
+ self.emb_proj = nn.Linear(text_encoded_dim, self.latent_dim)
45
+ else:
46
+ raise TypeError(f"condition type {self.condition} not supported")
47
+ self.use_sep = use_sep
48
+ self.query_pos = PositionalEncoding(self.latent_dim, dropout)
49
+ self.mem_pos = PositionalEncoding(self.latent_dim, dropout)
50
+ if self.use_sep:
51
+ self.sep_token = nn.Parameter(torch.randn(1, self.latent_dim))
52
+
53
+ # use torch transformer
54
+ encoder_layer = nn.TransformerEncoderLayer(
55
+ d_model=self.latent_dim,
56
+ nhead=num_heads,
57
+ dim_feedforward=ff_size,
58
+ dropout=dropout,
59
+ activation=activation)
60
+ self.encoder = nn.TransformerEncoder(encoder_layer,
61
+ num_layers=num_layers)
62
+
63
+ def forward(self,
64
+ noised_motion,
65
+ timestep,
66
+ in_motion_mask,
67
+ text_embeds,
68
+ condition_mask,
69
+ motion_embeds=None,
70
+ lengths=None,
71
+ **kwargs):
72
+ # 0. dimension matching
73
+ # noised_motion [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]]
74
+ bs = noised_motion.shape[0]
75
+ noised_motion = noised_motion.permute(1, 0, 2)
76
+ # 0. check lengths for no vae (diffusion only)
77
+ # if lengths not in [None, []]:
78
+ motion_in_mask = in_motion_mask
79
+
80
+ # time_embedding | text_embedding | frames_source | frames_target
81
+ # 1 * lat_d | max_text * lat_d | max_frames * lat_d | max_frames * lat_d
82
+
83
+
84
+ # 1. time_embeddingno
85
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
86
+ timesteps = timestep.expand(noised_motion.shape[1]).clone()
87
+ time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
88
+ # make it S first
89
+ # time_emb = self.time_embedding(time_emb).unsqueeze(0)
90
+ if self.condition in ["text", "text_uncond"]:
91
+ # make it seq first
92
+ text_embeds = text_embeds.permute(1, 0, 2)
93
+ if self.text_encoded_dim != self.latent_dim:
94
+ # [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim]
95
+ text_emb_latent = self.emb_proj(text_embeds)
96
+ else:
97
+ text_emb_latent = text_embeds
98
+ # source_motion_zeros = torch.zeros(*noised_motion.shape[:2],
99
+ # self.latent_dim,
100
+ # device=noised_motion.device)
101
+ # aux_fake_mask = torch.zeros(condition_mask.shape[0],
102
+ # noised_motion.shape[0],
103
+ # device=noised_motion.device)
104
+ # condition_mask = torch.cat((condition_mask, aux_fake_mask),
105
+ # 1).bool().to(noised_motion.device)
106
+ emb_latent = torch.cat((time_emb, text_emb_latent), 0)
107
+
108
+ if motion_embeds is not None:
109
+ zeroes_mask = (motion_embeds == 0).all(dim=-1)
110
+ if motion_embeds.shape[-1] != self.latent_dim:
111
+ motion_embeds_proj = self.pose_proj_in_source(motion_embeds)
112
+ motion_embeds_proj[zeroes_mask] = 0
113
+ else:
114
+ motion_embeds_proj = motion_embeds
115
+
116
+ else:
117
+ raise TypeError(f"condition type {self.condition} not supported")
118
+ # 4. transformer
119
+ # if self.diffusion_only:
120
+ proj_noised_motion = self.pose_proj_in_target(noised_motion)
121
+
122
+ if self.use_sep:
123
+
124
+ sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs,
125
+ -1)
126
+ xseq = torch.cat((emb_latent, motion_embeds_proj,
127
+ sep_token_batch[None],
128
+ proj_noised_motion), axis=0)
129
+ else:
130
+ xseq = torch.cat((emb_latent, motion_embeds_proj,
131
+ proj_noised_motion), axis=0)
132
+ # if self.ablation_skip_connection:
133
+ # xseq = self.query_pos(xseq)
134
+ # tokens = self.encoder(xseq)
135
+ # else:
136
+ # # adding the timestep embed
137
+ # # [seqlen+1, bs, d]
138
+ # # todo change to query_pos_decoder
139
+ xseq = self.query_pos(xseq)
140
+ # BUILD the mask now
141
+ if motion_embeds is None:
142
+ time_token_mask = torch.ones((bs, time_emb.shape[0]),
143
+ dtype=bool, device=xseq.device)
144
+ aug_mask = torch.cat((time_token_mask,
145
+ condition_mask[:, :text_emb_latent.shape[0]],
146
+ motion_in_mask), 1)
147
+ else:
148
+ time_token_mask = torch.ones((bs, time_emb.shape[0]),
149
+ dtype=bool,
150
+ device=xseq.device)
151
+ if self.use_sep:
152
+ sep_token_mask = torch.ones((bs, self.sep_token.shape[0]),
153
+ dtype=bool,
154
+ device=xseq.device)
155
+ if self.use_sep:
156
+ aug_mask = torch.cat((time_token_mask,
157
+ condition_mask[:, :text_emb_latent.shape[0]],
158
+ condition_mask[:, text_emb_latent.shape[0]:],
159
+ sep_token_mask,
160
+ motion_in_mask,
161
+ ), 1)
162
+ else:
163
+ aug_mask = torch.cat((time_token_mask,
164
+ condition_mask[:, :text_emb_latent.shape[0]],
165
+ condition_mask[:, text_emb_latent.shape[0]:],
166
+ motion_in_mask,
167
+ ), 1)
168
+ tokens = self.encoder(xseq, src_key_padding_mask=~aug_mask)
169
+
170
+ # if self.diffusion_only:
171
+ if motion_embeds is not None:
172
+ denoised_motion_proj = tokens[emb_latent.shape[0]:]
173
+ if self.use_sep:
174
+ useful_tokens = motion_embeds_proj.shape[0]+1
175
+ else:
176
+ useful_tokens = motion_embeds_proj.shape[0]
177
+ denoised_motion_proj = denoised_motion_proj[useful_tokens:]
178
+ else:
179
+ denoised_motion_proj = tokens[emb_latent.shape[0]:]
180
+
181
+ denoised_motion = self.pose_proj_out(denoised_motion_proj)
182
+ if self.pred_delta_motion and motion_embeds is not None:
183
+ import torch.nn.functional as F
184
+ tgt_size = len(denoised_motion)
185
+ if len(denoised_motion) > len(motion_embeds):
186
+ pad_for_src = tgt_size - len(motion_embeds)
187
+ motion_embeds = F.pad(motion_embeds,
188
+ (0, 0, 0, 0, 0, pad_for_src))
189
+ denoised_motion = denoised_motion + motion_embeds[:tgt_size]
190
+
191
+ denoised_motion[~motion_in_mask.T] = 0
192
+ # zero for padded area
193
+ # else:
194
+ # sample = tokens[:sample.shape[0]]
195
+ # 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]]
196
+ denoised_motion = denoised_motion.permute(1, 0, 2)
197
+ return denoised_motion
198
+
199
+ def forward_with_guidance(self,
200
+ noised_motion,
201
+ timestep,
202
+ in_motion_mask,
203
+ text_embeds,
204
+ condition_mask,
205
+ guidance_motion,
206
+ guidance_text_n_motion,
207
+ motion_embeds=None,
208
+ lengths=None,
209
+ inpaint_dict=None,
210
+ max_steps=None,
211
+ prob_way='3way',
212
+ **kwargs):
213
+ # if motion embeds is None
214
+ # TODO put here that you have tow
215
+ # implement 2 cases for that case
216
+ # text unconditional more or less 2 replicas
217
+ # timestep
218
+ if max_steps is not None:
219
+ curr_ts = timestep[0].item()
220
+ g_m = max(1, guidance_motion*2*curr_ts/max_steps)
221
+ guidance_motion = g_m
222
+ g_t_tm = max(1, guidance_text_n_motion*2*curr_ts/max_steps)
223
+ guidance_text_n_motion = g_t_tm
224
+
225
+ if motion_embeds is None:
226
+ half = noised_motion[: len(noised_motion) // 2]
227
+ combined = torch.cat([half, half], dim=0)
228
+ model_out = self.forward(combined, timestep,
229
+ in_motion_mask=in_motion_mask,
230
+ text_embeds=text_embeds,
231
+ condition_mask=condition_mask,
232
+ motion_embeds=motion_embeds,
233
+ lengths=lengths)
234
+ uncond_eps, cond_eps_text = torch.split(model_out, len(model_out) // 2,
235
+ dim=0)
236
+ # make it BxSxfeatures
237
+ if inpaint_dict is not None:
238
+ import torch.nn.functional as F
239
+ source_mot = inpaint_dict['start_motion'].permute(1, 0, 2)
240
+ if source_mot.shape[1] >= uncond_eps.shape[1]:
241
+ source_mot = source_mot[:, :uncond_eps.shape[1]]
242
+ else:
243
+ pad = uncond_eps.shape[1] - source_mot.shape[1]
244
+ # Pad the tensor on the second dimension (time)
245
+ source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0)
246
+
247
+ mot_len = source_mot.shape[1]
248
+ # concat mask for all the frames
249
+ mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
250
+ mot_len,
251
+ 1)
252
+ uncond_eps = uncond_eps*(~mask_src_parts) + source_mot*mask_src_parts
253
+ cond_eps_text = cond_eps_text*(~mask_src_parts) + source_mot*mask_src_parts
254
+ half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps)
255
+ eps = torch.cat([half_eps, half_eps], dim=0)
256
+ else:
257
+ third = noised_motion[: len(noised_motion) // 3]
258
+ combined = torch.cat([third, third, third], dim=0)
259
+ model_out = self.forward(combined, timestep,
260
+ in_motion_mask=in_motion_mask,
261
+ text_embeds=text_embeds,
262
+ condition_mask=condition_mask,
263
+ motion_embeds=motion_embeds,
264
+ lengths=lengths)
265
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
266
+ # three channels by default. The standard approach to cfg applies it to all channels.
267
+ # This can be done by uncommenting the following line and commenting-out the line following that.
268
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
269
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
270
+ uncond_eps, cond_eps_motion, cond_eps_text_n_motion = torch.split(model_out,
271
+ len(model_out) // 3,
272
+ dim=0)
273
+ if inpaint_dict is not None:
274
+ import torch.nn.functional as F
275
+ source_mot = inpaint_dict['start_motion'].permute(1, 0, 2)
276
+ if source_mot.shape[1] >= uncond_eps.shape[1]:
277
+ source_mot = source_mot[:, :uncond_eps.shape[1]]
278
+ else:
279
+ pad = uncond_eps.shape[1] - source_mot.shape[1]
280
+ # Pad the tensor on the second dimension (time)
281
+ source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0)
282
+
283
+ mot_len = source_mot.shape[1]
284
+ # concat mask for all the frames
285
+ mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
286
+ mot_len,
287
+ 1)
288
+ uncond_eps = uncond_eps*(~mask_src_parts) + source_mot*mask_src_parts
289
+ cond_eps_text = cond_eps_text*(~mask_src_parts) + source_mot*mask_src_parts
290
+ cond_eps_text_n_motion = cond_eps_text_n_motion*(~mask_src_parts) + source_mot*mask_src_parts
291
+ if prob_way=='3way':
292
+ third_eps = uncond_eps + guidance_motion * (cond_eps_motion - uncond_eps) + \
293
+ guidance_text_n_motion * (cond_eps_text_n_motion - cond_eps_motion)
294
+ if prob_way=='2way':
295
+ third_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text_n_motion - uncond_eps)
296
+
297
+ eps = torch.cat([third_eps, third_eps, third_eps], dim=0)
298
+ return eps
299
+
300
+ def _diffusion_reverse(self, text_embeds, text_masks_from_enc,
301
+ motion_embeds, cond_motion_masks,
302
+ inp_motion_mask, diff_process,
303
+ init_vec=None,
304
+ init_from='noise',
305
+ gd_text=None, gd_motion=None,
306
+ mode='full_cond',
307
+ return_init_noise=False,
308
+ steps_num=None,
309
+ inpaint_dict=None,
310
+ use_linear=False,
311
+ prob_way='3way'):
312
+ # guidance_scale_text: 7.5 #
313
+ # guidance_scale_motion: 1.5
314
+ # init latents
315
+
316
+ bsz = inp_motion_mask.shape[0]
317
+ assert mode in ['full_cond', 'text_cond', 'mot_cond']
318
+ assert inp_motion_mask is not None
319
+ # len_to_gen = max(lengths) if not self.input_deltas else max(lengths) + 1
320
+ if init_vec is None:
321
+ initial_latents = torch.randn(
322
+ (bsz, inp_motion_mask.shape[1], 207),
323
+ device=inp_motion_mask.device,
324
+ dtype=torch.float,
325
+ )
326
+ else:
327
+ initial_latents = init_vec
328
+
329
+ gd_scale_text = 2.0
330
+ gd_scale_motion = 4.0
331
+
332
+ if text_embeds is not None:
333
+ max_text_len = text_embeds.shape[1]
334
+ else:
335
+ max_text_len = 0
336
+ max_motion_len = cond_motion_masks.shape[1]
337
+ text_masks = text_masks_from_enc.clone()
338
+ nomotion_mask = torch.zeros(bsz, max_motion_len,
339
+ dtype=torch.bool).to('cuda')
340
+ motion_masks = torch.cat([nomotion_mask,
341
+ cond_motion_masks,
342
+ cond_motion_masks],
343
+ dim=0)
344
+ aug_mask = torch.cat([text_masks,
345
+ motion_masks],
346
+ dim=1)
347
+
348
+
349
+ # Setup classifier-free guidance:
350
+ if motion_embeds is not None:
351
+ z = torch.cat([initial_latents, initial_latents, initial_latents], 0)
352
+ else:
353
+ z = torch.cat([initial_latents, initial_latents], 0)
354
+
355
+ # y_null = torch.tensor([1000] * n, device=device)
356
+ # y = torch.cat([y, y_null], 0)
357
+ if use_linear:
358
+ max_steps_diff = diff_process.num_timesteps
359
+ else:
360
+ max_steps_diff = None
361
+ if motion_embeds is not None:
362
+ model_kwargs = dict(# noised_motion=latent_model_input,
363
+ # timestep=t,
364
+ in_motion_mask=torch.cat([inp_motion_mask,
365
+ inp_motion_mask,
366
+ inp_motion_mask], 0),
367
+ text_embeds=text_embeds,
368
+ condition_mask=aug_mask,
369
+ motion_embeds=torch.cat([torch.zeros_like(motion_embeds),
370
+ motion_embeds,
371
+ motion_embeds], 1),
372
+ guidance_motion=gd_motion,
373
+ guidance_text_n_motion=gd_text,
374
+ inpaint_dict=inpaint_dict,
375
+ max_steps=max_steps_diff,
376
+ prob_way=prob_way)
377
+ else:
378
+ model_kwargs = dict(# noised_motion=latent_model_input,
379
+ # timestep=t,
380
+ in_motion_mask=torch.cat([inp_motion_mask,
381
+ inp_motion_mask], 0),
382
+ text_embeds=text_embeds,
383
+ condition_mask=aug_mask,
384
+ motion_embeds=None,
385
+ guidance_motion=gd_motion,
386
+ guidance_text_n_motion=gd_text,
387
+ inpaint_dict=inpaint_dict,
388
+ max_steps=max_steps_diff)
389
+
390
+ # model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
391
+ # Sample images:
392
+ samples = diff_process.p_sample_loop(self.forward_with_guidance,
393
+ z.shape, z,
394
+ clip_denoised=False,
395
+ model_kwargs=model_kwargs,
396
+ progress=True,
397
+ device=initial_latents.device,)
398
+ _, _, samples = samples.chunk(3, dim=0) # Remove null class samples
399
+
400
+ final_diffout = samples.permute(1, 0, 2)
401
+ if return_init_noise:
402
+ return initial_latents, final_diffout
403
+ else:
404
+ return final_diffout