OpenLEM2 / app.py
OpenLab-NLP's picture
Update app.py
5d9caf3 verified
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import sentencepiece as spm
import gradio as gr
import requests
import os
# ----------------------
# ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์œ ํ‹ธ
# ----------------------
def download_file(url, save_path):
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} ์ €์žฅ๋จ")
MODEL_PATH = "encoder.weights.h5"
TOKENIZER_PATH = "bpe.model"
if not os.path.exists(MODEL_PATH):
download_file(
"https://huggingface.co/OpenLab-NLP/openlem2/resolve/main/encoder_fit.weights.h5?download=true",
MODEL_PATH
)
if not os.path.exists(TOKENIZER_PATH):
download_file(
"https://huggingface.co/OpenLab-NLP/openlem2/resolve/main/bpe.model?download=true",
TOKENIZER_PATH
)
MAX_LEN = 128
EMBED_DIM = 384
LATENT_DIM = 384
DROP_RATE = 0.1
# ===============================
# 1๏ธโƒฃ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
# ===============================
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
vocab_size = sp.get_piece_size()
def encode_sentence(sentence, max_len=MAX_LEN):
return sp.encode(sentence, out_type=int)[:max_len]
def pad_sentence(tokens):
return tokens + [pad_id]*(MAX_LEN - len(tokens))
class DynamicConv(layers.Layer):
def __init__(self, k=7):
super().__init__()
assert k % 2 == 1, "kernel size should be odd for symmetric padding"
self.k = k
# generator๋Š” ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด k๊ฐœ์˜ ๋กœ์ง“์„ ๋ฑ‰์Œ -> softmax๋กœ ๊ฐ€์ค‘์น˜ํ™”
self.generator = layers.Dense(k)
def call(self, x):
# x: (B, L, D)
B = tf.shape(x)[0]
L = tf.shape(x)[1]
D = tf.shape(x)[2]
# (B, L, k) logits -> softmax -> (B, L, k)
kernels = self.generator(x)
kernels = tf.nn.softmax(kernels, axis=-1)
# padding (same)
pad = (self.k - 1) // 2
x_pad = tf.pad(x, [[0, 0], [pad, pad], [0, 0]]) # (B, L+2pad, D)
# extract patches using tf.image.extract_patches:
# make 4D: (B, H=1, W=L+2pad, C=D)
x_pad_4d = tf.expand_dims(x_pad, axis=1)
patches = tf.image.extract_patches(
images=x_pad_4d,
sizes=[1, 1, self.k, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID'
) # (B, 1, L, k*D)
# reshape -> (B, L, k, D)
patches = tf.reshape(patches, [B, 1, L, self.k * D])
patches = tf.squeeze(patches, axis=1)
patches = tf.reshape(patches, [B, L, self.k, D])
# kernels: (B, L, k) -> (B, L, k, 1)
kernels_exp = tf.expand_dims(kernels, axis=-1)
# weighted sum over kernel dim -> (B, L, D)
out = tf.reduce_sum(patches * kernels_exp, axis=2)
return out
class EncoderBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN, num_conv_layers=2):
super().__init__()
self.embed_dim = embed_dim
self.seq_len = seq_len
# MLP / FFN
self.fc1 = layers.Dense(ff_dim)
self.fc2 = layers.Dense(embed_dim)
# DynamicConv ๋ธ”๋ก ์—ฌ๋Ÿฌ ๊ฐœ ์Œ“๊ธฐ
self.blocks = [DynamicConv(k=7) for _ in range(num_conv_layers)]
# LayerNorm
self.ln = layers.LayerNormalization(epsilon=1e-5) # ์ž…๋ ฅ ์ •๊ทœํ™”
self.ln1 = layers.LayerNormalization(epsilon=1e-5) # Conv residual
self.ln2 = layers.LayerNormalization(epsilon=1e-5) # FFN residual
def call(self, x, mask=None):
# ์ž…๋ ฅ ์ •๊ทœํ™”
x_norm = self.ln(x)
# DynamicConv ์—ฌ๋Ÿฌ ์ธต ํ†ต๊ณผ
out = x_norm
for block in self.blocks:
out = block(out)
# Conv residual ์—ฐ๊ฒฐ
x = x_norm + self.ln1(out)
# FFN / GLU
v = out
h = self.fc1(v)
g, v_split = tf.split(h, 2, axis=-1)
h = tf.nn.silu(g) * v_split
h = self.fc2(h)
# FFN residual ์—ฐ๊ฒฐ
x = x + self.ln2(h)
return x
class L2NormLayer(layers.Layer):
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.epsilon = epsilon
def call(self, inputs):
return tf.math.l2_normalize(inputs, axis=self.axis, epsilon=self.epsilon)
def get_config(self):
return {"axis": self.axis, "epsilon": self.epsilon, **super().get_config()}
class SentenceEncoder(tf.keras.Model):
def __init__(self, vocab_size, embed_dim=384, latent_dim=384, max_len=128, pad_id=pad_id):
super().__init__()
self.pad_id = pad_id
self.embed = layers.Embedding(vocab_size, embed_dim)
self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
self.blocks = [EncoderBlock() for _ in range(2)]
self.attn_pool = layers.Dense(1)
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
self.latent = layers.Dense(latent_dim, activation=None) # tanh ์ œ๊ฑฐ
self.l2norm = L2NormLayer() # ์ถ”๊ฐ€
def call(self, x):
positions = tf.range(tf.shape(x)[1])[tf.newaxis, :]
x_embed = self.embed(x) + self.pos_embed(positions)
mask = tf.cast(tf.not_equal(x, self.pad_id), tf.float32)
x = x_embed
for block in self.blocks:
x = block(x, mask)
x = self.ln_f(x)
scores = self.attn_pool(x)
scores = tf.where(tf.equal(mask[..., tf.newaxis], 0), -1e9, scores)
scores = tf.nn.softmax(scores, axis=1)
pooled = tf.reduce_sum(x * scores, axis=1)
latent = self.latent(pooled)
return self.l2norm(latent) # L2 ์ •๊ทœํ™” ํ›„ ๋ฐ˜ํ™˜
# 3๏ธโƒฃ ๋ชจ๋ธ ๋กœ๋“œ
# ===============================
encoder = SentenceEncoder(vocab_size=vocab_size)
encoder(np.zeros((1, MAX_LEN), dtype=np.int32)) # ๋ชจ๋ธ ๋นŒ๋“œ
encoder.load_weights(MODEL_PATH)
# ===============================
# 4๏ธโƒฃ ๋ฒกํ„ฐํ™” ํ•จ์ˆ˜
# ===============================
def get_sentence_vector(sentence):
tokens = pad_sentence(encode_sentence(sentence))
vec = encoder(np.array([tokens])).numpy()[0]
return vec / np.linalg.norm(vec)
# ===============================
# 5๏ธโƒฃ ๊ฐ€์žฅ ๋น„์Šทํ•œ ๋ฌธ์žฅ ์ฐพ๊ธฐ
# ===============================
def find_most_similar(query, s1, s2, s3):
candidates = [s1, s2, s3]
candidate_vectors = np.stack([get_sentence_vector(c) for c in candidates]).astype(np.float32)
query_vector = get_sentence_vector(query)
sims = candidate_vectors @ query_vector # cosine similarity
top_idx = np.argmax(sims)
return {
"๊ฐ€์žฅ ๋น„์Šทํ•œ ๋ฌธ์žฅ": candidates[top_idx],
"์œ ์‚ฌ๋„": float(sims[top_idx])
}
# ===============================
# 6๏ธโƒฃ Gradio UI
# ===============================
with gr.Blocks() as demo:
gr.Markdown("## ๐Ÿ” ๋ฌธ์žฅ ์œ ์‚ฌ๋„ ๊ฒ€์ƒ‰๊ธฐ (์ฟผ๋ฆฌ 1๊ฐœ + ํ›„๋ณด 3๊ฐœ)")
with gr.Row():
query_input = gr.Textbox(label="๊ฒ€์ƒ‰ํ•  ๋ฌธ์žฅ (Query)", placeholder="์—ฌ๊ธฐ์— ์ž…๋ ฅ")
with gr.Row():
s1_input = gr.Textbox(label="๊ฒ€์ƒ‰ ํ›„๋ณด 1")
s2_input = gr.Textbox(label="๊ฒ€์ƒ‰ ํ›„๋ณด 2")
s3_input = gr.Textbox(label="๊ฒ€์ƒ‰ ํ›„๋ณด 3")
output = gr.JSON(label="๊ฒฐ๊ณผ")
search_btn = gr.Button("๊ฐ€์žฅ ๋น„์Šทํ•œ ๋ฌธ์žฅ ์ฐพ๊ธฐ")
search_btn.click(
fn=find_most_similar,
inputs=[query_input, s1_input, s2_input, s3_input],
outputs=output
)
demo.launch()