| |
| from __future__ import annotations |
| import argparse |
| import json |
| import math |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import seaborn as sns |
| from visualization.utils import CLS_PREFS, DATASET_NAMES, MODEL_NAMES, REG_PREFS |
|
|
|
|
| def is_regression(metrics: Dict[str, float]) -> bool: |
| """Heuristic based on key names.""" |
| reg = ("spearman", "pearson", "r_squared", "rmse", "mse") |
| cls = ("accuracy", "f1", "mcc", "auc", "precision", "recall") |
| |
| filtered_metrics = {k: v for k, v in metrics.items() |
| if 'training_time' not in k.lower() and 'time_seconds' not in k.lower()} |
| keys = {k.lower() for k in filtered_metrics} |
| if any(k for k in keys if any(r in k for r in reg)): |
| return True |
| if any(k for k in keys if any(c in k for c in cls)): |
| return False |
| return False |
|
|
|
|
| def pick_metric(metrics: Dict[str, float], prefs: List[Tuple[str, str]]) -> Tuple[str, str]: |
| """Return (key, pretty_name) for the first preference present in metrics.""" |
| for k, nice in prefs: |
| for mk in metrics: |
| |
| if 'training_time' in mk.lower() or 'time_seconds' in mk.lower(): |
| continue |
| if mk.lower().endswith(k): |
| return k, nice |
| raise KeyError("No preferred metric found.") |
|
|
|
|
| def parse_metric_value(value) -> Tuple[float, float]: |
| """ |
| Parse a metric value that may be in 'mean±std' format or a plain number. |
| Returns (mean, std) where std is 0.0 if not present. |
| """ |
| if isinstance(value, str) and '±' in value: |
| parts = value.split('±') |
| try: |
| mean_val = float(parts[0]) |
| std_val = float(parts[1]) if len(parts) > 1 else 0.0 |
| return mean_val, std_val |
| except ValueError: |
| return math.nan, 0.0 |
| elif isinstance(value, (int, float)): |
| return float(value), 0.0 |
| return math.nan, 0.0 |
|
|
|
|
| def get_metric_value(metrics: Dict[str, float], key_suffix: str) -> float: |
| """Fetch metric value case-/prefix-insensitively; NaN if absent. |
| For mean±std format, returns only the mean value.""" |
| for k, v in metrics.items(): |
| |
| if 'training_time' in k.lower() or 'time_seconds' in k.lower(): |
| continue |
| if k.lower().endswith('_mean') or k.lower().endswith('_std'): |
| continue |
| if k.lower().endswith(key_suffix): |
| mean_val, _ = parse_metric_value(v) |
| return mean_val |
| return math.nan |
|
|
|
|
| def get_metric_value_with_std(metrics: Dict[str, float], key_suffix: str) -> Tuple[float, float, str]: |
| """ |
| Fetch metric value with std case-/prefix-insensitively. |
| Returns (mean, std, display_string) where display_string is formatted for heatmap display. |
| """ |
| for k, v in metrics.items(): |
| |
| if 'training_time' in k.lower() or 'time_seconds' in k.lower(): |
| continue |
| if k.lower().endswith('_mean') or k.lower().endswith('_std'): |
| continue |
| if k.lower().endswith(key_suffix): |
| mean_val, std_val = parse_metric_value(v) |
| if std_val > 0: |
| display_str = f"{mean_val:.2f}±{std_val:.2f}" |
| else: |
| display_str = f"{mean_val:.2f}" |
| return mean_val, std_val, display_str |
| return math.nan, 0.0, "" |
|
|
|
|
| def radar_factory(n_axes: int): |
| theta = np.linspace(0, 2 * np.pi, n_axes, endpoint=False) |
| fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={"polar": True}) |
| ax.set_theta_offset(np.pi / 2) |
| ax.set_theta_direction(-1) |
| return fig, ax, theta |
|
|
|
|
| def plot_radar(*, |
| categories: List[str], |
| models: List[str], |
| data: List[List[float]], |
| title: str, |
| outfile: Path, |
| normalize: bool = False): |
| |
| pretty_categories = [DATASET_NAMES.get(cat, cat) for cat in categories] |
| pretty_models = [MODEL_NAMES.get(m, m) for m in models] |
|
|
| if normalize: |
| arr = np.asarray(data) |
| rng = np.where(np.ptp(arr, axis=0) == 0, 1, np.ptp(arr, axis=0)) |
| data = (arr - arr.min(0)) / rng |
| |
| data = data.tolist() |
|
|
| |
| pretty_categories = pretty_categories + ["Avg"] |
| data = [row + [np.nanmean(row)] for row in data] |
|
|
| fig, ax, theta = radar_factory(len(pretty_categories)) |
| ax.set_thetagrids(np.degrees(theta), pretty_categories, fontsize=11) |
| ax.set_ylim(0, 1.0) |
| ax.set_yticks(np.linspace(0, 1, 11)) |
|
|
| palette = [plt.cm.tab20(i / len(pretty_models)) for i in range(len(pretty_models))] |
| for i, (m, vals) in enumerate(zip(pretty_models, data)): |
| ang = np.concatenate([theta, [theta[0]]]) |
| val = np.concatenate([vals, [vals[0]]]) |
| ax.plot(ang, val, lw=2, label=m, color=palette[i]) |
| ax.fill(ang, val, alpha=.25, color=palette[i]) |
|
|
| ax.grid(True) |
| plt.title(title, pad=20) |
| plt.legend(bbox_to_anchor=(1.25, 1.05)) |
| plt.tight_layout() |
| plt.savefig(outfile, dpi=450, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def bar_plot(datasets: List[str], |
| models: List[str], |
| data: List[List[float]], |
| metric_name: str, |
| outfile: Path): |
| rows = [ |
| {"Dataset": DATASET_NAMES.get(d, d), "Model": MODEL_NAMES.get(m, m), "Score": s} |
| for m, col in zip(models, data) |
| for d, s in zip(datasets, col) |
| ] |
| dfp = pd.DataFrame(rows) |
| plt.figure(figsize=(max(12, .8 * len(datasets)), 8)) |
| sns.barplot(dfp, x="Dataset", y="Score", hue="Model") |
| plt.title(f"{metric_name} across datasets (Cls→F1, Reg→Spearman)") |
| plt.xticks(rotation=45, ha="right") |
| plt.tight_layout() |
| plt.savefig(outfile, dpi=450, bbox_inches="tight") |
| plt.close() |
|
|
|
|
| def normalize_per_dataset(arr: np.ndarray) -> np.ndarray: |
| """ |
| Normalize array per dataset (row-wise). |
| |
| Args: |
| arr: Array of shape (num_datasets, num_models) |
| |
| Returns: |
| Normalized array of same shape, with each row normalized to [0, 1] |
| """ |
| normalized_data = np.zeros_like(arr) |
| for i in range(arr.shape[0]): |
| lowest_performance = np.nanmin(arr[i, :]) |
| best_performance = np.nanmax(arr[i, :]) |
| denom = best_performance - lowest_performance |
| denom = 1 if denom == 0 else denom |
| normalized_data[i, :] = (arr[i, :] - lowest_performance) / denom |
| return normalized_data |
|
|
| def heatmap_plot(datasets: List[str], |
| models: List[str], |
| data: List[List[float]], |
| metric_name: str, |
| outfile: Path, |
| normalize: bool = False, |
| display_strings: List[List[str]] = None, |
| no_std: bool = False): |
| """ |
| Create a heatmap plot. |
| |
| Args: |
| datasets: List of dataset names |
| models: List of model names |
| data: List of lists of mean values (for coloring) |
| metric_name: Name of the metric being plotted |
| outfile: Output file path |
| normalize: Whether to normalize display values |
| display_strings: Optional list of lists of display strings (e.g., "0.85±0.01"). |
| If provided, these are used for annotations instead of raw values. |
| """ |
| arr = np.array(data).T |
| |
| avg_row = np.nanmean(arr, axis=0, keepdims=True) |
| arr_with_avg = np.vstack([arr, avg_row]) |
| datasets_plus_avg = datasets + ['Average'] |
|
|
| |
| clean_model_names = [MODEL_NAMES.get(m, m) for m in models] |
| clean_dataset_names = [DATASET_NAMES.get(d, d) for d in datasets_plus_avg] |
| print(clean_dataset_names) |
| print(datasets_plus_avg) |
|
|
| |
| if display_strings is not None and not no_std: |
| |
| display_arr = np.array(display_strings).T.tolist() |
| |
| avg_display = [] |
| for j in range(len(models)): |
| model_vals = [arr[i, j] for i in range(arr.shape[0]) if not math.isnan(arr[i, j])] |
| if model_vals: |
| avg_display.append(f"{np.mean(model_vals):.4f}") |
| else: |
| avg_display.append("") |
| display_arr.append(avg_display) |
| else: |
| display_arr = None |
|
|
| |
| if normalize: |
| |
| normalized_data = normalize_per_dataset(arr) |
| |
| |
| avg_row_norm = np.nanmean(normalized_data, axis=0, keepdims=True) |
| annot_arr = np.vstack([normalized_data, avg_row_norm]) |
| annot_label = 'Normalized Performance (0-1)' |
| |
| display_arr = None |
| else: |
| annot_arr = arr_with_avg |
| annot_label = metric_name |
|
|
| |
| |
| color_arr = np.zeros_like(arr_with_avg) |
| for i in range(arr_with_avg.shape[0]): |
| row_min = np.nanmin(arr_with_avg[i, :]) |
| row_max = np.nanmax(arr_with_avg[i, :]) |
| denom = row_max - row_min |
| if denom == 0 or np.isnan(denom): |
| color_arr[i, :] = 0.5 |
| else: |
| color_arr[i, :] = (arr_with_avg[i, :] - row_min) / denom |
|
|
| |
| |
| has_std = display_arr is not None |
| |
| |
| |
| |
| |
| |
| cell_width = 1.3 |
| cell_height = 0.7 |
| |
| fig_width = max(8, cell_width * len(clean_model_names)) |
| fig_height = max(6, cell_height * len(clean_dataset_names) + 1) |
| |
| fig, ax = plt.subplots(figsize=(fig_width, fig_height)) |
| |
| |
| from matplotlib.colors import LinearSegmentedColormap |
| colors = ['#3498db', '#85c1e9', '#FFD700'] |
| n_bins = 100 |
| cmap = LinearSegmentedColormap.from_list('blue_yellow', colors, N=n_bins) |
| |
| im = ax.imshow(color_arr, cmap=cmap, aspect='auto', vmin=0, vmax=1) |
| |
| |
| cbar = plt.colorbar(im, ax=ax) |
| cbar.set_label('Worst to Best', fontsize=16) |
| cbar.set_ticks([0, 0.5, 1]) |
| cbar.set_ticklabels(['Worst', 'Mid', 'Best'], fontsize=11) |
| |
| |
| ax.set_xticks(np.arange(len(clean_model_names))) |
| ax.set_yticks(np.arange(len(clean_dataset_names))) |
| ax.set_xticklabels(clean_model_names, rotation=45, ha='right', fontsize=16) |
| ax.set_yticklabels(clean_dataset_names, rotation=0, fontsize=16) |
| |
| |
| |
| font_size = 10 if has_std else 16 |
| for i in range(annot_arr.shape[0]): |
| for j in range(annot_arr.shape[1]): |
| if display_arr is not None and i < len(display_arr) and j < len(display_arr[i]): |
| text_str = display_arr[i][j] |
| else: |
| if i == annot_arr.shape[0] - 1: |
| text_str = f'{annot_arr[i, j]:.4f}' |
| else: |
| text_str = f'{annot_arr[i, j]:.2f}' |
| text = ax.text(j, i, text_str, |
| ha="center", va="center", color="black", fontsize=font_size) |
| |
| |
| for i in range(color_arr.shape[0]): |
| if not np.all(np.isnan(color_arr[i, :])): |
| best_idx = np.nanargmax(color_arr[i, :]) |
| ax.add_patch(plt.Rectangle((best_idx - 0.5, i - 0.5), 1, 1, |
| fill=False, edgecolor='black', lw=3)) |
| |
| |
| if normalize: |
| title = f'{annot_label} Heatmap (Cls→F1, Reg→Spearman)\nColors normalized per dataset' |
| else: |
| title = f'{annot_label} Heatmap (Cls→F1, Reg→Spearman)\nColors normalized per dataset' |
| |
| plt.title(title, pad=20, fontsize=21) |
| plt.ylabel('Dataset', fontsize=17) |
| plt.xlabel('Model', fontsize=17) |
| plt.tight_layout() |
| plt.savefig(outfile, dpi=450, bbox_inches='tight') |
| plt.close() |
|
|
|
|
| def load_tsv(tsv: Path) -> pd.DataFrame: |
| df = pd.read_csv(tsv, sep="\t") |
| for c in df.columns: |
| if c != "dataset": |
| df[c] = df[c].apply(json.loads) |
| return df |
|
|
|
|
| def create_plots(tsv: str, outdir: str, no_std: bool = False): |
| tsv, outdir = Path(tsv), Path(outdir) |
| df = load_tsv(tsv) |
| models = [c for c in df.columns if c != "dataset"] |
|
|
| |
| datasets, scores_by_model = [], {m: [] for m in models} |
| display_by_model = {m: [] for m in models} |
| dataset_types = [] |
|
|
| for _, row in df.iterrows(): |
| name = row["dataset"] |
| metrics0 = row[models[0]] |
| task = "regression" if is_regression(metrics0) else "classification" |
| dataset_types.append(task) |
| prefs = REG_PREFS if task == "regression" else CLS_PREFS |
|
|
| try: |
| suffix, pretty = pick_metric(metrics0, prefs) |
| except KeyError: |
| print(f"[WARN] {name}: no suitable metric – skipped.") |
| continue |
|
|
| datasets.append(name) |
| for m in models: |
| if no_std: |
| mean_val, std_val, display_str = get_metric_value_with_std(row[m], suffix) |
| display_str = f"{mean_val:.2f}" |
| else: |
| mean_val, std_val, display_str = get_metric_value_with_std(row[m], suffix) |
| scores_by_model[m].append(mean_val) |
| display_by_model[m].append(display_str) |
|
|
| if not datasets: |
| raise RuntimeError("No plottable datasets found.") |
|
|
| |
| only_classification = all(t == "classification" for t in dataset_types) |
| only_regression = all(t == "regression" for t in dataset_types) |
|
|
| |
| ordered_datasets = [] |
| ordered_scores = {m: [] for m in models} |
| ordered_display = {m: [] for m in models} |
| ordered_types = [] |
| |
| |
| for ds in DATASET_NAMES.keys(): |
| if ds in datasets: |
| idx = datasets.index(ds) |
| ordered_datasets.append(ds) |
| ordered_types.append(dataset_types[idx]) |
| for m in models: |
| ordered_scores[m].append(scores_by_model[m][idx]) |
| ordered_display[m].append(display_by_model[m][idx]) |
| |
| |
| for ds in datasets: |
| if ds not in ordered_datasets: |
| ordered_datasets.append(ds) |
| idx = datasets.index(ds) |
| ordered_types.append(dataset_types[idx]) |
| for m in models: |
| ordered_scores[m].append(scores_by_model[m][idx]) |
| ordered_display[m].append(display_by_model[m][idx]) |
| |
| |
| datasets = ordered_datasets |
| scores_by_model = ordered_scores |
| display_by_model = ordered_display |
| dataset_types = ordered_types |
|
|
| |
| plot_matrix = [scores_by_model[m] for m in models] |
| display_matrix = [display_by_model[m] for m in models] |
|
|
| |
| model_avgs = [np.nanmean(scores) for scores in plot_matrix] |
| sorted_indices = np.argsort(model_avgs) |
| sorted_models = [models[i] for i in sorted_indices] |
| sorted_plot_matrix = [plot_matrix[i] for i in sorted_indices] |
| sorted_display_matrix = [display_matrix[i] for i in sorted_indices] |
|
|
| |
| arr_for_norm = np.array(plot_matrix).T |
| normalized_data = normalize_per_dataset(arr_for_norm) |
| |
| normalized_model_avgs = np.nanmean(normalized_data, axis=0) |
| sorted_indices_norm = np.argsort(normalized_model_avgs) |
| sorted_models_norm = [models[i] for i in sorted_indices_norm] |
| sorted_plot_matrix_norm = [plot_matrix[i] for i in sorted_indices_norm] |
| sorted_display_matrix_norm = [display_matrix[i] for i in sorted_indices_norm] |
|
|
| fig_tag = tsv.stem |
| outdir = outdir / fig_tag |
| outdir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| radar_path = outdir / f"{fig_tag}_radar_all.png" |
| radar_path_norm = outdir / f"{fig_tag}_radar_all_normalized.png" |
| bar_path = outdir / f"{fig_tag}_bar_all.png" |
| bar_path_norm = outdir / f"{fig_tag}_bar_all_normalized.png" |
| heatmap_path = outdir / f"{fig_tag}_heatmap_all.png" |
| heatmap_path_norm = outdir / f"{fig_tag}_heatmap_all_normalized.png" |
|
|
| |
| if only_classification: |
| subtitle = "Classification datasets plot F1" |
| metric_name = "F1" |
| elif only_regression: |
| subtitle = "Regression datasets plot Spearman rho" |
| metric_name = "Spearman rho" |
| else: |
| subtitle = "Classification datasets plot F1; Regression datasets plot Spearman rho" |
| metric_name = "F1 / Spearman rho" |
| |
| |
| plot_radar(categories=datasets, |
| models=models, |
| data=plot_matrix, |
| title=subtitle, |
| outfile=radar_path, |
| normalize=False) |
| plot_radar(categories=datasets, |
| models=models, |
| data=plot_matrix, |
| title=subtitle + " (Normalized)", |
| outfile=radar_path_norm, |
| normalize=True) |
| |
| bar_plot(datasets, sorted_models, sorted_plot_matrix, metric_name, bar_path) |
| |
| |
| arr = np.asarray(sorted_plot_matrix) |
| rng = np.where(np.ptp(arr, axis=0) == 0, 1, np.ptp(arr, axis=0)) |
| arr_norm = (arr - arr.min(0)) / rng |
| bar_plot(datasets, sorted_models, arr_norm.tolist(), metric_name + " (Normalized)", bar_path_norm) |
| |
| heatmap_plot(datasets, sorted_models, sorted_plot_matrix, metric_name, heatmap_path, |
| normalize=False, display_strings=sorted_display_matrix, no_std=no_std) |
| |
| heatmap_plot(datasets, sorted_models_norm, sorted_plot_matrix_norm, metric_name, heatmap_path_norm, |
| normalize=True, display_strings=sorted_display_matrix_norm, no_std=no_std) |
| print(f"Radar saved to {radar_path}") |
| print(f"Radar (normalized) saved to {radar_path_norm}") |
| print(f"Bar saved to {bar_path}") |
| print(f"Bar (normalized) saved to {bar_path_norm}") |
| print(f"Heatmap saved to {heatmap_path}") |
| print(f"Heatmap (normalized) saved to {heatmap_path_norm}") |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser(description="Generate radar, bar, and heatmap plots for all datasets. Always saves both normalized and unnormalized versions.") |
| ap.add_argument("--input", required=True, help="TSV file with metrics") |
| ap.add_argument("--output_dir", default="plots", help="Directory for plots") |
| ap.add_argument("--no_std", action="store_true", help="Do not display standard deviation in heatmap plots") |
| args = ap.parse_args() |
|
|
| create_plots(Path(args.input), Path(args.output_dir), no_std=args.no_std) |
| print("Finished.") |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| import sys |
| if "--input" in sys.argv: |
| main() |
| else: |
| |
| print("\nRunning plot function tests...") |
| from pathlib import Path |
| tmpdir = Path("plots/test_plots") |
| tmpdir.mkdir(parents=True, exist_ok=True) |
| |
| categories = ["A", "B", "C"] |
| models = ["Model1", "Model2"] |
| data = [ |
| [0.8, 0.6, 0.7], |
| [0.5, 0.9, 0.4], |
| ] |
| |
| radar_path = tmpdir / "test_radar.png" |
| plot_radar(categories=categories, models=models, data=data, title="Test Radar", outfile=radar_path) |
| assert radar_path.exists(), "Radar plot not created!" |
| print(f"Radar plot test passed: {radar_path}") |
| |
| radar_path_norm = tmpdir / "test_radar_normalized.png" |
| plot_radar(categories=categories, models=models, data=data, title="Test Radar (Normalized)", outfile=radar_path_norm, normalize=True) |
| assert radar_path_norm.exists(), "Normalized radar plot not created!" |
| print(f"Normalized radar plot test passed: {radar_path_norm}") |
| |
| bar_path = tmpdir / "test_bar.png" |
| bar_plot(categories, models, data, "Test Metric", bar_path) |
| assert bar_path.exists(), "Bar plot not created!" |
| print(f"Bar plot test passed: {bar_path}") |
| |
| arr = np.asarray(data) |
| rng = np.where(np.ptp(arr, axis=0) == 0, 1, np.ptp(arr, axis=0)) |
| arr_norm = (arr - arr.min(0)) / rng |
| bar_path_norm = tmpdir / "test_bar_normalized.png" |
| bar_plot(categories, models, arr_norm.tolist(), "Test Metric (Normalized)", bar_path_norm) |
| assert bar_path_norm.exists(), "Normalized bar plot not created!" |
| print(f"Normalized bar plot test passed: {bar_path_norm}") |
| |
| heatmap_path = tmpdir / "test_heatmap.png" |
| heatmap_plot(categories, models, data, "Test Metric", heatmap_path) |
| assert heatmap_path.exists(), "Heatmap plot not created!" |
| print(f"Heatmap plot test passed: {heatmap_path}") |
| |
| heatmap_path_norm = tmpdir / "test_heatmap_normalized.png" |
| heatmap_plot(categories, models, data, "Test Metric", heatmap_path_norm, normalize=True) |
| assert heatmap_path_norm.exists(), "Normalized heatmap plot not created!" |
| print(f"Normalized heatmap plot test passed: {heatmap_path_norm}") |
| |
| display_strings = [ |
| ["0.80±0.02", "0.60±0.01", "0.70±0.03"], |
| ["0.50±0.05", "0.90±0.02", "0.40±0.01"], |
| ] |
| heatmap_path_std = tmpdir / "test_heatmap_with_std.png" |
| heatmap_plot(categories, models, data, "Test Metric", heatmap_path_std, display_strings=display_strings) |
| assert heatmap_path_std.exists(), "Heatmap with std not created!" |
| print(f"Heatmap with std test passed: {heatmap_path_std}") |
| print("All plot function tests passed!\n") |