LukeDarlow commited on
Commit
451276c
·
1 Parent(s): b9964ef

Clamp of decay params applied to data so that gradients are valid moving forward. Fix suggested by user=kuviki

Browse files
Files changed (2) hide show
  1. models/ctm.py +4 -1
  2. models/ctm_qamnist.py +4 -1
models/ctm.py CHANGED
@@ -500,7 +500,10 @@ class ContinuousThoughtMachine(nn.Module):
500
 
501
  # --- Initialise Recurrent Synch Values ---
502
  decay_alpha_action, decay_beta_action = None, None
503
- r_action, r_out = torch.exp(-torch.clamp(self.decay_params_action, 0, 15)).unsqueeze(0).repeat(B, 1), torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
 
 
 
504
  _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
505
  # Compute learned weighting for synchronisation
506
 
 
500
 
501
  # --- Initialise Recurrent Synch Values ---
502
  decay_alpha_action, decay_beta_action = None, None
503
+ self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
504
+ self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
505
+ r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
506
+
507
  _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
508
  # Compute learned weighting for synchronisation
509
 
models/ctm_qamnist.py CHANGED
@@ -147,7 +147,10 @@ class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine):
147
 
148
  # --- Initialise Recurrent Synch Values ---
149
  decay_alpha_action, decay_beta_action = None, None
150
- r_action, r_out = torch.exp(-torch.clamp(self.decay_params_action, 0, 15)).unsqueeze(0).repeat(B, 1), torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
 
 
 
151
  _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
152
 
153
  prev_input = None
 
147
 
148
  # --- Initialise Recurrent Synch Values ---
149
  decay_alpha_action, decay_beta_action = None, None
150
+ self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
151
+ self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
152
+ r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
153
+
154
  _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
155
 
156
  prev_input = None