Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -145,7 +145,7 @@ class Lo(layers.Layer):
|
|
| 145 |
super().__init__()
|
| 146 |
# λ΄λΆ κ³μ°μ float32λ‘ μ μ§
|
| 147 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 148 |
-
self.p = layers.Dense(
|
| 149 |
self._out_dtype = 'float32'
|
| 150 |
|
| 151 |
def call(self, x):
|
|
@@ -174,21 +174,15 @@ class LoSoU(layers.Layer):
|
|
| 174 |
self.eps = float(eps)
|
| 175 |
|
| 176 |
# projection / gating layers in float32
|
| 177 |
-
self.Q = layers.Dense(
|
| 178 |
-
self.K = layers.Dense(
|
| 179 |
-
self.V =
|
|
|
|
|
|
|
|
|
|
| 180 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 181 |
self.O = layers.Dense(d_model, dtype='float32')
|
| 182 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 183 |
-
|
| 184 |
-
# λμ alpha κ³μ°μ μν λ μ΄μ΄
|
| 185 |
-
# alphaλ [0, 1] λ²μμ¬μΌ νλ―λ‘ sigmoid μ¬μ©
|
| 186 |
-
# μ
λ ₯ xμ d_model μ°¨μμ μ¬μ©νμ¬ κ° μνμ λν΄ alpha κ³μ°
|
| 187 |
-
# μ: (B, L, d_model) -> (B, L, 1) -> (B, L, 1) with sigmoid
|
| 188 |
-
# λλ (B, L, d_model) -> (B, L, d_model) -> global reduce -> (B, L, 1)
|
| 189 |
-
# κ°λ¨ν κ° μμΉμ λν΄ λμΌν alpha μ¬μ© (μ
λ ₯μ νκ· κΈ°λ°)
|
| 190 |
-
# λλ μμΉλ³λ‘ λ€λ₯΄κ² μ¬μ© (κ° μμΉμ λν΄ κ³μ°)
|
| 191 |
-
# μ¬κΈ°μλ μμΉλ³λ‘ λ€λ₯΄κ² κ³μ° (B, L, 1)
|
| 192 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 193 |
|
| 194 |
def _ema_over_time(self, score, alpha_dynamic):
|
|
@@ -231,9 +225,13 @@ class LoSoU(layers.Layer):
|
|
| 231 |
residual = x_f32
|
| 232 |
|
| 233 |
# Q, K, V
|
| 234 |
-
q = self.Q(x_f32)
|
| 235 |
-
k = self.K(x_f32)
|
| 236 |
-
V =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
# gating signals in (0,1)
|
| 239 |
g_q = tf.nn.sigmoid(q)
|
|
|
|
| 145 |
super().__init__()
|
| 146 |
# λ΄λΆ κ³μ°μ float32λ‘ μ μ§
|
| 147 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 148 |
+
self.p = layers.Dense(128, use_bias=True, dtype='float32')
|
| 149 |
self._out_dtype = 'float32'
|
| 150 |
|
| 151 |
def call(self, x):
|
|
|
|
| 174 |
self.eps = float(eps)
|
| 175 |
|
| 176 |
# projection / gating layers in float32
|
| 177 |
+
self.Q = layers.Dense(d_model, dtype='float32')
|
| 178 |
+
self.K = layers.Dense(d_model, dtype='float32')
|
| 179 |
+
self.V = layers.Dense(d_model, dtype='float32')
|
| 180 |
+
self.Qr = Lo(d_model)
|
| 181 |
+
self.Kr = Lo(d_model)
|
| 182 |
+
self.Vr = Lo(d_model)
|
| 183 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 184 |
self.O = layers.Dense(d_model, dtype='float32')
|
| 185 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 187 |
|
| 188 |
def _ema_over_time(self, score, alpha_dynamic):
|
|
|
|
| 225 |
residual = x_f32
|
| 226 |
|
| 227 |
# Q, K, V
|
| 228 |
+
q = self.Q(x_f32)
|
| 229 |
+
k = self.K(x_f32)
|
| 230 |
+
V = self.V(x_f32)
|
| 231 |
+
|
| 232 |
+
q = self.Qr(q)
|
| 233 |
+
k = self.Kr(k)
|
| 234 |
+
v = self.Vr(v)
|
| 235 |
|
| 236 |
# gating signals in (0,1)
|
| 237 |
g_q = tf.nn.sigmoid(q)
|