Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -153,7 +153,6 @@ class LoU(layers.Layer):
|
|
| 153 |
self.Vr = Lo(d_model)
|
| 154 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 155 |
self.O = layers.Dense(d_model, dtype='float32')
|
| 156 |
-
self.glu = SwiGLU(d_model)
|
| 157 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 158 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 159 |
|
|
@@ -201,16 +200,26 @@ class LoU(layers.Layer):
|
|
| 201 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 202 |
x_comb = score_clipped * V
|
| 203 |
out = self.proj(x_comb)
|
| 204 |
-
out = self.glu(out)
|
| 205 |
out = self.norm(out + residual)
|
| 206 |
return tf.cast(out, x.dtype)
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
class ReLM(tf.keras.Model):
|
| 209 |
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
|
| 210 |
super().__init__()
|
| 211 |
self.token_embedding = layers.Embedding(vocab_size, d_model)
|
| 212 |
self.pos_embedding = layers.Embedding(max_seq_len, d_model)
|
| 213 |
-
self.blocks = [
|
| 214 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
|
| 215 |
|
| 216 |
def call(self, x, training=False):
|
|
|
|
| 153 |
self.Vr = Lo(d_model)
|
| 154 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 155 |
self.O = layers.Dense(d_model, dtype='float32')
|
|
|
|
| 156 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 157 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 158 |
|
|
|
|
| 200 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 201 |
x_comb = score_clipped * V
|
| 202 |
out = self.proj(x_comb)
|
|
|
|
| 203 |
out = self.norm(out + residual)
|
| 204 |
return tf.cast(out, x.dtype)
|
| 205 |
|
| 206 |
+
class Block(layers.Layer):
|
| 207 |
+
def __init__(self, d_model):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.lou = LoU(d_model)
|
| 210 |
+
self.glu = SwiGLU(d_model)
|
| 211 |
+
|
| 212 |
+
def call(self, x):
|
| 213 |
+
x = self.lou(x)
|
| 214 |
+
x = self.glu(x)
|
| 215 |
+
return x
|
| 216 |
+
|
| 217 |
class ReLM(tf.keras.Model):
|
| 218 |
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
|
| 219 |
super().__init__()
|
| 220 |
self.token_embedding = layers.Embedding(vocab_size, d_model)
|
| 221 |
self.pos_embedding = layers.Embedding(max_seq_len, d_model)
|
| 222 |
+
self.blocks = [Block(d_model) for _ in range(n_layers)]
|
| 223 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
|
| 224 |
|
| 225 |
def call(self, x, training=False):
|