Yuchan
commited on
Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -48,7 +48,7 @@ TOKENIZER_PATH = "ko_unigram.model"
|
|
| 48 |
|
| 49 |
if not os.path.exists(DATA_PATH):
|
| 50 |
download_file(
|
| 51 |
-
"https://huggingface.co/datasets/Yuchan5386/
|
| 52 |
DATA_PATH
|
| 53 |
)
|
| 54 |
|
|
@@ -68,7 +68,7 @@ unk_id = sp.piece_to_id("<unk>")
|
|
| 68 |
vocab_size = sp.get_piece_size()
|
| 69 |
print(f"โ
Vocabulary size: {vocab_size}")
|
| 70 |
|
| 71 |
-
max_len =
|
| 72 |
batch_size = 256
|
| 73 |
|
| 74 |
def text_to_ids(text):
|
|
@@ -99,7 +99,7 @@ def txt_stream(file_path):
|
|
| 99 |
)
|
| 100 |
|
| 101 |
|
| 102 |
-
LIMIT =
|
| 103 |
|
| 104 |
dataset = tf.data.Dataset.from_generator(
|
| 105 |
lambda: txt_stream(DATA_PATH),
|
|
@@ -117,7 +117,7 @@ with strategy.scope():
|
|
| 117 |
class SwiGLU(layers.Layer):
|
| 118 |
def __init__(self, d_model, d_ff):
|
| 119 |
super().__init__()
|
| 120 |
-
self.proj = layers.Dense(
|
| 121 |
self.out = layers.Dense(d_model)
|
| 122 |
def call(self, x):
|
| 123 |
x_proj = self.proj(x)
|
|
@@ -257,7 +257,7 @@ def masked_perplexity(y_true, y_pred, eps=0.1):
|
|
| 257 |
# ๋ชจ๋ธ ์์ฑ & ์ปดํ์ผ
|
| 258 |
# =======================
|
| 259 |
with strategy.scope():
|
| 260 |
-
model = ReLM(vocab_size=vocab_size, max_seq_len=max_len, d_model=
|
| 261 |
dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
|
| 262 |
_ = model(dummy_input, training=False)
|
| 263 |
model.summary()
|
|
@@ -271,7 +271,7 @@ with strategy.scope():
|
|
| 271 |
model.save_weights("tf_model.weights.h5")
|
| 272 |
print("โ
๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 273 |
|
| 274 |
-
def generate_text_topp(model, prompt, max_len=
|
| 275 |
model_input = text_to_ids(f"<start> {prompt}")
|
| 276 |
model_input = model_input[:max_len]
|
| 277 |
generated = list(model_input)
|
|
|
|
| 48 |
|
| 49 |
if not os.path.exists(DATA_PATH):
|
| 50 |
download_file(
|
| 51 |
+
"https://huggingface.co/datasets/Yuchan5386/1/resolve/main/shuffled_corpus.txt?download=true",
|
| 52 |
DATA_PATH
|
| 53 |
)
|
| 54 |
|
|
|
|
| 68 |
vocab_size = sp.get_piece_size()
|
| 69 |
print(f"โ
Vocabulary size: {vocab_size}")
|
| 70 |
|
| 71 |
+
max_len = 256
|
| 72 |
batch_size = 256
|
| 73 |
|
| 74 |
def text_to_ids(text):
|
|
|
|
| 99 |
)
|
| 100 |
|
| 101 |
|
| 102 |
+
LIMIT = 36757266
|
| 103 |
|
| 104 |
dataset = tf.data.Dataset.from_generator(
|
| 105 |
lambda: txt_stream(DATA_PATH),
|
|
|
|
| 117 |
class SwiGLU(layers.Layer):
|
| 118 |
def __init__(self, d_model, d_ff):
|
| 119 |
super().__init__()
|
| 120 |
+
self.proj = layers.Dense(960)
|
| 121 |
self.out = layers.Dense(d_model)
|
| 122 |
def call(self, x):
|
| 123 |
x_proj = self.proj(x)
|
|
|
|
| 257 |
# ๋ชจ๋ธ ์์ฑ & ์ปดํ์ผ
|
| 258 |
# =======================
|
| 259 |
with strategy.scope():
|
| 260 |
+
model = ReLM(vocab_size=vocab_size, max_seq_len=max_len, d_model=384, n_layers=3)
|
| 261 |
dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
|
| 262 |
_ = model(dummy_input, training=False)
|
| 263 |
model.summary()
|
|
|
|
| 271 |
model.save_weights("tf_model.weights.h5")
|
| 272 |
print("โ
๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 273 |
|
| 274 |
+
def generate_text_topp(model, prompt, max_len=500, max_gen=500, p=0.9, temperature=0.8, min_len=20):
|
| 275 |
model_input = text_to_ids(f"<start> {prompt}")
|
| 276 |
model_input = model_input[:max_len]
|
| 277 |
generated = list(model_input)
|