File size: 1,840 Bytes
bb44a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn.functional as F
import math
from torch.optim import Optimizer # Ensure Optimizer is imported for custom classes

class DCLR(Optimizer):
    def __init__(self, params, lr=0.01, lambda_=1.0, epsilon=1e-8, delta=1e-12, verbose=True):
        defaults = dict(lr=lr, lambda_=lambda_, epsilon=epsilon, delta=delta, verbose=verbose)
        super(DCLR, self).__init__(params, defaults)

    def step(self, closure=None, output_activations=None):
        if output_activations is None:
            raise ValueError("Output activations must be provided to compute entropy.")

        loss = None
        if closure is not None:
            loss = closure()

        probs = torch.nn.functional.softmax(output_activations, dim=1)
        log_probs = torch.log(probs + self.defaults['delta'])
        entropy = -torch.sum(probs * log_probs, dim=1).mean()

        for group in self.param_groups:
            lr_0 = group['lr']
            lambda_ = group['lambda_']
            epsilon = group['epsilon']
            verbose = group['verbose']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                grad_norm_sq = grad.norm() ** 2

                eta_t = lr_0 * math.exp(-lambda_ * grad_norm_sq.item() / (entropy.item() + epsilon))

                if verbose:
                    print(f"[DCLR] Entropy: {entropy.item():.6f} | GradNorm²: {grad_norm_sq.item():.6f} | η(t): {eta_t:.6e}")

                # Fix for UserWarning: This overload of add_ is deprecated:
                # add_(Number alpha, Tensor other)
                # Consider using one of the following signatures instead:
                # add_(Tensor other, *, Number alpha = 1)
                p.data.add_(grad, alpha=-eta_t)

        return loss