Yuchan commited on
Commit
3bd0fac
ยท
verified ยท
1 Parent(s): f9eec12

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +26 -78
Model.py CHANGED
@@ -116,37 +116,35 @@ with strategy.scope():
116
  class Lo(layers.Layer):
117
  def __init__(self, d_model):
118
  super().__init__()
119
- # ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์œ ์ง€
120
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
121
  self.p = layers.Dense(128, use_bias=True, dtype='float32')
122
  self._out_dtype = 'float32'
123
 
124
  def call(self, x):
125
- # x may be bfloat16; cast to float32 for stable intermediate computation
126
  x_f32 = tf.cast(x, tf.float32)
127
  x = self.proj(x_f32)
128
  x = tf.nn.gelu(x)
129
  x = self.p(x)
130
- # cast back to model dtype for consistency
131
  return tf.cast(x, self._out_dtype)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  class LoU(layers.Layer):
134
- """
135
- ์•ˆ์ •ํ™”๋œ LoSoU ๋ ˆ์ด์–ด (๋™์  alpha ์‚ฌ์šฉ)
136
- - alpha ๊ฐ’์„ ์ž…๋ ฅ์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ๊ณ„์‚ฐ: alpha = sigmoid(Linear(x))
137
- - ๋ˆ„์ ํ•ฉ ๋Œ€์‹  ์ง€์ˆ˜์ด๋™ํ‰๊ท (EMA) ์‚ฌ์šฉ (alpha: smoothing factor)
138
- - ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์ˆ˜ํ–‰ (TPU bfloat16 ์•ˆ์ •์„ฑ ํ–ฅ์ƒ)
139
- - EMA ๊ฒฐ๊ณผ ํด๋ฆฌํ•‘ ๋ฐ ์ž‘์€ epsilon ์ ์šฉ
140
- - ์•ˆ์ „ํ•œ split ์ฒ˜๋ฆฌ (์ง์ˆ˜ ์ฐจ์› ๊ฐ€์ •; ์•„๋‹ˆ๋ผ๋ฉด ๋งˆ์ง€๋ง‰ ์ฐจ์› pad ํ•„์š”)
141
- """
142
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
143
  super().__init__()
144
- # ๋Œ€๋ถ€๋ถ„ ์—ฐ์‚ฐ์„ float32๋กœ ์ˆ˜ํ–‰
145
  self.d_model = d_model
146
  self.clip_value = float(clip_value)
147
  self.eps = float(eps)
148
-
149
- # projection / gating layers in float32
150
  self.Q = layers.Dense(d_model, dtype='float32')
151
  self.K = layers.Dense(d_model, dtype='float32')
152
  self.V = layers.Dense(d_model, dtype='float32')
@@ -155,100 +153,55 @@ class LoU(layers.Layer):
155
  self.Vr = Lo(d_model)
156
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
157
  self.O = layers.Dense(d_model, dtype='float32')
 
158
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
159
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
160
-
161
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
162
 
163
  def _ema_over_time(self, score, alpha_dynamic):
164
- # score: (B, L, D) float32 in [0,1] roughly
165
- # alpha_dynamic: (B, L, 1) float32 in [0,1]
166
-
167
- # transpose to (L, B, D) to scan over time steps
168
- seq = tf.transpose(score, perm=[1, 0, 2]) # (L, B, D)
169
- alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2]) # (L, B, 1)
170
 
171
  def step(prev_ema, inputs):
172
  x_t, alpha_t = inputs
173
- # prev_ema: (B, D), x_t: (B, D), alpha_t: (B, 1)
174
  new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
175
  return new
176
 
177
- # ์ดˆ๊ธฐ๊ฐ’์„ ์ฒซ step ๊ฐ’์œผ๋กœ ์„ค์ •
178
- init = seq[0] # (B, D)
179
- first_alpha = alpha_seq[0] # (B, 1)
180
-
181
- # scan์˜ elems๋Š” (L-1, B, D) ๋ฐ (L-1, B, 1) ์ด์–ด์•ผ ํ•จ
182
- remaining_seq = seq[1:] # (L-1, B, D)
183
- remaining_alpha = alpha_seq[1:] # (L-1, B, 1)
184
-
185
- # elems๋Š” ๋‘ ํ…์„œ์˜ ํŠœํ”Œ๋กœ ๊ตฌ์„ฑ: (x_t, alpha_t)
186
  elems = (remaining_seq, remaining_alpha)
