| import gradio as gr |
| import os |
| import re |
| import subprocess |
| from SenseVoiceAx import SenseVoiceAx |
| import numpy as np |
|
|
| model_root = "../sensevoice_ax650" |
| max_seq_len = 256 |
| model_path = os.path.join(model_root, "sensevoice.axmodel") |
|
|
| assert os.path.exists(model_path), f"model {model_path} not exist" |
|
|
| cmvn_file = os.path.join(model_root, "am.mvn") |
| bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model") |
| token_file = os.path.join(model_root, "tokens.txt") |
|
|
| model = SenseVoiceAx( |
| model_path, |
| cmvn_file, |
| token_file, |
| bpe_model, |
| max_seq_len=max_seq_len, |
| beam_size=3, |
| hot_words=None, |
| streaming=False, |
| ) |
|
|
| |
| def speech_to_text(audio_input, lang): |
| """ |
| audio_path: A tuple of (sample rate in Hz, audio data as numpy array). |
| lang: 语言类型 "auto", "zh", "en", "yue", "ja", "ko" |
| """ |
| if not audio_input: |
| return "无音频" |
|
|
| sr, audio_data = audio_input |
| if audio_data.dtype != np.float32: |
| audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max |
|
|
| asr_res = model.infer((audio_data, sr), lang, print_rtf=False) |
| return asr_res |
|
|
|
|
| def get_all_local_ips(): |
| result = subprocess.run(['ip', 'a'], capture_output=True, text=True) |
| output = result.stdout |
|
|
| |
| ips = re.findall(r'inet (\d+\.\d+\.\d+\.\d+)', output) |
|
|
| |
| real_ips = [ip for ip in ips if not ip.startswith('127.')] |
|
|
| return real_ips |
|
|
|
|
| def main(): |
| with gr.Blocks() as demo: |
| with gr.Row(): |
| output_text = gr.Textbox(label="识别结果", lines=5) |
|
|
| with gr.Row(): |
| audio_input = gr.Audio( |
| sources=["microphone", "upload"], type="numpy", label="录制或上传音频", format="wav" |
| ) |
| lang_dropdown = gr.Dropdown( |
| choices=["auto", "zh", "en", "yue", "ja", "ko"], |
| value="auto", |
| label="选择音频语言", |
| ) |
|
|
| audio_input.change( |
| fn=speech_to_text, inputs=[audio_input, lang_dropdown], outputs=output_text |
| ) |
| |
| |
| ips = get_all_local_ips() |
| port = 7861 |
| for ip in ips: |
| print(f"* Running on local URL: https://{ip}:{port}") |
| ip = "0.0.0.0" |
|
|
| demo.launch( |
| server_name=ip, |
| server_port=port, |
| ssl_certfile="./cert.pem", |
| ssl_keyfile="./key.pem", |
| ssl_verify=False, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|