---
language: en
license: mit
tags:
- graph-ml
- bioinformatics
- precision-medicine
- explainable-ai
- reinforcement-learning
datasets:
- FuhaiLiAiLab/Target-QA
library_name: transformers
pipeline_tag: text-generation
model-index:
- name: GALAX
results:
- task:
type: text-generation
name: Target Prioritization
dataset:
name: Target-QA
type: FuhaiLiAiLab/Target-QA
metrics:
- type: precision
value: 0.5472
- type: recall
value: 0.5332
- type: hit@10
value: 0.8815
- type: hit@5
value: 0.9249
---
# GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine
---
## 🧩 Model Overview

**GALAX** is a graph-augmented language model designed for explainable target prioritization in precision medicine. It combines three key components:
- **LLaMA3-8B-Instruct** as the language backbone, further adapted with the BioMedGraphica corpus and fine-tuned on Target-QA.
- **Graph Attention Network (GAT)** pretrained on integrated multi-omics data and BioMedGraphica knowledge graphs.
- **A reinforcement-guided subgraph generator** that enables interpretable reasoning by constructing biologically meaningful subgraphs from multi-omics and knowledge graph signals.
By jointly leveraging **multi-omics features**, **protein–protein interactions**, and **disease–target associations**, GALAX provides an interpretable framework for **CRISPR target prioritization** across diverse cancer cell lines. To support benchmarking and reproducibility, we also introduce the **[Target-QA dataset](https://huggingface.co/datasets/FuhaiLiAiLab/Target-QA)**.
---
## 🚀 How to Use
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
import os, torch
# 1. Load GALAX language model
model_id = "FuhaiLiAiLab/GALAX"
tokenizer = AutoTokenizer.from_pretrained(model_id)
lm_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto"
)
# 2. Access graph foundation model
repo_path = snapshot_download(model_id)
combined_model_path = os.path.join(repo_path, "best_combined_model.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
best_combined_model = torch.load(combined_model_path, map_location=device)
```
---
## ⚙️ Experimental Setup
- **Backbone LM:** LLaMA3-8B-Instruct (QA-tuned).
- **Graph Encoder:** BioBERT-v1.1 embeddings + GAT with edge masking.
- **Training:** Adam optimizer on 2× NVIDIA H100 (80GB).
- **Top features per omics modality:** K = 10.
- **Subgraph rollout depth:** L = 5, candidate nodes η = 20.
- **Evaluation:** Precision, Recall, F1, Jaccard, Hit@5, Hit@10.
---
## 📊 Results
GALAX consistently outperforms baselines and ablation variants.
- **Overall Precision:** 0.5472
- **Overall Recall:** 0.5332
- **Hit@10:** 0.8815
- **Hit@5:** 0.9249
**Table 1. Precision and Recall across datasets**
| Model | Overall Precision ↑ | Overall Recall ↑ | LUAD Precision ↑ | LUAD Recall ↑ | BRCA Precision ↑ | BRCA Recall ↑ |
|-------------------------|---------------------|------------------|------------------|---------------|------------------|---------------|
| M2T | 0.0016 | 0.0011 | 0.0020 | 0.0014 | 0.0000 | 0.0000 |
| GAT | 0.0006 ± 0.0000 | 0.0006 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0033 ± 0.0000 | 0.0033 ± 0.0000 |
| L3 + Omics | 0.0071 ± 0.0032 | 0.0013 ± 0.0002 | 0.0079 ± 0.0137 | 0.0005 ± 0.0008 | 0.0020 ± 0.0035 | 0.0017 ± 0.0029 |
| L3 + Omics + KG | 0.0125 ± 0.0032 | 0.0029 ± 0.0003 | 0.0014 ± 0.0025 | 0.0010 ± 0.0017 | 0.0073 ± 0.0068 | 0.0033 ± 0.0029 |
| L3-FT(Med) + Omics | 0.0179 ± 0.0045 | 0.0133 ± 0.0064 | 0.0091 ± 0.0018 | 0.0105 ± 0.0044 | 0.0110 ± 0.0086 | 0.0106 ± 0.0075 |
| L3-FT(Med) + Omics + KG | 0.0158 ± 0.0030 | 0.0058 ± 0.0011 | 0.0081 ± 0.0071 | 0.0024 ± 0.0017 | 0.0149 ± 0.0057 | 0.0050 ± 0.0000 |
| L3-FT(QA) + Omics | 0.5250 ± 0.0282 | 0.4959 ± 0.0435 | 0.5201 ± 0.0408 | 0.4905 ± 0.0532 | 0.5074 ± 0.0498 | 0.4856 ± 0.0570 |
| L3-FT(QA) + Omics + KG | 0.5185 ± 0.0240 | 0.4908 ± 0.0402 | 0.5214 ± 0.0242 | 0.4952 ± 0.0432 | 0.4856 ± 0.0395 | 0.4656 ± 0.0436 |
| G-Retriever + pre-GAT | 0.4763 ± 0.0004 | 0.3929 ± 0.0063 | 0.4642 ± 0.0181 | 0.3881 ± 0.0264 | 0.4414 ± 0.0099 | 0.3772 ± 0.0010 |
| **GALAX** | **0.5472 ± 0.0053** | **0.5332 ± 0.0031** | **0.5345 ± 0.0185** | **0.5157 ± 0.0043** | **0.5608 ± 0.0031** | **0.5533 ± 0.0033** |
**Table 2. Hit@10 and Hit@5 across datasets**
| Model | Overall Hit@10 ↑ | Overall Hit@5 ↑ | LUAD Hit@10 ↑ | LUAD Hit@5 ↑ | BRCA Hit@10 ↑ | BRCA Hit@5 ↑ |
|-------------------------|------------------|-----------------|---------------|--------------|---------------|--------------|
| M2T | 0.0029 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | 0.0000 |
| GAT | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 |
| L3 + Omics | 0.0021 ± 0.0037 | 0.0032 ± 0.0055 | 0.0048 ± 0.0082 | 0.0095 ± 0.0165 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 |
| L3 + Omics + KG | 0.0122 ± 0.0033 | 0.0085 ± 0.0037 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0056 ± 0.0096 | 0.0111 ± 0.0192 |
| L3-FT(Med) + Omics | 0.0122 ± 0.0072 | 0.0116 ± 0.0097 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0111 ± 0.0192 | 0.0000 ± 0.0000 |
| L3-FT(Med) + Omics + KG | 0.0132 ± 0.0040 | 0.0106 ± 0.0048 | 0.0048 ± 0.0082 | 0.0095 ± 0.0165 | 0.0111 ± 0.0192 | 0.0000 ± 0.0000 |
| L3-FT(QA) + Omics | 0.8693 ± 0.0157 | 0.8889 ± 0.0168 | 0.8667 ± 0.0218 | 0.8476 ± 0.0165 | 0.8389 ± 0.0096 | 0.8889 ± 0.0509 |
| L3-FT(QA) + Omics + KG | 0.8529 ± 0.0153 | 0.8794 ± 0.0114 | 0.8048 ± 0.0541 | 0.7905 ± 0.0436 | 0.8222 ± 0.0347 | 0.8778 ± 0.0192 |
| G-Retriever + pre-GAT | 0.8550 ± 0.0046 | 0.8804 ± 0.0037 | 0.8524 ± 0.0165 | 0.8857 ± 0.0000 | **0.8667 ± 0.0000** | 0.8667 ± 0.0000 |
| **GALAX** | **0.8815 ± 0.0033** | **0.9249 ± 0.0048** | **0.8810 ± 0.0082** | **0.9238 ± 0.0436** | 0.8500 ± 0.0441 | **0.8889 ± 0.0839** |
---
## 🔬 Intended Uses
- **Research use only**
- Benchmarking **graph-language foundation models** in target priorization
- Target prioritization in **cancer biology**
---
## 📜 Citation
If you use this model, please cite:
```bibtex
@article{zhang2025galax,
title = {GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine},
author = {Zhang, Heming and Huang, Di and Li, Wenyu and Province, Michael and Chen, Yixin and Payne, Philip and Li, Fuhai},
journal = {arXiv preprint arXiv:2509.20935},
year = {2025},
doi = {10.48550/arXiv.2509.20935},
url = {https://arxiv.org/abs/2509.20935}
}