GALAX / README.md
HemingZhang's picture
Update README.md
66a97d5 verified
metadata
language: en
license: mit
tags:
  - graph-ml
  - bioinformatics
  - precision-medicine
  - explainable-ai
  - reinforcement-learning
datasets:
  - FuhaiLiAiLab/Target-QA
library_name: transformers
pipeline_tag: text-generation
model-index:
  - name: GALAX
    results:
      - task:
          type: text-generation
          name: Target Prioritization
        dataset:
          name: Target-QA
          type: FuhaiLiAiLab/Target-QA
        metrics:
          - type: precision
            value: 0.5472
          - type: recall
            value: 0.5332
          - type: hit@10
            value: 0.8815
          - type: hit@5
            value: 0.9249

GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine

GALAX

🧩 Model Overview

GALAX Overall Architecture

GALAX is a graph-augmented language model designed for explainable target prioritization in precision medicine. It combines three key components:

  • LLaMA3-8B-Instruct as the language backbone, further adapted with the BioMedGraphica corpus and fine-tuned on Target-QA.
  • Graph Attention Network (GAT) pretrained on integrated multi-omics data and BioMedGraphica knowledge graphs.
  • A reinforcement-guided subgraph generator that enables interpretable reasoning by constructing biologically meaningful subgraphs from multi-omics and knowledge graph signals.

By jointly leveraging multi-omics features, protein–protein interactions, and disease–target associations, GALAX provides an interpretable framework for CRISPR target prioritization across diverse cancer cell lines. To support benchmarking and reproducibility, we also introduce the Target-QA dataset.


🚀 How to Use

from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
import os, torch

# 1. Load GALAX language model
model_id = "FuhaiLiAiLab/GALAX"
tokenizer = AutoTokenizer.from_pretrained(model_id)
lm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto"
)

# 2. Access graph foundation model
repo_path = snapshot_download(model_id)
combined_model_path = os.path.join(repo_path, "best_combined_model.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
best_combined_model = torch.load(combined_model_path, map_location=device)

⚙️ Experimental Setup

  • Backbone LM: LLaMA3-8B-Instruct (QA-tuned).
  • Graph Encoder: BioBERT-v1.1 embeddings + GAT with edge masking.
  • Training: Adam optimizer on 2× NVIDIA H100 (80GB).
  • Top features per omics modality: K = 10.
  • Subgraph rollout depth: L = 5, candidate nodes η = 20.
  • Evaluation: Precision, Recall, F1, Jaccard, Hit@5, Hit@10.

📊 Results

GALAX consistently outperforms baselines and ablation variants.

  • Overall Precision: 0.5472
  • Overall Recall: 0.5332
  • Hit@10: 0.8815
  • Hit@5: 0.9249

Table 1. Precision and Recall across datasets

Model Overall Precision ↑ Overall Recall ↑ LUAD Precision ↑ LUAD Recall ↑ BRCA Precision ↑ BRCA Recall ↑
M2T 0.0016 0.0011 0.0020 0.0014 0.0000 0.0000
GAT 0.0006 ± 0.0000 0.0006 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0033 ± 0.0000 0.0033 ± 0.0000
L3 + Omics 0.0071 ± 0.0032 0.0013 ± 0.0002 0.0079 ± 0.0137 0.0005 ± 0.0008 0.0020 ± 0.0035 0.0017 ± 0.0029
L3 + Omics + KG 0.0125 ± 0.0032 0.0029 ± 0.0003 0.0014 ± 0.0025 0.0010 ± 0.0017 0.0073 ± 0.0068 0.0033 ± 0.0029
L3-FT(Med) + Omics 0.0179 ± 0.0045 0.0133 ± 0.0064 0.0091 ± 0.0018 0.0105 ± 0.0044 0.0110 ± 0.0086 0.0106 ± 0.0075
L3-FT(Med) + Omics + KG 0.0158 ± 0.0030 0.0058 ± 0.0011 0.0081 ± 0.0071 0.0024 ± 0.0017 0.0149 ± 0.0057 0.0050 ± 0.0000
L3-FT(QA) + Omics 0.5250 ± 0.0282 0.4959 ± 0.0435 0.5201 ± 0.0408 0.4905 ± 0.0532 0.5074 ± 0.0498 0.4856 ± 0.0570
L3-FT(QA) + Omics + KG 0.5185 ± 0.0240 0.4908 ± 0.0402 0.5214 ± 0.0242 0.4952 ± 0.0432 0.4856 ± 0.0395 0.4656 ± 0.0436
G-Retriever + pre-GAT 0.4763 ± 0.0004 0.3929 ± 0.0063 0.4642 ± 0.0181 0.3881 ± 0.0264 0.4414 ± 0.0099 0.3772 ± 0.0010
GALAX 0.5472 ± 0.0053 0.5332 ± 0.0031 0.5345 ± 0.0185 0.5157 ± 0.0043 0.5608 ± 0.0031 0.5533 ± 0.0033

Table 2. Hit@10 and Hit@5 across datasets

Model Overall Hit@10 ↑ Overall Hit@5 ↑ LUAD Hit@10 ↑ LUAD Hit@5 ↑ BRCA Hit@10 ↑ BRCA Hit@5 ↑
M2T 0.0029 0.0000 0.0000 0.0000 0.0000 0.0000
GAT 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000
L3 + Omics 0.0021 ± 0.0037 0.0032 ± 0.0055 0.0048 ± 0.0082 0.0095 ± 0.0165 0.0000 ± 0.0000 0.0000 ± 0.0000
L3 + Omics + KG 0.0122 ± 0.0033 0.0085 ± 0.0037 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0056 ± 0.0096 0.0111 ± 0.0192
L3-FT(Med) + Omics 0.0122 ± 0.0072 0.0116 ± 0.0097 0.0000 ± 0.0000 0.0000 ± 0.0000 0.0111 ± 0.0192 0.0000 ± 0.0000
L3-FT(Med) + Omics + KG 0.0132 ± 0.0040 0.0106 ± 0.0048 0.0048 ± 0.0082 0.0095 ± 0.0165 0.0111 ± 0.0192 0.0000 ± 0.0000
L3-FT(QA) + Omics 0.8693 ± 0.0157 0.8889 ± 0.0168 0.8667 ± 0.0218 0.8476 ± 0.0165 0.8389 ± 0.0096 0.8889 ± 0.0509
L3-FT(QA) + Omics + KG 0.8529 ± 0.0153 0.8794 ± 0.0114 0.8048 ± 0.0541 0.7905 ± 0.0436 0.8222 ± 0.0347 0.8778 ± 0.0192
G-Retriever + pre-GAT 0.8550 ± 0.0046 0.8804 ± 0.0037 0.8524 ± 0.0165 0.8857 ± 0.0000 0.8667 ± 0.0000 0.8667 ± 0.0000
GALAX 0.8815 ± 0.0033 0.9249 ± 0.0048 0.8810 ± 0.0082 0.9238 ± 0.0436 0.8500 ± 0.0441 0.8889 ± 0.0839

🔬 Intended Uses

  • Research use only
  • Benchmarking graph-language foundation models in target priorization
  • Target prioritization in cancer biology

📜 Citation

If you use this model, please cite:

@article{zhang2025galax,
  title     = {GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine},
  author    = {Zhang, Heming and Huang, Di and Li, Wenyu and Province, Michael and Chen, Yixin and Payne, Philip and Li, Fuhai},
  journal   = {arXiv preprint arXiv:2509.20935},
  year      = {2025},
  doi       = {10.48550/arXiv.2509.20935},
  url       = {https://arxiv.org/abs/2509.20935}
}