| | from abc import ABC |
| | from typing import Optional |
| |
|
| | import math |
| | import torch |
| | from torch import _dynamo |
| | _dynamo.config.suppress_errors = True |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from torch.nn.functional import mse_loss, l1_loss, binary_cross_entropy, cross_entropy, kl_div, nll_loss |
| | from pyro.distributions.conjugate import BetaBinomial |
| | from pyro.distributions import Normal |
| | from torch_geometric.nn import MessagePassing |
| |
|
| |
|
| | class NeighborEmbedding(MessagePassing, ABC): |
| | def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper): |
| | super(NeighborEmbedding, self).__init__(aggr="add") |
| | self.distance_proj = nn.Linear(num_rbf, hidden_channels) |
| | self.combine = nn.Linear(hidden_channels * 2, hidden_channels) |
| | self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | nn.init.xavier_uniform_(self.distance_proj.weight) |
| | nn.init.xavier_uniform_(self.combine.weight) |
| | self.distance_proj.bias.data.fill_(0) |
| | self.combine.bias.data.fill_(0) |
| |
|
| | def forward(self, x, edge_index, edge_weight, edge_attr): |
| | |
| | mask = edge_index[0] != edge_index[1] |
| | if not mask.all(): |
| | edge_index = edge_index[:, mask] |
| | edge_weight = edge_weight[mask] |
| | edge_attr = edge_attr[mask] |
| |
|
| | C = self.cutoff(edge_weight) |
| | W = self.distance_proj(edge_attr) * C.view(-1, 1) |
| |
|
| | x_neighbors = x |
| | |
| | x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W, size=None) |
| | x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) |
| | return x_neighbors |
| |
|
| | def message(self, x_j, W): |
| | return x_j * W |
| |
|
| |
|
| | class GaussianSmearing(nn.Module): |
| | def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): |
| | super(GaussianSmearing, self).__init__() |
| | self.cutoff_lower = cutoff_lower |
| | self.cutoff_upper = cutoff_upper |
| | self.num_rbf = num_rbf |
| | self.trainable = trainable |
| |
|
| | offset, coeff = self._initial_params() |
| | if trainable: |
| | self.register_parameter("coeff", nn.Parameter(coeff)) |
| | self.register_parameter("offset", nn.Parameter(offset)) |
| | else: |
| | self.register_buffer("coeff", coeff) |
| | self.register_buffer("offset", offset) |
| |
|
| | def _initial_params(self): |
| | offset = torch.linspace(self.cutoff_lower, self.cutoff_upper, self.num_rbf) |
| | coeff = -0.5 / (offset[1] - offset[0]) ** 2 |
| | return offset, coeff |
| |
|
| | def reset_parameters(self): |
| | offset, coeff = self._initial_params() |
| | self.offset.data.copy_(offset) |
| | self.coeff.data.copy_(coeff) |
| |
|
| | def forward(self, dist): |
| | dist = dist.unsqueeze(-1) - self.offset |
| | return torch.exp(self.coeff * torch.pow(dist, 2)) |
| |
|
| |
|
| | class ExpNormalSmearing(nn.Module): |
| | def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): |
| | super(ExpNormalSmearing, self).__init__() |
| | self.cutoff_lower = cutoff_lower |
| | self.cutoff_upper = cutoff_upper |
| | self.num_rbf = num_rbf |
| | self.trainable = trainable |
| |
|
| | self.cutoff_fn = CosineCutoff(0, cutoff_upper) |
| | self.alpha = 5.0 / (cutoff_upper - cutoff_lower) |
| |
|
| | means, betas = self._initial_params() |
| | if trainable: |
| | self.register_parameter("means", nn.Parameter(means)) |
| | self.register_parameter("betas", nn.Parameter(betas)) |
| | else: |
| | self.register_buffer("means", means) |
| | self.register_buffer("betas", betas) |
| |
|
| | def _initial_params(self): |
| | |
| | |
| | start_value = torch.exp( |
| | torch.scalar_tensor(-self.cutoff_upper + self.cutoff_lower) |
| | ) |
| | means = torch.linspace(start_value, 1, self.num_rbf) |
| | betas = torch.tensor( |
| | [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf |
| | ) |
| | return means, betas |
| |
|
| | def reset_parameters(self): |
| | means, betas = self._initial_params() |
| | self.means.data.copy_(means) |
| | self.betas.data.copy_(betas) |
| |
|
| | def forward(self, dist): |
| | dist = dist.unsqueeze(-1) |
| | return self.cutoff_fn(dist) * torch.exp( |
| | -self.betas |
| | * (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2 |
| | ) |
| |
|
| |
|
| | class ExpNormalSmearingUnlimited(nn.Module): |
| | def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): |
| | super(ExpNormalSmearingUnlimited, self).__init__() |
| | self.num_rbf = num_rbf |
| | self.trainable = trainable |
| |
|
| | self.alpha = 1 / 20 |
| |
|
| | means, betas = self._initial_params() |
| | if trainable: |
| | self.register_parameter("means", nn.Parameter(means)) |
| | self.register_parameter("betas", nn.Parameter(betas)) |
| | else: |
| | self.register_buffer("means", means) |
| | self.register_buffer("betas", betas) |
| |
|
| | def _initial_params(self): |
| | |
| | |
| | start_value = 0.1 |
| | means = torch.linspace(start_value, 1, self.num_rbf) |
| | betas = torch.tensor( |
| | [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf |
| | ) |
| | return means, betas |
| |
|
| | def reset_parameters(self): |
| | means, betas = self._initial_params() |
| | self.means.data.copy_(means) |
| | self.betas.data.copy_(betas) |
| |
|
| | def forward(self, dist): |
| | dist = dist.unsqueeze(-1) |
| | return torch.exp( |
| | -self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2 |
| | ) |
| |
|
| |
|
| | class ShiftedSoftplus(nn.Module): |
| | def __init__(self): |
| | super(ShiftedSoftplus, self).__init__() |
| | self.shift = torch.log(torch.tensor(2.0)).item() |
| |
|
| | def forward(self, x): |
| | return F.softplus(x) - self.shift |
| |
|
| |
|
| | class CosineCutoff(nn.Module): |
| | def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0): |
| | super(CosineCutoff, self).__init__() |
| | self.cutoff_lower = cutoff_lower |
| | self.cutoff_upper = cutoff_upper |
| |
|
| | def forward(self, distances): |
| | if self.cutoff_lower > 0: |
| | cutoffs = 0.5 * ( |
| | torch.cos( |
| | math.pi |
| | * ( |
| | 2 |
| | * (distances - self.cutoff_lower) |
| | / (self.cutoff_upper - self.cutoff_lower) |
| | + 1.0 |
| | ) |
| | ) |
| | + 1.0 |
| | ) |
| | |
| | cutoffs = cutoffs * (distances < self.cutoff_upper).float() |
| | cutoffs = cutoffs * (distances > self.cutoff_lower).float() |
| | return cutoffs |
| | else: |
| | cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0) |
| | |
| | cutoffs = cutoffs * (distances < self.cutoff_upper).float() |
| | return cutoffs |
| |
|
| |
|
| | class Distance(nn.Module): |
| | def __init__( |
| | self, |
| | cutoff_lower, |
| | cutoff_upper, |
| | return_vecs=False, |
| | loop=False, |
| | ): |
| | super(Distance, self).__init__() |
| | self.cutoff_lower = cutoff_lower |
| | self.cutoff_upper = cutoff_upper |
| | self.return_vecs = return_vecs |
| | self.loop = loop |
| |
|
| | def forward(self, pos, edge_index): |
| | edge_vec = pos[edge_index[0]] - pos[edge_index[1]] |
| |
|
| | mask: Optional[torch.Tensor] = None |
| | if self.loop: |
| | |
| | |
| | |
| | mask = edge_index[0] != edge_index[1] |
| | edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device, dtype=edge_vec.dtype) |
| | edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) |
| | else: |
| | edge_weight = torch.norm(edge_vec, dim=-1) |
| |
|
| | lower_mask = edge_weight >= self.cutoff_lower |
| | if self.loop and mask is not None: |
| | |
| | lower_mask = lower_mask | ~mask |
| | edge_index = edge_index[:, lower_mask] |
| | edge_weight = edge_weight[lower_mask] |
| |
|
| | if self.return_vecs: |
| | edge_vec = edge_vec[lower_mask] |
| | return edge_index, edge_weight, edge_vec |
| | |
| | |
| | return edge_index, edge_weight, None |
| |
|
| |
|
| | class DistanceV2(nn.Module): |
| | def __init__( |
| | self, |
| | return_vecs=True, |
| | loop=False, |
| | ): |
| | super(DistanceV2, self).__init__() |
| | self.return_vecs = return_vecs |
| | self.loop = loop |
| |
|
| | def forward(self, pos, coords, edge_index): |
| | |
| | |
| | ca_ca = pos[edge_index[1]] - pos[edge_index[0]] |
| | cb_cb = coords[edge_index[1], :, [0]] - coords[edge_index[0], :, [0]] |
| | cb_N = coords[edge_index[1], :, [2]] - coords[edge_index[0], :, [0]] |
| | cb_O = coords[edge_index[1], :, [3]] - coords[edge_index[0], :, [0]] |
| | edge_vec = torch.cat([ca_ca.unsqueeze(-1), |
| | cb_cb.unsqueeze(-1), |
| | cb_N.unsqueeze(-1), |
| | cb_O.unsqueeze(-1)], dim=-1) |
| | mask: Optional[torch.Tensor] = None |
| | if self.loop: |
| | mask = edge_index[0] != edge_index[1] |
| | edge_weight = torch.zeros(ca_ca.size(0), device=ca_ca.device, dtype=ca_ca.dtype) |
| | edge_weight[mask] = torch.norm(ca_ca[mask], dim=-1) |
| | else: |
| | edge_weight = torch.norm(ca_ca, dim=-1) |
| |
|
| | return edge_index, edge_weight, edge_vec |
| |
|
| |
|
| |
|
| | rbf_class_mapping = {"gauss": GaussianSmearing, "expnorm": ExpNormalSmearing, "expnormunlim": ExpNormalSmearingUnlimited} |
| |
|
| |
|
| | class AbsTanh(nn.Module): |
| | def __init__(self): |
| | super(AbsTanh, self).__init__() |
| |
|
| | @staticmethod |
| | def forward(x: torch.Tensor) -> torch.Tensor: |
| | return torch.abs(torch.tanh(x)) |
| |
|
| |
|
| | class Tanh2(nn.Module): |
| | def __init__(self): |
| | super(Tanh2, self).__init__() |
| |
|
| | @staticmethod |
| | def forward(x: torch.Tensor) -> torch.Tensor: |
| | return torch.square(torch.tanh(x)) |
| |
|
| | def gelu(x): |
| | """Implementation of the gelu activation function. |
| | |
| | For information: OpenAI GPT's gelu is slightly different |
| | (and gives slightly different results): |
| | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
| | """ |
| | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
| |
|
| | act_class_mapping = { |
| | "ssp": ShiftedSoftplus, |
| | "softplus": nn.Softplus, |
| | "silu": nn.SiLU, |
| | "leaky_relu": nn.LeakyReLU, |
| | "tanh": nn.Tanh, |
| | "sigmoid": nn.Sigmoid, |
| | "pass": nn.Identity, |
| | "abs_tanh": AbsTanh, |
| | "tanh2": Tanh2, |
| | "softmax": nn.Softmax, |
| | "gelu": nn.GELU, |
| | } |
| |
|
| |
|
| | def cosin_contrastive_loss(input, target, margin=0): |
| | if target.ndim == 1: |
| | target = target.unsqueeze(1) |
| | if input.shape[0] == 1: |
| | return torch.tensor(0, dtype=input.dtype, device=input.device) |
| | |
| | dist = F.cosine_similarity(input.unsqueeze(1), input.unsqueeze(0), dim=2) |
| | |
| | sim = torch.eq(target, target.T) |
| | |
| | sim = sim.float() * 2 - 1 |
| | |
| | loss = - dist * sim + (sim + 1) / 2 + (sim - 1) * margin / 2 |
| | |
| | loss = torch.clamp(loss.triu(diagonal=1), min=0).sum() / (target.shape[0] * (target.shape[0] - 1) / 2) |
| | return loss |
| |
|
| |
|
| | def euclid_contrastive_loss(input, target): |
| | if target.ndim == 1: |
| | target = target.unsqueeze(1) |
| | if input.shape[0] == 1: |
| | return torch.tensor(0, dtype=input.dtype, device=input.device) |
| | |
| | margin = 10 * input.shape[1] |
| | |
| | dist = torch.cdist(input, input) |
| | |
| | sim = torch.eq(target, target.T) |
| | |
| | sim = sim.float() * 2 - 1 |
| | |
| | mask = (dist > margin).float() * (sim == -1).float() |
| | loss = dist * sim * (1 - mask) |
| | |
| | loss = loss.triu(diagonal=1).sum() / (target.shape[0] * (target.shape[0] - 1) / 2) |
| | return loss |
| |
|
| |
|
| | class WeightedCombinedLoss(nn.modules.loss._WeightedLoss): |
| | """ |
| | Weighted combined loss function. |
| | Input weight should be a tensor of shape (5,). |
| | The first 2 weights are for the patho/beni loss |
| | The last 3 weights are for the beni/gof/lof loss |
| | """ |
| | def __init__(self, weight: Optional[torch.Tensor] = None, |
| | task_weight: float = 10.0, |
| | size_average=None, ignore_index: int = -100, |
| | reduce=None, reduction: str = 'mean') -> None: |
| | super().__init__(weight, size_average, reduce, reduction) |
| | self.ignore_index = ignore_index |
| | self.task_weight = task_weight |
| |
|
| | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| | return combined_loss(input, target, |
| | weight_1=self.weight[:2], |
| | weight_2=self.weight[2:], |
| | weight=self.task_weight, |
| | reduction=self.reduction) |
| |
|
| |
|
| | class WeightedLoss1(nn.modules.loss._WeightedLoss): |
| | """ |
| | Weighted combined loss function. |
| | Input weight should be a tensor of shape (5,). |
| | The first 2 weights are for the patho/beni loss |
| | The last 3 weights are for the beni/gof/lof loss |
| | """ |
| | def __init__(self, weight: Optional[torch.Tensor] = None, |
| | task_weight: float = 10.0, |
| | size_average=None, ignore_index: int = -100, |
| | reduce=None, reduction: str = 'mean') -> None: |
| | super().__init__(weight, size_average, reduce, reduction) |
| | self.ignore_index = ignore_index |
| | self.task_weight = task_weight |
| |
|
| | def forward(self, input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor=None, reduce=True, reduction=None) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if reduction is None: |
| | reduction = self.reduction |
| | weight_1 = self.weight[:2] |
| | target_1 = (target).float() |
| | weight_loss_1 = torch.ones_like(target_1, dtype=input.dtype, device=input.device) |
| | weight_loss_1[target == 1] = weight_1[1] / weight_1[0] |
| | if weight is not None: |
| | weight_loss_1 *= weight |
| | loss_1 = binary_cross_entropy(input=input, |
| | target=target_1, |
| | weight=weight_loss_1, |
| | reduce=reduce, |
| | reduction=reduction) |
| | return loss_1 |
| |
|
| |
|
| | class WeightedLoss2(nn.modules.loss._WeightedLoss): |
| | """ |
| | Weighted combined loss function. |
| | Input weight should be a tensor of shape (5,). |
| | The first 2 weights are for the patho/beni loss |
| | The last 3 weights are for the beni/gof/lof loss |
| | """ |
| | def __init__(self, weight: Optional[torch.Tensor] = None, |
| | task_weight: float = 10.0, |
| | size_average=None, ignore_index: int = -100, |
| | reduce=None, reduction: str = 'mean') -> None: |
| | super().__init__(weight, size_average, reduce, reduction) |
| | self.ignore_index = ignore_index |
| | self.task_weight = task_weight |
| |
|
| | def forward(self, input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor=None, reduce=True, reduction=None) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | target_1 = (-1/3 * target**3 + target**2 + 1/3 * target).float() |
| | if reduction is None: |
| | reduction = self.reduction |
| | loss_1 = binary_cross_entropy(input=input, |
| | target=target_1, |
| | weight=weight, |
| | reduce=reduce, |
| | reduction=reduction) |
| | |
| | filter = (target == -1) | (target == 1) |
| | |
| | if not filter.any(): |
| | return 0 * loss_1 |
| | |
| | weight_2 = self.weight[2:] |
| | |
| | target_2 = (1/2 * (-target + 1)).float() |
| | |
| | weight_loss_2 = torch.ones_like(target_2, dtype=input.dtype, device=input.device) |
| | weight_loss_2[target == 1] = weight_2[1] / weight_2[0] |
| | if weight is not None: |
| | weight_loss_2 *= weight |
| | loss_2 = binary_cross_entropy(input=input[filter], |
| | target=target_2[filter], |
| | weight=weight_loss_2[filter], |
| | reduce=reduce, |
| | reduction=reduction) |
| | return loss_2 |
| |
|
| |
|
| | class WeightedLoss3(nn.modules.loss._WeightedLoss): |
| | """ |
| | Weighted combined loss function. |
| | Input weight should be a tensor of shape (5,). |
| | The first 2 weights are for the patho/beni loss |
| | The last 3 weights are for the beni/gof/lof loss |
| | """ |
| | def __init__(self, weight: Optional[torch.Tensor] = None, |
| | task_weight: float = 10.0, |
| | size_average=None, ignore_index: int = -100, |
| | reduce=None, reduction: str = 'mean') -> None: |
| | super().__init__(weight, size_average, reduce, reduction) |
| | self.ignore_index = ignore_index |
| | self.task_weight = task_weight |
| |
|
| | def forward(self, input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor=None, reduce=True, reduction=None) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | target_1 = (-1/3 * target**3 + target**2 + 1/3 * target).float() |
| | if reduction is None: |
| | reduction = self.reduction |
| | loss_1 = binary_cross_entropy(input=input[:, 0]/(input[:, 0] + input[:, 1]), |
| | target=target_1, |
| | weight=weight, |
| | reduce=reduce, |
| | reduction=reduction) |
| | |
| | filter = (target == -1) | (target == 1) |
| | |
| | if not filter.any(): |
| | return 0 * loss_1 |
| | |
| | weight_2 = self.weight[2:] |
| | |
| | target_2 = (1/2 * (-target + 1)).float() |
| | |
| | weight_loss_2 = torch.ones_like(target_2, dtype=input.dtype, device=input.device) |
| | weight_loss_2[target == 1] = weight_2[1] / weight_2[0] |
| | if weight is not None: |
| | weight_loss_2 *= weight |
| | loss_2 = -BetaBinomial( |
| | concentration1=input[:, 0][filter], |
| | concentration0=input[:, 1][filter], |
| | total_count=1 |
| | ).log_prob(target_2[filter]) |
| | |
| | loss_2 *= weight_loss_2[filter] |
| | |
| | loss_2 = loss_2.mean() |
| | return loss_2 |
| |
|
| |
|
| | class RegressionWeightedLoss(nn.modules.loss._WeightedLoss): |
| | """ |
| | Weighted combined loss function. |
| | Input weight should be a tensor of shape (5,). |
| | The first 2 weights are for the patho/beni loss |
| | The last 3 weights are for the beni/gof/lof loss |
| | """ |
| | def __init__(self, weight: Optional[torch.Tensor] = None, |
| | task_weight: float = 10.0, |
| | size_average=None, ignore_index: int = -100, |
| | reduce=None, reduction: str = 'mean') -> None: |
| | super().__init__(weight, size_average, reduce, reduction) |
| | self.ignore_index = ignore_index |
| | self.task_weight = task_weight |
| |
|
| | def forward(self, input, target) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | regression_target = target[:, 1:] |
| | regression_input = input[:, 1:] |
| | regression_loss = mse_loss(input=regression_input, |
| | target=regression_target, |
| | reduction=self.reduction) |
| | target = target[:, [0]] |
| | input = input[:, [0]] |
| | target_1 = (-1/3 * target**3 + target**2 + 1/3 * target).float() |
| | loss_1 = binary_cross_entropy(input=input, |
| | target=target_1, |
| | reduction=self.reduction) |
| | |
| | filter = (target == -1) | (target == 1) |
| | |
| | if not filter.any(): |
| | return 0 * loss_1 |
| | |
| | weight_2 = self.weight[2:] |
| | |
| | target_2 = (1/2 * (-target + 1)).float() |
| | |
| | weight_loss_2 = torch.ones_like(target_2, dtype=input.dtype, device=input.device) |
| | weight_loss_2[target == 1] = weight_2[1] / weight_2[0] |
| | loss_2 = binary_cross_entropy(input=input[filter], |
| | target=target_2[filter], |
| | weight=weight_loss_2[filter], |
| | reduction=self.reduction) |
| | return loss_2 + regression_loss |
| |
|
| |
|
| | class GPLoss(nn.modules.loss._WeightedLoss): |
| | def __init__(): |
| | super().__init__() |
| |
|
| |
|
| | def combined_loss(input: torch.Tensor, target: torch.Tensor, |
| | weight: float=10.0, |
| | weight_1: Optional[torch.Tensor]=None, |
| | weight_2: Optional[torch.Tensor]=None, |
| | reduction: str = 'mean') -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if target.ndim == 2: |
| | target = target.squeeze(1) |
| | |
| | target_1 = (-1/3 * target**3 + target**2 + 1/3 * target).float() |
| | |
| | target_2 = (1/2 * (target + 1)).float() |
| | |
| | |
| | weight_loss_1 = torch.ones_like(target_1, dtype=input.dtype, device=input.device) |
| | weight_loss_1[target_1 == 1] = weight_1[0] / weight_1[1] |
| | loss_1 = binary_cross_entropy(input=input[:, 0], |
| | target=target_1, |
| | weight=weight_loss_1, |
| | reduction=reduction) |
| | |
| | |
| | filter = (target == -1) | (target == 1) |
| | |
| | if not filter.any(): |
| | return loss_1 |
| | weight_loss_2 = torch.ones_like(target_2, dtype=input.dtype, device=input.device) |
| | weight_loss_2[target_2 == 1] = weight_2[0] / weight_2[1] |
| | loss_2 = binary_cross_entropy(input=input[filter, 1], |
| | target=target_2[filter], |
| | weight=weight_loss_2[filter], |
| | reduction=reduction) |
| | |
| | |
| | if not (target == 0).any(): |
| | loss = loss_2 |
| | else: |
| | loss = loss_1 + weight * loss_2 |
| | return loss |
| |
|
| |
|
| | def gaussian_loss(input: torch.Tensor, target: torch.Tensor): |
| | |
| | |
| | |
| | |
| | loss = -Normal(loc=input[:, 0], scale=torch.nn.functional.softplus(input[:, 1])).log_prob(target).mean() |
| | loss += torch.nn.functional.softplus(input[:, 1]).mean() |
| | return loss |
| |
|
| |
|
| | def mse_loss_weighted(input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor=None, reduce=True, reduction=None) -> torch.Tensor: |
| | |
| | mse = (input - target).pow(2) |
| | if weight is not None: |
| | mse *= weight |
| | if reduce: |
| | return mse.mean() |
| | else: |
| | return mse |
| |
|
| |
|
| | loss_fn_mapping = { |
| | "mse_loss": mse_loss, |
| | "mse_loss_weighted": mse_loss_weighted, |
| | "l1_loss": l1_loss, |
| | "binary_cross_entropy": binary_cross_entropy, |
| | "cross_entropy": cross_entropy, |
| | "kl_div": kl_div, |
| | "cosin_contrastive_loss": cosin_contrastive_loss, |
| | "euclid_contrastive_loss": euclid_contrastive_loss, |
| | "combined_loss": combined_loss, |
| | "weighted_combined_loss": WeightedCombinedLoss, |
| | "weighted_loss": WeightedLoss2, |
| | "weighted_loss_betabinomial": WeightedLoss3, |
| | "gaussian_loss": gaussian_loss, |
| | "weighted_loss_pretrain": WeightedLoss1, |
| | "regression_weighted_loss": RegressionWeightedLoss, |
| | "GP_loss": GPLoss, |
| | } |
| |
|
| |
|
| | def get_template_fn(template): |
| | if template == 'plain-distance': |
| | return plain_distance, 1 |
| | elif template == 'exp-normal-smearing-distance': |
| | return exp_normal_smearing_distance, 50 |
| |
|
| | def plain_distance(pos): |
| | eps=1e-10 |
| | CA = pos[..., 3, :] |
| | d = (eps + (CA[..., None, :, :] - CA[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)) ** 0.5 |
| | return d |
| |
|
| | def exp_normal_smearing_distance(pos, cutoff_upper=100, cutoff_lower=0, num_rbf=50): |
| | alpha = 5.0 / (cutoff_upper - cutoff_lower) |
| | start_value = torch.exp( |
| | torch.scalar_tensor(-cutoff_upper + cutoff_lower) |
| | ).to(pos.device) |
| | means = torch.linspace(start_value, 1, num_rbf).to(pos.device) |
| | betas = torch.tensor( |
| | [(2 / num_rbf * (1 - start_value)) ** -2] * num_rbf |
| | ).to(pos.device) |
| | dist = plain_distance(pos) |
| | cutoffs = 0.5 * (torch.cos(dist * math.pi / cutoff_upper).to(pos.device) + 1.0) |
| | |
| | cutoffs = cutoffs * (dist < cutoff_upper).float() |
| | return cutoffs * torch.exp( |
| | -betas * (torch.exp(alpha * (-dist + cutoff_lower)) - means) ** 2 |
| | ) |
| |
|
| |
|