Commit
·
451276c
1
Parent(s):
b9964ef
Clamp of decay params applied to data so that gradients are valid moving forward. Fix suggested by user=kuviki
Browse files- models/ctm.py +4 -1
- models/ctm_qamnist.py +4 -1
models/ctm.py
CHANGED
|
@@ -500,7 +500,10 @@ class ContinuousThoughtMachine(nn.Module):
|
|
| 500 |
|
| 501 |
# --- Initialise Recurrent Synch Values ---
|
| 502 |
decay_alpha_action, decay_beta_action = None, None
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
| 504 |
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 505 |
# Compute learned weighting for synchronisation
|
| 506 |
|
|
|
|
| 500 |
|
| 501 |
# --- Initialise Recurrent Synch Values ---
|
| 502 |
decay_alpha_action, decay_beta_action = None, None
|
| 503 |
+
self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
|
| 504 |
+
self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
|
| 505 |
+
r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
|
| 506 |
+
|
| 507 |
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 508 |
# Compute learned weighting for synchronisation
|
| 509 |
|
models/ctm_qamnist.py
CHANGED
|
@@ -147,7 +147,10 @@ class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine):
|
|
| 147 |
|
| 148 |
# --- Initialise Recurrent Synch Values ---
|
| 149 |
decay_alpha_action, decay_beta_action = None, None
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
| 151 |
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 152 |
|
| 153 |
prev_input = None
|
|
|
|
| 147 |
|
| 148 |
# --- Initialise Recurrent Synch Values ---
|
| 149 |
decay_alpha_action, decay_beta_action = None, None
|
| 150 |
+
self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
|
| 151 |
+
self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
|
| 152 |
+
r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
|
| 153 |
+
|
| 154 |
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 155 |
|
| 156 |
prev_input = None
|