Yuchan commited on
Commit
a6ed1c9
·
verified ·
1 Parent(s): 3e4952f

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +12 -3
Model.py CHANGED
@@ -153,7 +153,6 @@ class LoU(layers.Layer):
153
  self.Vr = Lo(d_model)
154
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
155
  self.O = layers.Dense(d_model, dtype='float32')
156
- self.glu = SwiGLU(d_model)
157
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
158
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
159
 
@@ -201,16 +200,26 @@ class LoU(layers.Layer):
201
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
202
  x_comb = score_clipped * V
203
  out = self.proj(x_comb)
204
- out = self.glu(out)
205
  out = self.norm(out + residual)
206
  return tf.cast(out, x.dtype)
207
 
 
 
 
 
 
 
 
 
 
 
 
208
  class ReLM(tf.keras.Model):
209
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
210
  super().__init__()
211
  self.token_embedding = layers.Embedding(vocab_size, d_model)
212
  self.pos_embedding = layers.Embedding(max_seq_len, d_model)
213
- self.blocks = [LoU(d_model) for _ in range(n_layers)]
214
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
215
 
216
  def call(self, x, training=False):
 
153
  self.Vr = Lo(d_model)
154
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
155
  self.O = layers.Dense(d_model, dtype='float32')
 
156
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
157
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
158
 
 
200
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
201
  x_comb = score_clipped * V
202
  out = self.proj(x_comb)
 
203
  out = self.norm(out + residual)
204
  return tf.cast(out, x.dtype)
205
 
206
+ class Block(layers.Layer):
207
+ def __init__(self, d_model):
208
+ super().__init__()
209
+ self.lou = LoU(d_model)
210
+ self.glu = SwiGLU(d_model)
211
+
212
+ def call(self, x):
213
+ x = self.lou(x)
214
+ x = self.glu(x)
215
+ return x
216
+
217
  class ReLM(tf.keras.Model):
218
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
219
  super().__init__()
220
  self.token_embedding = layers.Embedding(vocab_size, d_model)
221
  self.pos_embedding = layers.Embedding(max_seq_len, d_model)
222
+ self.blocks = [Block(d_model) for _ in range(n_layers)]
223
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
224
 
225
  def call(self, x, training=False):