VoxDoc / app /models /rag_service.py
joelthomas77's picture
Upload app code
60d4850 verified
"""
RAG (Retrieval-Augmented Generation) Service — Production Healthcare Edition
Indexes past intake sessions as vector embeddings and retrieves
semantically similar cases to enrich SOAP note generation.
Production enhancements:
Phase 1:
- Medical-domain embeddings: PubMedBERT for clinical semantic accuracy
- Similarity threshold: discard low-relevance retrievals
- Clinical-aware chunking: per-SOAP-section indexing with metadata
- Cross-encoder reranking: precision boost on top-k candidates
- Retrieval confidence: scored per result for downstream gating
- PHI-safe indexing: redact PHI before embedding
Phase 3:
- Tenant isolation: org_id / provider_id scoped retrieval
- Encrypted vector store: AES-256-GCM encryption of ChromaDB persistence
- RAG audit trail: HIPAA-compliant logging of every retrieval
- Enhanced PHI verification: double-pass redaction with verification
"""
import hashlib
import json
import logging
import time
from typing import List, Dict, Any, Optional
from pathlib import Path
from app.config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Module-level lazy singletons
# ---------------------------------------------------------------------------
_embedding_model = None
_reranker_model = None
_chroma_client = None
_collection = None
# SOAP section types used for clinical-aware chunking
SOAP_SECTIONS = ("subjective", "objective", "assessment", "plan")
def _get_embedding_model():
"""Lazy-load the medical-domain embedding model (PubMedBERT)."""
global _embedding_model
if _embedding_model is None:
from sentence_transformers import SentenceTransformer
model_name = settings.rag_embedding_model
logger.info(f"Loading RAG embedding model: {model_name}")
_embedding_model = SentenceTransformer(model_name)
logger.info("RAG embedding model loaded")
return _embedding_model
def _get_reranker_model():
"""Lazy-load the cross-encoder reranker model."""
global _reranker_model
if _reranker_model is None:
if not settings.rag_reranker_enabled:
return None
from sentence_transformers import CrossEncoder
model_name = settings.rag_reranker_model
logger.info(f"Loading RAG reranker model: {model_name}")
_reranker_model = CrossEncoder(model_name)
logger.info("RAG reranker model loaded")
return _reranker_model
def _get_collection():
"""Lazy-load ChromaDB client and collection."""
global _chroma_client, _collection
if _collection is None:
import chromadb
persist_dir = settings.rag_persist_dir
# Phase 3.3: If vector store encryption is enabled, use an encrypted
# subdirectory that is decrypted on mount. For embedded ChromaDB we
# encrypt/decrypt the whole persist dir at startup/shutdown.
if settings.rag_vector_store_encryption_enabled:
_ensure_vector_store_decrypted(persist_dir)
Path(persist_dir).mkdir(parents=True, exist_ok=True)
_chroma_client = chromadb.PersistentClient(path=persist_dir)
_collection = _chroma_client.get_or_create_collection(
name="intake_sessions_v2",
metadata={"hnsw:space": "cosine"},
)
logger.info(
f"ChromaDB collection ready at '{persist_dir}' "
f"({_collection.count()} documents)"
)
return _collection
def _embed(text: str) -> List[float]:
"""Return a normalised embedding vector for the given text."""
model = _get_embedding_model()
embedding = model.encode(text, normalize_embeddings=True)
vec = embedding.tolist()
# Phase 4.3: Record embedding for drift detection
try:
from app.models.rag_evaluation_service import record_embedding_for_drift
record_embedding_for_drift(vec)
except Exception:
pass
return vec
def _redact_phi_for_embedding(text: str) -> str:
"""
Strip PHI from text before it enters the vector store (Phase 3.1).
Uses enhanced double-pass redaction with verification.
"""
from app.compliance import redact_for_vector_store
redacted, verification = redact_for_vector_store(text)
if not verification["is_clean"]:
logger.warning(
f"RAG PHI verification: {verification['phi_count']} patterns "
f"still detected after redaction: {verification['pattern_types']}"
)
return redacted
def _cosine_similarity_from_distance(distance: float) -> float:
"""ChromaDB cosine distance = 1 - cosine_similarity."""
return max(0.0, 1.0 - distance)
def _build_chunk_id(session_id: str, section: str) -> str:
"""Deterministic chunk ID for a session + section pair."""
return f"{session_id}::{section}"
def _hash_query(text: str) -> str:
"""SHA-256 hash of query text for audit logging (no PHI in logs)."""
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]
def _rerank(query: str, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Re-score candidates using cross-encoder for higher precision."""
reranker = _get_reranker_model()
if reranker is None or not candidates:
return candidates
try:
pairs = [(query, c["document"]) for c in candidates]
scores = reranker.predict(pairs)
for i, score in enumerate(scores):
candidates[i]["rerank_score"] = float(score)
candidates.sort(key=lambda c: c["rerank_score"], reverse=True)
return candidates
except Exception as exc:
logger.warning(f"RAG: reranking failed, using cosine order: {exc}")
return candidates
# ---------------------------------------------------------------------------
# Phase 3.2: Tenant isolation helpers
# ---------------------------------------------------------------------------
def _get_tenant_metadata(
organization_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> Dict[str, str]:
"""Build tenant metadata for indexing."""
return {
"organization_id": organization_id or settings.default_organization_id,
"provider_id": provider_id or settings.default_provider_id,
}
def _build_tenant_filter(
organization_id: Optional[str] = None,
provider_id: Optional[str] = None,
extra_filter: Optional[Dict] = None,
) -> Optional[Dict]:
"""
Build a ChromaDB where-filter that enforces tenant isolation.
When multi-tenancy is enabled, retrieval is scoped to the requesting
organization. Provider-level scoping is optional (for within-org isolation).
"""
if not settings.multi_tenancy_enabled:
return extra_filter
org_id = organization_id or settings.default_organization_id
conditions = [{"organization_id": org_id}]
if provider_id:
conditions.append({"provider_id": provider_id})
if extra_filter:
conditions.append(extra_filter)
if len(conditions) == 1:
return conditions[0]
return {"$and": conditions}
# ---------------------------------------------------------------------------
# Phase 3.3: Encrypted vector store helpers
# ---------------------------------------------------------------------------
_ENCRYPTED_MARKER = ".encrypted"
def _ensure_vector_store_decrypted(persist_dir: str) -> None:
"""
Decrypt the vector store directory if it exists in encrypted form.
Strategy: we store a tarball of the ChromaDB files encrypted with
AES-256-GCM. On startup we decrypt into the working directory.
On shutdown (or periodic flush) we re-encrypt.
"""
encrypted_path = Path(persist_dir + _ENCRYPTED_MARKER)
target_path = Path(persist_dir)
if not encrypted_path.exists():
return # No encrypted archive — first run or already decrypted
try:
from app.encryption import decrypt_bytes
import tarfile
import io
logger.info("Decrypting vector store...")
encrypted_data = encrypted_path.read_bytes()
decrypted_data = decrypt_bytes(encrypted_data)
# Extract tar archive
target_path.mkdir(parents=True, exist_ok=True)
tar_buffer = io.BytesIO(decrypted_data)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
tar.extractall(path=str(target_path), filter="data")
logger.info(f"Vector store decrypted to {persist_dir}")
except Exception as exc:
logger.error(f"Failed to decrypt vector store: {exc}")
raise
def encrypt_vector_store() -> Optional[str]:
"""
Encrypt the vector store directory to an archive file.
Call this on shutdown or periodically to ensure data-at-rest encryption.
Returns the path to the encrypted file, or None if encryption is disabled.
"""
if not settings.rag_vector_store_encryption_enabled:
return None
persist_dir = settings.rag_persist_dir
target_path = Path(persist_dir)
encrypted_path = Path(persist_dir + _ENCRYPTED_MARKER)
if not target_path.exists():
return None
try:
from app.encryption import encrypt_bytes
import tarfile
import io
logger.info("Encrypting vector store...")
# Create tar.gz archive of the directory
tar_buffer = io.BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar:
for file_path in target_path.rglob("*"):
if file_path.is_file():
arcname = file_path.relative_to(target_path)
tar.add(str(file_path), arcname=str(arcname))
tar_data = tar_buffer.getvalue()
encrypted_data = encrypt_bytes(tar_data)
encrypted_path.write_bytes(encrypted_data)
logger.info(
f"Vector store encrypted: {len(tar_data)} bytes → "
f"{len(encrypted_data)} bytes at {encrypted_path}"
)
return str(encrypted_path)
except Exception as exc:
logger.error(f"Failed to encrypt vector store: {exc}")
return None
# ---------------------------------------------------------------------------
# Phase 3.4: RAG audit trail
# ---------------------------------------------------------------------------
def _record_rag_audit(
action: str,
query_hash: str,
retrieved_session_ids: List[str],
similarities: List[float],
organization_id: Optional[str] = None,
provider_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_details: Optional[Dict] = None,
) -> None:
"""
Record a RAG operation in the audit trail.
This is logged as a structured JSON entry for HIPAA "minimum necessary"
documentation. The query text is NOT logged — only its hash.
"""
if not settings.rag_audit_enabled:
return
try:
audit_entry = {
"action": action,
"query_hash": query_hash,
"retrieved_session_ids": retrieved_session_ids,
"similarities": [round(s, 4) for s in similarities],
"result_count": len(retrieved_session_ids),
"organization_id": organization_id or settings.default_organization_id,
"provider_id": provider_id or settings.default_provider_id,
"user_id": user_id,
"timestamp": time.time(),
"threshold": settings.rag_similarity_threshold,
}
if extra_details:
audit_entry["details"] = extra_details
# Log as structured JSON for downstream SIEM/audit ingestion
logger.info(f"RAG_AUDIT: {json.dumps(audit_entry)}")
# Also write to the RAG audit log file for persistent trail
_append_to_audit_file(audit_entry)
except Exception as exc:
logger.warning(f"RAG audit logging failed: {exc}")
def _append_to_audit_file(entry: Dict[str, Any]) -> None:
"""Append an audit entry to the RAG audit log file (JSONL format)."""
try:
audit_dir = Path(settings.rag_persist_dir) / "audit"
audit_dir.mkdir(parents=True, exist_ok=True)
# Daily rotation: one file per day
from datetime import datetime
date_str = datetime.utcnow().strftime("%Y-%m-%d")
audit_file = audit_dir / f"rag_audit_{date_str}.jsonl"
with open(audit_file, "a", encoding="utf-8") as f:
f.write(json.dumps(entry) + "\n")
except Exception as exc:
logger.debug(f"RAG audit file write failed: {exc}")
def get_rag_audit_logs(
date: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""
Read RAG audit logs for compliance review.
Args:
date: Date string (YYYY-MM-DD). Defaults to today.
limit: Max entries to return.
Returns:
List of audit entries (most recent first).
"""
try:
from datetime import datetime
audit_dir = Path(settings.rag_persist_dir) / "audit"
if not audit_dir.exists():
return []
date_str = date or datetime.utcnow().strftime("%Y-%m-%d")
# Validate format to prevent path traversal (e.g. ../../etc/passwd)
import re as _re
if not _re.fullmatch(r"\d{4}-\d{2}-\d{2}", date_str):
logger.warning("Invalid audit log date format rejected: %r", date_str)
return []
audit_file = audit_dir / f"rag_audit_{date_str}.jsonl"
# Confirm the resolved path stays inside the audit directory
if not str(audit_file.resolve()).startswith(str(audit_dir.resolve())):
logger.warning("Path traversal attempt in audit log date rejected")
return []
if not audit_file.exists():
return []
entries = []
with open(audit_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
try:
entries.append(json.loads(line))
except json.JSONDecodeError:
continue
# Return most recent first, limited
entries.reverse()
return entries[:limit]
except Exception as exc:
logger.warning(f"Failed to read RAG audit logs: {exc}")
return []
# ---------------------------------------------------------------------------
# Metrics helpers (lazy import to avoid circular dependency)
# ---------------------------------------------------------------------------
def _record_retrieval_metrics(
latency_s: float,
num_results: int,
threshold_met: bool,
similarities: List[float],
):
"""Push RAG-specific metrics to the monitoring system."""
try:
from app.metrics import (
RAG_RETRIEVAL_LATENCY,
RAG_RETRIEVAL_COUNT,
RAG_SIMILARITY_SCORE,
RAG_FALLBACK_COUNT,
RAG_INDEX_SIZE,
)
RAG_RETRIEVAL_LATENCY.observe(latency_s)
RAG_RETRIEVAL_COUNT.inc(threshold_met="true" if threshold_met else "false")
for sim in similarities:
RAG_SIMILARITY_SCORE.observe(sim)
if not threshold_met:
RAG_FALLBACK_COUNT.inc()
try:
collection = _get_collection()
RAG_INDEX_SIZE.set(collection.count())
except Exception:
pass
except ImportError:
pass
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def index_session(
session_id: str,
transcript: str,
chief_complaint: Optional[str] = None,
soap_subjective: Optional[str] = None,
soap_objective: Optional[str] = None,
soap_assessment: Optional[str] = None,
soap_plan: Optional[str] = None,
organization_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> None:
"""
Index a single intake session into the vector store.
When chunking is enabled (default), each SOAP section is stored as a
separate document with section-type metadata, enabling section-specific
retrieval. The transcript + chief complaint is always indexed as a
"full" chunk for broad matching.
PHI is redacted from all text before embedding.
Tenant metadata (org_id, provider_id) is attached for isolation (Phase 3.2).
"""
if not settings.rag_enabled:
return
try:
collection = _get_collection()
# PHI-safe: redact before embedding (Phase 3.1 — verified)
safe_transcript = _redact_phi_for_embedding(transcript.strip())
safe_cc = _redact_phi_for_embedding(chief_complaint or "")
# Phase 3.2: tenant metadata
tenant_meta = _get_tenant_metadata(organization_id, provider_id)
# --- Full-session chunk (always indexed) ---
embed_parts = [safe_transcript]
if safe_cc and safe_cc not in ("not specified", ""):
embed_parts.append(f"Chief complaint: {safe_cc}")
embed_text = " ".join(embed_parts)
full_doc_parts: List[str] = []
if safe_cc:
full_doc_parts.append(f"Chief complaint: {safe_cc}")
soap_map = {
"subjective": soap_subjective,
"objective": soap_objective,
"assessment": soap_assessment,
"plan": soap_plan,
}
for section, content in soap_map.items():
if content:
safe_content = _redact_phi_for_embedding(content)
full_doc_parts.append(f"{section.title()}: {safe_content}")
full_document = "\n".join(full_doc_parts) if full_doc_parts else safe_transcript[:500]
full_metadata = {
"session_id": session_id,
"section_type": "full",
"has_soap": bool(soap_subjective),
"chief_complaint": (safe_cc or "")[:200],
**tenant_meta,
}
ids = [_build_chunk_id(session_id, "full")]
embeddings = [_embed(embed_text)]
documents = [full_document]
metadatas = [full_metadata]
# --- Per-section chunks (if chunking enabled) ---
if settings.rag_chunking_enabled:
for section, content in soap_map.items():
if not content:
continue
safe_content = _redact_phi_for_embedding(content)
section_embed_text = f"{section.title()}: {safe_content}"
if safe_cc:
section_embed_text = f"Chief complaint: {safe_cc}. {section_embed_text}"
ids.append(_build_chunk_id(session_id, section))
embeddings.append(_embed(section_embed_text))
documents.append(f"{section.title()}: {safe_content}")
metadatas.append({
"session_id": session_id,
"section_type": section,
"has_soap": True,
"chief_complaint": (safe_cc or "")[:200],
**tenant_meta,
})
collection.upsert(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
# Phase 3.4: Audit the indexing operation
_record_rag_audit(
action="index",
query_hash="N/A",
retrieved_session_ids=[session_id],
similarities=[],
organization_id=tenant_meta["organization_id"],
provider_id=tenant_meta["provider_id"],
extra_details={"chunks_indexed": len(ids)},
)
logger.info(
f"RAG: indexed session {session_id} "
f"({len(ids)} chunk{'s' if len(ids) > 1 else ''}, "
f"org={tenant_meta['organization_id']})"
)
except Exception as exc:
logger.warning(f"RAG: failed to index session {session_id}: {exc}")
def retrieve_similar_sessions(
transcript: str,
top_k: Optional[int] = None,
exclude_id: Optional[str] = None,
section_filter: Optional[str] = None,
organization_id: Optional[str] = None,
provider_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Retrieve the top-k most similar past sessions for a given transcript.
Pipeline:
1. Embed query with PubMedBERT
2. Retrieve initial_retrieval_k candidates from ChromaDB
(scoped to tenant if multi-tenancy enabled — Phase 3.2)
3. Filter by similarity threshold
4. Rerank with cross-encoder (if enabled)
5. Return top-k with confidence scores
6. Log retrieval to audit trail (Phase 3.4)
Returns a list of dicts with keys:
id, document, metadata, distance, similarity, retrieval_confidence
"""
if not settings.rag_enabled:
return []
t0 = time.time()
query_hash = _hash_query(transcript)
try:
collection = _get_collection()
total = collection.count()
if total == 0:
logger.debug("RAG collection is empty — skipping retrieval")
return []
k = top_k or settings.rag_top_k
initial_k = min(settings.rag_initial_retrieval_k, total)
if initial_k <= 0:
return []
# PHI-safe query
safe_query = _redact_phi_for_embedding(transcript)
query_embedding = _embed(safe_query)
# Phase 3.2: Build tenant-scoped filter
section_where = None
if section_filter and section_filter in SOAP_SECTIONS:
section_where = {"section_type": section_filter}
where_filter = _build_tenant_filter(
organization_id=organization_id,
provider_id=provider_id,
extra_filter=section_where,
)
results = collection.query(
query_embeddings=[query_embedding],
n_results=initial_k,
include=["documents", "metadatas", "distances"],
where=where_filter,
)
# --- Step 1: Convert to candidates with similarity scores ---
candidates: List[Dict[str, Any]] = []
seen_sessions = set()
for i in range(len(results["ids"][0])):
doc_id = results["ids"][0][i]
metadata = results["metadatas"][0][i]
session_id = metadata.get("session_id", doc_id.split("::")[0])
if exclude_id and session_id == exclude_id:
continue
distance = results["distances"][0][i]
similarity = _cosine_similarity_from_distance(distance)
if similarity < settings.rag_similarity_threshold:
continue
candidates.append({
"id": doc_id,
"session_id": session_id,
"document": results["documents"][0][i],
"metadata": metadata,
"distance": distance,
"similarity": round(similarity, 4),
})
# --- Step 3: Rerank with cross-encoder ---
if settings.rag_reranker_enabled and len(candidates) > 1:
candidates = _rerank(safe_query, candidates)
# --- Step 4: Deduplicate by session ---
deduped: List[Dict[str, Any]] = []
for c in candidates:
sid = c["session_id"]
if sid not in seen_sessions:
seen_sessions.add(sid)
deduped.append(c)
if len(deduped) >= k:
break
# --- Step 5: Compute retrieval confidence ---
for item in deduped:
sim = item["similarity"]
if sim >= 0.85:
item["retrieval_confidence"] = "high"
elif sim >= 0.75:
item["retrieval_confidence"] = "medium"
else:
item["retrieval_confidence"] = "low"
latency = time.time() - t0
similarities = [c["similarity"] for c in deduped]
_record_retrieval_metrics(
latency_s=latency,
num_results=len(deduped),
threshold_met=len(deduped) > 0,
similarities=similarities,
)
# Phase 3.4: Audit trail
_record_rag_audit(
action="retrieve",
query_hash=query_hash,
retrieved_session_ids=[c["session_id"] for c in deduped],
similarities=similarities,
organization_id=organization_id,
provider_id=provider_id,
user_id=user_id,
extra_details={
"candidates_before_filter": len(candidates),
"latency_s": round(latency, 3),
"reranker_used": settings.rag_reranker_enabled,
},
)
logger.info(
f"RAG: retrieved {len(deduped)} results "
f"(from {len(candidates)} candidates, "
f"threshold={settings.rag_similarity_threshold}, "
f"latency={latency:.3f}s)"
)
return deduped
except Exception as exc:
logger.warning(f"RAG: retrieval failed: {exc}")
return []
def remove_session(session_id: str) -> None:
"""Remove all chunks for a session from the vector store."""
if not settings.rag_enabled:
return
try:
collection = _get_collection()
ids_to_remove = [_build_chunk_id(session_id, "full")]
for section in SOAP_SECTIONS:
ids_to_remove.append(_build_chunk_id(session_id, section))
collection.delete(ids=ids_to_remove)
_record_rag_audit(
action="delete",
query_hash="N/A",
retrieved_session_ids=[session_id],
similarities=[],
)
logger.info(f"RAG: removed session {session_id} (all chunks)")
except Exception as exc:
logger.warning(f"RAG: failed to remove session {session_id}: {exc}")
def get_index_stats() -> Dict[str, Any]:
"""Return statistics about the RAG vector store."""
if not settings.rag_enabled:
return {"enabled": False}
try:
collection = _get_collection()
count = collection.count()
section_counts = {}
for section in ["full"] + list(SOAP_SECTIONS):
try:
result = collection.get(
where={"section_type": section},
include=[],
)
section_counts[section] = len(result["ids"])
except Exception:
section_counts[section] = "unknown"
return {
"enabled": True,
"total_chunks": count,
"chunks_by_section": section_counts,
"embedding_model": settings.rag_embedding_model,
"reranker_model": settings.rag_reranker_model if settings.rag_reranker_enabled else "disabled",
"similarity_threshold": settings.rag_similarity_threshold,
"top_k": settings.rag_top_k,
"initial_retrieval_k": settings.rag_initial_retrieval_k,
"chunking_enabled": settings.rag_chunking_enabled,
"persist_dir": settings.rag_persist_dir,
"multi_tenancy_enabled": settings.multi_tenancy_enabled,
"vector_store_encrypted": settings.rag_vector_store_encryption_enabled,
"audit_enabled": settings.rag_audit_enabled,
}
except Exception as exc:
return {"enabled": True, "error": str(exc)}
def retrieve_enriched_context(
transcript: str,
top_k: Optional[int] = None,
exclude_id: Optional[str] = None,
organization_id: Optional[str] = None,
provider_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Unified retrieval that merges session-based RAG with knowledge base
guidelines (Phase 2 integration), scoped to tenant (Phase 3.2).
Returns a dict with:
- similar_sessions: list from retrieve_similar_sessions()
- clinical_guidelines: list from knowledge_base_service
- has_context: bool indicating whether any context was found
"""
similar_sessions = retrieve_similar_sessions(
transcript,
top_k=top_k,
exclude_id=exclude_id,
organization_id=organization_id,
provider_id=provider_id,
user_id=user_id,
)
clinical_guidelines = []
try:
from app.models.knowledge_base_service import retrieve_guidelines
clinical_guidelines = retrieve_guidelines(transcript)
except Exception as exc:
logger.debug(f"RAG: knowledge base retrieval skipped: {exc}")
return {
"similar_sessions": similar_sessions,
"clinical_guidelines": clinical_guidelines,
"has_context": bool(similar_sessions or clinical_guidelines),
}
def is_ready() -> bool:
"""Return True when RAG is enabled and the collection is accessible."""
if not settings.rag_enabled:
return False
try:
_get_collection()
return True
except Exception:
return False