Isa0's picture
add train code
e498f5b
|
Raw
History Blame Contribute Delete
2.58 kB
---
license: mit
---
# Speech to Text
Fine-tune a [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) acoustic model on the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset using CTC, then export it to ONNX for inference.
## Requirements
- Python >= 3.10
- A CUDA-capable GPU is recommended for training
Install dependencies:
```bash
pip install -e .
```
## Training
Fine-tune `facebook/wav2vec2-base` on LJSpeech (5% held out for eval). Training takes ~10 epochs by default and writes checkpoints to `wav2vec2-ljspeech/`.
```bash
python train.py
```
Key settings live at the top of `train.py`:
| Constant | Default | Purpose |
| --- | --- | --- |
| `MODEL_ID` | `facebook/wav2vec2-base` | Pre-trained wav2vec2 checkpoint |
| `DATASET_ID` | `lj_speech` | HuggingFace dataset id |
Training hyperparameters (batch size, epochs, learning rate, etc.) are configured through `TrainingArguments` inside `train.py`.
Monitor progress with TensorBoard:
```bash
tensorboard --logdir wav2vec2-ljspeech
```
## ONNX Export
Export the trained checkpoint to ONNX and validate it with ONNX Runtime:
```bash
python export_onnx.py
```
Options:
```
--model-dir Checkpoint directory (default: wav2vec2-ljspeech)
--output Output ONNX path (default: wav2vec2-ljspeech.onnx)
--opset ONNX opset version (default: 17)
```
The exported model uses dynamic axes on batch and time, so it accepts audio of any length.
## Inference
```python
import numpy as np
import onnxruntime as ort
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor.from_pretrained("wav2vec2-ljspeech")
session = ort.InferenceSession("wav2vec2-ljspeech.onnx")
# audio_array: 16 kHz mono float32 numpy array
inputs = processor(audio_array, sampling_rate=16000, return_tensors="np")
logits = session.run(None, {"input_values": inputs.input_values})[0]
text = processor.tokenizer.batch_decode(np.argmax(logits, axis=-1))[0]
print(text)
```
Notes:
- Audio must be **16 kHz mono float32**.
- The `Wav2Vec2Processor` handles waveform normalization and tokenization β€” always pass audio through it before the ONNX session.
- This exports the **acoustic model only**. Add an external LM (e.g. KenLM) for language-model-rescored decoding if needed.
## Project Layout
```
speech-to-text/
β”œβ”€β”€ train.py # Wav2Vec2 + CTC fine-tuning on LJSpeech
β”œβ”€β”€ export_onnx.py # ONNX export and ONNX Runtime validation
β”œβ”€β”€ main.py # Placeholder entry point
β”œβ”€β”€ pyproject.toml # Project metadata and dependencies
└── README.md
```