187
-
188
  ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
189
- # ์ดˆ๊ธฐ๊ฐ’ ํฌํ•จ
190
- ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0) # (L, B, D)
191
-
192
- # transpose back to (B, L, D)
193
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
194
  return ema
195
 
196
  def call(self, x):
197
- # x: (B, L, d_model) maybe bfloat16 or float32
198
- # cast to float32 for all internal computations
199
  x_f32 = tf.cast(x, tf.float32)
200
  residual = x_f32
201
  x_f32 = self.norm1(x)
202
 
203
- # Q, K, V
204
  q = self.Q(x_f32)
205
  k = self.K(x_f32)
206
  V = self.V(x_f32)
207
-
208
  q = self.Qr(q)
209
  k = self.Kr(k)
210
  V = self.Vr(V)
211
 
212
- # gating signals in (0,1)
213
  g_q = tf.nn.sigmoid(q)
214
  g_k = tf.nn.sigmoid(k)
215
-
216
- # elementwise product -> bounded roughly [0,1]
217
  score = g_q * g_k
218
-
219
- # ๋™์  alpha ๊ณ„์‚ฐ: (B, L, d_model) -> (B, L, 1)
220
- alpha_dynamic = self.alpha_linear(x_f32) # (B, L, 1)
221
- # ํ•„์š”์‹œ alpha_dynamic์— ๋Œ€ํ•œ ํ›„์ฒ˜๋ฆฌ (์˜ˆ: min/max ๋“ฑ) ๊ฐ€๋Šฅ
222
- # ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
223
-
224
- # EMA across time (stable alternative to cumsum)
225
  score_ema = self._ema_over_time(score, alpha_dynamic)
226
-
227
- # optionally normalize by (mean + eps) across last dim to reduce scale variations
228
- mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True) # (B, L, 1)
229
  denom = tf.maximum(mean_last, self.eps)
230
  score_norm = score_ema / denom
231
-
232
- # clip to avoid extremes
233
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
234
-
235
- # combine with V
236
- x_comb = score_clipped * V # (B, L, d_model)
237
-
238
- out = self.proj(x_comb) # (B, L, d_model)
239
-
240
- # ensure out dim even for split
241
- d = out.shape[-1] # this is an int (static shape)
242
- if d is not None and d % 2 == 1:
243
- out = tf.pad(out, [[0,0],[0,0],[0,1]])
244
-
245
- a, b = tf.split(out, 2, axis=-1)
246
- gated = tf.nn.silu(a) * b
247
- out = self.O(gated)
248
-
249
  out = self.norm(out + residual)
250
-
251
- # cast back to original dtype for downstream layers
252
  return tf.cast(out, x.dtype)
253
 
254
  class ReLM(tf.keras.Model):
@@ -257,20 +210,15 @@ class ReLM(tf.keras.Model):
257
  self.token_embedding = layers.Embedding(vocab_size, d_model)
258
  self.pos_embedding = layers.Embedding(max_seq_len, d_model)
259
  self.blocks = [LoU(d_model) for _ in range(n_layers)]
260
-
261
- # LayerNormalization์€ float32๋กœ ํ•ด์„œ ์ •๋ฐ€๋„ ๋ฌธ์ œ ๋ฐฉ์ง€
262
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
263
 
264
  def call(self, x, training=False):
265
  batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
266
  positions = tf.range(seq_len)[tf.newaxis, :]
267
-
268
  x = self.token_embedding(x) + self.pos_embedding(positions)
269
  for block in self.blocks:
270
  x = block(x)
271
-
272
  x = self.ln_f(x)
273
-
274
  embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
275
  logits = tf.matmul(x, embedding_matrix, transpose_b=True)
276
  return tf.cast(logits, tf.float32)
@@ -301,7 +249,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
301
  model = ReLM(
302
  vocab_size=vocab_size,
303
  max_seq_len=max_len,
304
- d_model=128,
305
  n_layers=2
306
  )
307
 
 
116
  class Lo(layers.Layer):
117
  def __init__(self, d_model):
118
  super().__init__()
 
119
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
120
  self.p = layers.Dense(128, use_bias=True, dtype='float32')
121
  self._out_dtype = 'float32'
122
 
