| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| from accelerate import Accelerator |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| |
| accelerator = Accelerator() |
|
|
| |
| model_dirs = [ |
| "muhammadAhmed22/fine_tuned_gpt2", |
| "muhammadAhmed22/MiriFurgpt2-recipes", |
| "muhammadAhmed22/auhide-chef-gpt-en" |
| ] |
|
|
| models = {} |
| tokenizers = {} |
|
|
| def load_model(model_dir): |
| model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32) |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| model = model.to(device) |
| return model, tokenizer |
|
|
| |
| for model_dir in model_dirs: |
| model_name = model_dir.split("/")[-1] |
| try: |
| model, tokenizer = load_model(model_dir) |
| models[model_name] = model |
| tokenizers[model_name] = tokenizer |
|
|
| |
| dummy_inputs = ["Hello", "What is a recipe?", "Explain cooking basics"] |
| for dummy_input in dummy_inputs: |
| input_ids = tokenizer.encode(dummy_input, return_tensors='pt').to(device) |
| with torch.no_grad(): |
| model.generate(input_ids, max_new_tokens=1) |
|
|
| print(f"Loaded model and tokenizer from {model_dir}.") |
| except Exception as e: |
| print(f"Failed to load model from {model_dir}: {e}") |
| continue |
|
|
| def get_response(prompt, model_name, user_type): |
| if model_name not in models: |
| return "Model not loaded correctly." |
| |
| model = models[model_name] |
| tokenizer = tokenizers[model_name] |
|
|
| |
| user_type_templates = { |
| "Professional": f"As a professional chef, {prompt}\nAnswer:", |
| "Beginner": f"Explain in simple terms: {prompt}\nAnswer:", |
| "Intermediate": f"As an intermediate cook, {prompt}\nAnswer:", |
| "Expert": f"As an expert chef, {prompt}\nAnswer:" |
| } |
|
|
| |
| prompt_template = user_type_templates.get(user_type, f"{prompt}\nAnswer:") |
|
|
| encoding = tokenizer( |
| prompt_template, |
| return_tensors='pt', |
| padding=True, |
| truncation=True, |
| max_length=512 |
| ).to(device) |
|
|
| |
| max_new_tokens = 200 |
|
|
| with torch.no_grad(): |
| output = model.generate( |
| input_ids=encoding['input_ids'], |
| attention_mask=encoding['attention_mask'], |
| max_new_tokens=max_new_tokens, |
| num_beams=1, |
| repetition_penalty=1.1, |
| temperature=0.7, |
| top_p=0.85, |
| early_stopping=True, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
|
|
| response = tokenizer.decode(output[0], skip_special_tokens=True) |
| return response.strip() |
|
|
| def process_input(prompt, model_name, user_type): |
| if prompt and prompt.strip(): |
| return get_response(prompt, model_name, user_type) |
| else: |
| return "Please provide a valid prompt." |
|
|
| |
| with gr.Blocks(css=""" |
| body { |
| background-color: #f8f8f8; |
| font-family: 'Helvetica Neue', Arial, sans-serif; |
| } |
| .title { |
| font-size: 2.6rem; |
| font-weight: 700; |
| color: #ff6347; |
| text-align: center; |
| margin-bottom: 1.5rem; |
| } |
| .container { |
| max-width: 800px; |
| margin: auto; |
| padding: 2rem; |
| background-color: #ffffff; |
| border-radius: 10px; |
| box-shadow: 0 12px 24px rgba(0, 0, 0, 0.1); |
| } |
| .button { |
| background-color: #ff6347; |
| color: white; |
| padding: 0.8rem 1.8rem; |
| font-size: 1.1rem; |
| border: none; |
| border-radius: 8px; |
| cursor: pointer; |
| transition: background-color 0.3s ease; |
| margin-top: 1.5rem; |
| width: 100%; |
| } |
| .button:hover { |
| background-color: #ff4500; |
| } |
| .gradio-interface .gr-textbox { |
| margin-bottom: 1.5rem; |
| width: 100%; |
| border-radius: 8px; |
| padding: 1rem; |
| border: 1px solid #ddd; |
| font-size: 1rem; |
| background-color: #f9f9f9; |
| color: #333; |
| } |
| .gradio-interface .gr-radio, .gradio-interface .gr-dropdown { |
| margin-bottom: 1.5rem; |
| width: 100%; |
| border-radius: 8px; |
| padding: 1rem; |
| border: 1px solid #ddd; |
| background-color: #f9f9f9; |
| font-size: 1rem; |
| color: #333; |
| } |
| .gradio-interface .gr-textbox[readonly] { |
| background-color: #f5f5f5; |
| color: #333; |
| font-size: 1rem; |
| } |
| """) as demo: |
|
|
| gr.Markdown("<div class='title'>Cookspert: Your Personal AI Chef</div>") |
|
|
| user_types = ["Professional", "Beginner", "Intermediate", "Expert"] |
|
|
| with gr.Column(scale=1, min_width=350): |
| |
| prompt = gr.Textbox(label="Enter Your Cooking Question", placeholder="What would you like to ask?", lines=3) |
|
|
| |
| model_name = gr.Radio(label="Choose Model", choices=list(models.keys()), interactive=True) |
|
|
| |
| user_type = gr.Dropdown(label="Select Your Skill Level", choices=user_types, value="Home Cook") |
|
|
| |
| submit_button = gr.Button("chef gpt", elem_classes="button") |
|
|
| |
| response = gr.Textbox( |
| label="Response", |
| placeholder="Your answer will appear here...", |
| lines=15, |
| interactive=False, |
| show_copy_button=True, |
| max_lines=20 |
| ) |
|
|
| submit_button.click(fn=process_input, inputs=[prompt, model_name, user_type], outputs=response) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", share=True, debug=True) |