Spaces:
Sleeping
Sleeping
| """ | |
| Therapy ASR Module - Multi-engine speech recognition for therapy applications. | |
| Supports: | |
| - Local Whisper (general speech, privacy-focused) | |
| - SpeechBrain (fine-tuned for atypical speech) | |
| - OpenAI Whisper API (fallback) | |
| """ | |
| import io | |
| import logging | |
| from enum import Enum | |
| from typing import Optional | |
| from dataclasses import dataclass | |
| from api.config import settings | |
| if settings.ENVIRONMENT == "development": | |
| logging.basicConfig(level=logging.DEBUG) | |
| else: | |
| logging.basicConfig(level=logging.WARNING) | |
| class ASREngine(str, Enum): | |
| """Available ASR engines.""" | |
| WHISPER_LOCAL = "whisper_local" | |
| SPEECHBRAIN = "speechbrain" | |
| WHISPER_API = "whisper_api" | |
| AUTO = "auto" # Automatically select based on user profile | |
| class TranscriptionResult: | |
| """Structured transcription result.""" | |
| text: str | |
| engine_used: ASREngine | |
| confidence: Optional[float] = None | |
| word_timestamps: Optional[list] = None | |
| language: Optional[str] = None | |
| class TherapyASR: | |
| """ | |
| Multi-engine ASR for therapy applications. | |
| Supports automatic engine selection based on user speech profile, | |
| with fallback chain for reliability. | |
| """ | |
| def __init__(self, default_engine: ASREngine = ASREngine.AUTO): | |
| self.default_engine = default_engine | |
| self._whisper_local_model = None | |
| self._speechbrain_model = None | |
| self._openai_client = None | |
| def _get_openai_client(self): | |
| """Lazy load OpenAI client.""" | |
| if self._openai_client is None: | |
| from openai import OpenAI | |
| self._openai_client = OpenAI(api_key=settings.OPENAI_API_KEY) | |
| return self._openai_client | |
| def _get_whisper_local(self): | |
| """Lazy load local Whisper model.""" | |
| if self._whisper_local_model is None: | |
| try: | |
| import torch | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| model_name = "openai/whisper-base" # Start with base, upgrade as needed | |
| logging.info(f"Loading local Whisper model: {model_name}") | |
| self._whisper_processor = WhisperProcessor.from_pretrained(model_name) | |
| self._whisper_local_model = WhisperForConditionalGeneration.from_pretrained(model_name) | |
| # Use GPU if available | |
| if torch.cuda.is_available(): | |
| self._whisper_local_model = self._whisper_local_model.to("cuda") | |
| elif torch.backends.mps.is_available(): | |
| self._whisper_local_model = self._whisper_local_model.to("mps") | |
| logging.info("Local Whisper model loaded successfully") | |
| except ImportError as e: | |
| logging.warning(f"Local Whisper not available: {e}") | |
| raise | |
| return self._whisper_local_model | |
| def _get_speechbrain(self): | |
| """Lazy load SpeechBrain model for atypical speech.""" | |
| if self._speechbrain_model is None: | |
| try: | |
| import speechbrain as sb | |
| # Use pre-trained model, can be swapped for fine-tuned version | |
| model_source = "speechbrain/asr-wav2vec2-commonvoice-en" | |
| logging.info(f"Loading SpeechBrain model: {model_source}") | |
| self._speechbrain_model = sb.pretrained.EncoderASR.from_hparams( | |
| source=model_source, | |
| savedir="models/speechbrain_asr" | |
| ) | |
| logging.info("SpeechBrain model loaded successfully") | |
| except ImportError as e: | |
| logging.warning(f"SpeechBrain not available: {e}") | |
| raise | |
| return self._speechbrain_model | |
| def _select_engine(self, user_profile: Optional[dict] = None) -> ASREngine: | |
| """Select appropriate ASR engine based on user profile.""" | |
| if self.default_engine != ASREngine.AUTO: | |
| return self.default_engine | |
| if user_profile: | |
| # Use SpeechBrain for users with speech conditions | |
| speech_condition = user_profile.get("speech_condition") | |
| if speech_condition in ["dysarthria", "apraxia", "autism", "stuttering"]: | |
| return ASREngine.SPEECHBRAIN | |
| # Use local Whisper for privacy-focused users | |
| if user_profile.get("privacy_mode") == "local": | |
| return ASREngine.WHISPER_LOCAL | |
| # Default to API for best accuracy | |
| return ASREngine.WHISPER_API | |
| def transcribe( | |
| self, | |
| audio_data: bytes, | |
| filename: str = "audio.wav", | |
| content_type: str = "audio/wav", | |
| user_profile: Optional[dict] = None, | |
| engine: Optional[ASREngine] = None | |
| ) -> TranscriptionResult: | |
| """ | |
| Transcribe audio using the most appropriate engine. | |
| Args: | |
| audio_data: Raw audio bytes | |
| filename: Original filename | |
| content_type: MIME type of audio | |
| user_profile: Optional user profile for engine selection | |
| engine: Force specific engine (overrides auto-selection) | |
| Returns: | |
| TranscriptionResult with text and metadata | |
| """ | |
| selected_engine = engine or self._select_engine(user_profile) | |
| logging.info(f"Transcribing with engine: {selected_engine.value}") | |
| # Try selected engine with fallback chain | |
| fallback_order = [selected_engine] | |
| if selected_engine != ASREngine.WHISPER_API: | |
| fallback_order.append(ASREngine.WHISPER_API) | |
| last_error = None | |
| for eng in fallback_order: | |
| try: | |
| if eng == ASREngine.WHISPER_API: | |
| return self._transcribe_whisper_api(audio_data, filename, content_type) | |
| elif eng == ASREngine.WHISPER_LOCAL: | |
| return self._transcribe_whisper_local(audio_data) | |
| elif eng == ASREngine.SPEECHBRAIN: | |
| return self._transcribe_speechbrain(audio_data) | |
| except Exception as e: | |
| logging.warning(f"Engine {eng.value} failed: {e}") | |
| last_error = e | |
| continue | |
| raise RuntimeError(f"All ASR engines failed. Last error: {last_error}") | |
| def _transcribe_whisper_api( | |
| self, | |
| audio_data: bytes, | |
| filename: str, | |
| content_type: str | |
| ) -> TranscriptionResult: | |
| """Transcribe using OpenAI Whisper API.""" | |
| logging.info("Transcribing with OpenAI Whisper API") | |
| client = self._get_openai_client() | |
| file_data = (filename, audio_data, content_type) | |
| transcription = client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=file_data, | |
| response_format="verbose_json", | |
| timestamp_granularities=["word"] | |
| ) | |
| # Extract word timestamps if available | |
| word_timestamps = None | |
| if hasattr(transcription, 'words'): | |
| word_timestamps = [ | |
| {"word": w.word, "start": w.start, "end": w.end} | |
| for w in transcription.words | |
| ] | |
| return TranscriptionResult( | |
| text=transcription.text, | |
| engine_used=ASREngine.WHISPER_API, | |
| language=getattr(transcription, 'language', None), | |
| word_timestamps=word_timestamps | |
| ) | |
| def _transcribe_whisper_local(self, audio_data: bytes) -> TranscriptionResult: | |
| """Transcribe using local Whisper model.""" | |
| logging.info("Transcribing with local Whisper") | |
| import torch | |
| import librosa | |
| import numpy as np | |
| model = self._get_whisper_local() | |
| # Load audio from bytes | |
| audio_array, sr = librosa.load(io.BytesIO(audio_data), sr=16000) | |
| # Process audio | |
| input_features = self._whisper_processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_features | |
| # Move to same device as model | |
| device = next(model.parameters()).device | |
| input_features = input_features.to(device) | |
| # Generate transcription | |
| with torch.no_grad(): | |
| predicted_ids = model.generate(input_features) | |
| transcription = self._whisper_processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return TranscriptionResult( | |
| text=transcription.strip(), | |
| engine_used=ASREngine.WHISPER_LOCAL | |
| ) | |
| def _transcribe_speechbrain(self, audio_data: bytes) -> TranscriptionResult: | |
| """Transcribe using SpeechBrain (optimized for atypical speech).""" | |
| logging.info("Transcribing with SpeechBrain") | |
| import tempfile | |
| import os | |
| model = self._get_speechbrain() | |
| # SpeechBrain requires file path, write temp file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| f.write(audio_data) | |
| temp_path = f.name | |
| try: | |
| transcription = model.transcribe_file(temp_path) | |
| # Handle different return types | |
| if isinstance(transcription, list): | |
| text = transcription[0] if transcription else "" | |
| else: | |
| text = str(transcription) | |
| return TranscriptionResult( | |
| text=text.strip(), | |
| engine_used=ASREngine.SPEECHBRAIN | |
| ) | |
| finally: | |
| os.unlink(temp_path) | |
| # Singleton instance for reuse | |
| _therapy_asr_instance: Optional[TherapyASR] = None | |
| def get_therapy_asr() -> TherapyASR: | |
| """Get or create TherapyASR singleton.""" | |
| global _therapy_asr_instance | |
| if _therapy_asr_instance is None: | |
| _therapy_asr_instance = TherapyASR() | |
| return _therapy_asr_instance | |
| def transcribe_for_therapy( | |
| audio_data: bytes, | |
| filename: str = "audio.wav", | |
| content_type: str = "audio/wav", | |
| user_profile: Optional[dict] = None, | |
| engine: Optional[ASREngine] = None | |
| ) -> TranscriptionResult: | |
| """ | |
| Convenience function to transcribe audio for therapy. | |
| This is the main entry point for therapy transcription. | |
| """ | |
| asr = get_therapy_asr() | |
| return asr.transcribe( | |
| audio_data=audio_data, | |
| filename=filename, | |
| content_type=content_type, | |
| user_profile=user_profile, | |
| engine=engine | |
| ) | |