import os import torch from transformers import Trainer, TrainingArguments, Wav2Vec2CTCTokenizer import torch.nn.functional as F from models.ctc_model import CTCTransformerModel, CTCTransformerConfig from data import DataCollatorCTCWithPadding, SpeechTokenPhonemeDataset import evaluate import numpy as np import pandas as pd import logging import warnings os.environ["WANDB_PROJECT"] = "speech-phoneme-ctc" warnings.filterwarnings("ignore") logger = logging.getLogger(__name__) df = pd.read_csv( "dataset.csv", ) # Dataset vocab_path = "vocab/vocab.json" tokenizer = Wav2Vec2CTCTokenizer( vocab_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", ) vocab = tokenizer.get_vocab() vocab_inv = {v: k for k, v in vocab.items()} num_speech_tokens = 6561 # ===== MODEL SETUP ===== config = CTCTransformerConfig( vocab_size=num_speech_tokens, num_labels=len(tokenizer), hidden_size=768, intermediate_size=3072, num_attention_heads=12, num_hidden_layers=12, max_position_embeddings=1024, label2id=vocab, id2label=vocab_inv, pad_token_id=tokenizer.pad_token_id, # output padding token src_pad_token_id=num_speech_tokens, # input padding token ) model = CTCTransformerModel(config) dataset = SpeechTokenPhonemeDataset(df, tokenizer=tokenizer) train_valid_dataset = dataset.train_test_split(test_size=0.05, random_state=42) train_dataset = train_valid_dataset["train"] eval_dataset = train_valid_dataset["test"] collator = DataCollatorCTCWithPadding( pad_token_id=num_speech_tokens, label_pad_token_id=tokenizer.pad_token_id ) # ===== METRICS ===== cer_metric = evaluate.load("cer") def compute_metrics(pred): label_ids = pred.label_ids logits = pred.predictions log_probs = F.log_softmax(torch.tensor(logits), dim=-1) pred_ids = np.argmax(log_probs, axis=-1) # Replace -100 with pad token ID label_ids[label_ids == -100] = tokenizer.pad_token_id # Decode predictions and references pred_str = tokenizer.batch_decode(pred_ids) label_str = tokenizer.batch_decode(label_ids, group_tokens=False) # Calculate WER and CER cer = cer_metric.compute(predictions=pred_str, references=label_str) return {"cer": cer} # Check vocabulary compatibility and print more detailed diagnostic info print(f"Model vocab size: {model.config.vocab_size}") print(f"Tokenizer vocab size: {len(tokenizer)}") print( f"Vocabulary: {list(tokenizer.get_vocab().keys())[:10]}... (showing first 10 tokens)" ) print("Training dataset size:", len(train_dataset)) print("Evaluation dataset size:", len(eval_dataset)) if model.config.vocab_size != len(tokenizer.get_vocab()): print("WARNING: Vocabulary size mismatch between model and tokenizer!") training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=64, per_device_eval_batch_size=16, eval_strategy="epoch", save_strategy="epoch", save_total_limit=10, num_train_epochs=10, learning_rate=1e-4, weight_decay=0.005, warmup_ratio=0.1, logging_steps=100, logging_dir="./logs", gradient_accumulation_steps=1, bf16=True, report_to="wandb", remove_unused_columns=False, dataloader_num_workers=4, include_inputs_for_metrics=True, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=collator, compute_metrics=compute_metrics, ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {training_args.num_train_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}" ) logger.info( f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}" ) logger.info(f" Total optimization steps = {training_args.max_steps}") logger.info(f" Logging steps = {training_args.logging_steps}") logger.info(f" Learning rate = {training_args.learning_rate}") logger.info(f" Weight decay = {training_args.weight_decay}") logger.info(f" Warmup steps = {training_args.warmup_steps}") logger.info(f" Save total limit = {training_args.save_total_limit}") train_res = trainer.train() trainer.save_model() trainer.save_state() metrics = train_res.metrics metrics["train_samples"] = len(train_dataset) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) metrics = trainer.evaluate() metrics["eval_samples"] = len(eval_dataset) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) with open("results/train.log", "w") as f: for obj in trainer.state.log_history: f.write(str(obj)) f.write("\n") print("- Training complete.")