Yuchan
commited on
Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -118,74 +118,58 @@ with strategy.scope():
|
|
| 118 |
class SwiGLU(layers.Layer):
|
| 119 |
def __init__(self, d_model, d_ff):
|
| 120 |
super().__init__()
|
| 121 |
-
self.proj = layers.Dense(
|
| 122 |
self.out = layers.Dense(d_model)
|
| 123 |
def call(self, x):
|
| 124 |
x_proj = self.proj(x)
|
| 125 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 126 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
|
|
|
| 130 |
super().__init__()
|
| 131 |
-
|
| 132 |
-
self.
|
| 133 |
-
self.
|
| 134 |
-
self.
|
| 135 |
-
self.
|
| 136 |
-
self.
|
| 137 |
-
self.
|
| 138 |
-
self.
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
V = self.V(x_f32)
|
| 173 |
-
g_q = (tf.nn.tanh(q) + 1.0) / 2.0
|
| 174 |
-
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 175 |
-
score = g_q * g_k
|
| 176 |
-
|
| 177 |
-
alpha_dynamic = self.alpha_linear(x_f32)
|
| 178 |
-
score_ema = self._ema_over_time(score, alpha_dynamic)
|
| 179 |
-
mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
|
| 180 |
-
denom = tf.maximum(mean_last, self.eps)
|
| 181 |
-
score_norm = score_ema / denom
|
| 182 |
-
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 183 |
-
x_comb = score_clipped * V
|
| 184 |
-
|
| 185 |
-
# LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
|
| 186 |
-
out = self.norm(x_comb + residual)
|
| 187 |
-
out = self.glu(out)
|
| 188 |
-
return tf.cast(out, x.dtype)
|
| 189 |
|
| 190 |
class Lo(layers.Layer):
|
| 191 |
def __init__(self, d_model):
|
|
@@ -202,7 +186,8 @@ class Lo(layers.Layer):
|
|
| 202 |
class Block(layers.Layer):
|
| 203 |
def __init__(self, d_model):
|
| 204 |
super().__init__()
|
| 205 |
-
self.lou =
|
|
|
|
| 206 |
self.lo = Lo(d_model)
|
| 207 |
|
| 208 |
def call(self, x):
|
|
|
|
| 118 |
class SwiGLU(layers.Layer):
|
| 119 |
def __init__(self, d_model, d_ff):
|
| 120 |
super().__init__()
|
| 121 |
+
self.proj = layers.Dense(dff)
|
| 122 |
self.out = layers.Dense(d_model)
|
| 123 |
def call(self, x):
|
| 124 |
x_proj = self.proj(x)
|
| 125 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 126 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 127 |
|
| 128 |
+
|
| 129 |
+
class MHLA(layers.Layer):
|
| 130 |
+
def __init__(self, embed_dim, num_heads=8, dropout=0.0):
|
| 131 |
super().__init__()
|
| 132 |
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 133 |
+
self.embed_dim = embed_dim
|
| 134 |
+
self.num_heads = num_heads
|
| 135 |
+
self.head_dim = embed_dim // num_heads
|
| 136 |
+
self.Wq = layers.Dense(embed_dim, use_bias=False)
|
| 137 |
+
self.Wk = layers.Dense(embed_dim, use_bias=False)
|
| 138 |
+
self.Wv = layers.Dense(embed_dim, use_bias=False)
|
| 139 |
+
self.out = layers.Dense(embed_dim)
|
| 140 |
+
self.dropout = layers.Dropout(dropout)
|
| 141 |
+
|
| 142 |
+
def split_heads(self, x):
|
| 143 |
+
# [B, L, D] -> [B, num_heads, L, head_dim]
|
| 144 |
+
B, L, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
|
| 145 |
+
x = tf.reshape(x, (B, L, self.num_heads, self.head_dim))
|
| 146 |
+
return tf.transpose(x, perm=[0, 2, 1, 3])
|
| 147 |
+
|
| 148 |
+
def combine_heads(self, x):
|
| 149 |
+
# [B, num_heads, L, head_dim] -> [B, L, D]
|
| 150 |
+
x = tf.transpose(x, perm=[0, 2, 1, 3])
|
| 151 |
+
B, L, H, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
|
| 152 |
+
return tf.reshape(x, (B, L, H*D))
|
| 153 |
+
|
| 154 |
+
def call(self, x, training=False):
|
| 155 |
+
q = tf.nn.elu(self.Wq(x)) + 1
|
| 156 |
+
k = tf.nn.elu(self.Wk(x)) + 1
|
| 157 |
+
v = self.Wv(x)
|
| 158 |
+
|
| 159 |
+
q = self.split_heads(q)
|
| 160 |
+
k = self.split_heads(k)
|
| 161 |
+
v = self.split_heads(v)
|
| 162 |
+
|
| 163 |
+
# causal linear attention cumulative sum
|
| 164 |
+
k_cum = tf.cumsum(k, axis=2)
|
| 165 |
+
kv_cum = tf.cumsum(k * v, axis=2)
|
| 166 |
+
|
| 167 |
+
z = 1.0 / tf.reduce_sum(q * k_cum, axis=-1, keepdims=True)
|
| 168 |
+
out = (q * kv_cum) * z
|
| 169 |
+
out = self.combine_heads(out)
|
| 170 |
+
out = self.dropout(out, training=training)
|
| 171 |
+
return self.out(out)
|
| 172 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
class Lo(layers.Layer):
|
| 175 |
def __init__(self, d_model):
|
|
|
|
| 186 |
class Block(layers.Layer):
|
| 187 |
def __init__(self, d_model):
|
| 188 |
super().__init__()
|
| 189 |
+
self.lou = MHLA(d_model, 8)
|
| 190 |
+
self.glu = SwiGLU(d_model, 1154)
|
| 191 |
self.lo = Lo(d_model)
|
| 192 |
|
| 193 |
def call(self, x):
|