Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -69,7 +69,7 @@ vocab_size = sp.get_piece_size()
|
|
| 69 |
print(f"โ
Vocabulary size: {vocab_size}")
|
| 70 |
|
| 71 |
max_len = 512
|
| 72 |
-
batch_size =
|
| 73 |
|
| 74 |
def text_to_ids(text):
|
| 75 |
return sp.encode(text, out_type=int)
|
|
@@ -124,71 +124,58 @@ class SwiGLU(layers.Layer):
|
|
| 124 |
out = self.W1(tf.nn.silu(a) * b)
|
| 125 |
return tf.cast(out, x.dtype)
|
| 126 |
|
| 127 |
-
class
|
| 128 |
-
def __init__(self,
|
| 129 |
-
super().__init__()
|
| 130 |
-
self.
|
| 131 |
-
self.
|
| 132 |
-
self.
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
self.
|
| 136 |
-
self.
|
| 137 |
-
self.
|
| 138 |
-
self.
|
| 139 |
-
|
| 140 |
-
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 141 |
-
|
| 142 |
-
def _ema_over_time(self, score, alpha_dynamic):
|
| 143 |
-
seq = tf.transpose(score, perm=[1, 0, 2])
|
| 144 |
-
alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
|
| 145 |
-
|
| 146 |
-
def step(prev_ema, inputs):
|
| 147 |
-
x_t, alpha_t = inputs
|
| 148 |
-
new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
|
| 149 |
-
return new
|
| 150 |
-
|
| 151 |
-
init = seq[0]
|
| 152 |
-
first_alpha = alpha_seq[0]
|
| 153 |
-
remaining_seq = seq[1:]
|
| 154 |
-
remaining_alpha = alpha_seq[1:]
|
| 155 |
-
elems = (remaining_seq, remaining_alpha)
|
| 156 |
-
ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
|
| 157 |
-
ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
|
| 158 |
-
ema = tf.transpose(ema_seq, perm=[1, 0, 2])
|
| 159 |
-
return ema
|
| 160 |
|
| 161 |
def call(self, x):
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
#
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
class Lo(layers.Layer):
|
| 189 |
def __init__(self, d_model):
|
| 190 |
super().__init__()
|
| 191 |
-
self.d = layers.Dense(
|
| 192 |
self.w = layers.Dense(d_model)
|
| 193 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 194 |
|
|
@@ -200,7 +187,7 @@ class Lo(layers.Layer):
|
|
| 200 |
class Block(layers.Layer):
|
| 201 |
def __init__(self, d_model):
|
| 202 |
super().__init__()
|
| 203 |
-
self.lou =
|
| 204 |
self.glu = SwiGLU(d_model)
|
| 205 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 206 |
self.lo = Lo(d_model)
|
|
@@ -256,8 +243,8 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
|
| 256 |
model = ReLM(
|
| 257 |
vocab_size=vocab_size,
|
| 258 |
max_seq_len=max_len,
|
| 259 |
-
d_model=
|
| 260 |
-
n_layers=
|
| 261 |
)
|
| 262 |
|
| 263 |
# ์ตํฐ๋ง์ด์ ์ค์
|
|
|
|
| 69 |
print(f"โ
Vocabulary size: {vocab_size}")
|
| 70 |
|
| 71 |
max_len = 512
|
| 72 |
+
batch_size = 32
|
| 73 |
|
| 74 |
def text_to_ids(text):
|
| 75 |
return sp.encode(text, out_type=int)
|
|
|
|
| 124 |
out = self.W1(tf.nn.silu(a) * b)
|
| 125 |
return tf.cast(out, x.dtype)
|
| 126 |
|
| 127 |
+
class SparseCausalAttention(Layer):
|
| 128 |
+
def __init__(self, num_heads, head_dim, window_size=16, **kwargs):
|
| 129 |
+
super().__init__(**kwargs)
|
| 130 |
+
self.num_heads = num_heads
|
| 131 |
+
self.head_dim = head_dim
|
| 132 |
+
self.window_size = window_size # ๋ก์ปฌ ์๋์ฐ ํฌ๊ธฐ
|
| 133 |
+
|
| 134 |
+
def build(self, input_shape):
|
| 135 |
+
self.q_dense = Dense(self.num_heads * self.head_dim)
|
| 136 |
+
self.k_dense = Dense(self.num_heads * self.head_dim)
|
| 137 |
+
self.v_dense = Dense(self.num_heads * self.head_dim)
|
| 138 |
+
self.out_dense = Dense(input_shape[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
def call(self, x):
|
| 141 |
+
batch_size, seq_len, dim = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
|
| 142 |
+
|
| 143 |
+
# Q, K, V
|
| 144 |
+
q = tf.reshape(self.q_dense(x), (batch_size, seq_len, self.num_heads, self.head_dim))
|
| 145 |
+
k = tf.reshape(self.k_dense(x), (batch_size, seq_len, self.num_heads, self.head_dim))
|
| 146 |
+
v = tf.reshape(self.v_dense(x), (batch_size, seq_len, self.num_heads, self.head_dim))
|
| 147 |
+
|
| 148 |
+
# Transpose for matmul: (batch, heads, seq, head_dim)
|
| 149 |
+
q = tf.transpose(q, perm=[0, 2, 1, 3])
|
| 150 |
+
k = tf.transpose(k, perm=[0, 2, 1, 3])
|
| 151 |
+
v = tf.transpose(v, perm=[0, 2, 1, 3])
|
| 152 |
+
|
| 153 |
+
# ์ค์ผ์ผ
|
| 154 |
+
scale = tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
|
| 155 |
+
q = q / scale
|
| 156 |
+
|
| 157 |
+
# ํฌ์ ๋ง์คํฌ ๊ณ์ฐ: ๋ก์ปฌ ์๋์ฐ
|
| 158 |
+
# ๊ฐ ํ ํฐ i๋ max(i-window_size,0) ~ i๊น์ง attention
|
| 159 |
+
attn_scores = tf.matmul(q, k, transpose_b=True) # (batch, heads, seq, seq)
|
| 160 |
+
mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) # causal mask
|
| 161 |
+
# ์๋์ฐ ํฌ๊ธฐ ์ ํ
|
| 162 |
+
band_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), self.window_size, 0)
|
| 163 |
+
mask = mask * band_mask
|
| 164 |
+
mask = tf.reshape(mask, (1, 1, seq_len, seq_len)) # ๋ธ๋ก๋์บ์คํธ ๊ฐ๋ฅ
|
| 165 |
+
attn_scores = tf.where(mask > 0, attn_scores, tf.fill(tf.shape(attn_scores), -1e9))
|
| 166 |
+
|
| 167 |
+
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
|
| 168 |
+
attn_output = tf.matmul(attn_probs, v) # (batch, heads, seq, head_dim)
|
| 169 |
+
|
| 170 |
+
# ํฉ์น๊ธฐ
|
| 171 |
+
attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
|
| 172 |
+
attn_output = tf.reshape(attn_output, (batch_size, seq_len, self.num_heads*self.head_dim))
|
| 173 |
+
return self.out_dense(attn_output)
|
| 174 |
|
| 175 |
class Lo(layers.Layer):
|
| 176 |
def __init__(self, d_model):
|
| 177 |
super().__init__()
|
| 178 |
+
self.d = layers.Dense(64, activation='silu')
|
| 179 |
self.w = layers.Dense(d_model)
|
| 180 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 181 |
|
|
|
|
| 187 |
class Block(layers.Layer):
|
| 188 |
def __init__(self, d_model):
|
| 189 |
super().__init__()
|
| 190 |
+
self.lou = SparseCausalAttention(num_heads=2, head_dim=64)
|
| 191 |
self.glu = SwiGLU(d_model)
|
| 192 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 193 |
self.lo = Lo(d_model)
|
|
|
|
| 243 |
model = ReLM(
|
| 244 |
vocab_size=vocab_size,
|
| 245 |
max_seq_len=max_len,
|
| 246 |
+
d_model=128,
|
| 247 |
+
n_layers=2
|
| 248 |
)
|
| 249 |
|
| 250 |
# ์ตํฐ๋ง์ด์ ์ค์
|