123
  def call(self, x):
 
124
  x_f32 = tf.cast(x, tf.float32)
125
  x = self.proj(x_f32)
126
  x = tf.nn.gelu(x)
127
  x = self.p(x)
 
128
  return tf.cast(x, self._out_dtype)
129
 
130
+ class SwiGLU(layers.Layer):
131
+ def __init__(self, d_model):
132
+ super().__init__()
133
+ self.W = layers.Dense(3500, dtype='float32')
134
+ self.W1 = layers.Dense(d_model, dtype='float32')
135
+ def call(self, x):
136
+ x = tf.cast(x, tf.float32)
137
+ x = self.W(x)
138
+ a, b = tf.split(x, 2, axis=-1)
139
+ out = self.W1(tf.nn.silu(a) * b)
140
+ return tf.cast(out, x.dtype)
141
+
142
  class LoU(layers.Layer):
 
 
 
 
 
 
 
 
143
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
144
  super().__init__()
 
145
  self.d_model = d_model
146
  self.clip_value = float(clip_value)
147
  self.eps = float(eps)
 
 
148
  self.Q = layers.Dense(d_model, dtype='float32')
149
  self.K = layers.Dense(d_model, dtype='float32')
150
  self.V = layers.Dense(d_model, dtype='float32')
 
153
  self.Vr = Lo(d_model)
154
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
155
  self.O = layers.Dense(d_model, dtype='float32')
156
+ self.glu = SwiGLU(d_model)
157
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
158
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
159
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
160
 
161
  def _ema_over_time(self, score, alpha_dynamic):
162
+ seq = tf.transpose(score, perm=[1, 0, 2])
163
+ alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
 
 
 
 
164
 
165
  def step(prev_ema, inputs):
166
  x_t, alpha_t = inputs
 
167
  new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
168
  return new
169
 
170
+ init = seq[0]
171
+ first_alpha = alpha_seq[0]
172
+ remaining_seq = seq[1:]
173
+ remaining_alpha = alpha_seq[1:]
 
 
 
 
 
174
  elems = (remaining_seq, remaining_alpha)
 
175
  ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
176
+ ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
 
 
 
177
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
178
  return ema
179
 
180
  def call(self, x):
 
 
181
  x_f32 = tf.cast(x, tf.float32)
182
  residual = x_f32
183
  x_f32 = self.norm1(x)
184
 
 
185
  q = self.Q(x_f32)
186
  k = self.K(x_f32)
187
  V = self.V(x_f32)
 
188
  q = self.Qr(q)
189
  k = self.Kr(k)
190
  V = self.Vr(V)
191
 
 
192
  g_q = tf.nn.sigmoid(q)
193
  g_k = tf.nn.sigmoid(k)
 
 
194
  score = g_q * g_k
195
+ alpha_dynamic = self.alpha_linear(x_f32)
 
 
 
 
 
 
196
  score_ema = self._ema_over_time(score, alpha_dynamic)
197
+ mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
 
 
198
  denom = tf.maximum(mean_last, self.eps)
199
  score_norm = score_ema / denom
 
 
200
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
201
+ x_comb = score_clipped * V
202
+ out = self.proj(x_comb)
203
+ out = self.glu(out)
 
 
 
 
 
 
 
 
 
 
 
 
204
  out = self.norm(out + residual)
 
 
205
  return tf.cast(out, x.dtype)
206
 
207
  class ReLM(tf.keras.Model):
 
210
  self.token_embedding = layers.Embedding(vocab_size, d_model)
211
  self.pos_embedding = layers.Embedding(max_seq_len, d_model)
212
  self.blocks = [LoU(d_model) for _ in range(n_layers)]
 
 
213
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
214
 
215
  def call(self, x, training=False):
216
  batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
217
  positions = tf.range(seq_len)[tf.newaxis, :]
 
218
  x = self.token_embedding(x) + self.pos_embedding(positions)
219
  for block in self.blocks:
220
  x = block(x)
 
221
  x = self.ln_f(x)
 
222
  embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
223
  logits = tf.matmul(x, embedding_matrix, transpose_b=True)
224
  return tf.cast(logits, tf.float32)
 
249
  model = ReLM(
250
  vocab_size=vocab_size,
251
  max_seq_len=max_len,
252
+ d_model=700,
253
  n_layers=2
254
  )
255