Speech to Text

Fine-tune a Wav2Vec2 acoustic model on the LJSpeech dataset using CTC, then export it to ONNX for inference.

Requirements

  • Python >= 3.10
  • A CUDA-capable GPU is recommended for training

Install dependencies:

pip install -e .

Training

Fine-tune facebook/wav2vec2-base on LJSpeech (5% held out for eval). Training takes ~10 epochs by default and writes checkpoints to wav2vec2-ljspeech/.

python train.py

Key settings live at the top of train.py:

Constant Default Purpose
MODEL_ID facebook/wav2vec2-base Pre-trained wav2vec2 checkpoint
DATASET_ID lj_speech HuggingFace dataset id

Training hyperparameters (batch size, epochs, learning rate, etc.) are configured through TrainingArguments inside train.py.

Monitor progress with TensorBoard:

tensorboard --logdir wav2vec2-ljspeech

ONNX Export

Export the trained checkpoint to ONNX and validate it with ONNX Runtime:

python export_onnx.py

Options:

--model-dir   Checkpoint directory (default: wav2vec2-ljspeech)
--output      Output ONNX path   (default: wav2vec2-ljspeech.onnx)
--opset       ONNX opset version (default: 17)

The exported model uses dynamic axes on batch and time, so it accepts audio of any length.

Inference

import numpy as np
import onnxruntime as ort
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("wav2vec2-ljspeech")
session = ort.InferenceSession("wav2vec2-ljspeech.onnx")

# audio_array: 16 kHz mono float32 numpy array
inputs = processor(audio_array, sampling_rate=16000, return_tensors="np")
logits = session.run(None, {"input_values": inputs.input_values})[0]
text = processor.tokenizer.batch_decode(np.argmax(logits, axis=-1))[0]
print(text)

Notes:

  • Audio must be 16 kHz mono float32.
  • The Wav2Vec2Processor handles waveform normalization and tokenization β€” always pass audio through it before the ONNX session.
  • This exports the acoustic model only. Add an external LM (e.g. KenLM) for language-model-rescored decoding if needed.

Project Layout

speech-to-text/
β”œβ”€β”€ train.py          # Wav2Vec2 + CTC fine-tuning on LJSpeech
β”œβ”€β”€ export_onnx.py    # ONNX export and ONNX Runtime validation
β”œβ”€β”€ main.py           # Placeholder entry point
β”œβ”€β”€ pyproject.toml    # Project metadata and dependencies
└── README.md
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support