| | """Contains utility math functions. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Any, Callable, Literal, NamedTuple, Tuple, Union |
| |
|
| | import torch |
| | from torch import autograd |
| |
|
| | ActivationType = Literal[ |
| | "linear", |
| | "exp", |
| | "sigmoid", |
| | "softplus", |
| | "relu_with_pushback", |
| | "hard_sigmoid_with_pushback", |
| | ] |
| | ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
| |
|
| |
|
| | class ActivationPair(NamedTuple): |
| | """A pair of forward and inverse activation functions.""" |
| |
|
| | forward: ActivationFunction |
| | inverse: ActivationFunction |
| |
|
| |
|
| | def create_activation_pair(activation_type: ActivationType) -> ActivationPair: |
| | """Create activation function and corresponding inverse function. |
| | |
| | Args: |
| | activation_type: The activation type to create. |
| | |
| | Returns: |
| | The corresponding activation functions and the corresponding inverse function. |
| | """ |
| | if activation_type == "linear": |
| | return ActivationPair(lambda x: x, lambda x: x) |
| | elif activation_type == "exp": |
| | return ActivationPair(torch.exp, torch.log) |
| | elif activation_type == "sigmoid": |
| | return ActivationPair(torch.sigmoid, inverse_sigmoid) |
| | elif activation_type == "softplus": |
| | return ActivationPair(torch.nn.functional.softplus, inverse_softplus) |
| | elif activation_type == "relu_with_pushback": |
| | return ActivationPair(relu_with_pushback, lambda x: x) |
| | elif activation_type == "hard_sigmoid_with_pushback": |
| | return ActivationPair(hard_sigmoid_with_pushback, lambda x: 6.0 * x - 3.0) |
| | else: |
| | raise ValueError(f"Unsupported activation function: {activation_type}.") |
| |
|
| |
|
| | def inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: |
| | """Compute inverse sigmoid.""" |
| | return torch.log(tensor / (1.0 - tensor)) |
| |
|
| |
|
| | def inverse_softplus(tensor: torch.Tensor, eps: float = 1e-06) -> torch.Tensor: |
| | """Compute inverse softplus.""" |
| | tensor = tensor.clamp_min(eps) |
| | sigmoid = torch.sigmoid(-tensor) |
| | exp = sigmoid / (1.0 - sigmoid) |
| | return tensor + torch.log(-exp + 1.0) |
| |
|
| |
|
| | |
| | |
| | SoftClampRange = Tuple[Union[torch.Tensor, float], Union[torch.Tensor, float]] |
| |
|
| |
|
| | def softclamp( |
| | tensor: torch.Tensor, |
| | min: SoftClampRange | None = None, |
| | max: SoftClampRange | None = None, |
| | ) -> torch.Tensor: |
| | """Clamp tensor to min/max in differentiable way. |
| | |
| | Args: |
| | tensor: The tensor to clamp. |
| | min: Pair of threshold to start clamping and value to clamp to. |
| | The first value should be larger than the second. |
| | max: Pair of threshold to start clamping and value to clamp to. |
| | The first value should be smaller than the second. |
| | |
| | Returns: |
| | The clamped tensor. |
| | """ |
| |
|
| | def normalize(clamp_range: SoftClampRange) -> torch.Tensor: |
| | value0, value1 = clamp_range |
| | return value0 + (value1 - value0) * torch.tanh((tensor - value0) / (value1 - value0)) |
| |
|
| | tensor_clamped = tensor |
| | if min is not None: |
| | tensor_clamped = torch.maximum(tensor_clamped, normalize(min)) |
| | if max is not None: |
| | tensor_clamped = torch.minimum(tensor_clamped, normalize(max)) |
| |
|
| | return tensor_clamped |
| |
|
| |
|
| | class ClampWithPushback(autograd.Function): |
| | """Implementation of clamp_with_pushback function.""" |
| |
|
| | @staticmethod |
| | def forward( |
| | ctx: Any, |
| | tensor: torch.Tensor, |
| | min: float | None, |
| | max: float | None, |
| | pushback: float, |
| | ) -> torch.Tensor: |
| | """Apply clamp.""" |
| | if min is not None and max is not None and min >= max: |
| | raise ValueError("Only min < max is supported.") |
| |
|
| | ctx.save_for_backward(tensor) |
| | ctx.min = min |
| | ctx.max = max |
| | ctx.pushback = pushback |
| | return torch.clamp(tensor, min=min, max=max) |
| |
|
| | @staticmethod |
| | def backward( |
| | ctx: Any, grad_in: torch.Tensor |
| | ) -> tuple[torch.Tensor, None, None, None]: |
| | """Compute gradient of clamp with pushback.""" |
| | grad_out = grad_in.clone() |
| | (tensor,) = ctx.saved_tensors |
| |
|
| | if ctx.min is not None: |
| | mask_min = tensor < ctx.min |
| | grad_out[mask_min] = -ctx.pushback |
| |
|
| | if ctx.max is not None: |
| | mask_max = tensor > ctx.max |
| | grad_out[mask_max] = ctx.pushback |
| |
|
| | return grad_out, None, None, None |
| |
|
| |
|
| | def clamp_with_pushback( |
| | tensor: torch.Tensor, |
| | min: float | None = None, |
| | max: float | None = None, |
| | pushback: float = 1e-2, |
| | ) -> torch.Tensor: |
| | """Variant of clamp function which avoid the vanishing gradient problem. |
| | |
| | This function is equivalent to adding a regularizer of the form |
| | |
| | pushback * sum_i ( |
| | relu(min - preactivation_i) + relu(preactivation_i - max) |
| | ) |
| | |
| | to the full loss function, which pushes clamped values back. |
| | |
| | When used in minimization problems, pushback should be greater than |
| | zero. In maximization problems, pushback should be smaller than zero. |
| | """ |
| | output = ClampWithPushback.apply(tensor, min, max, pushback) |
| | assert isinstance(output, torch.Tensor) |
| | return output |
| |
|
| |
|
| | def hard_sigmoid_with_pushback(x: torch.Tensor, slope: float = 1.0 / 6.0) -> torch.Tensor: |
| | """Apply hard sigmoid with pushback. |
| | |
| | For compatibility reasons, we follow the default PyTorch implementation with a |
| | default slope of 1/6: |
| | |
| | https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html |
| | """ |
| | return clamp_with_pushback(slope * x + 0.5, min=0.0, max=1.0) |
| |
|
| |
|
| | def relu_with_pushback(x: torch.Tensor) -> torch.Tensor: |
| | """Compute relu with pushback.""" |
| | return clamp_with_pushback(x, min=0.0) |
| |
|