metadata
language: en
license: mit
tags:
- graph-ml
- bioinformatics
- precision-medicine
- explainable-ai
- reinforcement-learning
datasets:
- FuhaiLiAiLab/Target-QA
library_name: transformers
pipeline_tag: text-generation
model-index:
- name: GALAX
results:
- task:
type: text-generation
name: Target Prioritization
dataset:
name: Target-QA
type: FuhaiLiAiLab/Target-QA
metrics:
- type: precision
value: 0.5472
- type: recall
value: 0.5332
- type: hit@10
value: 0.8815
- type: hit@5
value: 0.9249
GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine
🧩 Model Overview
GALAX is a graph-augmented language model designed for explainable target prioritization in precision medicine. It combines three key components:
- LLaMA3-8B-Instruct as the language backbone, further adapted with the BioMedGraphica corpus and fine-tuned on Target-QA.
- Graph Attention Network (GAT) pretrained on integrated multi-omics data and BioMedGraphica knowledge graphs.
- A reinforcement-guided subgraph generator that enables interpretable reasoning by constructing biologically meaningful subgraphs from multi-omics and knowledge graph signals.
By jointly leveraging multi-omics features, protein–protein interactions, and disease–target associations, GALAX provides an interpretable framework for CRISPR target prioritization across diverse cancer cell lines. To support benchmarking and reproducibility, we also introduce the Target-QA dataset.
🚀 How to Use
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
import os, torch
# 1. Load GALAX language model
model_id = "FuhaiLiAiLab/GALAX"
tokenizer = AutoTokenizer.from_pretrained(model_id)
lm_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto"
)
# 2. Access graph foundation model
repo_path = snapshot_download(model_id)
combined_model_path = os.path.join(repo_path, "best_combined_model.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
best_combined_model = torch.load(combined_model_path, map_location=device)
⚙️ Experimental Setup
- Backbone LM: LLaMA3-8B-Instruct (QA-tuned).
- Graph Encoder: BioBERT-v1.1 embeddings + GAT with edge masking.
- Training: Adam optimizer on 2× NVIDIA H100 (80GB).
- Top features per omics modality: K = 10.
- Subgraph rollout depth: L = 5, candidate nodes η = 20.
- Evaluation: Precision, Recall, F1, Jaccard, Hit@5, Hit@10.
📊 Results
GALAX consistently outperforms baselines and ablation variants.
- Overall Precision: 0.5472
- Overall Recall: 0.5332
- Hit@10: 0.8815
- Hit@5: 0.9249
Table 1. Precision and Recall across datasets
| Model | Overall Precision ↑ | Overall Recall ↑ | LUAD Precision ↑ | LUAD Recall ↑ | BRCA Precision ↑ | BRCA Recall ↑ |
|---|---|---|---|---|---|---|
| M2T | 0.0016 | 0.0011 | 0.0020 | 0.0014 | 0.0000 | 0.0000 |
| GAT | 0.0006 ± 0.0000 | 0.0006 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0033 ± 0.0000 | 0.0033 ± 0.0000 |
| L3 + Omics | 0.0071 ± 0.0032 | 0.0013 ± 0.0002 | 0.0079 ± 0.0137 | 0.0005 ± 0.0008 | 0.0020 ± 0.0035 | 0.0017 ± 0.0029 |
| L3 + Omics + KG | 0.0125 ± 0.0032 | 0.0029 ± 0.0003 | 0.0014 ± 0.0025 | 0.0010 ± 0.0017 | 0.0073 ± 0.0068 | 0.0033 ± 0.0029 |
| L3-FT(Med) + Omics | 0.0179 ± 0.0045 | 0.0133 ± 0.0064 | 0.0091 ± 0.0018 | 0.0105 ± 0.0044 | 0.0110 ± 0.0086 | 0.0106 ± 0.0075 |
| L3-FT(Med) + Omics + KG | 0.0158 ± 0.0030 | 0.0058 ± 0.0011 | 0.0081 ± 0.0071 | 0.0024 ± 0.0017 | 0.0149 ± 0.0057 | 0.0050 ± 0.0000 |
| L3-FT(QA) + Omics | 0.5250 ± 0.0282 | 0.4959 ± 0.0435 | 0.5201 ± 0.0408 | 0.4905 ± 0.0532 | 0.5074 ± 0.0498 | 0.4856 ± 0.0570 |
| L3-FT(QA) + Omics + KG | 0.5185 ± 0.0240 | 0.4908 ± 0.0402 | 0.5214 ± 0.0242 | 0.4952 ± 0.0432 | 0.4856 ± 0.0395 | 0.4656 ± 0.0436 |
| G-Retriever + pre-GAT | 0.4763 ± 0.0004 | 0.3929 ± 0.0063 | 0.4642 ± 0.0181 | 0.3881 ± 0.0264 | 0.4414 ± 0.0099 | 0.3772 ± 0.0010 |
| GALAX | 0.5472 ± 0.0053 | 0.5332 ± 0.0031 | 0.5345 ± 0.0185 | 0.5157 ± 0.0043 | 0.5608 ± 0.0031 | 0.5533 ± 0.0033 |
Table 2. Hit@10 and Hit@5 across datasets
| Model | Overall Hit@10 ↑ | Overall Hit@5 ↑ | LUAD Hit@10 ↑ | LUAD Hit@5 ↑ | BRCA Hit@10 ↑ | BRCA Hit@5 ↑ |
|---|---|---|---|---|---|---|
| M2T | 0.0029 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | 0.0000 |
| GAT | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 |
| L3 + Omics | 0.0021 ± 0.0037 | 0.0032 ± 0.0055 | 0.0048 ± 0.0082 | 0.0095 ± 0.0165 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 |
| L3 + Omics + KG | 0.0122 ± 0.0033 | 0.0085 ± 0.0037 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0056 ± 0.0096 | 0.0111 ± 0.0192 |
| L3-FT(Med) + Omics | 0.0122 ± 0.0072 | 0.0116 ± 0.0097 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0111 ± 0.0192 | 0.0000 ± 0.0000 |
| L3-FT(Med) + Omics + KG | 0.0132 ± 0.0040 | 0.0106 ± 0.0048 | 0.0048 ± 0.0082 | 0.0095 ± 0.0165 | 0.0111 ± 0.0192 | 0.0000 ± 0.0000 |
| L3-FT(QA) + Omics | 0.8693 ± 0.0157 | 0.8889 ± 0.0168 | 0.8667 ± 0.0218 | 0.8476 ± 0.0165 | 0.8389 ± 0.0096 | 0.8889 ± 0.0509 |
| L3-FT(QA) + Omics + KG | 0.8529 ± 0.0153 | 0.8794 ± 0.0114 | 0.8048 ± 0.0541 | 0.7905 ± 0.0436 | 0.8222 ± 0.0347 | 0.8778 ± 0.0192 |
| G-Retriever + pre-GAT | 0.8550 ± 0.0046 | 0.8804 ± 0.0037 | 0.8524 ± 0.0165 | 0.8857 ± 0.0000 | 0.8667 ± 0.0000 | 0.8667 ± 0.0000 |
| GALAX | 0.8815 ± 0.0033 | 0.9249 ± 0.0048 | 0.8810 ± 0.0082 | 0.9238 ± 0.0436 | 0.8500 ± 0.0441 | 0.8889 ± 0.0839 |
🔬 Intended Uses
- Research use only
- Benchmarking graph-language foundation models in target priorization
- Target prioritization in cancer biology
📜 Citation
If you use this model, please cite:
@article{zhang2025galax,
title = {GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine},
author = {Zhang, Heming and Huang, Di and Li, Wenyu and Province, Michael and Chen, Yixin and Payne, Philip and Li, Fuhai},
journal = {arXiv preprint arXiv:2509.20935},
year = {2025},
doi = {10.48550/arXiv.2509.20935},
url = {https://arxiv.org/abs/2509.20935}
}
