Yuchan commited on
Commit
8ca26fe
ยท
verified ยท
1 Parent(s): 586cdb5

Create Model.py

Browse files
Files changed (1) hide show
  1. Model.py +328 -0
Model.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ from tensorflow.keras import layers
6
+ import sentencepiece as spm
7
+ import requests
8
+
9
+ # โฌ‡๏ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ํ•จ์ˆ˜
10
+ def download_file(url, save_path):
11
+ response = requests.get(url, stream=True)
12
+ response.raise_for_status()
13
+ with open(save_path, 'wb') as f:
14
+ for chunk in response.iter_content(chunk_size=8192):
15
+ f.write(chunk)
16
+ print(f"โœ… ํŒŒ์ผ ์ €์žฅ๋จ: {save_path}")
17
+
18
+ # โฌ‡๏ธ ๋ฐ์ดํ„ฐ์™€ ํ† ํฌ๋‚˜์ด์ € ๋‹ค์šด๋กœ๋“œ
19
+ download_file('https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/ko_unigram.model?download=true', 'ko_unigram.model')
20
+ download_file('https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet?download=true', 'dataset.parquet')
21
+
22
+ # โฌ‡๏ธ Parquet ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
23
+ df = pd.read_parquet("dataset.parquet", engine="pyarrow")
24
+
25
+ # โฌ‡๏ธ <start> ์งˆ๋ฌธ <sep> ๋‹ต๋ณ€ <end> ํฌ๋งท์œผ๋กœ ๋ณ€ํ™˜
26
+ train_sentences = []
27
+
28
+ for conversations in df["conversations"]:
29
+ for i in range(0, len(conversations) - 1, 2):
30
+ item1, item2 = conversations[i], conversations[i + 1]
31
+ if item1.get("from") == "human" and item2.get("from") == "gpt":
32
+ prompt = item1.get("value", "").strip().replace("\n", " ")
33
+ response = item2.get("value", "").strip().replace("\n", " ")
34
+ full = f"<start> {prompt} <sep> {response} <end>"
35
+ train_sentences.append(full)
36
+ train_sentences = train_sentences
37
+ print(f"์ด ๋ฌธ์žฅ ๊ฐœ์ˆ˜: {len(train_sentences)}")
38
+
39
+ # โฌ‡๏ธ ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
40
+ sp = spm.SentencePieceProcessor()
41
+ sp.load("ko_unigram.model")
42
+
43
+ # โฌ‡๏ธ ํŠน์ˆ˜ ํ† ํฐ ID ์ถ”์ถœ
44
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
45
+ start_id = sp.piece_to_id("<start>")
46
+ sep_id = sp.piece_to_id("<sep>")
47
+ end_id = sp.piece_to_id("<end>")
48
+ unk_id = sp.piece_to_id("<unk>")
49
+
50
+ vocab_size = sp.get_piece_size()
51
+ print(f"โœ… Vocabulary size: {vocab_size}")
52
+
53
+ # โฌ‡๏ธ ํ…์ŠคํŠธ <-> ID ๋ณ€ํ™˜ ํ•จ์ˆ˜
54
+ def text_to_ids(text):
55
+ return sp.encode(text, out_type=int)
56
+
57
+ def ids_to_text(ids):
58
+ return sp.decode(ids)
59
+
60
+ # โฌ‡๏ธ ์ „์ฒ˜๋ฆฌ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
61
+ max_len = 100
62
+ batch_size = 128
63
+
64
+ # โฌ‡๏ธ ์ธํ’‹๊ณผ ํƒ€๊ฒŸ ๋งˆ์Šคํ‚น ํฌํ•จ๋œ ์ „์ฒ˜๋ฆฌ
65
+ encoded_inputs = []
66
+ targets = []
67
+
68
+ for sentence in train_sentences:
69
+ if "<sep>" not in sentence:
70
+ continue
71
+
72
+ sep_index = sentence.index("<sep>")
73
+ input_text = sentence[:sep_index + len("<sep>")].strip()
74
+ target_text = sentence[sep_index + len("<sep>"):].strip()
75
+
76
+ input_ids = text_to_ids(input_text)
77
+ target_ids = text_to_ids(target_text + " <end>")
78
+
79
+ full_input = input_ids + target_ids
80
+ full_input = full_input[:max_len]
81
+
82
+ target_mask = [0] * len(input_ids) + [1] * len(target_ids)
83
+ target_mask = target_mask[:max_len]
84
+
85
+ if len(full_input) < max_len:
86
+ pad_len = max_len - len(full_input)
87
+ full_input += [pad_id] * pad_len
88
+ target_mask += [0] * pad_len
89
+
90
+ encoded_inputs.append(full_input)
91
+
92
+ target_seq = full_input[1:] + [end_id]
93
+ target_seq = target_seq[:max_len]
94
+
95
+ masked_target = [
96
+ t if m == 1 else pad_id
97
+ for t, m in zip(target_seq, target_mask)
98
+ ]
99
+
100
+ targets.append(masked_target)
101
+
102
+ # โฌ‡๏ธ ๋„˜ํŒŒ์ด ๋ณ€ํ™˜
103
+ encoded_inputs = np.array(encoded_inputs)
104
+ targets = np.array(targets)
105
+
106
+ # โฌ‡๏ธ TensorFlow Dataset ์ƒ์„ฑ
107
+ def data_generator():
108
+ for input_seq, target_seq in zip(encoded_inputs, targets):
109
+ yield input_seq, target_seq
110
+
111
+ dataset = tf.data.Dataset.from_generator(
112
+ data_generator,
113
+ output_signature=(
114
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
115
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32)
116
+ )
117
+ )
118
+
119
+ dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
120
+
121
+ print("โœ… TF Dataset ์ƒ์„ฑ ์™„๋ฃŒ!")
122
+
123
+ class Adapter(layers.Layer):
124
+ def __init__(self, d_model):
125
+ super().__init__()
126
+ # ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์œ ์ง€
127
+ self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
128
+ self.p = layers.Dense(128, use_bias=True, dtype='float32')
129
+ self._out_dtype = 'float32'
130
+ self.ln = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
131
+ self.ln1 = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
132
+
133
+ def call(self, x):
134
+ # x may be bfloat16; cast to float32 for stable intermediate computation
135
+ x_f32 = tf.cast(x, tf.float32)
136
+ re = x_f32
137
+ x_f32 = self.ln(x_f32)
138
+ x = self.p(x_f32)
139
+ x = tf.nn.gelu(x)
140
+ x = self.proj(x)
141
+ x = self.ln1(x) + re
142
+ # cast back to model dtype for consistency
143
+ return tf.cast(x, self._out_dtype)
144
+
145
+ class SwiGLU(layers.Layer):
146
+ def __init__(self, d_model):
147
+ super().__init__()
148
+ self.proj = layers.Dense(2304)
149
+ self.w1 = layers.Dense(d_model)
150
+ self.ln = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
151
+ self.ln1 = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
152
+
153
+ def call(self, x):
154
+ x = self.ln(x)
155
+ x = self.proj(x)
156
+ a, b = tf.split(x, 2, axis=-1)
157
+ o = tf.nn.silu(a) * b
158
+ o = self.ln1(self.w1(o))
159
+ return o
160
+
161
+ class LowRankGLA(tf.keras.layers.Layer):
162
+ def __init__(self, d_model, low_rank_dim, **kwargs):
163
+ super(LowRankGLA, self).__init__(**kwargs)
164
+ self.d_model = d_model
165
+ self.low_rank_dim = low_rank_dim
166
+
167
+ # Low-rank projections for Q, K, V, G
168
+ # W_q โ‰ˆ W_q_A * W_q_B
169
+ self.W_q_A = layers.Dense(low_rank_dim, use_bias=True)
170
+
171
+ self.W_k_A = layers.Dense(low_rank_dim, use_bias=True)
172
+
173
+ self.W_v_A = layers.Dense(low_rank_dim, use_bias=True)
174
+
175
+ self.W_g_A = layers.Dense(low_rank_dim, use_bias=True)
176
+
177
+ # Output projection
178
+ self.output_dense_B = layers.Dense(d_model, use_bias=True)
179
+
180
+ def call(self, inputs):
181
+ # inputs shape: (batch_size, seq_len, d_model)
182
+
183
+ # Low-rank projections
184
+ # Q = inputs * W_q_A * W_q_B
185
+ q = self.W_q_A(inputs)
186
+ k = self.W_k_A(inputs)
187
+ v = self.W_v_A(inputs)
188
+ g = self.W_g_A(inputs)
189
+
190
+ # Apply activation functions
191
+ q = tf.nn.sigmoid(q)
192
+ k = tf.nn.sigmoid(k)
193
+ g = tf.nn.sigmoid(g)
194
+
195
+ # GLA computation with cumulative sum
196
+ attn_weights = q * k # (batch_size, seq_len, d_model)
197
+ numerator = tf.cumsum(attn_weights * v, axis=1)
198
+ denominator = tf.cumsum(attn_weights, axis=1) + 1e-12
199
+ output = numerator / denominator
200
+ output = output * g # Apply gate
201
+
202
+ # Final low-rank output projection
203
+ output = self.output_dense_B(output)
204
+
205
+ return output
206
+
207
+ def get_config(self):
208
+ config = super().get_config()
209
+ config.update({
210
+ "d_model": self.d_model,
211
+ "low_rank_dim": self.low_rank_dim,
212
+ })
213
+ return config
214
+
215
+ class Respiso(tf.keras.Model):
216
+ def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
217
+ super().__init__()
218
+ self.token_embedding = layers.Embedding(vocab_size, d_model)
219
+ self.gla = LowRankGLA(d_model, 48)
220
+ self.glu = SwiGLU(d_model)
221
+ self.adapter = Adapter(d_model)
222
+ self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
223
+ self.lm_head = layers.Dense(vocab_size, use_bias=False)
224
+
225
+ def call(self, x, training=False):
226
+ x = self.token_embedding(x)
227
+ x = self.glu(x)
228
+ x = self.adapter(x)
229
+ x = self.ln_f(x)
230
+ logits = self.lm_head(x)
231
+ return tf.cast(logits, tf.float32)
232
+
233
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
234
+
235
+ def masked_loss(y_true, y_pred):
236
+ loss = loss_fn(y_true, y_pred)
237
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
238
+ masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
239
+ return masked_loss
240
+
241
+ def masked_perplexity(y_true, y_pred):
242
+ loss = loss_fn(y_true, y_pred)
243
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
244
+ avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
245
+ return tf.exp(tf.minimum(avg_loss, 10.0)) # ์ˆ˜์น˜ ์•ˆ์ •์„ฑ ํ™•๋ณด
246
+
247
+ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
248
+ return tf.keras.optimizers.schedules.ExponentialDecay(
249
+ initial_learning_rate=initial_lr,
250
+ decay_steps=decay_steps,
251
+ decay_rate=decay_rate,
252
+ staircase=False
253
+ )
254
+
255
+ # ๋ชจ๋ธ ์ƒ์„ฑ
256
+ model = Respiso(
257
+ vocab_size=vocab_size,
258
+ max_seq_len=max_len,
259
+ d_model=256,
260
+ n_layers=1
261
+ )
262
+
263
+ # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
264
+ optimizer = tf.keras.optimizers.Adam(
265
+ learning_rate=create_lr_schedule(),
266
+ beta_1=0.9,
267
+ beta_2=0.95,
268
+ epsilon=1e-8,
269
+ clipnorm=1.0
270
+ )
271
+
272
+ # ๋ชจ๋ธ ์ปดํŒŒ์ผ
273
+ model.compile(
274
+ optimizer=optimizer,
275
+ loss=masked_loss,
276
+ metrics=[
277
+ masked_perplexity
278
+ ]
279
+ )
280
+
281
+ # ๋”๋ฏธ ์ธํ’‹์œผ๋กœ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
282
+ dummy_input = np.zeros((1, max_len), dtype=np.int32)
283
+ model(dummy_input)
284
+ model.summary()
285
+
286
+ # ํ•™์Šต ์‹œ์ž‘
287
+ history = model.fit(
288
+ dataset,
289
+ epochs=1,
290
+ steps_per_epoch = encoded_inputs.shape[0] // batch_size,
291
+ verbose=1
292
+ )
293
+
294
+ # ๊ฐ€์ค‘์น˜ ์ €์žฅ
295
+ model.save_weights("Cobra.weights.h5")
296
+ print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
297
+
298
+ def generate_text_topp(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=20):
299
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
300
+ model_input = model_input[:max_len]
301
+ generated = list(model_input)
302
+ for step in range(max_gen):
303
+ if len(generated) > max_len:
304
+ input_seq = generated[-max_len:]
305
+ else:
306
+ input_seq = generated
307
+ input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
308
+ input_tensor = tf.convert_to_tensor([input_padded])
309
+ logits = model(input_tensor, training=False)
310
+ next_token_logits = logits[0, len(input_seq) - 1].numpy()
311
+ next_token_logits[end_id] -= 5.0
312
+ next_token_logits[pad_id] -= 10.0
313
+ probs = tf.nn.softmax(next_token_logits / temperature).numpy()
314
+ sorted_indices = np.argsort(probs)[::-1]
315
+ sorted_probs = probs[sorted_indices]
316
+ cumulative_probs = np.cumsum(sorted_probs)
317
+ cutoff = np.searchsorted(cumulative_probs, p)
318
+ top_indices = sorted_indices[:cutoff + 1]
319
+ top_probs = sorted_probs[:cutoff + 1]
320
+ top_probs /= np.sum(top_probs)
321
+ next_token_id = np.random.choice(top_indices, p=top_probs)
322
+ if next_token_id == end_id and len(generated) >= min_len:
323
+ break
324
+ generated.append(int(next_token_id))
325
+ return ids_to_text(generated)
326
+
327
+ print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
328
+ print(generate_text_topp(model, "์•ˆ๋…•", p=0.9))