PRISM / src /simulation /assembly_sim.py
Siddhant Bhat
Initial commit: PRISM protein assembly order prediction GNN
1430181
"""
Assembly simulation engine.
Implements greedy and stochastic assembly order prediction.
Evaluates against ground-truth orders using Kendall's Ο„.
Scoring terms:
S(i β†’ complex) = w_sasa * Ξ”SASA_norm
+ w_ppi * PPI_score(i, complex)
+ w_esm * ESM_cosine(i, complex)
+ w_go * GO_coherence(i, complex)
Run standalone:
python assembly_sim.py # toy 4-subunit demo
"""
import random
import math
import itertools
import numpy as np
from dataclasses import dataclass, field
from typing import Callable, Optional
from scipy.stats import kendalltau
# ──────────────────────────────────────────────────────────────────────────────
# Data structures
# ──────────────────────────────────────────────────────────────────────────────
@dataclass
class Complex:
"""A protein complex with N subunits indexed 0..N-1."""
name: str
sequences: list[str] # one per subunit
chain_ids: list[str] # e.g. ["A", "B", "C", ...]
pdb_path: Optional[str] = None
# Precomputed pairwise matrices (filled by build_scoring_matrices)
sasa_matrix: Optional[np.ndarray] = None # (N, N) buried SASA nmΒ²
ppi_matrix: Optional[np.ndarray] = None # (N, N) STRING score [0,1]
esm_matrix: Optional[np.ndarray] = None # (N, N) ESM cosine sim
go_matrix: Optional[np.ndarray] = None # (N, N) GO coherence
ground_truth_order: Optional[list[int]] = None # 0-indexed assembly order
@dataclass
class AssemblyState:
"""Mutable state during a simulation run."""
assembled: list[int] = field(default_factory=list)
unassembled: list[int] = field(default_factory=list)
step_scores: list[float] = field(default_factory=list)
@property
def order(self) -> list[int]:
return self.assembled.copy()
@dataclass
class SimulationResult:
predicted_order: list[int]
ground_truth_order: Optional[list[int]]
kendall_tau: Optional[float]
kendall_pval: Optional[float]
step_scores: list[float]
mode: str # "greedy" | "stochastic"
# ──────────────────────────────────────────────────────────────────────────────
# Scoring
# ──────────────────────────────────────────────────────────────────────────────
def candidate_score(
candidate: int,
assembled: list[int],
sasa_matrix: np.ndarray,
ppi_matrix: np.ndarray,
esm_matrix: np.ndarray,
go_matrix: np.ndarray,
w_sasa: float = 1.0,
w_ppi: float = 0.0,
w_esm: float = 0.0,
w_go: float = 0.0,
) -> float:
"""Score candidate subunit joining current partial complex."""
if not assembled:
return 0.0
sasa = sasa_matrix[candidate, assembled].max()
ppi = ppi_matrix[candidate, assembled].max()
esm = esm_matrix[candidate, assembled].max()
go = go_matrix[candidate, assembled].max()
# Normalize SASA to [0,1]
sasa_max = sasa_matrix.max()
sasa_n = sasa / (sasa_max + 1e-8)
return w_sasa * sasa_n + w_ppi * ppi + w_esm * esm + w_go * go
# ──────────────────────────────────────────────────────────────────────────────
# Simulation modes
# ──────────────────────────────────────────────────────────────────────────────
def greedy_assembly(
cplx: Complex,
seed_subunit: int = 0,
w_sasa: float = 1.0,
w_ppi: float = 0.0,
w_esm: float = 0.0,
w_go: float = 0.0,
) -> SimulationResult:
"""Deterministic greedy: always add the highest-scoring candidate."""
N = len(cplx.sequences)
state = AssemblyState(
assembled=[seed_subunit],
unassembled=[i for i in range(N) if i != seed_subunit],
)
kwargs = dict(
sasa_matrix=_ensure(cplx.sasa_matrix, N),
ppi_matrix=_ensure(cplx.ppi_matrix, N),
esm_matrix=_ensure(cplx.esm_matrix, N),
go_matrix=_ensure(cplx.go_matrix, N),
w_sasa=w_sasa, w_ppi=w_ppi, w_esm=w_esm, w_go=w_go,
)
while state.unassembled:
scores = [
candidate_score(c, state.assembled, **kwargs)
for c in state.unassembled
]
best_idx = int(np.argmax(scores))
chosen = state.unassembled[best_idx]
state.step_scores.append(scores[best_idx])
state.assembled.append(chosen)
state.unassembled.remove(chosen)
return _make_result(state, cplx, "greedy")
def stochastic_assembly(
cplx: Complex,
seed_subunit: int = 0,
temperature: float = 1.0,
w_sasa: float = 1.0,
w_ppi: float = 0.0,
w_esm: float = 0.0,
w_go: float = 0.0,
rng: Optional[random.Random] = None,
) -> SimulationResult:
"""
Stochastic assembly: sample next subunit proportional to softmax(scores/T).
temperatureβ†’0 recovers greedy; temperatureβ†’βˆž gives uniform random.
"""
if rng is None:
rng = random.Random()
N = len(cplx.sequences)
state = AssemblyState(
assembled=[seed_subunit],
unassembled=[i for i in range(N) if i != seed_subunit],
)
kwargs = dict(
sasa_matrix=_ensure(cplx.sasa_matrix, N),
ppi_matrix=_ensure(cplx.ppi_matrix, N),
esm_matrix=_ensure(cplx.esm_matrix, N),
go_matrix=_ensure(cplx.go_matrix, N),
w_sasa=w_sasa, w_ppi=w_ppi, w_esm=w_esm, w_go=w_go,
)
while state.unassembled:
scores = np.array([
candidate_score(c, state.assembled, **kwargs)
for c in state.unassembled
])
# Softmax sampling
scores_t = scores / max(temperature, 1e-8)
scores_t -= scores_t.max() # numerical stability
probs = np.exp(scores_t)
probs /= probs.sum()
idx = rng.choices(range(len(state.unassembled)), weights=probs.tolist())[0]
chosen = state.unassembled[idx]
state.step_scores.append(float(scores[idx]))
state.assembled.append(chosen)
state.unassembled.remove(chosen)
return _make_result(state, cplx, "stochastic")
def ensemble_assembly(
cplx: Complex,
n_runs: int = 100,
temperature: float = 0.5,
w_sasa: float = 1.0,
w_ppi: float = 0.0,
w_esm: float = 0.0,
w_go: float = 0.0,
seed: int = 42,
) -> tuple[list[SimulationResult], dict]:
"""
Run stochastic ensemble, return all results + pathway statistics.
Statistics: mean order, order frequency distribution, bottleneck steps.
"""
rng = random.Random(seed)
results = [
stochastic_assembly(
cplx, temperature=temperature,
w_sasa=w_sasa, w_ppi=w_ppi, w_esm=w_esm, w_go=w_go,
rng=rng,
)
for _ in range(n_runs)
]
# Pathway frequency
order_counts: dict[tuple, int] = {}
for r in results:
key = tuple(r.predicted_order)
order_counts[key] = order_counts.get(key, 0) + 1
# Mean Kendall Ο„
taus = [r.kendall_tau for r in results if r.kendall_tau is not None]
mean_tau = float(np.mean(taus)) if taus else None
# Step entropy: diversity at each assembly step
N = len(cplx.sequences)
step_entropies = []
for step in range(1, N):
choices = [r.predicted_order[step] for r in results]
counts = np.bincount(choices, minlength=N).astype(float)
counts /= counts.sum()
ent = -np.sum(counts[counts > 0] * np.log2(counts[counts > 0]))
step_entropies.append(ent)
stats = {
"mean_tau": mean_tau,
"order_distribution": sorted(order_counts.items(), key=lambda x: -x[1]),
"step_entropies": step_entropies,
"max_entropy_step": int(np.argmax(step_entropies)) + 1 if step_entropies else None,
"n_unique_pathways": len(order_counts),
}
return results, stats
# ──────────────────────────────────────────────────────────────────────────────
# Evaluation
# ──────────────────────────────────────────────────────────────────────────────
def evaluate_kendall_tau(
predicted: list[int], ground_truth: list[int]
) -> tuple[float, float]:
"""Kendall's Ο„ between predicted and ground-truth assembly orders."""
assert len(predicted) == len(ground_truth)
# Map to rank arrays
n = len(predicted)
gt_rank = {v: i for i, v in enumerate(ground_truth)}
pred_ranks = list(range(n))
gt_ranks = [gt_rank[v] for v in predicted]
tau, pval = kendalltau(pred_ranks, gt_ranks)
return float(tau), float(pval)
def ablation_table(cplx: Complex, n_stochastic: int = 50) -> list[dict]:
"""
Run all ablation conditions and return rows for the paper table.
Matches the ABLATION_CONDITIONS from the research plan.
"""
conditions = [
{"name": "SASA only (Marsh baseline)", "w_sasa": 1.0, "w_ppi": 0.0, "w_esm": 0.0, "w_go": 0.0},
{"name": "PPI only", "w_sasa": 0.0, "w_ppi": 1.0, "w_esm": 0.0, "w_go": 0.0},
{"name": "ESM compat only", "w_sasa": 0.0, "w_ppi": 0.0, "w_esm": 1.0, "w_go": 0.0},
{"name": "GO coherence only", "w_sasa": 0.0, "w_ppi": 0.0, "w_esm": 0.0, "w_go": 1.0},
{"name": "PPI + ESM (no GO)", "w_sasa": 0.0, "w_ppi": 0.5, "w_esm": 0.5, "w_go": 0.0},
{"name": "All terms (equal weights)", "w_sasa": 0.25,"w_ppi": 0.25,"w_esm": 0.25,"w_go": 0.25},
]
rows = []
for cond in conditions:
w = {k: v for k, v in cond.items() if k.startswith("w_")}
greedy_r = greedy_assembly(cplx, **w)
_, stats = ensemble_assembly(cplx, n_runs=n_stochastic, **w)
rows.append({
"condition": cond["name"],
"greedy_tau": greedy_r.kendall_tau,
"mean_stochastic_tau": stats["mean_tau"],
"n_unique_pathways": stats["n_unique_pathways"],
})
return rows
# ──────────────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────────────
def _ensure(mat: Optional[np.ndarray], N: int) -> np.ndarray:
return mat if mat is not None else np.zeros((N, N))
def _make_result(state: AssemblyState, cplx: Complex, mode: str) -> SimulationResult:
gt = cplx.ground_truth_order
tau, pval = (evaluate_kendall_tau(state.assembled, gt) if gt else (None, None))
return SimulationResult(
predicted_order=state.assembled,
ground_truth_order=gt,
kendall_tau=tau,
kendall_pval=pval,
step_scores=state.step_scores,
mode=mode,
)
# ──────────────────────────────────────────────────────────────────────────────
# Toy demo
# ──────────────────────────────────────────────────────────────────────────────
def _toy_demo():
"""4-subunit complex with synthetic SASA matrix. Verifies everything runs."""
rng = np.random.default_rng(0)
N = 4
# Synthetic: subunit 0 is the seed, then 2, 1, 3 in order (larger buried area = earlier)
sasa = np.array([
[0, 8.0, 3.0, 1.0],
[8.0, 0, 2.0, 4.0],
[3.0, 2.0, 0, 5.0],
[1.0, 4.0, 5.0, 0 ],
])
cplx = Complex(
name="toy",
sequences=["ACDEF"] * N,
chain_ids=["A", "B", "C", "D"],
sasa_matrix=sasa,
ppi_matrix=sasa / 10,
esm_matrix=rng.uniform(0, 1, (N, N)),
go_matrix=rng.uniform(0, 1, (N, N)),
ground_truth_order=[0, 1, 2, 3],
)
print("=== Greedy (SASA only) ===")
r = greedy_assembly(cplx, w_sasa=1.0)
print(f" Order: {r.predicted_order} Ο„={r.kendall_tau:.3f}")
print("=== Ensemble (all terms) ===")
_, stats = ensemble_assembly(cplx, n_runs=20, w_sasa=0.25, w_ppi=0.25, w_esm=0.25, w_go=0.25)
print(f" Mean Ο„={stats['mean_tau']:.3f} Unique pathways: {stats['n_unique_pathways']}")
print("=== Ablation table ===")
rows = ablation_table(cplx, n_stochastic=20)
header = f"{'Condition':<35} {'Greedy Ο„':>10} {'Mean Ο„':>10} {'Paths':>7}"
print(header)
print("-" * len(header))
for row in rows:
tau_g = f"{row['greedy_tau']:.3f}" if row['greedy_tau'] is not None else " N/A"
tau_s = f"{row['mean_stochastic_tau']:.3f}" if row['mean_stochastic_tau'] is not None else " N/A"
print(f"{row['condition']:<35} {tau_g:>10} {tau_s:>10} {row['n_unique_pathways']:>7}")
if __name__ == "__main__":
_toy_demo()