| import gradio as gr |
| from huggingface_hub import InferenceClient |
|
|
| |
| |
| |
| |
| MODEL_ID = "Marcus719/Llama-3.2-3B-Instruct-Lab2" |
|
|
| client = InferenceClient(model=MODEL_ID) |
|
|
| def chat(message, history, system_message, max_tokens, temperature, top_p): |
| """Generate response using HuggingFace Inference API""" |
| |
| messages = [{"role": "system", "content": system_message}] |
| |
| |
| for user_msg, assistant_msg in history: |
| if user_msg: |
| messages.append({"role": "user", "content": user_msg}) |
| if assistant_msg: |
| messages.append({"role": "assistant", "content": assistant_msg}) |
| |
| |
| messages.append({"role": "user", "content": message}) |
| |
| |
| response = "" |
| for chunk in client.chat_completion( |
| messages, |
| max_tokens=max_tokens, |
| stream=True, |
| temperature=temperature, |
| top_p=top_p, |
| ): |
| if chunk.choices and chunk.choices[0].delta.content: |
| token = chunk.choices[0].delta.content |
| response += token |
| yield response |
|
|
| |
| |
| |
| DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant." |
|
|
| with gr.Blocks(theme=gr.themes.Soft(), title="🦙 Llama 3.2 ChatBot") as demo: |
| |
| gr.Markdown( |
| """ |
| # 🦙 Llama 3.2 3B Instruct - Fine-tuned on FineTome |
| |
| **KTH ID2223 Scalable Machine Learning - Lab 2** |
| |
| This chatbot uses my fine-tuned Llama 3.2 3B model trained on the FineTome-100k dataset. |
| |
| 📦 Model: [Marcus719/Llama-3.2-3B-Instruct-Lab2](https://huggingface.co/Marcus719/Llama-3.2-3B-Instruct-Lab2) |
| """ |
| ) |
| |
| chatbot = gr.Chatbot(label="Chat", height=450, show_copy_button=True) |
| |
| with gr.Row(): |
| msg = gr.Textbox( |
| placeholder="Type your message here...", |
| scale=4, |
| container=False, |
| autofocus=True |
| ) |
| submit_btn = gr.Button("Send 🚀", scale=1, variant="primary") |
| |
| with gr.Accordion("⚙️ Settings", open=False): |
| system_prompt = gr.Textbox( |
| label="System Prompt", |
| value=DEFAULT_SYSTEM_PROMPT, |
| lines=2 |
| ) |
| with gr.Row(): |
| max_tokens = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") |
| temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature") |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") |
| |
| with gr.Row(): |
| clear_btn = gr.Button("🗑️ Clear Chat") |
| retry_btn = gr.Button("🔄 Regenerate") |
| |
| gr.Examples( |
| examples=[ |
| "Hello! Can you introduce yourself?", |
| "Explain machine learning in simple terms.", |
| "What is the difference between fine-tuning and pre-training?", |
| "Write a short poem about AI.", |
| ], |
| inputs=msg, |
| label="💡 Try these examples" |
| ) |
| |
| |
| def user_input(message, history): |
| return "", history + [[message, None]] |
| |
| def bot_response(history, system_prompt, max_tokens, temperature, top_p): |
| if not history: |
| return history |
| message = history[-1][0] |
| history_for_model = history[:-1] |
| for response in chat(message, history_for_model, system_prompt, max_tokens, temperature, top_p): |
| history[-1][1] = response |
| yield history |
| |
| def retry_last(history, system_prompt, max_tokens, temperature, top_p): |
| if history: |
| history[-1][1] = None |
| message = history[-1][0] |
| history_for_model = history[:-1] |
| for response in chat(message, history_for_model, system_prompt, max_tokens, temperature, top_p): |
| history[-1][1] = response |
| yield history |
| |
| msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot_response, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot |
| ) |
| submit_btn.click(user_input, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot_response, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot |
| ) |
| clear_btn.click(lambda: [], None, chatbot, queue=False) |
| retry_btn.click(retry_last, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot) |
| |
| gr.Markdown( |
| """ |
| --- |
| ### 📝 About This Project |
| |
| **Fine-tuning Details:** |
| - Base Model: `meta-llama/Llama-3.2-3B-Instruct` |
| - Dataset: [FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) |
| - Method: QLoRA (4-bit quantization + LoRA) |
| - Framework: [Unsloth](https://github.com/unslothai/unsloth) |
| |
| Built with ❤️ for KTH ID2223 Lab 2 |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |