Yuchan commited on
Commit
395c8b3
ยท
verified ยท
1 Parent(s): f8d58f5

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +50 -63
Model.py CHANGED
@@ -69,7 +69,7 @@ vocab_size = sp.get_piece_size()
69
  print(f"โœ… Vocabulary size: {vocab_size}")
70
 
71
  max_len = 512
72
- batch_size = 256
73
 
74
  def text_to_ids(text):
75
  return sp.encode(text, out_type=int)
@@ -124,71 +124,58 @@ class SwiGLU(layers.Layer):
124
  out = self.W1(tf.nn.silu(a) * b)
125
  return tf.cast(out, x.dtype)
126
 
127
- class LoU(layers.Layer):
128
- def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
- super().__init__()
130
- self.d_model = d_model
131
- self.clip_value = float(clip_value)
132
- self.eps = float(eps)
133
- self.Q = layers.Dense(d_model, dtype='float32')
134
- self.K = layers.Dense(d_model, dtype='float32')
135
- self.V = layers.Dense(d_model, dtype='float32')
136
- self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
137
- self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
138
- self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
139
-
140
- self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
141
-
142
- def _ema_over_time(self, score, alpha_dynamic):
143
- seq = tf.transpose(score, perm=[1, 0, 2])
144
- alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
145
-
146
- def step(prev_ema, inputs):
147
- x_t, alpha_t = inputs
148
- new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
149
- return new
150
-
151
- init = seq[0]
152
- first_alpha = alpha_seq[0]
153
- remaining_seq = seq[1:]
154
- remaining_alpha = alpha_seq[1:]
155
- elems = (remaining_seq, remaining_alpha)
156
- ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
157
- ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
158
- ema = tf.transpose(ema_seq, perm=[1, 0, 2])
159
- return ema
160
 
161
  def call(self, x):
162
- x_f32 = tf.cast(x, tf.float32)
163
- residual = x_f32
164
- x_f32 = self.norm1(x)
165
-
166
- q = self.Q(x_f32)
167
- k = self.K(x_f32)
168
- V = self.V(x_f32)
169
- # ๊ธฐ์กด ์ฝ”๋“œ:
170
- # g_q = tf.nn.sigmoid(q)
171
- # g_k = tf.nn.sigmoid(k)
172
-
173
- g_q = (tf.nn.tanh(q) + 1.0) / 2.0
174
- g_k = (tf.nn.tanh(k) + 1.0) / 2.0
175
- score = g_q * g_k
176
-
177
- alpha_dynamic = self.alpha_linear(x_f32)
178
- score_ema = self._ema_over_time(score, alpha_dynamic)
179
- mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
180
- denom = tf.maximum(mean_last, self.eps)
181
- score_norm = score_ema / denom
182
- score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
183
- x_comb = score_clipped * V
184
- out = self.proj(x_comb)
185
- out = self.norm(out + residual)
186
- return tf.cast(out, x.dtype)
 
 
 
 
 
 
 
 
187
 
188
  class Lo(layers.Layer):
189
  def __init__(self, d_model):
190
  super().__init__()
191
- self.d = layers.Dense(256, activation='silu')
192
  self.w = layers.Dense(d_model)
193
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
194
 
@@ -200,7 +187,7 @@ class Lo(layers.Layer):
200
  class Block(layers.Layer):
201
  def __init__(self, d_model):
202
  super().__init__()
203
- self.lou = LoU(d_model)
204
  self.glu = SwiGLU(d_model)
205
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
206
  self.lo = Lo(d_model)
@@ -256,8 +243,8 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
256
  model = ReLM(
257
  vocab_size=vocab_size,
258
  max_seq_len=max_len,
259
- d_model=700,
260
- n_layers=16
261
  )
262
 
263
  # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
 
69
  print(f"โœ… Vocabulary size: {vocab_size}")
70
 
71
  max_len = 512
72
+ batch_size = 32
73
 
74
  def text_to_ids(text):
