import gradio as gr from small_gpt.config import SmallGPTConfig from small_gpt.service import SmallGPTService config = SmallGPTConfig() service = SmallGPTService(config=config) def generate_text(prompt, max_new_tokens, temperature, top_k): return service.generate( prompt=prompt, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_k=int(top_k), ) def train_model(extra_text, steps): return service.train(extra_text=extra_text, steps=int(steps)) def reset_model(): return service.reset() with gr.Blocks( title="Small GPT Python", theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"), ) as demo: gr.Markdown( """ # Small GPT Python A tiny GPT-style language model written in Python from scratch. - Causal transformer decoder - Word-level tokenizer - No external pretrained LLM - Local CPU training and generation """ ) with gr.Tab("Generate"): prompt_input = gr.Textbox( label="Prompt", value="User: hello\nAssistant:", lines=6, ) with gr.Row(): max_tokens_input = gr.Slider(10, 180, value=72, step=2, label="Max New Tokens") temperature_input = gr.Slider(0.2, 1.3, value=0.75, step=0.05, label="Temperature") top_k_input = gr.Slider(1, 20, value=8, step=1, label="Top-K") generate_button = gr.Button("Generate", variant="primary") output_text = gr.Textbox(label="Output", lines=10) output_status = gr.Textbox(label="Status", lines=4) with gr.Tab("Train"): extra_text_input = gr.Textbox( label="Extra Training Text", placeholder="Add more local text to continue training the small GPT model.", lines=10, ) steps_input = gr.Slider(10, 400, value=120, step=10, label="Training Steps") train_button = gr.Button("Train / Continue Training", variant="primary") reset_button = gr.Button("Reset Model") train_status = gr.Textbox(label="Training Status", lines=6) generate_button.click( fn=generate_text, inputs=[prompt_input, max_tokens_input, temperature_input, top_k_input], outputs=[output_text, output_status], ) train_button.click( fn=train_model, inputs=[extra_text_input, steps_input], outputs=[train_status], ) reset_button.click( fn=reset_model, outputs=[train_status], ) if __name__ == "__main__": demo.launch()