protfunc / server.py
Sbhat2026's picture
fix(predict): decouple organism inference from GO gating; trust dropdown; kill oxygen-binding FP
3812163
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from pydantic import BaseModel
import torch
import torch.nn as nn
import pandas as pd
import joblib
import json
import math
import os
import re
import time
import hashlib
import warnings
import aiohttp
from functools import lru_cache
warnings.filterwarnings("ignore")
# Use all available CPU threads for faster ESM inference
torch.set_num_threads(min(os.cpu_count() or 4, 8))
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
STATIC_DIR = os.path.join(BASE_DIR, "static")
os.makedirs(STATIC_DIR, exist_ok=True)
HF_REPO = "Sbhat2026/protfunc-models"
# Priority order: 35M > unified_v1 > mammal_enriched > v3_fixed > v3 > improved > supp_res2 > baseline
HF_FILES = [
"unified_35M_v1.pth", "unified_35M_v1_thresholds.json",
"unified_v1.pth", "unified_v1_recalibrated.json",
"mammal_enriched.pth", "mammal_enriched_thresholds.json",
"protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json",
"protfunc_v3.pth", "protfunc_v3_thresholds.json",
"improved_res.pth", "improved_per_label_thresholds.json",
"supp_res2.pth",
"baseline_res.pth", "mlb_public_v1.pkl", "go_names.json",
]
OPTIONAL = {
"go_names.json",
"unified_35M_v1.pth", "unified_35M_v1_thresholds.json",
"mammal_enriched.pth", "mammal_enriched_thresholds.json",
"protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json",
"protfunc_v3.pth", "protfunc_v3_thresholds.json",
"improved_res.pth", "improved_per_label_thresholds.json",
"supp_res2.pth",
}
# Globals populated during lifespan startup
device = torch.device("cpu") # always CPU on HF Space
model = None
esm_model = None
batch_converter = None
mlb = None
go_map = {}
go_defs = {} # GO ID -> definition string (from OBO def: field)
mf_terms = set()
go_parents = {} # GO ID -> set of direct parent GO IDs (MF DAG)
go_ancestors = {} # GO ID -> full set of ancestor GO IDs (transitive)
go_depth = {} # GO ID -> min depth from MF root (root = 0)
go_replaced = {} # obsolete GO ID -> replacement GO ID
mf_indices = None
thresholds = {} # label_idx (str) -> float threshold (mammal-calibrated)
temperature = 1.0 # temperature scaling T for mammal inference (logit/T before sigmoid)
# Insect-specific inference params (flat threshold, no temperature scaling needed)
insect_temperature = 1.0
insect_threshold_default = 0.68
NUM_LABELS = 0
_ESM_DIM = 320 # updated to 480 when unified_35M_v1 is loaded
# Supplemented model stats (loaded from checkpoint if present)
supp_mu = None # np.ndarray shape (SUPP_DIM,)
supp_sd = None # np.ndarray shape (SUPP_DIM,)
supp_cols = None # list[str]
model_uses_supp = False # True when model expects supp features
# Taxon probe (logistic regression on ESM embeddings, loaded from taxon_probe.json)
taxon_probe = None # dict with scaler_mean, scaler_std, coef, intercept
# Platt scaling (per-label logistic regression on logits, loaded from platt_mammal.json)
platt_params = {} # label_idx str -> [A, B]
# UniProt taxonomy + annotation cache
_uniprot_cache: dict = {}
# Biological complexity filter constants
MIN_SEQ_LENGTH = 30
MIN_ENTROPY_BITS = 2.5
MAX_DOMINANT_FRAC = 0.60
MIN_DISTINCT_AA = 5
INVALID_AA = set("BJOUXZ")
MF_ROOT = "GO:0003674"
# Kyte-Doolittle hydrophobicity scale
_KD = {'A':1.8,'R':-4.5,'N':-3.5,'D':-3.5,'C':2.5,'Q':-3.5,'E':-3.5,
'G':-0.4,'H':-3.2,'I':4.5,'L':3.8,'K':-3.9,'M':1.9,'F':2.8,
'P':-1.6,'S':-0.8,'T':-0.7,'W':-0.9,'Y':-1.3,'V':4.2}
# Chou-Fasman helix/sheet propensities
_CF_HELIX = {'A':1.42,'R':0.98,'N':0.67,'D':1.01,'C':0.70,'Q':1.11,'E':1.51,
'G':0.57,'H':1.00,'I':1.08,'L':1.21,'K':1.16,'M':1.45,'F':1.13,
'P':0.57,'S':0.77,'T':0.83,'W':1.08,'Y':0.69,'V':1.06}
_CF_SHEET = {'A':0.83,'R':0.93,'N':0.89,'D':0.54,'C':1.19,'Q':1.10,'E':0.37,
'G':0.75,'H':0.87,'I':1.60,'L':1.30,'K':0.74,'M':1.05,'F':1.38,
'P':0.55,'S':0.75,'T':1.19,'W':1.37,'Y':1.47,'V':1.70}
# Disorder-promoting residues (Uversky)
_DISORDER_PROMOTING = set("AERSQGKPTD")
_TM_HYDROPHOBIC = set("AVILMFYW")
_CHARGE = {"K": 1.0, "R": 1.0, "D": -1.0, "E": -1.0}
INSECT_LINEAGE = {"Insecta", "Hexapoda", "Arthropoda", "Chelicerata", "Myriapoda"}
MAMMAL_LINEAGE = {"Mammalia", "Theria", "Eutheria", "Metatheria", "Monotremata"}
# Ordered most-specific → least so overlapping clades resolve correctly
# (e.g. Aves before reptile clades; both sit under Sauropsida).
TAXON_LINEAGE = [
("insect", {"Insecta", "Hexapoda"}),
("mammal", {"Mammalia"}),
("bird", {"Aves"}),
("amphibian", {"Amphibia"}),
("reptile", {"Lepidosauria", "Testudines", "Crocodylia", "Squamata", "Rhynchocephalia"}),
("fish", {"Actinopterygii", "Chondrichthyes", "Cyclostomata", "Dipnoi", "Coelacanthimorpha"}),
("plant", {"Viridiplantae", "Embryophyta", "Streptophyta"}),
("fungus", {"Fungi"}),
("bacteria", {"Bacteria", "Archaea"}),
("virus", {"Viruses"}),
]
def lineage_to_taxon_group(lineage) -> str | None:
"""Map a UniProt organism lineage (list of scientificName strings) to a
display taxon_group. Returns None if no clade matches."""
names = set(lineage or [])
for group, clades in TAXON_LINEAGE:
if names & clades:
return group
return None
# Vertebrate groups share the mammal calibration head; insects use the insect
# head. Microbe/plant/fungus/virus are OUT-OF-DISTRIBUTION for the unified
# model — we still emit GO predictions but FLAG them rather than suppress.
_VERTEBRATE_GROUPS = {"mammal", "bird", "reptile", "amphibian", "fish"}
_OOD_GROUPS = {"plant", "fungus", "bacteria", "virus"}
# Organism-INDEPENDENT per-label threshold floors for GO terms with notorious
# cross-protein false positives — a dense hemoglobin/myoglobin embedding cluster
# in training bleeds "oxygen binding" onto unrelated proteins (e.g. p53). These
# raise the decision bar for those specific labels; they are a calibration pass,
# not a retraining, and apply regardless of inferred organism.
FP_THRESHOLD_OVERRIDES = {
"GO:0019825": 0.97, # oxygen binding
"GO:0005344": 0.95, # oxygen carrier activity
"GO:0019826": 0.95, # oxygen sensor activity
}
def resolve_taxon(req_group, emb_np=None, sequence="", detected_uniprot=None):
"""
Resolve organism, DECOUPLED from GO-term gating.
GO terms are properties of molecular function, not of the host organism, so
organism inference must never semantically suppress GO predictions. This
helper only (a) picks the insect-vs-mammal *calibration* head and (b) labels
provenance / flags. Precedence:
explicit user selection > UniProt lineage > sequence inference.
Returns: {display, calibration, source, conf, ood, low_confidence}.
- display : organism group for the UI (may be bird/fish/bacteria/…)
- calibration : "insect" | "mammal" (the only trained heads)
- source : manual | uniprot | probe | composition
- ood : True if outside the insect+mammal training distribution
- low_confidence : True if an *inferred* call we shouldn't hard-gate on
"""
req_group = (req_group or "auto").lower()
if req_group and req_group not in ("auto", "infer", "unknown", ""):
display, source, conf = req_group, "manual", 1.0 # trust the dropdown
elif detected_uniprot:
display, source, conf = detected_uniprot, "uniprot", 1.0
elif taxon_probe is not None and emb_np is not None:
display, conf = _detect_taxon_probe(emb_np)
source = "probe"
else:
display, conf = _detect_taxon_composition(sequence)
source = "composition"
calibration = "insect" if display == "insect" else "mammal"
ood = display in _OOD_GROUPS
low_confidence = source in ("probe", "composition") and conf < 0.65
return {"display": display, "calibration": calibration, "source": source,
"conf": conf, "ood": ood, "low_confidence": low_confidence}
# Anchor sequences for organism inference via ESM-2 cosine similarity
ANCHOR_SEQUENCES = {
"mammal": "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG",
"bird": "KVFGRCELAAAMKRHGLDNYRGYSLGNWVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCAKKIVSDGNGMNAWVAWRNRCKGTDVQAWIRGCRL",
"reptile": "MVLSAADKTNVKAAWSKVGGHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKAHGKKVADALASAAGHLDDLPGALSALSDLHAHKLRVDPVNFKLLSHCLLVTLACHHPAEFTPAVHASLDKFLASVSTVLTSKYR",
"fish": "MCDEDETTALVCDNGSGLVKAGFAGDDAPRAVFPSIVGRPRHQGVMVGMGQKDSYVGDEAQSKRGILTLKYPIEHGIVTNWDDMEKIWHHTFYNELR",
"insect": "MSKGPAVGIDLGTTYSCVGVFQHGKVEIIANDQGNRTTPSYVAFTDTERLIGDAAKNQVAMNPTNTVFDAKRLIGRKFGDPVVQSDMKHWPFQVINDGDKPKVQVSYKGEKKMMKDISKNKRALRRLQEIADEYQGKEDQGAD",
"plant": "MSPQTETKASVGFKAGVKDYKLTYYTPEYETKDTDILAAFRVTPQPGVPPEEAGAAVAAESSTGTWTTVWTDGLTSLDRYKGRCYRLLNKLLNHSYGRARYVNPEVGDLGDALSTPQDAPINMAFKPFGGLGTPVMRLAHHSGRWFLNAGDWAEANRLAAAKLNLVPVAYKDLPFYISPPELDAVLDRFQKAGSGSGSGSGSGSVLKEVNREIQIAGNFHRYGKPQLTQFVDAMVAQNLGMKPESIAAYTEVHREAFEARAQAAPSS",
"fungus": "MVKVGVNGFGRIGRLVTRAAFNSGKVDVVAINDPFIDLNYMVYMFQYDSTHGVFKGKVKENGKLVINGNPITIFQERDPSKIKWGDAGAEYVVESTGVFTTMEKAGAHLQGGAKRVIISAPSADAPMFVMGVNHEKYDNSLKIVSNASCTTNCLAPLAKVIHDHFGIVEGLMTTVHAITATQKTVDGPSGKLWRDGRGAAQNIIPASTGAAKAVGKVIPELNGKLTGMAFRVPTANVSVVDLTCRLEKPAKYDDIKKVVKQASEGPLKGILGYTEDQVVSCDFNSATHSSTFDAGAGIALNDHFVKLISWYDNEFGYSNRVVDLMAHMASKE",
"bacteria": "MSDKIIHLTDDSFDTDVLKADGAILVDFWAEWCGPCKMIAPILDEIADEYQGKLTVAKLNIDQNPGTAPKYGIRGIPTLLLFKNGEVAATKVGALSKGQLKEFLDANLA",
"virus": "PIVQNLQGQMVHQAISPRTLNAWVKVVEEKAFSPEVIPMFSALSEGATPQDLNTMLNTVGGHQAAMQMLKETINEEAAEWDRLHPVHAGPIAPGQMREPRGSDIAGTTSTLQEQIGWMTHNPPIPVGEIYKRWIILGLNKIVRMYSPTSILDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQNANPDCKTILKALGPAATLEEMMTACQGVGGPGHKARVL",
}
_anchor_embeddings: dict = {} # taxon_group -> L2-normalized ndarray (esm_dim,)
def _detect_taxon_composition(seq: str) -> tuple:
"""
Heuristic taxon detection from amino acid composition.
Uses a linear discriminant calibrated from insect/mammal proteome statistics.
Returns ('insect'|'mammal', confidence_float).
Confidence < 0.60 → ambiguous, caller should treat as 'auto'.
"""
n = len(seq)
if n < 60:
return "mammal", 0.50
seq_u = seq.upper()
freq = {aa: seq_u.count(aa) / n for aa in "ACDEFGHIKLMNPQRSTVWY"}
# Linear discriminant (positive = mammal, negative = insect)
# Derived from empirical insect vs mammal proteome frequency differences
score = 0.0
score += (freq.get("K", 0) - 0.058) * 18.0 # Lys enriched in mammals
score += (freq.get("G", 0) - 0.073) * -12.0 # Gly enriched in insects
score += (freq.get("A", 0) - 0.072) * -9.0 # Ala enriched in insects
score += (freq.get("S", 0) - 0.077) * 7.0 # Ser enriched in mammals
score += (freq.get("T", 0) - 0.056) * 6.0 # Thr enriched in mammals
score += (freq.get("P", 0) - 0.055) * -5.0 # Pro enriched in insects
score += (freq.get("R", 0) - 0.053) * 4.0 # Arg enriched in mammals
p_mammal = 1.0 / (1.0 + math.exp(-score * 2.5))
if p_mammal >= 0.68:
return "mammal", round(p_mammal, 3)
elif p_mammal <= 0.32:
return "insect", round(1.0 - p_mammal, 3)
return "mammal", 0.50 # uncertain → default mammal
def _detect_taxon_probe(emb_np) -> tuple:
"""Use the trained logistic probe on ESM embedding if loaded."""
if taxon_probe is None:
return None, 0.0
import numpy as np
w = taxon_probe["coef"]
b = taxon_probe["intercept"]
mu = taxon_probe["scaler_mean"]
sd = taxon_probe["scaler_std"]
x = (emb_np - mu) / (sd + 1e-12)
logit = float(np.dot(x, w) + b)
p_mammal = 1.0 / (1.0 + math.exp(-logit))
if p_mammal >= 0.65:
return "mammal", round(p_mammal, 3)
elif p_mammal <= 0.35:
return "insect", round(1.0 - p_mammal, 3)
return "mammal", 0.50
def _anchor_infer(sequence: str) -> dict:
"""
Approximate taxon from ESM-2 cosine similarity to single-protein anchors.
The anchors are one reference protein per group, so raw cosine to any one
of them is a weak, easily-overconfident signal (a GAPDH homolog matches the
'fungus' anchor regardless of true taxon). We therefore report a CALIBRATED
confidence: a softmax over all anchor similarities (competition-aware) and a
top-1 vs top-2 margin gate. When the margin is thin the result is flagged
'approximate' so the UI does not present it as authoritative.
"""
import numpy as np
if not _anchor_embeddings:
return {"taxon_group": "unknown", "confidence": 0.0, "method": "anchor_unavailable"}
try:
emb = _get_esm_embedding(sequence[:500]).detach().cpu().numpy().astype(np.float32)
norm = np.linalg.norm(emb)
if norm < 1e-12:
return {"taxon_group": "unknown", "confidence": 0.0, "method": "zero_embedding"}
emb_n = emb / norm
groups = list(_anchor_embeddings.keys())
sims = np.array([float(np.dot(emb_n, _anchor_embeddings[g])) for g in groups])
order = sims.argsort()[::-1]
top, second = order[0], (order[1] if len(order) > 1 else order[0])
margin = float(sims[top] - sims[second])
# Competition-aware confidence (temperature softmax over anchor sims)
T = 0.05
e = np.exp((sims - sims.max()) / T)
soft = e / e.sum()
conf = float(soft[top])
best_group = groups[top]
# Thin margin or low softmax mass → don't claim a confident call
if margin < 0.04 or conf < 0.30:
return {"taxon_group": best_group, "confidence": round(conf, 3),
"method": "approximate", "margin": round(margin, 3)}
return {"taxon_group": best_group, "confidence": round(conf, 3),
"method": "sequence_similarity", "margin": round(margin, 3)}
except Exception as e:
return {"taxon_group": "unknown", "confidence": 0.0, "method": "error", "error": str(e)[:100]}
def infer_organism(sequence: str, user_selection: str | None = None) -> dict:
"""
Infer organism taxon group from sequence or accept user selection.
Returns {"taxon_group": str, "confidence": float, "method": str}.
taxon_group is one of: mammal, bird, reptile, amphibian, fish, insect,
plant, fungus, bacteria, virus, unknown.
NOTE: this is the sequence-only fallback. Prefer identify_protein(), which
grounds the organism in the real UniProt entry when an accession resolves.
"""
if user_selection and user_selection.lower() not in ("auto", "infer", ""):
return {"taxon_group": user_selection.lower(), "confidence": 1.0, "method": "user_specified"}
return _anchor_infer(sequence)
_ACC_RE = re.compile(r'^[A-NR-Z0-9][0-9][A-Z0-9]{3}[0-9](?:[A-Z0-9]{4})?$')
def _clean_seq(seq: str) -> str:
return ''.join(c for c in (seq or "").upper() if c.isalpha())
async def _peptide_search(sess, seq: str, sp_only: bool, max_candidates: int = 6) -> list:
"""UniProt async peptide search → list of candidate accessions (best-effort)."""
import asyncio
if len(seq) < 8:
return []
# An interior window is specific enough to pin the protein while keeping the
# search fast; the whole sequence is used when it is already short.
probe = seq if len(seq) <= 80 else seq[len(seq) // 2 - 40: len(seq) // 2 + 40]
data = {"peps": probe, "spOnly": "on" if sp_only else "off", "lEQi": "off"}
hdrs = {"User-Agent": "FABLE/1.0", "Accept": "*/*"}
try:
async with sess.post("https://peptidesearch.uniprot.org/asyncrest/", data=data,
headers=hdrs, allow_redirects=False,
timeout=aiohttp.ClientTimeout(total=15)) as r:
job = r.headers.get("Location")
if not job:
return []
job = job.replace("http://", "https://")
for _ in range(12): # poll up to ~18s
async with sess.get(job, headers=hdrs, allow_redirects=False,
timeout=aiohttp.ClientTimeout(total=15)) as jr:
if jr.status == 200:
body = (await jr.text()).strip()
return [a for a in body.split(",") if a][:max_candidates]
await asyncio.sleep(1.5) # 202/303 → still running
except Exception:
return []
return []
async def _fetch_fasta_seq(sess, acc: str):
try:
async with sess.get(f"https://rest.uniprot.org/uniprotkb/{acc}.fasta",
timeout=aiohttp.ClientTimeout(total=10)) as r:
if r.status != 200:
return None
txt = await r.text()
return ''.join(l.strip() for l in txt.splitlines() if not l.startswith(">")).upper()
except Exception:
return None
async def _fetch_organism(sess, acc: str) -> dict:
try:
url = (f"https://rest.uniprot.org/uniprotkb/{acc}.json"
f"?fields=organism_name,lineage,reviewed,protein_name")
async with sess.get(url, headers={"Accept": "application/json", "User-Agent": "FABLE/1.0"},
timeout=aiohttp.ClientTimeout(total=10)) as r:
if r.status != 200:
return {}
raw = await r.json()
org = raw.get("organism", {})
# The `lineage` field returns plain strings; `organism_lineage` returns
# {scientificName: ...} dicts. Accept either shape.
lineage = [x if isinstance(x, str) else x.get("scientificName", "")
for x in org.get("lineage", [])]
return {
"organism": org.get("scientificName", ""),
"common": org.get("commonName", ""),
"lineage": lineage,
"reviewed": raw.get("entryType", "").startswith("UniProtKB reviewed"),
}
except Exception:
return {}
async def identify_protein(sequence: str, user_selection: str | None = None,
accession: str | None = None) -> dict:
"""
High-value organism + accession resolver.
Resolution order:
1. Explicit organism selection from the user → trusted.
2. A known accession (typed, extracted from a FASTA header, or found via
UniProt peptide search + exact-sequence verification) → REAL organism
and lineage from UniProt. This replaces the weak anchor-cosine guess.
3. Fallback: calibrated anchor-cosine similarity (clearly flagged).
Returns: {taxon_group, organism, confidence, method, identity, accession,
reviewed, common}.
"""
seq = _clean_seq(sequence)
if user_selection and user_selection.lower() not in ("auto", "infer", ""):
return {"taxon_group": user_selection.lower(), "confidence": 1.0,
"method": "user_specified", "accession": (accession or "").upper().strip()}
if not seq:
return {"taxon_group": "unknown", "confidence": 0.0, "method": "empty_sequence"}
acc = (accession or "").upper().strip()
try:
async with aiohttp.ClientSession() as sess:
# Candidate accessions: trust a supplied one, else peptide-search.
if acc and _ACC_RE.match(acc):
candidates = [acc]
else:
candidates = await _peptide_search(sess, seq, sp_only=True)
if not candidates:
candidates = await _peptide_search(sess, seq, sp_only=False)
exact, similar = None, None
for a in candidates:
cseq = await _fetch_fasta_seq(sess, a)
if not cseq:
continue
if cseq == seq:
exact = a
break
if len(seq) >= 25 and (seq in cseq or cseq in seq):
similar = similar or a
chosen = exact or similar
if chosen:
org = await _fetch_organism(sess, chosen)
tg = lineage_to_taxon_group(org.get("lineage", [])) or "unknown"
return {
"taxon_group": tg,
"organism": org.get("organism", ""),
"common": org.get("common", ""),
"accession": chosen,
"identity": "exact" if exact else "similar",
"reviewed": bool(org.get("reviewed")),
"confidence": 0.99 if exact else 0.65,
"method": "uniprot_exact" if exact else "uniprot_similar",
}
except Exception:
pass
# No UniProt grounding → calibrated sequence-similarity fallback.
return _anchor_infer(seq)
def write_saliency_to_bfactor(pdb_str: str, saliency: list) -> str:
"""Replace B-factor column (cols 60-66) in ATOM records with saliency ×100."""
lines = []
for line in pdb_str.split("\n"):
if line.startswith("ATOM") or line.startswith("HETATM"):
try:
res_num = int(line[22:26].strip()) - 1
if 0 <= res_num < len(saliency):
score = float(saliency[res_num]) * 100
line = line[:60] + f"{score:6.2f}" + line[66:]
except (ValueError, IndexError):
pass
lines.append(line)
return "\n".join(lines)
async def get_structure_pdb(sequence: str, uniprot_id: str | None = None) -> tuple:
"""
Fetch PDB-format structure string.
Strategy 1: AlphaFold PDB via UniProt accession (if provided).
Strategy 2: ESMFold REST API fallback.
Returns (pdb_str | None, source_str).
"""
import aiohttp
# Strategy 1 — AlphaFold
if uniprot_id:
uid = uniprot_id.upper().strip()
try:
async with aiohttp.ClientSession() as sess:
async with sess.get(
f"https://alphafold.ebi.ac.uk/api/prediction/{uid}",
headers={"Accept": "application/json", "User-Agent": "FABLE/1.0"},
timeout=aiohttp.ClientTimeout(total=10),
) as r:
if r.status == 200:
entries = await r.json()
pdb_url = entries[0].get("pdbUrl", "")
if pdb_url:
async with sess.get(pdb_url, timeout=aiohttp.ClientTimeout(total=15)) as r2:
if r2.status == 200:
return await r2.text(), "alphafold"
except Exception:
pass
# Strategy 2 — ESMFold
try:
async with aiohttp.ClientSession() as sess:
async with sess.post(
"https://api.esmatlas.com/foldSequence/v1/pdb/",
data=sequence,
headers={"Content-Type": "text/plain"},
timeout=aiohttp.ClientTimeout(total=60),
) as r:
if r.status == 200:
text = await r.text()
if text.strip().startswith("ATOM") or "ATOM" in text[:500]:
return text, "esmfold"
except Exception:
pass
return None, "unavailable"
FEATURE_META = [
{"key": "f_seq_len", "label": "Sequence Length", "desc": "Global protein length (uniform per-residue)", "color": "#888888"},
{"key": "f_mean_hydro", "label": "Hydrophobicity", "desc": "Kyte-Doolittle hydrophobicity (window=5)", "color": "#f4a261"},
{"key": "f_net_charge", "label": "Net Charge", "desc": "Local charge balance K+R−D−E (window=9)", "color": "#457b9d"},
{"key": "f_uversky_disorder", "label": "Disorder Score", "desc": "Uversky charge-hydrophobicity disorder criterion (window=11)", "color": "#9b5de5"},
{"key": "f_idr_frac_proxy", "label": "IDR Residues", "desc": "Disorder-promoting residues: A,E,R,S,Q,G,K,P,T,D", "color": "#00b4d8"},
{"key": "f_lowcomp_proxy", "label": "Low Complexity", "desc": "Repetitive amino acid runs (length ≥5)", "color": "#adb5bd"},
{"key": "f_tm_frac_proxy", "label": "TM Helix Windows", "desc": "Transmembrane helix windows (≥17/20 hydrophobic residues)", "color": "#e63946"},
{"key": "f_tm_any_proxy", "label": "TM Present", "desc": "Presence of any transmembrane window", "color": "#c1121f"},
{"key": "f_signal_peptide_proxy", "label": "Signal Peptide", "desc": "N-terminal hydrophobic signal (linear decay, first 30 aa)", "color": "#2d6a4f"},
{"key": "f_cf_helix_mean", "label": "α-Helix Propensity", "desc": "Chou-Fasman α-helix propensity per residue", "color": "#4361ee"},
{"key": "f_cf_sheet_mean", "label": "β-Sheet Propensity", "desc": "Chou-Fasman β-sheet propensity per residue", "color": "#e76f51"},
]
def compute_seq_features(seq: str) -> dict:
"""
Compute the 11 sequence-based supplementary features that are always
available at inference time. Returns a dict keyed by SUPP_COL names.
AF-derived features (f_afdb_has_model, f_plddt_*, f_distbin_*, f_pae_*,
f_seqfeat_present, f_af_present) are set to 0 — they will be z-scored to
near-zero against training means where >99% of proteins also had no AF data.
"""
seq_u = seq.upper()
n = len(seq_u)
kd = [_KD.get(aa, 0.0) for aa in seq_u]
mean_hydro = sum(kd) / n
net_charge = (seq_u.count('R') + seq_u.count('K') -
seq_u.count('D') - seq_u.count('E')) / n
# Uversky charge-hydrophobicity disorder criterion
uversky_disorder = float(abs(mean_hydro) - abs(net_charge) < 0.06)
idr_frac = sum(1 for aa in seq_u if aa in _DISORDER_PROMOTING) / n
# Low-complexity: runs of the same amino acid
lowcomp = 0
i, prev, run = 0, '', 0
for aa in seq_u:
run = run + 1 if aa == prev else 1
if run >= 5:
lowcomp += 1
prev = aa
lowcomp_proxy = lowcomp / n
# TM helix proxy: windows of ≥17 hydrophobic residues in 20-aa window
tm_count = 0
for i in range(n - 19):
window = seq_u[i:i+20]
if sum(1 for aa in window if aa in _TM_HYDROPHOBIC) >= 17:
tm_count += 1
tm_frac = tm_count / max(n - 19, 1)
tm_any = float(tm_count > 0)
# Signal peptide proxy: first 30 aa have a hydrophobic core
sp_window = seq_u[:30]
sp_proxy = float(sum(1 for aa in sp_window if aa in _TM_HYDROPHOBIC) >= 8)
# Chou-Fasman secondary structure propensity
cf_helix = sum(_CF_HELIX.get(aa, 1.0) for aa in seq_u) / n
cf_sheet = sum(_CF_SHEET.get(aa, 1.0) for aa in seq_u) / n
return {
"f_seq_len": float(n),
"f_mean_hydro": float(mean_hydro),
"f_net_charge": float(net_charge),
"f_uversky_disorder": float(uversky_disorder),
"f_idr_frac_proxy": float(idr_frac),
"f_lowcomp_proxy": float(lowcomp_proxy),
"f_tm_frac_proxy": float(tm_frac),
"f_tm_any_proxy": float(tm_any),
"f_signal_peptide_proxy":float(sp_proxy),
"f_cf_helix_mean": float(cf_helix),
"f_cf_sheet_mean": float(cf_sheet),
# AF-derived features: absent at inference → use 0 (imputed to mean)
"f_afdb_has_model":0.0,"f_plddt_mean":0.0,"f_plddt_std":0.0,
"f_plddt_q10":0.0,"f_plddt_q50":0.0,"f_plddt_q90":0.0,
"f_plddt_frac_gt90":0.0,"f_plddt_frac_gt70":0.0,"f_plddt_frac_lt50":0.0,
"f_distbin_0":0.0,"f_distbin_1":0.0,"f_distbin_2":0.0,"f_distbin_3":0.0,
"f_distbin_4":0.0,"f_distbin_5":0.0,"f_distbin_6":0.0,"f_distbin_7":0.0,
"f_distbin_8":0.0,"f_distbin_9":0.0,
"f_pae_mean":0.0,"f_pae_median":0.0,"f_pae_p90":0.0,"f_pae_p95":0.0,
"f_pae_frac_lt5":0.0,"f_pae_frac_lt10":0.0,"f_pae_frac_gt20":0.0,
"f_seqfeat_present":1.0,"f_af_present":0.0,
}
def compute_per_residue_features(seq: str) -> dict:
"""Return per-residue vectors (normalized [0,1]) for the 11 supp features."""
import numpy as np
seq_u = seq.upper()
n = len(seq_u)
def smooth(arr, w):
hw = w // 2
out = []
for i in range(n):
lo, hi = max(0, i - hw), min(n, i + hw + 1)
out.append(sum(arr[lo:hi]) / (hi - lo))
return out
def normalize(arr):
mn, mx = min(arr), max(arr)
if mx == mn:
return [0.5] * n
return [(v - mn) / (mx - mn) for v in arr]
kd_raw = [_KD.get(aa, 0.0) for aa in seq_u]
charge_raw = [_CHARGE.get(aa, 0.0) for aa in seq_u]
# f_seq_len: uniform, no per-residue signal
r_seq_len = [1.0] * n
# f_mean_hydro: KD per residue smoothed window=5
r_mean_hydro = normalize(smooth(kd_raw, 5))
# f_net_charge: sliding charge window=9
r_net_charge = normalize(smooth(charge_raw, 9))
# f_uversky_disorder: window=11, high = more disordered tendency
uv_raw = []
for i in range(n):
lo, hi = max(0, i - 5), min(n, i + 6)
wkd = sum(kd_raw[lo:hi]) / (hi - lo)
wch = sum(charge_raw[lo:hi]) / (hi - lo)
uv_raw.append(max(0.0, 0.20 - (abs(wkd) - abs(wch))))
r_uversky = normalize(uv_raw)
# f_idr_frac_proxy: binary disorder residue indicator, smoothed
idr_raw = [1.0 if aa in _DISORDER_PROMOTING else 0.0 for aa in seq_u]
r_idr = normalize(smooth(idr_raw, 7))
# f_lowcomp_proxy: residues in amino-acid runs ≥5
lowcomp_raw = [0.0] * n
prev, run = '', 0
for j, aa in enumerate(seq_u):
run = run + 1 if aa == prev else 1
prev = aa
if run >= 5:
for k in range(max(0, j - run + 1), j + 1):
lowcomp_raw[k] = 1.0
r_lowcomp = lowcomp_raw
# f_tm_frac_proxy: residues covered by TM windows (≥17/20 hydrophobic)
tm_raw = [0.0] * n
if n >= 20:
for i in range(n - 19):
if sum(1 for aa in seq_u[i:i+20] if aa in _TM_HYDROPHOBIC) >= 17:
for k in range(i, i + 20):
tm_raw[k] = 1.0
r_tm_frac = tm_raw
# f_tm_any_proxy: same map (at residue level = tm_frac)
r_tm_any = tm_raw
# f_signal_peptide_proxy: linear decay × hydrophobicity, first 30 aa
sp_mod = []
for i in range(n):
weight = max(0.0, 1.0 - i / 30.0)
sp_mod.append(weight * max(0.0, kd_raw[i] / 4.5))
r_sp = normalize(sp_mod) if any(v > 0 for v in sp_mod) else [max(0.0, 1.0 - i / 30.0) for i in range(n)]
# f_cf_helix_mean / f_cf_sheet_mean: propensity per residue smoothed
r_cf_helix = normalize(smooth([_CF_HELIX.get(aa, 1.0) for aa in seq_u], 3))
r_cf_sheet = normalize(smooth([_CF_SHEET.get(aa, 1.0) for aa in seq_u], 3))
return {
"f_seq_len": [round(v, 3) for v in r_seq_len],
"f_mean_hydro": [round(v, 3) for v in r_mean_hydro],
"f_net_charge": [round(v, 3) for v in r_net_charge],
"f_uversky_disorder": [round(v, 3) for v in r_uversky],
"f_idr_frac_proxy": [round(v, 3) for v in r_idr],
"f_lowcomp_proxy": [round(v, 3) for v in r_lowcomp],
"f_tm_frac_proxy": [round(v, 3) for v in r_tm_frac],
"f_tm_any_proxy": [round(v, 3) for v in r_tm_any],
"f_signal_peptide_proxy": [round(v, 3) for v in r_sp],
"f_cf_helix_mean": [round(v, 3) for v in r_cf_helix],
"f_cf_sheet_mean": [round(v, 3) for v in r_cf_sheet],
}
def build_ancestor_cache(go_parents_map: dict) -> dict:
"""Compute full transitive ancestor sets for all GO terms (memoized DFS)."""
cache = {}
def _anc(gid):
if gid in cache:
return cache[gid]
parents = go_parents_map.get(gid, set())
all_anc = set(parents)
for p in parents:
all_anc |= _anc(p)
cache[gid] = all_anc
return all_anc
for gid in go_parents_map:
_anc(gid)
return cache
def _download_with_retry(fname):
from huggingface_hub import hf_hub_download
max_attempts = 6
for attempt in range(1, max_attempts + 1):
try:
print(f" [{attempt}/{max_attempts}] Downloading {fname}...")
path = hf_hub_download(
repo_id=HF_REPO, filename=fname,
local_dir=BASE_DIR, repo_type="model",
token=os.environ.get("HF_TOKEN"),
)
print(f" saved -> {path}")
return
except Exception as e:
if fname in OPTIONAL:
print(f" {fname} is optional, skipping ({e})")
return
if attempt == max_attempts:
raise RuntimeError(f"Could not download '{fname}' after {max_attempts} attempts: {e}")
wait = 2 ** attempt
print(f" Network error, retrying in {wait}s... ({e})")
time.sleep(wait)
def ensure_model_files():
missing = [f for f in HF_FILES if not os.path.exists(os.path.join(BASE_DIR, f))]
if not missing:
print("All model files already present.")
return
print(f"Downloading {len(missing)} file(s) from HuggingFace Hub...")
for fname in missing:
_download_with_retry(fname)
def load_go_map():
try:
df = pd.read_csv(os.path.join(BASE_DIR, "go_annotations_fixed.csv"))
mapping = {}
for _, row in df.iterrows():
go_id = str(row["GO Annotation"]).strip()
raw_name = str(row.get("Gene Ontology (molecular function)", "Unknown"))
mapping[go_id] = raw_name.split(" [")[0].strip()
print(f"GO map: {len(mapping)} labels loaded")
return mapping
except Exception as e:
print(f"GO map load error: {e}")
return {}
def load_thresholds():
"""
Load per-label thresholds and return (mammal_thresholds, mammal_T, insect_T, insect_t_default).
Threshold JSON formats accepted:
{"0": 0.68, ...} — plain float
{"0": {"threshold": 0.68, "tier": 0, "temperature": 3.69}} — rich dict
{"_meta": {...}, "0": 0.68, ...} — new format with metadata
Returns:
flat : dict str_idx -> float (mammal per-label thresholds)
mammal_T : float temperature for mammal inference
ins_T : float temperature for insect inference (usually 1.0)
ins_t_default : float flat threshold for insect labels with no per-label data
"""
for path in [
os.path.join(BASE_DIR, "unified_35M_v1_thresholds.json"), # 35M model thresholds (latest)
os.path.join(BASE_DIR, "unified_v1_recalibrated.json"), # 8M recalibrated: T=3.85, precision-tuned
os.path.join(BASE_DIR, "unified_v1_thresholds.json"),
os.path.join(BASE_DIR, "mammal_enriched_thresholds.json"),
os.path.join(BASE_DIR, "protfunc_v3_fixed_thresholds.json"),
os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
os.path.join(BASE_DIR, "per_label_thresholds.json"),
os.path.join(BASE_DIR, "artifacts", "per_label_thresholds.json"),
]:
if not os.path.exists(path):
continue
print(f"Thresholds loaded from {path}")
with open(path) as f:
raw = json.load(f)
flat = {}
mammal_T = 1.0
ins_T = 1.0
ins_t_def = 0.68
# Extract metadata block if present
meta = raw.pop("_meta", None)
if meta:
mammal_T = float(meta.get("temperature", mammal_T))
ins_T = float(meta.get("insect_temperature", ins_T))
ins_t_def = float(meta.get("insect_global_t", ins_t_def))
for k, v in raw.items():
if isinstance(v, dict):
flat[k] = float(v.get("threshold", 0.5))
mammal_T = float(v.get("temperature", mammal_T))
else:
try:
flat[k] = float(v)
except (TypeError, ValueError):
pass
print(f" {len(flat)} mammal thresholds | mammal_T={mammal_T:.4f} | "
f"insect_T={ins_T:.4f} insect_t={ins_t_def:.2f}")
return flat, mammal_T, ins_T, ins_t_def
print("Thresholds not found, using defaults")
return {}, 1.0, 1.0, 0.68
def parse_obo(path):
"""
Parse go-basic.obo and return:
mf_terms : set of active (non-obsolete) GO IDs with namespace == molecular_function
go_parents : dict GO ID -> set of direct parent GO IDs (is_a + part_of, MF only)
go_names_ob : dict GO ID -> canonical name from OBO (authoritative)
go_replaced : dict obsolete GO ID -> replacement GO ID
go_depth : dict GO ID -> minimum depth from MF root (root = 0)
All relationships are restricted to the MF namespace.
"""
ns_map = {} # id -> namespace
par_map = {} # id -> {parent ids}
name_map = {} # id -> canonical name
def_map = {} # id -> definition string
rep_map = {} # obsolete id -> replaced_by id
alt_map = {} # alt_id -> canonical id
obs_set = set()
cur_id = None
cur_ns = None
cur_nm = None
cur_df = None
cur_par = set()
cur_rep = None
cur_obs = False
cur_alt = []
in_term = False
def flush():
nonlocal cur_id, cur_ns, cur_nm, cur_df, cur_par, cur_rep, cur_obs, cur_alt
if cur_id:
if cur_obs:
obs_set.add(cur_id)
if cur_rep:
rep_map[cur_id] = cur_rep
else:
ns_map[cur_id] = cur_ns or ""
par_map[cur_id] = cur_par
name_map[cur_id] = cur_nm or cur_id
if cur_df:
def_map[cur_id] = cur_df
for a in cur_alt:
alt_map[a] = cur_id
cur_id = None; cur_ns = None; cur_nm = None; cur_df = None
cur_par = set(); cur_rep = None; cur_obs = False; cur_alt = []
with open(path, "r", encoding="utf-8") as fh:
for raw in fh:
line = raw.strip()
if line == "[Term]":
flush(); in_term = True; continue
if line.startswith("[") and line != "[Term]":
flush(); in_term = False; continue
if not in_term:
continue
if line.startswith("id:"):
cur_id = line.split("id:", 1)[1].strip().split()[0]
elif line.startswith("name:"):
cur_nm = line.split("name:", 1)[1].strip()
elif line.startswith("namespace:"):
cur_ns = line.split("namespace:", 1)[1].strip()
elif line.startswith("alt_id:"):
cur_alt.append(line.split("alt_id:", 1)[1].strip().split()[0])
elif line.startswith("def:"):
# def: "description text" [source] — strip quotes and source
raw_def = line.split("def:", 1)[1].strip()
if raw_def.startswith('"'):
end_q = raw_def.find('"', 1)
cur_df = raw_def[1:end_q] if end_q > 0 else raw_def
else:
cur_df = raw_def.split("[")[0].strip()
elif line.startswith("is_obsolete:") and "true" in line:
cur_obs = True
elif line.startswith("replaced_by:"):
cur_rep = line.split("replaced_by:", 1)[1].strip().split()[0]
elif line.startswith("is_a:"):
parent = line.split("is_a:", 1)[1].strip().split()[0]
cur_par.add(parent)
elif line.startswith("relationship:"):
parts = line.split("relationship:", 1)[1].strip().split()
if len(parts) >= 2 and parts[0] in ("part_of", "regulates",
"positively_regulates",
"negatively_regulates"):
cur_par.add(parts[1])
flush()
mf = {gid for gid, n in ns_map.items() if n == "molecular_function"}
go_parents_mf = {gid: (par_map[gid] & mf) for gid in mf}
go_names_ob = {gid: name_map[gid] for gid in mf}
go_defs_mf = {gid: def_map[gid] for gid in mf if gid in def_map}
n_edges = sum(len(v) for v in go_parents_mf.values())
print(f"OBO parsed: {len(mf)} MF terms, {n_edges} parent edges, "
f"{len(rep_map)} replacements, {len(alt_map)} alt-ids, "
f"{len(go_defs_mf)} definitions")
# BFS from root to compute minimum depth for each MF term
go_depth: dict = {}
go_depth[MF_ROOT] = 0
# Build children map for BFS (reverse of parents)
children: dict = {gid: set() for gid in mf}
for gid, parents in go_parents_mf.items():
for p in parents:
if p in children:
children[p].add(gid)
queue = [MF_ROOT]
while queue:
nxt = []
for gid in queue:
d = go_depth[gid]
for child in children.get(gid, ()):
if child not in go_depth:
go_depth[child] = d + 1
nxt.append(child)
queue = nxt
print(f"Depth computed: {len(go_depth)} MF terms, "
f"max depth={max(go_depth.values(), default=0)}")
return mf, go_parents_mf, go_names_ob, rep_map, go_depth, go_defs_mf
def compute_dynamic_cap(sorted_probs: list, seq_len: int) -> int:
"""
Return a protein-proportional cap on the number of direct predictions.
Combines three signals:
1. Sequence-length prior (log2 scaling).
2. Probability-gap detection — largest relative drop within budget.
3. Diffuse-activation penalty — when predictions are bunched at low
confidence with no clear outlier, the cap is tightened. This prevents
the model outputting many near-threshold terms for proteins it is
uncertain about (e.g. sparse mammal annotations).
"""
n = len(sorted_probs)
if n == 0:
return 0
if n <= 3:
return n
# Length prior (log2 scaling)
length_prior = max(2, int(2.5 * math.log2(max(seq_len, 50) / 50) + 2))
abs_cap = min(15, length_prior * 2)
# ── Diffuse-activation penalty ──────────────────────────────────────────
# If the top prediction is below 0.75 AND the spread across all predictions
# is narrow (< 0.12), we're seeing uniform noise rather than clear signal.
top_prob = sorted_probs[0]
spread = sorted_probs[0] - sorted_probs[-1]
if top_prob < 0.75 and spread < 0.12:
# Tight cluster at low confidence: cut cap to length_prior (conservative)
abs_cap = max(3, length_prior)
elif top_prob < 0.72:
# Moderately uncertain: mild tightening
abs_cap = max(3, min(abs_cap, length_prior + 2))
search_end = min(n, abs_cap)
if search_end < 2:
return min(n, abs_cap)
best_score = -1.0
best_idx = search_end # default: all up to abs_cap
for i in range(1, search_end):
prev = sorted_probs[i - 1]
curr = sorted_probs[i]
rel_gap = (prev - curr) / (prev + 1e-6)
dist = abs(i - length_prior) / max(length_prior, 1)
score = rel_gap * (1.0 - 0.25 * min(dist, 1.0))
if score > best_score and rel_gap >= 0.08:
best_score = score
best_idx = i
return max(3, best_idx)
def propagate_and_filter(preds, go_parents_map, go_ancestors_map, prob_map):
"""
1. Propagate predictions upward: for every predicted term, all its MF
ancestors are implicitly predicted. Ancestors not already above threshold
are added as 'implied' predictions with the child's probability.
2. Filter: a term is 'suppressed' only if it has MF parents and NONE of
them appear in the final visible set (direct or implied).
3. Specificity: implied ancestors with depth ≤ 1 (root-adjacent, trivially
general terms) are hidden from the visible list. Each prediction carries
its depth so the UI can convey specificity.
Returns (visible, suppressed) where visible includes informative implied parents.
"""
if not go_ancestors_map:
# Still annotate with depth even without ancestors
for p in preds:
p["depth"] = go_depth.get(p["go_id"], -1)
return preds, []
predicted_ids = {p["go_id"] for p in preds}
implied = {} # go_id -> max child prob
for pred in preds:
gid = pred["go_id"]
prob = pred["prob"]
for anc in go_ancestors_map.get(gid, set()):
if anc not in predicted_ids and anc != MF_ROOT:
implied[anc] = max(implied.get(anc, 0.0), prob)
all_visible_ids = predicted_ids | set(implied.keys())
# Classify direct predictions: visible if any MF parent is visible (or root / no parents)
suppressed = []
direct_ok = []
for pred in preds:
gid = pred["go_id"]
parents = go_parents_map.get(gid, set())
pred["depth"] = go_depth.get(gid, -1)
if gid == MF_ROOT or not parents or (parents & all_visible_ids):
direct_ok.append(pred)
else:
pred["reason"] = "no_visible_parent"
suppressed.append(pred)
# Add implied ancestor terms — skip root-adjacent (depth ≤ 1) and root itself
# Depth ≤ 1 = trivially general terms like "binding", "catalytic activity"
# that carry no predictive specificity on their own.
MIN_IMPLIED_DEPTH = 2
implied_preds = []
for gid, prob in implied.items():
d = go_depth.get(gid, -1)
if d < MIN_IMPLIED_DEPTH:
continue # too general — still implicitly true, just not displayed
implied_preds.append({
"go_id": gid,
"name": go_map.get(gid, gid),
"prob": round(prob, 3),
"implied": True,
"depth": d,
})
implied_preds.sort(key=lambda x: (-x["prob"], -x["depth"]))
visible = direct_ok + implied_preds
visible.sort(key=lambda x: (-x["prob"], -x.get("depth", 0)))
return visible, suppressed
def sequence_entropy(seq):
seq_upper = seq.upper()
counts = {}
for aa in seq_upper:
counts[aa] = counts.get(aa, 0) + 1
n = len(seq_upper)
return -sum((c / n) * math.log2(c / n) for c in counts.values())
def validate_sequence(name, seq):
"""Returns an error string if the sequence should be rejected, else None."""
if len(seq) < MIN_SEQ_LENGTH:
return (f"'{name}' is too short ({len(seq)} aa — minimum {MIN_SEQ_LENGTH} aa). "
f"Sequences this short are unlikely to fold into a stable domain.")
# Reject non-letter characters (digits, spaces, symbols)
non_letter = sorted({c for c in seq if not c.isalpha()})
if non_letter:
display = ", ".join(f"'{c}'" for c in non_letter[:5])
return (f"'{name}' contains non-amino-acid characters: {display}. "
f"Only single-letter amino acid codes are accepted.")
# Detect DNA/RNA sequences (>85% ATCGU with ≤5 distinct chars)
seq_upper_set = {c.upper() for c in seq}
nucleotide_chars = seq_upper_set & set("ATCGU")
nucleotide_frac = sum(seq.upper().count(c) for c in "ATCGU") / len(seq)
if nucleotide_frac > 0.85 and len(seq_upper_set) <= 6:
return (f"'{name}' appears to be a nucleotide sequence (DNA/RNA), not a protein. "
f"Please enter an amino acid sequence in single-letter code.")
bad = sorted({c.upper() for c in seq if c.upper() in INVALID_AA})
if bad:
return (f"'{name}' contains invalid amino acid character(s): "
f"{', '.join(bad)}. These ambiguity codes are not accepted.")
counts = {}
for aa in seq.upper():
counts[aa] = counts.get(aa, 0) + 1
if len(counts) < MIN_DISTINCT_AA:
return (f"'{name}' uses only {len(counts)} distinct residue type(s). "
f"Real proteins require at least {MIN_DISTINCT_AA}.")
dominant_frac = max(counts.values()) / len(seq)
if dominant_frac > MAX_DOMINANT_FRAC:
dominant_aa = max(counts, key=counts.get)
return (f"'{name}' is dominated by a single residue "
f"({dominant_aa} = {dominant_frac:.0%}). "
f"Low-complexity sequences produce unreliable embeddings.")
H = sequence_entropy(seq)
if H < MIN_ENTROPY_BITS:
return (f"'{name}' has very low sequence complexity "
f"(Shannon entropy {H:.2f} bits, minimum {MIN_ENTROPY_BITS:.1f} bits). "
f"Repetitive or artificially constructed sequences are not accepted.")
return None
@asynccontextmanager
async def lifespan(app: FastAPI):
global device, model, esm_model, batch_converter
global mlb, go_map, go_defs, mf_terms, go_parents, go_ancestors, go_depth, go_replaced
global mf_indices, thresholds, temperature, insect_temperature, insect_threshold_default, NUM_LABELS, _ESM_DIM
global supp_mu, supp_sd, supp_cols, model_uses_supp
global taxon_probe, platt_params, _anchor_embeddings
# Step 1: download missing files
ensure_model_files()
# Step 2: GO name map
go_map = load_go_map()
go_names_path = os.path.join(BASE_DIR, "go_names.json")
if os.path.exists(go_names_path):
with open(go_names_path) as f:
go_map.update(json.load(f))
print(f"Canonical GO names loaded: {len(go_map)} total entries")
# Step 3: MLB — load BEFORE anything references mlb.classes_
mlb = joblib.load(os.path.join(BASE_DIR, "mlb_public_v1.pkl"))
NUM_LABELS = len(mlb.classes_)
print(f"MLB loaded: {NUM_LABELS} labels")
# Step 4: OBO — parse MF namespace, parent DAG, names, depth, replacements
obo_path = os.path.join(BASE_DIR, "go-basic.obo")
if os.path.exists(obo_path):
mf_terms, go_parents, go_names_obo, go_replaced, go_depth, go_defs_obo = parse_obo(obo_path)
go_defs.update(go_defs_obo)
# OBO canonical names are the most authoritative — merge over CSV names
go_map.update(go_names_obo)
print(f"OBO names merged: {len(go_names_obo)} MF term names")
mf_in_mlb = sum(1 for gid in mlb.classes_ if gid in mf_terms)
rep_in_mlb = sum(1 for gid in mlb.classes_ if gid in go_replaced)
print(f"OBO cross-check: {mf_in_mlb}/{NUM_LABELS} active MF, "
f"{rep_in_mlb} replaced/obsolete labels remapped")
# Build full transitive ancestor cache for parental propagation
go_ancestors = build_ancestor_cache(go_parents)
print(f"Ancestor cache built: {len(go_ancestors)} terms")
else:
print("WARNING: go-basic.obo not found — hierarchy filtering disabled. "
"Download from https://current.geneontology.org/ontology/go-basic.obo "
"and place it alongside server.py.")
# Step 5: MF-only whitelist — OBO namespace is authoritative, CSV is fallback
if mf_terms:
mf_indices = [i for i, gid in enumerate(mlb.classes_) if gid in mf_terms]
print(f"MF whitelist (OBO): {len(mf_indices)} active indices")
else:
mf_go_ids = {
go_id for go_id, name in go_map.items()
if name and name != go_id and not name.startswith("GO:")
}
mf_indices = [i for i, gid in enumerate(mlb.classes_) if gid in mf_go_ids] or list(range(NUM_LABELS))
print(f"MF whitelist (CSV fallback): {len(mf_indices)} active indices")
app.state.mf_indices = mf_indices
# Step 6: per-label thresholds (mammal-calibrated) + insect fallback params
thresholds, temperature, insect_temperature, insect_threshold_default = load_thresholds()
# Step 7: classifier — auto-detect architecture from checkpoint keys
class ResBlock(nn.Module):
"""Pre-activation residual block with BatchNorm (improved model)."""
def __init__(self, dim, dropout=0.2):
super().__init__()
self.net = nn.Sequential(
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
)
def forward(self, x):
return x + self.net(x)
class ImprovedResidualMLP(nn.Module):
"""4-block, hidden=2048, BatchNorm — trained by train_improved.py."""
def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=2048, n_blocks=4, dropout=0.2):
super().__init__()
self.fc_in = nn.Linear(in_dim, hidden)
self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)])
self.fc_out = nn.Sequential(
nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
h = self.fc_in(x)
for blk in self.blocks:
h = blk(h)
return self.fc_out(h)
class ResidualMLP(nn.Module):
"""Original 2-block notebook model (fallback)."""
def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=1024, dropout=0.2):
super().__init__()
self.fc_in = nn.Linear(in_dim, hidden)
self.block1 = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden))
self.block2 = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden))
self.fc_out = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim))
def forward(self, x):
h = self.fc_in(x)
h = torch.relu(h)
h = h + self.block1(h)
h = h + self.block2(h)
return self.fc_out(h)
class RecoveredBaselineModel(nn.Module):
"""Earlier server-side architecture — retained for backward compatibility."""
def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=1024, dropout=0.2):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden)
self.proj = nn.Linear(in_dim, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.out = nn.Linear(hidden, out_dim)
self.relu = nn.ReLU()
self.drop = nn.Dropout(dropout)
def forward(self, x):
h = self.relu(self.fc1(x))
h = h + self.proj(x)
h = self.relu(self.fc2(h))
h = self.drop(h)
return self.out(h)
import numpy as np
device = torch.device("cpu")
# Prefer checkpoints in priority order: 35M > unified_v1 > mammal_enriched > v3_fixed > improved > v3 > supp_res2 > baseline
ckpt_candidates = [
os.path.join(BASE_DIR, "unified_35M_v1.pth"),
os.path.join(BASE_DIR, "unified_v1.pth"),
os.path.join(BASE_DIR, "mammal_enriched.pth"),
os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"),
os.path.join(BASE_DIR, "improved_res.pth"),
os.path.join(BASE_DIR, "protfunc_v3.pth"),
os.path.join(BASE_DIR, "supp_res2.pth"),
os.path.join(BASE_DIR, "baseline_res.pth"),
]
_ESM_DIM = 320 # updated after checkpoint load if esm_dim present
_model = None
for ckpt_path in ckpt_candidates:
if not os.path.exists(ckpt_path):
continue
print(f"Trying classifier: {os.path.basename(ckpt_path)}")
try:
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
except Exception as e:
print(f" Failed to load {os.path.basename(ckpt_path)}: {e} — skipping")
continue
sd = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
keys = set(sd.keys())
# Reset supp globals for each candidate
_c_supp_mu = None
_c_supp_sd = None
_c_supp_cols = None
_c_uses_supp = False
if isinstance(ckpt, dict) and "supp_mu" in ckpt:
_c_supp_mu = np.array(ckpt["supp_mu"], dtype=np.float32)
_c_supp_sd = np.array(ckpt["supp_sd"], dtype=np.float32)
_c_supp_cols = ckpt["supp_cols"]
_c_uses_supp = True
# Detect architecture and validate in_dim against supp metadata.
# Checkpoints that store explicit "in_dim" may use a truncated supp feature
# set (e.g. mammal_enriched uses SUPP_COLS[:11] → in_dim=331, but stores
# all 39 supp_cols for reference). Trust fc_in.weight shape in that case.
_ckpt_has_explicit_in_dim = isinstance(ckpt, dict) and "in_dim" in ckpt
# Detect out_dim from checkpoint to avoid mismatch with MLB size
_out_dim_ckpt = None
for _out_key in ("fc_out.3.weight", "fc_out.2.weight", "out.weight"):
if _out_key in sd:
_out_dim_ckpt = sd[_out_key].shape[0]
break
if "blocks.0.net.0.weight" in keys:
hidden_dim = sd["fc_in.weight"].shape[0]
n_blocks = sum(1 for k in keys if k.startswith("blocks.") and k.endswith(".net.0.weight"))
in_dim_ckpt = sd["fc_in.weight"].shape[1]
if _c_uses_supp and _c_supp_cols is not None and not _ckpt_has_explicit_in_dim:
expected = _ESM_DIM + len(_c_supp_cols) + 1
if in_dim_ckpt != expected:
print(f" SKIP: supp metadata says in_dim={expected} but fc_in has {in_dim_ckpt} — corrupted checkpoint")
continue
out_dim_use = _out_dim_ckpt or NUM_LABELS
_model = ImprovedResidualMLP(in_dim=in_dim_ckpt, hidden=hidden_dim, n_blocks=n_blocks, out_dim=out_dim_use).to(device)
print(f" ImprovedResidualMLP (in={in_dim_ckpt} hidden={hidden_dim} blocks={n_blocks} out={out_dim_use})")
elif any(k.startswith("fc_in") for k in keys):
in_dim_ckpt = sd["fc_in.weight"].shape[1]
if _c_uses_supp and _c_supp_cols is not None and not _ckpt_has_explicit_in_dim:
expected = _ESM_DIM + len(_c_supp_cols) + 1
if in_dim_ckpt != expected:
print(f" SKIP: supp metadata says in_dim={expected} but fc_in has {in_dim_ckpt} — corrupted checkpoint")
continue
out_dim_use = _out_dim_ckpt or NUM_LABELS
_model = ResidualMLP(in_dim=in_dim_ckpt, out_dim=out_dim_use).to(device)
print(f" ResidualMLP (in={in_dim_ckpt} out={out_dim_use})")
elif any(k.startswith("fc1") for k in keys):
_c_uses_supp = False
out_dim_use = _out_dim_ckpt or NUM_LABELS
_model = RecoveredBaselineModel(out_dim=out_dim_use).to(device)
print(f" RecoveredBaselineModel (legacy out={out_dim_use})")
else:
print(f" SKIP: unrecognised architecture — keys: {sorted(keys)[:8]}")
continue
try:
_model.load_state_dict(sd, strict=True)
except Exception as e:
print(f" SKIP: load_state_dict failed: {e}")
_model = None
continue
# Commit globals and break
supp_mu = _c_supp_mu
supp_sd = _c_supp_sd
supp_cols = _c_supp_cols
model_uses_supp = _c_uses_supp
if isinstance(ckpt, dict) and "val_fmax" in ckpt:
print(f" val_fmax={ckpt['val_fmax']:.4f} epoch={ckpt.get('epoch','?')}")
if _c_uses_supp:
print(f" Supplemented model: {len(_c_supp_cols)} extra features")
# Detect ESM dim from checkpoint metadata
if isinstance(ckpt, dict) and "esm_dim" in ckpt:
_ESM_DIM = int(ckpt["esm_dim"])
print(f" ESM dim from checkpoint: {_ESM_DIM}")
print(f"Classifier loaded: {os.path.basename(ckpt_path)}")
break
if _model is None:
raise RuntimeError("No valid classifier checkpoint found.")
_model.eval()
model = _model
# Step 8: ESM-2 — choose model based on detected esm_dim
import esm as esm_lib
if _ESM_DIM == 480:
_esm_model, alphabet = esm_lib.pretrained.esm2_t12_35M_UR50D()
print("ESM-2 (35M, 480-dim) loaded OK")
else:
_esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D()
print("ESM-2 (8M, 320-dim) loaded OK")
esm_model = _esm_model.to(device).eval()
batch_converter = alphabet.get_batch_converter()
# Step 9: Taxon probe (optional, generated by calibrate_server.py / calibrate_probe_35M.py)
probe_path = os.path.join(BASE_DIR, "taxon_probe.json")
if os.path.exists(probe_path):
with open(probe_path) as f:
taxon_probe = json.load(f)
acc = taxon_probe.get("train_accuracy", 0)
probe_esm_dim = taxon_probe.get("esm_dim", 320)
if probe_esm_dim != _ESM_DIM:
print(f"Taxon probe ESM dim mismatch ({probe_esm_dim} vs {_ESM_DIM}) — disabling probe")
taxon_probe = None
else:
for k in ("coef", "intercept", "scaler_mean", "scaler_std"):
if k in taxon_probe:
taxon_probe[k] = np.asarray(taxon_probe[k], dtype=np.float32)
print(f"Taxon probe loaded (train_acc={acc:.4f}, esm_dim={probe_esm_dim})")
else:
print("Taxon probe not found — using composition heuristic for auto-detection")
# Step 10: Platt scaling (optional, generated by calibrate_server.py)
platt_path = os.path.join(BASE_DIR, "platt_mammal.json")
if os.path.exists(platt_path):
with open(platt_path) as f:
platt_params = json.load(f)
print(f"Platt scaling loaded: {len(platt_params)} labels calibrated")
else:
print("Platt params not found — using temperature scaling only")
# Step 11: Override temperature from calibration sweep if available
temp_path = os.path.join(BASE_DIR, "temperature_best.json")
if os.path.exists(temp_path):
with open(temp_path) as f:
temp_data = json.load(f)
new_T = float(temp_data.get("optimal_T", temperature))
if abs(new_T - temperature) > 0.1:
print(f"Temperature updated by calibration sweep: {temperature:.4f}{new_T:.4f}")
temperature = new_T
else:
print(f"Temperature unchanged by sweep: {temperature:.4f}")
# Step 12: Pre-compute anchor embeddings for organism inference
import numpy as _np
_anchor_embeddings = {}
for _tg, _seq in ANCHOR_SEQUENCES.items():
try:
_emb = _get_esm_embedding(_seq[:500]).detach().cpu().numpy().astype(_np.float32)
_norm = _np.linalg.norm(_emb)
if _norm > 1e-12:
_anchor_embeddings[_tg] = _emb / _norm
except Exception as _e:
print(f" Anchor embedding for {_tg} failed: {_e}")
print(f"Anchor embeddings computed: {list(_anchor_embeddings.keys())}")
yield
print("Shutting down.")
app = FastAPI(lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
@app.get("/")
async def root():
return FileResponse(os.path.join(STATIC_DIR, "interface.html"), headers={"Cache-Control": "no-store"})
@app.get("/api/model/info")
async def model_info():
"""Return model metadata and configuration."""
unified_35M = os.path.exists(os.path.join(BASE_DIR, "unified_35M_v1.pth"))
unified_v1 = os.path.exists(os.path.join(BASE_DIR, "unified_v1.pth"))
mammal_enriched = os.path.exists(os.path.join(BASE_DIR, "mammal_enriched.pth"))
v3_fixed = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"))
improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
# model name reflects actual loaded model (unified_35M_v1 takes highest priority)
if unified_35M and model_uses_supp and _ESM_DIM == 480:
name, version, active = "FABLE v5.0 (35M ESM, GOA-enriched, insect+mammal)", "5.0.0", "unified_35M_v1"
elif unified_v1 and model_uses_supp:
name, version, active = "FABLE v4.0 (unified insect+mammal, CAFA5-evaluated)", "4.0.0", "unified_v1"
elif mammal_enriched and model_uses_supp:
name, version, active = "FABLE v3.2 (mammal-enriched, CAFA5-evaluated)", "3.2.0", "mammal_enriched"
elif v3_fixed and model_uses_supp:
name, version, active = "FABLE v3-fixed (ablation best, CAFA-correct)", "3.1.0", "protfunc_v3_fixed"
elif model_uses_supp:
name, version, active = "FABLE v3 (supplemented + mammal)", "3.0.0", "protfunc_v3"
elif improved:
name, version, active = "FABLE Enhanced", "2.1.0", "improved"
else:
name, version, active = "FABLE", "2.0.0", "baseline"
return {
"model_name": name,
"model": active,
"version": version,
"esm_model": "esm2_t12_35M_UR50D" if _ESM_DIM == 480 else "esm2_t6_8M_UR50D",
"embed_dim": _ESM_DIM,
"num_labels": NUM_LABELS,
"supported_namespaces": ["molecular_function"],
"max_sequence_length": 1500,
"thresholds_loaded": len(thresholds) > 0,
"temperature_scaling": temperature != 1.0,
"temperature": round(temperature, 4),
"insect_temperature": round(insect_temperature, 4),
"insect_threshold_default": round(insect_threshold_default, 4),
"supported_taxa": ["insect", "mammal"],
"taxon_routing": True,
"hierarchy_filtering": len(go_parents) > 0,
"parental_propagation": len(go_ancestors) > 0,
"depth_annotation": len(go_depth) > 0,
"mf_terms_loaded": len(mf_terms) if mf_terms else 0,
"mf_max_depth": max(go_depth.values(), default=0) if go_depth else 0,
"supplemented_features": model_uses_supp,
"supp_feature_count": len(supp_cols) if supp_cols else 0,
}
@app.get("/api/generalization")
async def get_generalization():
"""
Return cross-taxon generalization results from eval_generalization.py output.
Serves artifacts/generalization_results.json if present, otherwise returns empty.
"""
candidates = [
os.path.join(BASE_DIR, "artifacts", "generalization", "generalization_results.json"),
os.path.join(BASE_DIR, "generalization_results.json"),
]
for path in candidates:
if os.path.exists(path):
with open(path) as f:
data = json.load(f)
return {
"available": True,
"taxa": list(data.keys()),
"results": data,
}
return {"available": False, "taxa": [], "results": {}}
@app.get("/api/structure")
async def get_structure(uniprot_id: str):
"""Look up AlphaFold structure data for a UniProt accession."""
import urllib.request, urllib.error
uid = uniprot_id.upper().strip()
if not re.match(r'^[A-Z0-9]{4,10}$', uid):
return {"found": False, "error": "Invalid accession format"}
try:
req = urllib.request.Request(
f"https://alphafold.ebi.ac.uk/api/prediction/{uid}",
headers={"Accept": "application/json", "User-Agent": "FABLE/1.0"}
)
with urllib.request.urlopen(req, timeout=8) as resp:
entries = json.loads(resp.read())
d = entries[0]
return {
"found": True,
"accession": uid,
"organism": d.get("organismScientificName", ""),
"gene": d.get("gene", ""),
"cif_url": d.get("cifUrl", ""),
"pae_image_url": d.get("paeImageUrl", ""),
"entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}",
"uniprot_url": f"https://www.uniprot.org/uniprot/{uid}",
"model_version": d.get("latestVersion", 4),
}
except urllib.error.HTTPError as e:
if e.code == 404:
return {"found": False, "accession": uid,
"uniprot_url": f"https://www.uniprot.org/uniprot/{uid}"}
return {"found": False, "error": f"HTTP {e.code}"}
except Exception as e:
return {"found": False, "error": str(e)[:100]}
@app.get("/api/uniprot/annotations")
async def get_uniprot_annotations(uniprot_id: str):
"""
Fetch GO-MF annotations and organism info from UniProt REST API.
Returns known annotations (evidence-coded) for comparison with predictions.
"""
import urllib.request, urllib.error
uid = uniprot_id.upper().strip()
if not re.match(r'^[A-Z0-9]{4,10}$', uid):
return {"found": False, "error": "Invalid accession format"}
if uid in _uniprot_cache:
return _uniprot_cache[uid]
try:
url = f"https://rest.uniprot.org/uniprotkb/{uid}.json?fields=go,organism,protein_name,gene_names,organism_lineage"
req = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "FABLE/1.0"})
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
except urllib.error.HTTPError as e:
if e.code == 404:
return {"found": False, "accession": uid, "error": "Not found in UniProt"}
return {"found": False, "error": f"HTTP {e.code}"}
except Exception as e:
return {"found": False, "error": str(e)[:100]}
# Parse organism lineage for taxon detection
organism = data.get("organism", {})
org_name = organism.get("scientificName", "")
lineage = [x.get("scientificName", "") for x in organism.get("lineage", [])]
if any(t in lineage for t in INSECT_LINEAGE):
detected_taxon = "insect"
elif any(t in lineage for t in MAMMAL_LINEAGE):
detected_taxon = "mammal"
else:
detected_taxon = "auto"
# Parse protein name
pn_block = data.get("proteinDescription", {})
prot_name = ""
if "recommendedName" in pn_block:
prot_name = pn_block["recommendedName"].get("fullName", {}).get("value", "")
elif "submissionNames" in pn_block:
prot_name = pn_block["submissionNames"][0].get("fullName", {}).get("value", "")
gene_names = []
for gn in data.get("genes", []):
if "geneName" in gn:
gene_names.append(gn["geneName"]["value"])
# Parse GO-MF annotations
go_mf = []
for xref in data.get("uniProtKBCrossReferences", []):
if xref.get("database") != "GO":
continue
go_id = xref.get("id", "")
props = {p["key"]: p["value"] for p in xref.get("properties", [])}
term_str = props.get("GoTerm", "")
if not term_str.startswith("F:"):
continue # only molecular function
evidence = props.get("GoEvidenceType", "")
ev_code = evidence.split(":")[0] if ":" in evidence else evidence
go_name = term_str[2:].strip()
# Resolve canonical name from our go_map if available
display_name = go_map.get(go_id, go_name)
# Classify evidence tier
exp_codes = {"EXP","IDA","IPI","IMP","IGI","IEP","HTP","HDA","HMP","HGI","HEP"}
comp_codes = {"ISS","ISO","ISA","ISM","IGC","IBA","IBD","IKR","IRD","RCA"}
if ev_code in exp_codes:
ev_tier = "experimental"
elif ev_code in comp_codes:
ev_tier = "computational"
elif ev_code == "IEA":
ev_tier = "electronic"
else:
ev_tier = "other"
go_mf.append({
"go_id": go_id,
"name": display_name,
"evidence": ev_code,
"ev_tier": ev_tier,
})
# Deduplicate by go_id, keeping best evidence tier
tier_rank = {"experimental": 0, "computational": 1, "other": 2, "electronic": 3}
seen = {}
for entry in go_mf:
gid = entry["go_id"]
if gid not in seen or tier_rank[entry["ev_tier"]] < tier_rank[seen[gid]["ev_tier"]]:
seen[gid] = entry
go_mf = sorted(seen.values(), key=lambda x: (tier_rank[x["ev_tier"]], x["name"]))
result = {
"found": True,
"accession": uid,
"protein_name": prot_name,
"gene_names": gene_names,
"organism": org_name,
"lineage": lineage[-5:] if lineage else [],
"detected_taxon": detected_taxon,
"go_mf": go_mf,
"n_experimental": sum(1 for e in go_mf if e["ev_tier"] == "experimental"),
"n_total": len(go_mf),
}
_uniprot_cache[uid] = result
return result
@app.get("/api/health")
async def health_check():
"""Health check endpoint for monitoring."""
return {
"status": "healthy",
"model_loaded": model is not None,
"esm_loaded": esm_model is not None,
"labels": NUM_LABELS,
}
class ProteinRequest(BaseModel):
sequence: str
taxon: str = "auto" # "auto" | "insect" | "mammal"
uniprot_id: str = "" # optional accession for structure lookup
class SaliencyRequest(BaseModel):
sequence: str
uniprot_id: str = ""
taxon: str = "auto"
top_k: int = 20 # top-k predicted labels to use as saliency objective
class ExplainRequest(BaseModel):
sequence: str
uniprot_id: str = ""
taxon: str = "auto"
top_k: int = 10
class BatchPredictRequest(BaseModel):
sequences: list # List of {name: str, sequence: str} objects
threshold: float = 0.5
include_suppressed: bool = True
class OrganismRequest(BaseModel):
sequence: str
user_selection: str = "" # empty or "auto" → infer; otherwise taxon_group label
class IdentifyRequest(BaseModel):
sequence: str
user_selection: str = "" # empty or "auto" → infer; otherwise taxon_group label
accession: str = "" # optional: typed or FASTA-header accession to trust
def parse_sequences(text):
text = text.strip()
if text.startswith(">"):
blocks = re.split(r"(>.*)", text)
names, seqs = [], []
i = 1
while i < len(blocks):
name = blocks[i][1:].strip()
seq = re.sub(r"\s+", "", blocks[i + 1]) if i + 1 < len(blocks) else ""
if seq:
names.append(name)
seqs.append(seq)
i += 2
return list(zip(names, seqs))
seqs = [line.strip() for line in text.splitlines() if line.strip()]
return [(f"Sequence {i + 1}", s) for i, s in enumerate(seqs)]
def _build_model_input(emb: torch.Tensor, sequence: str) -> torch.Tensor:
"""
Build the model input tensor from an ESM-2 embedding.
For supplemented models: appends z-scored seq/structural features.
For base models: returns the embedding as-is.
Handles two supplemented feature sets:
- v3 models (in_dim=360): 39 sequence-derived features + 1 missingness flag
- supp_res2 (in_dim=709): 388 features including f_Dim_* (ESM dims), AF
structural features (zeros at inference), and sequence features
"""
import numpy as np
if not model_uses_supp or supp_mu is None:
return emb.unsqueeze(0)
feats = compute_seq_features(sequence)
# For models that include f_Dim_* features (e.g. supp_res2), populate them
# from the ESM embedding rather than defaulting to 0.
emb_np = emb.detach().cpu().numpy()
for c in supp_cols:
if c.startswith('f_Dim_'):
try:
dim_idx = int(c.split('_')[-1])
if dim_idx < len(emb_np):
feats[c] = float(emb_np[dim_idx])
except (ValueError, IndexError):
pass
s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32)
s_z = (s_vec - supp_mu) / (supp_sd + 1e-12)
# missingness flag: 1 = feature data available (seq features always computable)
flag = np.array([1.0], dtype=np.float32)
extra = torch.from_numpy(np.concatenate([s_z, flag]))
full_input = torch.cat([emb, extra], dim=0).unsqueeze(0)
# Verify the built input matches what the model actually expects.
# Some checkpoints (e.g. mammal_enriched) store all supp_cols for reference but
# were trained with only the first N features (via [:in_dim] truncation).
# Truncate to match rather than falling back to bare ESM embedding.
if model is not None and hasattr(model, 'fc_in'):
expected_dim = model.fc_in.weight.shape[1]
if full_input.shape[1] != expected_dim:
if full_input.shape[1] > expected_dim:
full_input = full_input[:, :expected_dim]
else:
return emb.unsqueeze(0) # can't expand — fall back to bare ESM
return full_input
# LRU cache for ESM embeddings — avoids recomputing for repeated/identical sequences
_ESM_CACHE: dict = {}
_ESM_CACHE_MAX = 256
def _get_esm_embedding(sequence: str) -> torch.Tensor:
"""Return mean-pooled ESM-2 embedding, using in-memory cache."""
key = hashlib.md5(sequence.encode()).hexdigest()
if key in _ESM_CACHE:
return _ESM_CACHE[key]
_, _, tokens = batch_converter([("p", sequence)])
with torch.no_grad():
rep = esm_model(tokens.to(device), repr_layers=[6])["representations"][6]
emb = rep[0, 1:len(sequence) + 1].mean(0).cpu()
if len(_ESM_CACHE) >= _ESM_CACHE_MAX:
_ESM_CACHE.pop(next(iter(_ESM_CACHE))) # evict oldest (FIFO)
_ESM_CACHE[key] = emb
return emb
@app.post("/predict")
async def predict(request: ProteinRequest):
import numpy as np
import urllib.request, urllib.error
entries = parse_sequences(request.sequence)
results = []
device_cpu = torch.device("cpu")
mf_idx = app.state.mf_indices
uid = (request.uniprot_id or "").upper().strip()
req_taxon = (request.taxon or "auto").lower()
# Fetch UniProt annotations async (if accession given) — do once for the batch
uniprot_data = None
detected_taxon_from_uniprot = None
if uid and re.match(r'^[A-Z0-9]{4,10}$', uid):
try:
url = f"https://rest.uniprot.org/uniprotkb/{uid}.json?fields=go,organism,protein_name,gene_names,organism_lineage"
req_obj = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "FABLE/1.0"})
with urllib.request.urlopen(req_obj, timeout=8) as resp:
raw = json.loads(resp.read())
organism = raw.get("organism", {})
lineage = [x.get("scientificName", "") for x in organism.get("lineage", [])]
if any(t in lineage for t in INSECT_LINEAGE):
detected_taxon_from_uniprot = "insect"
elif any(t in lineage for t in MAMMAL_LINEAGE):
detected_taxon_from_uniprot = "mammal"
# Parse GO-MF annotations
go_mf_known = {}
tier_rank = {"experimental": 0, "computational": 1, "other": 2, "electronic": 3}
exp_ev = {"EXP","IDA","IPI","IMP","IGI","IEP","HTP","HDA","HMP","HGI","HEP"}
comp_ev = {"ISS","ISO","ISA","ISM","IGC","IBA","IBD","IKR","IRD","RCA"}
for xref in raw.get("uniProtKBCrossReferences", []):
if xref.get("database") != "GO":
continue
props = {p["key"]: p["value"] for p in xref.get("properties", [])}
if not props.get("GoTerm", "").startswith("F:"):
continue
go_id = xref.get("id", "")
ev_code = props.get("GoEvidenceType", "").split(":")[0]
ev_tier = "experimental" if ev_code in exp_ev else (
"computational" if ev_code in comp_ev else (
"electronic" if ev_code == "IEA" else "other"))
entry = {"go_id": go_id, "name": go_map.get(go_id, props["GoTerm"][2:]),
"evidence": ev_code, "ev_tier": ev_tier}
if go_id not in go_mf_known or tier_rank[ev_tier] < tier_rank[go_mf_known[go_id]["ev_tier"]]:
go_mf_known[go_id] = entry
uniprot_data = {
"go_mf_known": sorted(go_mf_known.values(), key=lambda x: (tier_rank[x["ev_tier"]], x["name"])),
"organism": organism.get("scientificName", ""),
"detected_taxon": detected_taxon_from_uniprot,
"accession": uid,
}
except Exception:
pass
for name, sequence in entries:
err = validate_sequence(name, sequence)
if err:
results.append({"name": name, "error": err})
continue
if len(sequence) > 1500:
results.append({"name": name, "error": "Sequence too long (max 1500 aa)"})
continue
try:
emb = _get_esm_embedding(sequence).to(device_cpu)
emb_np = emb.detach().cpu().numpy()
# ── Organism resolution (DECOUPLED from GO gating) ──────────────
tx = resolve_taxon(req_taxon, emb_np=emb_np, sequence=sequence,
detected_uniprot=detected_taxon_from_uniprot)
calibration = tx["calibration"]
taxon_conf = tx["conf"]
if calibration == "insect":
t_apply = insect_temperature
thresh_lookup = {}
t_default = insect_threshold_default
base_floor = 0.0
else:
t_apply = temperature
thresh_lookup = thresholds
t_default = 0.68
base_floor = 0.56
# DECOUPLE: a blanket organism-conditioned floor must not zero out
# semantically-valid GO terms when the organism is only a guess or the
# protein is out-of-distribution. In those cases gate on the per-label
# confidence thresholds alone and FLAG instead of suppressing (#1).
mammal_floor = 0.0 if (tx["low_confidence"] or tx["ood"]) else base_floor
# ── Forward pass ────────────────────────────────────────────────
with torch.no_grad():
inp = _build_model_input(emb, sequence)
logits = model(inp).squeeze()
# Apply Platt scaling per-label (if available, mammal calibration only)
if platt_params and calibration != "insect":
probs_list = []
for i in range(len(logits)):
l = float(logits[i])
if str(i) in platt_params:
A, B = platt_params[str(i)]
p = 1.0 / (1.0 + math.exp(-(A * l + B)))
else:
p = 1.0 / (1.0 + math.exp(-l / t_apply))
probs_list.append(p)
prob = torch.tensor(probs_list)
else:
prob = torch.sigmoid(logits / t_apply)
if prob.dim() == 0:
prob = prob.unsqueeze(0)
# ── Threshold + collect ─────────────────────────────────────────
raw_preds = []
prob_map = {}
for i in mf_idx:
pv = float(prob[i])
go_id = mlb.classes_[i]
display_id = go_replaced.get(go_id, go_id)
label_thresh = max(float(thresh_lookup.get(str(i), t_default)), mammal_floor)
# Organism-independent false-positive floor (e.g. oxygen binding).
fp_floor = FP_THRESHOLD_OVERRIDES.get(display_id) or FP_THRESHOLD_OVERRIDES.get(go_id)
if fp_floor:
label_thresh = max(label_thresh, fp_floor)
if pv >= label_thresh:
display_nm = go_map.get(display_id, go_map.get(go_id, go_id))
entry = {"go_id": display_id, "name": display_nm,
"prob": round(pv, 4), "depth": go_depth.get(display_id, -1)}
if display_id != go_id:
entry["original_id"] = go_id
raw_preds.append(entry)
prob_map[display_id] = pv
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(sequence))
raw_preds = raw_preds[:cap]
for rp in raw_preds:
prob_map[rp["go_id"]] = rp["prob"]
visible, suppressed = propagate_and_filter(raw_preds, go_parents, go_ancestors, prob_map)
result = {
"name": name,
"sequence_length": len(sequence),
"predictions": visible,
"suppressed": suppressed,
"n_above_threshold": len(raw_preds),
"n_implied_parents": sum(1 for p in visible if p.get("implied")),
"taxon_applied": tx["display"],
"taxon_calibration": calibration,
"taxon_source": tx["source"],
"taxon_confidence": round(taxon_conf, 3),
"ood": tx["ood"],
"low_confidence": tx["low_confidence"],
"temperature_applied": round(t_apply, 4),
"platt_applied": bool(platt_params) and calibration != "insect",
}
# FLAG rather than suppress: surface why predictions may be unreliable
# instead of silently zeroing them on an organism guess (#1).
top_prob = raw_preds[0]["prob"] if raw_preds else 0.0
if tx["ood"]:
result["warning"] = (
f"{tx['display'].title()} proteins are outside FABLE's training "
f"distribution (insect + mammal). GO terms are shown unfiltered — "
f"treat confidence scores with caution.")
elif tx["low_confidence"]:
result["warning"] = (
"Organism could not be confidently inferred, so GO terms are gated "
"on prediction confidence alone. Select the organism above for "
"calibrated results.")
elif not raw_preds or top_prob < 0.55:
result["warning"] = (
"Low-confidence predictions — this sequence may be outside FABLE's "
"training distribution. Interpret results with caution.")
if uniprot_data:
result["uniprot"] = uniprot_data
results.append(result)
except Exception as e:
results.append({"name": name, "error": str(e)})
return {"results": results}
@app.post("/api/predict/batch")
async def predict_batch(request: BatchPredictRequest):
"""
Batch prediction endpoint for multiple sequences.
Accepts a list of sequence objects and returns predictions for all.
More efficient than multiple single predictions due to batching.
"""
results = []
mf_idx = app.state.mf_indices
custom_threshold = request.threshold
for item in request.sequences:
name = item.get("name", "Unknown")
sequence = item.get("sequence", "")
# Validate sequence
err = validate_sequence(name, sequence)
if err:
results.append({"name": name, "error": err})
continue
if len(sequence) > 1500:
results.append({"name": name, "error": "Sequence too long (max 1500 aa)"})
continue
try:
emb = _get_esm_embedding(sequence).to(device)
with torch.no_grad():
inp = _build_model_input(emb, sequence)
prob = torch.sigmoid(model(inp) / temperature).squeeze()
if prob.dim() == 0:
prob = prob.unsqueeze(0)
raw_preds = []
prob_map = {}
for i in mf_idx:
pv = float(prob[i])
thresh = float(thresholds.get(str(i), custom_threshold))
if pv >= thresh:
go_id = mlb.classes_[i]
display_id = go_replaced.get(go_id, go_id)
display_nm = go_map.get(display_id, go_map.get(go_id, go_id))
entry = {
"go_id": display_id,
"name": display_nm,
"prob": round(pv, 4),
"depth": go_depth.get(display_id, -1),
}
if display_id != go_id:
entry["original_id"] = go_id
raw_preds.append(entry)
prob_map[display_id] = pv
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(sequence))
raw_preds = raw_preds[:cap]
for rp in raw_preds:
prob_map[rp["go_id"]] = rp["prob"]
visible, suppressed = propagate_and_filter(
raw_preds, go_parents, go_ancestors, prob_map
)
if not request.include_suppressed:
suppressed = []
results.append({
"name": name,
"sequence_length": len(sequence),
"predictions": visible,
"suppressed": suppressed,
"n_above_threshold": len(raw_preds),
"n_implied_parents": sum(1 for p in visible if p.get("implied")),
})
except Exception as e:
results.append({"name": name, "error": str(e)})
return {
"results": results,
"total": len(results),
"successful": sum(1 for r in results if "error" not in r),
}
class GoTermsRequest(BaseModel):
sequence: str
uniprot_id: str = ""
taxon: str = "auto"
top_k: int = 20
include_implied: bool = False
min_prob: float = 0.0
_go_terms_cache: dict = {}
@app.post("/api/go_terms")
async def get_go_terms(request: GoTermsRequest):
"""
Lightweight GO-MF prediction endpoint for programmatic/pipeline use.
Returns only predicted term list — no suppressed, no taxon explanation,
no Platt metadata. ~2× faster than /predict for downstream tools.
Results are cached by (sequence_hash, uniprot_id, taxon).
"""
import hashlib
seq = request.sequence.strip()
cache_key = (hashlib.md5(seq.encode()).hexdigest(), request.uniprot_id.upper(), request.taxon)
if cache_key in _go_terms_cache:
return _go_terms_cache[cache_key]
err = validate_sequence("seq", seq)
if err:
return {"error": err, "predictions": []}
if len(seq) > 1500:
return {"error": "Sequence too long (max 1500 aa)", "predictions": []}
try:
mf_idx = app.state.mf_indices
emb = _get_esm_embedding(seq).to(device)
emb_np = emb.detach().cpu().numpy()
tx = resolve_taxon(request.taxon, emb_np=emb_np, sequence=seq)
calibration = tx["calibration"]
taxon_conf = tx["conf"]
if calibration == "insect":
t_apply = insect_temperature
thresh_lookup = {}
t_default = insect_threshold_default
base_floor = 0.0
else:
t_apply = temperature
thresh_lookup = thresholds
t_default = 0.68
base_floor = 0.56
# Decouple gating from an uncertain/OOD organism guess (#1).
mammal_floor = 0.0 if (tx["low_confidence"] or tx["ood"]) else base_floor
with torch.no_grad():
inp = _build_model_input(emb, seq)
logits = model(inp).squeeze()
if platt_params and calibration != "insect":
probs_list = []
for i in range(len(logits)):
l = float(logits[i])
if str(i) in platt_params:
A, B = platt_params[str(i)]
p = 1.0 / (1.0 + math.exp(-(A * l + B)))
else:
p = 1.0 / (1.0 + math.exp(-l / t_apply))
probs_list.append(p)
prob = torch.tensor(probs_list)
else:
prob = torch.sigmoid(logits / t_apply)
if prob.dim() == 0:
prob = prob.unsqueeze(0)
raw_preds = []
for i in mf_idx:
pv = float(prob[i])
go_id = mlb.classes_[i]
display_id = go_replaced.get(go_id, go_id)
label_thresh = max(float(thresh_lookup.get(str(i), t_default)), mammal_floor)
fp_floor = FP_THRESHOLD_OVERRIDES.get(display_id) or FP_THRESHOLD_OVERRIDES.get(go_id)
if fp_floor:
label_thresh = max(label_thresh, fp_floor)
if pv >= label_thresh:
display_nm = go_map.get(display_id, go_map.get(go_id, go_id))
raw_preds.append({
"go_id": display_id,
"name": display_nm,
"prob": round(pv, 4),
"depth": go_depth.get(display_id, -1),
})
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(seq))
raw_preds = raw_preds[:cap]
if not request.include_implied:
predictions = raw_preds[:request.top_k]
else:
prob_map = {p["go_id"]: p["prob"] for p in raw_preds}
visible, _ = propagate_and_filter(raw_preds, go_parents, go_ancestors, prob_map)
predictions = [p for p in visible if not p.get("implied")][:request.top_k]
if request.min_prob > 0:
predictions = [p for p in predictions if p["prob"] >= request.min_prob]
result = {
"predictions": predictions,
"n_predicted": len(predictions),
"taxon_applied": tx["display"],
"taxon_calibration": calibration,
"taxon_conf": round(taxon_conf, 3),
"taxon_source_method": tx["source"],
"ood": tx["ood"],
"low_confidence": tx["low_confidence"],
"platt_applied": bool(platt_params) and calibration != "insect",
"threshold_default": t_default,
}
_go_terms_cache[cache_key] = result
return result
except Exception as e:
return {"error": str(e)[:300], "predictions": []}
@app.get("/api/explain_terms")
async def explain_terms(ids: str = ""):
"""
Return name + definition for a comma-separated list of GO IDs.
Used by the frontend "Why?" panel to show term descriptions inline.
"""
if not ids:
return {"terms": []}
result = []
for gid in ids.split(",")[:50]:
gid = gid.strip()
if not gid:
continue
name = go_map.get(gid, gid)
defn = go_defs.get(gid, "")
result.append({"id": gid, "name": name, "definition": defn})
return {"terms": result}
@app.post("/api/saliency")
async def compute_saliency(request: SaliencyRequest):
"""
Compute per-residue gradient saliency for a protein sequence.
Uses d(sum_of_top_k_probs)/d(ESM_residue_representations) via backprop.
Optionally fetches AlphaFold structure if uniprot_id is provided.
Returns normalized per-residue importance scores in [0, 1].
"""
import numpy as np
sequence = re.sub(r"\s+", "", request.sequence.upper())
if not sequence:
return {"error": "Empty sequence"}
if len(sequence) > 1200:
return {"error": "Sequence too long for saliency (max 1200 aa)"}
err = validate_sequence("query", sequence)
if err:
return {"error": err}
taxon = (request.taxon or "auto").lower()
t_apply = insect_temperature if taxon == "insect" else temperature
try:
_, _, tokens = batch_converter([("p", sequence)])
tokens = tokens.to(device)
L = len(sequence)
# ESM parameters are frozen; run forward without grad, then re-attach
# residue reps as a leaf tensor so backward can compute d(obj)/d(emb).
with torch.no_grad():
out = esm_model(tokens, repr_layers=[6])
residue_reps = out["representations"][6].detach().clone().requires_grad_(True)
with torch.enable_grad():
# Mean-pool (stays in graph from residue_reps leaf)
emb = residue_reps[0, 1:L + 1].mean(0) # (320,)
# Build MLP input in-graph (gradient-safe version of _build_model_input)
if model_uses_supp and supp_mu is not None:
feats = compute_seq_features(sequence)
s_vec = torch.tensor(
[(feats.get(c, 0.0) - float(supp_mu[j])) / (float(supp_sd[j]) + 1e-12)
for j, c in enumerate(supp_cols)],
dtype=torch.float32, device=device
)
# flag tensor (no grad needed)
flag = torch.ones(1, dtype=torch.float32, device=device)
inp_full = torch.cat([emb, s_vec, flag]).unsqueeze(0)
expected = model.fc_in.weight.shape[1]
inp = inp_full[:, :expected]
else:
inp = emb.unsqueeze(0)
logits = model(inp) / t_apply
probs = torch.sigmoid(logits[0]) # (8124,)
# Objective: sum of top-k predicted probabilities
k = min(request.top_k, int((probs > 0.3).sum().item()), 8124)
k = max(k, 5)
top_vals = probs.topk(k).values
objective = top_vals.sum()
objective.backward()
scores = [0.0] * L
if residue_reps.grad is not None:
grad = residue_reps.grad[0, 1:L + 1].norm(dim=-1).detach().cpu().numpy()
mn, mx = grad.min(), grad.max()
scores = ((grad - mn) / (mx - mn + 1e-8)).tolist()
# Fetch PDB structure and write saliency into B-factors
uid = (request.uniprot_id or "").upper().strip()
uid_clean = uid if re.match(r'^[A-Z0-9]{4,10}$', uid) else None
pdb_str, struct_source = await get_structure_pdb(sequence, uid_clean)
pdb_with_saliency = None
if pdb_str:
pdb_with_saliency = write_saliency_to_bfactor(pdb_str, scores)
return {
"sequence_length": L,
"per_residue_scores": scores,
"taxon": taxon,
"pdb_with_saliency": pdb_with_saliency,
"structure_source": struct_source,
}
except Exception as e:
return {"error": str(e)[:200]}
@app.post("/api/explainability")
async def compute_explainability(request: ExplainRequest):
"""
Compute feature-level importance for the 11 sequence features via gradient × input.
Also returns per-residue feature maps for 3D structure coloring.
Gradient flows only through the first 11 supp features; ESM embedding is detached.
"""
import numpy as np
import urllib.request, urllib.error
sequence = re.sub(r"\s+", "", request.sequence.upper())
if not sequence:
return {"error": "Empty sequence"}
if len(sequence) > 1200:
return {"error": "Sequence too long for explainability (max 1200 aa)"}
err = validate_sequence("query", sequence)
if err:
return {"error": err}
if not model_uses_supp or supp_mu is None:
return {"error": "Feature importance requires a supplemented model (unified_v1 or later)"}
taxon = (request.taxon or "auto").lower()
t_apply = insect_temperature if taxon == "insect" else temperature
try:
emb = _get_esm_embedding(sequence).to(device) # (320,) detached
feats = compute_seq_features(sequence)
s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32)
s_z = (s_vec - supp_mu) / (supp_sd + 1e-12)
n_tracked = min(11, len(supp_cols))
s_z_11 = torch.tensor(s_z[:n_tracked], requires_grad=True,
dtype=torch.float32, device=device)
s_z_rest = torch.tensor(s_z[n_tracked:], dtype=torch.float32, device=device)
flag = torch.ones(1, dtype=torch.float32, device=device)
inp_full = torch.cat([emb.detach(), s_z_11, s_z_rest, flag]).unsqueeze(0)
expected = model.fc_in.weight.shape[1]
inp = inp_full[:, :expected]
with torch.enable_grad():
logits = model(inp) / t_apply
probs = torch.sigmoid(logits[0])
k = min(request.top_k, int((probs > 0.3).sum().item()), 8124)
k = max(k, 5)
probs.topk(k).values.sum().backward()
if s_z_11.grad is None:
return {"error": "Gradient computation failed — no gradient on supp features"}
grad_np = s_z_11.grad.detach().cpu().numpy()
s_z_11np = s_z_11.detach().cpu().numpy()
attribution = grad_np * s_z_11np # signed grad × input
per_residue_maps = compute_per_residue_features(sequence)
feat_meta_by_key = {fm["key"]: fm for fm in FEATURE_META}
features = []
for i in range(n_tracked):
col = supp_cols[i]
meta = feat_meta_by_key.get(col, {"key": col, "label": col, "desc": col, "color": "#888888"})
attr = float(attribution[i])
features.append({
"key": col,
"label": meta["label"],
"desc": meta["desc"],
"color": meta["color"],
"importance": round(attr, 4),
"abs_importance": round(abs(attr), 4),
"per_residue": per_residue_maps.get(col, [0.5] * len(sequence)),
})
features.sort(key=lambda x: x["abs_importance"], reverse=True)
# Fetch AlphaFold structure
structure = None
uid = request.uniprot_id.upper().strip()
if uid and re.match(r'^[A-Z0-9]{4,10}$', uid):
try:
req = urllib.request.Request(
f"https://alphafold.ebi.ac.uk/api/prediction/{uid}",
headers={"Accept": "application/json", "User-Agent": "FABLE/1.0"}
)
with urllib.request.urlopen(req, timeout=8) as resp:
d = json.loads(resp.read())[0]
structure = {
"found": True,
"accession": uid,
"cif_url": d.get("cifUrl", ""),
"organism": d.get("organismScientificName", ""),
"gene": d.get("gene", ""),
"entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}",
"uniprot_url": f"https://www.uniprot.org/uniprot/{uid}",
}
except Exception:
structure = {"found": False, "accession": uid}
return {
"sequence_length": len(sequence),
"features": features,
"top_feature": features[0]["key"] if features else None,
"structure": structure,
"taxon": taxon,
}
except Exception as e:
return {"error": str(e)[:300]}
@app.post("/api/infer_organism")
async def api_infer_organism(request: OrganismRequest):
"""Infer organism taxon group from sequence or return user selection."""
seq = re.sub(r"\s+", "", request.sequence.upper())[:500]
if not seq:
return {"taxon_group": "unknown", "confidence": 0.0, "method": "empty_sequence"}
sel = (request.user_selection or "").strip().lower()
result = infer_organism(seq, sel if sel and sel != "auto" else None)
return result
@app.post("/api/identify")
async def api_identify(request: IdentifyRequest):
"""
Resolve organism + UniProt accession from a sequence.
Grounds the organism in the real UniProt entry (via accession or peptide
search + exact-sequence verification) when possible, falling back to a
calibrated sequence-similarity estimate. Powers both the organism badge and
the auto-filled accession field. Best-effort: never raises to the client.
"""
seq = _clean_seq(request.sequence)[:1500]
if not seq:
return {"taxon_group": "unknown", "confidence": 0.0, "method": "empty_sequence"}
sel = (request.user_selection or "").strip().lower()
try:
return await identify_protein(
seq,
sel if sel and sel != "auto" else None,
request.accession or None,
)
except Exception as e:
return {"taxon_group": "unknown", "confidence": 0.0,
"method": "error", "error": str(e)[:120]}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)