abersbail's picture
Add small GPT Python Space
79078fe verified
import gradio as gr
from small_gpt.config import SmallGPTConfig
from small_gpt.service import SmallGPTService
config = SmallGPTConfig()
service = SmallGPTService(config=config)
def generate_text(prompt, max_new_tokens, temperature, top_k):
return service.generate(
prompt=prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_k=int(top_k),
)
def train_model(extra_text, steps):
return service.train(extra_text=extra_text, steps=int(steps))
def reset_model():
return service.reset()
with gr.Blocks(
title="Small GPT Python",
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"),
) as demo:
gr.Markdown(
"""
# Small GPT Python
A tiny GPT-style language model written in Python from scratch.
- Causal transformer decoder
- Word-level tokenizer
- No external pretrained LLM
- Local CPU training and generation
"""
)
with gr.Tab("Generate"):
prompt_input = gr.Textbox(
label="Prompt",
value="User: hello\nAssistant:",
lines=6,
)
with gr.Row():
max_tokens_input = gr.Slider(10, 180, value=72, step=2, label="Max New Tokens")
temperature_input = gr.Slider(0.2, 1.3, value=0.75, step=0.05, label="Temperature")
top_k_input = gr.Slider(1, 20, value=8, step=1, label="Top-K")
generate_button = gr.Button("Generate", variant="primary")
output_text = gr.Textbox(label="Output", lines=10)
output_status = gr.Textbox(label="Status", lines=4)
with gr.Tab("Train"):
extra_text_input = gr.Textbox(
label="Extra Training Text",
placeholder="Add more local text to continue training the small GPT model.",
lines=10,
)
steps_input = gr.Slider(10, 400, value=120, step=10, label="Training Steps")
train_button = gr.Button("Train / Continue Training", variant="primary")
reset_button = gr.Button("Reset Model")
train_status = gr.Textbox(label="Training Status", lines=6)
generate_button.click(
fn=generate_text,
inputs=[prompt_input, max_tokens_input, temperature_input, top_k_input],
outputs=[output_text, output_status],
)
train_button.click(
fn=train_model,
inputs=[extra_text_input, steps_input],
outputs=[train_status],
)
reset_button.click(
fn=reset_model,
outputs=[train_status],
)
if __name__ == "__main__":
demo.launch()