Yuchan
commited on
Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -124,6 +124,17 @@ class SwiGLU(layers.Layer):
|
|
| 124 |
out = self.W1(tf.nn.silu(a) * b)
|
| 125 |
return tf.cast(out, x.dtype)
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
class LoU(layers.Layer):
|
| 128 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
| 129 |
super().__init__()
|
|
@@ -137,9 +148,7 @@ class LoU(layers.Layer):
|
|
| 137 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 138 |
|
| 139 |
self.glu = SwiGLU(d_model, 320)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def call(self, x, z):
|
| 143 |
x_f32 = tf.cast(x, tf.float32)
|
| 144 |
residual = x_f32
|
| 145 |
x_f32 = self.norm1(x)
|
|
@@ -171,7 +180,6 @@ class LoU(layers.Layer):
|
|
| 171 |
x_comb = score_clipped * V
|
| 172 |
|
| 173 |
out = self.norm(x_comb + residual)
|
| 174 |
-
out = self.cross(out, z)
|
| 175 |
out = self.glu(out)
|
| 176 |
return tf.cast(out, x.dtype)
|
| 177 |
|
|
|
|
| 124 |
out = self.W1(tf.nn.silu(a) * b)
|
| 125 |
return tf.cast(out, x.dtype)
|
| 126 |
|
| 127 |
+
class SwiGLU(layers.Layer):
|
| 128 |
+
def __init__(self, d_model, d_ff):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.proj = layers.Dense(d_ff)
|
| 131 |
+
self.out = layers.Dense(d_model)
|
| 132 |
+
def call(self, x):
|
| 133 |
+
x_proj = self.proj(x)
|
| 134 |
+
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 135 |
+
return self.out(x_val * tf.nn.silu(x_gate))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
class LoU(layers.Layer):
|
| 139 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
| 140 |
super().__init__()
|
|
|
|
| 148 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 149 |
|
| 150 |
self.glu = SwiGLU(d_model, 320)
|
| 151 |
+
def call(self, x):
|
|
|
|
|
|
|
| 152 |
x_f32 = tf.cast(x, tf.float32)
|
| 153 |
residual = x_f32
|
| 154 |
x_f32 = self.norm1(x)
|
|
|
|
| 180 |
x_comb = score_clipped * V
|
| 181 |
|
| 182 |
out = self.norm(x_comb + residual)
|
|
|
|
| 183 |
out = self.glu(out)
|
| 184 |
return tf.cast(out, x.dtype)
|
| 185 |
|