Spaces:
Sleeping
Sleeping
Commit ·
6c3d8a1
1
Parent(s): f4a27d9
clean up
Browse files- README.md +130 -27
- flare/data/datasets.py +143 -337
- flare/run.sh +17 -3
- flare/test.py +114 -79
- flare/train.py +70 -79
- flare/tune.py +13 -9
- flare/utils/__init__.py +0 -2
- flare/utils/config.py +63 -0
- flare/utils/data.py +16 -83
- flare/utils/loss.py +0 -34
- flare/utils/models.py +0 -7
- flare/utils/mol_search.py +3 -53
README.md
CHANGED
|
@@ -8,55 +8,158 @@ pinned: false
|
|
| 8 |
python_version: 3.11.7
|
| 9 |
---
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
### Authors
|
| 15 |
**Yan Zhou Chen, Soha Hassoun**
|
| 16 |
-
Department of Computer Science, Tufts University
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
---
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
---
|
| 23 |
|
| 24 |
-
##
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
---
|
| 28 |
-
|
| 29 |
-
##
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
git clone https://huggingface.co/spaces/HassounLab/FLARE
|
| 32 |
-
cd
|
| 33 |
-
|
| 34 |
-
### Set up environment and install dependencies
|
| 35 |
-
```
|
| 36 |
conda create -n flare python=3.11
|
| 37 |
conda activate flare
|
| 38 |
pip install -r requirements.txt
|
| 39 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
---
|
| 41 |
-
## 🚀 Usage
|
| 42 |
-
Modify params.yaml as necessary
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
```
|
| 45 |
-
# preprocess data
|
| 46 |
-
python subformula_assign/assign_subformulae.py --spec-files ../data/sample/data.tsv --output-dir ../data/sample/subformulae --labels-file ../data/sample/data.tsv --max-formulae 60
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
```
|
| 54 |
|
|
|
|
|
|
|
| 55 |
---
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
---
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
| 8 |
python_version: 3.11.7
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# FLARE
|
| 12 |
+
|
| 13 |
+
**F**ine-grained **L**earning for **A**lignment of spectra–molecule **RE**presentations
|
| 14 |
+
|
| 15 |
+
### Authors
|
| 16 |
|
|
|
|
| 17 |
**Yan Zhou Chen, Soha Hassoun**
|
| 18 |
+
Department of Computer Science, Tufts University
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Overview
|
| 23 |
+
|
| 24 |
+
FLARE learns a joint embedding space for **MS/MS spectra** (represented as **per-peak chemical formulas** from a subformula assigner) and **molecular graphs**. The default publication model uses **FILIP-style contrastive learning** (`filipContrastive`): fine-grained similarity between spectrum tokens and graph nodes, with a temperature-scaled loss.
|
| 25 |
+
|
| 26 |
+
Use cases:
|
| 27 |
+
|
| 28 |
+
- **Retrieval**: rank a list of candidate SMILES for each query spectrum (MassSpecGym-style evaluation).
|
| 29 |
+
- **Interpretation**: the Streamlit app visualizes **peak-to-node** correspondence for a single spectrum–molecule pair.
|
| 30 |
|
| 31 |
---
|
| 32 |
|
| 33 |
+
## Model (default stack)
|
| 34 |
+
|
| 35 |
+
| Component | Setting (see `params.yaml`) |
|
| 36 |
+
|-----------|----------------------------|
|
| 37 |
+
| Spectrum input | `SpecFormula` — formula peaks from JSON in `subformula_dir_pth` |
|
| 38 |
+
| Formula source | `default` — MIST-compatible JSON (`load_mist_data`); optional `sirius` |
|
| 39 |
+
| Spectrum encoder | `Transformer_Formula` |
|
| 40 |
+
| Molecule encoder | `GNN` (DGL + dgllife GCN), node embeddings for FILIP |
|
| 41 |
+
| Training objective | `filipContrastive` — masked FILIP loss, temperature `contr_temp` |
|
| 42 |
+
| Output | Embeddings for cosine / FILIP similarity at test time |
|
| 43 |
+
|
| 44 |
+
Hyperparameters are split into: **run/logging**, **training loop**, **data paths**, **featurizers**, **encoder widths/depths**, and **evaluation** (`at_ks`, `myopic_mces_kwargs`). Only keys present in `params.yaml` are required; paths can be **relative to the repository root** (recommended) or absolute.
|
| 45 |
|
| 46 |
---
|
| 47 |
|
| 48 |
+
## Repository layout
|
| 49 |
+
|
| 50 |
+
| Path | Role |
|
| 51 |
+
|------|------|
|
| 52 |
+
| `params.yaml` | Canonical training/testing/app hyperparameters |
|
| 53 |
+
| `hparams.yaml` | Symlink to `params.yaml` (Hugging Face Spaces convention) |
|
| 54 |
+
| `flare/` | Training (`train.py`, `test.py`, `tune.py`), models, data pipeline |
|
| 55 |
+
| `massspecgym/` | Vendored MassSpecGym Lightning base classes and utilities |
|
| 56 |
+
| `app.py`, `app_utils/` | Streamlit peak–node visualization |
|
| 57 |
+
| `pretrained_models/` | Place public checkpoints here (e.g. `flare.ckpt`) |
|
| 58 |
+
| `experiments/` | Default output root for new runs (see `flare/definitions.py`) |
|
| 59 |
+
| `archive/` | Older scripts and features **not** part of the slim release (MAGMA, class experiments, legacy YAML, etc.); nothing was deleted |
|
| 60 |
|
| 61 |
---
|
| 62 |
+
|
| 63 |
+
## Environment variables (no hardcoded machine paths)
|
| 64 |
+
|
| 65 |
+
| Variable | Purpose |
|
| 66 |
+
|----------|---------|
|
| 67 |
+
| `FLARE_PARAMS` | Path to YAML params (default: `<repo>/params.yaml`) |
|
| 68 |
+
| `FLARE_CHECKPOINT` | Checkpoint for the app or manual runs |
|
| 69 |
+
| `FLARE_DEBUG_DATASET` | When `debug: true`, TSV path for a tiny local dataset |
|
| 70 |
+
| `FLARE_REPO_ROOT` | Optional; overrides repo root for resolving relative paths in `default_param_path()` |
|
| 71 |
+
| `MASSSPECGYM_ROOT` | Optional extra `sys.path` root if you use an external `massspecgym` checkout |
|
| 72 |
+
| `FLARE_UPLOAD_CKPT`, `HF_REPO_ID`, `HF_REPO_TYPE`, `HF_TOKEN` | See `app_utils/upload_model.py` for HF uploads |
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## Setup
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
git clone https://huggingface.co/spaces/HassounLab/FLARE
|
| 80 |
+
cd FLARE
|
| 81 |
+
|
|
|
|
|
|
|
| 82 |
conda create -n flare python=3.11
|
| 83 |
conda activate flare
|
| 84 |
pip install -r requirements.txt
|
| 85 |
```
|
| 86 |
+
|
| 87 |
+
Place **MassSpecGym** (or your) spectrum TSV, **candidate JSON**, and **subformula JSON directory** where you want them, then set paths in `params.yaml` (relative paths like `data/MassSpecGym.tsv` resolve from the repo root).
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## Data preparation
|
| 92 |
+
|
| 93 |
+
Per-spectrum subformula JSON files (one file per spectrum id, MIST-style) are required for `SpecFormula`. Generate them with the bundled assigner (adapted from MIST):
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
cd flare/subformula_assign
|
| 97 |
+
export SPEC_FILES=/path/to/spectra.tsv
|
| 98 |
+
export OUTPUT_DIR=/path/to/subformulae_out
|
| 99 |
+
export LABELS_FILE=/path/to/spectra.tsv # often same as SPEC_FILES
|
| 100 |
+
export MAX_FORMULAE=60
|
| 101 |
+
bash run.sh
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Defaults in `run.sh` point at `data/sample/` under the repo if you add a small sample there.
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
## Training
|
| 109 |
+
|
| 110 |
+
From the repository root (so `flare` and `massspecgym` import correctly):
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
cd flare
|
| 114 |
+
python train.py # uses FLARE_PARAMS or ../params.yaml
|
| 115 |
+
python train.py --param_pth /path/to/custom.yaml
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
`train.py` creates `experiments/<YYYYMMDD>_<run_name>/`, writes TensorBoard logs there, and saves checkpoints. `df_test_path` defaults to `<experiment_dir>/result.pkl` if unset.
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## Testing (retrieval)
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
cd flare
|
| 126 |
+
python test.py \
|
| 127 |
+
--checkpoint_pth /path/to/epoch=....ckpt \
|
| 128 |
+
--exp_dir /path/to/experiment_dir # optional; else latest matching run_name
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
Useful flags: `--candidates_pth`, `--df_test_pth`, `--external_test` (no positive label in the list). Override params file with `--param_pth` or `FLARE_PARAMS`.
|
| 132 |
+
|
| 133 |
---
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
## Hyperparameter search
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
cd flare
|
| 139 |
+
python tune.py --n_trials 20
|
| 140 |
```
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
Uses Optuna; study database and logs live under `experiments/<date>_<run_name>_optuna/`. Best YAML is written to `best_params.yaml` in that folder.
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
## Streamlit app (peak-to-node visualization)
|
| 147 |
|
| 148 |
+
```bash
|
| 149 |
+
streamlit run app.py
|
| 150 |
```
|
| 151 |
|
| 152 |
+
The app loads architecture settings from `FLARE_PARAMS` (default `params.yaml`) and weights from `FLARE_CHECKPOINT` (default `pretrained_models/flare.ckpt`). Ensure the checkpoint matches the architecture in the YAML.
|
| 153 |
+
|
| 154 |
---
|
| 155 |
+
|
| 156 |
+
## Acknowledgments
|
| 157 |
+
|
| 158 |
+
- **Data**: [MassSpecGym](https://github.com/pluskal-lab/MassSpecGym)
|
| 159 |
+
- **Subformula tooling**: [MIST](https://github.com/samgoldman97/mist/tree/main_v2)
|
| 160 |
|
| 161 |
---
|
| 162 |
+
|
| 163 |
+
## Contact
|
| 164 |
+
|
| 165 |
+
For questions: soha.hassoun@tufts.edu
|
flare/data/datasets.py
CHANGED
|
@@ -1,124 +1,30 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
import json
|
| 3 |
import typing as T
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
import torch
|
|
|
|
| 6 |
import massspecgym.utils as utils
|
| 7 |
-
|
| 8 |
-
from torch.utils.data.dataset import Dataset
|
| 9 |
-
from torch.utils.data.dataloader import default_collate
|
| 10 |
-
import dgl
|
| 11 |
-
from collections import defaultdict
|
| 12 |
-
from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
|
| 13 |
from massspecgym.data.datasets import MassSpecDataset
|
| 14 |
-
|
| 15 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 16 |
from massspecgym.models.base import Stage
|
| 17 |
-
import
|
| 18 |
-
import
|
| 19 |
-
import
|
| 20 |
-
|
| 21 |
-
from rdkit import Chem
|
| 22 |
-
from magma.run_magma import run_magma
|
| 23 |
-
import matchms
|
| 24 |
|
| 25 |
class JESTR1_MassSpecDataset(MassSpecDataset):
|
| 26 |
-
|
| 27 |
-
self,
|
| 28 |
-
spectra_view: str,
|
| 29 |
-
fp_dir_pth: str = None,
|
| 30 |
-
cons_spec_dir_pth: str = None,
|
| 31 |
-
NL_spec_dir_pth: str = None,
|
| 32 |
-
**kwargs
|
| 33 |
-
):
|
| 34 |
-
super().__init__(**kwargs)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
self.use_NL_spec = False
|
| 39 |
self.spectra_view = spectra_view
|
| 40 |
|
| 41 |
-
# load fingerprints
|
| 42 |
-
self._load_fp(fp_dir_pth)
|
| 43 |
-
|
| 44 |
-
# load consensus
|
| 45 |
-
self._load_cons_spec(cons_spec_dir_pth)
|
| 46 |
-
|
| 47 |
-
# load NL specs
|
| 48 |
-
self._load_NL_spec(NL_spec_dir_pth)
|
| 49 |
-
|
| 50 |
-
def _load_fp(self, fp_dir_pth):
|
| 51 |
-
if fp_dir_pth is not None:
|
| 52 |
-
self.use_fp = True
|
| 53 |
-
if fp_dir_pth:
|
| 54 |
-
with open(fp_dir_pth, 'rb') as f:
|
| 55 |
-
self.smiles_to_fp = pickle.load(f)
|
| 56 |
-
else:
|
| 57 |
-
self.smiles_to_fp = {}
|
| 58 |
-
|
| 59 |
-
def _load_cons_spec(self, cons_spec_dir_pth):
|
| 60 |
-
if cons_spec_dir_pth is not None:
|
| 61 |
-
self.use_cons_spec = True
|
| 62 |
-
with open(cons_spec_dir_pth, 'rb') as f:
|
| 63 |
-
cons_specs = pickle.load(f)
|
| 64 |
-
|
| 65 |
-
# Convert spectra to matchms spectra
|
| 66 |
-
matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
|
| 67 |
-
spectra = cons_specs.apply(matchMS_preparer.prepare,axis=1)
|
| 68 |
-
|
| 69 |
-
self.cons_specs = dict(zip(cons_specs['smiles'].tolist(), spectra))
|
| 70 |
-
|
| 71 |
-
def _load_NL_spec(self, NL_spec_dir_pth):
|
| 72 |
-
if NL_spec_dir_pth is not None:
|
| 73 |
-
self.use_NL_spec = True
|
| 74 |
-
with open(NL_spec_dir_pth, 'rb') as f:
|
| 75 |
-
NL_specs = pickle.load(f)
|
| 76 |
-
|
| 77 |
-
# Convert spectra to matchms spectra
|
| 78 |
-
matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
|
| 79 |
-
self.NL_specs = NL_specs.apply(matchMS_preparer.prepare,axis=1)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
|
| 83 |
-
|
| 84 |
-
spec = self.spectra[i]
|
| 85 |
-
metadata = self.metadata.iloc[i]
|
| 86 |
-
mol = metadata["smiles"] if 'smiles' in metadata else metadata["identifier"]
|
| 87 |
-
|
| 88 |
-
# Apply all transformations to the spectrum
|
| 89 |
-
item = {}
|
| 90 |
-
if transform_spec and self.spec_transform:
|
| 91 |
-
if isinstance(self.spec_transform, dict):
|
| 92 |
-
for key, transform in self.spec_transform.items():
|
| 93 |
-
item[key] = transform(spec) if transform is not None else spec
|
| 94 |
-
else:
|
| 95 |
-
item["spec"] = self.spec_transform(spec)
|
| 96 |
-
|
| 97 |
-
if self.return_mol_freq:
|
| 98 |
-
item["mol_freq"] = metadata["mol_freq"]
|
| 99 |
-
|
| 100 |
-
if self.return_identifier:
|
| 101 |
-
item["identifier"] = metadata["identifier"]
|
| 102 |
-
|
| 103 |
-
if self.use_fp and self.smiles_to_fp:
|
| 104 |
-
item['fp'] = torch.Tensor(self.smiles_to_fp[mol].ToList())
|
| 105 |
-
|
| 106 |
-
if self.use_cons_spec:
|
| 107 |
-
item['cons_spec'] = self.spec_transform[self.spectra_view](self.cons_specs[mol])
|
| 108 |
-
|
| 109 |
-
if self.use_NL_spec:
|
| 110 |
-
item['NL_spec'] = self.spec_transform[self.spectra_view](self.NL_specs[i])
|
| 111 |
-
|
| 112 |
-
# Apply all transformations to the molecule
|
| 113 |
-
if transform_mol and self.mol_transform:
|
| 114 |
-
if isinstance(self.mol_transform, dict):
|
| 115 |
-
for key, transform in self.mol_transform.items():
|
| 116 |
-
item[key] = transform(mol) if transform is not None else mol
|
| 117 |
-
else:
|
| 118 |
-
item["mol"] = self.mol_transform(mol)
|
| 119 |
-
else:
|
| 120 |
-
item["mol"] = mol
|
| 121 |
-
return item
|
| 122 |
|
| 123 |
class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
|
| 124 |
def __init__(
|
|
@@ -128,26 +34,16 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
|
|
| 128 |
mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]],
|
| 129 |
pth: T.Optional[Path],
|
| 130 |
subformula_dir_pth: str,
|
| 131 |
-
fp_dir_pth: str = None,
|
| 132 |
-
NL_spec_dir_pth: str = None,
|
| 133 |
-
cons_spec_dir_pth: str = None,
|
| 134 |
return_mol_freq: bool = False,
|
| 135 |
return_identifier: bool = True,
|
| 136 |
dtype: T.Type = torch.float32,
|
| 137 |
-
formula_source =
|
| 138 |
-
stage: Stage = Stage.TRAIN
|
| 139 |
):
|
| 140 |
-
"""
|
| 141 |
-
Args:
|
| 142 |
-
"""
|
| 143 |
self.pth = pth
|
| 144 |
self.spec_transform = spec_transform
|
| 145 |
self.mol_transform = mol_transform
|
| 146 |
self.return_mol_freq = return_mol_freq
|
| 147 |
-
self.pred_fp = False
|
| 148 |
-
self.use_fp = False
|
| 149 |
-
self.use_cons_spec = False
|
| 150 |
-
self.use_NL_spec = False
|
| 151 |
self.spectra_view = spectra_view
|
| 152 |
self.formula_source = formula_source
|
| 153 |
self.subformula_dir_pth = subformula_dir_pth
|
|
@@ -155,31 +51,23 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
|
|
| 155 |
if isinstance(self.pth, str):
|
| 156 |
self.pth = Path(self.pth)
|
| 157 |
|
| 158 |
-
self.spectra_view = spectra_view
|
| 159 |
print("Data path: ", self.pth)
|
| 160 |
self.metadata = pd.read_csv(self.pth, sep="\t")
|
| 161 |
|
| 162 |
-
# load subformulas
|
| 163 |
id_to_spec = self._load_id_to_spec(stage)
|
| 164 |
-
|
| 165 |
-
# load fingerprints
|
| 166 |
-
self._load_fp(fp_dir_pth)
|
| 167 |
-
|
| 168 |
-
# load consensus spectra
|
| 169 |
-
self._load_cons_spec(cons_spec_dir_pth)
|
| 170 |
|
| 171 |
-
|
| 172 |
-
self._load_NL_spec(NL_spec_dir_pth)
|
| 173 |
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
-
formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'})
|
| 177 |
-
self.metadata = self.metadata.merge(formula_df, on='identifier')
|
| 178 |
-
|
| 179 |
-
# create matchms spectra
|
| 180 |
matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view)
|
| 181 |
-
self.spectra = self.metadata.apply(matchMS_preparer.prepare,axis=1)
|
| 182 |
-
|
| 183 |
if self.return_mol_freq:
|
| 184 |
if "inchikey" not in self.metadata.columns:
|
| 185 |
self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
|
|
@@ -187,108 +75,104 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
|
|
| 187 |
|
| 188 |
self.return_identifier = return_identifier
|
| 189 |
self.dtype = dtype
|
| 190 |
-
|
| 191 |
def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
|
| 192 |
-
item = super().__getitem__(i, transform_spec, transform_mol
|
| 193 |
-
mol = item[
|
| 194 |
|
| 195 |
-
# transform mol
|
| 196 |
if transform_mol:
|
| 197 |
if isinstance(self.mol_transform, dict):
|
| 198 |
for key, transform in self.mol_transform.items():
|
| 199 |
item[key] = transform(mol) if transform is not None else mol
|
| 200 |
else:
|
| 201 |
item["mol"] = self.mol_transform(mol)
|
|
|
|
|
|
|
| 202 |
|
| 203 |
return item
|
| 204 |
|
| 205 |
def _load_id_to_spec(self, stage):
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
prec_mz_list = self.metadata['precursor_mz'].tolist()
|
| 216 |
id_to_spec = self.subformulaLoader(all_spec_ids, form_list, prec_mz_list)
|
| 217 |
|
| 218 |
-
# create subformula spectra if no subformula is available
|
| 219 |
tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
|
| 220 |
-
tmp_df = self.metadata[self.metadata[
|
| 221 |
-
tmp_df[
|
| 222 |
-
id_to_spec.update(dict(zip(tmp_df[
|
| 223 |
|
| 224 |
return id_to_spec
|
| 225 |
|
|
|
|
| 226 |
class ContrastiveDataset(Dataset):
|
| 227 |
-
def __init__(
|
| 228 |
-
self,
|
| 229 |
-
spec_mol_data,
|
| 230 |
-
):
|
| 231 |
super().__init__()
|
| 232 |
-
|
| 233 |
indices = spec_mol_data.indices
|
| 234 |
self.spec_mol_data = spec_mol_data
|
| 235 |
-
self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby(
|
| 236 |
self.smiles_to_spec_couter = defaultdict(int)
|
| 237 |
self.smiles_list = list(self.smiles_to_specmol_ids.keys())
|
| 238 |
|
| 239 |
def __len__(self) -> int:
|
| 240 |
return len(self.smiles_list)
|
| 241 |
-
|
| 242 |
-
def __getitem__(self, i:int) -> dict:
|
| 243 |
mol = self.smiles_list[i]
|
| 244 |
|
| 245 |
-
# select spectrum (iterate through list of spectra)
|
| 246 |
specmol_ids = self.smiles_to_specmol_ids[mol]
|
| 247 |
counter = self.smiles_to_spec_couter[mol]
|
| 248 |
specmol_id = specmol_ids[counter % len(specmol_ids)]
|
| 249 |
|
| 250 |
item = self.spec_mol_data.__getitem__(specmol_id)
|
| 251 |
-
self.smiles_to_spec_couter[mol] = counter+1
|
| 252 |
-
# item['smiles'] = mol
|
| 253 |
-
# item['spec_id'] = specmol_id
|
| 254 |
return item
|
| 255 |
|
| 256 |
@staticmethod
|
| 257 |
-
def collate_fn(
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
require_pad = False
|
| 261 |
-
if
|
| 262 |
require_pad = True
|
| 263 |
-
padding_value=-5 if spec_enc in (
|
| 264 |
non_standard_collate.append(spectra_view)
|
| 265 |
else:
|
| 266 |
-
non_standard_collate.remove(
|
| 267 |
-
non_standard_collate.remove('NL_spec')
|
| 268 |
|
| 269 |
collated_batch = {}
|
| 270 |
-
# standard collate
|
| 271 |
for k in batch[0].keys():
|
| 272 |
if k not in non_standard_collate:
|
| 273 |
try:
|
| 274 |
collated_batch[k] = default_collate([item[k] for item in batch])
|
| 275 |
-
except:
|
| 276 |
print(f"Error in collating key {k}")
|
| 277 |
raise
|
| 278 |
-
|
| 279 |
-
# batch graphs
|
| 280 |
if batch_mol:
|
| 281 |
-
|
| 282 |
-
batch_mol_nodes= []
|
| 283 |
|
| 284 |
for item in batch:
|
| 285 |
-
|
| 286 |
batch_mol_nodes.append(item[mol_key].num_nodes())
|
| 287 |
|
| 288 |
-
collated_batch[mol_key] = dgl.batch(
|
| 289 |
-
collated_batch[
|
| 290 |
-
|
| 291 |
-
# pad peaks/formulas
|
| 292 |
if require_pad:
|
| 293 |
peaks = []
|
| 294 |
n_peaks = []
|
|
@@ -296,54 +180,40 @@ class ContrastiveDataset(Dataset):
|
|
| 296 |
peaks.append(item[spectra_view])
|
| 297 |
n_peaks.append(len(item[spectra_view]))
|
| 298 |
collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
|
| 299 |
-
collated_batch[
|
| 300 |
-
|
| 301 |
-
if 'cons_spec' in batch[0]:
|
| 302 |
-
peaks = []
|
| 303 |
-
n_peaks = []
|
| 304 |
-
for item in batch:
|
| 305 |
-
peaks.append(item['cons_spec'])
|
| 306 |
-
n_peaks.append(len(item['cons_spec']))
|
| 307 |
-
collated_batch['cons_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
|
| 308 |
-
collated_batch['cons_n_peaks'] = n_peaks
|
| 309 |
-
|
| 310 |
-
if 'NL_spec' in batch[0]:
|
| 311 |
-
peaks = []
|
| 312 |
-
n_peaks = []
|
| 313 |
-
for item in batch:
|
| 314 |
-
peaks.append(item['NL_spec'])
|
| 315 |
-
n_peaks.append(len(item['NL_spec']))
|
| 316 |
-
collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
|
| 317 |
-
collated_batch['NL_n_peaks'] = n_peaks
|
| 318 |
return collated_batch
|
| 319 |
-
|
| 320 |
-
|
| 321 |
|
| 322 |
class ExpandedRetrievalDataset:
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
def __init__(
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
|
| 342 |
self.candidates_pth = candidates_pth
|
|
|
|
| 343 |
self.mol_label_transform = mol_label_transform
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
| 347 |
candidates = json.load(file)
|
| 348 |
|
| 349 |
self.candidates = {}
|
|
@@ -351,130 +221,66 @@ class ExpandedRetrievalDataset:
|
|
| 351 |
clean_cands = []
|
| 352 |
for c in cand:
|
| 353 |
try:
|
| 354 |
-
if
|
| 355 |
clean_cands.append(c)
|
| 356 |
-
except:
|
| 357 |
print(f"Error in processing candidate {c} for smiles {s}")
|
| 358 |
-
|
| 359 |
-
self.candidates[s] = clean_cands
|
| 360 |
-
|
| 361 |
-
self.spec_cand = [] #(spec index, cand_smiles, true_label)
|
| 362 |
-
|
| 363 |
-
# use for external dataset where target smiles is not known
|
| 364 |
-
# self.candidates should be a dict of identifier to candidates
|
| 365 |
-
if 'smiles' not in self.metadata.columns:
|
| 366 |
-
if not isinstance(self.metadata.iloc[0]['identifier'], str):
|
| 367 |
-
self.metadata['smiles'] = self.metadata['identifier'].apply(str)
|
| 368 |
-
else:
|
| 369 |
-
self.metadata['smiles'] = self.metadata['identifier']
|
| 370 |
-
|
| 371 |
-
# keep datapoints where there are candidates
|
| 372 |
-
self.metadata = self.metadata[self.metadata['smiles'].isin(self.candidates.keys())]
|
| 373 |
-
|
| 374 |
-
test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
|
| 375 |
-
test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
|
| 376 |
-
|
| 377 |
-
self.spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index))
|
| 378 |
-
|
| 379 |
-
for spec_id, s in zip(test_ms_id, test_smiles):
|
| 380 |
-
candidates = self.candidates[s]
|
| 381 |
-
|
| 382 |
-
# mol_label = self.mol_label_transform(s)
|
| 383 |
-
# labels = [self.mol_label_transform(c) == mol_label for c in candidates]
|
| 384 |
-
labels = [c == s for c in candidates]
|
| 385 |
-
if len(candidates) == 0:
|
| 386 |
-
print(f"Skipping {spec_id}; empty candidate set")
|
| 387 |
-
continue
|
| 388 |
-
if not any(labels):
|
| 389 |
-
# print(f"Target smiles not in candidate set")
|
| 390 |
-
pass
|
| 391 |
-
|
| 392 |
-
self.spec_cand.extend([(self.spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
|
| 393 |
-
|
| 394 |
-
def __getattr__(self, name):
|
| 395 |
-
return self.instance.__getattribute__(name)
|
| 396 |
-
|
| 397 |
-
def __len__(self):
|
| 398 |
-
return len(self.spec_cand)
|
| 399 |
-
|
| 400 |
-
def __getitem__(self, i):
|
| 401 |
-
spec_i = self.spec_cand[i][0]
|
| 402 |
-
cand_smiles = self.spec_cand[i][1]
|
| 403 |
-
label = self.spec_cand[i][2]
|
| 404 |
|
| 405 |
-
|
| 406 |
-
item = self.instance.__getitem__(spec_i, transform_mol=False, transform_spec=False)
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
spec_data = run_magma(i, mzs, intensities, cand_smiles, adduct)
|
| 414 |
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
-
|
| 418 |
-
mz = np.array(spec['formula_mzs']),
|
| 419 |
-
intensities = np.array(spec['formula_intensities']),
|
| 420 |
-
metadata = {'precursor_mz': precursor_mz, 'formulas': np.array(spec['formulas'])})
|
| 421 |
-
|
| 422 |
-
if isinstance(self.spec_transform, dict):
|
| 423 |
-
|
| 424 |
-
for key, transform in self.spec_transform.items():
|
| 425 |
-
item[key] = transform(spec) if transform is not None else spec
|
| 426 |
-
else:
|
| 427 |
-
item["spec"] = self.spec_transform(spec)
|
| 428 |
|
|
|
|
|
|
|
| 429 |
else:
|
| 430 |
-
|
| 431 |
|
| 432 |
-
|
| 433 |
-
item['cand_smiles'] = cand_smiles
|
| 434 |
-
item['label'] = label
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
| 438 |
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
def __init__(self,
|
| 444 |
-
use_formulas: bool,
|
| 445 |
-
aug_cands_dir_pth: str,
|
| 446 |
-
aug_cands_size: int,
|
| 447 |
-
**kwargs):
|
| 448 |
-
self.aug_cands_size = aug_cands_size
|
| 449 |
-
self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
|
| 450 |
-
|
| 451 |
-
with open(aug_cands_dir_pth, 'rb') as f:
|
| 452 |
-
aug_cands = pickle.load(f)
|
| 453 |
-
|
| 454 |
-
if self.use_fp:
|
| 455 |
-
self.fpgen = AllChem.GetMorganGenerator(radius=5,fpSize=1024)
|
| 456 |
-
|
| 457 |
-
self.aug_cands = {}
|
| 458 |
-
targets = np.array(list(aug_cands.keys()))
|
| 459 |
-
for smiles, cands in aug_cands.items():
|
| 460 |
-
# sort candidates by tanimoto similarity
|
| 461 |
-
cands.sort(key=lambda x: x[1], reverse=True)
|
| 462 |
-
cands = [c for c in cands if '.' not in c]
|
| 463 |
-
# assert(len(cands) >0)
|
| 464 |
-
if len(cands) <=1: # if no candidates, shuffle from target list
|
| 465 |
-
np.random.shuffle(targets)
|
| 466 |
-
cands = targets
|
| 467 |
-
self.aug_cands[smiles] = itertools.cycle(cands)
|
| 468 |
|
| 469 |
def __getattr__(self, name):
|
| 470 |
return self.instance.__getattribute__(name)
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
| 472 |
def __getitem__(self, i):
|
| 473 |
-
|
|
|
|
|
|
|
| 474 |
|
| 475 |
-
|
| 476 |
-
item[
|
| 477 |
-
item["
|
| 478 |
-
item["
|
| 479 |
|
| 480 |
-
return item
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import typing as T
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import dgl
|
| 7 |
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
import torch
|
| 10 |
+
import flare.utils.data as data_utils
|
| 11 |
import massspecgym.utils as utils
|
| 12 |
+
import matchms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from massspecgym.data.datasets import MassSpecDataset
|
| 14 |
+
from massspecgym.data.transforms import MolTransform, MolToInChIKey, SpecTransform
|
|
|
|
| 15 |
from massspecgym.models.base import Stage
|
| 16 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 17 |
+
from torch.utils.data.dataloader import default_collate
|
| 18 |
+
from torch.utils.data.dataset import Dataset
|
| 19 |
+
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
class JESTR1_MassSpecDataset(MassSpecDataset):
|
| 22 |
+
"""Same as MassSpecDataset; keeps `spectra_view` for API compatibility."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
def __init__(self, spectra_view: str, **kwargs):
|
| 25 |
+
super().__init__(**kwargs)
|
|
|
|
| 26 |
self.spectra_view = spectra_view
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
|
| 30 |
def __init__(
|
|
|
|
| 34 |
mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]],
|
| 35 |
pth: T.Optional[Path],
|
| 36 |
subformula_dir_pth: str,
|
|
|
|
|
|
|
|
|
|
| 37 |
return_mol_freq: bool = False,
|
| 38 |
return_identifier: bool = True,
|
| 39 |
dtype: T.Type = torch.float32,
|
| 40 |
+
formula_source: str = "default",
|
| 41 |
+
stage: Stage = Stage.TRAIN,
|
| 42 |
):
|
|
|
|
|
|
|
|
|
|
| 43 |
self.pth = pth
|
| 44 |
self.spec_transform = spec_transform
|
| 45 |
self.mol_transform = mol_transform
|
| 46 |
self.return_mol_freq = return_mol_freq
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
self.spectra_view = spectra_view
|
| 48 |
self.formula_source = formula_source
|
| 49 |
self.subformula_dir_pth = subformula_dir_pth
|
|
|
|
| 51 |
if isinstance(self.pth, str):
|
| 52 |
self.pth = Path(self.pth)
|
| 53 |
|
|
|
|
| 54 |
print("Data path: ", self.pth)
|
| 55 |
self.metadata = pd.read_csv(self.pth, sep="\t")
|
| 56 |
|
|
|
|
| 57 |
id_to_spec = self._load_id_to_spec(stage)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
self.metadata = self.metadata[self.metadata["identifier"].isin(id_to_spec)]
|
|
|
|
| 60 |
|
| 61 |
+
formula_df = (
|
| 62 |
+
pd.DataFrame.from_dict(id_to_spec, orient="index")
|
| 63 |
+
.reset_index()
|
| 64 |
+
.rename(columns={"index": "identifier"})
|
| 65 |
+
)
|
| 66 |
+
self.metadata = self.metadata.merge(formula_df, on="identifier")
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view)
|
| 69 |
+
self.spectra = self.metadata.apply(matchMS_preparer.prepare, axis=1)
|
| 70 |
+
|
| 71 |
if self.return_mol_freq:
|
| 72 |
if "inchikey" not in self.metadata.columns:
|
| 73 |
self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
|
|
|
|
| 75 |
|
| 76 |
self.return_identifier = return_identifier
|
| 77 |
self.dtype = dtype
|
| 78 |
+
|
| 79 |
def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
|
| 80 |
+
item = super().__getitem__(i, transform_spec, transform_mol=False)
|
| 81 |
+
mol = item["mol"]
|
| 82 |
|
|
|
|
| 83 |
if transform_mol:
|
| 84 |
if isinstance(self.mol_transform, dict):
|
| 85 |
for key, transform in self.mol_transform.items():
|
| 86 |
item[key] = transform(mol) if transform is not None else mol
|
| 87 |
else:
|
| 88 |
item["mol"] = self.mol_transform(mol)
|
| 89 |
+
else:
|
| 90 |
+
item["mol"] = mol
|
| 91 |
|
| 92 |
return item
|
| 93 |
|
| 94 |
def _load_id_to_spec(self, stage):
|
| 95 |
+
all_spec_ids = self.metadata["identifier"].tolist()
|
| 96 |
+
self.subformulaLoader = data_utils.Subformula_Loader(
|
| 97 |
+
spectra_view=self.spectra_view,
|
| 98 |
+
dir_path=self.subformula_dir_pth,
|
| 99 |
+
formula_source=self.formula_source,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
form_list = self.metadata["formula"].tolist()
|
| 103 |
+
prec_mz_list = self.metadata["precursor_mz"].tolist()
|
|
|
|
| 104 |
id_to_spec = self.subformulaLoader(all_spec_ids, form_list, prec_mz_list)
|
| 105 |
|
|
|
|
| 106 |
tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
|
| 107 |
+
tmp_df = self.metadata[self.metadata["identifier"].isin(tmp_ids)]
|
| 108 |
+
tmp_df["spec"] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
|
| 109 |
+
id_to_spec.update(dict(zip(tmp_df["identifier"].tolist(), tmp_df["spec"].tolist())))
|
| 110 |
|
| 111 |
return id_to_spec
|
| 112 |
|
| 113 |
+
|
| 114 |
class ContrastiveDataset(Dataset):
|
| 115 |
+
def __init__(self, spec_mol_data):
|
|
|
|
|
|
|
|
|
|
| 116 |
super().__init__()
|
| 117 |
+
|
| 118 |
indices = spec_mol_data.indices
|
| 119 |
self.spec_mol_data = spec_mol_data
|
| 120 |
+
self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby("smiles").indices
|
| 121 |
self.smiles_to_spec_couter = defaultdict(int)
|
| 122 |
self.smiles_list = list(self.smiles_to_specmol_ids.keys())
|
| 123 |
|
| 124 |
def __len__(self) -> int:
|
| 125 |
return len(self.smiles_list)
|
| 126 |
+
|
| 127 |
+
def __getitem__(self, i: int) -> dict:
|
| 128 |
mol = self.smiles_list[i]
|
| 129 |
|
|
|
|
| 130 |
specmol_ids = self.smiles_to_specmol_ids[mol]
|
| 131 |
counter = self.smiles_to_spec_couter[mol]
|
| 132 |
specmol_id = specmol_ids[counter % len(specmol_ids)]
|
| 133 |
|
| 134 |
item = self.spec_mol_data.__getitem__(specmol_id)
|
| 135 |
+
self.smiles_to_spec_couter[mol] = counter + 1
|
|
|
|
|
|
|
| 136 |
return item
|
| 137 |
|
| 138 |
@staticmethod
|
| 139 |
+
def collate_fn(
|
| 140 |
+
batch: T.Iterable[dict],
|
| 141 |
+
spec_enc: str,
|
| 142 |
+
spectra_view: str,
|
| 143 |
+
stage=None,
|
| 144 |
+
batch_mol: bool = True,
|
| 145 |
+
) -> dict:
|
| 146 |
+
mol_key = "cand" if stage == Stage.TEST else "mol"
|
| 147 |
+
non_standard_collate = ["mol", "cand", "aug_cands"]
|
| 148 |
require_pad = False
|
| 149 |
+
if "Formula" in spectra_view or "Tokens" in spectra_view:
|
| 150 |
require_pad = True
|
| 151 |
+
padding_value = -5 if spec_enc in ("Transformer_Formula", "Formula_BinnedSpec", "Transformer_MzInt") else 0
|
| 152 |
non_standard_collate.append(spectra_view)
|
| 153 |
else:
|
| 154 |
+
non_standard_collate.remove("aug_cands")
|
|
|
|
| 155 |
|
| 156 |
collated_batch = {}
|
|
|
|
| 157 |
for k in batch[0].keys():
|
| 158 |
if k not in non_standard_collate:
|
| 159 |
try:
|
| 160 |
collated_batch[k] = default_collate([item[k] for item in batch])
|
| 161 |
+
except Exception:
|
| 162 |
print(f"Error in collating key {k}")
|
| 163 |
raise
|
| 164 |
+
|
|
|
|
| 165 |
if batch_mol:
|
| 166 |
+
batch_mol_list = []
|
| 167 |
+
batch_mol_nodes = []
|
| 168 |
|
| 169 |
for item in batch:
|
| 170 |
+
batch_mol_list.append(item[mol_key])
|
| 171 |
batch_mol_nodes.append(item[mol_key].num_nodes())
|
| 172 |
|
| 173 |
+
collated_batch[mol_key] = dgl.batch(batch_mol_list)
|
| 174 |
+
collated_batch["mol_n_nodes"] = batch_mol_nodes
|
| 175 |
+
|
|
|
|
| 176 |
if require_pad:
|
| 177 |
peaks = []
|
| 178 |
n_peaks = []
|
|
|
|
| 180 |
peaks.append(item[spectra_view])
|
| 181 |
n_peaks.append(len(item[spectra_view]))
|
| 182 |
collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
|
| 183 |
+
collated_batch["n_peaks"] = n_peaks
|
| 184 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
return collated_batch
|
| 186 |
+
|
|
|
|
| 187 |
|
| 188 |
class ExpandedRetrievalDataset:
|
| 189 |
+
"""Test-time retrieval over a fixed candidate pool per spectrum/formula."""
|
| 190 |
+
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
use_formulas: bool = True,
|
| 194 |
+
mol_label_transform: MolTransform = MolToInChIKey(),
|
| 195 |
+
candidates_pth: T.Optional[T.Union[Path, str]] = None,
|
| 196 |
+
formula_to_smiles_pth: T.Optional[T.Union[Path, str]] = None,
|
| 197 |
+
external_test: bool = False,
|
| 198 |
+
**kwargs,
|
| 199 |
+
):
|
| 200 |
+
self.external_test = external_test
|
| 201 |
+
|
| 202 |
+
self.instance = (
|
| 203 |
+
MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False, stage=Stage.TEST)
|
| 204 |
+
if use_formulas
|
| 205 |
+
else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
|
| 206 |
+
)
|
| 207 |
|
| 208 |
self.candidates_pth = candidates_pth
|
| 209 |
+
self.formula_to_smiles_pth = formula_to_smiles_pth
|
| 210 |
self.mol_label_transform = mol_label_transform
|
| 211 |
+
|
| 212 |
+
candidate_source_pth = self.formula_to_smiles_pth if self.formula_to_smiles_pth else self.candidates_pth
|
| 213 |
+
if not candidate_source_pth:
|
| 214 |
+
raise ValueError("One of candidates_pth or formula_to_smiles_pth must be provided.")
|
| 215 |
+
|
| 216 |
+
with open(candidate_source_pth, "r") as file:
|
| 217 |
candidates = json.load(file)
|
| 218 |
|
| 219 |
self.candidates = {}
|
|
|
|
| 221 |
clean_cands = []
|
| 222 |
for c in cand:
|
| 223 |
try:
|
| 224 |
+
if "." not in c:
|
| 225 |
clean_cands.append(c)
|
| 226 |
+
except Exception:
|
| 227 |
print(f"Error in processing candidate {c} for smiles {s}")
|
| 228 |
+
self.candidates[s] = clean_cands
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
+
self.spec_cand = []
|
|
|
|
| 231 |
|
| 232 |
+
if "smiles" not in self.metadata.columns:
|
| 233 |
+
if not isinstance(self.metadata.iloc[0]["identifier"], str):
|
| 234 |
+
self.metadata["smiles"] = self.metadata["identifier"].apply(str)
|
| 235 |
+
else:
|
| 236 |
+
self.metadata["smiles"] = self.metadata["identifier"]
|
|
|
|
| 237 |
|
| 238 |
+
if self.formula_to_smiles_pth:
|
| 239 |
+
if "formula" not in self.metadata.columns:
|
| 240 |
+
raise ValueError("formula_to_smiles_pth was provided, but dataset has no 'formula' column.")
|
| 241 |
+
self.metadata["candidate_key"] = self.metadata["formula"].astype(str)
|
| 242 |
+
else:
|
| 243 |
+
self.metadata["candidate_key"] = self.metadata["smiles"].astype(str)
|
| 244 |
|
| 245 |
+
self.metadata = self.metadata[self.metadata["candidate_key"].isin(self.candidates.keys())]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
if "fold" in self.metadata.columns:
|
| 248 |
+
test_metadata = self.metadata[self.metadata["fold"] == "test"]
|
| 249 |
else:
|
| 250 |
+
test_metadata = self.metadata
|
| 251 |
|
| 252 |
+
self.spec_id_to_index = dict(zip(self.metadata["identifier"], self.metadata.index))
|
|
|
|
|
|
|
| 253 |
|
| 254 |
+
for _, row in test_metadata.iterrows():
|
| 255 |
+
spec_id = row["identifier"]
|
| 256 |
+
candidate_key = row["candidate_key"]
|
| 257 |
+
cands = self.candidates[candidate_key]
|
| 258 |
|
| 259 |
+
if self.external_test:
|
| 260 |
+
labels = [False for _ in cands]
|
| 261 |
+
else:
|
| 262 |
+
target_smiles = row["smiles"]
|
| 263 |
+
labels = [c == target_smiles for c in cands]
|
| 264 |
+
if len(cands) == 0:
|
| 265 |
+
print(f"Skipping {spec_id}; empty candidate set")
|
| 266 |
+
continue
|
| 267 |
|
| 268 |
+
self.spec_cand.extend([(self.spec_id_to_index[spec_id], cands[j], k) for j, k in enumerate(labels)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
def __getattr__(self, name):
|
| 271 |
return self.instance.__getattribute__(name)
|
| 272 |
+
|
| 273 |
+
def __len__(self):
|
| 274 |
+
return len(self.spec_cand)
|
| 275 |
+
|
| 276 |
def __getitem__(self, i):
|
| 277 |
+
spec_i = self.spec_cand[i][0]
|
| 278 |
+
cand_smiles = self.spec_cand[i][1]
|
| 279 |
+
label = self.spec_cand[i][2]
|
| 280 |
|
| 281 |
+
item = self.instance.__getitem__(spec_i, transform_mol=False)
|
| 282 |
+
item["cand"] = self.mol_transform(cand_smiles)
|
| 283 |
+
item["cand_smiles"] = cand_smiles
|
| 284 |
+
item["label"] = label
|
| 285 |
|
| 286 |
+
return item
|
flare/run.sh
CHANGED
|
@@ -1,3 +1,17 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Example: run from the `flare/` directory.
|
| 3 |
+
# conda activate flare
|
| 4 |
+
# python train.py
|
| 5 |
+
# python test.py --checkpoint_choice val
|
| 6 |
+
#
|
| 7 |
+
# Optional overrides:
|
| 8 |
+
# export FLARE_PARAMS=/path/to/params.yaml
|
| 9 |
+
# export CANDIDATES_JSON="$PWD/../data/MassSpecGym_retrieval_candidates_formula.json"
|
| 10 |
+
# python test.py --candidates_pth "${CANDIDATES_JSON}"
|
| 11 |
+
|
| 12 |
+
set -euo pipefail
|
| 13 |
+
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 14 |
+
export FLARE_REPO_ROOT="${FLARE_REPO_ROOT:-$REPO_ROOT}"
|
| 15 |
+
|
| 16 |
+
# python train.py
|
| 17 |
+
# python test.py --checkpoint_choice val
|
flare/test.py
CHANGED
|
@@ -1,128 +1,163 @@
|
|
| 1 |
import argparse
|
| 2 |
import datetime
|
|
|
|
| 3 |
import sys
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
from rdkit import RDLogger
|
| 8 |
import pytorch_lightning as pl
|
| 9 |
from pytorch_lightning import Trainer
|
|
|
|
| 10 |
from massspecgym.models.base import Stage
|
| 11 |
-
import os
|
| 12 |
|
| 13 |
from flare.data.data_module import TestDataModule
|
| 14 |
from flare.data.datasets import ContrastiveDataset
|
| 15 |
-
from flare.
|
|
|
|
|
|
|
| 16 |
from flare.utils.models import get_model
|
| 17 |
|
| 18 |
-
from flare.definitions import TEST_RESULTS_DIR
|
| 19 |
-
import yaml
|
| 20 |
-
from functools import partial
|
| 21 |
-
# Suppress RDKit warnings and errors
|
| 22 |
lg = RDLogger.logger()
|
| 23 |
lg.setLevel(RDLogger.CRITICAL)
|
| 24 |
|
| 25 |
parser = argparse.ArgumentParser()
|
| 26 |
-
parser.add_argument(
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
def main(params):
|
| 35 |
-
# Seed everything
|
| 36 |
-
pl.seed_everything(params['seed'])
|
| 37 |
-
|
| 38 |
-
# Init paths to data files
|
| 39 |
-
if params['debug']:
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
data_module = TestDataModule(
|
| 54 |
dataset=dataset,
|
| 55 |
collate_fn=collate_fn,
|
| 56 |
-
split_pth=params[
|
| 57 |
-
batch_size=params[
|
| 58 |
-
num_workers=params[
|
| 59 |
)
|
| 60 |
|
| 61 |
-
model = get_model(params[
|
| 62 |
-
model.df_test_path = params[
|
| 63 |
-
model.external_test = params[
|
| 64 |
-
|
| 65 |
-
# Init trainer
|
| 66 |
trainer = Trainer(
|
| 67 |
-
accelerator=params[
|
| 68 |
-
devices=params[
|
| 69 |
-
default_root_dir=params[
|
| 70 |
)
|
| 71 |
|
| 72 |
-
# Prepare data module to test
|
| 73 |
data_module.prepare_data()
|
| 74 |
data_module.setup(stage="test")
|
| 75 |
-
|
| 76 |
-
# Test
|
| 77 |
trainer.test(model, datamodule=data_module)
|
| 78 |
|
| 79 |
|
| 80 |
if __name__ == "__main__":
|
| 81 |
args = parser.parse_args([] if "__file__" not in globals() else None)
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# Experiment directory
|
| 88 |
if args.exp_dir:
|
| 89 |
exp_dir = args.exp_dir
|
| 90 |
else:
|
| 91 |
-
run_name = params[
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
params[
|
| 102 |
-
|
| 103 |
-
# Checkpoint path
|
| 104 |
if args.checkpoint_pth:
|
| 105 |
-
params[
|
| 106 |
-
|
| 107 |
-
if not params
|
| 108 |
-
print("No checkpoint
|
| 109 |
-
for f in os.listdir(exp_dir):
|
| 110 |
if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
|
| 111 |
-
|
| 112 |
-
params['checkpoint_pth'] = checkpoint_path
|
| 113 |
break
|
| 114 |
-
assert
|
| 115 |
|
| 116 |
if args.external_test:
|
| 117 |
-
params[
|
| 118 |
else:
|
| 119 |
-
params[
|
| 120 |
-
|
| 121 |
if args.candidates_pth:
|
| 122 |
-
params[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
if args.df_test_pth:
|
| 124 |
-
params[
|
| 125 |
-
if not params
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
main(params)
|
|
|
|
| 1 |
import argparse
|
| 2 |
import datetime
|
| 3 |
+
import os
|
| 4 |
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _add_local_dependency_paths() -> None:
|
| 9 |
+
flare_repo_root = Path(__file__).resolve().parents[1]
|
| 10 |
+
candidate_roots = [flare_repo_root, flare_repo_root / "massspecgym"]
|
| 11 |
+
for env_var in ("MASSSPECGYM_ROOT",):
|
| 12 |
+
env_value = os.environ.get(env_var)
|
| 13 |
+
if env_value:
|
| 14 |
+
candidate_roots.append(Path(env_value).expanduser())
|
| 15 |
+
for root in candidate_roots:
|
| 16 |
+
if root.exists():
|
| 17 |
+
root_str = str(root)
|
| 18 |
+
if root_str not in sys.path:
|
| 19 |
+
sys.path.insert(0, root_str)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
_add_local_dependency_paths()
|
| 23 |
+
|
| 24 |
+
from functools import partial
|
| 25 |
|
|
|
|
| 26 |
import pytorch_lightning as pl
|
| 27 |
from pytorch_lightning import Trainer
|
| 28 |
+
from rdkit import RDLogger
|
| 29 |
from massspecgym.models.base import Stage
|
|
|
|
| 30 |
|
| 31 |
from flare.data.data_module import TestDataModule
|
| 32 |
from flare.data.datasets import ContrastiveDataset
|
| 33 |
+
from flare.definitions import DATA_DIR, TEST_RESULTS_DIR
|
| 34 |
+
from flare.utils.config import default_param_path, load_param_file
|
| 35 |
+
from flare.utils.data import get_mol_featurizer, get_spec_featurizer, get_test_ms_dataset
|
| 36 |
from flare.utils.models import get_model
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
lg = RDLogger.logger()
|
| 39 |
lg.setLevel(RDLogger.CRITICAL)
|
| 40 |
|
| 41 |
parser = argparse.ArgumentParser()
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--param_pth",
|
| 44 |
+
type=str,
|
| 45 |
+
default=None,
|
| 46 |
+
help="YAML hyperparameters (default: FLARE_PARAMS or repo params.yaml)",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("--checkpoint_pth", type=str, default="")
|
| 49 |
+
parser.add_argument("--checkpoint_choice", type=str, default="train", choices=["train", "val"])
|
| 50 |
+
parser.add_argument("--df_test_pth", type=str, help="result file name under experiment_dir")
|
| 51 |
+
parser.add_argument("--exp_dir", type=str, help="experiment directory (overrides auto-detect)")
|
| 52 |
+
parser.add_argument("--candidates_pth", type=str, help="override candidates JSON path")
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--external_test",
|
| 55 |
+
action="store_true",
|
| 56 |
+
help="external data without ground-truth labels in the candidate list",
|
| 57 |
+
)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
def main(params):
|
| 61 |
+
pl.seed_everything(params["seed"])
|
|
|
|
| 62 |
|
| 63 |
+
if params.get("debug"):
|
| 64 |
+
dbg = os.environ.get("FLARE_DEBUG_DATASET")
|
| 65 |
+
if dbg:
|
| 66 |
+
params["dataset_pth"] = dbg
|
| 67 |
+
else:
|
| 68 |
+
sample_tsv = DATA_DIR / "sample" / "data.tsv"
|
| 69 |
+
if sample_tsv.is_file():
|
| 70 |
+
params["dataset_pth"] = str(sample_tsv)
|
| 71 |
+
params["split_pth"] = None
|
| 72 |
+
params["df_test_path"] = os.path.join(params["experiment_dir"], "debug_result.pkl")
|
| 73 |
|
| 74 |
+
spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
|
| 75 |
+
mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
|
| 76 |
+
dataset = get_test_ms_dataset(
|
| 77 |
+
params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params
|
| 78 |
+
)
|
| 79 |
|
| 80 |
+
collate_fn = partial(
|
| 81 |
+
ContrastiveDataset.collate_fn,
|
| 82 |
+
spec_enc=params["spec_enc"],
|
| 83 |
+
spectra_view=params["spectra_view"],
|
| 84 |
+
stage=Stage.TEST,
|
| 85 |
+
)
|
| 86 |
data_module = TestDataModule(
|
| 87 |
dataset=dataset,
|
| 88 |
collate_fn=collate_fn,
|
| 89 |
+
split_pth=params["split_pth"],
|
| 90 |
+
batch_size=params["batch_size"],
|
| 91 |
+
num_workers=params["num_workers"],
|
| 92 |
)
|
| 93 |
|
| 94 |
+
model = get_model(params["model"], params)
|
| 95 |
+
model.df_test_path = params["df_test_path"]
|
| 96 |
+
model.external_test = params["external_test"]
|
| 97 |
+
|
|
|
|
| 98 |
trainer = Trainer(
|
| 99 |
+
accelerator=params["accelerator"],
|
| 100 |
+
devices=params["devices"],
|
| 101 |
+
default_root_dir=params["experiment_dir"],
|
| 102 |
)
|
| 103 |
|
|
|
|
| 104 |
data_module.prepare_data()
|
| 105 |
data_module.setup(stage="test")
|
|
|
|
|
|
|
| 106 |
trainer.test(model, datamodule=data_module)
|
| 107 |
|
| 108 |
|
| 109 |
if __name__ == "__main__":
|
| 110 |
args = parser.parse_args([] if "__file__" not in globals() else None)
|
| 111 |
|
| 112 |
+
param_path = args.param_pth or str(default_param_path())
|
| 113 |
+
params = load_param_file(param_path)
|
| 114 |
+
|
| 115 |
+
exp_dir = None
|
|
|
|
| 116 |
if args.exp_dir:
|
| 117 |
exp_dir = args.exp_dir
|
| 118 |
else:
|
| 119 |
+
run_name = params["run_name"]
|
| 120 |
+
if TEST_RESULTS_DIR.is_dir():
|
| 121 |
+
for exp in sorted(os.listdir(TEST_RESULTS_DIR), reverse=True):
|
| 122 |
+
if exp.endswith("_" + run_name):
|
| 123 |
+
exp_dir = str(TEST_RESULTS_DIR / exp)
|
| 124 |
+
break
|
| 125 |
+
if exp_dir is None:
|
| 126 |
+
today_str = datetime.datetime.now().strftime("%Y%m%d")
|
| 127 |
+
exp_dir = str(TEST_RESULTS_DIR / f"{today_str}_{run_name}")
|
| 128 |
+
os.makedirs(exp_dir, exist_ok=True)
|
| 129 |
+
params["experiment_dir"] = exp_dir
|
| 130 |
+
|
|
|
|
| 131 |
if args.checkpoint_pth:
|
| 132 |
+
params["checkpoint_pth"] = args.checkpoint_pth
|
| 133 |
+
|
| 134 |
+
if not params.get("checkpoint_pth"):
|
| 135 |
+
print("No checkpoint in params; searching experiment_dir for a .ckpt file")
|
| 136 |
+
for f in sorted(os.listdir(exp_dir)):
|
| 137 |
if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
|
| 138 |
+
params["checkpoint_pth"] = os.path.join(exp_dir, f)
|
|
|
|
| 139 |
break
|
| 140 |
+
assert params.get("checkpoint_pth"), "No checkpoint found; pass --checkpoint_pth"
|
| 141 |
|
| 142 |
if args.external_test:
|
| 143 |
+
params["external_test"] = True
|
| 144 |
else:
|
| 145 |
+
params["external_test"] = params.get("external_test", False)
|
| 146 |
+
|
| 147 |
if args.candidates_pth:
|
| 148 |
+
params["candidates_pth"] = args.candidates_pth
|
| 149 |
+
from flare.utils.config import resolve_repo_paths
|
| 150 |
+
|
| 151 |
+
resolve_repo_paths(params)
|
| 152 |
+
|
| 153 |
if args.df_test_pth:
|
| 154 |
+
params["df_test_path"] = os.path.join(exp_dir, args.df_test_pth)
|
| 155 |
+
if not params.get("df_test_path"):
|
| 156 |
+
cand = params.get("candidates_pth") or "candidates.json"
|
| 157 |
+
stem = Path(cand).stem
|
| 158 |
+
params["df_test_path"] = os.path.join(exp_dir, f"result_{stem}.pkl")
|
| 159 |
+
|
| 160 |
+
print("DF TEST PATH: ", params["df_test_path"])
|
| 161 |
+
print("EXP DIR: ", exp_dir)
|
| 162 |
+
|
| 163 |
main(params)
|
flare/train.py
CHANGED
|
@@ -1,137 +1,128 @@
|
|
| 1 |
import argparse
|
| 2 |
import datetime
|
| 3 |
-
|
| 4 |
import os
|
| 5 |
import sys
|
| 6 |
-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
import pytorch_lightning as pl
|
|
|
|
| 10 |
from pytorch_lightning import Trainer
|
| 11 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
| 12 |
-
|
| 13 |
|
| 14 |
from flare.data.data_module import ContrastiveDataModule
|
| 15 |
-
|
| 16 |
-
from flare.definitions import TEST_RESULTS_DIR
|
| 17 |
-
import yaml
|
| 18 |
from flare.data.datasets import ContrastiveDataset
|
| 19 |
-
from
|
| 20 |
-
|
| 21 |
-
from flare.utils.data import get_ms_dataset,
|
| 22 |
from flare.utils.models import get_model
|
| 23 |
-
|
| 24 |
lg = RDLogger.logger()
|
| 25 |
lg.setLevel(RDLogger.CRITICAL)
|
| 26 |
|
| 27 |
parser = argparse.ArgumentParser()
|
| 28 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def main(params):
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
data_module = ContrastiveDataModule(
|
| 48 |
dataset=dataset,
|
| 49 |
collate_fn=collate_fn,
|
| 50 |
-
split_pth=params[
|
| 51 |
-
batch_size=params[
|
| 52 |
-
num_workers=params[
|
| 53 |
)
|
| 54 |
|
| 55 |
-
model = get_model(params[
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
dir=params['experiment_dir'],
|
| 64 |
-
log_dir=params['experiment_dir'],
|
| 65 |
-
name=params['run_name'],
|
| 66 |
-
project=params['project_name'],
|
| 67 |
-
log_model=False,
|
| 68 |
-
config=model.hparams
|
| 69 |
-
)
|
| 70 |
|
| 71 |
-
|
| 72 |
-
callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ]
|
| 73 |
for i, monitor in enumerate(model.get_checkpoint_monitors()):
|
| 74 |
-
monitor_name = monitor[
|
| 75 |
checkpoint = pl.callbacks.ModelCheckpoint(
|
| 76 |
monitor=monitor_name,
|
| 77 |
save_top_k=1,
|
| 78 |
-
mode=monitor[
|
| 79 |
-
dirpath=params[
|
| 80 |
-
filename=f
|
| 81 |
-
# filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}',
|
| 82 |
auto_insert_metric_name=True,
|
| 83 |
-
# save_last=(i == 0)
|
| 84 |
)
|
| 85 |
callbacks.append(checkpoint)
|
| 86 |
-
if monitor.get(
|
| 87 |
early_stopping = EarlyStopping(
|
| 88 |
monitor=monitor_name,
|
| 89 |
-
mode=monitor[
|
| 90 |
verbose=True,
|
| 91 |
-
patience=params[
|
| 92 |
)
|
| 93 |
callbacks.append(early_stopping)
|
| 94 |
|
| 95 |
-
# Init trainer
|
| 96 |
trainer = Trainer(
|
| 97 |
-
accelerator=params[
|
| 98 |
-
devices=params[
|
| 99 |
-
max_epochs=params[
|
| 100 |
logger=logger,
|
| 101 |
-
log_every_n_steps=params[
|
| 102 |
-
val_check_interval=params[
|
| 103 |
callbacks=callbacks,
|
| 104 |
-
default_root_dir=params[
|
| 105 |
)
|
| 106 |
|
| 107 |
-
# Prepare data module to validate or test before training
|
| 108 |
data_module.prepare_data()
|
| 109 |
data_module.setup()
|
| 110 |
|
| 111 |
-
|
| 112 |
-
# Validate before training
|
| 113 |
trainer.validate(model, datamodule=data_module)
|
| 114 |
-
|
| 115 |
-
# Train
|
| 116 |
trainer.fit(model, datamodule=data_module)
|
| 117 |
-
|
| 118 |
|
| 119 |
|
| 120 |
if __name__ == "__main__":
|
| 121 |
args = parser.parse_args([] if "__file__" not in globals() else None)
|
| 122 |
|
| 123 |
-
|
|
|
|
|
|
|
| 124 |
now = datetime.datetime.now()
|
| 125 |
now_formatted = now.strftime("%Y%m%d")
|
| 126 |
-
|
| 127 |
-
# Load
|
| 128 |
-
with open(args.param_pth) as f:
|
| 129 |
-
params = yaml.load(f, Loader=yaml.FullLoader)
|
| 130 |
-
|
| 131 |
experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
|
| 132 |
-
params[
|
| 133 |
|
| 134 |
-
if not params
|
| 135 |
-
params[
|
| 136 |
|
| 137 |
main(params)
|
|
|
|
| 1 |
import argparse
|
| 2 |
import datetime
|
|
|
|
| 3 |
import os
|
| 4 |
import sys
|
|
|
|
| 5 |
|
| 6 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 7 |
+
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
import pytorch_lightning as pl
|
| 11 |
+
import yaml
|
| 12 |
from pytorch_lightning import Trainer
|
| 13 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
| 14 |
+
from rdkit import RDLogger
|
| 15 |
|
| 16 |
from flare.data.data_module import ContrastiveDataModule
|
|
|
|
|
|
|
|
|
|
| 17 |
from flare.data.datasets import ContrastiveDataset
|
| 18 |
+
from flare.definitions import DATA_DIR, TEST_RESULTS_DIR
|
| 19 |
+
from flare.utils.config import default_param_path, load_param_file
|
| 20 |
+
from flare.utils.data import get_ms_dataset, get_mol_featurizer, get_spec_featurizer
|
| 21 |
from flare.utils.models import get_model
|
| 22 |
+
|
| 23 |
lg = RDLogger.logger()
|
| 24 |
lg.setLevel(RDLogger.CRITICAL)
|
| 25 |
|
| 26 |
parser = argparse.ArgumentParser()
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--param_pth",
|
| 29 |
+
type=str,
|
| 30 |
+
default=None,
|
| 31 |
+
help="YAML hyperparameters (default: FLARE_PARAMS env or repo params.yaml)",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
|
| 35 |
def main(params):
|
| 36 |
+
pl.seed_everything(params["seed"])
|
| 37 |
+
|
| 38 |
+
if params.get("debug"):
|
| 39 |
+
dbg = os.environ.get("FLARE_DEBUG_DATASET")
|
| 40 |
+
if dbg:
|
| 41 |
+
params["dataset_pth"] = dbg
|
| 42 |
+
else:
|
| 43 |
+
sample_tsv = DATA_DIR / "sample" / "data.tsv"
|
| 44 |
+
if sample_tsv.is_file():
|
| 45 |
+
params["dataset_pth"] = str(sample_tsv)
|
| 46 |
+
params["candidates_pth"] = None
|
| 47 |
+
params["split_pth"] = None
|
| 48 |
+
|
| 49 |
+
spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
|
| 50 |
+
mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
|
| 51 |
+
dataset = get_ms_dataset(
|
| 52 |
+
params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
collate_fn = partial(
|
| 56 |
+
ContrastiveDataset.collate_fn, spec_enc=params["spec_enc"], spectra_view=params["spectra_view"]
|
| 57 |
+
)
|
| 58 |
data_module = ContrastiveDataModule(
|
| 59 |
dataset=dataset,
|
| 60 |
collate_fn=collate_fn,
|
| 61 |
+
split_pth=params["split_pth"],
|
| 62 |
+
batch_size=params["batch_size"],
|
| 63 |
+
num_workers=params["num_workers"],
|
| 64 |
)
|
| 65 |
|
| 66 |
+
model = get_model(params["model"], params)
|
| 67 |
+
|
| 68 |
+
tb_logger = pl.loggers.TensorBoardLogger(
|
| 69 |
+
save_dir=params["experiment_dir"],
|
| 70 |
+
name="",
|
| 71 |
+
version="",
|
| 72 |
+
)
|
| 73 |
+
logger = tb_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
callbacks = [pl.callbacks.ModelCheckpoint(save_last=False)]
|
|
|
|
| 76 |
for i, monitor in enumerate(model.get_checkpoint_monitors()):
|
| 77 |
+
monitor_name = monitor["monitor"]
|
| 78 |
checkpoint = pl.callbacks.ModelCheckpoint(
|
| 79 |
monitor=monitor_name,
|
| 80 |
save_top_k=1,
|
| 81 |
+
mode=monitor["mode"],
|
| 82 |
+
dirpath=params["experiment_dir"],
|
| 83 |
+
filename=f"{{epoch}}-{{{monitor_name}:.2f}}",
|
|
|
|
| 84 |
auto_insert_metric_name=True,
|
|
|
|
| 85 |
)
|
| 86 |
callbacks.append(checkpoint)
|
| 87 |
+
if monitor.get("early_stopping", False):
|
| 88 |
early_stopping = EarlyStopping(
|
| 89 |
monitor=monitor_name,
|
| 90 |
+
mode=monitor["mode"],
|
| 91 |
verbose=True,
|
| 92 |
+
patience=params["early_stopping_patience"],
|
| 93 |
)
|
| 94 |
callbacks.append(early_stopping)
|
| 95 |
|
|
|
|
| 96 |
trainer = Trainer(
|
| 97 |
+
accelerator=params["accelerator"],
|
| 98 |
+
devices=params["devices"],
|
| 99 |
+
max_epochs=params["max_epochs"],
|
| 100 |
logger=logger,
|
| 101 |
+
log_every_n_steps=params["log_every_n_steps"],
|
| 102 |
+
val_check_interval=params["val_check_interval"],
|
| 103 |
callbacks=callbacks,
|
| 104 |
+
default_root_dir=params["experiment_dir"],
|
| 105 |
)
|
| 106 |
|
|
|
|
| 107 |
data_module.prepare_data()
|
| 108 |
data_module.setup()
|
| 109 |
|
|
|
|
|
|
|
| 110 |
trainer.validate(model, datamodule=data_module)
|
|
|
|
|
|
|
| 111 |
trainer.fit(model, datamodule=data_module)
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
if __name__ == "__main__":
|
| 115 |
args = parser.parse_args([] if "__file__" not in globals() else None)
|
| 116 |
|
| 117 |
+
param_path = args.param_pth or str(default_param_path())
|
| 118 |
+
params = load_param_file(param_path)
|
| 119 |
+
|
| 120 |
now = datetime.datetime.now()
|
| 121 |
now_formatted = now.strftime("%Y%m%d")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
|
| 123 |
+
params["experiment_dir"] = experiment_dir
|
| 124 |
|
| 125 |
+
if not params.get("df_test_path"):
|
| 126 |
+
params["df_test_path"] = os.path.join(experiment_dir, "result.pkl")
|
| 127 |
|
| 128 |
main(params)
|
flare/tune.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
import datetime
|
| 3 |
import os
|
| 4 |
import sys
|
|
@@ -20,6 +21,7 @@ from flare.data.datasets import ContrastiveDataset
|
|
| 20 |
from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
|
| 21 |
from flare.utils.models import get_model
|
| 22 |
from flare.definitions import TEST_RESULTS_DIR
|
|
|
|
| 23 |
from functools import partial
|
| 24 |
from rdkit import RDLogger
|
| 25 |
from massspecgym.models.base import Stage
|
|
@@ -29,7 +31,12 @@ lg = RDLogger.logger()
|
|
| 29 |
lg.setLevel(RDLogger.CRITICAL)
|
| 30 |
|
| 31 |
parser = argparse.ArgumentParser()
|
| 32 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
parser.add_argument("--n_trials", type=int, default=20)
|
| 34 |
|
| 35 |
class EpochLossTracker(Callback):
|
|
@@ -112,7 +119,7 @@ def save_trial_result(base_dir, trial, params, duration):
|
|
| 112 |
|
| 113 |
def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials):
|
| 114 |
start_time = time.time()
|
| 115 |
-
params =
|
| 116 |
|
| 117 |
try:
|
| 118 |
# Training-related params
|
|
@@ -160,8 +167,6 @@ def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_tri
|
|
| 160 |
ContrastiveDataset.collate_fn,
|
| 161 |
spec_enc=params["spec_enc"],
|
| 162 |
spectra_view=params["spectra_view"],
|
| 163 |
-
mask_peak_ratio=params["mask_peak_ratio"],
|
| 164 |
-
aug_cands=params["aug_cands"],
|
| 165 |
)
|
| 166 |
|
| 167 |
data_module = ContrastiveDataModule(
|
|
@@ -226,12 +231,11 @@ def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_tri
|
|
| 226 |
|
| 227 |
|
| 228 |
def main(args):
|
| 229 |
-
|
| 230 |
-
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
base_dir = "../experiments/20250916_simple_model_optuna"
|
| 235 |
os.makedirs(base_dir, exist_ok=True)
|
| 236 |
params["experiment_dir"] = base_dir
|
| 237 |
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import copy
|
| 3 |
import datetime
|
| 4 |
import os
|
| 5 |
import sys
|
|
|
|
| 21 |
from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
|
| 22 |
from flare.utils.models import get_model
|
| 23 |
from flare.definitions import TEST_RESULTS_DIR
|
| 24 |
+
from flare.utils.config import default_param_path, load_param_file
|
| 25 |
from functools import partial
|
| 26 |
from rdkit import RDLogger
|
| 27 |
from massspecgym.models.base import Stage
|
|
|
|
| 31 |
lg.setLevel(RDLogger.CRITICAL)
|
| 32 |
|
| 33 |
parser = argparse.ArgumentParser()
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--param_pth",
|
| 36 |
+
type=str,
|
| 37 |
+
default=None,
|
| 38 |
+
help="Base YAML (default: FLARE_PARAMS or repo params.yaml)",
|
| 39 |
+
)
|
| 40 |
parser.add_argument("--n_trials", type=int, default=20)
|
| 41 |
|
| 42 |
class EpochLossTracker(Callback):
|
|
|
|
| 119 |
|
| 120 |
def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials):
|
| 121 |
start_time = time.time()
|
| 122 |
+
params = copy.deepcopy(base_params)
|
| 123 |
|
| 124 |
try:
|
| 125 |
# Training-related params
|
|
|
|
| 167 |
ContrastiveDataset.collate_fn,
|
| 168 |
spec_enc=params["spec_enc"],
|
| 169 |
spectra_view=params["spectra_view"],
|
|
|
|
|
|
|
| 170 |
)
|
| 171 |
|
| 172 |
data_module = ContrastiveDataModule(
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
def main(args):
|
| 234 |
+
param_path = args.param_pth or str(default_param_path())
|
| 235 |
+
params = load_param_file(param_path)
|
| 236 |
|
| 237 |
+
now = datetime.datetime.now().strftime("%Y%m%d")
|
| 238 |
+
base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
|
|
|
|
| 239 |
os.makedirs(base_dir, exist_ok=True)
|
| 240 |
params["experiment_dir"] = base_dir
|
| 241 |
|
flare/utils/__init__.py
CHANGED
|
@@ -1,3 +1 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
|
| 3 |
from massspecgym.utils import *
|
|
|
|
|
|
|
|
|
|
| 1 |
from massspecgym.utils import *
|
flare/utils/config.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load YAML hyperparameters and resolve filesystem paths relative to the repository root."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
from flare.definitions import REPO_DIR
|
| 11 |
+
|
| 12 |
+
# Keys that may hold filesystem paths (relative paths are resolved against REPO_DIR).
|
| 13 |
+
_PATH_KEYS = frozenset(
|
| 14 |
+
{
|
| 15 |
+
"dataset_pth",
|
| 16 |
+
"candidates_pth",
|
| 17 |
+
"subformula_dir_pth",
|
| 18 |
+
"split_pth",
|
| 19 |
+
"checkpoint_pth",
|
| 20 |
+
"df_test_path",
|
| 21 |
+
"formula_to_smiles_pth",
|
| 22 |
+
}
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def resolve_repo_paths(params: dict[str, Any]) -> None:
|
| 27 |
+
"""In-place: turn repo-relative path strings into absolute paths."""
|
| 28 |
+
root = REPO_DIR
|
| 29 |
+
for key in _PATH_KEYS:
|
| 30 |
+
val = params.get(key)
|
| 31 |
+
if not val or not isinstance(val, str):
|
| 32 |
+
continue
|
| 33 |
+
p = Path(val)
|
| 34 |
+
if not p.is_absolute():
|
| 35 |
+
params[key] = str((root / p).resolve())
|
| 36 |
+
else:
|
| 37 |
+
params[key] = str(p.resolve())
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_param_file(path: str | Path) -> dict[str, Any]:
|
| 41 |
+
"""Load a YAML parameter file and resolve path fields."""
|
| 42 |
+
p = Path(path)
|
| 43 |
+
if not p.is_file():
|
| 44 |
+
raise FileNotFoundError(f"Parameter file not found: {p}")
|
| 45 |
+
with open(p, encoding="utf-8") as f:
|
| 46 |
+
params = yaml.load(f, Loader=yaml.FullLoader)
|
| 47 |
+
if params is None:
|
| 48 |
+
params = {}
|
| 49 |
+
if not isinstance(params, dict):
|
| 50 |
+
raise TypeError(f"Expected mapping at top level of {p}, got {type(params)}")
|
| 51 |
+
resolve_repo_paths(params)
|
| 52 |
+
return params
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def default_param_path() -> Path:
|
| 56 |
+
"""Path to the default params file (overridable with FLARE_PARAMS)."""
|
| 57 |
+
override = os.environ.get("FLARE_PARAMS")
|
| 58 |
+
if override:
|
| 59 |
+
return Path(override).expanduser()
|
| 60 |
+
env_root = os.environ.get("FLARE_REPO_ROOT")
|
| 61 |
+
if env_root:
|
| 62 |
+
return Path(env_root).expanduser() / "params.yaml"
|
| 63 |
+
return REPO_DIR / "params.yaml"
|
flare/utils/data.py
CHANGED
|
@@ -21,6 +21,11 @@ class Subformula_Loader:
|
|
| 21 |
self.dir_path = dir_path
|
| 22 |
self.use_prec_mz = use_prec_mz
|
| 23 |
self.formula_source = formula_source
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
if spectra_view == 'SpecFormula':
|
| 25 |
self.load = self.load_subformula_data
|
| 26 |
elif spectra_view == "SpecFormulaMz":
|
|
@@ -63,77 +68,6 @@ class Subformula_Loader:
|
|
| 63 |
except:
|
| 64 |
return None
|
| 65 |
|
| 66 |
-
def load_magma_data(self, data, curr_form, curr_prec_mz):
|
| 67 |
-
|
| 68 |
-
np.random.seed(42)
|
| 69 |
-
|
| 70 |
-
formula_to_intensity = {}
|
| 71 |
-
formula_to_mz = {}
|
| 72 |
-
|
| 73 |
-
# data is None
|
| 74 |
-
if data is None:
|
| 75 |
-
if self.use_prec_mz:
|
| 76 |
-
return {'formulas': [curr_form], 'formula_mzs': [curr_prec_mz], 'formula_intensities': [PRECURSOR_INTENSITY]}
|
| 77 |
-
else:
|
| 78 |
-
return {'formulas': [], 'formula_mzs': [], 'formula_intensities': []}
|
| 79 |
-
|
| 80 |
-
# randomly choose 1 formula for each peak, keep largest intensity for each formula
|
| 81 |
-
if self.formula_source.endswith('1'):
|
| 82 |
-
for f, m, i in zip(data['subformulas'], data['mz'], data['intensities']):
|
| 83 |
-
|
| 84 |
-
if not f:
|
| 85 |
-
continue
|
| 86 |
-
selected_f = np.random.choice(f)
|
| 87 |
-
if selected_f in formula_to_intensity:
|
| 88 |
-
if i > formula_to_intensity[selected_f]:
|
| 89 |
-
formula_to_intensity[selected_f] = i
|
| 90 |
-
formula_to_mz[selected_f] = m
|
| 91 |
-
else:
|
| 92 |
-
formula_to_intensity[selected_f] = i
|
| 93 |
-
formula_to_mz[selected_f] = m
|
| 94 |
-
|
| 95 |
-
# take all formulas, divide intensity by number of formulas, keep largest intensity for each formula
|
| 96 |
-
elif self.formula_source.endswith('all'):
|
| 97 |
-
for f, m, i in zip(data['subformulas'], data['mz'], data['intensities']):
|
| 98 |
-
|
| 99 |
-
if not f:
|
| 100 |
-
continue
|
| 101 |
-
for fi in f:
|
| 102 |
-
if fi in formula_to_intensity:
|
| 103 |
-
if i/len(f) > formula_to_intensity[fi]:
|
| 104 |
-
formula_to_intensity[fi] = i/len(f)
|
| 105 |
-
formula_to_mz[fi] = m
|
| 106 |
-
else:
|
| 107 |
-
formula_to_intensity[fi] = i/len(f)
|
| 108 |
-
formula_to_mz[fi] = m
|
| 109 |
-
else:
|
| 110 |
-
raise Exception(f"Formula source not supported: {self.formula_source}")
|
| 111 |
-
|
| 112 |
-
mzs = list(formula_to_mz.values())
|
| 113 |
-
formulas = list(formula_to_mz.keys())
|
| 114 |
-
intensities = list(formula_to_intensity.values())
|
| 115 |
-
|
| 116 |
-
# add precursor mz
|
| 117 |
-
if self.use_prec_mz:
|
| 118 |
-
if curr_form in formulas:
|
| 119 |
-
intensities[formulas.index(curr_form)] = PRECURSOR_INTENSITY
|
| 120 |
-
else:
|
| 121 |
-
formulas.append(curr_form)
|
| 122 |
-
intensities.append(PRECURSOR_INTENSITY)
|
| 123 |
-
mzs.append(curr_prec_mz)
|
| 124 |
-
|
| 125 |
-
# sort by mzs
|
| 126 |
-
mzs = np.array(mzs)
|
| 127 |
-
formulas = np.array(formulas)
|
| 128 |
-
intensities = np.array(intensities)
|
| 129 |
-
|
| 130 |
-
ind = mzs.argsort()
|
| 131 |
-
mzs = mzs[ind]
|
| 132 |
-
formulas = formulas[ind]
|
| 133 |
-
intensities = intensities[ind]
|
| 134 |
-
|
| 135 |
-
return {'formulas': formulas, 'formula_mzs': mzs, 'formula_intensities': intensities}
|
| 136 |
-
|
| 137 |
def load_sirius_data(self, data):
|
| 138 |
try:
|
| 139 |
|
|
@@ -167,8 +101,6 @@ class Subformula_Loader:
|
|
| 167 |
data = json.load(f)
|
| 168 |
if self.formula_source == 'sirius':
|
| 169 |
return self.load_sirius_data(data)
|
| 170 |
-
elif self.formula_source.startswith('magma'):
|
| 171 |
-
return self.load_magma_data(data, curr_form, curr_prec_mz)
|
| 172 |
else:
|
| 173 |
return self.load_mist_data(data, curr_form, curr_prec_mz)
|
| 174 |
|
|
@@ -263,17 +195,18 @@ def get_test_ms_dataset(spectra_view: T.Union[str, T.List[str]],
|
|
| 263 |
else: views.extend(v)
|
| 264 |
views = frozenset(views)
|
| 265 |
|
| 266 |
-
dataset_params = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
if "SpecFormula" in views or "SpecFormulaMz" in views:
|
| 268 |
-
dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], '
|
| 269 |
-
use_formulas = True
|
| 270 |
-
|
| 271 |
-
# if params['use_cons_spec']:
|
| 272 |
-
# dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
|
| 273 |
-
# if 'use_NL_spec' in params and params['use_NL_spec']:
|
| 274 |
-
# dataset_params.update({'NL_spec_dir_pth': params['NL_spec_dir_pth']})
|
| 275 |
-
# if params['pred_fp'] or params['use_fp']:
|
| 276 |
-
# dataset_params.update({'fp_dir_pth': '', 'fp_size': params['fp_size'], 'fp_radius': params['fp_radius']})
|
| 277 |
|
| 278 |
return jestr_datasets.ExpandedRetrievalDataset(use_formulas=use_formulas, **dataset_params)
|
| 279 |
|
|
|
|
| 21 |
self.dir_path = dir_path
|
| 22 |
self.use_prec_mz = use_prec_mz
|
| 23 |
self.formula_source = formula_source
|
| 24 |
+
if str(formula_source).startswith('magma'):
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"MAGMA formula sources are not supported in this release (see archive/magma/). "
|
| 27 |
+
"Use 'default' (MIST-style JSON) or 'sirius'."
|
| 28 |
+
)
|
| 29 |
if spectra_view == 'SpecFormula':
|
| 30 |
self.load = self.load_subformula_data
|
| 31 |
elif spectra_view == "SpecFormulaMz":
|
|
|
|
| 68 |
except:
|
| 69 |
return None
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def load_sirius_data(self, data):
|
| 72 |
try:
|
| 73 |
|
|
|
|
| 101 |
data = json.load(f)
|
| 102 |
if self.formula_source == 'sirius':
|
| 103 |
return self.load_sirius_data(data)
|
|
|
|
|
|
|
| 104 |
else:
|
| 105 |
return self.load_mist_data(data, curr_form, curr_prec_mz)
|
| 106 |
|
|
|
|
| 195 |
else: views.extend(v)
|
| 196 |
views = frozenset(views)
|
| 197 |
|
| 198 |
+
dataset_params = {
|
| 199 |
+
'spectra_view': spectra_view,
|
| 200 |
+
'pth': params['dataset_pth'],
|
| 201 |
+
'spec_transform': spectra_featurizer,
|
| 202 |
+
'mol_transform': mol_featurizer,
|
| 203 |
+
"candidates_pth": params.get('candidates_pth'),
|
| 204 |
+
"formula_to_smiles_pth": params.get('formula_to_smiles_pth'),
|
| 205 |
+
"external_test": params.get('external_test', False)
|
| 206 |
+
}
|
| 207 |
if "SpecFormula" in views or "SpecFormulaMz" in views:
|
| 208 |
+
dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'formula_source': params['formula_source']})
|
| 209 |
+
use_formulas = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
return jestr_datasets.ExpandedRetrievalDataset(use_formulas=use_formulas, **dataset_params)
|
| 212 |
|
flare/utils/loss.py
CHANGED
|
@@ -43,40 +43,6 @@ def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
|
|
| 43 |
|
| 44 |
return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)
|
| 45 |
|
| 46 |
-
def cand_spec_sim_loss(spec_enc, cand_enc):
|
| 47 |
-
cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d
|
| 48 |
-
spec_enc = spec_enc.unsqueeze(0) # 1 x B x d
|
| 49 |
-
|
| 50 |
-
sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2)
|
| 51 |
-
loss = torch.mean(sim)
|
| 52 |
-
|
| 53 |
-
return loss
|
| 54 |
-
|
| 55 |
-
class cons_spec_loss:
|
| 56 |
-
def __init__(self, loss_type) -> None:
|
| 57 |
-
self.loss_compute = {'cosine': self.cos_loss,
|
| 58 |
-
'l2':torch.nn.MSELoss()}[loss_type]
|
| 59 |
-
def __call__(self,cons_spec, ind_spec):
|
| 60 |
-
return self.loss_compute(cons_spec, ind_spec)
|
| 61 |
-
|
| 62 |
-
def cos_loss(self, cons_spec, ind_spec):
|
| 63 |
-
sim = nn.functional.cosine_similarity(cons_spec, ind_spec)
|
| 64 |
-
loss = 1-torch.mean(sim)
|
| 65 |
-
return loss
|
| 66 |
-
|
| 67 |
-
class fp_loss:
|
| 68 |
-
def __init__(self, loss_type) -> None:
|
| 69 |
-
self.loss_compute = {'cosine': self.fp_loss_cos,
|
| 70 |
-
'bce': nn.BCELoss()}[loss_type]
|
| 71 |
-
|
| 72 |
-
def __call__(self, predicted_fp, target_fp):
|
| 73 |
-
return self.loss_compute(predicted_fp, target_fp)
|
| 74 |
-
|
| 75 |
-
def fp_loss_cos(self, predicted_fp, target_fp):
|
| 76 |
-
sim = nn.functional.cosine_similarity(predicted_fp, target_fp)
|
| 77 |
-
return 1 - torch.mean(sim)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
# ---------- Utility ----------
|
| 81 |
def _safe_divide(num, denom, eps=1e-8):
|
| 82 |
return num / (denom + eps)
|
|
|
|
| 43 |
|
| 44 |
return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
# ---------- Utility ----------
|
| 47 |
def _safe_divide(num, denom, eps=1e-8):
|
| 48 |
return num / (denom + eps)
|
flare/utils/models.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
|
| 2 |
from flare.models.mol_encoder import MolEnc
|
| 3 |
-
from flare.models.encoders import MLP
|
| 4 |
from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive, FilipGlobalContrastive
|
| 5 |
|
| 6 |
def get_spec_encoder(spec_enc:str, args):
|
|
@@ -13,12 +12,6 @@ def get_spec_encoder(spec_enc:str, args):
|
|
| 13 |
def get_mol_encoder(mol_enc: str, args):
|
| 14 |
return {'GNN': MolEnc}[mol_enc](args, in_dim=78)
|
| 15 |
|
| 16 |
-
def get_fp_pred_model(args):
|
| 17 |
-
return MLP(in_dim=args.final_embedding_dim, hidden_dims=[args.fp_size], final_activation='sigmoid', dropout=args.fp_dropout)
|
| 18 |
-
|
| 19 |
-
def get_fp_enc_model(args):
|
| 20 |
-
return MLP(in_dim=args.fp_size, hidden_dims=[args.final_embedding_dim,args.final_embedding_dim*2,args.final_embedding_dim,], final_activation=None, dropout=0.0)
|
| 21 |
-
|
| 22 |
def get_model(model:str,
|
| 23 |
params):
|
| 24 |
|
|
|
|
| 1 |
from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
|
| 2 |
from flare.models.mol_encoder import MolEnc
|
|
|
|
| 3 |
from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive, FilipGlobalContrastive
|
| 4 |
|
| 5 |
def get_spec_encoder(spec_enc:str, args):
|
|
|
|
| 12 |
def get_mol_encoder(mol_enc: str, args):
|
| 13 |
return {'GNN': MolEnc}[mol_enc](args, in_dim=78)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def get_model(model:str,
|
| 16 |
params):
|
| 17 |
|
flare/utils/mol_search.py
CHANGED
|
@@ -311,57 +311,7 @@ class SpectraMoleculeRetriever:
|
|
| 311 |
|
| 312 |
|
| 313 |
if __name__ == "__main__":
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
from flare.utils.data import get_spec_featurizer, get_mol_featurizer
|
| 318 |
-
from flare.utils.models import get_model
|
| 319 |
-
from flare.utils.mol_search import SpectraMoleculeRetriever
|
| 320 |
-
from flare.utils.general import filip_similarity_single
|
| 321 |
-
import yaml
|
| 322 |
-
|
| 323 |
-
metadata = {
|
| 324 |
-
"class": {
|
| 325 |
-
"lipid": ["mol1", "mol2"],
|
| 326 |
-
"peptide": ["mol3"]
|
| 327 |
-
},
|
| 328 |
-
"pathway": {
|
| 329 |
-
"beta-oxidation": ["mol1"],
|
| 330 |
-
"glycolysis": ["mol2", "mol3"]
|
| 331 |
-
}
|
| 332 |
-
}
|
| 333 |
-
|
| 334 |
-
smiles_dict = {
|
| 335 |
-
"mol1": "CCO",
|
| 336 |
-
"mol2": "CCN",
|
| 337 |
-
"mol3": "CCC"
|
| 338 |
-
}
|
| 339 |
-
|
| 340 |
-
# Load model and data
|
| 341 |
-
param_pth = '/data/yzhouc01/cancer/flare.yaml'
|
| 342 |
-
with open(param_pth) as f:
|
| 343 |
-
params = yaml.load(f, Loader=yaml.FullLoader)
|
| 344 |
-
|
| 345 |
-
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
|
| 346 |
-
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
# load model
|
| 350 |
-
checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250930_optimized_flare_42/epoch=1959-train_loss=0.08.ckpt"
|
| 351 |
-
params['checkpoint_pth'] = checkpoint_pth
|
| 352 |
-
model = get_model(params['model'], params)
|
| 353 |
-
|
| 354 |
-
specMolRetriever = SpectraMoleculeRetriever(
|
| 355 |
-
molecule_encoder=model.mol_enc_model,
|
| 356 |
-
spectra_encoder=model.spec_enc_model,
|
| 357 |
-
fine_similarity_fn=filip_similarity_single,
|
| 358 |
-
smiles_preprocess=mol_featurizer
|
| 359 |
)
|
| 360 |
-
|
| 361 |
-
specMolRetriever.build_database(smiles_dict, metadata=metadata, cache_nodes=True)
|
| 362 |
-
|
| 363 |
-
# Filter search to molecules in a specific pathway
|
| 364 |
-
# results = specMolRetriever.search(spectrum, subset={"pathway": "beta-oxidation"})
|
| 365 |
-
|
| 366 |
-
# for mol_id, score in results[:10]:
|
| 367 |
-
# print(f"{mol_id}: {score:.3f}")
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
if __name__ == "__main__":
|
| 314 |
+
raise SystemExit(
|
| 315 |
+
"SpectraMoleculeRetriever is a library class; configure paths via FLARE_PARAMS / "
|
| 316 |
+
"FLARE_CHECKPOINT and import it from application code (see README)."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|