| |
|
|
| from io import BytesIO |
| from typing import Optional, Dict, Any, List, Set, Union, Tuple |
|
|
| |
| import os |
| import time |
| import asyncio |
|
|
| |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Depends |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi.responses import HTMLResponse |
| import numpy as np |
| import torch |
| import torchaudio |
| from funasr import AutoModel |
| from dotenv import load_dotenv |
| import os |
| import time |
| import gradio as gr |
|
|
| |
| load_dotenv() |
|
|
| |
| API_TOKEN: str = os.getenv("API_TOKEN") |
| if not API_TOKEN: |
| raise RuntimeError("API_TOKEN environment variable is not set") |
|
|
| |
| security = HTTPBearer() |
|
|
| app = FastAPI( |
| title="SenseVoice API", |
| description="Speech To Text API Service", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| model = AutoModel( |
| model="FunAudioLLM/SenseVoiceSmall", |
| vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| vad_kwargs={"max_single_segment_time": 30000}, |
| hub="hf", |
| device="cuda" |
| ) |
|
|
| emotion_dict: Dict[str, str] = { |
| "<|HAPPY|>": "😊", |
| "<|SAD|>": "😔", |
| "<|ANGRY|>": "😡", |
| "<|NEUTRAL|>": "", |
| "<|FEARFUL|>": "😰", |
| "<|DISGUSTED|>": "🤢", |
| "<|SURPRISED|>": "😮", |
| } |
|
|
| event_dict: Dict[str, str] = { |
| "<|BGM|>": "🎼", |
| "<|Speech|>": "", |
| "<|Applause|>": "👏", |
| "<|Laughter|>": "😀", |
| "<|Cry|>": "😭", |
| "<|Sneeze|>": "🤧", |
| "<|Breath|>": "", |
| "<|Cough|>": "🤧", |
| } |
|
|
| emoji_dict: Dict[str, str] = { |
| "<|nospeech|><|Event_UNK|>": "❓", |
| "<|zh|>": "", |
| "<|en|>": "", |
| "<|yue|>": "", |
| "<|ja|>": "", |
| "<|ko|>": "", |
| "<|nospeech|>": "", |
| "<|HAPPY|>": "😊", |
| "<|SAD|>": "😔", |
| "<|ANGRY|>": "😡", |
| "<|NEUTRAL|>": "", |
| "<|BGM|>": "🎼", |
| "<|Speech|>": "", |
| "<|Applause|>": "👏", |
| "<|Laughter|>": "😀", |
| "<|FEARFUL|>": "😰", |
| "<|DISGUSTED|>": "🤢", |
| "<|SURPRISED|>": "😮", |
| "<|Cry|>": "😭", |
| "<|EMO_UNKNOWN|>": "", |
| "<|Sneeze|>": "🤧", |
| "<|Breath|>": "", |
| "<|Cough|>": "😷", |
| "<|Sing|>": "", |
| "<|Speech_Noise|>": "", |
| "<|withitn|>": "", |
| "<|woitn|>": "", |
| "<|GBG|>": "", |
| "<|Event_UNK|>": "", |
| } |
|
|
| lang_dict: Dict[str, str] = { |
| "<|zh|>": "<|lang|>", |
| "<|en|>": "<|lang|>", |
| "<|yue|>": "<|lang|>", |
| "<|ja|>": "<|lang|>", |
| "<|ko|>": "<|lang|>", |
| "<|nospeech|>": "<|lang|>", |
| } |
|
|
| emo_set: Set[str] = {"😊", "😔", "😡", "😰", "🤢", "😮"} |
| event_set: Set[str] = {"🎼", "👏", "😀", "😭", "🤧", "😷"} |
|
|
|
|
| def format_text_with_emotion(text: str) -> str: |
| """Format text with emotion and event markers""" |
| token_count: Dict[str, int] = {} |
| original_text = text |
| for token in emoji_dict: |
| token_count[token] = text.count(token) |
| |
| |
| dominant_emotion = "<|NEUTRAL|>" |
| for emotion in emotion_dict: |
| if token_count[emotion] > token_count[dominant_emotion]: |
| dominant_emotion = emotion |
| |
| |
| text = original_text |
| for event in event_dict: |
| if token_count[event] > 0: |
| text = event_dict[event] + text |
| |
| |
| for token in emoji_dict: |
| text = text.replace(token, emoji_dict[token]) |
| |
| |
| text = text + emotion_dict[dominant_emotion] |
|
|
| |
| for emoji in emo_set.union(event_set): |
| text = text.replace(" " + emoji, emoji) |
| text = text.replace(emoji + " ", emoji) |
| return text.strip() |
|
|
|
|
| def format_text_advanced(text: str) -> str: |
| """Advanced text formatting with multilingual and complex token handling""" |
| def get_emotion(text: str) -> Optional[str]: |
| return text[-1] if text[-1] in emo_set else None |
|
|
| def get_event(text: str) -> Optional[str]: |
| return text[0] if text[0] in event_set else None |
|
|
| |
| text = text.replace("<|nospeech|><|Event_UNK|>", "❓") |
| for lang in lang_dict: |
| text = text.replace(lang, "<|lang|>") |
| |
| |
| text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")] |
| formatted_text = " " + text_segments[0] |
| current_event = get_event(formatted_text) |
|
|
| |
| for i in range(1, len(text_segments)): |
| if not text_segments[i]: |
| continue |
|
|
| if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None: |
| text_segments[i] = text_segments[i][1:] |
| current_event = get_event(text_segments[i]) |
|
|
| if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text): |
| formatted_text = formatted_text[:-1] |
| formatted_text += text_segments[i].strip() |
|
|
| formatted_text = formatted_text.replace("The.", " ") |
| return formatted_text.strip() |
|
|
|
|
| async def audio_stt(audio: torch.Tensor, sample_rate: int, language: str = "auto") -> str: |
| """Process audio tensor and perform speech-to-text conversion. |
| |
| Args: |
| audio: Input audio tensor |
| sample_rate: Audio sample rate in Hz |
| language: Target language code (auto/zh/en/yue/ja/ko/nospeech) |
| |
| Returns: |
| str: Transcribed and formatted text result |
| """ |
| try: |
| |
| if audio.dtype != torch.float32: |
| if audio.dtype == torch.int16: |
| audio = audio.float() / torch.iinfo(torch.int16).max |
| elif audio.dtype == torch.int32: |
| audio = audio.float() / torch.iinfo(torch.int32).max |
| else: |
| audio = audio.float() |
| |
| |
| if audio.abs().max() > 1.0: |
| audio = audio / audio.abs().max() |
| |
| |
| if len(audio.shape) > 1: |
| audio = audio.mean(dim=0) |
| audio = audio.squeeze() |
| |
| |
| if sample_rate != 16000: |
| resampler = torchaudio.transforms.Resample( |
| orig_freq=sample_rate, |
| new_freq=16000 |
| ) |
| audio = resampler(audio.unsqueeze(0)).squeeze(0) |
| |
| text = model.generate( |
| input=audio, |
| cache={}, |
| language=language, |
| use_itn=True, |
| batch_size_s=500, |
| merge_vad=True |
| ) |
| |
| |
| result = text[0]["text"] |
| return format_text_advanced(result) |
| |
| except Exception as e: |
| raise HTTPException( |
| status_code=500, |
| detail=f"Audio processing failed in audio_stt: {str(e)}" |
| ) |
|
|
| async def process_audio(audio_data: bytes, language: str = "auto") -> str: |
| """Process audio data and return transcription result. |
| |
| Args: |
| audio_data: Raw audio data in bytes |
| language: Target language code |
| |
| Returns: |
| str: Transcribed and formatted text |
| |
| Raises: |
| HTTPException: If audio processing fails |
| """ |
| try: |
| audio_buffer = BytesIO(audio_data) |
| waveform, sample_rate = torchaudio.load( |
| uri=audio_buffer, |
| normalize=True, |
| channels_first=True |
| ) |
| result = await audio_stt(waveform, sample_rate, language) |
| return result |
| |
| except Exception as e: |
| raise HTTPException( |
| status_code=500, |
| detail=f"Audio processing failed: {str(e)}" |
| ) |
|
|
|
|
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials: |
| """Verify Bearer Token authentication""" |
| if credentials.credentials != API_TOKEN: |
| raise HTTPException( |
| status_code=401, |
| detail="Invalid authentication token", |
| headers={"WWW-Authenticate": "Bearer"} |
| ) |
| return credentials |
|
|
| @app.post("/v1/audio/transcriptions") |
| async def transcribe_audio( |
| file: UploadFile = File(...), |
| model: str = "FunAudioLLM/SenseVoiceSmall", |
| language: str = "auto", |
| token: HTTPAuthorizationCredentials = Depends(verify_token) |
| ) -> Dict[str, Union[str, int, float]]: |
| """Audio transcription endpoint. |
| |
| Args: |
| file: Audio file (supports mp3, wav, flac, ogg, m4a) |
| model: Model name |
| language: Language code |
| token: Authentication token |
| |
| Returns: |
| Dict containing transcription result and metadata |
| """ |
| start_time = time.time() |
| |
| try: |
| |
| if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): |
| return { |
| "text": "", |
| "error_code": 400, |
| "error_msg": "不支持的音频格式", |
| "process_time": time.time() - start_time |
| } |
| |
| |
| if model != "FunAudioLLM/SenseVoiceSmall": |
| return { |
| "text": "", |
| "error_code": 400, |
| "error_msg": "不支持的模型", |
| "process_time": time.time() - start_time |
| } |
| |
| |
| if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: |
| return { |
| "text": "", |
| "error_code": 400, |
| "error_msg": "不支持的语言", |
| "process_time": time.time() - start_time |
| } |
| |
| |
| content = await file.read() |
| text = await process_audio(content, language) |
| |
| return { |
| "text": text, |
| "error_code": 0, |
| "error_msg": "", |
| "process_time": time.time() - start_time |
| } |
| |
| except Exception as e: |
| return { |
| "text": "", |
| "error_code": 500, |
| "error_msg": str(e), |
| "process_time": time.time() - start_time |
| } |
|
|
|
|
| def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str: |
| """Gradio interface for audio transcription""" |
| try: |
| if audio is None: |
| return "Please upload an audio file" |
| |
| |
| sample_rate, input_wav = audio |
| |
| |
| input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max |
| |
| |
| input_wav = torch.from_numpy(input_wav) |
| result = asyncio.run(audio_stt(input_wav, sample_rate, language)) |
| |
| return result |
| except Exception as e: |
| return f"Processing failed: {str(e)}" |
|
|
| |
| demo = gr.Interface( |
| fn=transcribe_audio_gradio, |
| inputs=[ |
| gr.Audio( |
| sources=["upload", "microphone"], |
| type="numpy", |
| label="Upload audio or record from microphone" |
| ), |
| gr.Dropdown( |
| choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], |
| value="auto", |
| label="Select Language" |
| ) |
| ], |
| outputs=gr.Textbox(label="Recognition Result"), |
| title="SenseVoice Speech Recognition", |
| description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean", |
| examples=[ |
| ["examples/zh.mp3", "zh"], |
| ["examples/en.mp3", "en"], |
| ] |
| ) |
|
|
| |
| app = gr.mount_gradio_app(app, demo, path="/") |
|
|
| |
| @app.get("/docs", include_in_schema=False) |
| async def custom_swagger_ui_html(): |
| return HTMLResponse(""" |
| <!DOCTYPE html> |
| <html> |
| <head> |
| <title>SenseVoice API Documentation</title> |
| <meta http-equiv="refresh" content="0;url=/docs/" /> |
| </head> |
| <body> |
| <p>Redirecting to API documentation...</p> |
| </body> |
| </html> |
| """) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|