Spaces:
Sleeping
Sleeping
| # RmsProp optimizer implementation | |
| from abc import ABC | |
| from .base import BaseOptimizer | |
| class RMSpropOptimizer(BaseOptimizer, ABC): | |
| """ | |
| RMSprop optimizer implementation. | |
| This optimizer uses a moving average of squared gradients to normalize the gradient. | |
| """ | |
| def __init__(self, params, lr=0.001, alpha=0.99, eps=1e-8): | |
| self.params = params | |
| self.lr = lr | |
| self.alpha = alpha | |
| self.eps = eps | |
| self.state = {p: {'mean_square': 0} for p in params} | |
| def step(self): | |
| for p in self.params: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| state['mean_square'] = self.alpha * state['mean_square'] + (1 - self.alpha) * (p.grad ** 2) | |
| p.data -= self.lr * p.grad / (state['mean_square'].sqrt() + self.eps) | |
| def zero_grad(self): | |
| for p in self.params: | |
| p.grad = 0 | |
| def __repr__(self): | |
| return f"RMSpropOptimizer(lr={self.lr}, alpha={self.alpha}, eps={self.eps})" | |
| def state_dict(self): | |
| return {p: {'mean_square': state['mean_square']} for p, state in self.state.items()} | |