Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -112,21 +112,6 @@ dataset = dataset.shuffle(2000, seed=SEED).batch(batch_size, drop_remainder=True
|
|
| 112 |
with strategy.scope():
|
| 113 |
dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
| 114 |
|
| 115 |
-
|
| 116 |
-
class Lo(layers.Layer):
|
| 117 |
-
def __init__(self, d_model):
|
| 118 |
-
super().__init__()
|
| 119 |
-
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 120 |
-
self.p = layers.Dense(128, use_bias=True, dtype='float32')
|
| 121 |
-
self._out_dtype = 'float32'
|
| 122 |
-
|
| 123 |
-
def call(self, x):
|
| 124 |
-
x_f32 = tf.cast(x, tf.float32)
|
| 125 |
-
x = self.proj(x_f32)
|
| 126 |
-
x = tf.nn.gelu(x)
|
| 127 |
-
x = self.p(x)
|
| 128 |
-
return tf.cast(x, self._out_dtype)
|
| 129 |
-
|
| 130 |
class SwiGLU(layers.Layer):
|
| 131 |
def __init__(self, d_model):
|
| 132 |
super().__init__()
|
|
@@ -148,9 +133,6 @@ class LoU(layers.Layer):
|
|
| 148 |
self.Q = layers.Dense(d_model, dtype='float32')
|
| 149 |
self.K = layers.Dense(d_model, dtype='float32')
|
| 150 |
self.V = layers.Dense(d_model, dtype='float32')
|
| 151 |
-
self.Qr = Lo(d_model)
|
| 152 |
-
self.Kr = Lo(d_model)
|
| 153 |
-
self.Vr = Lo(d_model)
|
| 154 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 155 |
self.O = layers.Dense(d_model, dtype='float32')
|
| 156 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
|
@@ -185,10 +167,6 @@ class LoU(layers.Layer):
|
|
| 185 |
q = self.Q(x_f32)
|
| 186 |
k = self.K(x_f32)
|
| 187 |
V = self.V(x_f32)
|
| 188 |
-
q = self.Qr(q)
|
| 189 |
-
k = self.Kr(k)
|
| 190 |
-
V = self.Vr(V)
|
| 191 |
-
|
| 192 |
# 旮办〈 旖旊摐:
|
| 193 |
# g_q = tf.nn.sigmoid(q)
|
| 194 |
# g_k = tf.nn.sigmoid(k)
|
|
@@ -208,16 +186,28 @@ class LoU(layers.Layer):
|
|
| 208 |
out = self.norm(out + residual)
|
| 209 |
return tf.cast(out, x.dtype)
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
class Block(layers.Layer):
|
| 212 |
def __init__(self, d_model):
|
| 213 |
super().__init__()
|
| 214 |
self.lou = LoU(d_model)
|
| 215 |
self.glu = SwiGLU(d_model)
|
| 216 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
|
|
|
| 217 |
|
| 218 |
def call(self, x):
|
| 219 |
x = self.lou(x)
|
| 220 |
x = self.norm(self.glu(x)) + x
|
|
|
|
| 221 |
return x
|
| 222 |
|
| 223 |
class ReLM(tf.keras.Model):
|
|
|
|
| 112 |
with strategy.scope():
|
| 113 |
dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
class SwiGLU(layers.Layer):
|
| 116 |
def __init__(self, d_model):
|
| 117 |
super().__init__()
|
|
|
|
| 133 |
self.Q = layers.Dense(d_model, dtype='float32')
|
| 134 |
self.K = layers.Dense(d_model, dtype='float32')
|
| 135 |
self.V = layers.Dense(d_model, dtype='float32')
|
|
|
|
|
|
|
|
|
|
| 136 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 137 |
self.O = layers.Dense(d_model, dtype='float32')
|
| 138 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
|
|
|
| 167 |
q = self.Q(x_f32)
|
| 168 |
k = self.K(x_f32)
|
| 169 |
V = self.V(x_f32)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# 旮办〈 旖旊摐:
|
| 171 |
# g_q = tf.nn.sigmoid(q)
|
| 172 |
# g_k = tf.nn.sigmoid(k)
|
|
|
|
| 186 |
out = self.norm(out + residual)
|
| 187 |
return tf.cast(out, x.dtype)
|
| 188 |
|
| 189 |
+
class Lo(layers.Layer):
|
| 190 |
+
def __init__(self, d_model):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.d = layers.Dense(256, activation='silu')
|
| 193 |
+
self.w = layers.Dense(d_model)
|
| 194 |
+
def call(self, x):
|
| 195 |
+
p = self.d(x)
|
| 196 |
+
p = self.w(p)
|
| 197 |
+
return p + x
|
| 198 |
+
|
| 199 |
class Block(layers.Layer):
|
| 200 |
def __init__(self, d_model):
|
| 201 |
super().__init__()
|
| 202 |
self.lou = LoU(d_model)
|
| 203 |
self.glu = SwiGLU(d_model)
|
| 204 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 205 |
+
self.lo = Lo(d_model)
|
| 206 |
|
| 207 |
def call(self, x):
|
| 208 |
x = self.lou(x)
|
| 209 |
x = self.norm(self.glu(x)) + x
|
| 210 |
+
x = self.lo(x)
|
| 211 |
return x
|
| 212 |
|
| 213 |
class ReLM(tf.keras.Model):
|