File size: 10,602 Bytes
cc4ea58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
"""
Therapy TTS Module - Text-to-speech for therapy and AAC applications.

Supports:
- WhisperSpeech (fast, voice cloning)
- OpenAI TTS API (fallback)
- Edge TTS (lightweight fallback)
"""

import io
import logging
from enum import Enum
from typing import Optional
from dataclasses import dataclass

from api.config import settings

if settings.ENVIRONMENT == "development":
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.WARNING)


class TTSEngine(str, Enum):
    """Available TTS engines."""
    WHISPERSPEECH = "whisperspeech"
    OPENAI_TTS = "openai_tts"
    EDGE_TTS = "edge_tts"
    AUTO = "auto"


class TTSVoice(str, Enum):
    """Preset voice options."""
    NEUTRAL = "neutral"
    WARM = "warm"
    CLEAR = "clear"
    SLOW = "slow"  # For therapy exercises
    CUSTOM = "custom"  # Voice cloning


@dataclass
class TTSResult:
    """TTS synthesis result."""
    audio_bytes: bytes
    format: str  # wav, mp3
    sample_rate: int
    engine_used: TTSEngine
    duration_seconds: Optional[float] = None


class TherapyTTS:
    """
    TTS engine for therapy applications.

    Features:
    - Voice cloning from reference audio
    - Adjustable speed for therapy exercises
    - Multiple engine support with fallback
    """

    def __init__(self, default_engine: TTSEngine = TTSEngine.AUTO):
        self.default_engine = default_engine
        self._whisperspeech_pipe = None
        self._openai_client = None

    def _get_openai_client(self):
        """Lazy load OpenAI client."""
        if self._openai_client is None:
            from openai import OpenAI
            self._openai_client = OpenAI(api_key=settings.OPENAI_API_KEY)
        return self._openai_client

    def _get_whisperspeech(self):
        """Lazy load WhisperSpeech pipeline."""
        if self._whisperspeech_pipe is None:
            try:
                from whisperspeech.pipeline import Pipeline
                logging.info("Loading WhisperSpeech pipeline...")
                self._whisperspeech_pipe = Pipeline(
                    s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model'
                )
                logging.info("WhisperSpeech loaded successfully")
            except ImportError as e:
                logging.warning(f"WhisperSpeech not available: {e}")
                raise
        return self._whisperspeech_pipe

    def _select_engine(self, voice_reference: Optional[bytes] = None) -> TTSEngine:
        """Select TTS engine based on requirements."""
        if self.default_engine != TTSEngine.AUTO:
            return self.default_engine

        # Use WhisperSpeech for voice cloning
        if voice_reference:
            return TTSEngine.WHISPERSPEECH

        # Default to OpenAI for quality
        return TTSEngine.OPENAI_TTS

    def synthesize(
        self,
        text: str,
        voice: TTSVoice = TTSVoice.NEUTRAL,
        speed: float = 1.0,
        voice_reference: Optional[bytes] = None,
        engine: Optional[TTSEngine] = None,
        output_format: str = "wav"
    ) -> TTSResult:
        """
        Synthesize speech from text.

        Args:
            text: Text to synthesize
            voice: Voice preset to use
            speed: Speech rate (0.5 = slow, 1.0 = normal, 2.0 = fast)
            voice_reference: Audio bytes for voice cloning
            engine: Force specific engine
            output_format: Output format (wav, mp3)

        Returns:
            TTSResult with audio bytes
        """
        selected_engine = engine or self._select_engine(voice_reference)
        logging.info(f"Synthesizing with engine: {selected_engine.value}")

        # Fallback chain
        fallback_order = [selected_engine]
        if selected_engine != TTSEngine.OPENAI_TTS:
            fallback_order.append(TTSEngine.OPENAI_TTS)

        last_error = None
        for eng in fallback_order:
            try:
                if eng == TTSEngine.OPENAI_TTS:
                    return self._synthesize_openai(text, voice, speed, output_format)
                elif eng == TTSEngine.WHISPERSPEECH:
                    return self._synthesize_whisperspeech(
                        text, voice_reference, speed, output_format
                    )
                elif eng == TTSEngine.EDGE_TTS:
                    return self._synthesize_edge_tts(text, voice, speed, output_format)
            except Exception as e:
                logging.warning(f"Engine {eng.value} failed: {e}")
                last_error = e
                continue

        raise RuntimeError(f"All TTS engines failed. Last error: {last_error}")

    def _synthesize_openai(
        self,
        text: str,
        voice: TTSVoice,
        speed: float,
        output_format: str
    ) -> TTSResult:
        """Synthesize using OpenAI TTS API."""
        logging.info("Synthesizing with OpenAI TTS")

        client = self._get_openai_client()

        # Map voice presets to OpenAI voices
        voice_map = {
            TTSVoice.NEUTRAL: "alloy",
            TTSVoice.WARM: "nova",
            TTSVoice.CLEAR: "onyx",
            TTSVoice.SLOW: "alloy",  # Use speed parameter
            TTSVoice.CUSTOM: "alloy",
        }

        response = client.audio.speech.create(
            model="tts-1",
            voice=voice_map.get(voice, "alloy"),
            input=text,
            speed=speed,
            response_format="wav" if output_format == "wav" else "mp3"
        )

        audio_bytes = response.content

        return TTSResult(
            audio_bytes=audio_bytes,
            format=output_format,
            sample_rate=24000,
            engine_used=TTSEngine.OPENAI_TTS
        )

    def _synthesize_whisperspeech(
        self,
        text: str,
        voice_reference: Optional[bytes],
        speed: float,
        output_format: str
    ) -> TTSResult:
        """Synthesize using WhisperSpeech with optional voice cloning."""
        logging.info("Synthesizing with WhisperSpeech")

        import torch
        import numpy as np

        pipe = self._get_whisperspeech()

        # Generate audio
        if voice_reference:
            # Voice cloning mode
            import tempfile
            import os

            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                f.write(voice_reference)
                ref_path = f.name

            try:
                audio = pipe.generate(text, speaker=ref_path)
            finally:
                os.unlink(ref_path)
        else:
            audio = pipe.generate(text)

        # Convert to bytes
        if isinstance(audio, torch.Tensor):
            audio_np = audio.cpu().numpy()
        else:
            audio_np = np.array(audio)

        # Ensure correct shape
        if audio_np.ndim > 1:
            audio_np = audio_np.squeeze()

        # Apply speed adjustment if needed
        if speed != 1.0:
            import librosa
            audio_np = librosa.effects.time_stretch(audio_np, rate=speed)

        # Convert to wav bytes
        import soundfile as sf
        buffer = io.BytesIO()
        sf.write(buffer, audio_np, 24000, format='WAV')
        buffer.seek(0)

        return TTSResult(
            audio_bytes=buffer.read(),
            format="wav",
            sample_rate=24000,
            engine_used=TTSEngine.WHISPERSPEECH,
            duration_seconds=len(audio_np) / 24000
        )

    def _synthesize_edge_tts(
        self,
        text: str,
        voice: TTSVoice,
        speed: float,
        output_format: str
    ) -> TTSResult:
        """Synthesize using Edge TTS (lightweight fallback)."""
        logging.info("Synthesizing with Edge TTS")

        import asyncio
        import edge_tts

        # Map voice presets to Edge TTS voices
        voice_map = {
            TTSVoice.NEUTRAL: "en-US-JennyNeural",
            TTSVoice.WARM: "en-US-AriaNeural",
            TTSVoice.CLEAR: "en-US-GuyNeural",
            TTSVoice.SLOW: "en-US-JennyNeural",
            TTSVoice.CUSTOM: "en-US-JennyNeural",
        }

        async def _generate():
            communicate = edge_tts.Communicate(
                text,
                voice_map.get(voice, "en-US-JennyNeural"),
                rate=f"{int((speed - 1) * 100):+d}%"
            )
            buffer = io.BytesIO()
            async for chunk in communicate.stream():
                if chunk["type"] == "audio":
                    buffer.write(chunk["data"])
            return buffer.getvalue()

        audio_bytes = asyncio.run(_generate())

        return TTSResult(
            audio_bytes=audio_bytes,
            format="mp3",
            sample_rate=24000,
            engine_used=TTSEngine.EDGE_TTS
        )

    def generate_therapy_prompt(
        self,
        exercise_type: str,
        target_text: str,
        **kwargs
    ) -> TTSResult:
        """
        Generate therapy exercise audio prompt.

        Args:
            exercise_type: Type of exercise (repeat_after_me, pronunciation, etc.)
            target_text: The text to practice
            **kwargs: Additional synthesis parameters

        Returns:
            TTSResult with exercise audio
        """
        prompts = {
            "repeat_after_me": f"Please repeat after me: {target_text}",
            "pronunciation": f"Let's practice saying: {target_text}. Listen carefully.",
            "slower": f"Now try saying it more slowly: {target_text}",
            "word_by_word": f"Let's break it down. {target_text}",
            "encouragement": f"Great try! Let's practice {target_text} again.",
        }

        prompt_text = prompts.get(exercise_type, target_text)

        # Use slower speed for therapy prompts
        speed = kwargs.pop("speed", 0.9)

        return self.synthesize(
            text=prompt_text,
            speed=speed,
            voice=TTSVoice.CLEAR,
            **kwargs
        )


# Singleton instance
_therapy_tts_instance: Optional[TherapyTTS] = None


def get_therapy_tts() -> TherapyTTS:
    """Get or create TherapyTTS singleton."""
    global _therapy_tts_instance
    if _therapy_tts_instance is None:
        _therapy_tts_instance = TherapyTTS()
    return _therapy_tts_instance


def synthesize_speech(
    text: str,
    voice: TTSVoice = TTSVoice.NEUTRAL,
    speed: float = 1.0,
    voice_reference: Optional[bytes] = None
) -> TTSResult:
    """Convenience function for TTS synthesis."""
    tts = get_therapy_tts()
    return tts.synthesize(
        text=text,
        voice=voice,
        speed=speed,
        voice_reference=voice_reference
    )