| | from typing import Dict, Any, List, Generator |
| | import torch |
| | import os |
| | import logging |
| | from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline |
| | import numpy as np |
| | from queue import Queue, Empty |
| | import threading |
| | import base64 |
| | import uuid |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | ( |
| | self.module_kwargs, |
| | self.socket_receiver_kwargs, |
| | self.socket_sender_kwargs, |
| | self.vad_handler_kwargs, |
| | self.whisper_stt_handler_kwargs, |
| | self.paraformer_stt_handler_kwargs, |
| | self.language_model_handler_kwargs, |
| | self.mlx_language_model_handler_kwargs, |
| | self.parler_tts_handler_kwargs, |
| | self.melo_tts_handler_kwargs, |
| | self.chat_tts_handler_kwargs, |
| | ) = get_default_arguments(mode='none', log_level='DEBUG', lm_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', tts_compile_mode='default', stt_compile_mode='reduce-overhead', tts_model_name='ylacombe/parler-tiny-v1-jenny') |
| |
|
| | setup_logger(self.module_kwargs.log_level) |
| |
|
| | prepare_all_args( |
| | self.module_kwargs, |
| | self.whisper_stt_handler_kwargs, |
| | self.paraformer_stt_handler_kwargs, |
| | self.language_model_handler_kwargs, |
| | self.mlx_language_model_handler_kwargs, |
| | self.parler_tts_handler_kwargs, |
| | self.melo_tts_handler_kwargs, |
| | self.chat_tts_handler_kwargs, |
| | ) |
| |
|
| | self.queues_and_events = initialize_queues_and_events() |
| |
|
| | self.pipeline_manager = build_pipeline( |
| | self.module_kwargs, |
| | self.socket_receiver_kwargs, |
| | self.socket_sender_kwargs, |
| | self.vad_handler_kwargs, |
| | self.whisper_stt_handler_kwargs, |
| | self.paraformer_stt_handler_kwargs, |
| | self.language_model_handler_kwargs, |
| | self.mlx_language_model_handler_kwargs, |
| | self.parler_tts_handler_kwargs, |
| | self.melo_tts_handler_kwargs, |
| | self.chat_tts_handler_kwargs, |
| | self.queues_and_events, |
| | ) |
| |
|
| | self.pipeline_manager.start() |
| |
|
| | |
| | self.final_output_queue = Queue() |
| | self.sessions = {} |
| | self.vad_chunk_size = 512 |
| | self.sample_rate = 16000 |
| |
|
| | def _process_audio_chunk(self, audio_data: bytes, session_id: str): |
| | print('processing audio chunk') |
| | audio_array = np.frombuffer(audio_data, dtype=np.int16) |
| |
|
| | |
| | chunks = [audio_array[i:i+self.vad_chunk_size] for i in range(0, len(audio_array), self.vad_chunk_size)] |
| | |
| | for chunk in chunks: |
| | if len(chunk) == self.vad_chunk_size: |
| | self.queues_and_events['recv_audio_chunks_queue'].put(chunk.tobytes()) |
| | elif len(chunk) < self.vad_chunk_size: |
| | |
| | padded_chunk = np.pad(chunk, (0, self.vad_chunk_size - len(chunk)), 'constant') |
| | self.queues_and_events['recv_audio_chunks_queue'].put(padded_chunk.tobytes()) |
| |
|
| | def _collect_output(self, session_id): |
| | while True: |
| | try: |
| | output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=2) |
| | if isinstance(output, (str, bytes)) and output in (b"END", "END"): |
| | self.sessions[session_id]['status'] = 'completed' |
| | break |
| | elif isinstance(output, np.ndarray): |
| | self.sessions[session_id]['chunks'].append(output.tobytes()) |
| | else: |
| | self.sessions[session_id]['chunks'].append(output) |
| | except Empty: |
| | continue |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | request_type = data.get("request_type", "start") |
| | |
| | if request_type == "start": |
| | return self._handle_start_request(data) |
| | elif request_type == "continue": |
| | return self._handle_continue_request(data) |
| | else: |
| | raise ValueError(f"Unsupported request type: {request_type}") |
| |
|
| | def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | print("Starting new session") |
| | session_id = str(uuid.uuid4()) |
| | self.sessions[session_id] = { |
| | 'status': 'new', |
| | 'chunks': [], |
| | 'last_sent_index': 0, |
| | 'buffer': b'' |
| | } |
| |
|
| | input_type = data.get("input_type", "text") |
| | input_data = data.get("inputs", "") |
| | print(f"input_type: {input_type}") |
| | print(f"input_data: {input_data}") |
| |
|
| | if input_type == "speech": |
| | audio_bytes = base64.b64decode(input_data) |
| | self._process_audio_chunk(audio_bytes, session_id) |
| | elif input_type == "text": |
| | self.queues_and_events['text_prompt_queue'].put(input_data) |
| | else: |
| | raise ValueError(f"Unsupported input type: {input_type}") |
| |
|
| | |
| | threading.Thread(target=self._collect_output, args=(session_id,)).start() |
| |
|
| | return {"session_id": session_id, "status": "new"} |
| |
|
| | def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | session_id = data.get("session_id") |
| | print(f"session_id: {session_id}") |
| | print('continue request') |
| | if not session_id or session_id not in self.sessions: |
| | raise ValueError("Invalid or missing session_id") |
| |
|
| | session = self.sessions[session_id] |
| |
|
| | if not self.queues_and_events['should_listen'].is_set(): |
| | session['status'] = 'processing' |
| | print('should_listen is not set, processing') |
| | elif "inputs" in data: |
| | input_data = data["inputs"] |
| | print(f"input_data: {input_data}") |
| | audio_bytes = base64.b64decode(input_data) |
| | self._process_audio_chunk(audio_bytes, session_id) |
| |
|
| | chunks_to_send = session['chunks'][session['last_sent_index']:] |
| | session['last_sent_index'] = len(session['chunks']) |
| |
|
| | if chunks_to_send: |
| | print('chunks_to_send') |
| | combined_audio = b''.join(chunks_to_send) |
| | base64_audio = base64.b64encode(combined_audio).decode('utf-8') |
| | return { |
| | "session_id": session_id, |
| | "status": session['status'], |
| | "output": base64_audio |
| | } |
| | else: |
| | return { |
| | "session_id": session_id, |
| | "status": session['status'], |
| | "output": None |
| | } |
| |
|
| | def cleanup(self): |
| | |
| | self.pipeline_manager.stop() |
| | |
| | |
| | self.queues_and_events['send_audio_chunks_queue'].put(b"END") |
| | self.output_collector_thread.join() |