Yuchan commited on
Commit
0c184aa
·
verified ·
1 Parent(s): a1c82ef

Update Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +1 -1
Inference.py CHANGED
@@ -66,7 +66,7 @@ class LoSoU(layers.Layer):
66
  # projection / gating layers in float32
67
  self.Q = layers.Dense(96, dtype='float32')
68
  self.K = layers.Dense(96, dtype='float32')
69
- self.V = layers.Dense(96, dtype='float32') # Lo already handles casting to model dtype; we'll cast back to float32
70
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
71
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
72
 
 
66
  # projection / gating layers in float32
67
  self.Q = layers.Dense(96, dtype='float32')
68
  self.K = layers.Dense(96, dtype='float32')
69
+ self.V = layers.Dense(96, activation='gelu', dtype='float32')
70
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
71
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
72