--- language: en license: mit tags: - graph-ml - bioinformatics - precision-medicine - explainable-ai - reinforcement-learning datasets: - FuhaiLiAiLab/Target-QA library_name: transformers pipeline_tag: text-generation model-index: - name: GALAX results: - task: type: text-generation name: Target Prioritization dataset: name: Target-QA type: FuhaiLiAiLab/Target-QA metrics: - type: precision value: 0.5472 - type: recall value: 0.5332 - type: hit@10 value: 0.8815 - type: hit@5 value: 0.9249 --- # GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine
GALAX
GitHub Hugging Face Model Hugging Face Dataset
arXiv License
--- ## 🧩 Model Overview ![GALAX Overall Architecture](https://github.com/FuhaiLiAiLab/GALAX/blob/main/Figures/Figure3.png?raw=true) **GALAX** is a graph-augmented language model designed for explainable target prioritization in precision medicine. It combines three key components: - **LLaMA3-8B-Instruct** as the language backbone, further adapted with the BioMedGraphica corpus and fine-tuned on Target-QA. - **Graph Attention Network (GAT)** pretrained on integrated multi-omics data and BioMedGraphica knowledge graphs. - **A reinforcement-guided subgraph generator** that enables interpretable reasoning by constructing biologically meaningful subgraphs from multi-omics and knowledge graph signals. By jointly leveraging **multi-omics features**, **protein–protein interactions**, and **disease–target associations**, GALAX provides an interpretable framework for **CRISPR target prioritization** across diverse cancer cell lines. To support benchmarking and reproducibility, we also introduce the **[Target-QA dataset](https://huggingface.co/datasets/FuhaiLiAiLab/Target-QA)**. --- ## 🚀 How to Use ```python from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import snapshot_download import os, torch # 1. Load GALAX language model model_id = "FuhaiLiAiLab/GALAX" tokenizer = AutoTokenizer.from_pretrained(model_id) lm_model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype="auto" ) # 2. Access graph foundation model repo_path = snapshot_download(model_id) combined_model_path = os.path.join(repo_path, "best_combined_model.pt") device = "cuda" if torch.cuda.is_available() else "cpu" best_combined_model = torch.load(combined_model_path, map_location=device) ``` --- ## ⚙️ Experimental Setup - **Backbone LM:** LLaMA3-8B-Instruct (QA-tuned). - **Graph Encoder:** BioBERT-v1.1 embeddings + GAT with edge masking. - **Training:** Adam optimizer on 2× NVIDIA H100 (80GB). - **Top features per omics modality:** K = 10. - **Subgraph rollout depth:** L = 5, candidate nodes η = 20. - **Evaluation:** Precision, Recall, F1, Jaccard, Hit@5, Hit@10. --- ## 📊 Results GALAX consistently outperforms baselines and ablation variants. - **Overall Precision:** 0.5472 - **Overall Recall:** 0.5332 - **Hit@10:** 0.8815 - **Hit@5:** 0.9249 **Table 1. Precision and Recall across datasets** | Model | Overall Precision ↑ | Overall Recall ↑ | LUAD Precision ↑ | LUAD Recall ↑ | BRCA Precision ↑ | BRCA Recall ↑ | |-------------------------|---------------------|------------------|------------------|---------------|------------------|---------------| | M2T | 0.0016 | 0.0011 | 0.0020 | 0.0014 | 0.0000 | 0.0000 | | GAT | 0.0006 ± 0.0000 | 0.0006 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0033 ± 0.0000 | 0.0033 ± 0.0000 | | L3 + Omics | 0.0071 ± 0.0032 | 0.0013 ± 0.0002 | 0.0079 ± 0.0137 | 0.0005 ± 0.0008 | 0.0020 ± 0.0035 | 0.0017 ± 0.0029 | | L3 + Omics + KG | 0.0125 ± 0.0032 | 0.0029 ± 0.0003 | 0.0014 ± 0.0025 | 0.0010 ± 0.0017 | 0.0073 ± 0.0068 | 0.0033 ± 0.0029 | | L3-FT(Med) + Omics | 0.0179 ± 0.0045 | 0.0133 ± 0.0064 | 0.0091 ± 0.0018 | 0.0105 ± 0.0044 | 0.0110 ± 0.0086 | 0.0106 ± 0.0075 | | L3-FT(Med) + Omics + KG | 0.0158 ± 0.0030 | 0.0058 ± 0.0011 | 0.0081 ± 0.0071 | 0.0024 ± 0.0017 | 0.0149 ± 0.0057 | 0.0050 ± 0.0000 | | L3-FT(QA) + Omics | 0.5250 ± 0.0282 | 0.4959 ± 0.0435 | 0.5201 ± 0.0408 | 0.4905 ± 0.0532 | 0.5074 ± 0.0498 | 0.4856 ± 0.0570 | | L3-FT(QA) + Omics + KG | 0.5185 ± 0.0240 | 0.4908 ± 0.0402 | 0.5214 ± 0.0242 | 0.4952 ± 0.0432 | 0.4856 ± 0.0395 | 0.4656 ± 0.0436 | | G-Retriever + pre-GAT | 0.4763 ± 0.0004 | 0.3929 ± 0.0063 | 0.4642 ± 0.0181 | 0.3881 ± 0.0264 | 0.4414 ± 0.0099 | 0.3772 ± 0.0010 | | **GALAX** | **0.5472 ± 0.0053** | **0.5332 ± 0.0031** | **0.5345 ± 0.0185** | **0.5157 ± 0.0043** | **0.5608 ± 0.0031** | **0.5533 ± 0.0033** | **Table 2. Hit@10 and Hit@5 across datasets** | Model | Overall Hit@10 ↑ | Overall Hit@5 ↑ | LUAD Hit@10 ↑ | LUAD Hit@5 ↑ | BRCA Hit@10 ↑ | BRCA Hit@5 ↑ | |-------------------------|------------------|-----------------|---------------|--------------|---------------|--------------| | M2T | 0.0029 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | | GAT | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | | L3 + Omics | 0.0021 ± 0.0037 | 0.0032 ± 0.0055 | 0.0048 ± 0.0082 | 0.0095 ± 0.0165 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | | L3 + Omics + KG | 0.0122 ± 0.0033 | 0.0085 ± 0.0037 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0056 ± 0.0096 | 0.0111 ± 0.0192 | | L3-FT(Med) + Omics | 0.0122 ± 0.0072 | 0.0116 ± 0.0097 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0111 ± 0.0192 | 0.0000 ± 0.0000 | | L3-FT(Med) + Omics + KG | 0.0132 ± 0.0040 | 0.0106 ± 0.0048 | 0.0048 ± 0.0082 | 0.0095 ± 0.0165 | 0.0111 ± 0.0192 | 0.0000 ± 0.0000 | | L3-FT(QA) + Omics | 0.8693 ± 0.0157 | 0.8889 ± 0.0168 | 0.8667 ± 0.0218 | 0.8476 ± 0.0165 | 0.8389 ± 0.0096 | 0.8889 ± 0.0509 | | L3-FT(QA) + Omics + KG | 0.8529 ± 0.0153 | 0.8794 ± 0.0114 | 0.8048 ± 0.0541 | 0.7905 ± 0.0436 | 0.8222 ± 0.0347 | 0.8778 ± 0.0192 | | G-Retriever + pre-GAT | 0.8550 ± 0.0046 | 0.8804 ± 0.0037 | 0.8524 ± 0.0165 | 0.8857 ± 0.0000 | **0.8667 ± 0.0000** | 0.8667 ± 0.0000 | | **GALAX** | **0.8815 ± 0.0033** | **0.9249 ± 0.0048** | **0.8810 ± 0.0082** | **0.9238 ± 0.0436** | 0.8500 ± 0.0441 | **0.8889 ± 0.0839** | --- ## 🔬 Intended Uses - **Research use only** - Benchmarking **graph-language foundation models** in target priorization - Target prioritization in **cancer biology** --- ## 📜 Citation If you use this model, please cite: ```bibtex @article{zhang2025galax, title = {GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine}, author = {Zhang, Heming and Huang, Di and Li, Wenyu and Province, Michael and Chen, Yixin and Payne, Philip and Li, Fuhai}, journal = {arXiv preprint arXiv:2509.20935}, year = {2025}, doi = {10.48550/arXiv.2509.20935}, url = {https://arxiv.org/abs/2509.20935} }