| """ |
| Assembly simulation engine. |
| |
| Implements greedy and stochastic assembly order prediction. |
| Evaluates against ground-truth orders using Kendall's Ο. |
| |
| Scoring terms: |
| S(i β complex) = w_sasa * ΞSASA_norm |
| + w_ppi * PPI_score(i, complex) |
| + w_esm * ESM_cosine(i, complex) |
| + w_go * GO_coherence(i, complex) |
| |
| Run standalone: |
| python assembly_sim.py # toy 4-subunit demo |
| """ |
|
|
| import random |
| import math |
| import itertools |
| import numpy as np |
| from dataclasses import dataclass, field |
| from typing import Callable, Optional |
| from scipy.stats import kendalltau |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class Complex: |
| """A protein complex with N subunits indexed 0..N-1.""" |
| name: str |
| sequences: list[str] |
| chain_ids: list[str] |
| pdb_path: Optional[str] = None |
|
|
| |
| sasa_matrix: Optional[np.ndarray] = None |
| ppi_matrix: Optional[np.ndarray] = None |
| esm_matrix: Optional[np.ndarray] = None |
| go_matrix: Optional[np.ndarray] = None |
|
|
| ground_truth_order: Optional[list[int]] = None |
|
|
|
|
| @dataclass |
| class AssemblyState: |
| """Mutable state during a simulation run.""" |
| assembled: list[int] = field(default_factory=list) |
| unassembled: list[int] = field(default_factory=list) |
| step_scores: list[float] = field(default_factory=list) |
|
|
| @property |
| def order(self) -> list[int]: |
| return self.assembled.copy() |
|
|
|
|
| @dataclass |
| class SimulationResult: |
| predicted_order: list[int] |
| ground_truth_order: Optional[list[int]] |
| kendall_tau: Optional[float] |
| kendall_pval: Optional[float] |
| step_scores: list[float] |
| mode: str |
|
|
|
|
| |
| |
| |
|
|
| def candidate_score( |
| candidate: int, |
| assembled: list[int], |
| sasa_matrix: np.ndarray, |
| ppi_matrix: np.ndarray, |
| esm_matrix: np.ndarray, |
| go_matrix: np.ndarray, |
| w_sasa: float = 1.0, |
| w_ppi: float = 0.0, |
| w_esm: float = 0.0, |
| w_go: float = 0.0, |
| ) -> float: |
| """Score candidate subunit joining current partial complex.""" |
| if not assembled: |
| return 0.0 |
|
|
| sasa = sasa_matrix[candidate, assembled].max() |
| ppi = ppi_matrix[candidate, assembled].max() |
| esm = esm_matrix[candidate, assembled].max() |
| go = go_matrix[candidate, assembled].max() |
|
|
| |
| sasa_max = sasa_matrix.max() |
| sasa_n = sasa / (sasa_max + 1e-8) |
|
|
| return w_sasa * sasa_n + w_ppi * ppi + w_esm * esm + w_go * go |
|
|
|
|
| |
| |
| |
|
|
| def greedy_assembly( |
| cplx: Complex, |
| seed_subunit: int = 0, |
| w_sasa: float = 1.0, |
| w_ppi: float = 0.0, |
| w_esm: float = 0.0, |
| w_go: float = 0.0, |
| ) -> SimulationResult: |
| """Deterministic greedy: always add the highest-scoring candidate.""" |
| N = len(cplx.sequences) |
| state = AssemblyState( |
| assembled=[seed_subunit], |
| unassembled=[i for i in range(N) if i != seed_subunit], |
| ) |
|
|
| kwargs = dict( |
| sasa_matrix=_ensure(cplx.sasa_matrix, N), |
| ppi_matrix=_ensure(cplx.ppi_matrix, N), |
| esm_matrix=_ensure(cplx.esm_matrix, N), |
| go_matrix=_ensure(cplx.go_matrix, N), |
| w_sasa=w_sasa, w_ppi=w_ppi, w_esm=w_esm, w_go=w_go, |
| ) |
|
|
| while state.unassembled: |
| scores = [ |
| candidate_score(c, state.assembled, **kwargs) |
| for c in state.unassembled |
| ] |
| best_idx = int(np.argmax(scores)) |
| chosen = state.unassembled[best_idx] |
| state.step_scores.append(scores[best_idx]) |
| state.assembled.append(chosen) |
| state.unassembled.remove(chosen) |
|
|
| return _make_result(state, cplx, "greedy") |
|
|
|
|
| def stochastic_assembly( |
| cplx: Complex, |
| seed_subunit: int = 0, |
| temperature: float = 1.0, |
| w_sasa: float = 1.0, |
| w_ppi: float = 0.0, |
| w_esm: float = 0.0, |
| w_go: float = 0.0, |
| rng: Optional[random.Random] = None, |
| ) -> SimulationResult: |
| """ |
| Stochastic assembly: sample next subunit proportional to softmax(scores/T). |
| temperatureβ0 recovers greedy; temperatureββ gives uniform random. |
| """ |
| if rng is None: |
| rng = random.Random() |
|
|
| N = len(cplx.sequences) |
| state = AssemblyState( |
| assembled=[seed_subunit], |
| unassembled=[i for i in range(N) if i != seed_subunit], |
| ) |
|
|
| kwargs = dict( |
| sasa_matrix=_ensure(cplx.sasa_matrix, N), |
| ppi_matrix=_ensure(cplx.ppi_matrix, N), |
| esm_matrix=_ensure(cplx.esm_matrix, N), |
| go_matrix=_ensure(cplx.go_matrix, N), |
| w_sasa=w_sasa, w_ppi=w_ppi, w_esm=w_esm, w_go=w_go, |
| ) |
|
|
| while state.unassembled: |
| scores = np.array([ |
| candidate_score(c, state.assembled, **kwargs) |
| for c in state.unassembled |
| ]) |
| |
| scores_t = scores / max(temperature, 1e-8) |
| scores_t -= scores_t.max() |
| probs = np.exp(scores_t) |
| probs /= probs.sum() |
| idx = rng.choices(range(len(state.unassembled)), weights=probs.tolist())[0] |
| chosen = state.unassembled[idx] |
| state.step_scores.append(float(scores[idx])) |
| state.assembled.append(chosen) |
| state.unassembled.remove(chosen) |
|
|
| return _make_result(state, cplx, "stochastic") |
|
|
|
|
| def ensemble_assembly( |
| cplx: Complex, |
| n_runs: int = 100, |
| temperature: float = 0.5, |
| w_sasa: float = 1.0, |
| w_ppi: float = 0.0, |
| w_esm: float = 0.0, |
| w_go: float = 0.0, |
| seed: int = 42, |
| ) -> tuple[list[SimulationResult], dict]: |
| """ |
| Run stochastic ensemble, return all results + pathway statistics. |
| Statistics: mean order, order frequency distribution, bottleneck steps. |
| """ |
| rng = random.Random(seed) |
| results = [ |
| stochastic_assembly( |
| cplx, temperature=temperature, |
| w_sasa=w_sasa, w_ppi=w_ppi, w_esm=w_esm, w_go=w_go, |
| rng=rng, |
| ) |
| for _ in range(n_runs) |
| ] |
|
|
| |
| order_counts: dict[tuple, int] = {} |
| for r in results: |
| key = tuple(r.predicted_order) |
| order_counts[key] = order_counts.get(key, 0) + 1 |
|
|
| |
| taus = [r.kendall_tau for r in results if r.kendall_tau is not None] |
| mean_tau = float(np.mean(taus)) if taus else None |
|
|
| |
| N = len(cplx.sequences) |
| step_entropies = [] |
| for step in range(1, N): |
| choices = [r.predicted_order[step] for r in results] |
| counts = np.bincount(choices, minlength=N).astype(float) |
| counts /= counts.sum() |
| ent = -np.sum(counts[counts > 0] * np.log2(counts[counts > 0])) |
| step_entropies.append(ent) |
|
|
| stats = { |
| "mean_tau": mean_tau, |
| "order_distribution": sorted(order_counts.items(), key=lambda x: -x[1]), |
| "step_entropies": step_entropies, |
| "max_entropy_step": int(np.argmax(step_entropies)) + 1 if step_entropies else None, |
| "n_unique_pathways": len(order_counts), |
| } |
| return results, stats |
|
|
|
|
| |
| |
| |
|
|
| def evaluate_kendall_tau( |
| predicted: list[int], ground_truth: list[int] |
| ) -> tuple[float, float]: |
| """Kendall's Ο between predicted and ground-truth assembly orders.""" |
| assert len(predicted) == len(ground_truth) |
| |
| n = len(predicted) |
| gt_rank = {v: i for i, v in enumerate(ground_truth)} |
| pred_ranks = list(range(n)) |
| gt_ranks = [gt_rank[v] for v in predicted] |
| tau, pval = kendalltau(pred_ranks, gt_ranks) |
| return float(tau), float(pval) |
|
|
|
|
| def ablation_table(cplx: Complex, n_stochastic: int = 50) -> list[dict]: |
| """ |
| Run all ablation conditions and return rows for the paper table. |
| Matches the ABLATION_CONDITIONS from the research plan. |
| """ |
| conditions = [ |
| {"name": "SASA only (Marsh baseline)", "w_sasa": 1.0, "w_ppi": 0.0, "w_esm": 0.0, "w_go": 0.0}, |
| {"name": "PPI only", "w_sasa": 0.0, "w_ppi": 1.0, "w_esm": 0.0, "w_go": 0.0}, |
| {"name": "ESM compat only", "w_sasa": 0.0, "w_ppi": 0.0, "w_esm": 1.0, "w_go": 0.0}, |
| {"name": "GO coherence only", "w_sasa": 0.0, "w_ppi": 0.0, "w_esm": 0.0, "w_go": 1.0}, |
| {"name": "PPI + ESM (no GO)", "w_sasa": 0.0, "w_ppi": 0.5, "w_esm": 0.5, "w_go": 0.0}, |
| {"name": "All terms (equal weights)", "w_sasa": 0.25,"w_ppi": 0.25,"w_esm": 0.25,"w_go": 0.25}, |
| ] |
|
|
| rows = [] |
| for cond in conditions: |
| w = {k: v for k, v in cond.items() if k.startswith("w_")} |
| greedy_r = greedy_assembly(cplx, **w) |
| _, stats = ensemble_assembly(cplx, n_runs=n_stochastic, **w) |
| rows.append({ |
| "condition": cond["name"], |
| "greedy_tau": greedy_r.kendall_tau, |
| "mean_stochastic_tau": stats["mean_tau"], |
| "n_unique_pathways": stats["n_unique_pathways"], |
| }) |
| return rows |
|
|
|
|
| |
| |
| |
|
|
| def _ensure(mat: Optional[np.ndarray], N: int) -> np.ndarray: |
| return mat if mat is not None else np.zeros((N, N)) |
|
|
|
|
| def _make_result(state: AssemblyState, cplx: Complex, mode: str) -> SimulationResult: |
| gt = cplx.ground_truth_order |
| tau, pval = (evaluate_kendall_tau(state.assembled, gt) if gt else (None, None)) |
| return SimulationResult( |
| predicted_order=state.assembled, |
| ground_truth_order=gt, |
| kendall_tau=tau, |
| kendall_pval=pval, |
| step_scores=state.step_scores, |
| mode=mode, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _toy_demo(): |
| """4-subunit complex with synthetic SASA matrix. Verifies everything runs.""" |
| rng = np.random.default_rng(0) |
| N = 4 |
| |
| sasa = np.array([ |
| [0, 8.0, 3.0, 1.0], |
| [8.0, 0, 2.0, 4.0], |
| [3.0, 2.0, 0, 5.0], |
| [1.0, 4.0, 5.0, 0 ], |
| ]) |
| cplx = Complex( |
| name="toy", |
| sequences=["ACDEF"] * N, |
| chain_ids=["A", "B", "C", "D"], |
| sasa_matrix=sasa, |
| ppi_matrix=sasa / 10, |
| esm_matrix=rng.uniform(0, 1, (N, N)), |
| go_matrix=rng.uniform(0, 1, (N, N)), |
| ground_truth_order=[0, 1, 2, 3], |
| ) |
|
|
| print("=== Greedy (SASA only) ===") |
| r = greedy_assembly(cplx, w_sasa=1.0) |
| print(f" Order: {r.predicted_order} Ο={r.kendall_tau:.3f}") |
|
|
| print("=== Ensemble (all terms) ===") |
| _, stats = ensemble_assembly(cplx, n_runs=20, w_sasa=0.25, w_ppi=0.25, w_esm=0.25, w_go=0.25) |
| print(f" Mean Ο={stats['mean_tau']:.3f} Unique pathways: {stats['n_unique_pathways']}") |
|
|
| print("=== Ablation table ===") |
| rows = ablation_table(cplx, n_stochastic=20) |
| header = f"{'Condition':<35} {'Greedy Ο':>10} {'Mean Ο':>10} {'Paths':>7}" |
| print(header) |
| print("-" * len(header)) |
| for row in rows: |
| tau_g = f"{row['greedy_tau']:.3f}" if row['greedy_tau'] is not None else " N/A" |
| tau_s = f"{row['mean_stochastic_tau']:.3f}" if row['mean_stochastic_tau'] is not None else " N/A" |
| print(f"{row['condition']:<35} {tau_g:>10} {tau_s:>10} {row['n_unique_pathways']:>7}") |
|
|
|
|
| if __name__ == "__main__": |
| _toy_demo() |
|
|