75
  return sp.encode(text, out_type=int)
 
124
  out = self.W1(tf.nn.silu(a) * b)
125
  return tf.cast(out, x.dtype)
126
 
127
+ class SparseCausalAttention(Layer):
128
+ def __init__(self, num_heads, head_dim, window_size=16, **kwargs):
129
+ super().__init__(**kwargs)
130
+ self.num_heads = num_heads
131
+ self.head_dim = head_dim
132
+ self.window_size = window_size # ๋กœ์ปฌ ์œˆ๋„์šฐ ํฌ๊ธฐ
133
+
134
+ def build(self, input_shape):
135
+ self.q_dense = Dense(self.num_heads * self.head_dim)
136
+ self.k_dense = Dense(self.num_heads * self.head_dim)
137
+ self.v_dense = Dense(self.num_heads * self.head_dim)
138
+ self.out_dense = Dense(input_shape[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  def call(self, x):
141
+ batch_size, seq_len, dim = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
142
+
143
+ # Q, K, V
144
+ q = tf.reshape(self.q_dense(x), (batch_size, seq_len, self.num_heads, self.head_dim))
145
+ k = tf.reshape(self.k_dense(x), (batch_size, seq_len, self.num_heads, self.head_dim))
146
+ v = tf.reshape(self.v_dense(x), (batch_size, seq_len, self.num_heads, self.head_dim))
147
+
148
+ # Transpose for matmul: (batch, heads, seq, head_dim)
149
+ q = tf.transpose(q, perm=[0, 2, 1, 3])
150
+ k = tf.transpose(k, perm=[0, 2, 1, 3])
151
+ v = tf.transpose(v, perm=[0, 2, 1, 3])
152
+
153
+ # ์Šค์ผ€์ผ
154
+ scale = tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
155
+ q = q / scale
156
+
157
+ # ํฌ์†Œ ๋งˆ์Šคํฌ ๊ณ„์‚ฐ: ๋กœ์ปฌ ์œˆ๋„์šฐ
158
+ # ๊ฐ ํ† ํฐ i๋Š” max(i-window_size,0) ~ i๊นŒ์ง€ attention
159
+ attn_scores = tf.matmul(q, k, transpose_b=True) # (batch, heads, seq, seq)
160
+ mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) # causal mask
161
+ # ์œˆ๋„์šฐ ํฌ๊ธฐ ์ œํ•œ
162
+ band_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), self.window_size, 0)
163
+ mask = mask * band_mask
164
+ mask = tf.reshape(mask, (1, 1, seq_len, seq_len)) # ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๊ฐ€๋Šฅ
165
+ attn_scores = tf.where(mask > 0, attn_scores, tf.fill(tf.shape(attn_scores), -1e9))
166
+
167
+ attn_probs = tf.nn.softmax(attn_scores, axis=-1)
168
+ attn_output = tf.matmul(attn_probs, v) # (batch, heads, seq, head_dim)
169
+
170
+ # ํ•ฉ์น˜๊ธฐ
171
+ attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
172
+ attn_output = tf.reshape(attn_output, (batch_size, seq_len, self.num_heads*self.head_dim))
173
+ return self.out_dense(attn_output)
174
 
175
  class Lo(layers.Layer):
176
  def __init__(self, d_model):
177
  super().__init__()
178
+ self.d = layers.Dense(64, activation='silu')
179
  self.w = layers.Dense(d_model)
180
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
181
 
 
187
  class Block(layers.Layer):
188
  def __init__(self, d_model):
189
  super().__init__()
190
+ self.lou = SparseCausalAttention(num_heads=2, head_dim=64)
191
  self.glu = SwiGLU(d_model)
192
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
193
  self.lo = Lo(d_model)
 
243
  model = ReLM(
244
  vocab_size=vocab_size,
245
  max_seq_len=max_len,
246
+ d_model=128,
247
+ n_layers=2
248
  )
249
 
250
  # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •