Macbook
Add FastAPI application
cc4ea58
raw
history blame
10.3 kB
"""
Therapy ASR Module - Multi-engine speech recognition for therapy applications.
Supports:
- Local Whisper (general speech, privacy-focused)
- SpeechBrain (fine-tuned for atypical speech)
- OpenAI Whisper API (fallback)
"""
import io
import logging
from enum import Enum
from typing import Optional
from dataclasses import dataclass
from api.config import settings
if settings.ENVIRONMENT == "development":
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.WARNING)
class ASREngine(str, Enum):
"""Available ASR engines."""
WHISPER_LOCAL = "whisper_local"
SPEECHBRAIN = "speechbrain"
WHISPER_API = "whisper_api"
AUTO = "auto" # Automatically select based on user profile
@dataclass
class TranscriptionResult:
"""Structured transcription result."""
text: str
engine_used: ASREngine
confidence: Optional[float] = None
word_timestamps: Optional[list] = None
language: Optional[str] = None
class TherapyASR:
"""
Multi-engine ASR for therapy applications.
Supports automatic engine selection based on user speech profile,
with fallback chain for reliability.
"""
def __init__(self, default_engine: ASREngine = ASREngine.AUTO):
self.default_engine = default_engine
self._whisper_local_model = None
self._speechbrain_model = None
self._openai_client = None
def _get_openai_client(self):
"""Lazy load OpenAI client."""
if self._openai_client is None:
from openai import OpenAI
self._openai_client = OpenAI(api_key=settings.OPENAI_API_KEY)
return self._openai_client
def _get_whisper_local(self):
"""Lazy load local Whisper model."""
if self._whisper_local_model is None:
try:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
model_name = "openai/whisper-base" # Start with base, upgrade as needed
logging.info(f"Loading local Whisper model: {model_name}")
self._whisper_processor = WhisperProcessor.from_pretrained(model_name)
self._whisper_local_model = WhisperForConditionalGeneration.from_pretrained(model_name)
# Use GPU if available
if torch.cuda.is_available():
self._whisper_local_model = self._whisper_local_model.to("cuda")
elif torch.backends.mps.is_available():
self._whisper_local_model = self._whisper_local_model.to("mps")
logging.info("Local Whisper model loaded successfully")
except ImportError as e:
logging.warning(f"Local Whisper not available: {e}")
raise
return self._whisper_local_model
def _get_speechbrain(self):
"""Lazy load SpeechBrain model for atypical speech."""
if self._speechbrain_model is None:
try:
import speechbrain as sb
# Use pre-trained model, can be swapped for fine-tuned version
model_source = "speechbrain/asr-wav2vec2-commonvoice-en"
logging.info(f"Loading SpeechBrain model: {model_source}")
self._speechbrain_model = sb.pretrained.EncoderASR.from_hparams(
source=model_source,
savedir="models/speechbrain_asr"
)
logging.info("SpeechBrain model loaded successfully")
except ImportError as e:
logging.warning(f"SpeechBrain not available: {e}")
raise
return self._speechbrain_model
def _select_engine(self, user_profile: Optional[dict] = None) -> ASREngine:
"""Select appropriate ASR engine based on user profile."""
if self.default_engine != ASREngine.AUTO:
return self.default_engine
if user_profile:
# Use SpeechBrain for users with speech conditions
speech_condition = user_profile.get("speech_condition")
if speech_condition in ["dysarthria", "apraxia", "autism", "stuttering"]:
return ASREngine.SPEECHBRAIN
# Use local Whisper for privacy-focused users
if user_profile.get("privacy_mode") == "local":
return ASREngine.WHISPER_LOCAL
# Default to API for best accuracy
return ASREngine.WHISPER_API
def transcribe(
self,
audio_data: bytes,
filename: str = "audio.wav",
content_type: str = "audio/wav",
user_profile: Optional[dict] = None,
engine: Optional[ASREngine] = None
) -> TranscriptionResult:
"""
Transcribe audio using the most appropriate engine.
Args:
audio_data: Raw audio bytes
filename: Original filename
content_type: MIME type of audio
user_profile: Optional user profile for engine selection
engine: Force specific engine (overrides auto-selection)
Returns:
TranscriptionResult with text and metadata
"""
selected_engine = engine or self._select_engine(user_profile)
logging.info(f"Transcribing with engine: {selected_engine.value}")
# Try selected engine with fallback chain
fallback_order = [selected_engine]
if selected_engine != ASREngine.WHISPER_API:
fallback_order.append(ASREngine.WHISPER_API)
last_error = None
for eng in fallback_order:
try:
if eng == ASREngine.WHISPER_API:
return self._transcribe_whisper_api(audio_data, filename, content_type)
elif eng == ASREngine.WHISPER_LOCAL:
return self._transcribe_whisper_local(audio_data)
elif eng == ASREngine.SPEECHBRAIN:
return self._transcribe_speechbrain(audio_data)
except Exception as e:
logging.warning(f"Engine {eng.value} failed: {e}")
last_error = e
continue
raise RuntimeError(f"All ASR engines failed. Last error: {last_error}")
def _transcribe_whisper_api(
self,
audio_data: bytes,
filename: str,
content_type: str
) -> TranscriptionResult:
"""Transcribe using OpenAI Whisper API."""
logging.info("Transcribing with OpenAI Whisper API")
client = self._get_openai_client()
file_data = (filename, audio_data, content_type)
transcription = client.audio.transcriptions.create(
model="whisper-1",
file=file_data,
response_format="verbose_json",
timestamp_granularities=["word"]
)
# Extract word timestamps if available
word_timestamps = None
if hasattr(transcription, 'words'):
word_timestamps = [
{"word": w.word, "start": w.start, "end": w.end}
for w in transcription.words
]
return TranscriptionResult(
text=transcription.text,
engine_used=ASREngine.WHISPER_API,
language=getattr(transcription, 'language', None),
word_timestamps=word_timestamps
)
def _transcribe_whisper_local(self, audio_data: bytes) -> TranscriptionResult:
"""Transcribe using local Whisper model."""
logging.info("Transcribing with local Whisper")
import torch
import librosa
import numpy as np
model = self._get_whisper_local()
# Load audio from bytes
audio_array, sr = librosa.load(io.BytesIO(audio_data), sr=16000)
# Process audio
input_features = self._whisper_processor(
audio_array,
sampling_rate=16000,
return_tensors="pt"
).input_features
# Move to same device as model
device = next(model.parameters()).device
input_features = input_features.to(device)
# Generate transcription
with torch.no_grad():
predicted_ids = model.generate(input_features)
transcription = self._whisper_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return TranscriptionResult(
text=transcription.strip(),
engine_used=ASREngine.WHISPER_LOCAL
)
def _transcribe_speechbrain(self, audio_data: bytes) -> TranscriptionResult:
"""Transcribe using SpeechBrain (optimized for atypical speech)."""
logging.info("Transcribing with SpeechBrain")
import tempfile
import os
model = self._get_speechbrain()
# SpeechBrain requires file path, write temp file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_data)
temp_path = f.name
try:
transcription = model.transcribe_file(temp_path)
# Handle different return types
if isinstance(transcription, list):
text = transcription[0] if transcription else ""
else:
text = str(transcription)
return TranscriptionResult(
text=text.strip(),
engine_used=ASREngine.SPEECHBRAIN
)
finally:
os.unlink(temp_path)
# Singleton instance for reuse
_therapy_asr_instance: Optional[TherapyASR] = None
def get_therapy_asr() -> TherapyASR:
"""Get or create TherapyASR singleton."""
global _therapy_asr_instance
if _therapy_asr_instance is None:
_therapy_asr_instance = TherapyASR()
return _therapy_asr_instance
def transcribe_for_therapy(
audio_data: bytes,
filename: str = "audio.wav",
content_type: str = "audio/wav",
user_profile: Optional[dict] = None,
engine: Optional[ASREngine] = None
) -> TranscriptionResult:
"""
Convenience function to transcribe audio for therapy.
This is the main entry point for therapy transcription.
"""
asr = get_therapy_asr()
return asr.transcribe(
audio_data=audio_data,
filename=filename,
content_type=content_type,
user_profile=user_profile,
engine=engine
)