Kenko / app.py
IniNLP247's picture
Update app.py
1bb45dc verified
raw
history blame
23.9 kB
#INFERENCE NLP+EMOTION DETECTION CV+TTS+Memory Management
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from deepface import DeepFace
import time
from transformers import AutoProcessor, DiaForConditionalGeneration
from sentence_transformers import SentenceTransformer
import numpy as np
import chromadb
from langchain_community.vectorstores import Chroma
from collections import defaultdict
from sklearn.cluster import DBSCAN
model_name = "IniNLP247/Kenko-mental-health-llama-3-model"
print("Loading Kenko Mental Health Model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=False,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id
)
print("Model loaded successfully!")
print("Loading Dia TTS Model...")
tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
tts_model = "nari-labs/Dia-1.6B-0626"
tts_processor = AutoProcessor.from_pretrained(tts_model)
tts_model = DiaForConditionalGeneration.from_pretrained(tts_model, torch_dtype=torch.float16).to(tts_device)
print("Dia TTS Model loaded successfully!")
print("Initializing Memory Components...")
chroma_client = chromadb.Client()
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
def embed_function(texts):
if isinstance(texts, str):
texts = [texts]
return embedding_model.encode(texts).tolist()
global_vector_store = Chroma(
client=chroma_client,
embedding_function=embed_function
)
print("Memory components initialized!")
current_emotion_state = {
"dominant": "neutral",
"confidence": 0.0,
"all_emotions": {},
"last_update": None
}
class AdvancedMemorySystem:
"""
Multi-tier memory system inspired by human memory:
- Working Memory: Current conversation (high priority)
- Short-term Memory: Recent session with decay
- Long-term Memory: Semantic clusters of important themes
- Emotional Memory: Affective associations and patterns
"""
def __init__(self, embedding_model, vector_store):
self.embedding_model = embedding_model
self.vector_store = vector_store
self.working_memory = []
self.short_term_memory = []
self.semantic_clusters = defaultdict(list)
self.emotional_memory = {
"emotion_transitions": [],
"trigger_patterns": defaultdict(list),
"coping_effectiveness": {}
}
self.conversation_themes = []
self.user_model = {
"communication_style": None,
"recurring_concerns": [],
"progress_indicators": [],
"relational_patterns": []
}
def calculate_importance(self, text, emotion, user_engagement):
"""Calculate memory importance using multiple factors"""
importance = 0.5
high_intensity_emotions = ["fear", "angry", "sad", "surprise"]
if emotion in high_intensity_emotions:
importance += 0.3
if len(text.split()) > 30:
importance += 0.2
therapeutic_keywords = [
"trauma", "suicide", "self-harm", "abuse", "panic",
"breakthrough", "progress", "better", "worse", "relationship"
]
if any(kw in text.lower() for kw in therapeutic_keywords):
importance += 0.3
return min(importance, 1.0)
def add_to_working_memory(self, user_msg, bot_msg, emotion, timestamp):
"""Add to immediate working memory (sliding window)"""
self.working_memory.append({
"user": user_msg,
"bot": bot_msg,
"emotion": emotion,
"timestamp": timestamp
})
if len(self.working_memory) > 5:
oldest = self.working_memory.pop(0)
self._consolidate_to_short_term(oldest)
def _consolidate_to_short_term(self, memory_item):
"""Move from working to short-term memory with importance scoring"""
text = f"User: {memory_item['user']}\nKenko: {memory_item['bot']}"
embedding = self.embedding_model.encode(text)
importance = self.calculate_importance(
memory_item['user'],
memory_item['emotion'],
len(memory_item['user'].split())
)
self.short_term_memory.append({
"text": text,
"embedding": embedding,
"importance": importance,
"timestamp": memory_item['timestamp'],
"emotion": memory_item['emotion']
})
try:
self.vector_store.add_texts(
texts=[text],
metadatas=[{"importance": importance, "timestamp": memory_item['timestamp']}]
)
except Exception as e:
print(f"Vector store error: {e}")
def apply_temporal_decay(self, current_time):
"""Apply decay to short-term memories over time"""
decay_rate = 0.01
for memory in self.short_term_memory:
time_elapsed = (current_time - memory['timestamp']) / 60
decay_factor = np.exp(-decay_rate * time_elapsed)
memory['importance'] *= decay_factor
if memory['importance'] < 0.15:
self._consolidate_to_long_term(memory)
def _consolidate_to_long_term(self, memory):
"""Cluster similar memories into semantic long-term memory"""
if not self.semantic_clusters:
self.semantic_clusters[0] = [memory]
self.short_term_memory.remove(memory)
return
best_cluster = 0
best_similarity = -1
for cluster_id, cluster_memories in self.semantic_clusters.items():
cluster_embeddings = [m['embedding'] for m in cluster_memories]
centroid = np.mean(cluster_embeddings, axis=0)
similarity = np.dot(memory['embedding'], centroid) / (
np.linalg.norm(memory['embedding']) * np.linalg.norm(centroid)
)
if similarity > best_similarity:
best_similarity = similarity
best_cluster = cluster_id
if best_similarity > 0.7:
self.semantic_clusters[best_cluster].append(memory)
else:
new_cluster_id = max(self.semantic_clusters.keys()) + 1
self.semantic_clusters[new_cluster_id] = [memory]
if memory in self.short_term_memory:
self.short_term_memory.remove(memory)
def track_emotional_transition(self, prev_emotion, current_emotion, context):
"""Track emotional state transitions for pattern recognition"""
self.emotional_memory["emotion_transitions"].append({
"from": prev_emotion,
"to": current_emotion,
"context": context,
"timestamp": time.time()
})
if prev_emotion != current_emotion:
self.emotional_memory["trigger_patterns"][current_emotion].append(context)
def analyze_conversation_themes(self):
"""Use topic modeling on conversation to identify recurring themes"""
if len(self.short_term_memory) < 3:
return []
all_text = " ".join([m['text'] for m in self.short_term_memory])
words = all_text.lower().split()
word_freq = defaultdict(int)
stopwords = {"the", "a", "is", "in", "and", "to", "of", "i", "my", "me", "you", "that", "it"}
for word in words:
if word not in stopwords and len(word) > 4:
word_freq[word] += 1
themes = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:5]
self.conversation_themes = [theme[0] for theme in themes]
return self.conversation_themes
def retrieve_contextual_memory(self, query, current_emotion):
"""Advanced retrieval using multiple memory tiers"""
context = {
"working": [],
"short_term": [],
"long_term": [],
"emotional": [],
"themes": []
}
context["working"] = self.working_memory[-3:]
if self.short_term_memory:
query_embedding = self.embedding_model.encode(query)
scored_memories = []
for memory in self.short_term_memory:
similarity = np.dot(query_embedding, memory['embedding']) / (
np.linalg.norm(query_embedding) * np.linalg.norm(memory['embedding'])
)
final_score = similarity * memory['importance']
if memory['emotion'] == current_emotion:
final_score *= 1.2
scored_memories.append((final_score, memory))
scored_memories.sort(reverse=True, key=lambda x: x[0])
context["short_term"] = [m[1] for m in scored_memories[:3]]
if self.semantic_clusters:
query_embedding = self.embedding_model.encode(query)
best_cluster_id = None
best_cluster_score = -1
for cluster_id, cluster_memories in self.semantic_clusters.items():
cluster_embeddings = [m['embedding'] for m in cluster_memories]
centroid = np.mean(cluster_embeddings, axis=0)
similarity = np.dot(query_embedding, centroid) / (
np.linalg.norm(query_embedding) * np.linalg.norm(centroid)
)
if similarity > best_cluster_score:
best_cluster_score = similarity
best_cluster_id = cluster_id
if best_cluster_id is not None and best_cluster_score > 0.6:
cluster = self.semantic_clusters[best_cluster_id]
context["long_term"] = cluster[:2]
if current_emotion in self.emotional_memory["trigger_patterns"]:
triggers = self.emotional_memory["trigger_patterns"][current_emotion]
context["emotional"] = triggers[-2:]
context["themes"] = self.analyze_conversation_themes()
return context
def update_user_model(self, message, emotion):
"""Build a psychological profile of the user over time"""
if len(message.split()) > 50:
style = "detailed"
elif len(message.split()) < 10:
style = "concise"
else:
style = "moderate"
self.user_model["communication_style"] = style
concern_keywords = {
"anxiety": ["anxious", "worried", "panic", "nervous", "anxiety"],
"depression": ["sad", "depressed", "hopeless", "empty", "depression"],
"relationships": ["partner", "relationship", "friend", "family"],
"work_stress": ["work", "job", "career", "boss", "stress"]
}
for concern, keywords in concern_keywords.items():
if any(kw in message.lower() for kw in keywords):
if concern not in self.user_model["recurring_concerns"]:
self.user_model["recurring_concerns"].append(concern)
def generate_memory_context_string(self, contextual_memory):
"""Format retrieved memories into prompt context"""
context_parts = []
if contextual_memory["working"]:
recent = "\n".join([
f"User: {m['user']}\nKenko: {m['bot']}"
for m in contextual_memory["working"]
])
context_parts.append(f"### Recent Conversation:\n{recent}")
if contextual_memory["short_term"]:
important = "\n".join([m['text'] for m in contextual_memory["short_term"]])
context_parts.append(f"### Important Recent Context:\n{important}")
if contextual_memory["long_term"]:
longterm = "\n".join([m['text'] for m in contextual_memory["long_term"]])
context_parts.append(f"### Related Past Discussions:\n{longterm}")
if contextual_memory["emotional"]:
emotional = ", ".join(contextual_memory["emotional"][:3])
context_parts.append(f"### Emotional Pattern: Previously triggered by: {emotional}")
if contextual_memory["themes"]:
themes = ", ".join(contextual_memory["themes"])
context_parts.append(f"### Session Themes: {themes}")
if self.user_model["recurring_concerns"]:
concerns = ", ".join(self.user_model["recurring_concerns"])
context_parts.append(f"### Recurring Concerns: {concerns}")
return "\n\n".join(context_parts)
def reset(self):
"""Reset all memory tiers"""
self.working_memory = []
self.short_term_memory = []
self.semantic_clusters = defaultdict(list)
self.emotional_memory = {
"emotion_transitions": [],
"trigger_patterns": defaultdict(list),
"coping_effectiveness": {}
}
self.conversation_themes = []
self.user_model = {
"communication_style": None,
"recurring_concerns": [],
"progress_indicators": [],
"relational_patterns": []
}
print("๐Ÿ”„ Initializing Advanced Memory System...")
advanced_memory = AdvancedMemorySystem(embedding_model, global_vector_store)
print("โœ… Advanced Memory System initialized!")
previous_emotion = "neutral"
def update_emotion_status():
if current_emotion_state["last_update"] is None:
return "*Waiting for emotion data...*"
elapsed = time.time() - current_emotion_state["last_update"]
if elapsed > 60:
return "*Emotion data outdated - please ensure webcam is active*"
dominant = current_emotion_state["dominant"]
confidence = current_emotion_state["confidence"]
return f"**Current Emotion:** {dominant.capitalize()} ({confidence:.1f}% confidence)\n*Last updated: {int(elapsed)}s ago*"
def analyze_emotion(image):
global current_emotion_state
try:
if image is None:
return {}
result = DeepFace.analyze(
img_path=image,
actions=['emotion'],
enforce_detection=False,
detector_backend='opencv'
)
if isinstance(result, list):
emotions = result[0]['emotion']
dominant = result[0]['dominant_emotion']
else:
emotions = result['emotion']
dominant = result['dominant_emotion']
current_emotion_state = {
"dominant": dominant,
"confidence": emotions[dominant],
"all_emotions": emotions,
"last_update": time.time()
}
output = {}
for emotion, score in sorted(emotions.items(), key=lambda x: x[1], reverse=True):
output[emotion.capitalize()] = score
return output
except Exception as e:
print(f"Emotion analysis error: {str(e)}")
return {}
def get_emotion_context():
"""Get current emotion as context string for the model"""
if current_emotion_state["last_update"] is None:
return ""
if time.time() - current_emotion_state["last_update"] > 60:
return ""
dominant = current_emotion_state["dominant"]
confidence = current_emotion_state["confidence"]
emotion_context = f"\n[User's Current Detected Emotion: {dominant} ({confidence:.1f}% confidence)]"
return emotion_context
def chat_with_kenko(message, history):
"""Chat function for Gradio interface with emotion awareness"""
conversation = ""
for user_msg, bot_msg in history:
conversation += f"User: {user_msg}\nKenko: {bot_msg}\n\n"
emotion_context = get_emotion_context()
prompt = f"""### Instruction:
You are Kenko, a compassionate mental health therapist. Provide empathetic, helpful, and professional responses to support the user's mental wellbeing.
{emotion_context}
{conversation}User: {message}
### Response:
"""
try:
response = pipe(prompt)[0]['generated_text']
return response.strip()
except Exception as e:
return f"I'm sorry, I'm having trouble processing your message right now. Error: {str(e)}"
def generate_tts(text):
try:
text = text[:600]
print(f"[TTS] Generating speech for {len(text)} chars: '{text[:50]}...'")
inputs = tts_processor(text=text, return_tensors="pt", padding=True)
inputs = {k: v.to(tts_device) for k, v in inputs.items()}
print(f"[TTS] Inputs prepared, generating audio codes...")
with torch.no_grad():
generated_ids = tts_model.generate(**inputs, max_length=2500)
print(f"[TTS] Audio codes generated, shape: {generated_ids.shape}")
print(f"[TTS] Decoding codes to waveform...")
audio_values = tts_processor.batch_decode(generated_ids, return_tensors="pt")
if isinstance(audio_values, dict) and 'audio_values' in audio_values:
audio_arr = audio_values['audio_values'][0].cpu().numpy()
elif isinstance(audio_values, torch.Tensor):
audio_arr = audio_values[0].cpu().numpy()
elif isinstance(audio_values, list):
audio_arr = np.array(audio_values[0])
else:
audio_arr = np.array(audio_values).squeeze()
audio_arr = audio_arr.astype(np.float32)
sample_rate = 44100
print(f"[TTS] Audio decoded: {len(audio_arr)} samples at {sample_rate}Hz = {len(audio_arr)/sample_rate:.2f} seconds")
if len(audio_arr) == 0:
print("Decoded audio is empty!")
return None
return (sample_rate, audio_arr)
except Exception as e:
print(f"TTS generation error: {str(e)}")
import traceback
traceback.print_exc()
return None
css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.emotion-box {
border: 2px solid #4CAF50;
border-radius: 10px;
padding: 10px;
margin: 10px 0;
}
"""
with gr.Blocks(
title="Kenko - Mental Health Assistant",
theme=gr.themes.Soft(),
css=css
) as demo:
gr.Markdown("""
# ๐Ÿ’š Kenko - Your Emotion-Aware Mental Health Assistant
Welcome! I'm Kenko, an AI mental health therapist enhanced with real-time emotion detection.
Allow webcam access to enable emotion-aware responses that adapt to how you're feeling.
*Please remember: I'm an AI assistant and cannot replace professional mental health care. In crisis situations, please contact emergency services or a mental health professional.*
""")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
height=500,
show_label=False,
container=True,
bubble_full_width=False,
avatar_images=("๐Ÿ‘ค", "๐Ÿง ")
)
audio_output = gr.Audio(
label="Kenko's Voice Response",
autoplay=True,
show_label=True
)
with gr.Row():
msg = gr.Textbox(
placeholder="Share what's on your mind... (press Enter to send)",
container=False,
scale=7,
lines=2,
max_lines=4
)
send_btn = gr.Button("Send ๐Ÿ’ฌ", scale=1, variant="primary")
with gr.Row():
clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Chat", scale=1, variant="secondary")
examples_btn = gr.Button("๐Ÿ’ก Example Topics", scale=1, variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### ๐Ÿ“ธ Emotion Detection")
gr.Markdown("*Your emotional state helps me provide more personalized support*")
webcam_input = gr.Image(
sources=["webcam"],
type="numpy",
streaming=True,
label="Live Webcam Feed"
)
emotion_output = gr.Label(
num_top_classes=7,
label="Detected Emotions"
)
emotion_status = gr.Markdown("*Waiting for emotion data...*")
with gr.Row(visible=False) as examples_row:
gr.Examples(
examples=[
"I've been feeling really anxious lately and I don't know why.",
"I'm having trouble sleeping and my mind won't stop racing.",
"I feel overwhelmed with work and personal responsibilities.",
"I'm struggling with low self-esteem and negative thoughts.",
"I'm having difficulty in my relationships.",
"I feel lonely and isolated.",
"I'm dealing with grief and loss.",
"I want to build better coping strategies."
],
inputs=msg,
label="Try these conversation starters:"
)
with gr.Accordion("โ„น๏ธ About Kenko", open=False):
gr.Markdown("""
**What I can help with:**
- Active listening and emotional support (now emotion-aware!)
- Coping strategies and stress management techniques
- Guidance on anxiety, depression, and mood concerns
- Relationship and communication advice
- Mindfulness and self-care suggestions
- Building healthy habits and routines
**Emotion Detection Feature:**
- Real-time facial emotion analysis
- Adapts responses based on your current emotional state
- Updates automatically every 30 seconds
- Completely optional - works without webcam too
**Important Notes:**
- I'm an AI trained to provide mental health support
- For immediate crisis support, contact emergency services (911) or crisis hotlines
- Consider professional therapy for ongoing mental health needs
- I don't diagnose conditions or prescribe medications
**Privacy:** Your conversations and emotion data are not stored or shared.
""")
@spaces.GPU
def respond(message, chat_history):
if not message.strip():
return "", chat_history, None
import time
start = time.time()
bot_response = chat_with_kenko(message, chat_history)
text_time = time.time() - start
print(f"Text Generation Time: {text_time:.2f} seconds: {len(bot_response)} characters")
chat_history.append((message, bot_response))
tts_start = time.time()
print(f"Generating TTS for: '{bot_response[:100]}...'")
audio = generate_tts(bot_response)
tts_time = time.time() - tts_start
print(f"TTS Generation Time: {tts_time:.2f} seconds")
print(f"TOTAL TIME: {time.time() - start:.2f}s")
return "", chat_history, audio
def toggle_examples():
return gr.Row(visible=True)
submit = msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
send = send_btn.click(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
clear_btn.click(lambda: [], None, outputs=[chatbot, audio_output])
examples_btn.click(toggle_examples, outputs=examples_row)
webcam_input.stream(
analyze_emotion,
inputs=webcam_input,
outputs=emotion_output,
stream_every=1,
time_limit=60
)
timer = gr.Timer(value=5)
timer.tick(
fn=lambda: (update_emotion_status()),
outputs=[emotion_status]
)
if __name__ == "__main__":
print("๐Ÿš€ Starting Kenko Mental Health Assistant with Emotion Detection...")
demo.launch(
share=True,
show_error=True
)