| """ |
| PRISM Dataset Analysis β N subunits Γ function Γ diversity |
| """ |
|
|
| import json, re |
| import numpy as np |
| import pandas as pd |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import matplotlib.gridspec as gridspec |
| from matplotlib.patches import Patch |
| from matplotlib.colors import LinearSegmentedColormap |
|
|
| |
|
|
| df = pd.read_csv("data/interim/labeled_dataset.tsv", sep="\t") |
| uids = df.uniprot_ids.apply(lambda x: json.loads(x) if pd.notna(x) else []) |
| df["n_unique"] = uids.apply(lambda x: len(set(x))) |
| df["diversity"] = df["n_unique"] / df["n_subunits"].clip(lower=1) |
|
|
| |
| def repeat_class(row): |
| if row.diversity >= 1.0: return "All-unique\n(pure heteromer)" |
| if row.diversity <= 1 / row.n_subunits + 1e-6 and row.n_unique == 1: |
| return "All-same\n(pure homomer)" |
| return "Mixed\n(repeated subunits)" |
|
|
| df["repeat_class"] = df.apply(repeat_class, axis=1) |
|
|
| |
| FUNC_CATS = { |
| "Kinase /\nPhosphatase": r"kinas|phosphatas|phosphoryl", |
| "Transcription /\nRNA Pol": r"transcri|rna pol|mediator", |
| "Protease /\nPeptidase": r"proteas|peptidas|caspas|trypsin", |
| "Transport /\nChannel": r"transport|channel|pump|import|export", |
| "Signaling /\nGTPase": r"signal|gtpas|ras |rho |rab |ran |g.protein", |
| "Chaperone /\nFolding": r"chaperoni|groel|groes|hsp|heat shock", |
| "Ubiquitin /\nLigase": r"ubiquitin|ligase|cullin|skp|ring.*finger", |
| "Ribosome /\nTranslation": r"ribosom|translat|eif|elongation factor", |
| "DNA Repair /\nReplication":r"dna rep|repair|pcna|helicase|primase", |
| } |
|
|
| def classify(name): |
| if not isinstance(name, str): return "Other" |
| n = name.lower() |
| for cat, pat in FUNC_CATS.items(): |
| if re.search(pat, n): return cat |
| return "Other" |
|
|
| df["func_cat"] = df.name.apply(classify) |
|
|
| |
| SRC_LABEL = { |
| "corum": "CORUM 5.2", |
| "rcsb_search": "RCSB PDB", |
| "complex_portal": "Complex Portal", |
| "experimental_gt": "Experimental GT", |
| } |
| df["source_label"] = df.source.map(SRC_LABEL) |
|
|
| TIER_LABEL = {1: "Tier 1 β
\n(Experimental)", 2: "Tier 2 β\n(Structure + BSA)", 3: "Tier 3 β\n(Seq-length BSA)"} |
| df["tier_label"] = df.gt_tier.map(TIER_LABEL) |
|
|
| |
|
|
| SRC_COLORS = {"CORUM 5.2": "#4e9af1", "RCSB PDB": "#f1a54e", |
| "Complex Portal": "#6bd16b", "Experimental GT": "#e05c5c"} |
| TIER_COLORS = {"Tier 1 β
\n(Experimental)": "#d29922", |
| "Tier 2 β\n(Structure + BSA)": "#4e9af1", |
| "Tier 3 β\n(Seq-length BSA)": "#8b949e"} |
| FUNC_COLORS = plt.cm.tab10(np.linspace(0, 1, len(FUNC_CATS) + 1)) |
|
|
| |
|
|
| fig = plt.figure(figsize=(20, 18), facecolor="#0d1117") |
| fig.suptitle("PRISM Training Corpus β Complex Composition & Diversity Analysis", |
| fontsize=16, fontweight="bold", color="#e6edf3", y=0.98) |
|
|
| gs = gridspec.GridSpec(3, 3, figure=fig, hspace=0.52, wspace=0.38, |
| left=0.07, right=0.97, top=0.94, bottom=0.06) |
|
|
| ax_style = dict(facecolor="#161b22", labelcolor="#8b949e", |
| tick_params=dict(colors="#8b949e")) |
|
|
| def style(ax, title="", xlabel="", ylabel=""): |
| ax.set_facecolor("#161b22") |
| ax.set_title(title, color="#e6edf3", fontsize=11, fontweight="bold", pad=8) |
| ax.set_xlabel(xlabel, color="#8b949e", fontsize=9) |
| ax.set_ylabel(ylabel, color="#8b949e", fontsize=9) |
| ax.tick_params(colors="#8b949e", labelsize=8) |
| for spine in ax.spines.values(): |
| spine.set_edgecolor("#30363d") |
| ax.grid(axis="y", color="#30363d", linewidth=0.5, alpha=0.6) |
| ax.set_axisbelow(True) |
|
|
| |
|
|
| ax_a = fig.add_subplot(gs[0, :2]) |
| style(ax_a, "A β Subunit Count Distribution by Source", |
| "Number of subunits per complex", "Complex count") |
|
|
| n_vals = range(2, 13) |
| src_order = ["CORUM 5.2", "RCSB PDB", "Complex Portal", "Experimental GT"] |
| bottom = np.zeros(len(n_vals)) |
| for src in src_order: |
| counts = [len(df[(df.source_label == src) & (df.n_subunits == n)]) for n in n_vals] |
| bars = ax_a.bar(list(n_vals), counts, bottom=bottom, |
| color=SRC_COLORS[src], label=src, width=0.7, alpha=0.9) |
| bottom += np.array(counts) |
|
|
| ax_a.set_xticks(list(n_vals)) |
| ax_a.legend(loc="upper right", fontsize=8, framealpha=0.2, |
| labelcolor="#e6edf3", facecolor="#161b22") |
|
|
| |
| for i, n in enumerate(n_vals): |
| total = int(bottom[i]) |
| if total > 0: |
| ax_a.text(n, total + 30, f"{total:,}", ha="center", va="bottom", |
| color="#8b949e", fontsize=7) |
|
|
| |
|
|
| ax_b = fig.add_subplot(gs[0, 2]) |
| style(ax_b, "B β Tier Distribution", "Source", "% of complexes") |
|
|
| tier_order = list(TIER_LABEL.values()) |
| x = np.arange(len(src_order)) |
| w = 0.22 |
| for ti, tlabel in enumerate(tier_order): |
| vals = [] |
| for src in src_order: |
| sub = df[df.source_label == src] |
| vals.append(100 * len(sub[sub.tier_label == tlabel]) / max(len(sub), 1)) |
| ax_b.bar(x + (ti - 1) * w, vals, width=w, |
| color=TIER_COLORS[tlabel], label=tlabel.replace("\n", " "), alpha=0.9) |
|
|
| ax_b.set_xticks(x) |
| ax_b.set_xticklabels([s.replace(" ", "\n") for s in src_order], fontsize=7) |
| ax_b.legend(fontsize=6.5, framealpha=0.2, labelcolor="#e6edf3", facecolor="#161b22") |
| ax_b.set_ylim(0, 105) |
| ax_b.yaxis.set_major_formatter(plt.FuncFormatter(lambda v, _: f"{v:.0f}%")) |
|
|
| |
|
|
| ax_c = fig.add_subplot(gs[1, :]) |
| style(ax_c, "C β Subunit Count by Functional Category", |
| "Functional category (name-keyword classified)", "N subunits") |
|
|
| cat_order = sorted(FUNC_CATS.keys(), key=lambda c: -len(df[df.func_cat == c])) |
| cat_order.append("Other") |
| cat_data = [df.loc[df.func_cat == c, "n_subunits"].values for c in cat_order] |
| cat_counts= [len(d) for d in cat_data] |
| cat_labels= [f"{c}\n(n={cnt:,})" for c, cnt in zip(cat_order, cat_counts)] |
|
|
| parts = ax_c.violinplot(cat_data, positions=range(len(cat_order)), |
| showmedians=True, showextrema=False, widths=0.7) |
| for i, (pc, cat) in enumerate(zip(parts["bodies"], cat_order)): |
| c = FUNC_COLORS[i % len(FUNC_COLORS)] |
| pc.set_facecolor(c) |
| pc.set_alpha(0.75) |
| pc.set_edgecolor("#30363d") |
|
|
| parts["cmedians"].set_color("#e6edf3") |
| parts["cmedians"].set_linewidth(1.5) |
|
|
| |
| means = [np.mean(d) for d in cat_data] |
| ax_c.scatter(range(len(cat_order)), means, color="white", s=25, zorder=5) |
|
|
| ax_c.set_xticks(range(len(cat_order))) |
| ax_c.set_xticklabels(cat_labels, fontsize=7.5) |
| ax_c.set_yticks(range(2, 13)) |
| ax_c.set_ylim(1.3, 13) |
| ax_c.axhline(df.n_subunits.mean(), color="#58a6ff", linewidth=0.8, |
| linestyle="--", alpha=0.6, label=f"Overall mean ({df.n_subunits.mean():.2f})") |
| ax_c.legend(fontsize=8, framealpha=0.2, labelcolor="#e6edf3", facecolor="#161b22") |
|
|
| |
|
|
| ax_d = fig.add_subplot(gs[2, 0]) |
| style(ax_d, "D β Subunit Identity Matrix\n(N total vs N unique)", "N unique proteins", "N total subunits") |
|
|
| n_range = range(1, 13) |
| heat = np.zeros((12, 12)) |
| for _, row in df.iterrows(): |
| nt, nu = int(row.n_subunits), int(row.n_unique) |
| if 1 <= nt <= 12 and 1 <= nu <= 12: |
| heat[nt - 1, nu - 1] += 1 |
|
|
| |
| heat_log = np.log1p(heat) |
| cmap = LinearSegmentedColormap.from_list("prism", ["#161b22", "#1f3f6e", "#4e9af1", "#e6edf3"]) |
| im = ax_d.imshow(heat_log, cmap=cmap, aspect="auto", origin="lower", |
| extent=[0.5, 12.5, 0.5, 12.5]) |
| plt.colorbar(im, ax=ax_d, label="log(count+1)", shrink=0.8).ax.yaxis.set_tick_params(color="#8b949e") |
|
|
| ax_d.set_xlabel("N unique proteins", color="#8b949e", fontsize=9) |
| ax_d.set_ylabel("N total subunits", color="#8b949e", fontsize=9) |
| ax_d.plot([0.5, 12.5], [0.5, 12.5], "w--", linewidth=0.8, alpha=0.5, label="N unique = N total\n(pure heteromer)") |
| ax_d.legend(fontsize=7, framealpha=0.2, labelcolor="#e6edf3", facecolor="#161b22") |
| ax_d.set_xticks(range(1, 13)) |
| ax_d.set_yticks(range(1, 13)) |
|
|
| |
|
|
| ax_e = fig.add_subplot(gs[2, 1]) |
| ax_e.set_facecolor("#161b22") |
| ax_e.set_title("E β Subunit Repetition Class", color="#e6edf3", |
| fontsize=11, fontweight="bold", pad=8) |
|
|
| rc_counts = df.repeat_class.value_counts() |
| rc_labels = rc_counts.index.tolist() |
| rc_vals = rc_counts.values |
| rc_colors = ["#4e9af1", "#f0883e", "#3fb950"] |
| wedges, texts, autotexts = ax_e.pie( |
| rc_vals, labels=None, autopct="%1.1f%%", startangle=90, |
| colors=rc_colors[:len(rc_labels)], |
| wedgeprops=dict(width=0.55, edgecolor="#0d1117"), |
| pctdistance=0.78, |
| ) |
| for at in autotexts: |
| at.set_fontsize(8) |
| at.set_color("#e6edf3") |
|
|
| ax_e.legend( |
| [Patch(facecolor=rc_colors[i]) for i in range(len(rc_labels))], |
| [f"{l}: {v:,}" for l, v in zip(rc_labels, rc_vals)], |
| loc="lower center", fontsize=7.5, framealpha=0.2, |
| labelcolor="#e6edf3", facecolor="#161b22", ncol=1, |
| bbox_to_anchor=(0.5, -0.18), |
| ) |
|
|
| |
|
|
| ax_f = fig.add_subplot(gs[2, 2]) |
| style(ax_f, "F β Avg Subunits per\nFunctional Category", "Category", "Mean N subunits") |
|
|
| func_means = df.groupby("func_cat").n_subunits.mean().sort_values(ascending=True) |
| |
| func_means = func_means[func_means.index != "Other"] |
| colors_f = [FUNC_COLORS[list(FUNC_CATS.keys()).index(c) % len(FUNC_COLORS)] |
| if c in FUNC_CATS else "#8b949e" for c in func_means.index] |
|
|
| bars = ax_f.barh(range(len(func_means)), func_means.values, color=colors_f, alpha=0.85) |
| ax_f.set_yticks(range(len(func_means))) |
| ax_f.set_yticklabels([c.replace("\n", " ") for c in func_means.index], fontsize=7.5) |
| ax_f.axvline(df.n_subunits.mean(), color="#58a6ff", linewidth=0.8, |
| linestyle="--", alpha=0.7, label=f"Overall ΞΌ={df.n_subunits.mean():.2f}") |
| ax_f.legend(fontsize=7.5, framealpha=0.2, labelcolor="#e6edf3", facecolor="#161b22") |
| for bar, val in zip(bars, func_means.values): |
| ax_f.text(val + 0.02, bar.get_y() + bar.get_height() / 2, |
| f"{val:.2f}", va="center", color="#e6edf3", fontsize=7.5) |
| ax_f.set_xlim(0, func_means.max() + 0.5) |
| ax_f.grid(axis="x", color="#30363d", linewidth=0.5) |
| ax_f.grid(axis="y", visible=False) |
| for spine in ax_f.spines.values(): |
| spine.set_edgecolor("#30363d") |
| ax_f.tick_params(colors="#8b949e", labelsize=8) |
|
|
| |
|
|
| out = "/tmp/prism_corpus_analysis.png" |
| fig.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) |
| print(f"Saved β {out}") |
|
|