PRISM / analysis_plot.py
Siddhant Bhat
Initial commit: PRISM protein assembly order prediction GNN
1430181
"""
PRISM Dataset Analysis β€” N subunits Γ— function Γ— diversity
"""
import json, re
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
from matplotlib.colors import LinearSegmentedColormap
# ── Load and enrich ──────────────────────────────────────────────────────────
df = pd.read_csv("data/interim/labeled_dataset.tsv", sep="\t")
uids = df.uniprot_ids.apply(lambda x: json.loads(x) if pd.notna(x) else [])
df["n_unique"] = uids.apply(lambda x: len(set(x)))
df["diversity"] = df["n_unique"] / df["n_subunits"].clip(lower=1)
# Classify repeat structure
def repeat_class(row):
if row.diversity >= 1.0: return "All-unique\n(pure heteromer)"
if row.diversity <= 1 / row.n_subunits + 1e-6 and row.n_unique == 1:
return "All-same\n(pure homomer)"
return "Mixed\n(repeated subunits)"
df["repeat_class"] = df.apply(repeat_class, axis=1)
# Functional keyword categories
FUNC_CATS = {
"Kinase /\nPhosphatase": r"kinas|phosphatas|phosphoryl",
"Transcription /\nRNA Pol": r"transcri|rna pol|mediator",
"Protease /\nPeptidase": r"proteas|peptidas|caspas|trypsin",
"Transport /\nChannel": r"transport|channel|pump|import|export",
"Signaling /\nGTPase": r"signal|gtpas|ras |rho |rab |ran |g.protein",
"Chaperone /\nFolding": r"chaperoni|groel|groes|hsp|heat shock",
"Ubiquitin /\nLigase": r"ubiquitin|ligase|cullin|skp|ring.*finger",
"Ribosome /\nTranslation": r"ribosom|translat|eif|elongation factor",
"DNA Repair /\nReplication":r"dna rep|repair|pcna|helicase|primase",
}
def classify(name):
if not isinstance(name, str): return "Other"
n = name.lower()
for cat, pat in FUNC_CATS.items():
if re.search(pat, n): return cat
return "Other"
df["func_cat"] = df.name.apply(classify)
# Source labels
SRC_LABEL = {
"corum": "CORUM 5.2",
"rcsb_search": "RCSB PDB",
"complex_portal": "Complex Portal",
"experimental_gt": "Experimental GT",
}
df["source_label"] = df.source.map(SRC_LABEL)
TIER_LABEL = {1: "Tier 1 β˜…\n(Experimental)", 2: "Tier 2 β—†\n(Structure + BSA)", 3: "Tier 3 β—‡\n(Seq-length BSA)"}
df["tier_label"] = df.gt_tier.map(TIER_LABEL)
# ── Color palettes ────────────────────────────────────────────────────────────
SRC_COLORS = {"CORUM 5.2": "#4e9af1", "RCSB PDB": "#f1a54e",
"Complex Portal": "#6bd16b", "Experimental GT": "#e05c5c"}
TIER_COLORS = {"Tier 1 β˜…\n(Experimental)": "#d29922",
"Tier 2 β—†\n(Structure + BSA)": "#4e9af1",
"Tier 3 β—‡\n(Seq-length BSA)": "#8b949e"}
FUNC_COLORS = plt.cm.tab10(np.linspace(0, 1, len(FUNC_CATS) + 1))
# ── Figure ────────────────────────────────────────────────────────────────────
fig = plt.figure(figsize=(20, 18), facecolor="#0d1117")
fig.suptitle("PRISM Training Corpus β€” Complex Composition & Diversity Analysis",
fontsize=16, fontweight="bold", color="#e6edf3", y=0.98)
gs = gridspec.GridSpec(3, 3, figure=fig, hspace=0.52, wspace=0.38,
left=0.07, right=0.97, top=0.94, bottom=0.06)
ax_style = dict(facecolor="#161b22", labelcolor="#8b949e",
tick_params=dict(colors="#8b949e"))
def style(ax, title="", xlabel="", ylabel=""):
ax.set_facecolor("#161b22")
ax.set_title(title, color="#e6edf3", fontsize=11, fontweight="bold", pad=8)
ax.set_xlabel(xlabel, color="#8b949e", fontsize=9)
ax.set_ylabel(ylabel, color="#8b949e", fontsize=9)
ax.tick_params(colors="#8b949e", labelsize=8)
for spine in ax.spines.values():
spine.set_edgecolor("#30363d")
ax.grid(axis="y", color="#30363d", linewidth=0.5, alpha=0.6)
ax.set_axisbelow(True)
# ── A: N_subunits distribution stacked by source ─────────────────────────────
ax_a = fig.add_subplot(gs[0, :2])
style(ax_a, "A β€” Subunit Count Distribution by Source",
"Number of subunits per complex", "Complex count")
n_vals = range(2, 13)
src_order = ["CORUM 5.2", "RCSB PDB", "Complex Portal", "Experimental GT"]
bottom = np.zeros(len(n_vals))
for src in src_order:
counts = [len(df[(df.source_label == src) & (df.n_subunits == n)]) for n in n_vals]
bars = ax_a.bar(list(n_vals), counts, bottom=bottom,
color=SRC_COLORS[src], label=src, width=0.7, alpha=0.9)
bottom += np.array(counts)
ax_a.set_xticks(list(n_vals))
ax_a.legend(loc="upper right", fontsize=8, framealpha=0.2,
labelcolor="#e6edf3", facecolor="#161b22")
# Annotate total per bar
for i, n in enumerate(n_vals):
total = int(bottom[i])
if total > 0:
ax_a.text(n, total + 30, f"{total:,}", ha="center", va="bottom",
color="#8b949e", fontsize=7)
# ── B: Tier breakdown per source ─────────────────────────────────────────────
ax_b = fig.add_subplot(gs[0, 2])
style(ax_b, "B β€” Tier Distribution", "Source", "% of complexes")
tier_order = list(TIER_LABEL.values())
x = np.arange(len(src_order))
w = 0.22
for ti, tlabel in enumerate(tier_order):
vals = []
for src in src_order:
sub = df[df.source_label == src]
vals.append(100 * len(sub[sub.tier_label == tlabel]) / max(len(sub), 1))
ax_b.bar(x + (ti - 1) * w, vals, width=w,
color=TIER_COLORS[tlabel], label=tlabel.replace("\n", " "), alpha=0.9)
ax_b.set_xticks(x)
ax_b.set_xticklabels([s.replace(" ", "\n") for s in src_order], fontsize=7)
ax_b.legend(fontsize=6.5, framealpha=0.2, labelcolor="#e6edf3", facecolor="#161b22")
ax_b.set_ylim(0, 105)
ax_b.yaxis.set_major_formatter(plt.FuncFormatter(lambda v, _: f"{v:.0f}%"))
# ── C: N_subunits by functional category (violin / box) ──────────────────────
ax_c = fig.add_subplot(gs[1, :])
style(ax_c, "C β€” Subunit Count by Functional Category",
"Functional category (name-keyword classified)", "N subunits")
cat_order = sorted(FUNC_CATS.keys(), key=lambda c: -len(df[df.func_cat == c]))
cat_order.append("Other")
cat_data = [df.loc[df.func_cat == c, "n_subunits"].values for c in cat_order]
cat_counts= [len(d) for d in cat_data]
cat_labels= [f"{c}\n(n={cnt:,})" for c, cnt in zip(cat_order, cat_counts)]
parts = ax_c.violinplot(cat_data, positions=range(len(cat_order)),
showmedians=True, showextrema=False, widths=0.7)
for i, (pc, cat) in enumerate(zip(parts["bodies"], cat_order)):
c = FUNC_COLORS[i % len(FUNC_COLORS)]
pc.set_facecolor(c)
pc.set_alpha(0.75)
pc.set_edgecolor("#30363d")
parts["cmedians"].set_color("#e6edf3")
parts["cmedians"].set_linewidth(1.5)
# Overlay mean dots
means = [np.mean(d) for d in cat_data]
ax_c.scatter(range(len(cat_order)), means, color="white", s=25, zorder=5)
ax_c.set_xticks(range(len(cat_order)))
ax_c.set_xticklabels(cat_labels, fontsize=7.5)
ax_c.set_yticks(range(2, 13))
ax_c.set_ylim(1.3, 13)
ax_c.axhline(df.n_subunits.mean(), color="#58a6ff", linewidth=0.8,
linestyle="--", alpha=0.6, label=f"Overall mean ({df.n_subunits.mean():.2f})")
ax_c.legend(fontsize=8, framealpha=0.2, labelcolor="#e6edf3", facecolor="#161b22")
# ── D: Diversity ratio heatmap: n_subunits Γ— n_unique ────────────────────────
ax_d = fig.add_subplot(gs[2, 0])
style(ax_d, "D β€” Subunit Identity Matrix\n(N total vs N unique)", "N unique proteins", "N total subunits")
n_range = range(1, 13)
heat = np.zeros((12, 12))
for _, row in df.iterrows():
nt, nu = int(row.n_subunits), int(row.n_unique)
if 1 <= nt <= 12 and 1 <= nu <= 12:
heat[nt - 1, nu - 1] += 1
# Log scale for visibility
heat_log = np.log1p(heat)
cmap = LinearSegmentedColormap.from_list("prism", ["#161b22", "#1f3f6e", "#4e9af1", "#e6edf3"])
im = ax_d.imshow(heat_log, cmap=cmap, aspect="auto", origin="lower",
extent=[0.5, 12.5, 0.5, 12.5])
plt.colorbar(im, ax=ax_d, label="log(count+1)", shrink=0.8).ax.yaxis.set_tick_params(color="#8b949e")
ax_d.set_xlabel("N unique proteins", color="#8b949e", fontsize=9)
ax_d.set_ylabel("N total subunits", color="#8b949e", fontsize=9)
ax_d.plot([0.5, 12.5], [0.5, 12.5], "w--", linewidth=0.8, alpha=0.5, label="N unique = N total\n(pure heteromer)")
ax_d.legend(fontsize=7, framealpha=0.2, labelcolor="#e6edf3", facecolor="#161b22")
ax_d.set_xticks(range(1, 13))
ax_d.set_yticks(range(1, 13))
# ── E: Repeat class donut ────────────────────────────────────────────────────
ax_e = fig.add_subplot(gs[2, 1])
ax_e.set_facecolor("#161b22")
ax_e.set_title("E β€” Subunit Repetition Class", color="#e6edf3",
fontsize=11, fontweight="bold", pad=8)
rc_counts = df.repeat_class.value_counts()
rc_labels = rc_counts.index.tolist()
rc_vals = rc_counts.values
rc_colors = ["#4e9af1", "#f0883e", "#3fb950"]
wedges, texts, autotexts = ax_e.pie(
rc_vals, labels=None, autopct="%1.1f%%", startangle=90,
colors=rc_colors[:len(rc_labels)],
wedgeprops=dict(width=0.55, edgecolor="#0d1117"),
pctdistance=0.78,
)
for at in autotexts:
at.set_fontsize(8)
at.set_color("#e6edf3")
ax_e.legend(
[Patch(facecolor=rc_colors[i]) for i in range(len(rc_labels))],
[f"{l}: {v:,}" for l, v in zip(rc_labels, rc_vals)],
loc="lower center", fontsize=7.5, framealpha=0.2,
labelcolor="#e6edf3", facecolor="#161b22", ncol=1,
bbox_to_anchor=(0.5, -0.18),
)
# ── F: Mean N subunits per functional category bar ───────────────────────────
ax_f = fig.add_subplot(gs[2, 2])
style(ax_f, "F β€” Avg Subunits per\nFunctional Category", "Category", "Mean N subunits")
func_means = df.groupby("func_cat").n_subunits.mean().sort_values(ascending=True)
# Exclude "Other" for cleaner view
func_means = func_means[func_means.index != "Other"]
colors_f = [FUNC_COLORS[list(FUNC_CATS.keys()).index(c) % len(FUNC_COLORS)]
if c in FUNC_CATS else "#8b949e" for c in func_means.index]
bars = ax_f.barh(range(len(func_means)), func_means.values, color=colors_f, alpha=0.85)
ax_f.set_yticks(range(len(func_means)))
ax_f.set_yticklabels([c.replace("\n", " ") for c in func_means.index], fontsize=7.5)
ax_f.axvline(df.n_subunits.mean(), color="#58a6ff", linewidth=0.8,
linestyle="--", alpha=0.7, label=f"Overall ΞΌ={df.n_subunits.mean():.2f}")
ax_f.legend(fontsize=7.5, framealpha=0.2, labelcolor="#e6edf3", facecolor="#161b22")
for bar, val in zip(bars, func_means.values):
ax_f.text(val + 0.02, bar.get_y() + bar.get_height() / 2,
f"{val:.2f}", va="center", color="#e6edf3", fontsize=7.5)
ax_f.set_xlim(0, func_means.max() + 0.5)
ax_f.grid(axis="x", color="#30363d", linewidth=0.5)
ax_f.grid(axis="y", visible=False)
for spine in ax_f.spines.values():
spine.set_edgecolor("#30363d")
ax_f.tick_params(colors="#8b949e", labelsize=8)
# ── Save & show ───────────────────────────────────────────────────────────────
out = "/tmp/prism_corpus_analysis.png"
fig.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
print(f"Saved β†’ {out}")