Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from small_gpt.config import SmallGPTConfig | |
| from small_gpt.service import SmallGPTService | |
| config = SmallGPTConfig() | |
| service = SmallGPTService(config=config) | |
| def generate_text(prompt, max_new_tokens, temperature, top_k): | |
| return service.generate( | |
| prompt=prompt, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| ) | |
| def train_model(extra_text, steps): | |
| return service.train(extra_text=extra_text, steps=int(steps)) | |
| def reset_model(): | |
| return service.reset() | |
| with gr.Blocks( | |
| title="Small GPT Python", | |
| theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"), | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # Small GPT Python | |
| A tiny GPT-style language model written in Python from scratch. | |
| - Causal transformer decoder | |
| - Word-level tokenizer | |
| - No external pretrained LLM | |
| - Local CPU training and generation | |
| """ | |
| ) | |
| with gr.Tab("Generate"): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| value="User: hello\nAssistant:", | |
| lines=6, | |
| ) | |
| with gr.Row(): | |
| max_tokens_input = gr.Slider(10, 180, value=72, step=2, label="Max New Tokens") | |
| temperature_input = gr.Slider(0.2, 1.3, value=0.75, step=0.05, label="Temperature") | |
| top_k_input = gr.Slider(1, 20, value=8, step=1, label="Top-K") | |
| generate_button = gr.Button("Generate", variant="primary") | |
| output_text = gr.Textbox(label="Output", lines=10) | |
| output_status = gr.Textbox(label="Status", lines=4) | |
| with gr.Tab("Train"): | |
| extra_text_input = gr.Textbox( | |
| label="Extra Training Text", | |
| placeholder="Add more local text to continue training the small GPT model.", | |
| lines=10, | |
| ) | |
| steps_input = gr.Slider(10, 400, value=120, step=10, label="Training Steps") | |
| train_button = gr.Button("Train / Continue Training", variant="primary") | |
| reset_button = gr.Button("Reset Model") | |
| train_status = gr.Textbox(label="Training Status", lines=6) | |
| generate_button.click( | |
| fn=generate_text, | |
| inputs=[prompt_input, max_tokens_input, temperature_input, top_k_input], | |
| outputs=[output_text, output_status], | |
| ) | |
| train_button.click( | |
| fn=train_model, | |
| inputs=[extra_text_input, steps_input], | |
| outputs=[train_status], | |
| ) | |
| reset_button.click( | |
| fn=reset_model, | |
| outputs=[train_status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |