import os import sys from contextlib import contextmanager import torch import torch.nn.functional as F from einops import rearrange from torch import nn from torchdyn.core import NeuralODE from tts.layers.ffn import (AdaLNFinalLayer, AdaLNMLP, GaussianFourierTimeEmbedding) @contextmanager def suppress_stdout(): original_stdout = sys.stdout try: sys.stdout = open(os.devnull, "w") yield finally: sys.stdout.close() sys.stdout = original_stdout def sample_from_logits( logits: torch.Tensor, temperature: float = 1.0, top_k: int = 0, top_p: float = 0.0, ) -> torch.Tensor: B, N, C = logits.shape logits = logits / temperature # Apply top-k if top_k > 0: top_k = min(top_k, C) topk_values, _ = torch.topk(logits, top_k, dim=-1) kth_value = topk_values[..., -1, None] logits = torch.where( logits < kth_value, torch.full_like(logits, float("-inf")), logits ) # Apply top-p (nucleus) sampling if top_p > 0.0 and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) # Create mask for tokens to remove cutoff_mask = cumulative_probs > top_p cutoff_mask[..., 0] = 0 # Always keep at least one token sorted_logits[cutoff_mask] = float("-inf") # Map back to original logits shape logits = torch.full_like(logits, float("-inf")).scatter( -1, sorted_indices, sorted_logits ) # Convert logits to probabilities probs = F.softmax(logits, dim=-1) # Sample samples = torch.multinomial(probs.view(-1, C), num_samples=1).view(B, N) return samples class LogitsHead(nn.Module): def __init__(self, hidden_dim: int, vocab_size: int): super().__init__() self.logits_proj = nn.Linear(hidden_dim, vocab_size) def forward(self, pre_logits): return self.logits_proj(pre_logits) def compute_loss( self, pre_logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None, ): logits = self(pre_logits) if mask is not None: flat_logits = logits[mask] flat_target = target[mask] else: flat_logits = rearrange(logits, "b n l -> (b n) l") flat_target = rearrange(target, "b n -> (b n)") loss = nn.functional.cross_entropy( flat_logits, flat_target, ) return {"cross_entropy": loss} def predict(self, x: torch.Tensor, *args, **kwargs): return sample_from_logits(self(x), *args, **kwargs) class ContinuousHead(nn.Module): def __init__(self, hidden_dim: int, feature_dim: int): super().__init__() self.continuous_head = nn.Linear(hidden_dim, feature_dim) def forward(self, x: torch.Tensor): return self.continuous_head(x) def predict(self, x: torch.Tensor): return self(x) def compute_loss( self, pre_logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor ): if mask is not None: pre_logits = pre_logits[mask] target = target[mask] return {"mse": nn.functional.mse_loss(self(pre_logits), target)} class VelocityHead(nn.Module): def __init__( self, hidden_dim: int, feature_dim: int, num_layers: int, cond_dim: int | None = None, ): super().__init__() cond_dim = cond_dim if cond_dim is not None else hidden_dim self.feature_embed = nn.Linear(feature_dim, hidden_dim) self.cond_embed = nn.Linear(cond_dim, hidden_dim) self.time_embed = GaussianFourierTimeEmbedding(hidden_dim // 2) self.adaln_mlp = nn.ModuleList( [AdaLNMLP(hidden_dim) for _ in range(num_layers)] ) self.adaln_final_layer = AdaLNFinalLayer(hidden_dim, feature_dim) self.feature_dim = feature_dim def forward( self, cond: torch.Tensor, x: torch.Tensor, t: torch.Tensor | None = None, cond_drop_mask: torch.BoolTensor | None = None, ): cond = self.cond_embed(cond) if cond_drop_mask is not None: cond[cond_drop_mask] = 0.0 cond += self.time_embed(t)[:, None] x = self.feature_embed(x) for l in self.adaln_mlp: x = l(x, cond) y = self.adaln_final_layer(x, cond) return y def compute_loss( self, cond: torch.Tensor, x1: torch.Tensor, mask: torch.Tensor | None, sigma: float = 1e-5, t: torch.Tensor | None = None, x0: torch.Tensor | None = None, cfg_drop_rate: float = 0.1, ): """ CFM Loss """ if t is None: t = torch.rand(cond.shape[0], device=cond.device) if x0 is None: x0 = torch.randn_like(x1, device=x1.device) flow_target = x1 - (1 - sigma) * x0 alpha = (1 - (1 - sigma) * t).view(-1, 1, 1) xt = alpha * x0 + t.view(-1, 1, 1) * x1 if self.training and cfg_drop_rate > 0.0: cond_drop_mask = torch.rand(cond.shape[:2]) < cfg_drop_rate else: cond_drop_mask = None flow_pred = self(cond, xt, t, cond_drop_mask=cond_drop_mask) if mask is not None: flow_pred = flow_pred[mask] flow_target = flow_target[mask] loss = nn.functional.mse_loss(flow_pred, flow_target) return {"diffusion": loss} def predict( self, pre_prediction: torch.Tensor, pre_prediction_ref: torch.Tensor | None = None, solver: str = "euler", sensitivity: str = "adjoint", num_steps: int = 10, cfg: float = 1.0, cfg_ref: float = 1.5, temperature: float = 1.0, **kwargs, ): if cfg == 1.0: if pre_prediction_ref is None: def solver_fn(t, Xt, *args, **kwargs): return self(pre_prediction, Xt, t.unsqueeze(0)) else: raise NotImplementedError else: if pre_prediction_ref is None: def solver_fn(t, Xt, *args, **kwargs): cond_uncond = torch.cat( (pre_prediction, torch.zeros_like(pre_prediction)), dim=0, ) cond_uncond = self(cond_uncond, Xt.repeat(2, 1, 1), t.unsqueeze(0)) cond, uncond = cond_uncond.chunk(2, dim=0) cond_uncond_cfg = uncond + cfg * (cond - uncond) return cond_uncond_cfg else: def solver_fn(t, Xt, *args, **kwargs): cond_uncond_ref = torch.cat((pre_prediction, pre_prediction_ref, torch.zeros_like(pre_prediction))) #cond_uncond_ref, = torch.cat( # (pre_prediction, torch.zeros_like(pre_prediction), pre_prediction_ref), # dim=0, #) cond_uncond = self(cond_uncond_ref, Xt.repeat(3, 1, 1), t.unsqueeze(0)) cond, ref, uncond = cond_uncond.chunk(3, dim=0) #cond_uncond_cfg = uncond + cfg * (cond - uncond) #cond_uncond_cfg_ref_cfg = ref + cfg_ref * (cond_uncond_cfg - ref) cond_uncond_cfg = ref + cfg_ref * (cond - ref) cond_uncond_cfg_ref_cfg = uncond + cfg * (cond_uncond_cfg - uncond) return cond_uncond_cfg_ref_cfg # get rid of torchdyn warning with suppress_stdout(): node_ = NeuralODE(solver_fn, solver=solver, sensitivity=sensitivity) t_span = torch.linspace(0, 1, num_steps + 1, device=pre_prediction.device) traj = node_.trajectory( torch.randn(pre_prediction.shape[0], 1, self.feature_dim, device=pre_prediction.device) * temperature, t_span=t_span, ) prediction = traj[-1] return prediction class StopPredictionHead(nn.Module): def __init__(self, dim: int, weight_loss: float = 1.0): super().__init__() self.proj = nn.Linear(dim, 1) self.weight_loss = weight_loss def forward(self, pre_prediction: torch.Tensor): return torch.sigmoid(self.proj(pre_prediction)) def predict(self, pre_prediction: torch.Tensor): return torch.sigmoid(self.proj(pre_prediction)) def compute_loss( self, pre_prediction: torch.Tensor, target: torch.Tensor, ): logits = self.proj(pre_prediction) bce = nn.functional.binary_cross_entropy_with_logits( logits.squeeze(-1), target.to(logits.dtype), weight=torch.ones(logits.shape[0], device=logits.device) * self.weight_loss, ) return {"stop_bce": bce}