Yuchan commited on
Commit
1928697
Β·
verified Β·
1 Parent(s): 1896fcf

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +14 -16
Model.py CHANGED
@@ -145,7 +145,7 @@ class Lo(layers.Layer):
145
  super().__init__()
146
  # λ‚΄λΆ€ 계산은 float32둜 μœ μ§€
147
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
148
- self.p = layers.Dense(96, use_bias=True, dtype='float32')
149
  self._out_dtype = 'float32'
150
 
151
  def call(self, x):
@@ -174,21 +174,15 @@ class LoSoU(layers.Layer):
174
  self.eps = float(eps)
175
 
176
  # projection / gating layers in float32
177
- self.Q = layers.Dense(96, dtype='float32')
178
- self.K = layers.Dense(96, dtype='float32')
179
- self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
 
 
 
180
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
181
  self.O = layers.Dense(d_model, dtype='float32')
182
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
183
-
184
- # 동적 alpha 계산을 μœ„ν•œ λ ˆμ΄μ–΄
185
- # alphaλŠ” [0, 1] λ²”μœ„μ—¬μ•Ό ν•˜λ―€λ‘œ sigmoid μ‚¬μš©
186
- # μž…λ ₯ x의 d_model 차원을 μ‚¬μš©ν•˜μ—¬ 각 μƒ˜ν”Œμ— λŒ€ν•΄ alpha 계산
187
- # 예: (B, L, d_model) -> (B, L, 1) -> (B, L, 1) with sigmoid
188
- # λ˜λŠ” (B, L, d_model) -> (B, L, d_model) -> global reduce -> (B, L, 1)
189
- # κ°„λ‹¨νžˆ 각 μœ„μΉ˜μ— λŒ€ν•΄ λ™μΌν•œ alpha μ‚¬μš© (μž…λ ₯의 평균 기반)
190
- # λ˜λŠ” μœ„μΉ˜λ³„λ‘œ λ‹€λ₯΄κ²Œ μ‚¬μš© (각 μœ„μΉ˜μ— λŒ€ν•΄ 계산)
191
- # μ—¬κΈ°μ„œλŠ” μœ„μΉ˜λ³„λ‘œ λ‹€λ₯΄κ²Œ 계산 (B, L, 1)
192
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
193
 
194
  def _ema_over_time(self, score, alpha_dynamic):
@@ -231,9 +225,13 @@ class LoSoU(layers.Layer):
231
  residual = x_f32
232
 
233
  # Q, K, V
234
- q = self.Q(x_f32) # (B, L, 96)
235
- k = self.K(x_f32) # (B, L, 96)
236
- V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
 
 
 
 
237
 
238
  # gating signals in (0,1)
239
  g_q = tf.nn.sigmoid(q)
 
145
  super().__init__()
146
  # λ‚΄λΆ€ 계산은 float32둜 μœ μ§€
147
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
148
+ self.p = layers.Dense(128, use_bias=True, dtype='float32')
149
  self._out_dtype = 'float32'
150
 
151
  def call(self, x):
 
174
  self.eps = float(eps)
175
 
176
  # projection / gating layers in float32
177
+ self.Q = layers.Dense(d_model, dtype='float32')
178
+ self.K = layers.Dense(d_model, dtype='float32')
179
+ self.V = layers.Dense(d_model, dtype='float32')
180
+ self.Qr = Lo(d_model)
181
+ self.Kr = Lo(d_model)
182
+ self.Vr = Lo(d_model)
183
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
184
  self.O = layers.Dense(d_model, dtype='float32')
185
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
 
 
 
 
 
 
 
 
186
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
187
 
188
  def _ema_over_time(self, score, alpha_dynamic):
 
225
  residual = x_f32
226
 
227
  # Q, K, V
228
+ q = self.Q(x_f32)
229
+ k = self.K(x_f32)
230
+ V = self.V(x_f32)
231
+
232
+ q = self.Qr(q)
233
+ k = self.Kr(k)
234
+ v = self.Vr(v)
235
 
236
  # gating signals in (0,1)
237
  g_q = tf.nn.sigmoid(q)