| | """ |
| | Utility functions for training and evaluation |
| | """ |
| | import numpy as np |
| | from sklearn.metrics import ( |
| | accuracy_score, |
| | precision_recall_fscore_support, |
| | confusion_matrix, |
| | classification_report |
| | ) |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from typing import Dict, Tuple, List, Optional |
| | import os |
| |
|
| |
|
| | def compute_metrics(eval_pred, id2label: Optional[Dict[int, str]] = None) -> Dict[str, float]: |
| | """ |
| | Compute comprehensive metrics for evaluation. |
| | |
| | Args: |
| | eval_pred: Tuple of (predictions, labels) |
| | id2label: Optional mapping from label IDs to label names for per-class metrics |
| | |
| | Returns: |
| | Dictionary of metrics including overall and per-class metrics |
| | """ |
| | predictions, labels = eval_pred |
| | predictions = np.argmax(predictions, axis=1) |
| | |
| | |
| | accuracy = accuracy_score(labels, predictions) |
| | |
| | |
| | precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support( |
| | labels, |
| | predictions, |
| | average='weighted', |
| | zero_division=0 |
| | ) |
| | |
| | |
| | precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( |
| | labels, |
| | predictions, |
| | average='macro', |
| | zero_division=0 |
| | ) |
| | |
| | |
| | precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support( |
| | labels, |
| | predictions, |
| | average='micro', |
| | zero_division=0 |
| | ) |
| | |
| | metrics = { |
| | 'accuracy': accuracy, |
| | 'precision_weighted': precision_weighted, |
| | 'recall_weighted': recall_weighted, |
| | 'f1_weighted': f1_weighted, |
| | 'precision_macro': precision_macro, |
| | 'recall_macro': recall_macro, |
| | 'f1_macro': f1_macro, |
| | 'precision_micro': precision_micro, |
| | 'recall_micro': recall_micro, |
| | 'f1_micro': f1_micro, |
| | } |
| | |
| | |
| | if id2label is not None: |
| | num_classes = len(id2label) |
| | precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support( |
| | labels, |
| | predictions, |
| | labels=list(range(num_classes)), |
| | average=None, |
| | zero_division=0 |
| | ) |
| | |
| | for i in range(num_classes): |
| | label_name = id2label[i] |
| | metrics[f'precision_{label_name}'] = float(precision_per_class[i]) |
| | metrics[f'recall_{label_name}'] = float(recall_per_class[i]) |
| | metrics[f'f1_{label_name}'] = float(f1_per_class[i]) |
| | metrics[f'support_{label_name}'] = int(support[i]) |
| | |
| | return metrics |
| |
|
| |
|
| | def compute_metrics_factory(id2label: Optional[Dict[int, str]] = None): |
| | """ |
| | Factory function to create compute_metrics with label mapping. |
| | |
| | Args: |
| | id2label: Mapping from label IDs to label names |
| | |
| | Returns: |
| | Function compatible with HuggingFace Trainer |
| | """ |
| | def compute_metrics_fn(eval_pred): |
| | return compute_metrics(eval_pred, id2label) |
| | |
| | return compute_metrics_fn |
| |
|
| |
|
| | def plot_confusion_matrix( |
| | y_true: np.ndarray, |
| | y_pred: np.ndarray, |
| | labels: List[str], |
| | save_path: str = "confusion_matrix.png", |
| | normalize: bool = False, |
| | figsize: Tuple[int, int] = (10, 8) |
| | ) -> None: |
| | """ |
| | Plot and save confusion matrix with optional normalization. |
| | |
| | Args: |
| | y_true: True labels |
| | y_pred: Predicted labels |
| | labels: List of label names |
| | save_path: Path to save the plot |
| | normalize: If True, normalize confusion matrix to percentages |
| | figsize: Figure size (width, height) |
| | """ |
| | cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels)))) |
| | |
| | if normalize: |
| | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
| | fmt = '.2f' |
| | title = 'Normalized Confusion Matrix' |
| | else: |
| | fmt = 'd' |
| | title = 'Confusion Matrix' |
| | |
| | plt.figure(figsize=figsize) |
| | sns.heatmap( |
| | cm, |
| | annot=True, |
| | fmt=fmt, |
| | cmap='Blues', |
| | xticklabels=labels, |
| | yticklabels=labels, |
| | cbar_kws={'label': 'Percentage' if normalize else 'Count'} |
| | ) |
| | plt.title(title, fontsize=14, fontweight='bold') |
| | plt.ylabel('True Label', fontsize=12) |
| | plt.xlabel('Predicted Label', fontsize=12) |
| | plt.tight_layout() |
| | |
| | |
| | os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True) |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| | plt.close() |
| | |
| | print(f"Confusion matrix saved to {save_path}") |
| |
|
| |
|
| | def print_classification_report( |
| | y_true: np.ndarray, |
| | y_pred: np.ndarray, |
| | labels: List[str], |
| | output_dict: bool = False |
| | ) -> Optional[Dict]: |
| | """ |
| | Print detailed classification report. |
| | |
| | Args: |
| | y_true: True labels |
| | y_pred: Predicted labels |
| | labels: List of label names |
| | output_dict: If True, return report as dictionary instead of printing |
| | |
| | Returns: |
| | Classification report as dictionary if output_dict=True, else None |
| | """ |
| | report = classification_report( |
| | y_true, |
| | y_pred, |
| | target_names=labels, |
| | digits=4, |
| | output_dict=output_dict, |
| | zero_division=0 |
| | ) |
| | |
| | if output_dict: |
| | return report |
| | |
| | print("\nClassification Report:") |
| | print("=" * 60) |
| | print(report) |
| | return None |
| |
|
| |
|
| | def plot_training_curves( |
| | train_losses: List[float], |
| | eval_losses: List[float], |
| | eval_metrics: Dict[str, List[float]], |
| | save_path: str = "./results/training_curves.png" |
| | ) -> None: |
| | """ |
| | Plot training and evaluation curves. |
| | |
| | Args: |
| | train_losses: List of training losses per step/epoch |
| | eval_losses: List of evaluation losses per step/epoch |
| | eval_metrics: Dictionary of metric names to lists of values |
| | save_path: Path to save the plot |
| | """ |
| | fig, axes = plt.subplots(2, 2, figsize=(15, 10)) |
| | |
| | |
| | axes[0, 0].plot(train_losses, label='Train Loss', color='blue') |
| | axes[0, 0].plot(eval_losses, label='Eval Loss', color='red') |
| | axes[0, 0].set_xlabel('Step/Epoch') |
| | axes[0, 0].set_ylabel('Loss') |
| | axes[0, 0].set_title('Training and Validation Loss') |
| | axes[0, 0].legend() |
| | axes[0, 0].grid(True, alpha=0.3) |
| | |
| | |
| | if 'accuracy' in eval_metrics: |
| | axes[0, 1].plot(eval_metrics['accuracy'], label='Accuracy', color='green') |
| | axes[0, 1].set_xlabel('Step/Epoch') |
| | axes[0, 1].set_ylabel('Accuracy') |
| | axes[0, 1].set_title('Validation Accuracy') |
| | axes[0, 1].legend() |
| | axes[0, 1].grid(True, alpha=0.3) |
| | |
| | |
| | if 'f1_weighted' in eval_metrics: |
| | axes[1, 0].plot(eval_metrics['f1_weighted'], label='F1 (weighted)', color='purple') |
| | axes[1, 0].set_xlabel('Step/Epoch') |
| | axes[1, 0].set_ylabel('F1 Score') |
| | axes[1, 0].set_title('Validation F1 Score') |
| | axes[1, 0].legend() |
| | axes[1, 0].grid(True, alpha=0.3) |
| | |
| | |
| | if 'precision_weighted' in eval_metrics and 'recall_weighted' in eval_metrics: |
| | axes[1, 1].plot(eval_metrics['precision_weighted'], label='Precision', color='orange') |
| | axes[1, 1].plot(eval_metrics['recall_weighted'], label='Recall', color='cyan') |
| | axes[1, 1].set_xlabel('Step/Epoch') |
| | axes[1, 1].set_ylabel('Score') |
| | axes[1, 1].set_title('Validation Precision and Recall') |
| | axes[1, 1].legend() |
| | axes[1, 1].grid(True, alpha=0.3) |
| | |
| | plt.tight_layout() |
| | os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True) |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| | plt.close() |
| | |
| | print(f"Training curves saved to {save_path}") |
| |
|