| --- |
| language: |
| - vi |
| - en |
| license: apache-2.0 |
| tags: |
| - asr |
| - automatic-speech-recognition |
| - transformer |
| - vietnamese |
| - english |
| - bilingual |
| datasets: |
| - Cong123779/AI2Text-Bilingual-ASR-Dataset |
| metrics: |
| - wer |
| - cer |
| --- |
| |
| # AI2Text β Bilingual ASR (Vietnamese + English) |
|
|
| A **~30M-parameter** Transformer Seq2Seq Automatic Speech Recognition model |
| trained on **~224k** bilingual (Vietnamese + English) audio samples. |
|
|
| ## Model Description |
|
|
| | Attribute | Value | |
| |---|---| |
| | Architecture | Encoder-Decoder Transformer | |
| | Parameters | ~30,325,164 | |
| | d_model | 256 | |
| | Encoder layers | 14 (RoPE + Flash Attention) | |
| | Decoder layers | 6 (causal, cross-attention) | |
| | Vocabulary size | 3,500 (SentencePiece BPE) | |
| | Language embedding | Yes (Vietnamese=0, English=1) | |
| | Normalization | RMSNorm | |
| | Activation | SiLU (Swish) | |
| | Positional encoding | Rotary (RoPE) | |
| |
| ### Modern Components |
| - **RMSNorm** β more efficient than LayerNorm |
| - **SiLU (Swish)** activation |
| - **Rotary Positional Embedding (RoPE)** β better generalization |
| - **Flash Attention (SDPA)** β memory-efficient attention |
| - **Hybrid CTC / Attention loss** β helps encoder learn alignment |
| |
| ## Training Data |
| |
| Trained on `Cong123779/AI2Text-Bilingual-ASR-Dataset`: |
| - **Train**: ~194,167 samples (77% Vietnamese, 23% English) |
| - **Validation**: ~30,123 samples |
| |
| Audio format: 16 kHz mono WAV, 80-dim Mel-spectrogram features. |
| |
| ## Training Configuration |
| |
| | Hyperparameter | Value | |
| |---|---| |
| | Batch size | 32 (effective 128 w/ grad-accum Γ 4) | |
| | Learning rate | 3e-4 | |
| | Epochs | 50 | |
| | Warmup | 3% of training steps | |
| | Mixed precision | bfloat16 (AMP) | |
| | Gradient clipping | 0.5 | |
| | CTC weight | 0.2 | |
| | Scheduled sampling | 1.0 β 0.5 (linear) | |
| |
| ## Usage |
| |
| ```python |
| import torch |
| from pathlib import Path |
| import sys |
| |
| # Clone the repo and add to path |
| sys.path.insert(0, "AI2Text") |
| |
| from models.asr_base import ASRModel |
| from preprocessing.sentencepiece_tokenizer import SentencePieceTokenizer |
| from preprocessing.audio_processing import AudioProcessor |
|
|
| # Load tokenizer |
| tokenizer = SentencePieceTokenizer("models/tokenizer_vi_en_3500.model") |
| |
| # Load model |
| checkpoint = torch.load("best_model.pt", map_location="cpu") |
| config = checkpoint.get("config", {}) |
| |
| model = ASRModel( |
| input_dim=80, |
| vocab_size=3500, |
| d_model=256, |
| num_encoder_layers=14, |
| num_decoder_layers=6, |
| num_heads=8, |
| d_ff=2048, |
| num_languages=2, |
| ) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| |
| # Transcribe |
| audio_processor = AudioProcessor(sample_rate=16000, n_mels=80) |
| features = audio_processor.process("audio.wav") # (time, 80) |
| features = features.unsqueeze(0) # (1, time, 80) |
| lengths = torch.tensor([features.size(1)]) |
|
|
| with torch.no_grad(): |
| tokens = model.generate( |
| features, lengths=lengths, |
| language_ids=torch.tensor([0]), # 0=vi, 1=en |
| max_len=128, |
| sos_token_id=tokenizer.sos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| text = tokenizer.decode(tokens[0].tolist()) |
| print(text) |
| ``` |
| |
| ## Framework |
| Built with PyTorch. Optimized for **RTX 5060TI 16GB / Ryzen 9 9990X / 64GB RAM**. |
|
|
| ## License |
| Apache 2.0 |
|
|