Spaces:
Runtime error
Runtime error
| """ | |
| Gradio Chatbot Interface for CGT-LLM-Beta RAG System | |
| This application provides a web interface for the RAG chatbot, allowing users to: | |
| - Select different LLM models from a dropdown | |
| - Choose education level for personalized answers (Middle School, High School, Professional, Improved) | |
| - View answers with Flesch-Kincaid grade level scores | |
| - See source documents and similarity scores for every answer | |
| Usage: | |
| python app.py | |
| IMPORTANT: Before using, update the MODEL_MAP dictionary with correct HuggingFace paths | |
| for models that currently have placeholder paths (Llama-4-Scout, MediPhi, Phi-4-reasoning). | |
| For Hugging Face Spaces: | |
| - Ensure vector database is built (run bot.py with indexing first) | |
| - Model will be loaded on startup | |
| - Access via the Gradio interface | |
| """ | |
| import gradio as gr | |
| import argparse | |
| import sys | |
| import os | |
| from typing import Tuple, Optional | |
| import logging | |
| import textstat | |
| import torch | |
| # Import from bot.py | |
| from bot import RAGBot, parse_args | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Model mapping: short name -> full HuggingFace path | |
| MODEL_MAP = { | |
| "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", | |
| "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct", | |
| "MediPhi-Instruct": "microsoft/MediPhi-Instruct", | |
| "MediPhi": "microsoft/MediPhi", | |
| "Phi-4-reasoning": "microsoft/Phi-4-reasoning", | |
| } | |
| # Education level mapping | |
| EDUCATION_LEVELS = { | |
| "Middle School": "middle_school", | |
| "High School": "high_school", | |
| "College": "college", | |
| "Doctoral": "doctoral" | |
| } | |
| # Example questions from the results CSV (hardcoded for easy access) | |
| EXAMPLE_QUESTIONS = [ | |
| "Can a BRCA2 variant skip a generation?", | |
| "Can a PMS2 variant skip a generation?", | |
| "Can an EPCAM/MSH2 variant skip a generation?", | |
| "Can an MLH1 variant skip a generation?", | |
| "Can an MSH2 variant skip a generation?", | |
| "Can an MSH6 variant skip a generation?", | |
| "Can I pass this MSH2 variant to my kids?", | |
| "Can only women carry a BRCA inherited mutation?", | |
| "Does GINA cover life or disability insurance?", | |
| "Does having a BRCA1 mutation mean I will definitely have cancer?", | |
| "Does having a BRCA2 mutation mean I will definitely have cancer?", | |
| "Does having a PMS2 mutation mean I will definitely have cancer?", | |
| "Does having an EPCAM/MSH2 mutation mean I will definitely have cancer?", | |
| "Does having an MLH1 mutation mean I will definitely have cancer?", | |
| "Does having an MSH2 mutation mean I will definitely have cancer?", | |
| "Does having an MSH6 mutation mean I will definitely have cancer?", | |
| "Does this BRCA1 genetic variant affect my cancer treatment?", | |
| "Does this BRCA2 genetic variant affect my cancer treatment?", | |
| "Does this EPCAM/MSH2 genetic variant affect my cancer treatment?", | |
| "Does this MLH1 genetic variant affect my cancer treatment?", | |
| "Does this MSH2 genetic variant affect my cancer treatment?", | |
| "Does this MSH6 genetic variant affect my cancer treatment?", | |
| "Does this PMS2 genetic variant affect my cancer treatment?", | |
| "How can I cope with this diagnosis?", | |
| "How can I get my kids tested?", | |
| "How can I help others with my condition?", | |
| "How might my genetic test results change over time?", | |
| "I don't talk to my family/parents/sister/brother. How can I share this with them?", | |
| "I have a BRCA pathogenic variant and I want to have children, what are my options?", | |
| "Is genetic testing for my family members covered by insurance?", | |
| "Is new research being done on my condition?", | |
| "Is this BRCA1 variant something I inherited?", | |
| "Is this BRCA2 variant something I inherited?", | |
| "Is this EPCAM/MSH2 variant something I inherited?", | |
| "Is this MLH1 variant something I inherited?", | |
| "Is this MSH2 variant something I inherited?", | |
| "Is this MSH6 variant something I inherited?", | |
| "Is this PMS2 variant something I inherited?", | |
| "My relative doesn't have insurance. What should they do?", | |
| "People who test positive for a genetic mutation are they at risk of losing their health insurance?", | |
| "Should I contact my male and female relatives?", | |
| "Should my family members get tested?", | |
| "What are the Risks and Benefits of Risk-Reducing Surgeries for Lynch Syndrome?", | |
| "What are the recommendations for my family members if I have a BRCA1 mutation?", | |
| "What are the recommendations for my family members if I have a BRCA2 mutation?", | |
| "What are the recommendations for my family members if I have a PMS2 mutation?", | |
| "What are the recommendations for my family members if I have an EPCAM/MSH2 mutation?", | |
| "What are the recommendations for my family members if I have an MLH1 mutation?", | |
| "What are the recommendations for my family members if I have an MSH2 mutation?", | |
| "What are the recommendations for my family members if I have an MSH6 mutation?", | |
| "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have a BRCA mutation?", | |
| "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an EPCAM/MSH2 mutation?", | |
| "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an MSH2 mutation?", | |
| "What does a BRCA1 genetic variant mean for me?", | |
| "What does a BRCA2 genetic variant mean for me?", | |
| "What does a PMS2 genetic variant mean for me?", | |
| "What does an EPCAM/MSH2 genetic variant mean for me?", | |
| "What does an MLH1 genetic variant mean for me?", | |
| "What does an MSH2 genetic variant mean for me?", | |
| "What does an MSH6 genetic variant mean for me?", | |
| "What if I feel overwhelmed?", | |
| "What if I want to have children and have a hereditary cancer gene? What are my reproductive options?", | |
| "What if a family member doesn't want to get tested?", | |
| "What is Lynch Syndrome?", | |
| "What is my cancer risk if I have BRCA1 Hereditary Breast and Ovarian Cancer syndrome?", | |
| "What is my cancer risk if I have BRCA2 Hereditary Breast and Ovarian Cancer syndrome?", | |
| "What is my cancer risk if I have MLH1 Lynch syndrome?", | |
| "What is my cancer risk if I have MSH2 or EPCAM-associated Lynch syndrome?", | |
| "What is my cancer risk if I have MSH6 Lynch syndrome?", | |
| "What is my cancer risk if I have PMS2 Lynch syndrome?", | |
| "What other resources are available to help me?", | |
| "What screening tests do you recommend for BRCA1 carriers?", | |
| "What screening tests do you recommend for BRCA2 carriers?", | |
| "What screening tests do you recommend for EPCAM/MSH2 carriers?", | |
| "What screening tests do you recommend for MLH1 carriers?", | |
| "What screening tests do you recommend for MSH2 carriers?", | |
| "What screening tests do you recommend for MSH6 carriers?", | |
| "What screening tests do you recommend for PMS2 carriers?", | |
| "What steps can I take to manage my cancer risk if I have Lynch syndrome?", | |
| "What types of cancers am I at risk for with a BRCA1 mutation?", | |
| "What types of cancers am I at risk for with a BRCA2 mutation?", | |
| "What types of cancers am I at risk for with a PMS2 mutation?", | |
| "What types of cancers am I at risk for with an EPCAM/MSH2 mutation?", | |
| "What types of cancers am I at risk for with an MLH1 mutation?", | |
| "What types of cancers am I at risk for with an MSH2 mutation?", | |
| "What types of cancers am I at risk for with an MSH6 mutation?", | |
| "Where can I find a genetic counselor?", | |
| "Which of my relatives are at risk?", | |
| "Who are my first-degree relatives?", | |
| "Who do my family members call to have genetic testing?", | |
| "Why do some families with Lynch syndrome have more cases of cancer than others?", | |
| "Why should I share my BRCA1 genetic results with family?", | |
| "Why should I share my BRCA2 genetic results with family?", | |
| "Why should I share my EPCAM/MSH2 genetic results with family?", | |
| "Why should I share my MLH1 genetic results with family?", | |
| "Why should I share my MSH2 genetic results with family?", | |
| "Why should I share my MSH6 genetic results with family?", | |
| "Why should I share my PMS2 genetic results with family?", | |
| "Why would my relatives want to know if they have this? What can they do about it?", | |
| "Will my insurance cover testing for my parents/brother/sister?", | |
| "Will this affect my health insurance?", | |
| ] | |
| class GradioRAGInterface: | |
| """Wrapper class to integrate RAGBot with Gradio""" | |
| def __init__(self, initial_bot: RAGBot): | |
| self.bot = initial_bot | |
| self.current_model = initial_bot.args.model | |
| self.data_dir = initial_bot.args.data_dir | |
| logger.info("GradioRAGInterface initialized") | |
| def _find_file_path(self, filename: str) -> str: | |
| """Find the full file path for a given filename""" | |
| from pathlib import Path | |
| data_path = Path(self.data_dir) | |
| if not data_path.exists(): | |
| return "" | |
| # Search for the file recursively | |
| for file_path in data_path.rglob(filename): | |
| return str(file_path) | |
| return "" | |
| def reload_model(self, model_short_name: str) -> str: | |
| """Reload the model when user selects a different one""" | |
| if model_short_name not in MODEL_MAP: | |
| return f"Error: Unknown model '{model_short_name}'" | |
| new_model_path = MODEL_MAP[model_short_name] | |
| # If same model, no need to reload | |
| if new_model_path == self.current_model: | |
| return f"Model already loaded: {model_short_name}" | |
| try: | |
| logger.info(f"Reloading model from {self.current_model} to {new_model_path}") | |
| # Update args | |
| self.bot.args.model = new_model_path | |
| # Clear old model from memory | |
| if self.bot.model is not None: | |
| del self.bot.model | |
| del self.bot.tokenizer | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| # Load new model | |
| self.bot._load_model() | |
| self.current_model = new_model_path | |
| return f"✓ Model loaded: {model_short_name}" | |
| except Exception as e: | |
| logger.error(f"Error reloading model: {e}", exc_info=True) | |
| return f"✗ Error loading model: {str(e)}" | |
| def process_question( | |
| self, | |
| question: str, | |
| model_name: str, | |
| education_level: str, | |
| k: int, | |
| temperature: float, | |
| max_tokens: int | |
| ) -> Tuple[str, str, str, str, str]: | |
| """ | |
| Process a single question and return formatted results | |
| Returns: | |
| Tuple of (answer, flesch_score, sources, similarity_scores, question_category) | |
| """ | |
| import time | |
| if not question or not question.strip(): | |
| return "Please enter a question.", "N/A", "", "", "" | |
| try: | |
| start_time = time.time() | |
| logger.info(f"Processing question: {question[:50]}...") | |
| # Reload model if changed (this can take 1-3 minutes) | |
| if model_name in MODEL_MAP: | |
| model_path = MODEL_MAP[model_name] | |
| if model_path != self.current_model: | |
| logger.info(f"Model changed, reloading from {self.current_model} to {model_path}") | |
| reload_status = self.reload_model(model_name) | |
| if reload_status.startswith("✗"): | |
| return f"Error: {reload_status}", "N/A", "", "", "" | |
| logger.info(f"Model reloaded in {time.time() - start_time:.1f}s") | |
| # Update bot args for this query | |
| self.bot.args.k = k | |
| self.bot.args.temperature = temperature | |
| # Limit max_tokens for faster generation in Gradio | |
| self.bot.args.max_new_tokens = min(max_tokens, 512) # Cap at 512 for faster responses | |
| # Categorize question | |
| logger.info("Categorizing question...") | |
| question_group = self.bot._categorize_question(question) | |
| # Retrieve relevant chunks with similarity scores | |
| logger.info("Retrieving relevant documents...") | |
| retrieve_start = time.time() | |
| context_chunks, similarity_scores = self.bot.retrieve_with_scores(question, k) | |
| logger.info(f"Retrieved {len(context_chunks)} chunks in {time.time() - retrieve_start:.2f}s") | |
| if not context_chunks: | |
| return ( | |
| "I don't have enough information to answer this question. Please try rephrasing or asking about a different topic.", | |
| "N/A", | |
| "No sources found", | |
| "No matches found", | |
| question_group | |
| ) | |
| # Format similarity scores | |
| similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores]) | |
| # Format sources with chunk text and file paths | |
| sources_list = [] | |
| for i, (chunk, score) in enumerate(zip(context_chunks, similarity_scores)): | |
| # Try to find the file path | |
| file_path = self._find_file_path(chunk.filename) | |
| source_info = f""" | |
| {'='*80} | |
| SOURCE {i+1} | Similarity: {score:.3f} | |
| {'='*80} | |
| 📄 File: {chunk.filename} | |
| 📍 Path: {file_path if file_path else 'File path not found (search in Data Resources directory)'} | |
| 📊 Chunk: {chunk.chunk_id + 1}/{chunk.total_chunks} (Position: {chunk.start_pos}-{chunk.end_pos}) | |
| 📝 Full Chunk Text: | |
| {chunk.text} | |
| """ | |
| sources_list.append(source_info) | |
| sources = "\n".join(sources_list) | |
| # Generation kwargs | |
| gen_kwargs = { | |
| 'max_new_tokens': min(max_tokens, 512), # Cap for faster responses | |
| 'temperature': temperature, | |
| 'top_p': self.bot.args.top_p, | |
| 'repetition_penalty': self.bot.args.repetition_penalty | |
| } | |
| # Generate answer based on education level | |
| answer = "" | |
| flesch_score = 0.0 | |
| # Generate original answer first (needed for all enhancement levels) | |
| logger.info("Generating original answer...") | |
| gen_start = time.time() | |
| prompt = self.bot.format_prompt(context_chunks, question) | |
| original_answer = self.bot.generate_answer(prompt, **gen_kwargs) | |
| logger.info(f"Original answer generated in {time.time() - gen_start:.1f}s") | |
| # Enhance based on education level | |
| logger.info(f"Enhancing answer for {education_level} level...") | |
| enhance_start = time.time() | |
| if education_level == "middle_school": | |
| # Simplify to middle school level | |
| answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="middle_school") | |
| elif education_level == "high_school": | |
| # Simplify to high school level | |
| answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="high_school") | |
| elif education_level == "college": | |
| # Enhance to college level | |
| answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="college") | |
| elif education_level == "doctoral": | |
| # Enhance to doctoral/professional level | |
| answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="doctoral") | |
| else: | |
| answer = "Invalid education level selected." | |
| flesch_score = 0.0 | |
| logger.info(f"Answer enhanced in {time.time() - enhance_start:.1f}s") | |
| total_time = time.time() - start_time | |
| logger.info(f"Total processing time: {total_time:.1f}s") | |
| # Clean the answer - remove special tokens and formatting | |
| import re | |
| cleaned_answer = answer | |
| # Remove special tokens (case-insensitive) | |
| special_tokens = [ | |
| "<|end|>", | |
| "<|endoftext|>", | |
| "<|end_of_text|>", | |
| "<|eot_id|>", | |
| "<|start_header_id|>", | |
| "<|end_header_id|>", | |
| "<|assistant|>", | |
| "<|endoftext|>", | |
| "<|end_of_text|>", | |
| ] | |
| for token in special_tokens: | |
| # Remove case-insensitive | |
| cleaned_answer = re.sub(re.escape(token), '', cleaned_answer, flags=re.IGNORECASE) | |
| # Remove any remaining special token patterns like <|...|> | |
| cleaned_answer = re.sub(r'<\|[^|]+\|>', '', cleaned_answer) | |
| # Remove any markdown-style headers that might have been added | |
| cleaned_answer = re.sub(r'^\*\*.*?\*\*.*?\n', '', cleaned_answer, flags=re.MULTILINE) | |
| # Clean up extra whitespace and newlines | |
| cleaned_answer = re.sub(r'\n\s*\n\s*\n+', '\n\n', cleaned_answer) # Multiple newlines to double | |
| cleaned_answer = re.sub(r'^\s+|\s+$', '', cleaned_answer, flags=re.MULTILINE) # Trim lines | |
| cleaned_answer = cleaned_answer.strip() | |
| # Return just the clean answer (no headers or metadata) | |
| return ( | |
| cleaned_answer, | |
| f"{flesch_score:.1f}", | |
| sources, | |
| similarity_scores_str, | |
| question_group # Add question category as 5th return value | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing question: {e}", exc_info=True) | |
| return ( | |
| f"An error occurred while processing your question: {str(e)}", | |
| "N/A", | |
| "", | |
| "", | |
| "Error" | |
| ) | |
| def create_interface(initial_bot: RAGBot) -> gr.Blocks: | |
| """Create and configure the Gradio interface""" | |
| interface = GradioRAGInterface(initial_bot) | |
| # Get initial model name from bot | |
| initial_model_short = None | |
| for short_name, full_path in MODEL_MAP.items(): | |
| if full_path == initial_bot.args.model: | |
| initial_model_short = short_name | |
| break | |
| if initial_model_short is None: | |
| initial_model_short = list(MODEL_MAP.keys())[0] | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: | |
| gr.Markdown(""" | |
| # 🧬 CGT-LLM-Beta: Genetic Counseling RAG Chatbot | |
| Ask questions about genetic counseling, cascade genetic testing, hereditary cancer syndromes, and related topics. | |
| The chatbot uses a Retrieval-Augmented Generation (RAG) system to provide evidence-based answers from medical literature. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="e.g., What is Lynch Syndrome? What screening is recommended for BRCA1 carriers?", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_MAP.keys()), | |
| value=initial_model_short, | |
| label="Select Model", | |
| info="Choose which LLM model to use for generating answers" | |
| ) | |
| education_dropdown = gr.Dropdown( | |
| choices=list(EDUCATION_LEVELS.keys()), | |
| value=list(EDUCATION_LEVELS.keys())[0], | |
| label="Education Level", | |
| info="Select your education level for personalized answers" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of document chunks to retrieve (k)" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.1, | |
| label="Temperature (lower = more focused)" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=128, | |
| label="Max Tokens (lower = faster responses)" | |
| ) | |
| submit_btn = gr.Button("Ask Question", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| answer_output = gr.Textbox( | |
| label="Answer", | |
| lines=20, | |
| interactive=False, | |
| elem_classes=["answer-box"] | |
| ) | |
| with gr.Row(): | |
| flesch_output = gr.Textbox( | |
| label="Flesch-Kincaid Grade Level", | |
| value="N/A", | |
| interactive=False, | |
| scale=1 | |
| ) | |
| similarity_output = gr.Textbox( | |
| label="Similarity Scores", | |
| value="", | |
| interactive=False, | |
| scale=1 | |
| ) | |
| category_output = gr.Textbox( | |
| label="Question Category", | |
| value="", | |
| interactive=False, | |
| scale=1 | |
| ) | |
| sources_output = gr.Textbox( | |
| label="Source Documents (with Chunk Text)", | |
| lines=15, | |
| interactive=False, | |
| info="Shows the retrieved document chunks with full text. File paths are shown for easy access." | |
| ) | |
| # Example questions - all questions from the results CSV (scrollable) | |
| gr.Markdown("### 💡 Example Questions") | |
| gr.Markdown(f"Select a question below to use it in the chatbot ({len(EXAMPLE_QUESTIONS)} questions - scrollable dropdown):") | |
| # Use Dropdown which is naturally scrollable with many options | |
| example_questions_dropdown = gr.Dropdown( | |
| choices=EXAMPLE_QUESTIONS, | |
| label="Example Questions", | |
| value=None, | |
| info="Open the dropdown and scroll through all questions. Select one to use it.", | |
| interactive=True, | |
| container=True, | |
| scale=1 | |
| ) | |
| # Update question input when dropdown selection changes | |
| def update_question_from_dropdown(selected_question): | |
| return selected_question if selected_question else "" | |
| example_questions_dropdown.change( | |
| fn=update_question_from_dropdown, | |
| inputs=example_questions_dropdown, | |
| outputs=question_input | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| **Note:** This chatbot provides informational answers based on medical literature. | |
| It is not a substitute for professional medical advice, diagnosis, or treatment. | |
| Always consult with qualified healthcare providers for medical decisions. | |
| """) | |
| # Connect the submit button | |
| def process_with_education_level(question, model, education, k, temp, max_tok): | |
| education_key = EDUCATION_LEVELS[education] | |
| return interface.process_question(question, model, education_key, k, temp, max_tok) | |
| submit_btn.click( | |
| fn=process_with_education_level, | |
| inputs=[ | |
| question_input, | |
| model_dropdown, | |
| education_dropdown, | |
| k_slider, | |
| temperature_slider, | |
| max_tokens_slider | |
| ], | |
| outputs=[ | |
| answer_output, | |
| flesch_output, | |
| sources_output, | |
| similarity_output, | |
| category_output | |
| ] | |
| ) | |
| # Also allow Enter key to submit | |
| question_input.submit( | |
| fn=process_with_education_level, | |
| inputs=[ | |
| question_input, | |
| model_dropdown, | |
| education_dropdown, | |
| k_slider, | |
| temperature_slider, | |
| max_tokens_slider | |
| ], | |
| outputs=[ | |
| answer_output, | |
| flesch_output, | |
| sources_output, | |
| similarity_output, | |
| category_output | |
| ] | |
| ) | |
| return demo | |
| def main(): | |
| """Main function to launch the Gradio app""" | |
| # Parse arguments with defaults suitable for Gradio | |
| parser = argparse.ArgumentParser(description="Gradio Interface for CGT-LLM-Beta RAG Chatbot") | |
| # Model and database settings | |
| parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct', | |
| help='HuggingFace model name') | |
| parser.add_argument('--vector-db-dir', default='./chroma_db', | |
| help='Directory for ChromaDB persistence') | |
| parser.add_argument('--data-dir', default='./Data Resources', | |
| help='Directory containing documents (for indexing if needed)') | |
| # Generation parameters | |
| parser.add_argument('--max-new-tokens', type=int, default=1024, | |
| help='Maximum new tokens to generate') | |
| parser.add_argument('--temperature', type=float, default=0.2, | |
| help='Generation temperature') | |
| parser.add_argument('--top-p', type=float, default=0.9, | |
| help='Top-p sampling parameter') | |
| parser.add_argument('--repetition-penalty', type=float, default=1.1, | |
| help='Repetition penalty') | |
| # Retrieval parameters | |
| parser.add_argument('--k', type=int, default=5, | |
| help='Number of chunks to retrieve per question') | |
| # Other settings | |
| parser.add_argument('--skip-indexing', action='store_true', | |
| help='Skip document indexing (use existing vector DB)') | |
| parser.add_argument('--verbose', action='store_true', | |
| help='Enable verbose logging') | |
| parser.add_argument('--share', action='store_true', | |
| help='Create a public Gradio share link') | |
| parser.add_argument('--server-name', type=str, default='127.0.0.1', | |
| help='Server name (0.0.0.0 for public access)') | |
| parser.add_argument('--server-port', type=int, default=7860, | |
| help='Server port') | |
| args = parser.parse_args() | |
| # Set logging level | |
| if args.verbose: | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| logger.info("Initializing RAGBot for Gradio interface...") | |
| logger.info(f"Model: {args.model}") | |
| logger.info(f"Vector DB: {args.vector_db_dir}") | |
| try: | |
| # Initialize bot | |
| bot = RAGBot(args) | |
| # Check if vector database exists and has documents | |
| collection_stats = bot.vector_retriever.get_collection_stats() | |
| if collection_stats.get('total_chunks', 0) == 0: | |
| logger.warning("Vector database is empty. You may need to run indexing first:") | |
| logger.warning(" python bot.py --data-dir './Data Resources' --vector-db-dir './chroma_db'") | |
| logger.warning("Continuing anyway - the chatbot will work but may not find relevant documents.") | |
| # Create and launch Gradio interface | |
| demo = create_interface(bot) | |
| # For Hugging Face Spaces, just return the demo (they handle launching) | |
| # For local use, launch it | |
| if os.getenv("SPACE_ID") or os.getenv("SYSTEM") == "spaces": | |
| # Running on Hugging Face Spaces - return demo for Spaces to launch | |
| return demo | |
| else: | |
| # Running locally | |
| logger.info(f"Launching Gradio interface on http://{args.server_name}:{args.server_port}") | |
| demo.launch( | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| share=args.share | |
| ) | |
| except KeyboardInterrupt: | |
| logger.info("Interrupted by user") | |
| sys.exit(0) | |
| except Exception as e: | |
| logger.error(f"Error launching Gradio app: {e}", exc_info=True) | |
| sys.exit(1) | |
| # For Hugging Face Spaces: create demo at module level | |
| # Spaces will import this module and look for a 'demo' variable | |
| def create_demo_for_spaces(): | |
| """Create demo for Hugging Face Spaces""" | |
| try: | |
| # Initialize with default args for Spaces | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct') | |
| parser.add_argument('--vector-db-dir', default='./chroma_db') | |
| parser.add_argument('--data-dir', default='./Data Resources') | |
| parser.add_argument('--max-new-tokens', type=int, default=1024) | |
| parser.add_argument('--temperature', type=float, default=0.2) | |
| parser.add_argument('--top-p', type=float, default=0.9) | |
| parser.add_argument('--repetition-penalty', type=float, default=1.1) | |
| parser.add_argument('--k', type=int, default=5) | |
| parser.add_argument('--skip-indexing', action='store_true', default=True) | |
| parser.add_argument('--verbose', action='store_true', default=False) | |
| parser.add_argument('--share', action='store_true', default=False) | |
| parser.add_argument('--server-name', type=str, default='0.0.0.0') | |
| parser.add_argument('--server-port', type=int, default=7860) | |
| parser.add_argument('--seed', type=int, default=42) | |
| args = parser.parse_args([]) # Empty args for Spaces | |
| bot = RAGBot(args) | |
| return create_interface(bot) | |
| except Exception as e: | |
| logger.error(f"Error creating demo for Spaces: {e}", exc_info=True) | |
| # Return a simple error demo | |
| with gr.Blocks() as error_demo: | |
| gr.Markdown(f"# Error Initializing Chatbot\n\nAn error occurred: {str(e)}") | |
| return error_demo | |
| # Create demo at module level for Hugging Face Spaces | |
| # This is what Spaces will import and use | |
| demo = create_demo_for_spaces() | |
| # For local execution | |
| if __name__ == "__main__": | |
| main() | |