arahrooh commited on
Commit
086ffee
·
1 Parent(s): 5164dc4

Deploy CGT-LLM-Beta RAG Chatbot with vector database

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.log
4
+ results/
5
+ *.csv
6
+ .DS_Store
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CGT-LLM-Beta RAG Chatbot
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # CGT-LLM-Beta: Genetic Counseling RAG Chatbot
14
+
15
+ A Retrieval-Augmented Generation (RAG) chatbot for genetic counseling and cascade genetic testing questions.
16
+
17
+ ## Features
18
+
19
+ - **Evidence-based answers** from medical literature
20
+ - **Multiple education levels**: Middle School, High School, College, and Doctoral
21
+ - **Source document citations** with full chunk text
22
+ - **Similarity scoring** for transparency
23
+ - **Flesch-Kincaid readability scores** for all answers
24
+ - **Multiple LLM models** to choose from
25
+ - **100+ example questions** for testing
26
+
27
+ ## How to Use
28
+
29
+ 1. **Select a model** from the dropdown (default: Llama-3.2-3B-Instruct)
30
+ 2. **Choose your education level** for personalized answers
31
+ 3. **Enter your question** or select from example questions
32
+ 4. **View the answer** with readability score, sources, and similarity scores
33
+
34
+ ## Education Levels
35
+
36
+ - **Middle School**: Simplified version for ages 12-14
37
+ - **High School**: Simplified version for ages 15-18
38
+ - **College**: Professional version for undergraduate level
39
+ - **Doctoral**: Advanced version for medical professionals
40
+
41
+ ## Models Available
42
+
43
+ - Llama-3.2-3B-Instruct
44
+ - Mistral-7B-Instruct-v0.2
45
+ - Llama-4-Scout-17B-16E-Instruct
46
+ - MediPhi-Instruct
47
+ - MediPhi
48
+ - Phi-4-reasoning
49
+
50
+ ## Important Notes
51
+
52
+ ⚠️ **This chatbot provides informational answers based on medical literature. It is not a substitute for professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare providers for medical decisions.**
53
+
54
+ ## Technical Details
55
+
56
+ - **Vector Database**: ChromaDB with sentence-transformers embeddings
57
+ - **RAG System**: Retrieval-Augmented Generation with semantic search
58
+ - **Source Attribution**: Full document tracking with chunk-level citations
59
+
app.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Chatbot Interface for CGT-LLM-Beta RAG System
3
+
4
+ This application provides a web interface for the RAG chatbot, allowing users to:
5
+ - Select different LLM models from a dropdown
6
+ - Choose education level for personalized answers (Middle School, High School, Professional, Improved)
7
+ - View answers with Flesch-Kincaid grade level scores
8
+ - See source documents and similarity scores for every answer
9
+
10
+ Usage:
11
+ python app.py
12
+
13
+ IMPORTANT: Before using, update the MODEL_MAP dictionary with correct HuggingFace paths
14
+ for models that currently have placeholder paths (Llama-4-Scout, MediPhi, Phi-4-reasoning).
15
+
16
+ For Hugging Face Spaces:
17
+ - Ensure vector database is built (run bot.py with indexing first)
18
+ - Model will be loaded on startup
19
+ - Access via the Gradio interface
20
+ """
21
+
22
+ import gradio as gr
23
+ import argparse
24
+ import sys
25
+ import os
26
+ from typing import Tuple, Optional
27
+ import logging
28
+ import textstat
29
+ import torch
30
+
31
+ # Import from bot.py
32
+ from bot import RAGBot, parse_args
33
+
34
+ # Set up logging
35
+ logging.basicConfig(level=logging.INFO)
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Model mapping: short name -> full HuggingFace path
39
+ MODEL_MAP = {
40
+ "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
41
+ "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2",
42
+ "Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
43
+ "MediPhi-Instruct": "microsoft/MediPhi-Instruct",
44
+ "MediPhi": "microsoft/MediPhi",
45
+ "Phi-4-reasoning": "microsoft/Phi-4-reasoning",
46
+ }
47
+
48
+ # Education level mapping
49
+ EDUCATION_LEVELS = {
50
+ "Middle School": "middle_school",
51
+ "High School": "high_school",
52
+ "College": "college",
53
+ "Doctoral": "doctoral"
54
+ }
55
+
56
+ # Example questions from the results CSV (hardcoded for easy access)
57
+ EXAMPLE_QUESTIONS = [
58
+ "Can a BRCA2 variant skip a generation?",
59
+ "Can a PMS2 variant skip a generation?",
60
+ "Can an EPCAM/MSH2 variant skip a generation?",
61
+ "Can an MLH1 variant skip a generation?",
62
+ "Can an MSH2 variant skip a generation?",
63
+ "Can an MSH6 variant skip a generation?",
64
+ "Can I pass this MSH2 variant to my kids?",
65
+ "Can only women carry a BRCA inherited mutation?",
66
+ "Does GINA cover life or disability insurance?",
67
+ "Does having a BRCA1 mutation mean I will definitely have cancer?",
68
+ "Does having a BRCA2 mutation mean I will definitely have cancer?",
69
+ "Does having a PMS2 mutation mean I will definitely have cancer?",
70
+ "Does having an EPCAM/MSH2 mutation mean I will definitely have cancer?",
71
+ "Does having an MLH1 mutation mean I will definitely have cancer?",
72
+ "Does having an MSH2 mutation mean I will definitely have cancer?",
73
+ "Does having an MSH6 mutation mean I will definitely have cancer?",
74
+ "Does this BRCA1 genetic variant affect my cancer treatment?",
75
+ "Does this BRCA2 genetic variant affect my cancer treatment?",
76
+ "Does this EPCAM/MSH2 genetic variant affect my cancer treatment?",
77
+ "Does this MLH1 genetic variant affect my cancer treatment?",
78
+ "Does this MSH2 genetic variant affect my cancer treatment?",
79
+ "Does this MSH6 genetic variant affect my cancer treatment?",
80
+ "Does this PMS2 genetic variant affect my cancer treatment?",
81
+ "How can I cope with this diagnosis?",
82
+ "How can I get my kids tested?",
83
+ "How can I help others with my condition?",
84
+ "How might my genetic test results change over time?",
85
+ "I don't talk to my family/parents/sister/brother. How can I share this with them?",
86
+ "I have a BRCA pathogenic variant and I want to have children, what are my options?",
87
+ "Is genetic testing for my family members covered by insurance?",
88
+ "Is new research being done on my condition?",
89
+ "Is this BRCA1 variant something I inherited?",
90
+ "Is this BRCA2 variant something I inherited?",
91
+ "Is this EPCAM/MSH2 variant something I inherited?",
92
+ "Is this MLH1 variant something I inherited?",
93
+ "Is this MSH2 variant something I inherited?",
94
+ "Is this MSH6 variant something I inherited?",
95
+ "Is this PMS2 variant something I inherited?",
96
+ "My relative doesn't have insurance. What should they do?",
97
+ "People who test positive for a genetic mutation are they at risk of losing their health insurance?",
98
+ "Should I contact my male and female relatives?",
99
+ "Should my family members get tested?",
100
+ "What are the Risks and Benefits of Risk-Reducing Surgeries for Lynch Syndrome?",
101
+ "What are the recommendations for my family members if I have a BRCA1 mutation?",
102
+ "What are the recommendations for my family members if I have a BRCA2 mutation?",
103
+ "What are the recommendations for my family members if I have a PMS2 mutation?",
104
+ "What are the recommendations for my family members if I have an EPCAM/MSH2 mutation?",
105
+ "What are the recommendations for my family members if I have an MLH1 mutation?",
106
+ "What are the recommendations for my family members if I have an MSH2 mutation?",
107
+ "What are the recommendations for my family members if I have an MSH6 mutation?",
108
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have a BRCA mutation?",
109
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an EPCAM/MSH2 mutation?",
110
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an MSH2 mutation?",
111
+ "What does a BRCA1 genetic variant mean for me?",
112
+ "What does a BRCA2 genetic variant mean for me?",
113
+ "What does a PMS2 genetic variant mean for me?",
114
+ "What does an EPCAM/MSH2 genetic variant mean for me?",
115
+ "What does an MLH1 genetic variant mean for me?",
116
+ "What does an MSH2 genetic variant mean for me?",
117
+ "What does an MSH6 genetic variant mean for me?",
118
+ "What if I feel overwhelmed?",
119
+ "What if I want to have children and have a hereditary cancer gene? What are my reproductive options?",
120
+ "What if a family member doesn't want to get tested?",
121
+ "What is Lynch Syndrome?",
122
+ "What is my cancer risk if I have BRCA1 Hereditary Breast and Ovarian Cancer syndrome?",
123
+ "What is my cancer risk if I have BRCA2 Hereditary Breast and Ovarian Cancer syndrome?",
124
+ "What is my cancer risk if I have MLH1 Lynch syndrome?",
125
+ "What is my cancer risk if I have MSH2 or EPCAM-associated Lynch syndrome?",
126
+ "What is my cancer risk if I have MSH6 Lynch syndrome?",
127
+ "What is my cancer risk if I have PMS2 Lynch syndrome?",
128
+ "What other resources are available to help me?",
129
+ "What screening tests do you recommend for BRCA1 carriers?",
130
+ "What screening tests do you recommend for BRCA2 carriers?",
131
+ "What screening tests do you recommend for EPCAM/MSH2 carriers?",
132
+ "What screening tests do you recommend for MLH1 carriers?",
133
+ "What screening tests do you recommend for MSH2 carriers?",
134
+ "What screening tests do you recommend for MSH6 carriers?",
135
+ "What screening tests do you recommend for PMS2 carriers?",
136
+ "What steps can I take to manage my cancer risk if I have Lynch syndrome?",
137
+ "What types of cancers am I at risk for with a BRCA1 mutation?",
138
+ "What types of cancers am I at risk for with a BRCA2 mutation?",
139
+ "What types of cancers am I at risk for with a PMS2 mutation?",
140
+ "What types of cancers am I at risk for with an EPCAM/MSH2 mutation?",
141
+ "What types of cancers am I at risk for with an MLH1 mutation?",
142
+ "What types of cancers am I at risk for with an MSH2 mutation?",
143
+ "What types of cancers am I at risk for with an MSH6 mutation?",
144
+ "Where can I find a genetic counselor?",
145
+ "Which of my relatives are at risk?",
146
+ "Who are my first-degree relatives?",
147
+ "Who do my family members call to have genetic testing?",
148
+ "Why do some families with Lynch syndrome have more cases of cancer than others?",
149
+ "Why should I share my BRCA1 genetic results with family?",
150
+ "Why should I share my BRCA2 genetic results with family?",
151
+ "Why should I share my EPCAM/MSH2 genetic results with family?",
152
+ "Why should I share my MLH1 genetic results with family?",
153
+ "Why should I share my MSH2 genetic results with family?",
154
+ "Why should I share my MSH6 genetic results with family?",
155
+ "Why should I share my PMS2 genetic results with family?",
156
+ "Why would my relatives want to know if they have this? What can they do about it?",
157
+ "Will my insurance cover testing for my parents/brother/sister?",
158
+ "Will this affect my health insurance?",
159
+ ]
160
+
161
+
162
+ class GradioRAGInterface:
163
+ """Wrapper class to integrate RAGBot with Gradio"""
164
+
165
+ def __init__(self, initial_bot: RAGBot):
166
+ self.bot = initial_bot
167
+ self.current_model = initial_bot.args.model
168
+ self.data_dir = initial_bot.args.data_dir
169
+ logger.info("GradioRAGInterface initialized")
170
+
171
+ def _find_file_path(self, filename: str) -> str:
172
+ """Find the full file path for a given filename"""
173
+ from pathlib import Path
174
+ data_path = Path(self.data_dir)
175
+
176
+ if not data_path.exists():
177
+ return ""
178
+
179
+ # Search for the file recursively
180
+ for file_path in data_path.rglob(filename):
181
+ return str(file_path)
182
+
183
+ return ""
184
+
185
+ def reload_model(self, model_short_name: str) -> str:
186
+ """Reload the model when user selects a different one"""
187
+ if model_short_name not in MODEL_MAP:
188
+ return f"Error: Unknown model '{model_short_name}'"
189
+
190
+ new_model_path = MODEL_MAP[model_short_name]
191
+
192
+ # If same model, no need to reload
193
+ if new_model_path == self.current_model:
194
+ return f"Model already loaded: {model_short_name}"
195
+
196
+ try:
197
+ logger.info(f"Reloading model from {self.current_model} to {new_model_path}")
198
+
199
+ # Update args
200
+ self.bot.args.model = new_model_path
201
+
202
+ # Clear old model from memory
203
+ if self.bot.model is not None:
204
+ del self.bot.model
205
+ del self.bot.tokenizer
206
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
207
+
208
+ # Load new model
209
+ self.bot._load_model()
210
+ self.current_model = new_model_path
211
+
212
+ return f"✓ Model loaded: {model_short_name}"
213
+ except Exception as e:
214
+ logger.error(f"Error reloading model: {e}", exc_info=True)
215
+ return f"✗ Error loading model: {str(e)}"
216
+
217
+ def process_question(
218
+ self,
219
+ question: str,
220
+ model_name: str,
221
+ education_level: str,
222
+ k: int,
223
+ temperature: float,
224
+ max_tokens: int
225
+ ) -> Tuple[str, str, str, str, str]:
226
+ """
227
+ Process a single question and return formatted results
228
+
229
+ Returns:
230
+ Tuple of (answer, flesch_score, sources, similarity_scores, question_category)
231
+ """
232
+ import time
233
+
234
+ if not question or not question.strip():
235
+ return "Please enter a question.", "N/A", "", "", ""
236
+
237
+ try:
238
+ start_time = time.time()
239
+ logger.info(f"Processing question: {question[:50]}...")
240
+
241
+ # Reload model if changed (this can take 1-3 minutes)
242
+ if model_name in MODEL_MAP:
243
+ model_path = MODEL_MAP[model_name]
244
+ if model_path != self.current_model:
245
+ logger.info(f"Model changed, reloading from {self.current_model} to {model_path}")
246
+ reload_status = self.reload_model(model_name)
247
+ if reload_status.startswith("✗"):
248
+ return f"Error: {reload_status}", "N/A", "", "", ""
249
+ logger.info(f"Model reloaded in {time.time() - start_time:.1f}s")
250
+
251
+ # Update bot args for this query
252
+ self.bot.args.k = k
253
+ self.bot.args.temperature = temperature
254
+ # Limit max_tokens for faster generation in Gradio
255
+ self.bot.args.max_new_tokens = min(max_tokens, 512) # Cap at 512 for faster responses
256
+
257
+ # Categorize question
258
+ logger.info("Categorizing question...")
259
+ question_group = self.bot._categorize_question(question)
260
+
261
+ # Retrieve relevant chunks with similarity scores
262
+ logger.info("Retrieving relevant documents...")
263
+ retrieve_start = time.time()
264
+ context_chunks, similarity_scores = self.bot.retrieve_with_scores(question, k)
265
+ logger.info(f"Retrieved {len(context_chunks)} chunks in {time.time() - retrieve_start:.2f}s")
266
+
267
+ if not context_chunks:
268
+ return (
269
+ "I don't have enough information to answer this question. Please try rephrasing or asking about a different topic.",
270
+ "N/A",
271
+ "No sources found",
272
+ "No matches found",
273
+ question_group
274
+ )
275
+
276
+ # Format similarity scores
277
+ similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores])
278
+
279
+ # Format sources with chunk text and file paths
280
+ sources_list = []
281
+ for i, (chunk, score) in enumerate(zip(context_chunks, similarity_scores)):
282
+ # Try to find the file path
283
+ file_path = self._find_file_path(chunk.filename)
284
+
285
+ source_info = f"""
286
+ {'='*80}
287
+ SOURCE {i+1} | Similarity: {score:.3f}
288
+ {'='*80}
289
+ 📄 File: {chunk.filename}
290
+ 📍 Path: {file_path if file_path else 'File path not found (search in Data Resources directory)'}
291
+ 📊 Chunk: {chunk.chunk_id + 1}/{chunk.total_chunks} (Position: {chunk.start_pos}-{chunk.end_pos})
292
+
293
+ 📝 Full Chunk Text:
294
+ {chunk.text}
295
+
296
+ """
297
+ sources_list.append(source_info)
298
+
299
+ sources = "\n".join(sources_list)
300
+
301
+ # Generation kwargs
302
+ gen_kwargs = {
303
+ 'max_new_tokens': min(max_tokens, 512), # Cap for faster responses
304
+ 'temperature': temperature,
305
+ 'top_p': self.bot.args.top_p,
306
+ 'repetition_penalty': self.bot.args.repetition_penalty
307
+ }
308
+
309
+ # Generate answer based on education level
310
+ answer = ""
311
+ flesch_score = 0.0
312
+
313
+ # Generate original answer first (needed for all enhancement levels)
314
+ logger.info("Generating original answer...")
315
+ gen_start = time.time()
316
+ prompt = self.bot.format_prompt(context_chunks, question)
317
+ original_answer = self.bot.generate_answer(prompt, **gen_kwargs)
318
+ logger.info(f"Original answer generated in {time.time() - gen_start:.1f}s")
319
+
320
+ # Enhance based on education level
321
+ logger.info(f"Enhancing answer for {education_level} level...")
322
+ enhance_start = time.time()
323
+ if education_level == "middle_school":
324
+ # Simplify to middle school level
325
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="middle_school")
326
+
327
+ elif education_level == "high_school":
328
+ # Simplify to high school level
329
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="high_school")
330
+
331
+ elif education_level == "college":
332
+ # Enhance to college level
333
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="college")
334
+
335
+ elif education_level == "doctoral":
336
+ # Enhance to doctoral/professional level
337
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="doctoral")
338
+ else:
339
+ answer = "Invalid education level selected."
340
+ flesch_score = 0.0
341
+
342
+ logger.info(f"Answer enhanced in {time.time() - enhance_start:.1f}s")
343
+ total_time = time.time() - start_time
344
+ logger.info(f"Total processing time: {total_time:.1f}s")
345
+
346
+ # Clean the answer - remove special tokens and formatting
347
+ import re
348
+ cleaned_answer = answer
349
+
350
+ # Remove special tokens (case-insensitive)
351
+ special_tokens = [
352
+ "<|end|>",
353
+ "<|endoftext|>",
354
+ "<|end_of_text|>",
355
+ "<|eot_id|>",
356
+ "<|start_header_id|>",
357
+ "<|end_header_id|>",
358
+ "<|assistant|>",
359
+ "<|endoftext|>",
360
+ "<|end_of_text|>",
361
+ ]
362
+ for token in special_tokens:
363
+ # Remove case-insensitive
364
+ cleaned_answer = re.sub(re.escape(token), '', cleaned_answer, flags=re.IGNORECASE)
365
+
366
+ # Remove any remaining special token patterns like <|...|>
367
+ cleaned_answer = re.sub(r'<\|[^|]+\|>', '', cleaned_answer)
368
+
369
+ # Remove any markdown-style headers that might have been added
370
+ cleaned_answer = re.sub(r'^\*\*.*?\*\*.*?\n', '', cleaned_answer, flags=re.MULTILINE)
371
+
372
+ # Clean up extra whitespace and newlines
373
+ cleaned_answer = re.sub(r'\n\s*\n\s*\n+', '\n\n', cleaned_answer) # Multiple newlines to double
374
+ cleaned_answer = re.sub(r'^\s+|\s+$', '', cleaned_answer, flags=re.MULTILINE) # Trim lines
375
+ cleaned_answer = cleaned_answer.strip()
376
+
377
+ # Return just the clean answer (no headers or metadata)
378
+ return (
379
+ cleaned_answer,
380
+ f"{flesch_score:.1f}",
381
+ sources,
382
+ similarity_scores_str,
383
+ question_group # Add question category as 5th return value
384
+ )
385
+
386
+ except Exception as e:
387
+ logger.error(f"Error processing question: {e}", exc_info=True)
388
+ return (
389
+ f"An error occurred while processing your question: {str(e)}",
390
+ "N/A",
391
+ "",
392
+ "",
393
+ "Error"
394
+ )
395
+
396
+
397
+ def create_interface(initial_bot: RAGBot) -> gr.Blocks:
398
+ """Create and configure the Gradio interface"""
399
+
400
+ interface = GradioRAGInterface(initial_bot)
401
+
402
+ # Get initial model name from bot
403
+ initial_model_short = None
404
+ for short_name, full_path in MODEL_MAP.items():
405
+ if full_path == initial_bot.args.model:
406
+ initial_model_short = short_name
407
+ break
408
+ if initial_model_short is None:
409
+ initial_model_short = list(MODEL_MAP.keys())[0]
410
+
411
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo:
412
+ gr.Markdown("""
413
+ # 🧬 CGT-LLM-Beta: Genetic Counseling RAG Chatbot
414
+
415
+ Ask questions about genetic counseling, cascade genetic testing, hereditary cancer syndromes, and related topics.
416
+
417
+ The chatbot uses a Retrieval-Augmented Generation (RAG) system to provide evidence-based answers from medical literature.
418
+ """)
419
+
420
+ with gr.Row():
421
+ with gr.Column(scale=2):
422
+ question_input = gr.Textbox(
423
+ label="Your Question",
424
+ placeholder="e.g., What is Lynch Syndrome? What screening is recommended for BRCA1 carriers?",
425
+ lines=3
426
+ )
427
+
428
+ with gr.Row():
429
+ model_dropdown = gr.Dropdown(
430
+ choices=list(MODEL_MAP.keys()),
431
+ value=initial_model_short,
432
+ label="Select Model",
433
+ info="Choose which LLM model to use for generating answers"
434
+ )
435
+
436
+ education_dropdown = gr.Dropdown(
437
+ choices=list(EDUCATION_LEVELS.keys()),
438
+ value=list(EDUCATION_LEVELS.keys())[0],
439
+ label="Education Level",
440
+ info="Select your education level for personalized answers"
441
+ )
442
+
443
+ with gr.Accordion("Advanced Settings", open=False):
444
+ k_slider = gr.Slider(
445
+ minimum=1,
446
+ maximum=10,
447
+ value=5,
448
+ step=1,
449
+ label="Number of document chunks to retrieve (k)"
450
+ )
451
+ temperature_slider = gr.Slider(
452
+ minimum=0.1,
453
+ maximum=1.0,
454
+ value=0.2,
455
+ step=0.1,
456
+ label="Temperature (lower = more focused)"
457
+ )
458
+ max_tokens_slider = gr.Slider(
459
+ minimum=128,
460
+ maximum=1024,
461
+ value=512,
462
+ step=128,
463
+ label="Max Tokens (lower = faster responses)"
464
+ )
465
+
466
+ submit_btn = gr.Button("Ask Question", variant="primary", size="lg")
467
+
468
+ with gr.Column(scale=3):
469
+ answer_output = gr.Textbox(
470
+ label="Answer",
471
+ lines=20,
472
+ interactive=False,
473
+ elem_classes=["answer-box"]
474
+ )
475
+
476
+ with gr.Row():
477
+ flesch_output = gr.Textbox(
478
+ label="Flesch-Kincaid Grade Level",
479
+ value="N/A",
480
+ interactive=False,
481
+ scale=1
482
+ )
483
+
484
+ similarity_output = gr.Textbox(
485
+ label="Similarity Scores",
486
+ value="",
487
+ interactive=False,
488
+ scale=1
489
+ )
490
+
491
+ category_output = gr.Textbox(
492
+ label="Question Category",
493
+ value="",
494
+ interactive=False,
495
+ scale=1
496
+ )
497
+
498
+ sources_output = gr.Textbox(
499
+ label="Source Documents (with Chunk Text)",
500
+ lines=15,
501
+ interactive=False,
502
+ info="Shows the retrieved document chunks with full text. File paths are shown for easy access."
503
+ )
504
+
505
+ # Example questions - all questions from the results CSV (scrollable)
506
+ gr.Markdown("### 💡 Example Questions")
507
+ gr.Markdown(f"Select a question below to use it in the chatbot ({len(EXAMPLE_QUESTIONS)} questions - scrollable dropdown):")
508
+
509
+ # Use Dropdown which is naturally scrollable with many options
510
+ example_questions_dropdown = gr.Dropdown(
511
+ choices=EXAMPLE_QUESTIONS,
512
+ label="Example Questions",
513
+ value=None,
514
+ info="Open the dropdown and scroll through all questions. Select one to use it.",
515
+ interactive=True,
516
+ container=True,
517
+ scale=1
518
+ )
519
+
520
+ # Update question input when dropdown selection changes
521
+ def update_question_from_dropdown(selected_question):
522
+ return selected_question if selected_question else ""
523
+
524
+ example_questions_dropdown.change(
525
+ fn=update_question_from_dropdown,
526
+ inputs=example_questions_dropdown,
527
+ outputs=question_input
528
+ )
529
+
530
+ # Footer
531
+ gr.Markdown("""
532
+ ---
533
+ **Note:** This chatbot provides informational answers based on medical literature.
534
+ It is not a substitute for professional medical advice, diagnosis, or treatment.
535
+ Always consult with qualified healthcare providers for medical decisions.
536
+ """)
537
+
538
+ # Connect the submit button
539
+ def process_with_education_level(question, model, education, k, temp, max_tok):
540
+ education_key = EDUCATION_LEVELS[education]
541
+ return interface.process_question(question, model, education_key, k, temp, max_tok)
542
+
543
+ submit_btn.click(
544
+ fn=process_with_education_level,
545
+ inputs=[
546
+ question_input,
547
+ model_dropdown,
548
+ education_dropdown,
549
+ k_slider,
550
+ temperature_slider,
551
+ max_tokens_slider
552
+ ],
553
+ outputs=[
554
+ answer_output,
555
+ flesch_output,
556
+ sources_output,
557
+ similarity_output,
558
+ category_output
559
+ ]
560
+ )
561
+
562
+ # Also allow Enter key to submit
563
+ question_input.submit(
564
+ fn=process_with_education_level,
565
+ inputs=[
566
+ question_input,
567
+ model_dropdown,
568
+ education_dropdown,
569
+ k_slider,
570
+ temperature_slider,
571
+ max_tokens_slider
572
+ ],
573
+ outputs=[
574
+ answer_output,
575
+ flesch_output,
576
+ sources_output,
577
+ similarity_output,
578
+ category_output
579
+ ]
580
+ )
581
+
582
+ return demo
583
+
584
+
585
+ def main():
586
+ """Main function to launch the Gradio app"""
587
+ # Parse arguments with defaults suitable for Gradio
588
+ parser = argparse.ArgumentParser(description="Gradio Interface for CGT-LLM-Beta RAG Chatbot")
589
+
590
+ # Model and database settings
591
+ parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct',
592
+ help='HuggingFace model name')
593
+ parser.add_argument('--vector-db-dir', default='./chroma_db',
594
+ help='Directory for ChromaDB persistence')
595
+ parser.add_argument('--data-dir', default='./Data Resources',
596
+ help='Directory containing documents (for indexing if needed)')
597
+
598
+ # Generation parameters
599
+ parser.add_argument('--max-new-tokens', type=int, default=1024,
600
+ help='Maximum new tokens to generate')
601
+ parser.add_argument('--temperature', type=float, default=0.2,
602
+ help='Generation temperature')
603
+ parser.add_argument('--top-p', type=float, default=0.9,
604
+ help='Top-p sampling parameter')
605
+ parser.add_argument('--repetition-penalty', type=float, default=1.1,
606
+ help='Repetition penalty')
607
+
608
+ # Retrieval parameters
609
+ parser.add_argument('--k', type=int, default=5,
610
+ help='Number of chunks to retrieve per question')
611
+
612
+ # Other settings
613
+ parser.add_argument('--skip-indexing', action='store_true',
614
+ help='Skip document indexing (use existing vector DB)')
615
+ parser.add_argument('--verbose', action='store_true',
616
+ help='Enable verbose logging')
617
+ parser.add_argument('--share', action='store_true',
618
+ help='Create a public Gradio share link')
619
+ parser.add_argument('--server-name', type=str, default='127.0.0.1',
620
+ help='Server name (0.0.0.0 for public access)')
621
+ parser.add_argument('--server-port', type=int, default=7860,
622
+ help='Server port')
623
+
624
+ args = parser.parse_args()
625
+
626
+ # Set logging level
627
+ if args.verbose:
628
+ logging.getLogger().setLevel(logging.DEBUG)
629
+
630
+ logger.info("Initializing RAGBot for Gradio interface...")
631
+ logger.info(f"Model: {args.model}")
632
+ logger.info(f"Vector DB: {args.vector_db_dir}")
633
+
634
+ try:
635
+ # Initialize bot
636
+ bot = RAGBot(args)
637
+
638
+ # Check if vector database exists and has documents
639
+ collection_stats = bot.vector_retriever.get_collection_stats()
640
+ if collection_stats.get('total_chunks', 0) == 0:
641
+ logger.warning("Vector database is empty. You may need to run indexing first:")
642
+ logger.warning(" python bot.py --data-dir './Data Resources' --vector-db-dir './chroma_db'")
643
+ logger.warning("Continuing anyway - the chatbot will work but may not find relevant documents.")
644
+
645
+ # Create and launch Gradio interface
646
+ demo = create_interface(bot)
647
+
648
+ logger.info(f"Launching Gradio interface on http://{args.server_name}:{args.server_port}")
649
+ demo.launch(
650
+ server_name=args.server_name,
651
+ server_port=args.server_port,
652
+ share=args.share
653
+ )
654
+
655
+ except KeyboardInterrupt:
656
+ logger.info("Interrupted by user")
657
+ sys.exit(0)
658
+ except Exception as e:
659
+ logger.error(f"Error launching Gradio app: {e}", exc_info=True)
660
+ sys.exit(1)
661
+
662
+
663
+ if __name__ == "__main__":
664
+ main()
bot.py ADDED
@@ -0,0 +1,1743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RAG Chatbot Implementation for CGT-LLM-Beta with Vector Database
4
+ Production-ready local RAG system with ChromaDB and MPS acceleration for Apple Silicon
5
+ """
6
+
7
+ import argparse
8
+ import csv
9
+ import json
10
+ import logging
11
+ import os
12
+ import re
13
+ import sys
14
+ import time
15
+ import hashlib
16
+ from pathlib import Path
17
+ from typing import List, Tuple, Dict, Any, Optional, Union
18
+ from dataclasses import dataclass
19
+ from collections import defaultdict
20
+
21
+ import textstat
22
+
23
+ import torch
24
+ import numpy as np
25
+ import pandas as pd
26
+ from tqdm import tqdm
27
+
28
+ # Optional imports with graceful fallbacks
29
+ try:
30
+ import chromadb
31
+ from chromadb.config import Settings
32
+ CHROMADB_AVAILABLE = True
33
+ except ImportError:
34
+ CHROMADB_AVAILABLE = False
35
+ print("Warning: chromadb not available. Install with: pip install chromadb")
36
+
37
+ try:
38
+ from sentence_transformers import SentenceTransformer
39
+ SENTENCE_TRANSFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SENTENCE_TRANSFORMERS_AVAILABLE = False
42
+ print("Warning: sentence-transformers not available. Install with: pip install sentence-transformers")
43
+
44
+ try:
45
+ import pypdf
46
+ PDF_AVAILABLE = True
47
+ except ImportError:
48
+ PDF_AVAILABLE = False
49
+ print("Warning: pypdf not available. PDF files will be skipped.")
50
+
51
+ try:
52
+ from docx import Document
53
+ DOCX_AVAILABLE = True
54
+ except ImportError:
55
+ DOCX_AVAILABLE = False
56
+ print("Warning: python-docx not available. DOCX files will be skipped.")
57
+
58
+ try:
59
+ from rank_bm25 import BM25Okapi
60
+ BM25_AVAILABLE = True
61
+ except ImportError:
62
+ BM25_AVAILABLE = False
63
+ print("Warning: rank-bm25 not available. BM25 retrieval disabled.")
64
+
65
+ # Configure logging
66
+ logging.basicConfig(
67
+ level=logging.INFO,
68
+ format='%(asctime)s - %(levelname)s - %(message)s',
69
+ handlers=[
70
+ logging.StreamHandler(),
71
+ logging.FileHandler('rag_bot.log')
72
+ ]
73
+ )
74
+ logger = logging.getLogger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class Document:
79
+ """Represents a document with metadata"""
80
+ filename: str
81
+ content: str
82
+ filepath: str
83
+ file_type: str
84
+ chunk_count: int = 0
85
+ file_hash: str = ""
86
+
87
+
88
+ @dataclass
89
+ class Chunk:
90
+ """Represents a text chunk with metadata"""
91
+ text: str
92
+ filename: str
93
+ chunk_id: int
94
+ total_chunks: int
95
+ start_pos: int
96
+ end_pos: int
97
+ metadata: Dict[str, Any]
98
+ chunk_hash: str = ""
99
+
100
+
101
+ class VectorRetriever:
102
+ """ChromaDB-based vector retrieval"""
103
+
104
+ def __init__(self, collection_name: str = "cgt_documents", persist_directory: str = "./chroma_db"):
105
+ if not CHROMADB_AVAILABLE:
106
+ raise ImportError("ChromaDB is required for vector retrieval")
107
+
108
+ self.collection_name = collection_name
109
+ self.persist_directory = persist_directory
110
+
111
+ # Initialize ChromaDB client
112
+ self.client = chromadb.PersistentClient(path=persist_directory)
113
+
114
+ # Get or create collection
115
+ try:
116
+ self.collection = self.client.get_collection(name=collection_name)
117
+ logger.info(f"Loaded existing collection '{collection_name}' with {self.collection.count()} documents")
118
+ except:
119
+ self.collection = self.client.create_collection(
120
+ name=collection_name,
121
+ metadata={"description": "CGT-LLM-Beta document collection"}
122
+ )
123
+ logger.info(f"Created new collection '{collection_name}'")
124
+
125
+ # Initialize embedding model
126
+ if SENTENCE_TRANSFORMERS_AVAILABLE:
127
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
128
+ logger.info("Loaded sentence-transformers embedding model")
129
+ else:
130
+ self.embedding_model = None
131
+ logger.warning("Sentence-transformers not available, using ChromaDB default embeddings")
132
+
133
+ def add_documents(self, chunks: List[Chunk]) -> None:
134
+ """Add document chunks to the vector database"""
135
+ if not chunks:
136
+ return
137
+
138
+ logger.info(f"Adding {len(chunks)} chunks to vector database...")
139
+
140
+ # Prepare data for ChromaDB
141
+ documents = []
142
+ metadatas = []
143
+ ids = []
144
+
145
+ for chunk in chunks:
146
+ chunk_id = f"{chunk.filename}_{chunk.chunk_id}"
147
+ documents.append(chunk.text)
148
+
149
+ metadata = {
150
+ "filename": chunk.filename,
151
+ "chunk_id": chunk.chunk_id,
152
+ "total_chunks": chunk.total_chunks,
153
+ "start_pos": chunk.start_pos,
154
+ "end_pos": chunk.end_pos,
155
+ "chunk_hash": chunk.chunk_hash,
156
+ **chunk.metadata
157
+ }
158
+ metadatas.append(metadata)
159
+ ids.append(chunk_id)
160
+
161
+ # Add to collection
162
+ try:
163
+ self.collection.add(
164
+ documents=documents,
165
+ metadatas=metadatas,
166
+ ids=ids
167
+ )
168
+ logger.info(f"Successfully added {len(chunks)} chunks to vector database")
169
+ except Exception as e:
170
+ logger.error(f"Error adding documents to vector database: {e}")
171
+
172
+ def search(self, query: str, k: int = 5) -> List[Tuple[Chunk, float]]:
173
+ """Search for similar chunks using vector similarity"""
174
+ try:
175
+ # Perform vector search
176
+ results = self.collection.query(
177
+ query_texts=[query],
178
+ n_results=k
179
+ )
180
+
181
+ chunks_with_scores = []
182
+ if results['documents'] and results['documents'][0]:
183
+ for i, (doc, metadata, distance) in enumerate(zip(
184
+ results['documents'][0],
185
+ results['metadatas'][0],
186
+ results['distances'][0]
187
+ )):
188
+ # Convert distance to similarity score (ChromaDB uses cosine distance)
189
+ similarity_score = 1 - distance
190
+
191
+ chunk = Chunk(
192
+ text=doc,
193
+ filename=metadata['filename'],
194
+ chunk_id=metadata['chunk_id'],
195
+ total_chunks=metadata['total_chunks'],
196
+ start_pos=metadata['start_pos'],
197
+ end_pos=metadata['end_pos'],
198
+ metadata={k: v for k, v in metadata.items()
199
+ if k not in ['filename', 'chunk_id', 'total_chunks', 'start_pos', 'end_pos', 'chunk_hash']},
200
+ chunk_hash=metadata.get('chunk_hash', '')
201
+ )
202
+ chunks_with_scores.append((chunk, similarity_score))
203
+
204
+ return chunks_with_scores
205
+
206
+ except Exception as e:
207
+ logger.error(f"Error searching vector database: {e}")
208
+ return []
209
+
210
+ def get_collection_stats(self) -> Dict[str, Any]:
211
+ """Get statistics about the collection"""
212
+ try:
213
+ count = self.collection.count()
214
+ return {
215
+ "total_chunks": count,
216
+ "collection_name": self.collection_name,
217
+ "persist_directory": self.persist_directory
218
+ }
219
+ except Exception as e:
220
+ logger.error(f"Error getting collection stats: {e}")
221
+ return {}
222
+
223
+
224
+ class RAGBot:
225
+ """Main RAG chatbot class with vector database"""
226
+
227
+ def __init__(self, args):
228
+ self.args = args
229
+ self.device = self._setup_device()
230
+ self.model = None
231
+ self.tokenizer = None
232
+ self.vector_retriever = None
233
+
234
+ # Load model
235
+ self._load_model()
236
+
237
+ # Initialize vector retriever
238
+ self._setup_vector_retriever()
239
+
240
+ def _setup_device(self) -> str:
241
+ """Setup device with MPS support for Apple Silicon"""
242
+ if torch.backends.mps.is_available():
243
+ device = "mps"
244
+ logger.info("Using device: mps (Apple Silicon)")
245
+ elif torch.cuda.is_available():
246
+ device = "cuda"
247
+ logger.info("Using device: cuda")
248
+ else:
249
+ device = "cpu"
250
+ logger.info("Using device: cpu")
251
+
252
+ return device
253
+
254
+ def _load_model(self):
255
+ """Load the specified LLM model and tokenizer"""
256
+ try:
257
+ model_name = self.args.model
258
+ logger.info(f"Loading model: {model_name}...")
259
+ from transformers import AutoTokenizer, AutoModelForCausalLM
260
+
261
+ # Load tokenizer
262
+ self.tokenizer = AutoTokenizer.from_pretrained(
263
+ model_name,
264
+ trust_remote_code=True
265
+ )
266
+
267
+ # Determine appropriate torch dtype based on device and model
268
+ # Use float16 for MPS/CUDA, float32 for CPU
269
+ # Some models work better with bfloat16
270
+ if self.device == "mps":
271
+ torch_dtype = torch.float16
272
+ elif self.device == "cuda":
273
+ torch_dtype = torch.float16
274
+ else:
275
+ torch_dtype = torch.float32
276
+
277
+ # Load model with appropriate settings
278
+ model_kwargs = {
279
+ "torch_dtype": torch_dtype,
280
+ "trust_remote_code": True,
281
+ }
282
+
283
+ # For MPS, use device_map; for CUDA, let it auto-detect
284
+ if self.device == "mps":
285
+ model_kwargs["device_map"] = self.device
286
+ elif self.device == "cuda":
287
+ model_kwargs["device_map"] = "auto"
288
+ # For CPU, don't specify device_map
289
+
290
+ self.model = AutoModelForCausalLM.from_pretrained(
291
+ model_name,
292
+ **model_kwargs
293
+ )
294
+
295
+ # Move to device if not using device_map
296
+ if self.device == "cpu":
297
+ self.model = self.model.to(self.device)
298
+
299
+ # Set pad token if not already set
300
+ if self.tokenizer.pad_token is None:
301
+ if self.tokenizer.eos_token is not None:
302
+ self.tokenizer.pad_token = self.tokenizer.eos_token
303
+ else:
304
+ # Some models might need a different approach
305
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
306
+
307
+ logger.info(f"Model {model_name} loaded successfully on {self.device}")
308
+
309
+ except Exception as e:
310
+ logger.error(f"Failed to load model {self.args.model}: {e}")
311
+ logger.error("Make sure the model name is correct and you have access to it on HuggingFace")
312
+ logger.error("For private models, ensure you're logged in: huggingface-cli login")
313
+ sys.exit(2)
314
+
315
+ def _setup_vector_retriever(self):
316
+ """Setup the vector retriever"""
317
+ try:
318
+ self.vector_retriever = VectorRetriever(
319
+ collection_name="cgt_documents",
320
+ persist_directory=self.args.vector_db_dir
321
+ )
322
+ logger.info("Vector retriever initialized successfully")
323
+ except Exception as e:
324
+ logger.error(f"Failed to setup vector retriever: {e}")
325
+ sys.exit(2)
326
+
327
+ def _calculate_file_hash(self, filepath: str) -> str:
328
+ """Calculate hash of file for change detection"""
329
+ try:
330
+ with open(filepath, 'rb') as f:
331
+ return hashlib.md5(f.read()).hexdigest()
332
+ except:
333
+ return ""
334
+
335
+ def _calculate_chunk_hash(self, text: str) -> str:
336
+ """Calculate hash of chunk text"""
337
+ return hashlib.md5(text.encode('utf-8')).hexdigest()
338
+
339
+ def load_corpus(self, data_dir: str) -> List[Document]:
340
+ """Load all documents from the data directory"""
341
+ logger.info(f"Loading corpus from {data_dir}")
342
+ documents = []
343
+ data_path = Path(data_dir)
344
+
345
+ if not data_path.exists():
346
+ logger.error(f"Data directory {data_dir} does not exist")
347
+ sys.exit(1)
348
+
349
+ # Supported file extensions
350
+ supported_extensions = {'.txt', '.md', '.json', '.csv'}
351
+ if PDF_AVAILABLE:
352
+ supported_extensions.add('.pdf')
353
+ if DOCX_AVAILABLE:
354
+ supported_extensions.add('.docx')
355
+ supported_extensions.add('.doc')
356
+
357
+ # Find all files recursively
358
+ files = []
359
+ for ext in supported_extensions:
360
+ files.extend(data_path.rglob(f"*{ext}"))
361
+
362
+ logger.info(f"Found {len(files)} files to process")
363
+
364
+ # Process files with progress bar
365
+ for file_path in tqdm(files, desc="Loading documents"):
366
+ try:
367
+ content = self._read_file(file_path)
368
+ if content.strip(): # Only add non-empty documents
369
+ file_hash = self._calculate_file_hash(file_path)
370
+ doc = Document(
371
+ filename=file_path.name,
372
+ content=content,
373
+ filepath=str(file_path),
374
+ file_type=file_path.suffix.lower(),
375
+ file_hash=file_hash
376
+ )
377
+ documents.append(doc)
378
+ logger.debug(f"Loaded {file_path.name} ({len(content)} chars)")
379
+ else:
380
+ logger.warning(f"Skipping empty file: {file_path.name}")
381
+
382
+ except Exception as e:
383
+ logger.error(f"Failed to load {file_path.name}: {e}")
384
+ continue
385
+
386
+ logger.info(f"Successfully loaded {len(documents)} documents")
387
+ return documents
388
+
389
+ def _read_file(self, file_path: Path) -> str:
390
+ """Read content from various file types"""
391
+ suffix = file_path.suffix.lower()
392
+
393
+ try:
394
+ if suffix == '.txt':
395
+ return file_path.read_text(encoding='utf-8')
396
+
397
+ elif suffix == '.md':
398
+ return file_path.read_text(encoding='utf-8')
399
+
400
+ elif suffix == '.json':
401
+ with open(file_path, 'r', encoding='utf-8') as f:
402
+ data = json.load(f)
403
+ if isinstance(data, dict):
404
+ return json.dumps(data, indent=2)
405
+ else:
406
+ return str(data)
407
+
408
+ elif suffix == '.csv':
409
+ df = pd.read_csv(file_path)
410
+ return df.to_string()
411
+
412
+ elif suffix == '.pdf' and PDF_AVAILABLE:
413
+ text = ""
414
+ with open(file_path, 'rb') as f:
415
+ pdf_reader = pypdf.PdfReader(f)
416
+ for page in pdf_reader.pages:
417
+ text += page.extract_text() + "\n"
418
+ return text
419
+
420
+ elif suffix in ['.docx', '.doc'] and DOCX_AVAILABLE:
421
+ doc = Document(file_path)
422
+ text = ""
423
+ for paragraph in doc.paragraphs:
424
+ text += paragraph.text + "\n"
425
+ return text
426
+
427
+ else:
428
+ logger.warning(f"Unsupported file type: {suffix}")
429
+ return ""
430
+
431
+ except Exception as e:
432
+ logger.error(f"Error reading {file_path}: {e}")
433
+ return ""
434
+
435
+ def chunk_documents(self, docs: List[Document], chunk_size: int, overlap: int) -> List[Chunk]:
436
+ """Chunk documents into smaller pieces"""
437
+ logger.info(f"Chunking {len(docs)} documents (size={chunk_size}, overlap={overlap})")
438
+ chunks = []
439
+
440
+ for doc in docs:
441
+ doc_chunks = self._chunk_text(
442
+ doc.content,
443
+ doc.filename,
444
+ chunk_size,
445
+ overlap
446
+ )
447
+ chunks.extend(doc_chunks)
448
+
449
+ # Update document metadata
450
+ doc.chunk_count = len(doc_chunks)
451
+
452
+ logger.info(f"Created {len(chunks)} chunks from {len(docs)} documents")
453
+ return chunks
454
+
455
+ def _chunk_text(self, text: str, filename: str, chunk_size: int, overlap: int) -> List[Chunk]:
456
+ """Split text into overlapping chunks"""
457
+ # Clean text
458
+ text = re.sub(r'\s+', ' ', text.strip())
459
+
460
+ # Simple token-based chunking (approximate)
461
+ words = text.split()
462
+ chunks = []
463
+
464
+ for i in range(0, len(words), chunk_size - overlap):
465
+ chunk_words = words[i:i + chunk_size]
466
+ chunk_text = ' '.join(chunk_words)
467
+
468
+ if chunk_text.strip():
469
+ chunk_hash = self._calculate_chunk_hash(chunk_text)
470
+ chunk = Chunk(
471
+ text=chunk_text,
472
+ filename=filename,
473
+ chunk_id=len(chunks),
474
+ total_chunks=0, # Will be updated later
475
+ start_pos=i,
476
+ end_pos=i + len(chunk_words),
477
+ metadata={
478
+ 'word_count': len(chunk_words),
479
+ 'char_count': len(chunk_text)
480
+ },
481
+ chunk_hash=chunk_hash
482
+ )
483
+ chunks.append(chunk)
484
+
485
+ # Update total_chunks for each chunk
486
+ for chunk in chunks:
487
+ chunk.total_chunks = len(chunks)
488
+
489
+ return chunks
490
+
491
+ def build_or_update_index(self, chunks: List[Chunk], force_rebuild: bool = False) -> None:
492
+ """Build or update the vector index"""
493
+ if not chunks:
494
+ logger.warning("No chunks provided for indexing")
495
+ return
496
+
497
+ # Check if we need to rebuild
498
+ collection_stats = self.vector_retriever.get_collection_stats()
499
+ existing_count = collection_stats.get('total_chunks', 0)
500
+
501
+ if existing_count > 0 and not force_rebuild:
502
+ logger.info(f"Vector database already contains {existing_count} chunks. Use --force-rebuild to rebuild.")
503
+ return
504
+
505
+ if force_rebuild and existing_count > 0:
506
+ logger.info("Force rebuild requested. Clearing existing collection...")
507
+ try:
508
+ self.client.delete_collection(self.vector_retriever.collection_name)
509
+ self.vector_retriever.collection = self.client.create_collection(
510
+ name=self.vector_retriever.collection_name,
511
+ metadata={"description": "CGT-LLM-Beta document collection"}
512
+ )
513
+ except Exception as e:
514
+ logger.error(f"Error clearing collection: {e}")
515
+
516
+ # Add chunks to vector database
517
+ self.vector_retriever.add_documents(chunks)
518
+
519
+ logger.info("Vector index built successfully")
520
+
521
+ def retrieve(self, query: str, k: int) -> List[Chunk]:
522
+ """Retrieve relevant chunks for a query using vector search"""
523
+ results = self.vector_retriever.search(query, k)
524
+ chunks = [chunk for chunk, score in results]
525
+
526
+ if self.args.verbose:
527
+ logger.info(f"Retrieved {len(chunks)} chunks for query: {query[:50]}...")
528
+ for i, (chunk, score) in enumerate(results):
529
+ logger.info(f" {i+1}. {chunk.filename} (score: {score:.3f})")
530
+
531
+ return chunks
532
+
533
+ def retrieve_with_scores(self, query: str, k: int) -> Tuple[List[Chunk], List[float]]:
534
+ """Retrieve relevant chunks with similarity scores
535
+
536
+ Returns:
537
+ Tuple of (chunks, scores) where scores are similarity scores for each chunk
538
+ """
539
+ results = self.vector_retriever.search(query, k)
540
+ chunks = [chunk for chunk, score in results]
541
+ scores = [score for chunk, score in results]
542
+
543
+ if self.args.verbose:
544
+ logger.info(f"Retrieved {len(chunks)} chunks for query: {query[:50]}...")
545
+ for i, (chunk, score) in enumerate(results):
546
+ logger.info(f" {i+1}. {chunk.filename} (score: {score:.3f})")
547
+
548
+ return chunks, scores
549
+
550
+ def format_prompt(self, context_chunks: List[Chunk], question: str) -> str:
551
+ """Format the prompt with context and question, ensuring it fits within token limits"""
552
+ context_parts = []
553
+ for chunk in context_chunks:
554
+ context_parts.append(f"{chunk.text}")
555
+
556
+ context = "\n".join(context_parts)
557
+
558
+ # Try to use the tokenizer's chat template if available
559
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
560
+ try:
561
+ messages = [
562
+ {"role": "system", "content": "You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative."},
563
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
564
+ ]
565
+ base_prompt = self.tokenizer.apply_chat_template(
566
+ messages,
567
+ tokenize=False,
568
+ add_generation_prompt=True
569
+ )
570
+ except Exception as e:
571
+ logger.warning(f"Failed to use chat template, falling back to manual format: {e}")
572
+ base_prompt = self._format_prompt_manual(context, question)
573
+ else:
574
+ # Fall back to manual formatting (for Llama models)
575
+ base_prompt = self._format_prompt_manual(context, question)
576
+
577
+ # Check if prompt is too long and truncate context if needed
578
+ max_context_tokens = 1200 # Leave room for generation
579
+ try:
580
+ tokenized = self.tokenizer(base_prompt, return_tensors="pt")
581
+ current_tokens = tokenized['input_ids'].shape[1]
582
+ except Exception as e:
583
+ logger.warning(f"Tokenization error, using base prompt as-is: {e}")
584
+ return base_prompt
585
+
586
+ if current_tokens > max_context_tokens:
587
+ # Truncate context to fit within limits
588
+ try:
589
+ context_tokens = self.tokenizer(context, return_tensors="pt")['input_ids'].shape[1]
590
+ available_tokens = max_context_tokens - (current_tokens - context_tokens)
591
+
592
+ if available_tokens > 0:
593
+ # Truncate context to fit
594
+ truncated_context = self.tokenizer.decode(
595
+ self.tokenizer(context, return_tensors="pt", truncation=True, max_length=available_tokens)['input_ids'][0],
596
+ skip_special_tokens=True
597
+ )
598
+
599
+ # Reformat with truncated context
600
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
601
+ try:
602
+ messages = [
603
+ {"role": "system", "content": "You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative."},
604
+ {"role": "user", "content": f"Context: {truncated_context}\n\nQuestion: {question}"}
605
+ ]
606
+ prompt = self.tokenizer.apply_chat_template(
607
+ messages,
608
+ tokenize=False,
609
+ add_generation_prompt=True
610
+ )
611
+ except:
612
+ prompt = self._format_prompt_manual(truncated_context, question)
613
+ else:
614
+ prompt = self._format_prompt_manual(truncated_context, question)
615
+ else:
616
+ # If even basic prompt is too long, use minimal format
617
+ prompt = self._format_prompt_manual(context[:500] + "...", question)
618
+ except Exception as e:
619
+ logger.warning(f"Error truncating context: {e}, using base prompt")
620
+ prompt = base_prompt
621
+ else:
622
+ prompt = base_prompt
623
+
624
+ return prompt
625
+
626
+ def _format_prompt_manual(self, context: str, question: str) -> str:
627
+ """Manual prompt formatting for models without chat templates (e.g., Llama)"""
628
+ return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
629
+
630
+ You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative.<|eot_id|><|start_header_id|>user<|end_header_id|>
631
+
632
+ Context: {context}
633
+
634
+ Question: {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
635
+
636
+ """
637
+
638
+ def format_improved_prompt(self, context_chunks: List[Chunk], question: str) -> Tuple[str, str]:
639
+ """Format an improved prompt with better tone, structure, and medical appropriateness
640
+
641
+ Returns:
642
+ Tuple of (prompt, prompt_text) where prompt_text is the system prompt instructions
643
+ """
644
+ context_parts = []
645
+ for chunk in context_chunks:
646
+ context_parts.append(f"{chunk.text}")
647
+
648
+ context = "\n".join(context_parts)
649
+
650
+ # Improved prompt with all the feedback incorporated
651
+ improved_prompt_text = """Provide a concise, neutral, and informative answer based on the provided medical context.
652
+
653
+ CRITICAL GUIDELINES:
654
+ - Format your response as clear, well-structured sentences and paragraphs
655
+ - Be concise and direct - focus on answering the specific question asked
656
+ - Use neutral, factual language - do NOT tell the questioner how to feel (avoid phrases like 'don't worry', 'the good news is', etc.)
657
+ - Do NOT use leading or coercive language - present information neutrally to preserve patient autonomy
658
+ - Do NOT make specific medical recommendations - instead state that management decisions should be made with a healthcare provider
659
+ - Use third-person voice only - never claim to be a medical professional or assistant
660
+ - Use consistent terminology: use 'children' (not 'offspring') consistently
661
+ - Do NOT include hypothetical examples with specific names (e.g., avoid 'Aunt Jenna' or similar)
662
+ - Include important distinctions when relevant (e.g., somatic vs. germline variants, reproductive risks)
663
+ - When citing sources, be consistent - always specify which guidelines or sources when mentioned
664
+ - Remove any formatting markers like asterisks (*) or bold markers
665
+ - Do NOT include phrases like 'Here's a rewritten version' - just provide the answer directly
666
+
667
+ If the question asks about medical management, screening, or interventions, conclude with: 'Management recommendations are individualized and should be discussed with a healthcare provider or genetic counselor.'"""
668
+
669
+ # Try to use the tokenizer's chat template if available
670
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
671
+ try:
672
+ messages = [
673
+ {"role": "system", "content": improved_prompt_text},
674
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
675
+ ]
676
+ base_prompt = self.tokenizer.apply_chat_template(
677
+ messages,
678
+ tokenize=False,
679
+ add_generation_prompt=True
680
+ )
681
+ except Exception as e:
682
+ logger.warning(f"Failed to use chat template for improved prompt, falling back to manual format: {e}")
683
+ base_prompt = self._format_improved_prompt_manual(context, question, improved_prompt_text)
684
+ else:
685
+ # Fall back to manual formatting (for Llama models)
686
+ base_prompt = self._format_improved_prompt_manual(context, question, improved_prompt_text)
687
+
688
+ # Check if prompt is too long and truncate context if needed
689
+ max_context_tokens = 1200 # Leave room for generation
690
+ try:
691
+ tokenized = self.tokenizer(base_prompt, return_tensors="pt")
692
+ current_tokens = tokenized['input_ids'].shape[1]
693
+ except Exception as e:
694
+ logger.warning(f"Tokenization error for improved prompt, using base prompt as-is: {e}")
695
+ return base_prompt, improved_prompt_text
696
+
697
+ if current_tokens > max_context_tokens:
698
+ # Truncate context to fit within limits
699
+ try:
700
+ context_tokens = self.tokenizer(context, return_tensors="pt")['input_ids'].shape[1]
701
+ available_tokens = max_context_tokens - (current_tokens - context_tokens)
702
+
703
+ if available_tokens > 0:
704
+ # Truncate context to fit
705
+ truncated_context = self.tokenizer.decode(
706
+ self.tokenizer(context, return_tensors="pt", truncation=True, max_length=available_tokens)['input_ids'][0],
707
+ skip_special_tokens=True
708
+ )
709
+
710
+ # Reformat with truncated context
711
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
712
+ try:
713
+ messages = [
714
+ {"role": "system", "content": improved_prompt_text},
715
+ {"role": "user", "content": f"Context: {truncated_context}\n\nQuestion: {question}"}
716
+ ]
717
+ prompt = self.tokenizer.apply_chat_template(
718
+ messages,
719
+ tokenize=False,
720
+ add_generation_prompt=True
721
+ )
722
+ except:
723
+ prompt = self._format_improved_prompt_manual(truncated_context, question, improved_prompt_text)
724
+ else:
725
+ prompt = self._format_improved_prompt_manual(truncated_context, question, improved_prompt_text)
726
+ else:
727
+ # If even basic prompt is too long, use minimal format
728
+ prompt = self._format_improved_prompt_manual(context[:500] + "...", question, improved_prompt_text)
729
+ except Exception as e:
730
+ logger.warning(f"Error truncating context for improved prompt: {e}, using base prompt")
731
+ prompt = base_prompt
732
+ else:
733
+ prompt = base_prompt
734
+
735
+ return prompt, improved_prompt_text
736
+
737
+ def _format_improved_prompt_manual(self, context: str, question: str, improved_prompt_text: str) -> str:
738
+ """Manual prompt formatting for improved prompts (for models without chat templates)"""
739
+ return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
740
+
741
+ {improved_prompt_text}<|eot_id|><|start_header_id|>user<|end_header_id|>
742
+
743
+ Context: {context}
744
+
745
+ Question: {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
746
+
747
+ """
748
+
749
+ def generate_answer(self, prompt: str, **gen_kwargs) -> str:
750
+ """Generate answer using the language model"""
751
+ try:
752
+ if self.args.verbose:
753
+ logger.info(f"Full prompt (first 500 chars): {prompt[:500]}...")
754
+
755
+ # Tokenize input with more conservative limit to leave room for generation
756
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1500)
757
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
758
+
759
+ if self.args.verbose:
760
+ logger.info(f"Input tokens: {inputs['input_ids'].shape}")
761
+
762
+ # Generate
763
+ with torch.no_grad():
764
+ outputs = self.model.generate(
765
+ **inputs,
766
+ max_new_tokens=gen_kwargs.get('max_new_tokens', 512),
767
+ temperature=gen_kwargs.get('temperature', 0.7),
768
+ top_p=gen_kwargs.get('top_p', 0.95),
769
+ repetition_penalty=gen_kwargs.get('repetition_penalty', 1.05),
770
+ do_sample=True,
771
+ pad_token_id=self.tokenizer.eos_token_id,
772
+ eos_token_id=self.tokenizer.eos_token_id,
773
+ use_cache=True,
774
+ num_beams=1
775
+ )
776
+
777
+ # Decode response without skipping special tokens to preserve full length
778
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
779
+
780
+ if self.args.verbose:
781
+ logger.info(f"Full response (first 1000 chars): {response[:1000]}...")
782
+ logger.info(f"Looking for 'Answer:' in response: {'Answer:' in response}")
783
+ if "Answer:" in response:
784
+ answer_part = response.split("Answer:")[-1]
785
+ logger.info(f"Answer part (first 200 chars): {answer_part[:200]}...")
786
+
787
+ # Debug: Show the full response to understand the structure
788
+ logger.info(f"Full response length: {len(response)}")
789
+ logger.info(f"Prompt length: {len(prompt)}")
790
+ logger.info(f"Response after prompt (first 500 chars): {response[len(prompt):][:500]}...")
791
+
792
+ # Extract the answer more robustly by looking for the end of the prompt
793
+ # Find the actual end of the prompt in the response
794
+ prompt_end_marker = "<|start_header_id|>assistant<|end_header_id|>\n\n"
795
+ if prompt_end_marker in response:
796
+ answer = response.split(prompt_end_marker)[-1].strip()
797
+ else:
798
+ # Fallback to character-based extraction
799
+ answer = response[len(prompt):].strip()
800
+
801
+ if self.args.verbose:
802
+ logger.info(f"Full LLM output (first 200 chars): {answer[:200]}...")
803
+ logger.info(f"Full LLM output length: {len(answer)} characters")
804
+ logger.info(f"Full LLM output (last 200 chars): ...{answer[-200:]}")
805
+
806
+ # Only do minimal cleanup to preserve the full response
807
+ # Remove special tokens that might interfere with display, but preserve content
808
+ if "<|start_header_id|>" in answer:
809
+ # Only remove if it's at the very end
810
+ if answer.endswith("<|start_header_id|>"):
811
+ answer = answer[:-len("<|start_header_id|>")].strip()
812
+ if "<|eot_id|>" in answer:
813
+ # Only remove if it's at the very end
814
+ if answer.endswith("<|eot_id|>"):
815
+ answer = answer[:-len("<|eot_id|>")].strip()
816
+ if "<|end_of_text|>" in answer:
817
+ # Only remove if it's at the very end
818
+ if answer.endswith("<|end_of_text|>"):
819
+ answer = answer[:-len("<|end_of_text|>")].strip()
820
+
821
+ # Final validation - only reject if completely empty
822
+ if not answer or len(answer) < 3:
823
+ answer = "I don't know."
824
+
825
+ if self.args.verbose:
826
+ logger.info(f"Final answer: '{answer}'")
827
+
828
+ return answer
829
+
830
+ except Exception as e:
831
+ logger.error(f"Generation error: {e}")
832
+ return "I encountered an error while generating the answer."
833
+
834
+ def process_questions(self, questions_path: str, **kwargs) -> List[Tuple[str, str, str, str, float, str, float, str, float, str, str]]:
835
+ """Process all questions and generate answers with multiple readability levels
836
+
837
+ Returns:
838
+ List of tuples: (question, answer, sources, question_group, original_flesch,
839
+ middle_school_answer, middle_school_flesch,
840
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores)
841
+ """
842
+ logger.info(f"Processing questions from {questions_path}")
843
+
844
+ # Load questions
845
+ try:
846
+ with open(questions_path, 'r', encoding='utf-8') as f:
847
+ questions = [line.strip() for line in f if line.strip()]
848
+ except Exception as e:
849
+ logger.error(f"Failed to load questions: {e}")
850
+ sys.exit(1)
851
+
852
+ logger.info(f"Found {len(questions)} questions to process")
853
+
854
+ qa_pairs = []
855
+
856
+ # Get the improved prompt text for CSV header by calling format_improved_prompt with empty chunks
857
+ # This will give us the prompt text without actually generating
858
+ _, improved_prompt_text = self.format_improved_prompt([], "")
859
+
860
+ # Initialize CSV file with headers
861
+ self.write_csv([], kwargs.get('output_file', 'results.csv'), append=False, improved_prompt_text=improved_prompt_text)
862
+
863
+ # Process each question
864
+ for i, question in enumerate(tqdm(questions, desc="Processing questions")):
865
+ logger.info(f"Question {i+1}/{len(questions)}: {question[:50]}...")
866
+
867
+ try:
868
+ # Categorize question
869
+ question_group = self._categorize_question(question)
870
+
871
+ # Retrieve relevant chunks with similarity scores
872
+ context_chunks, similarity_scores = self.retrieve_with_scores(question, self.args.k)
873
+
874
+ # Format similarity scores as a string (comma-separated, 3 decimal places)
875
+ similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores]) if similarity_scores else "0.000"
876
+
877
+ if not context_chunks:
878
+ answer = "I don't know."
879
+ sources = "No sources found"
880
+ middle_school_answer = "I don't know."
881
+ high_school_answer = "I don't know."
882
+ improved_answer = "I don't know."
883
+ original_flesch = 0.0
884
+ middle_school_flesch = 0.0
885
+ high_school_flesch = 0.0
886
+ similarity_scores_str = "0.000"
887
+ else:
888
+ # Format original prompt
889
+ prompt = self.format_prompt(context_chunks, question)
890
+
891
+ # Generate original answer
892
+ start_time = time.time()
893
+ answer = self.generate_answer(prompt, **kwargs)
894
+ gen_time = time.time() - start_time
895
+
896
+ # Generate improved answer
897
+ improved_prompt, _ = self.format_improved_prompt(context_chunks, question)
898
+ improved_start = time.time()
899
+ improved_answer = self.generate_answer(improved_prompt, **kwargs)
900
+ improved_time = time.time() - improved_start
901
+
902
+ # Clean up improved answer - remove unwanted phrases and formatting
903
+ improved_answer = self._clean_improved_answer(improved_answer)
904
+ logger.info(f"Improved answer generated in {improved_time:.2f}s")
905
+
906
+ # Extract source documents
907
+ sources = self._extract_sources(context_chunks)
908
+
909
+ # Calculate original answer Flesch score
910
+ try:
911
+ original_flesch = textstat.flesch_kincaid_grade(answer)
912
+ except:
913
+ original_flesch = 0.0
914
+
915
+ # Generate middle school version
916
+ readability_start = time.time()
917
+ middle_school_answer, middle_school_flesch = self.enhance_readability(answer, "middle_school")
918
+ readability_time = time.time() - readability_start
919
+ logger.info(f"Middle school readability in {readability_time:.2f}s")
920
+
921
+ # Generate high school version
922
+ readability_start = time.time()
923
+ high_school_answer, high_school_flesch = self.enhance_readability(answer, "high_school")
924
+ readability_time = time.time() - readability_start
925
+ logger.info(f"High school readability in {readability_time:.2f}s")
926
+
927
+ logger.info(f"Generated answer in {gen_time:.2f}s")
928
+ logger.info(f"Sources: {sources}")
929
+ logger.info(f"Similarity scores: {similarity_scores_str}")
930
+ logger.info(f"Original Flesch: {original_flesch:.1f}, Middle School: {middle_school_flesch:.1f}, High School: {high_school_flesch:.1f}")
931
+
932
+ qa_pairs.append((question, answer, sources, question_group, original_flesch,
933
+ middle_school_answer, middle_school_flesch,
934
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str))
935
+
936
+ # Write incrementally to CSV after each question
937
+ self.write_csv([(question, answer, sources, question_group, original_flesch,
938
+ middle_school_answer, middle_school_flesch,
939
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str)],
940
+ kwargs.get('output_file', 'results.csv'), append=True, improved_prompt_text=improved_prompt_text)
941
+ logger.info(f"Progress saved: {i+1}/{len(questions)} questions completed")
942
+
943
+ except Exception as e:
944
+ logger.error(f"Error processing question {i+1}: {e}")
945
+ error_answer = "I encountered an error processing this question."
946
+ sources = "Error retrieving sources"
947
+ question_group = self._categorize_question(question)
948
+ original_flesch = 0.0
949
+ middle_school_answer = "I encountered an error processing this question."
950
+ high_school_answer = "I encountered an error processing this question."
951
+ improved_answer = "I encountered an error processing this question."
952
+ middle_school_flesch = 0.0
953
+ high_school_flesch = 0.0
954
+ similarity_scores_str = "0.000"
955
+ qa_pairs.append((question, error_answer, sources, question_group, original_flesch,
956
+ middle_school_answer, middle_school_flesch,
957
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str))
958
+
959
+ # Still write the error to CSV
960
+ self.write_csv([(question, error_answer, sources, question_group, original_flesch,
961
+ middle_school_answer, middle_school_flesch,
962
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str)],
963
+ kwargs.get('output_file', 'results.csv'), append=True, improved_prompt_text=improved_prompt_text)
964
+ logger.info(f"Error saved: {i+1}/{len(questions)} questions completed")
965
+
966
+ return qa_pairs
967
+
968
+ def _clean_readability_answer(self, answer: str, target_level: str) -> str:
969
+ """Clean up readability-enhanced answers to remove unwanted phrases and formatting
970
+
971
+ Args:
972
+ answer: The readability-enhanced answer
973
+ target_level: Either "middle_school" or "high_school"
974
+ """
975
+ cleaned = answer
976
+
977
+ # Remove the "Here's a rewritten version" phrases
978
+ if target_level == "middle_school":
979
+ unwanted_phrases = [
980
+ "Here's a rewritten version of the text at a middle school reading level:",
981
+ "Here's a rewritten version of the text at a middle school reading level",
982
+ "Here is a rewritten version of the text at a middle school reading level:",
983
+ "Here is a rewritten version of the text at a middle school reading level",
984
+ "Here's a rewritten version at a middle school reading level:",
985
+ "Here's a rewritten version at a middle school reading level",
986
+ ]
987
+ elif target_level == "high_school":
988
+ unwanted_phrases = [
989
+ "Here's a rewritten version of the text at a high school reading level",
990
+ "Here's a rewritten version of the text at a high school reading level:",
991
+ "Here is a rewritten version of the text at a high school reading level",
992
+ "Here is a rewritten version of the text at a high school reading level:",
993
+ "Here's a rewritten version at a high school reading level",
994
+ "Here's a rewritten version at a high school reading level:",
995
+ ]
996
+ else:
997
+ unwanted_phrases = []
998
+
999
+ for phrase in unwanted_phrases:
1000
+ if phrase.lower() in cleaned.lower():
1001
+ # Find and remove the phrase (case-insensitive)
1002
+ pattern = re.compile(re.escape(phrase), re.IGNORECASE)
1003
+ cleaned = pattern.sub("", cleaned).strip()
1004
+ # Remove leading colons, semicolons, or dashes
1005
+ cleaned = re.sub(r'^[:;\-]\s*', '', cleaned).strip()
1006
+
1007
+ # Remove asterisks (but preserve bullet points if they use •)
1008
+ cleaned = re.sub(r'\*\*', '', cleaned) # Remove bold markers
1009
+ cleaned = re.sub(r'\(\*\)', '', cleaned) # Remove (*)
1010
+ cleaned = re.sub(r'\*', '', cleaned) # Remove remaining asterisks
1011
+
1012
+ # Clean up extra whitespace
1013
+ cleaned = ' '.join(cleaned.split())
1014
+
1015
+ return cleaned
1016
+
1017
+ def _clean_improved_answer(self, answer: str) -> str:
1018
+ """Clean up improved answer to remove unwanted phrases and formatting"""
1019
+ # Remove phrases like "Here's a rewritten version" or similar
1020
+ unwanted_phrases = [
1021
+ "Here's a rewritten version",
1022
+ "Here's a version",
1023
+ "Here is a rewritten version",
1024
+ "Here is a version",
1025
+ "Here's the answer",
1026
+ "Here is the answer"
1027
+ ]
1028
+
1029
+ cleaned = answer
1030
+ for phrase in unwanted_phrases:
1031
+ if phrase.lower() in cleaned.lower():
1032
+ # Find and remove the phrase and any following colon/semicolon
1033
+ pattern = re.compile(re.escape(phrase), re.IGNORECASE)
1034
+ cleaned = pattern.sub("", cleaned).strip()
1035
+ # Remove leading colons, semicolons, or dashes
1036
+ cleaned = re.sub(r'^[:;\-]\s*', '', cleaned).strip()
1037
+
1038
+ # Remove formatting markers like (*) or ** but preserve bullet points
1039
+ cleaned = re.sub(r'\*\*', '', cleaned) # Remove bold markers
1040
+ cleaned = re.sub(r'\(\*\)', '', cleaned) # Remove (*)
1041
+ # Note: Single asterisks are left alone as they might be used for formatting
1042
+ # The prompt specifies using • for bullet points, so this should be fine
1043
+
1044
+ # Remove "Don't worry" and similar emotional management phrases
1045
+ emotional_phrases = [
1046
+ r"don't worry[^.]*\.\s*",
1047
+ r"Don't worry[^.]*\.\s*",
1048
+ r"the good news is[^.]*\.\s*",
1049
+ r"The good news is[^.]*\.\s*",
1050
+ ]
1051
+ for pattern in emotional_phrases:
1052
+ cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE)
1053
+
1054
+ # Clean up extra whitespace
1055
+ cleaned = ' '.join(cleaned.split())
1056
+
1057
+ return cleaned
1058
+
1059
+ def diagnose_system(self, sample_questions: List[str] = None) -> Dict[str, Any]:
1060
+ """Diagnose the document loading, chunking, and retrieval system
1061
+
1062
+ Args:
1063
+ sample_questions: Optional list of questions to test retrieval
1064
+
1065
+ Returns:
1066
+ Dictionary with diagnostic information
1067
+ """
1068
+ diagnostics = {
1069
+ 'vector_db_stats': {},
1070
+ 'document_stats': {},
1071
+ 'chunk_stats': {},
1072
+ 'retrieval_tests': []
1073
+ }
1074
+
1075
+ # Check vector database
1076
+ try:
1077
+ stats = self.vector_retriever.get_collection_stats()
1078
+ diagnostics['vector_db_stats'] = {
1079
+ 'total_chunks': stats.get('total_chunks', 0),
1080
+ 'collection_name': stats.get('collection_name', 'unknown'),
1081
+ 'status': 'OK' if stats.get('total_chunks', 0) > 0 else 'EMPTY'
1082
+ }
1083
+ except Exception as e:
1084
+ diagnostics['vector_db_stats'] = {
1085
+ 'status': 'ERROR',
1086
+ 'error': str(e)
1087
+ }
1088
+
1089
+ # Test document loading (without actually loading)
1090
+ try:
1091
+ data_path = Path(self.args.data_dir)
1092
+ if data_path.exists():
1093
+ supported_extensions = {'.txt', '.md', '.json', '.csv'}
1094
+ if PDF_AVAILABLE:
1095
+ supported_extensions.add('.pdf')
1096
+ if DOCX_AVAILABLE:
1097
+ supported_extensions.add('.docx')
1098
+ supported_extensions.add('.doc')
1099
+
1100
+ files = []
1101
+ for ext in supported_extensions:
1102
+ files.extend(data_path.rglob(f"*{ext}"))
1103
+
1104
+ # Sample a few files to check content
1105
+ sample_files = files[:5] if len(files) > 5 else files
1106
+ file_samples = []
1107
+ for file_path in sample_files:
1108
+ try:
1109
+ content = self._read_file(file_path)
1110
+ file_samples.append({
1111
+ 'filename': file_path.name,
1112
+ 'size_chars': len(content),
1113
+ 'size_words': len(content.split()),
1114
+ 'readable': True
1115
+ })
1116
+ except Exception as e:
1117
+ file_samples.append({
1118
+ 'filename': file_path.name,
1119
+ 'readable': False,
1120
+ 'error': str(e)
1121
+ })
1122
+
1123
+ diagnostics['document_stats'] = {
1124
+ 'total_files_found': len(files),
1125
+ 'sample_files': file_samples,
1126
+ 'status': 'OK'
1127
+ }
1128
+ else:
1129
+ diagnostics['document_stats'] = {
1130
+ 'status': 'ERROR',
1131
+ 'error': f'Data directory {self.args.data_dir} does not exist'
1132
+ }
1133
+ except Exception as e:
1134
+ diagnostics['document_stats'] = {
1135
+ 'status': 'ERROR',
1136
+ 'error': str(e)
1137
+ }
1138
+
1139
+ # Test chunking on a sample document
1140
+ try:
1141
+ if diagnostics['document_stats'].get('status') == 'OK':
1142
+ sample_file = None
1143
+ for file_info in diagnostics['document_stats'].get('sample_files', []):
1144
+ if file_info.get('readable', False):
1145
+ # Find the actual file
1146
+ data_path = Path(self.args.data_dir)
1147
+ for ext in ['.txt', '.md', '.pdf', '.docx']:
1148
+ files = list(data_path.rglob(f"*{file_info['filename']}"))
1149
+ if files:
1150
+ sample_file = files[0]
1151
+ break
1152
+ if sample_file:
1153
+ break
1154
+
1155
+ if sample_file:
1156
+ content = self._read_file(sample_file)
1157
+ # Create a dummy document (Document is already imported at top)
1158
+ sample_doc = Document(
1159
+ filename=sample_file.name,
1160
+ content=content,
1161
+ filepath=str(sample_file),
1162
+ file_type=sample_file.suffix.lower(),
1163
+ file_hash=""
1164
+ )
1165
+
1166
+ # Test chunking
1167
+ sample_chunks = self._chunk_text(
1168
+ content,
1169
+ sample_file.name,
1170
+ self.args.chunk_size,
1171
+ self.args.chunk_overlap
1172
+ )
1173
+
1174
+ chunk_lengths = [len(chunk.text.split()) for chunk in sample_chunks]
1175
+
1176
+ diagnostics['chunk_stats'] = {
1177
+ 'sample_document': sample_file.name,
1178
+ 'total_chunks': len(sample_chunks),
1179
+ 'avg_chunk_size_words': sum(chunk_lengths) / len(chunk_lengths) if chunk_lengths else 0,
1180
+ 'min_chunk_size_words': min(chunk_lengths) if chunk_lengths else 0,
1181
+ 'max_chunk_size_words': max(chunk_lengths) if chunk_lengths else 0,
1182
+ 'chunk_size_setting': self.args.chunk_size,
1183
+ 'chunk_overlap_setting': self.args.chunk_overlap,
1184
+ 'status': 'OK'
1185
+ }
1186
+ except Exception as e:
1187
+ diagnostics['chunk_stats'] = {
1188
+ 'status': 'ERROR',
1189
+ 'error': str(e)
1190
+ }
1191
+
1192
+ # Test retrieval with sample questions
1193
+ if sample_questions and diagnostics['vector_db_stats'].get('status') == 'OK':
1194
+ for question in sample_questions:
1195
+ try:
1196
+ context_chunks = self.retrieve(question, self.args.k)
1197
+ sources = self._extract_sources(context_chunks)
1198
+
1199
+ # Get similarity scores
1200
+ results = self.vector_retriever.search(question, self.args.k)
1201
+
1202
+ # Get sample chunk text (first 200 chars of first chunk)
1203
+ sample_chunk_text = context_chunks[0].text[:200] + "..." if context_chunks else "N/A"
1204
+
1205
+ diagnostics['retrieval_tests'].append({
1206
+ 'question': question,
1207
+ 'chunks_retrieved': len(context_chunks),
1208
+ 'sources': sources,
1209
+ 'similarity_scores': [f"{score:.3f}" for _, score in results],
1210
+ 'sample_chunk_preview': sample_chunk_text,
1211
+ 'status': 'OK' if context_chunks else 'NO_RESULTS'
1212
+ })
1213
+ except Exception as e:
1214
+ diagnostics['retrieval_tests'].append({
1215
+ 'question': question,
1216
+ 'status': 'ERROR',
1217
+ 'error': str(e)
1218
+ })
1219
+
1220
+ return diagnostics
1221
+
1222
+ def print_diagnostics(self, diagnostics: Dict[str, Any]) -> None:
1223
+ """Print diagnostic information in a readable format"""
1224
+ print("\n" + "="*80)
1225
+ print("SYSTEM DIAGNOSTICS")
1226
+ print("="*80)
1227
+
1228
+ # Vector DB Stats
1229
+ print("\n📊 VECTOR DATABASE:")
1230
+ vdb = diagnostics.get('vector_db_stats', {})
1231
+ print(f" Status: {vdb.get('status', 'UNKNOWN')}")
1232
+ print(f" Total chunks: {vdb.get('total_chunks', 0)}")
1233
+ print(f" Collection: {vdb.get('collection_name', 'unknown')}")
1234
+ if 'error' in vdb:
1235
+ print(f" Error: {vdb['error']}")
1236
+
1237
+ # Document Stats
1238
+ print("\n📄 DOCUMENT LOADING:")
1239
+ doc_stats = diagnostics.get('document_stats', {})
1240
+ print(f" Status: {doc_stats.get('status', 'UNKNOWN')}")
1241
+ print(f" Total files found: {doc_stats.get('total_files_found', 0)}")
1242
+ if 'sample_files' in doc_stats:
1243
+ print(f" Sample files:")
1244
+ for file_info in doc_stats['sample_files']:
1245
+ if file_info.get('readable', False):
1246
+ print(f" ✓ {file_info['filename']}: {file_info.get('size_chars', 0):,} chars, {file_info.get('size_words', 0):,} words")
1247
+ else:
1248
+ print(f" ✗ {file_info['filename']}: {file_info.get('error', 'unreadable')}")
1249
+ if 'error' in doc_stats:
1250
+ print(f" Error: {doc_stats['error']}")
1251
+
1252
+ # Chunk Stats
1253
+ print("\n✂️ CHUNKING:")
1254
+ chunk_stats = diagnostics.get('chunk_stats', {})
1255
+ print(f" Status: {chunk_stats.get('status', 'UNKNOWN')}")
1256
+ if chunk_stats.get('status') == 'OK':
1257
+ print(f" Sample document: {chunk_stats.get('sample_document', 'N/A')}")
1258
+ print(f" Total chunks from sample: {chunk_stats.get('total_chunks', 0)}")
1259
+ print(f" Average chunk size: {chunk_stats.get('avg_chunk_size_words', 0):.1f} words")
1260
+ print(f" Chunk size range: {chunk_stats.get('min_chunk_size_words', 0)} - {chunk_stats.get('max_chunk_size_words', 0)} words")
1261
+ print(f" Settings: size={chunk_stats.get('chunk_size_setting', 0)}, overlap={chunk_stats.get('chunk_overlap_setting', 0)}")
1262
+ if 'error' in chunk_stats:
1263
+ print(f" Error: {chunk_stats['error']}")
1264
+
1265
+ # Retrieval Tests
1266
+ if diagnostics.get('retrieval_tests'):
1267
+ print("\n🔍 RETRIEVAL TESTS:")
1268
+ for test in diagnostics['retrieval_tests']:
1269
+ print(f"\n Question: {test.get('question', 'N/A')}")
1270
+ print(f" Status: {test.get('status', 'UNKNOWN')}")
1271
+ if test.get('status') == 'OK':
1272
+ print(f" Chunks retrieved: {test.get('chunks_retrieved', 0)}")
1273
+ print(f" Sources: {test.get('sources', 'N/A')}")
1274
+ scores = test.get('similarity_scores', [])
1275
+ if scores:
1276
+ print(f" Similarity scores: {', '.join(scores)}")
1277
+ # Warn if scores are low
1278
+ try:
1279
+ score_values = [float(s) for s in scores]
1280
+ if max(score_values) < 0.3:
1281
+ print(f" ⚠️ WARNING: Low similarity scores - retrieved chunks may not be very relevant")
1282
+ elif max(score_values) < 0.5:
1283
+ print(f" ⚠️ NOTE: Moderate similarity - consider increasing --k or checking chunk quality")
1284
+ except:
1285
+ pass
1286
+ if 'sample_chunk_preview' in test:
1287
+ print(f" Sample chunk preview: {test['sample_chunk_preview']}")
1288
+ elif 'error' in test:
1289
+ print(f" Error: {test['error']}")
1290
+
1291
+ print("\n" + "="*80 + "\n")
1292
+
1293
+ def _extract_sources(self, context_chunks: List[Chunk]) -> str:
1294
+ """Extract source document names from context chunks"""
1295
+ sources = []
1296
+ for chunk in context_chunks:
1297
+ # Debug: Print chunk filename if verbose
1298
+ if self.args.verbose:
1299
+ logger.info(f"Chunk filename: {chunk.filename}")
1300
+
1301
+ # Extract filename from chunk attribute (not metadata)
1302
+ source = chunk.filename if hasattr(chunk, 'filename') and chunk.filename else 'Unknown source'
1303
+ # Clean up the source name
1304
+ if source.endswith('.pdf'):
1305
+ source = source[:-4] # Remove .pdf extension
1306
+ elif source.endswith('.txt'):
1307
+ source = source[:-4] # Remove .txt extension
1308
+ elif source.endswith('.md'):
1309
+ source = source[:-3] # Remove .md extension
1310
+
1311
+ sources.append(source)
1312
+
1313
+ # Remove duplicates while preserving order
1314
+ unique_sources = []
1315
+ for source in sources:
1316
+ if source not in unique_sources:
1317
+ unique_sources.append(source)
1318
+
1319
+ return "; ".join(unique_sources)
1320
+
1321
+ def _categorize_question(self, question: str) -> str:
1322
+ """Categorize a question into one of 5 categories"""
1323
+ question_lower = question.lower()
1324
+
1325
+ # Gene-Specific Recommendations
1326
+ if any(gene in question_lower for gene in ['msh2', 'mlh1', 'msh6', 'pms2', 'epcam', 'brca1', 'brca2']):
1327
+ if any(kw in question_lower for kw in ['screening', 'surveillance', 'prevention', 'recommendation', 'risk', 'cancer risk', 'steps', 'management']):
1328
+ return "Gene-Specific Recommendations"
1329
+
1330
+ # Inheritance Patterns
1331
+ if any(kw in question_lower for kw in ['inherit', 'inherited', 'pass', 'skip a generation', 'generation', 'can i pass']):
1332
+ return "Inheritance Patterns"
1333
+
1334
+ # Family Risk Assessment
1335
+ if any(kw in question_lower for kw in ['family member', 'relative', 'first-degree', 'family risk', 'which relative', 'should my family']):
1336
+ return "Family Risk Assessment"
1337
+
1338
+ # Genetic Variant Interpretation
1339
+ if any(kw in question_lower for kw in ['what does', 'genetic variant mean', 'variant mean', 'mutation mean', 'genetic result']):
1340
+ return "Genetic Variant Interpretation"
1341
+
1342
+ # Support and Resources
1343
+ if any(kw in question_lower for kw in ['cope', 'overwhelmed', 'resource', 'genetic counselor', 'support', 'research', 'help', 'insurance', 'gina']):
1344
+ return "Support and Resources"
1345
+
1346
+ # Default to Genetic Variant Interpretation if unclear
1347
+ return "Genetic Variant Interpretation"
1348
+
1349
+ def enhance_readability(self, answer: str, target_level: str = "middle_school") -> Tuple[str, float]:
1350
+ """Enhance answer readability to different levels and calculate Flesch-Kincaid Grade Level
1351
+
1352
+ Args:
1353
+ answer: The original answer to simplify or enhance
1354
+ target_level: One of "middle_school", "high_school", "college", or "doctoral"
1355
+
1356
+ Returns:
1357
+ Tuple of (enhanced_answer, grade_level)
1358
+ """
1359
+ try:
1360
+ # Define prompts for different reading levels
1361
+ if target_level == "middle_school":
1362
+ level_description = "middle school reading level (ages 12-14, 6th-8th grade)"
1363
+ instructions = """
1364
+ - Use simpler medical terms or explain them
1365
+ - Medium-length sentences
1366
+ - Clear, structured explanations
1367
+ - Keep important medical information accessible"""
1368
+ elif target_level == "high_school":
1369
+ level_description = "high school reading level (ages 15-18, 9th-12th grade)"
1370
+ instructions = """
1371
+ - Use appropriate medical terminology with context
1372
+ - Varied sentence length
1373
+ - Comprehensive yet accessible explanations
1374
+ - Maintain technical accuracy while ensuring clarity"""
1375
+ elif target_level == "college":
1376
+ level_description = "college reading level (undergraduate level, ages 18-22)"
1377
+ instructions = """
1378
+ - Use standard medical terminology with brief explanations
1379
+ - Professional and clear writing style
1380
+ - Include relevant clinical context
1381
+ - Maintain scientific accuracy and precision
1382
+ - Appropriate for undergraduate students in health sciences"""
1383
+ elif target_level == "doctoral":
1384
+ level_description = "doctoral/professional reading level (graduate level, medical professionals)"
1385
+ instructions = """
1386
+ - Use advanced medical and scientific terminology
1387
+ - Include detailed clinical and research context
1388
+ - Reference specific mechanisms, pathways, and evidence
1389
+ - Provide comprehensive technical explanations
1390
+ - Appropriate for medical professionals, researchers, and graduate students
1391
+ - Include nuanced discussions of clinical implications and research findings"""
1392
+ else:
1393
+ raise ValueError(f"Unknown target_level: {target_level}. Must be one of: middle_school, high_school, college, doctoral")
1394
+
1395
+ # Create a prompt to enhance the medical answer for the target level
1396
+ # Try to use chat template if available, otherwise use manual format
1397
+ system_message = f"""You are a helpful medical assistant who specializes in explaining complex medical information at appropriate reading levels. Rewrite the following medical answer for {level_description}:
1398
+ {instructions}
1399
+ - Keep the same important information but adapt the complexity
1400
+ - Provide context for technical terms
1401
+ - Ensure the answer is informative yet understandable"""
1402
+
1403
+ user_message = f"Please rewrite this medical answer for {level_description}:\n\n{answer}"
1404
+
1405
+ # Try to use chat template if available
1406
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
1407
+ try:
1408
+ messages = [
1409
+ {"role": "system", "content": system_message},
1410
+ {"role": "user", "content": user_message}
1411
+ ]
1412
+ readability_prompt = self.tokenizer.apply_chat_template(
1413
+ messages,
1414
+ tokenize=False,
1415
+ add_generation_prompt=True
1416
+ )
1417
+ except Exception as e:
1418
+ logger.warning(f"Failed to use chat template for readability, falling back to manual format: {e}")
1419
+ readability_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1420
+
1421
+ {system_message}
1422
+
1423
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
1424
+
1425
+ {user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1426
+
1427
+ """
1428
+ else:
1429
+ # Fall back to manual formatting (for Llama models)
1430
+ readability_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1431
+
1432
+ {system_message}
1433
+
1434
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
1435
+
1436
+ {user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1437
+
1438
+ """
1439
+
1440
+ # Generate simplified answer
1441
+ inputs = self.tokenizer(readability_prompt, return_tensors="pt", truncation=True, max_length=2048)
1442
+ if self.device == "mps":
1443
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
1444
+
1445
+ # Adjust generation parameters based on target level
1446
+ if target_level in ["college", "doctoral"]:
1447
+ max_tokens = 512 # Reduced from 1024 for faster responses
1448
+ temp = 0.4 # Slightly higher temperature for more natural flow
1449
+ else:
1450
+ max_tokens = 384 # Reduced from 512 for faster responses
1451
+ temp = 0.3 # Lower temperature for more consistent simplification
1452
+
1453
+ with torch.no_grad():
1454
+ outputs = self.model.generate(
1455
+ **inputs,
1456
+ max_new_tokens=max_tokens,
1457
+ temperature=temp,
1458
+ top_p=0.9,
1459
+ repetition_penalty=1.05,
1460
+ do_sample=True,
1461
+ pad_token_id=self.tokenizer.eos_token_id,
1462
+ eos_token_id=self.tokenizer.eos_token_id,
1463
+ use_cache=True,
1464
+ num_beams=1
1465
+ )
1466
+
1467
+ # Decode response
1468
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
1469
+
1470
+ # Extract enhanced answer
1471
+ # Try to find the assistant response marker
1472
+ prompt_end_marker = "<|start_header_id|>assistant<|end_header_id|>\n\n"
1473
+ if prompt_end_marker in response:
1474
+ simplified_answer = response.split(prompt_end_marker)[-1].strip()
1475
+ elif "<|assistant|>" in response:
1476
+ # Some chat templates use <|assistant|>
1477
+ simplified_answer = response.split("<|assistant|>")[-1].strip()
1478
+ else:
1479
+ # Fallback: extract everything after the prompt
1480
+ simplified_answer = response[len(readability_prompt):].strip()
1481
+
1482
+ # Clean up special tokens
1483
+ if "<|eot_id|>" in simplified_answer:
1484
+ if simplified_answer.endswith("<|eot_id|>"):
1485
+ simplified_answer = simplified_answer[:-len("<|eot_id|>")].strip()
1486
+ if "<|end_of_text|>" in simplified_answer:
1487
+ if simplified_answer.endswith("<|end_of_text|>"):
1488
+ simplified_answer = simplified_answer[:-len("<|end_of_text|>")].strip()
1489
+
1490
+ # Clean up unwanted phrases and formatting
1491
+ simplified_answer = self._clean_readability_answer(simplified_answer, target_level)
1492
+
1493
+ # Calculate Flesch-Kincaid Grade Level
1494
+ try:
1495
+ grade_level = textstat.flesch_kincaid_grade(simplified_answer)
1496
+ except:
1497
+ grade_level = 0.0
1498
+
1499
+ if self.args.verbose:
1500
+ logger.info(f"Simplified answer length: {len(simplified_answer)} characters")
1501
+ logger.info(f"Flesch-Kincaid Grade Level: {grade_level:.1f}")
1502
+
1503
+ return simplified_answer, grade_level
1504
+
1505
+ except Exception as e:
1506
+ logger.error(f"Error enhancing readability: {e}")
1507
+ # Fallback: return original answer with estimated grade level
1508
+ try:
1509
+ grade_level = textstat.flesch_kincaid_grade(answer)
1510
+ except:
1511
+ grade_level = 12.0 # Default to high school level
1512
+ return answer, grade_level
1513
+
1514
+ def write_csv(self, qa_pairs: List[Tuple[str, str, str, str, float, str, float, str, float, str, str]], output_path: str, append: bool = False, improved_prompt_text: str = "") -> None:
1515
+ """Write Q&A pairs to CSV file in results folder
1516
+
1517
+ Expected tuple format: (question, answer, sources, question_group, original_flesch,
1518
+ middle_school_answer, middle_school_flesch,
1519
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores)
1520
+ """
1521
+ # Ensure results directory exists
1522
+ os.makedirs('results', exist_ok=True)
1523
+
1524
+ # If output_path doesn't already have results/ prefix, add it
1525
+ if not output_path.startswith('results/'):
1526
+ output_path = f'results/{output_path}'
1527
+
1528
+ if append:
1529
+ logger.info(f"Appending results to {output_path}")
1530
+ else:
1531
+ logger.info(f"Writing results to {output_path}")
1532
+
1533
+ # Create output directory if needed
1534
+ output_path = Path(output_path)
1535
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1536
+
1537
+ try:
1538
+ # Check if file exists and if we're appending
1539
+ file_exists = output_path.exists()
1540
+ write_mode = 'a' if append and file_exists else 'w'
1541
+
1542
+ with open(output_path, write_mode, newline='', encoding='utf-8') as f:
1543
+ writer = csv.writer(f)
1544
+
1545
+ # Write header only if creating new file or first append
1546
+ if not append or not file_exists:
1547
+ # Create improved answer header with prompt text
1548
+ improved_header = f'improved_answer (PROMPT: {improved_prompt_text})'
1549
+ writer.writerow(['question', 'question_group', 'answer', 'original_flesch', 'sources',
1550
+ 'similarity_scores', 'middle_school_answer', 'middle_school_flesch',
1551
+ 'high_school_answer', 'high_school_flesch', improved_header])
1552
+
1553
+ for data in qa_pairs:
1554
+ # Unpack the data tuple
1555
+ (question, answer, sources, question_group, original_flesch,
1556
+ middle_school_answer, middle_school_flesch,
1557
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores) = data
1558
+
1559
+ # Clean and escape the answers for CSV
1560
+ def clean_text(text):
1561
+ # Replace newlines with spaces and clean up formatting
1562
+ cleaned = text.replace('\n', ' ').replace('\r', ' ')
1563
+ # Remove extra whitespace but preserve the full content
1564
+ cleaned = ' '.join(cleaned.split())
1565
+ # Escape quotes properly for CSV
1566
+ cleaned = cleaned.replace('"', '""')
1567
+ return cleaned
1568
+
1569
+ clean_question = clean_text(question)
1570
+ clean_answer = clean_text(answer)
1571
+ clean_sources = clean_text(sources)
1572
+ clean_middle_school = clean_text(middle_school_answer)
1573
+ clean_high_school = clean_text(high_school_answer)
1574
+ clean_improved = clean_text(improved_answer)
1575
+
1576
+ # Log the full answer length for debugging
1577
+ if self.args.verbose:
1578
+ logger.info(f"Writing answer length: {len(clean_answer)} characters")
1579
+ logger.info(f"Middle school answer length: {len(clean_middle_school)} characters")
1580
+ logger.info(f"High school answer length: {len(clean_high_school)} characters")
1581
+ logger.info(f"Improved answer length: {len(clean_improved)} characters")
1582
+ logger.info(f"Question group: {question_group}")
1583
+
1584
+ # Use proper CSV quoting - let csv.writer handle the quoting
1585
+ writer.writerow([
1586
+ clean_question,
1587
+ question_group,
1588
+ clean_answer,
1589
+ f"{original_flesch:.1f}",
1590
+ clean_sources,
1591
+ similarity_scores, # Similarity scores (comma-separated)
1592
+ clean_middle_school,
1593
+ f"{middle_school_flesch:.1f}",
1594
+ clean_high_school,
1595
+ f"{high_school_flesch:.1f}",
1596
+ clean_improved
1597
+ ])
1598
+
1599
+ if append:
1600
+ logger.info(f"Appended {len(qa_pairs)} Q&A pairs to {output_path}")
1601
+ else:
1602
+ logger.info(f"Successfully wrote {len(qa_pairs)} Q&A pairs to {output_path}")
1603
+
1604
+ except Exception as e:
1605
+ logger.error(f"Failed to write CSV: {e}")
1606
+ sys.exit(4)
1607
+
1608
+
1609
+ def parse_args():
1610
+ """Parse command line arguments"""
1611
+ parser = argparse.ArgumentParser(description="RAG Chatbot for CGT-LLM-Beta with Vector Database")
1612
+
1613
+ # File paths
1614
+ parser.add_argument('--data-dir', default='./Data Resources',
1615
+ help='Directory containing documents to index')
1616
+ parser.add_argument('--questions', default='./questions.txt',
1617
+ help='File containing questions (one per line)')
1618
+ parser.add_argument('--out', default='./answers.csv',
1619
+ help='Output CSV file for answers')
1620
+ parser.add_argument('--vector-db-dir', default='./chroma_db',
1621
+ help='Directory for ChromaDB persistence')
1622
+
1623
+ # Retrieval parameters
1624
+ parser.add_argument('--k', type=int, default=5,
1625
+ help='Number of chunks to retrieve per question')
1626
+
1627
+ # Chunking parameters
1628
+ parser.add_argument('--chunk-size', type=int, default=500,
1629
+ help='Size of text chunks in tokens')
1630
+ parser.add_argument('--chunk-overlap', type=int, default=200,
1631
+ help='Overlap between chunks in tokens')
1632
+
1633
+ # Model selection
1634
+ parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct',
1635
+ help='HuggingFace model name to use (e.g., meta-llama/Llama-3.2-3B-Instruct, mistralai/Mistral-7B-Instruct-v0.2)')
1636
+
1637
+ # Generation parameters
1638
+ parser.add_argument('--max-new-tokens', type=int, default=1024,
1639
+ help='Maximum new tokens to generate')
1640
+ parser.add_argument('--temperature', type=float, default=0.2,
1641
+ help='Generation temperature')
1642
+ parser.add_argument('--top-p', type=float, default=0.9,
1643
+ help='Top-p sampling parameter')
1644
+ parser.add_argument('--repetition-penalty', type=float, default=1.1,
1645
+ help='Repetition penalty')
1646
+
1647
+ # Database options
1648
+ parser.add_argument('--force-rebuild', action='store_true',
1649
+ help='Force rebuild of vector database')
1650
+ parser.add_argument('--skip-indexing', action='store_true',
1651
+ help='Skip document indexing, use existing database')
1652
+
1653
+ # Other options
1654
+ parser.add_argument('--seed', type=int, default=42,
1655
+ help='Random seed for reproducibility')
1656
+ parser.add_argument('--verbose', action='store_true',
1657
+ help='Enable verbose logging')
1658
+ parser.add_argument('--dry-run', action='store_true',
1659
+ help='Build index and test retrieval without generation')
1660
+ parser.add_argument('--diagnose', action='store_true',
1661
+ help='Run system diagnostics and exit')
1662
+
1663
+ return parser.parse_args()
1664
+
1665
+
1666
+ def main():
1667
+ """Main function"""
1668
+ args = parse_args()
1669
+
1670
+ # Set random seed
1671
+ torch.manual_seed(args.seed)
1672
+ np.random.seed(args.seed)
1673
+
1674
+ # Set logging level
1675
+ if args.verbose:
1676
+ logging.getLogger().setLevel(logging.DEBUG)
1677
+
1678
+ logger.info("Starting RAG Chatbot with Vector Database")
1679
+ logger.info(f"Arguments: {vars(args)}")
1680
+
1681
+ try:
1682
+ # Initialize bot
1683
+ bot = RAGBot(args)
1684
+
1685
+ # Check if we should skip indexing
1686
+ if not args.skip_indexing:
1687
+ # Load and process documents
1688
+ documents = bot.load_corpus(args.data_dir)
1689
+ if not documents:
1690
+ logger.error("No documents found to process")
1691
+ sys.exit(3)
1692
+
1693
+ # Chunk documents
1694
+ chunks = bot.chunk_documents(documents, args.chunk_size, args.chunk_overlap)
1695
+ if not chunks:
1696
+ logger.error("No chunks created from documents")
1697
+ sys.exit(3)
1698
+
1699
+ # Build or update index
1700
+ bot.build_or_update_index(chunks, args.force_rebuild)
1701
+ else:
1702
+ logger.info("Skipping document indexing, using existing vector database")
1703
+
1704
+ # Run diagnostics if requested
1705
+ if args.diagnose:
1706
+ sample_questions = [
1707
+ "What is Lynch Syndrome?",
1708
+ "What does a BRCA1 genetic variant mean?",
1709
+ "What screening tests are recommended for MSH2 carriers?"
1710
+ ]
1711
+ diagnostics = bot.diagnose_system(sample_questions=sample_questions)
1712
+ bot.print_diagnostics(diagnostics)
1713
+ return
1714
+
1715
+ if args.dry_run:
1716
+ logger.info("Dry run completed successfully")
1717
+ return
1718
+
1719
+ # Process questions
1720
+ generation_kwargs = {
1721
+ 'max_new_tokens': args.max_new_tokens,
1722
+ 'temperature': args.temperature,
1723
+ 'top_p': args.top_p,
1724
+ 'repetition_penalty': args.repetition_penalty
1725
+ }
1726
+
1727
+ qa_pairs = bot.process_questions(args.questions, output_file=args.out, **generation_kwargs)
1728
+
1729
+ logger.info("RAG Chatbot completed successfully")
1730
+
1731
+ except KeyboardInterrupt:
1732
+ logger.info("Interrupted by user")
1733
+ sys.exit(0)
1734
+ except Exception as e:
1735
+ logger.error(f"Unexpected error: {e}")
1736
+ if args.verbose:
1737
+ import traceback
1738
+ traceback.print_exc()
1739
+ sys.exit(1)
1740
+
1741
+
1742
+ if __name__ == "__main__":
1743
+ main()
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80fe29380be0f587de8c3d0df3bbd891219ebe35d3ab4e007721d322ca704b9f
3
+ size 18888520
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56091853c1c20a1ec97ba4a7935cb7ab95f58b91d1ca56b990bf768f7bd2df88
3
+ size 100
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:754f12ddf66368443039e44c7d3625dbfa54c42604f231054e5c8ab8df162ebb
3
+ size 548379
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e72c9f5fb80c8fa3f488f68172cf32cdaf226d94cb6cff09ff68990b34fbb04c
3
+ size 45080
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0046b8333ff42649a27896a5da1f0fd89ee54954221fde9172dfe284d94262b
3
+ size 99820
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70340ab0d0dddb6b5bcf29c0e09f316b0f695f6645be0231302346d5af463700
3
+ size 294584320
requirements.txt ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # RAG Chatbot with Vector Database - Requirements
3
+ # =============================================================================
4
+ # Production-ready dependencies for medical document analysis and Q&A
5
+
6
+ # Core ML/AI Framework
7
+ torch>=2.0.0 # PyTorch for model inference
8
+ transformers>=4.30.0 # Hugging Face transformers
9
+ accelerate>=0.20.0 # Model loading optimization
10
+ safetensors>=0.3.0 # Safe model loading
11
+
12
+ # Vector Database & Embeddings
13
+ chromadb>=0.4.0 # Vector database for fast retrieval
14
+ sentence-transformers>=2.2.0 # Semantic embeddings (all-MiniLM-L6-v2)
15
+
16
+ # Data Processing
17
+ pandas>=1.3.0 # Data manipulation and CSV handling
18
+ numpy>=1.21.0 # Numerical computing
19
+ scikit-learn>=1.0.0 # ML utilities and TF-IDF
20
+
21
+ # Text Analysis & Readability
22
+ textstat>=0.7.0 # Flesch-Kincaid Grade Level calculation
23
+ nltk>=3.8.0 # Natural language processing utilities
24
+
25
+ # Document Processing (Core)
26
+ pypdf>=3.0.0 # PDF document parsing
27
+ python-docx>=0.8.11 # DOCX document parsing
28
+
29
+ # Optional Document Processing
30
+ rank-bm25>=0.2.2 # BM25 retrieval algorithm (alternative to TF-IDF)
31
+
32
+ # Utilities & Progress
33
+ tqdm>=4.65.0 # Progress bars
34
+ pathlib2>=2.3.0 # Enhanced path handling (if needed)
35
+
36
+ # Web Interface
37
+ gradio>=4.0.0 # Gradio web interface for chatbot
38
+
39
+ # Development & Testing (Optional)
40
+ pytest>=7.0.0 # Testing framework
41
+ black>=22.0.0 # Code formatting
42
+ flake8>=4.0.0 # Code linting
43
+
44
+ # Performance Monitoring (Optional)
45
+ psutil>=5.8.0 # System resource monitoring
46
+ memory-profiler>=0.60.0 # Memory usage profiling
47
+
48
+ # =============================================================================
49
+ # Installation Notes:
50
+ # =============================================================================
51
+ # 1. Install with: pip install -r requirements.txt
52
+ # 2. For Apple Silicon: PyTorch will automatically use MPS acceleration
53
+ # 3. Optional packages can be installed separately if needed
54
+ # 4. Model files (~6GB) will be downloaded on first run
55
+ # =============================================================================