yzhouchen001 commited on
Commit
6c3d8a1
·
1 Parent(s): f4a27d9
README.md CHANGED
@@ -8,55 +8,158 @@ pinned: false
8
  python_version: 3.11.7
9
  ---
10
 
11
- # 🔥 FLARE
12
- Fine-grained Learning for Aligment of spectra-molecule REpresentation
 
 
 
13
 
14
- ### Authors
15
  **Yan Zhou Chen, Soha Hassoun**
16
- Department of Computer Science, Tufts University
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  ---
19
 
20
- FLARE is a framework for **ranking molecular candidates given a mass spectrum**. Beyond candidate ranking, FLARE provides **visualization of peak-to-node attribution**, enabling deeper insights into how spectral peaks correspond to molecular graph nodes.
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  ---
23
 
24
- ## 🌐 Visualize Peak-to-Node Correspondence
25
- Explore our interactive [app](https://huggingface.co/spaces/HassounLab/FLARE) to visualize peak-to-node attributes in real time.
 
 
 
 
 
 
 
 
 
 
26
 
27
  ---
28
- ## 🛠 Set up
29
- ### Clone repository
30
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  git clone https://huggingface.co/spaces/HassounLab/FLARE
32
- cd flare
33
- ```
34
- ### Set up environment and install dependencies
35
- ```
36
  conda create -n flare python=3.11
37
  conda activate flare
38
  pip install -r requirements.txt
39
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ---
41
- ## 🚀 Usage
42
- Modify params.yaml as necessary
43
 
 
 
 
 
 
44
  ```
45
- # preprocess data
46
- python subformula_assign/assign_subformulae.py --spec-files ../data/sample/data.tsv --output-dir ../data/sample/subformulae --labels-file ../data/sample/data.tsv --max-formulae 60
47
 
48
- # train
49
- python train.py
 
 
 
50
 
51
- # test
52
- python test.py
53
  ```
54
 
 
 
55
  ---
56
- ## 🙏 Acknowledgments
57
- - **Training Data**: [MassSpecGym](https://github.com/pluskal-lab/MassSpecGym)
58
- - **Subformula Assigner Code**: [MIST](https://github.com/samgoldman97/mist/tree/main_v2)
 
 
59
 
60
  ---
61
- ## 📧 Contact
62
- For questions, reach out to soha.hassoun@tufts.edu
 
 
 
8
  python_version: 3.11.7
9
  ---
10
 
11
+ # FLARE
12
+
13
+ **F**ine-grained **L**earning for **A**lignment of spectra–molecule **RE**presentations
14
+
15
+ ### Authors
16
 
 
17
  **Yan Zhou Chen, Soha Hassoun**
18
+ Department of Computer Science, Tufts University
19
+
20
+ ---
21
+
22
+ ## Overview
23
+
24
+ FLARE learns a joint embedding space for **MS/MS spectra** (represented as **per-peak chemical formulas** from a subformula assigner) and **molecular graphs**. The default publication model uses **FILIP-style contrastive learning** (`filipContrastive`): fine-grained similarity between spectrum tokens and graph nodes, with a temperature-scaled loss.
25
+
26
+ Use cases:
27
+
28
+ - **Retrieval**: rank a list of candidate SMILES for each query spectrum (MassSpecGym-style evaluation).
29
+ - **Interpretation**: the Streamlit app visualizes **peak-to-node** correspondence for a single spectrum–molecule pair.
30
 
31
  ---
32
 
33
+ ## Model (default stack)
34
+
35
+ | Component | Setting (see `params.yaml`) |
36
+ |-----------|----------------------------|
37
+ | Spectrum input | `SpecFormula` — formula peaks from JSON in `subformula_dir_pth` |
38
+ | Formula source | `default` — MIST-compatible JSON (`load_mist_data`); optional `sirius` |
39
+ | Spectrum encoder | `Transformer_Formula` |
40
+ | Molecule encoder | `GNN` (DGL + dgllife GCN), node embeddings for FILIP |
41
+ | Training objective | `filipContrastive` — masked FILIP loss, temperature `contr_temp` |
42
+ | Output | Embeddings for cosine / FILIP similarity at test time |
43
+
44
+ Hyperparameters are split into: **run/logging**, **training loop**, **data paths**, **featurizers**, **encoder widths/depths**, and **evaluation** (`at_ks`, `myopic_mces_kwargs`). Only keys present in `params.yaml` are required; paths can be **relative to the repository root** (recommended) or absolute.
45
 
46
  ---
47
 
48
+ ## Repository layout
49
+
50
+ | Path | Role |
51
+ |------|------|
52
+ | `params.yaml` | Canonical training/testing/app hyperparameters |
53
+ | `hparams.yaml` | Symlink to `params.yaml` (Hugging Face Spaces convention) |
54
+ | `flare/` | Training (`train.py`, `test.py`, `tune.py`), models, data pipeline |
55
+ | `massspecgym/` | Vendored MassSpecGym Lightning base classes and utilities |
56
+ | `app.py`, `app_utils/` | Streamlit peak–node visualization |
57
+ | `pretrained_models/` | Place public checkpoints here (e.g. `flare.ckpt`) |
58
+ | `experiments/` | Default output root for new runs (see `flare/definitions.py`) |
59
+ | `archive/` | Older scripts and features **not** part of the slim release (MAGMA, class experiments, legacy YAML, etc.); nothing was deleted |
60
 
61
  ---
62
+
63
+ ## Environment variables (no hardcoded machine paths)
64
+
65
+ | Variable | Purpose |
66
+ |----------|---------|
67
+ | `FLARE_PARAMS` | Path to YAML params (default: `<repo>/params.yaml`) |
68
+ | `FLARE_CHECKPOINT` | Checkpoint for the app or manual runs |
69
+ | `FLARE_DEBUG_DATASET` | When `debug: true`, TSV path for a tiny local dataset |
70
+ | `FLARE_REPO_ROOT` | Optional; overrides repo root for resolving relative paths in `default_param_path()` |
71
+ | `MASSSPECGYM_ROOT` | Optional extra `sys.path` root if you use an external `massspecgym` checkout |
72
+ | `FLARE_UPLOAD_CKPT`, `HF_REPO_ID`, `HF_REPO_TYPE`, `HF_TOKEN` | See `app_utils/upload_model.py` for HF uploads |
73
+
74
+ ---
75
+
76
+ ## Setup
77
+
78
+ ```bash
79
  git clone https://huggingface.co/spaces/HassounLab/FLARE
80
+ cd FLARE
81
+
 
 
82
  conda create -n flare python=3.11
83
  conda activate flare
84
  pip install -r requirements.txt
85
  ```
86
+
87
+ Place **MassSpecGym** (or your) spectrum TSV, **candidate JSON**, and **subformula JSON directory** where you want them, then set paths in `params.yaml` (relative paths like `data/MassSpecGym.tsv` resolve from the repo root).
88
+
89
+ ---
90
+
91
+ ## Data preparation
92
+
93
+ Per-spectrum subformula JSON files (one file per spectrum id, MIST-style) are required for `SpecFormula`. Generate them with the bundled assigner (adapted from MIST):
94
+
95
+ ```bash
96
+ cd flare/subformula_assign
97
+ export SPEC_FILES=/path/to/spectra.tsv
98
+ export OUTPUT_DIR=/path/to/subformulae_out
99
+ export LABELS_FILE=/path/to/spectra.tsv # often same as SPEC_FILES
100
+ export MAX_FORMULAE=60
101
+ bash run.sh
102
+ ```
103
+
104
+ Defaults in `run.sh` point at `data/sample/` under the repo if you add a small sample there.
105
+
106
+ ---
107
+
108
+ ## Training
109
+
110
+ From the repository root (so `flare` and `massspecgym` import correctly):
111
+
112
+ ```bash
113
+ cd flare
114
+ python train.py # uses FLARE_PARAMS or ../params.yaml
115
+ python train.py --param_pth /path/to/custom.yaml
116
+ ```
117
+
118
+ `train.py` creates `experiments/<YYYYMMDD>_<run_name>/`, writes TensorBoard logs there, and saves checkpoints. `df_test_path` defaults to `<experiment_dir>/result.pkl` if unset.
119
+
120
+ ---
121
+
122
+ ## Testing (retrieval)
123
+
124
+ ```bash
125
+ cd flare
126
+ python test.py \
127
+ --checkpoint_pth /path/to/epoch=....ckpt \
128
+ --exp_dir /path/to/experiment_dir # optional; else latest matching run_name
129
+ ```
130
+
131
+ Useful flags: `--candidates_pth`, `--df_test_pth`, `--external_test` (no positive label in the list). Override params file with `--param_pth` or `FLARE_PARAMS`.
132
+
133
  ---
 
 
134
 
135
+ ## Hyperparameter search
136
+
137
+ ```bash
138
+ cd flare
139
+ python tune.py --n_trials 20
140
  ```
 
 
141
 
142
+ Uses Optuna; study database and logs live under `experiments/<date>_<run_name>_optuna/`. Best YAML is written to `best_params.yaml` in that folder.
143
+
144
+ ---
145
+
146
+ ## Streamlit app (peak-to-node visualization)
147
 
148
+ ```bash
149
+ streamlit run app.py
150
  ```
151
 
152
+ The app loads architecture settings from `FLARE_PARAMS` (default `params.yaml`) and weights from `FLARE_CHECKPOINT` (default `pretrained_models/flare.ckpt`). Ensure the checkpoint matches the architecture in the YAML.
153
+
154
  ---
155
+
156
+ ## Acknowledgments
157
+
158
+ - **Data**: [MassSpecGym](https://github.com/pluskal-lab/MassSpecGym)
159
+ - **Subformula tooling**: [MIST](https://github.com/samgoldman97/mist/tree/main_v2)
160
 
161
  ---
162
+
163
+ ## Contact
164
+
165
+ For questions: soha.hassoun@tufts.edu
flare/data/datasets.py CHANGED
@@ -1,124 +1,30 @@
1
- import pandas as pd
2
  import json
3
  import typing as T
 
 
 
 
4
  import numpy as np
 
5
  import torch
 
6
  import massspecgym.utils as utils
7
- from pathlib import Path
8
- from torch.utils.data.dataset import Dataset
9
- from torch.utils.data.dataloader import default_collate
10
- import dgl
11
- from collections import defaultdict
12
- from massspecgym.data.transforms import SpecTransform, MolTransform, MolToInChIKey
13
  from massspecgym.data.datasets import MassSpecDataset
14
- import flare.utils.data as data_utils
15
- from torch.nn.utils.rnn import pad_sequence
16
  from massspecgym.models.base import Stage
17
- import pickle
18
- import math
19
- import itertools
20
- from rdkit.Chem import AllChem
21
- from rdkit import Chem
22
- from magma.run_magma import run_magma
23
- import matchms
24
 
25
  class JESTR1_MassSpecDataset(MassSpecDataset):
26
- def __init__(
27
- self,
28
- spectra_view: str,
29
- fp_dir_pth: str = None,
30
- cons_spec_dir_pth: str = None,
31
- NL_spec_dir_pth: str = None,
32
- **kwargs
33
- ):
34
- super().__init__(**kwargs)
35
 
36
- self.use_fp = False
37
- self.use_cons_spec = False
38
- self.use_NL_spec = False
39
  self.spectra_view = spectra_view
40
 
41
- # load fingerprints
42
- self._load_fp(fp_dir_pth)
43
-
44
- # load consensus
45
- self._load_cons_spec(cons_spec_dir_pth)
46
-
47
- # load NL specs
48
- self._load_NL_spec(NL_spec_dir_pth)
49
-
50
- def _load_fp(self, fp_dir_pth):
51
- if fp_dir_pth is not None:
52
- self.use_fp = True
53
- if fp_dir_pth:
54
- with open(fp_dir_pth, 'rb') as f:
55
- self.smiles_to_fp = pickle.load(f)
56
- else:
57
- self.smiles_to_fp = {}
58
-
59
- def _load_cons_spec(self, cons_spec_dir_pth):
60
- if cons_spec_dir_pth is not None:
61
- self.use_cons_spec = True
62
- with open(cons_spec_dir_pth, 'rb') as f:
63
- cons_specs = pickle.load(f)
64
-
65
- # Convert spectra to matchms spectra
66
- matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
67
- spectra = cons_specs.apply(matchMS_preparer.prepare,axis=1)
68
-
69
- self.cons_specs = dict(zip(cons_specs['smiles'].tolist(), spectra))
70
-
71
- def _load_NL_spec(self, NL_spec_dir_pth):
72
- if NL_spec_dir_pth is not None:
73
- self.use_NL_spec = True
74
- with open(NL_spec_dir_pth, 'rb') as f:
75
- NL_specs = pickle.load(f)
76
-
77
- # Convert spectra to matchms spectra
78
- matchMS_preparer = data_utils.PrepMatchMS(self.spectra_view)
79
- self.NL_specs = NL_specs.apply(matchMS_preparer.prepare,axis=1)
80
-
81
-
82
- def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
83
-
84
- spec = self.spectra[i]
85
- metadata = self.metadata.iloc[i]
86
- mol = metadata["smiles"] if 'smiles' in metadata else metadata["identifier"]
87
-
88
- # Apply all transformations to the spectrum
89
- item = {}
90
- if transform_spec and self.spec_transform:
91
- if isinstance(self.spec_transform, dict):
92
- for key, transform in self.spec_transform.items():
93
- item[key] = transform(spec) if transform is not None else spec
94
- else:
95
- item["spec"] = self.spec_transform(spec)
96
-
97
- if self.return_mol_freq:
98
- item["mol_freq"] = metadata["mol_freq"]
99
-
100
- if self.return_identifier:
101
- item["identifier"] = metadata["identifier"]
102
-
103
- if self.use_fp and self.smiles_to_fp:
104
- item['fp'] = torch.Tensor(self.smiles_to_fp[mol].ToList())
105
-
106
- if self.use_cons_spec:
107
- item['cons_spec'] = self.spec_transform[self.spectra_view](self.cons_specs[mol])
108
-
109
- if self.use_NL_spec:
110
- item['NL_spec'] = self.spec_transform[self.spectra_view](self.NL_specs[i])
111
-
112
- # Apply all transformations to the molecule
113
- if transform_mol and self.mol_transform:
114
- if isinstance(self.mol_transform, dict):
115
- for key, transform in self.mol_transform.items():
116
- item[key] = transform(mol) if transform is not None else mol
117
- else:
118
- item["mol"] = self.mol_transform(mol)
119
- else:
120
- item["mol"] = mol
121
- return item
122
 
123
  class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
124
  def __init__(
@@ -128,26 +34,16 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
128
  mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]],
129
  pth: T.Optional[Path],
130
  subformula_dir_pth: str,
131
- fp_dir_pth: str = None,
132
- NL_spec_dir_pth: str = None,
133
- cons_spec_dir_pth: str = None,
134
  return_mol_freq: bool = False,
135
  return_identifier: bool = True,
136
  dtype: T.Type = torch.float32,
137
- formula_source = 'default',
138
- stage: Stage = Stage.TRAIN
139
  ):
140
- """
141
- Args:
142
- """
143
  self.pth = pth
144
  self.spec_transform = spec_transform
145
  self.mol_transform = mol_transform
146
  self.return_mol_freq = return_mol_freq
147
- self.pred_fp = False
148
- self.use_fp = False
149
- self.use_cons_spec = False
150
- self.use_NL_spec = False
151
  self.spectra_view = spectra_view
152
  self.formula_source = formula_source
153
  self.subformula_dir_pth = subformula_dir_pth
@@ -155,31 +51,23 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
155
  if isinstance(self.pth, str):
156
  self.pth = Path(self.pth)
157
 
158
- self.spectra_view = spectra_view
159
  print("Data path: ", self.pth)
160
  self.metadata = pd.read_csv(self.pth, sep="\t")
161
 
162
- # load subformulas
163
  id_to_spec = self._load_id_to_spec(stage)
164
-
165
- # load fingerprints
166
- self._load_fp(fp_dir_pth)
167
-
168
- # load consensus spectra
169
- self._load_cons_spec(cons_spec_dir_pth)
170
 
171
- # load NL specs
172
- self._load_NL_spec(NL_spec_dir_pth)
173
 
174
- self.metadata = self.metadata[self.metadata['identifier'].isin(id_to_spec)]
 
 
 
 
 
175
 
176
- formula_df = pd.DataFrame.from_dict(id_to_spec, orient='index').reset_index().rename(columns={'index': 'identifier'})
177
- self.metadata = self.metadata.merge(formula_df, on='identifier')
178
-
179
- # create matchms spectra
180
  matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view)
181
- self.spectra = self.metadata.apply(matchMS_preparer.prepare,axis=1)
182
-
183
  if self.return_mol_freq:
184
  if "inchikey" not in self.metadata.columns:
185
  self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
@@ -187,108 +75,104 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
187
 
188
  self.return_identifier = return_identifier
189
  self.dtype = dtype
190
-
191
  def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
192
- item = super().__getitem__(i, transform_spec, transform_mol = False)
193
- mol = item['mol'] #smiles
194
 
195
- # transform mol
196
  if transform_mol:
197
  if isinstance(self.mol_transform, dict):
198
  for key, transform in self.mol_transform.items():
199
  item[key] = transform(mol) if transform is not None else mol
200
  else:
201
  item["mol"] = self.mol_transform(mol)
 
 
202
 
203
  return item
204
 
205
  def _load_id_to_spec(self, stage):
206
- # if stage == Stage.TRAIN:
207
- # self.metadata = self.metadata[self.metadata['fold'] != Stage.TEST.value]
208
- # else:
209
- # self.metadata = self.metadata[self.metadata['fold'] == Stage.TEST.value]
210
-
211
- all_spec_ids = self.metadata['identifier'].tolist()
212
- self.subformulaLoader = data_utils.Subformula_Loader(spectra_view=self.spectra_view, dir_path=self.subformula_dir_pth, formula_source=self.formula_source)
213
-
214
- form_list = self.metadata['formula'].tolist()
215
- prec_mz_list = self.metadata['precursor_mz'].tolist()
216
  id_to_spec = self.subformulaLoader(all_spec_ids, form_list, prec_mz_list)
217
 
218
- # create subformula spectra if no subformula is available
219
  tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
220
- tmp_df = self.metadata[self.metadata['identifier'].isin(tmp_ids)]
221
- tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
222
- id_to_spec.update(dict(zip(tmp_df['identifier'].tolist(), tmp_df['spec'].tolist())))
223
 
224
  return id_to_spec
225
 
 
226
  class ContrastiveDataset(Dataset):
227
- def __init__(
228
- self,
229
- spec_mol_data,
230
- ):
231
  super().__init__()
232
-
233
  indices = spec_mol_data.indices
234
  self.spec_mol_data = spec_mol_data
235
- self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby('smiles').indices
236
  self.smiles_to_spec_couter = defaultdict(int)
237
  self.smiles_list = list(self.smiles_to_specmol_ids.keys())
238
 
239
  def __len__(self) -> int:
240
  return len(self.smiles_list)
241
-
242
- def __getitem__(self, i:int) -> dict:
243
  mol = self.smiles_list[i]
244
 
245
- # select spectrum (iterate through list of spectra)
246
  specmol_ids = self.smiles_to_specmol_ids[mol]
247
  counter = self.smiles_to_spec_couter[mol]
248
  specmol_id = specmol_ids[counter % len(specmol_ids)]
249
 
250
  item = self.spec_mol_data.__getitem__(specmol_id)
251
- self.smiles_to_spec_couter[mol] = counter+1
252
- # item['smiles'] = mol
253
- # item['spec_id'] = specmol_id
254
  return item
255
 
256
  @staticmethod
257
- def collate_fn(batch: T.Iterable[dict], spec_enc: str, spectra_view: str, stage=None, batch_mol: bool = True) -> dict:
258
- mol_key = 'cand' if stage == Stage.TEST else 'mol'
259
- non_standard_collate = ['mol', 'cand', 'aug_cands', 'cons_spec', 'aug_cands_fp', 'NL_spec']
 
 
 
 
 
 
260
  require_pad = False
261
- if 'Formula' in spectra_view or 'Tokens' in spectra_view:
262
  require_pad = True
263
- padding_value=-5 if spec_enc in ('Transformer_Formula', 'Formula_BinnedSpec', 'Transformer_MzInt') else 0
264
  non_standard_collate.append(spectra_view)
265
  else:
266
- non_standard_collate.remove('cons_spec')
267
- non_standard_collate.remove('NL_spec')
268
 
269
  collated_batch = {}
270
- # standard collate
271
  for k in batch[0].keys():
272
  if k not in non_standard_collate:
273
  try:
274
  collated_batch[k] = default_collate([item[k] for item in batch])
275
- except:
276
  print(f"Error in collating key {k}")
277
  raise
278
-
279
- # batch graphs
280
  if batch_mol:
281
- batch_mol = []
282
- batch_mol_nodes= []
283
 
284
  for item in batch:
285
- batch_mol.append(item[mol_key])
286
  batch_mol_nodes.append(item[mol_key].num_nodes())
287
 
288
- collated_batch[mol_key] = dgl.batch(batch_mol)
289
- collated_batch['mol_n_nodes'] = batch_mol_nodes
290
-
291
- # pad peaks/formulas
292
  if require_pad:
293
  peaks = []
294
  n_peaks = []
@@ -296,54 +180,40 @@ class ContrastiveDataset(Dataset):
296
  peaks.append(item[spectra_view])
297
  n_peaks.append(len(item[spectra_view]))
298
  collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
299
- collated_batch['n_peaks'] = n_peaks
300
-
301
- if 'cons_spec' in batch[0]:
302
- peaks = []
303
- n_peaks = []
304
- for item in batch:
305
- peaks.append(item['cons_spec'])
306
- n_peaks.append(len(item['cons_spec']))
307
- collated_batch['cons_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
308
- collated_batch['cons_n_peaks'] = n_peaks
309
-
310
- if 'NL_spec' in batch[0]:
311
- peaks = []
312
- n_peaks = []
313
- for item in batch:
314
- peaks.append(item['NL_spec'])
315
- n_peaks.append(len(item['NL_spec']))
316
- collated_batch['NL_spec'] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
317
- collated_batch['NL_n_peaks'] = n_peaks
318
  return collated_batch
319
-
320
-
321
 
322
  class ExpandedRetrievalDataset:
323
- '''Used for testing only
324
- Assumes 'fold' column defines the split'''
325
- def __init__(self,
326
- use_formulas: bool = True,
327
- mol_label_transform: MolTransform = MolToInChIKey(),
328
- candidates_pth: T.Optional[T.Union[Path, str]] = None,
329
- fp_size: int = None,
330
- fp_radius: int = None,
331
- use_magma = False,
332
- **kwargs):
333
-
334
-
335
- self.use_magma = use_magma
336
-
337
- self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False, stage = Stage.TEST) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
338
-
339
- if self.use_fp:
340
- self.fpgen = AllChem.GetMorganGenerator(radius=fp_radius,fpSize=fp_size)
341
 
342
  self.candidates_pth = candidates_pth
 
343
  self.mol_label_transform = mol_label_transform
344
-
345
- # Read candidates_pth from json to dict: SMILES -> respective candidate SMILES
346
- with open(self.candidates_pth, "r") as file:
 
 
 
347
  candidates = json.load(file)
348
 
349
  self.candidates = {}
@@ -351,130 +221,66 @@ class ExpandedRetrievalDataset:
351
  clean_cands = []
352
  for c in cand:
353
  try:
354
- if '.' not in c:
355
  clean_cands.append(c)
356
- except:
357
  print(f"Error in processing candidate {c} for smiles {s}")
358
- pass
359
- self.candidates[s] = clean_cands
360
-
361
- self.spec_cand = [] #(spec index, cand_smiles, true_label)
362
-
363
- # use for external dataset where target smiles is not known
364
- # self.candidates should be a dict of identifier to candidates
365
- if 'smiles' not in self.metadata.columns:
366
- if not isinstance(self.metadata.iloc[0]['identifier'], str):
367
- self.metadata['smiles'] = self.metadata['identifier'].apply(str)
368
- else:
369
- self.metadata['smiles'] = self.metadata['identifier']
370
-
371
- # keep datapoints where there are candidates
372
- self.metadata = self.metadata[self.metadata['smiles'].isin(self.candidates.keys())]
373
-
374
- test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
375
- test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
376
-
377
- self.spec_id_to_index = dict(zip(self.metadata['identifier'], self.metadata.index))
378
-
379
- for spec_id, s in zip(test_ms_id, test_smiles):
380
- candidates = self.candidates[s]
381
-
382
- # mol_label = self.mol_label_transform(s)
383
- # labels = [self.mol_label_transform(c) == mol_label for c in candidates]
384
- labels = [c == s for c in candidates]
385
- if len(candidates) == 0:
386
- print(f"Skipping {spec_id}; empty candidate set")
387
- continue
388
- if not any(labels):
389
- # print(f"Target smiles not in candidate set")
390
- pass
391
-
392
- self.spec_cand.extend([(self.spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
393
-
394
- def __getattr__(self, name):
395
- return self.instance.__getattribute__(name)
396
-
397
- def __len__(self):
398
- return len(self.spec_cand)
399
-
400
- def __getitem__(self, i):
401
- spec_i = self.spec_cand[i][0]
402
- cand_smiles = self.spec_cand[i][1]
403
- label = self.spec_cand[i][2]
404
 
405
- if self.use_magma:
406
- item = self.instance.__getitem__(spec_i, transform_mol=False, transform_spec=False)
407
 
408
- mzs = np.array([float(x) for x in self.metadata.iloc[spec_i]['mzs'].split(',')])
409
- intensities = np.array([float(x) for x in self.metadata.iloc[spec_i]['intensities'].split(',')])
410
- adduct = self.metadata.iloc[spec_i]['adduct']
411
- precursor_mz = self.metadata.iloc[spec_i]['precursor_mz']
412
- formula = self.metadata.iloc[spec_i]['formula']
413
- spec_data = run_magma(i, mzs, intensities, cand_smiles, adduct)
414
 
415
- spec = self.subformulaLoader.load_magma_data(spec_data, formula, precursor_mz)
 
 
 
 
 
416
 
417
- spec = matchms.Spectrum(
418
- mz = np.array(spec['formula_mzs']),
419
- intensities = np.array(spec['formula_intensities']),
420
- metadata = {'precursor_mz': precursor_mz, 'formulas': np.array(spec['formulas'])})
421
-
422
- if isinstance(self.spec_transform, dict):
423
-
424
- for key, transform in self.spec_transform.items():
425
- item[key] = transform(spec) if transform is not None else spec
426
- else:
427
- item["spec"] = self.spec_transform(spec)
428
 
 
 
429
  else:
430
- item = self.instance.__getitem__(spec_i, transform_mol=False)
431
 
432
- item['cand'] = self.mol_transform(cand_smiles)
433
- item['cand_smiles'] = cand_smiles
434
- item['label'] = label
435
 
436
- if self.use_fp:
437
- item['fp'] = torch.Tensor(self.fpgen.GetFingerprint(Chem.MolFromSmiles(cand_smiles)).ToList())
 
 
438
 
439
- return item
 
 
 
 
 
 
 
440
 
441
- class MassSpecDataset_Candidates:
442
-
443
- def __init__(self,
444
- use_formulas: bool,
445
- aug_cands_dir_pth: str,
446
- aug_cands_size: int,
447
- **kwargs):
448
- self.aug_cands_size = aug_cands_size
449
- self.instance = MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False) if use_formulas else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
450
-
451
- with open(aug_cands_dir_pth, 'rb') as f:
452
- aug_cands = pickle.load(f)
453
-
454
- if self.use_fp:
455
- self.fpgen = AllChem.GetMorganGenerator(radius=5,fpSize=1024)
456
-
457
- self.aug_cands = {}
458
- targets = np.array(list(aug_cands.keys()))
459
- for smiles, cands in aug_cands.items():
460
- # sort candidates by tanimoto similarity
461
- cands.sort(key=lambda x: x[1], reverse=True)
462
- cands = [c for c in cands if '.' not in c]
463
- # assert(len(cands) >0)
464
- if len(cands) <=1: # if no candidates, shuffle from target list
465
- np.random.shuffle(targets)
466
- cands = targets
467
- self.aug_cands[smiles] = itertools.cycle(cands)
468
 
469
  def __getattr__(self, name):
470
  return self.instance.__getattribute__(name)
471
-
 
 
 
472
  def __getitem__(self, i):
473
- item = self.instance.__getitem__(i,transform_mol=False)
 
 
474
 
475
- aug_cands = [next(self.aug_cands[item['mol']]) for _ in range(self.aug_cands_size)]
476
- item['aug_cands_fp'] = [self.fpgen.GetFingerprint(Chem.MolFromSmiles(c)).ToList() for c in aug_cands]
477
- item["aug_cands"] = [self.mol_transform(c) for c in aug_cands]
478
- item["mol"] = self.mol_transform(item["mol"])
479
 
480
- return item
 
 
1
  import json
2
  import typing as T
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+
6
+ import dgl
7
  import numpy as np
8
+ import pandas as pd
9
  import torch
10
+ import flare.utils.data as data_utils
11
  import massspecgym.utils as utils
12
+ import matchms
 
 
 
 
 
13
  from massspecgym.data.datasets import MassSpecDataset
14
+ from massspecgym.data.transforms import MolTransform, MolToInChIKey, SpecTransform
 
15
  from massspecgym.models.base import Stage
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ from torch.utils.data.dataloader import default_collate
18
+ from torch.utils.data.dataset import Dataset
19
+
 
 
 
20
 
21
  class JESTR1_MassSpecDataset(MassSpecDataset):
22
+ """Same as MassSpecDataset; keeps `spectra_view` for API compatibility."""
 
 
 
 
 
 
 
 
23
 
24
+ def __init__(self, spectra_view: str, **kwargs):
25
+ super().__init__(**kwargs)
 
26
  self.spectra_view = spectra_view
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
30
  def __init__(
 
34
  mol_transform: T.Optional[T.Union[MolTransform, T.Dict[str, MolTransform]]],
35
  pth: T.Optional[Path],
36
  subformula_dir_pth: str,
 
 
 
37
  return_mol_freq: bool = False,
38
  return_identifier: bool = True,
39
  dtype: T.Type = torch.float32,
40
+ formula_source: str = "default",
41
+ stage: Stage = Stage.TRAIN,
42
  ):
 
 
 
43
  self.pth = pth
44
  self.spec_transform = spec_transform
45
  self.mol_transform = mol_transform
46
  self.return_mol_freq = return_mol_freq
 
 
 
 
47
  self.spectra_view = spectra_view
48
  self.formula_source = formula_source
49
  self.subformula_dir_pth = subformula_dir_pth
 
51
  if isinstance(self.pth, str):
52
  self.pth = Path(self.pth)
53
 
 
54
  print("Data path: ", self.pth)
55
  self.metadata = pd.read_csv(self.pth, sep="\t")
56
 
 
57
  id_to_spec = self._load_id_to_spec(stage)
 
 
 
 
 
 
58
 
59
+ self.metadata = self.metadata[self.metadata["identifier"].isin(id_to_spec)]
 
60
 
61
+ formula_df = (
62
+ pd.DataFrame.from_dict(id_to_spec, orient="index")
63
+ .reset_index()
64
+ .rename(columns={"index": "identifier"})
65
+ )
66
+ self.metadata = self.metadata.merge(formula_df, on="identifier")
67
 
 
 
 
 
68
  matchMS_preparer = data_utils.PrepMatchMS(spectra_view=spectra_view)
69
+ self.spectra = self.metadata.apply(matchMS_preparer.prepare, axis=1)
70
+
71
  if self.return_mol_freq:
72
  if "inchikey" not in self.metadata.columns:
73
  self.metadata["inchikey"] = self.metadata["smiles"].apply(utils.smiles_to_inchi_key)
 
75
 
76
  self.return_identifier = return_identifier
77
  self.dtype = dtype
78
+
79
  def __getitem__(self, i, transform_spec: bool = True, transform_mol: bool = True):
80
+ item = super().__getitem__(i, transform_spec, transform_mol=False)
81
+ mol = item["mol"]
82
 
 
83
  if transform_mol:
84
  if isinstance(self.mol_transform, dict):
85
  for key, transform in self.mol_transform.items():
86
  item[key] = transform(mol) if transform is not None else mol
87
  else:
88
  item["mol"] = self.mol_transform(mol)
89
+ else:
90
+ item["mol"] = mol
91
 
92
  return item
93
 
94
  def _load_id_to_spec(self, stage):
95
+ all_spec_ids = self.metadata["identifier"].tolist()
96
+ self.subformulaLoader = data_utils.Subformula_Loader(
97
+ spectra_view=self.spectra_view,
98
+ dir_path=self.subformula_dir_pth,
99
+ formula_source=self.formula_source,
100
+ )
101
+
102
+ form_list = self.metadata["formula"].tolist()
103
+ prec_mz_list = self.metadata["precursor_mz"].tolist()
 
104
  id_to_spec = self.subformulaLoader(all_spec_ids, form_list, prec_mz_list)
105
 
 
106
  tmp_ids = [spec_id for spec_id in all_spec_ids if spec_id not in id_to_spec]
107
+ tmp_df = self.metadata[self.metadata["identifier"].isin(tmp_ids)]
108
+ tmp_df["spec"] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)
109
+ id_to_spec.update(dict(zip(tmp_df["identifier"].tolist(), tmp_df["spec"].tolist())))
110
 
111
  return id_to_spec
112
 
113
+
114
  class ContrastiveDataset(Dataset):
115
+ def __init__(self, spec_mol_data):
 
 
 
116
  super().__init__()
117
+
118
  indices = spec_mol_data.indices
119
  self.spec_mol_data = spec_mol_data
120
+ self.smiles_to_specmol_ids = spec_mol_data.dataset.metadata.loc[indices].groupby("smiles").indices
121
  self.smiles_to_spec_couter = defaultdict(int)
122
  self.smiles_list = list(self.smiles_to_specmol_ids.keys())
123
 
124
  def __len__(self) -> int:
125
  return len(self.smiles_list)
126
+
127
+ def __getitem__(self, i: int) -> dict:
128
  mol = self.smiles_list[i]
129
 
 
130
  specmol_ids = self.smiles_to_specmol_ids[mol]
131
  counter = self.smiles_to_spec_couter[mol]
132
  specmol_id = specmol_ids[counter % len(specmol_ids)]
133
 
134
  item = self.spec_mol_data.__getitem__(specmol_id)
135
+ self.smiles_to_spec_couter[mol] = counter + 1
 
 
136
  return item
137
 
138
  @staticmethod
139
+ def collate_fn(
140
+ batch: T.Iterable[dict],
141
+ spec_enc: str,
142
+ spectra_view: str,
143
+ stage=None,
144
+ batch_mol: bool = True,
145
+ ) -> dict:
146
+ mol_key = "cand" if stage == Stage.TEST else "mol"
147
+ non_standard_collate = ["mol", "cand", "aug_cands"]
148
  require_pad = False
149
+ if "Formula" in spectra_view or "Tokens" in spectra_view:
150
  require_pad = True
151
+ padding_value = -5 if spec_enc in ("Transformer_Formula", "Formula_BinnedSpec", "Transformer_MzInt") else 0
152
  non_standard_collate.append(spectra_view)
153
  else:
154
+ non_standard_collate.remove("aug_cands")
 
155
 
156
  collated_batch = {}
 
157
  for k in batch[0].keys():
158
  if k not in non_standard_collate:
159
  try:
160
  collated_batch[k] = default_collate([item[k] for item in batch])
161
+ except Exception:
162
  print(f"Error in collating key {k}")
163
  raise
164
+
 
165
  if batch_mol:
166
+ batch_mol_list = []
167
+ batch_mol_nodes = []
168
 
169
  for item in batch:
170
+ batch_mol_list.append(item[mol_key])
171
  batch_mol_nodes.append(item[mol_key].num_nodes())
172
 
173
+ collated_batch[mol_key] = dgl.batch(batch_mol_list)
174
+ collated_batch["mol_n_nodes"] = batch_mol_nodes
175
+
 
176
  if require_pad:
177
  peaks = []
178
  n_peaks = []
 
180
  peaks.append(item[spectra_view])
181
  n_peaks.append(len(item[spectra_view]))
182
  collated_batch[spectra_view] = pad_sequence(peaks, batch_first=True, padding_value=padding_value)
183
+ collated_batch["n_peaks"] = n_peaks
184
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  return collated_batch
186
+
 
187
 
188
  class ExpandedRetrievalDataset:
189
+ """Test-time retrieval over a fixed candidate pool per spectrum/formula."""
190
+
191
+ def __init__(
192
+ self,
193
+ use_formulas: bool = True,
194
+ mol_label_transform: MolTransform = MolToInChIKey(),
195
+ candidates_pth: T.Optional[T.Union[Path, str]] = None,
196
+ formula_to_smiles_pth: T.Optional[T.Union[Path, str]] = None,
197
+ external_test: bool = False,
198
+ **kwargs,
199
+ ):
200
+ self.external_test = external_test
201
+
202
+ self.instance = (
203
+ MassSpecDataset_PeakFormulas(**kwargs, return_mol_freq=False, stage=Stage.TEST)
204
+ if use_formulas
205
+ else JESTR1_MassSpecDataset(**kwargs, return_mol_freq=False)
206
+ )
207
 
208
  self.candidates_pth = candidates_pth
209
+ self.formula_to_smiles_pth = formula_to_smiles_pth
210
  self.mol_label_transform = mol_label_transform
211
+
212
+ candidate_source_pth = self.formula_to_smiles_pth if self.formula_to_smiles_pth else self.candidates_pth
213
+ if not candidate_source_pth:
214
+ raise ValueError("One of candidates_pth or formula_to_smiles_pth must be provided.")
215
+
216
+ with open(candidate_source_pth, "r") as file:
217
  candidates = json.load(file)
218
 
219
  self.candidates = {}
 
221
  clean_cands = []
222
  for c in cand:
223
  try:
224
+ if "." not in c:
225
  clean_cands.append(c)
226
+ except Exception:
227
  print(f"Error in processing candidate {c} for smiles {s}")
228
+ self.candidates[s] = clean_cands
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ self.spec_cand = []
 
231
 
232
+ if "smiles" not in self.metadata.columns:
233
+ if not isinstance(self.metadata.iloc[0]["identifier"], str):
234
+ self.metadata["smiles"] = self.metadata["identifier"].apply(str)
235
+ else:
236
+ self.metadata["smiles"] = self.metadata["identifier"]
 
237
 
238
+ if self.formula_to_smiles_pth:
239
+ if "formula" not in self.metadata.columns:
240
+ raise ValueError("formula_to_smiles_pth was provided, but dataset has no 'formula' column.")
241
+ self.metadata["candidate_key"] = self.metadata["formula"].astype(str)
242
+ else:
243
+ self.metadata["candidate_key"] = self.metadata["smiles"].astype(str)
244
 
245
+ self.metadata = self.metadata[self.metadata["candidate_key"].isin(self.candidates.keys())]
 
 
 
 
 
 
 
 
 
 
246
 
247
+ if "fold" in self.metadata.columns:
248
+ test_metadata = self.metadata[self.metadata["fold"] == "test"]
249
  else:
250
+ test_metadata = self.metadata
251
 
252
+ self.spec_id_to_index = dict(zip(self.metadata["identifier"], self.metadata.index))
 
 
253
 
254
+ for _, row in test_metadata.iterrows():
255
+ spec_id = row["identifier"]
256
+ candidate_key = row["candidate_key"]
257
+ cands = self.candidates[candidate_key]
258
 
259
+ if self.external_test:
260
+ labels = [False for _ in cands]
261
+ else:
262
+ target_smiles = row["smiles"]
263
+ labels = [c == target_smiles for c in cands]
264
+ if len(cands) == 0:
265
+ print(f"Skipping {spec_id}; empty candidate set")
266
+ continue
267
 
268
+ self.spec_cand.extend([(self.spec_id_to_index[spec_id], cands[j], k) for j, k in enumerate(labels)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  def __getattr__(self, name):
271
  return self.instance.__getattribute__(name)
272
+
273
+ def __len__(self):
274
+ return len(self.spec_cand)
275
+
276
  def __getitem__(self, i):
277
+ spec_i = self.spec_cand[i][0]
278
+ cand_smiles = self.spec_cand[i][1]
279
+ label = self.spec_cand[i][2]
280
 
281
+ item = self.instance.__getitem__(spec_i, transform_mol=False)
282
+ item["cand"] = self.mol_transform(cand_smiles)
283
+ item["cand_smiles"] = cand_smiles
284
+ item["label"] = label
285
 
286
+ return item
flare/run.sh CHANGED
@@ -1,3 +1,17 @@
1
- # python train.py --param_pth params_filipGlobal.yaml
2
- # python test.py --param_pth params_filipGlobal.yaml
3
- python test.py --param_pth params_filipGlobal.yaml --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Example: run from the `flare/` directory.
3
+ # conda activate flare
4
+ # python train.py
5
+ # python test.py --checkpoint_choice val
6
+ #
7
+ # Optional overrides:
8
+ # export FLARE_PARAMS=/path/to/params.yaml
9
+ # export CANDIDATES_JSON="$PWD/../data/MassSpecGym_retrieval_candidates_formula.json"
10
+ # python test.py --candidates_pth "${CANDIDATES_JSON}"
11
+
12
+ set -euo pipefail
13
+ REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
14
+ export FLARE_REPO_ROOT="${FLARE_REPO_ROOT:-$REPO_ROOT}"
15
+
16
+ # python train.py
17
+ # python test.py --checkpoint_choice val
flare/test.py CHANGED
@@ -1,128 +1,163 @@
1
  import argparse
2
  import datetime
 
3
  import sys
4
- sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
5
- sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- from rdkit import RDLogger
8
  import pytorch_lightning as pl
9
  from pytorch_lightning import Trainer
 
10
  from massspecgym.models.base import Stage
11
- import os
12
 
13
  from flare.data.data_module import TestDataModule
14
  from flare.data.datasets import ContrastiveDataset
15
- from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset
 
 
16
  from flare.utils.models import get_model
17
 
18
- from flare.definitions import TEST_RESULTS_DIR
19
- import yaml
20
- from functools import partial
21
- # Suppress RDKit warnings and errors
22
  lg = RDLogger.logger()
23
  lg.setLevel(RDLogger.CRITICAL)
24
 
25
  parser = argparse.ArgumentParser()
26
- parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
27
- parser.add_argument('--checkpoint_pth', type=str, default='')
28
- parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['train', 'val'])
29
- parser.add_argument('--df_test_pth', type=str, help='result file name')
30
- parser.add_argument('--exp_dir', type=str)
31
- parser.add_argument('--candidates_pth', type=str)
32
- parser.add_argument('--external_test', action='store_true', help='whether the test set is external data without labels')
 
 
 
 
 
 
 
 
 
33
 
34
- def main(params):
35
- # Seed everything
36
- pl.seed_everything(params['seed'])
37
-
38
- # Init paths to data files
39
- if params['debug']:
40
 
41
- params['dataset_pth'] = "/data/yzhouc01/MVP/data/sample/data.tsv"
42
- params['split_pth']=None
43
- params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')
44
 
45
- # Load dataset
46
- spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
 
 
 
 
 
 
 
 
47
 
48
- mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
49
- dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
 
 
 
50
 
51
- # Init data module
52
- collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], stage=Stage.TEST)
 
 
 
 
53
  data_module = TestDataModule(
54
  dataset=dataset,
55
  collate_fn=collate_fn,
56
- split_pth=params['split_pth'],
57
- batch_size=params['batch_size'],
58
- num_workers=params['num_workers']
59
  )
60
 
61
- model = get_model(params['model'], params)
62
- model.df_test_path = params['df_test_path']
63
- model.external_test = params['external_test']
64
-
65
- # Init trainer
66
  trainer = Trainer(
67
- accelerator=params['accelerator'],
68
- devices=params['devices'],
69
- default_root_dir=params['experiment_dir']
70
  )
71
 
72
- # Prepare data module to test
73
  data_module.prepare_data()
74
  data_module.setup(stage="test")
75
-
76
- # Test
77
  trainer.test(model, datamodule=data_module)
78
 
79
 
80
  if __name__ == "__main__":
81
  args = parser.parse_args([] if "__file__" not in globals() else None)
82
 
83
- # Load
84
- with open(args.param_pth) as f:
85
- params = yaml.load(f, Loader=yaml.FullLoader)
86
-
87
- # Experiment directory
88
  if args.exp_dir:
89
  exp_dir = args.exp_dir
90
  else:
91
- run_name = params['run_name']
92
- for exp in os.listdir(TEST_RESULTS_DIR): # find exp dir with matching run_name
93
- if exp.endswith("_"+run_name):
94
- exp_dir = str(TEST_RESULTS_DIR / exp)
95
- break
96
- if not exp_dir:
97
- now = datetime.datetime.now().strftime("%Y%m%d")
98
- exp_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}")
99
- os.makedirs(exp_dir, exist_ok=True)
100
- print("EXPERIMENT directory: ",exp_dir)
101
- params['experiment_dir'] = exp_dir
102
-
103
- # Checkpoint path
104
  if args.checkpoint_pth:
105
- params['checkpoint_pth'] = args.checkpoint_pth
106
-
107
- if not params['checkpoint_pth']:
108
- print("No checkpoint provided. Using the checkpoint in the experiment directory")
109
- for f in os.listdir(exp_dir):
110
  if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
111
- checkpoint_path = os.path.join(exp_dir, f)
112
- params['checkpoint_pth'] = checkpoint_path
113
  break
114
- assert(params['checkpoint_pth'] != '')
115
 
116
  if args.external_test:
117
- params['external_test'] = True
118
  else:
119
- params['external_test'] = False
120
-
121
  if args.candidates_pth:
122
- params['candidates_pth'] = args.candidates_pth
 
 
 
 
123
  if args.df_test_pth:
124
- params['df_test_path'] = os.path.join(exp_dir, args.df_test_pth)
125
- if not params['df_test_path']:
126
- params['df_test_path'] = os.path.join(exp_dir, f"result_{params['candidates_pth'].split('/')[-1].split('.')[0]}.pkl")
127
-
 
 
 
 
 
128
  main(params)
 
1
  import argparse
2
  import datetime
3
+ import os
4
  import sys
5
+ from pathlib import Path
6
+
7
+
8
+ def _add_local_dependency_paths() -> None:
9
+ flare_repo_root = Path(__file__).resolve().parents[1]
10
+ candidate_roots = [flare_repo_root, flare_repo_root / "massspecgym"]
11
+ for env_var in ("MASSSPECGYM_ROOT",):
12
+ env_value = os.environ.get(env_var)
13
+ if env_value:
14
+ candidate_roots.append(Path(env_value).expanduser())
15
+ for root in candidate_roots:
16
+ if root.exists():
17
+ root_str = str(root)
18
+ if root_str not in sys.path:
19
+ sys.path.insert(0, root_str)
20
+
21
+
22
+ _add_local_dependency_paths()
23
+
24
+ from functools import partial
25
 
 
26
  import pytorch_lightning as pl
27
  from pytorch_lightning import Trainer
28
+ from rdkit import RDLogger
29
  from massspecgym.models.base import Stage
 
30
 
31
  from flare.data.data_module import TestDataModule
32
  from flare.data.datasets import ContrastiveDataset
33
+ from flare.definitions import DATA_DIR, TEST_RESULTS_DIR
34
+ from flare.utils.config import default_param_path, load_param_file
35
+ from flare.utils.data import get_mol_featurizer, get_spec_featurizer, get_test_ms_dataset
36
  from flare.utils.models import get_model
37
 
 
 
 
 
38
  lg = RDLogger.logger()
39
  lg.setLevel(RDLogger.CRITICAL)
40
 
41
  parser = argparse.ArgumentParser()
42
+ parser.add_argument(
43
+ "--param_pth",
44
+ type=str,
45
+ default=None,
46
+ help="YAML hyperparameters (default: FLARE_PARAMS or repo params.yaml)",
47
+ )
48
+ parser.add_argument("--checkpoint_pth", type=str, default="")
49
+ parser.add_argument("--checkpoint_choice", type=str, default="train", choices=["train", "val"])
50
+ parser.add_argument("--df_test_pth", type=str, help="result file name under experiment_dir")
51
+ parser.add_argument("--exp_dir", type=str, help="experiment directory (overrides auto-detect)")
52
+ parser.add_argument("--candidates_pth", type=str, help="override candidates JSON path")
53
+ parser.add_argument(
54
+ "--external_test",
55
+ action="store_true",
56
+ help="external data without ground-truth labels in the candidate list",
57
+ )
58
 
 
 
 
 
 
 
59
 
60
+ def main(params):
61
+ pl.seed_everything(params["seed"])
 
62
 
63
+ if params.get("debug"):
64
+ dbg = os.environ.get("FLARE_DEBUG_DATASET")
65
+ if dbg:
66
+ params["dataset_pth"] = dbg
67
+ else:
68
+ sample_tsv = DATA_DIR / "sample" / "data.tsv"
69
+ if sample_tsv.is_file():
70
+ params["dataset_pth"] = str(sample_tsv)
71
+ params["split_pth"] = None
72
+ params["df_test_path"] = os.path.join(params["experiment_dir"], "debug_result.pkl")
73
 
74
+ spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
75
+ mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
76
+ dataset = get_test_ms_dataset(
77
+ params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params
78
+ )
79
 
80
+ collate_fn = partial(
81
+ ContrastiveDataset.collate_fn,
82
+ spec_enc=params["spec_enc"],
83
+ spectra_view=params["spectra_view"],
84
+ stage=Stage.TEST,
85
+ )
86
  data_module = TestDataModule(
87
  dataset=dataset,
88
  collate_fn=collate_fn,
89
+ split_pth=params["split_pth"],
90
+ batch_size=params["batch_size"],
91
+ num_workers=params["num_workers"],
92
  )
93
 
94
+ model = get_model(params["model"], params)
95
+ model.df_test_path = params["df_test_path"]
96
+ model.external_test = params["external_test"]
97
+
 
98
  trainer = Trainer(
99
+ accelerator=params["accelerator"],
100
+ devices=params["devices"],
101
+ default_root_dir=params["experiment_dir"],
102
  )
103
 
 
104
  data_module.prepare_data()
105
  data_module.setup(stage="test")
 
 
106
  trainer.test(model, datamodule=data_module)
107
 
108
 
109
  if __name__ == "__main__":
110
  args = parser.parse_args([] if "__file__" not in globals() else None)
111
 
112
+ param_path = args.param_pth or str(default_param_path())
113
+ params = load_param_file(param_path)
114
+
115
+ exp_dir = None
 
116
  if args.exp_dir:
117
  exp_dir = args.exp_dir
118
  else:
119
+ run_name = params["run_name"]
120
+ if TEST_RESULTS_DIR.is_dir():
121
+ for exp in sorted(os.listdir(TEST_RESULTS_DIR), reverse=True):
122
+ if exp.endswith("_" + run_name):
123
+ exp_dir = str(TEST_RESULTS_DIR / exp)
124
+ break
125
+ if exp_dir is None:
126
+ today_str = datetime.datetime.now().strftime("%Y%m%d")
127
+ exp_dir = str(TEST_RESULTS_DIR / f"{today_str}_{run_name}")
128
+ os.makedirs(exp_dir, exist_ok=True)
129
+ params["experiment_dir"] = exp_dir
130
+
 
131
  if args.checkpoint_pth:
132
+ params["checkpoint_pth"] = args.checkpoint_pth
133
+
134
+ if not params.get("checkpoint_pth"):
135
+ print("No checkpoint in params; searching experiment_dir for a .ckpt file")
136
+ for f in sorted(os.listdir(exp_dir)):
137
  if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
138
+ params["checkpoint_pth"] = os.path.join(exp_dir, f)
 
139
  break
140
+ assert params.get("checkpoint_pth"), "No checkpoint found; pass --checkpoint_pth"
141
 
142
  if args.external_test:
143
+ params["external_test"] = True
144
  else:
145
+ params["external_test"] = params.get("external_test", False)
146
+
147
  if args.candidates_pth:
148
+ params["candidates_pth"] = args.candidates_pth
149
+ from flare.utils.config import resolve_repo_paths
150
+
151
+ resolve_repo_paths(params)
152
+
153
  if args.df_test_pth:
154
+ params["df_test_path"] = os.path.join(exp_dir, args.df_test_pth)
155
+ if not params.get("df_test_path"):
156
+ cand = params.get("candidates_pth") or "candidates.json"
157
+ stem = Path(cand).stem
158
+ params["df_test_path"] = os.path.join(exp_dir, f"result_{stem}.pkl")
159
+
160
+ print("DF TEST PATH: ", params["df_test_path"])
161
+ print("EXP DIR: ", exp_dir)
162
+
163
  main(params)
flare/train.py CHANGED
@@ -1,137 +1,128 @@
1
  import argparse
2
  import datetime
3
-
4
  import os
5
  import sys
6
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
 
8
- from rdkit import RDLogger
 
 
 
9
  import pytorch_lightning as pl
 
10
  from pytorch_lightning import Trainer
11
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
12
-
13
 
14
  from flare.data.data_module import ContrastiveDataModule
15
-
16
- from flare.definitions import TEST_RESULTS_DIR
17
- import yaml
18
  from flare.data.datasets import ContrastiveDataset
19
- from functools import partial
20
-
21
- from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
22
  from flare.utils.models import get_model
23
- # Suppress RDKit warnings and errors
24
  lg = RDLogger.logger()
25
  lg.setLevel(RDLogger.CRITICAL)
26
 
27
  parser = argparse.ArgumentParser()
28
- parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
 
 
 
 
 
 
29
 
30
  def main(params):
31
- # Seed everything
32
- pl.seed_everything(params['seed'])
33
-
34
- # Init paths to data files
35
- if params['debug']:
36
- params['dataset_pth'] = "/data/yzhouc01/MVP/data/sample/data.tsv"
37
- params['candidates_pth'] =None
38
- params['split_pth']=None
39
-
40
- # Load dataset
41
- spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
42
- mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
43
- dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
44
-
45
- # Init data module
46
- collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'])
 
 
 
 
 
 
47
  data_module = ContrastiveDataModule(
48
  dataset=dataset,
49
  collate_fn=collate_fn,
50
- split_pth=params['split_pth'],
51
- batch_size=params['batch_size'],
52
- num_workers=params['num_workers'],
53
  )
54
 
55
- model = get_model(params['model'], params)
56
-
57
- # Init logger
58
- if params['no_wandb']:
59
- logger = None
60
- else:
61
- logger = pl.loggers.WandbLogger(
62
- save_dir=params['experiment_dir'],
63
- dir=params['experiment_dir'],
64
- log_dir=params['experiment_dir'],
65
- name=params['run_name'],
66
- project=params['project_name'],
67
- log_model=False,
68
- config=model.hparams
69
- )
70
 
71
- # Init callbacks for checkpointing and early stopping
72
- callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ]
73
  for i, monitor in enumerate(model.get_checkpoint_monitors()):
74
- monitor_name = monitor['monitor']
75
  checkpoint = pl.callbacks.ModelCheckpoint(
76
  monitor=monitor_name,
77
  save_top_k=1,
78
- mode=monitor['mode'],
79
- dirpath=params['experiment_dir'],
80
- filename=f'{{epoch}}-{{{monitor_name}:.2f}}',
81
- # filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}',
82
  auto_insert_metric_name=True,
83
- # save_last=(i == 0)
84
  )
85
  callbacks.append(checkpoint)
86
- if monitor.get('early_stopping', False):
87
  early_stopping = EarlyStopping(
88
  monitor=monitor_name,
89
- mode=monitor['mode'],
90
  verbose=True,
91
- patience=params['early_stopping_patience'],
92
  )
93
  callbacks.append(early_stopping)
94
 
95
- # Init trainer
96
  trainer = Trainer(
97
- accelerator=params['accelerator'],
98
- devices=params['devices'],
99
- max_epochs=params['max_epochs'],
100
  logger=logger,
101
- log_every_n_steps=params['log_every_n_steps'],
102
- val_check_interval=params['val_check_interval'],
103
  callbacks=callbacks,
104
- default_root_dir=params['experiment_dir'],
105
  )
106
 
107
- # Prepare data module to validate or test before training
108
  data_module.prepare_data()
109
  data_module.setup()
110
 
111
-
112
- # Validate before training
113
  trainer.validate(model, datamodule=data_module)
114
-
115
- # Train
116
  trainer.fit(model, datamodule=data_module)
117
-
118
 
119
 
120
  if __name__ == "__main__":
121
  args = parser.parse_args([] if "__file__" not in globals() else None)
122
 
123
- # Get current time
 
 
124
  now = datetime.datetime.now()
125
  now_formatted = now.strftime("%Y%m%d")
126
-
127
- # Load
128
- with open(args.param_pth) as f:
129
- params = yaml.load(f, Loader=yaml.FullLoader)
130
-
131
  experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
132
- params['experiment_dir'] = experiment_dir
133
 
134
- if not params['df_test_path']:
135
- params['df_test_path'] = os.path.join(experiment_dir, "result.pkl")
136
 
137
  main(params)
 
1
  import argparse
2
  import datetime
 
3
  import os
4
  import sys
 
5
 
6
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
7
+
8
+ from functools import partial
9
+
10
  import pytorch_lightning as pl
11
+ import yaml
12
  from pytorch_lightning import Trainer
13
  from pytorch_lightning.callbacks.early_stopping import EarlyStopping
14
+ from rdkit import RDLogger
15
 
16
  from flare.data.data_module import ContrastiveDataModule
 
 
 
17
  from flare.data.datasets import ContrastiveDataset
18
+ from flare.definitions import DATA_DIR, TEST_RESULTS_DIR
19
+ from flare.utils.config import default_param_path, load_param_file
20
+ from flare.utils.data import get_ms_dataset, get_mol_featurizer, get_spec_featurizer
21
  from flare.utils.models import get_model
22
+
23
  lg = RDLogger.logger()
24
  lg.setLevel(RDLogger.CRITICAL)
25
 
26
  parser = argparse.ArgumentParser()
27
+ parser.add_argument(
28
+ "--param_pth",
29
+ type=str,
30
+ default=None,
31
+ help="YAML hyperparameters (default: FLARE_PARAMS env or repo params.yaml)",
32
+ )
33
+
34
 
35
  def main(params):
36
+ pl.seed_everything(params["seed"])
37
+
38
+ if params.get("debug"):
39
+ dbg = os.environ.get("FLARE_DEBUG_DATASET")
40
+ if dbg:
41
+ params["dataset_pth"] = dbg
42
+ else:
43
+ sample_tsv = DATA_DIR / "sample" / "data.tsv"
44
+ if sample_tsv.is_file():
45
+ params["dataset_pth"] = str(sample_tsv)
46
+ params["candidates_pth"] = None
47
+ params["split_pth"] = None
48
+
49
+ spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
50
+ mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
51
+ dataset = get_ms_dataset(
52
+ params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params
53
+ )
54
+
55
+ collate_fn = partial(
56
+ ContrastiveDataset.collate_fn, spec_enc=params["spec_enc"], spectra_view=params["spectra_view"]
57
+ )
58
  data_module = ContrastiveDataModule(
59
  dataset=dataset,
60
  collate_fn=collate_fn,
61
+ split_pth=params["split_pth"],
62
+ batch_size=params["batch_size"],
63
+ num_workers=params["num_workers"],
64
  )
65
 
66
+ model = get_model(params["model"], params)
67
+
68
+ tb_logger = pl.loggers.TensorBoardLogger(
69
+ save_dir=params["experiment_dir"],
70
+ name="",
71
+ version="",
72
+ )
73
+ logger = tb_logger
 
 
 
 
 
 
 
74
 
75
+ callbacks = [pl.callbacks.ModelCheckpoint(save_last=False)]
 
76
  for i, monitor in enumerate(model.get_checkpoint_monitors()):
77
+ monitor_name = monitor["monitor"]
78
  checkpoint = pl.callbacks.ModelCheckpoint(
79
  monitor=monitor_name,
80
  save_top_k=1,
81
+ mode=monitor["mode"],
82
+ dirpath=params["experiment_dir"],
83
+ filename=f"{{epoch}}-{{{monitor_name}:.2f}}",
 
84
  auto_insert_metric_name=True,
 
85
  )
86
  callbacks.append(checkpoint)
87
+ if monitor.get("early_stopping", False):
88
  early_stopping = EarlyStopping(
89
  monitor=monitor_name,
90
+ mode=monitor["mode"],
91
  verbose=True,
92
+ patience=params["early_stopping_patience"],
93
  )
94
  callbacks.append(early_stopping)
95
 
 
96
  trainer = Trainer(
97
+ accelerator=params["accelerator"],
98
+ devices=params["devices"],
99
+ max_epochs=params["max_epochs"],
100
  logger=logger,
101
+ log_every_n_steps=params["log_every_n_steps"],
102
+ val_check_interval=params["val_check_interval"],
103
  callbacks=callbacks,
104
+ default_root_dir=params["experiment_dir"],
105
  )
106
 
 
107
  data_module.prepare_data()
108
  data_module.setup()
109
 
 
 
110
  trainer.validate(model, datamodule=data_module)
 
 
111
  trainer.fit(model, datamodule=data_module)
 
112
 
113
 
114
  if __name__ == "__main__":
115
  args = parser.parse_args([] if "__file__" not in globals() else None)
116
 
117
+ param_path = args.param_pth or str(default_param_path())
118
+ params = load_param_file(param_path)
119
+
120
  now = datetime.datetime.now()
121
  now_formatted = now.strftime("%Y%m%d")
 
 
 
 
 
122
  experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
123
+ params["experiment_dir"] = experiment_dir
124
 
125
+ if not params.get("df_test_path"):
126
+ params["df_test_path"] = os.path.join(experiment_dir, "result.pkl")
127
 
128
  main(params)
flare/tune.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  import datetime
3
  import os
4
  import sys
@@ -20,6 +21,7 @@ from flare.data.datasets import ContrastiveDataset
20
  from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
21
  from flare.utils.models import get_model
22
  from flare.definitions import TEST_RESULTS_DIR
 
23
  from functools import partial
24
  from rdkit import RDLogger
25
  from massspecgym.models.base import Stage
@@ -29,7 +31,12 @@ lg = RDLogger.logger()
29
  lg.setLevel(RDLogger.CRITICAL)
30
 
31
  parser = argparse.ArgumentParser()
32
- parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
 
 
 
 
 
33
  parser.add_argument("--n_trials", type=int, default=20)
34
 
35
  class EpochLossTracker(Callback):
@@ -112,7 +119,7 @@ def save_trial_result(base_dir, trial, params, duration):
112
 
113
  def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials):
114
  start_time = time.time()
115
- params = base_params.copy()
116
 
117
  try:
118
  # Training-related params
@@ -160,8 +167,6 @@ def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_tri
160
  ContrastiveDataset.collate_fn,
161
  spec_enc=params["spec_enc"],
162
  spectra_view=params["spectra_view"],
163
- mask_peak_ratio=params["mask_peak_ratio"],
164
- aug_cands=params["aug_cands"],
165
  )
166
 
167
  data_module = ContrastiveDataModule(
@@ -226,12 +231,11 @@ def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_tri
226
 
227
 
228
  def main(args):
229
- with open(args.param_pth) as f:
230
- params = yaml.load(f, Loader=yaml.FullLoader)
231
 
232
- # now = datetime.datetime.now().strftime("%Y%m%d")
233
- # base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
234
- base_dir = "../experiments/20250916_simple_model_optuna"
235
  os.makedirs(base_dir, exist_ok=True)
236
  params["experiment_dir"] = base_dir
237
 
 
1
  import argparse
2
+ import copy
3
  import datetime
4
  import os
5
  import sys
 
21
  from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
22
  from flare.utils.models import get_model
23
  from flare.definitions import TEST_RESULTS_DIR
24
+ from flare.utils.config import default_param_path, load_param_file
25
  from functools import partial
26
  from rdkit import RDLogger
27
  from massspecgym.models.base import Stage
 
31
  lg.setLevel(RDLogger.CRITICAL)
32
 
33
  parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ "--param_pth",
36
+ type=str,
37
+ default=None,
38
+ help="Base YAML (default: FLARE_PARAMS or repo params.yaml)",
39
+ )
40
  parser.add_argument("--n_trials", type=int, default=20)
41
 
42
  class EpochLossTracker(Callback):
 
119
 
120
  def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials):
121
  start_time = time.time()
122
+ params = copy.deepcopy(base_params)
123
 
124
  try:
125
  # Training-related params
 
167
  ContrastiveDataset.collate_fn,
168
  spec_enc=params["spec_enc"],
169
  spectra_view=params["spectra_view"],
 
 
170
  )
171
 
172
  data_module = ContrastiveDataModule(
 
231
 
232
 
233
  def main(args):
234
+ param_path = args.param_pth or str(default_param_path())
235
+ params = load_param_file(param_path)
236
 
237
+ now = datetime.datetime.now().strftime("%Y%m%d")
238
+ base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
 
239
  os.makedirs(base_dir, exist_ok=True)
240
  params["experiment_dir"] = base_dir
241
 
flare/utils/__init__.py CHANGED
@@ -1,3 +1 @@
1
- import sys
2
- sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
3
  from massspecgym.utils import *
 
 
 
1
  from massspecgym.utils import *
flare/utils/config.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load YAML hyperparameters and resolve filesystem paths relative to the repository root."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import yaml
9
+
10
+ from flare.definitions import REPO_DIR
11
+
12
+ # Keys that may hold filesystem paths (relative paths are resolved against REPO_DIR).
13
+ _PATH_KEYS = frozenset(
14
+ {
15
+ "dataset_pth",
16
+ "candidates_pth",
17
+ "subformula_dir_pth",
18
+ "split_pth",
19
+ "checkpoint_pth",
20
+ "df_test_path",
21
+ "formula_to_smiles_pth",
22
+ }
23
+ )
24
+
25
+
26
+ def resolve_repo_paths(params: dict[str, Any]) -> None:
27
+ """In-place: turn repo-relative path strings into absolute paths."""
28
+ root = REPO_DIR
29
+ for key in _PATH_KEYS:
30
+ val = params.get(key)
31
+ if not val or not isinstance(val, str):
32
+ continue
33
+ p = Path(val)
34
+ if not p.is_absolute():
35
+ params[key] = str((root / p).resolve())
36
+ else:
37
+ params[key] = str(p.resolve())
38
+
39
+
40
+ def load_param_file(path: str | Path) -> dict[str, Any]:
41
+ """Load a YAML parameter file and resolve path fields."""
42
+ p = Path(path)
43
+ if not p.is_file():
44
+ raise FileNotFoundError(f"Parameter file not found: {p}")
45
+ with open(p, encoding="utf-8") as f:
46
+ params = yaml.load(f, Loader=yaml.FullLoader)
47
+ if params is None:
48
+ params = {}
49
+ if not isinstance(params, dict):
50
+ raise TypeError(f"Expected mapping at top level of {p}, got {type(params)}")
51
+ resolve_repo_paths(params)
52
+ return params
53
+
54
+
55
+ def default_param_path() -> Path:
56
+ """Path to the default params file (overridable with FLARE_PARAMS)."""
57
+ override = os.environ.get("FLARE_PARAMS")
58
+ if override:
59
+ return Path(override).expanduser()
60
+ env_root = os.environ.get("FLARE_REPO_ROOT")
61
+ if env_root:
62
+ return Path(env_root).expanduser() / "params.yaml"
63
+ return REPO_DIR / "params.yaml"
flare/utils/data.py CHANGED
@@ -21,6 +21,11 @@ class Subformula_Loader:
21
  self.dir_path = dir_path
22
  self.use_prec_mz = use_prec_mz
23
  self.formula_source = formula_source
 
 
 
 
 
24
  if spectra_view == 'SpecFormula':
25
  self.load = self.load_subformula_data
26
  elif spectra_view == "SpecFormulaMz":
@@ -63,77 +68,6 @@ class Subformula_Loader:
63
  except:
64
  return None
65
 
66
- def load_magma_data(self, data, curr_form, curr_prec_mz):
67
-
68
- np.random.seed(42)
69
-
70
- formula_to_intensity = {}
71
- formula_to_mz = {}
72
-
73
- # data is None
74
- if data is None:
75
- if self.use_prec_mz:
76
- return {'formulas': [curr_form], 'formula_mzs': [curr_prec_mz], 'formula_intensities': [PRECURSOR_INTENSITY]}
77
- else:
78
- return {'formulas': [], 'formula_mzs': [], 'formula_intensities': []}
79
-
80
- # randomly choose 1 formula for each peak, keep largest intensity for each formula
81
- if self.formula_source.endswith('1'):
82
- for f, m, i in zip(data['subformulas'], data['mz'], data['intensities']):
83
-
84
- if not f:
85
- continue
86
- selected_f = np.random.choice(f)
87
- if selected_f in formula_to_intensity:
88
- if i > formula_to_intensity[selected_f]:
89
- formula_to_intensity[selected_f] = i
90
- formula_to_mz[selected_f] = m
91
- else:
92
- formula_to_intensity[selected_f] = i
93
- formula_to_mz[selected_f] = m
94
-
95
- # take all formulas, divide intensity by number of formulas, keep largest intensity for each formula
96
- elif self.formula_source.endswith('all'):
97
- for f, m, i in zip(data['subformulas'], data['mz'], data['intensities']):
98
-
99
- if not f:
100
- continue
101
- for fi in f:
102
- if fi in formula_to_intensity:
103
- if i/len(f) > formula_to_intensity[fi]:
104
- formula_to_intensity[fi] = i/len(f)
105
- formula_to_mz[fi] = m
106
- else:
107
- formula_to_intensity[fi] = i/len(f)
108
- formula_to_mz[fi] = m
109
- else:
110
- raise Exception(f"Formula source not supported: {self.formula_source}")
111
-
112
- mzs = list(formula_to_mz.values())
113
- formulas = list(formula_to_mz.keys())
114
- intensities = list(formula_to_intensity.values())
115
-
116
- # add precursor mz
117
- if self.use_prec_mz:
118
- if curr_form in formulas:
119
- intensities[formulas.index(curr_form)] = PRECURSOR_INTENSITY
120
- else:
121
- formulas.append(curr_form)
122
- intensities.append(PRECURSOR_INTENSITY)
123
- mzs.append(curr_prec_mz)
124
-
125
- # sort by mzs
126
- mzs = np.array(mzs)
127
- formulas = np.array(formulas)
128
- intensities = np.array(intensities)
129
-
130
- ind = mzs.argsort()
131
- mzs = mzs[ind]
132
- formulas = formulas[ind]
133
- intensities = intensities[ind]
134
-
135
- return {'formulas': formulas, 'formula_mzs': mzs, 'formula_intensities': intensities}
136
-
137
  def load_sirius_data(self, data):
138
  try:
139
 
@@ -167,8 +101,6 @@ class Subformula_Loader:
167
  data = json.load(f)
168
  if self.formula_source == 'sirius':
169
  return self.load_sirius_data(data)
170
- elif self.formula_source.startswith('magma'):
171
- return self.load_magma_data(data, curr_form, curr_prec_mz)
172
  else:
173
  return self.load_mist_data(data, curr_form, curr_prec_mz)
174
 
@@ -263,17 +195,18 @@ def get_test_ms_dataset(spectra_view: T.Union[str, T.List[str]],
263
  else: views.extend(v)
264
  views = frozenset(views)
265
 
266
- dataset_params = {'spectra_view': spectra_view, 'pth': params['dataset_pth'], 'spec_transform': spectra_featurizer, 'mol_transform': mol_featurizer, "candidates_pth": params['candidates_pth']}
 
 
 
 
 
 
 
 
267
  if "SpecFormula" in views or "SpecFormulaMz" in views:
268
- dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'use_magma': params['formula_source'].startswith('magma'), 'formula_source':params['formula_source']})
269
- use_formulas = True
270
-
271
- # if params['use_cons_spec']:
272
- # dataset_params.update({'cons_spec_dir_pth': params['cons_spec_dir_pth']})
273
- # if 'use_NL_spec' in params and params['use_NL_spec']:
274
- # dataset_params.update({'NL_spec_dir_pth': params['NL_spec_dir_pth']})
275
- # if params['pred_fp'] or params['use_fp']:
276
- # dataset_params.update({'fp_dir_pth': '', 'fp_size': params['fp_size'], 'fp_radius': params['fp_radius']})
277
 
278
  return jestr_datasets.ExpandedRetrievalDataset(use_formulas=use_formulas, **dataset_params)
279
 
 
21
  self.dir_path = dir_path
22
  self.use_prec_mz = use_prec_mz
23
  self.formula_source = formula_source
24
+ if str(formula_source).startswith('magma'):
25
+ raise ValueError(
26
+ "MAGMA formula sources are not supported in this release (see archive/magma/). "
27
+ "Use 'default' (MIST-style JSON) or 'sirius'."
28
+ )
29
  if spectra_view == 'SpecFormula':
30
  self.load = self.load_subformula_data
31
  elif spectra_view == "SpecFormulaMz":
 
68
  except:
69
  return None
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def load_sirius_data(self, data):
72
  try:
73
 
 
101
  data = json.load(f)
102
  if self.formula_source == 'sirius':
103
  return self.load_sirius_data(data)
 
 
104
  else:
105
  return self.load_mist_data(data, curr_form, curr_prec_mz)
106
 
 
195
  else: views.extend(v)
196
  views = frozenset(views)
197
 
198
+ dataset_params = {
199
+ 'spectra_view': spectra_view,
200
+ 'pth': params['dataset_pth'],
201
+ 'spec_transform': spectra_featurizer,
202
+ 'mol_transform': mol_featurizer,
203
+ "candidates_pth": params.get('candidates_pth'),
204
+ "formula_to_smiles_pth": params.get('formula_to_smiles_pth'),
205
+ "external_test": params.get('external_test', False)
206
+ }
207
  if "SpecFormula" in views or "SpecFormulaMz" in views:
208
+ dataset_params.update({'subformula_dir_pth': params['subformula_dir_pth'], 'formula_source': params['formula_source']})
209
+ use_formulas = True
 
 
 
 
 
 
 
210
 
211
  return jestr_datasets.ExpandedRetrievalDataset(use_formulas=use_formulas, **dataset_params)
212
 
flare/utils/loss.py CHANGED
@@ -43,40 +43,6 @@ def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
43
 
44
  return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)
45
 
46
- def cand_spec_sim_loss(spec_enc, cand_enc):
47
- cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d
48
- spec_enc = spec_enc.unsqueeze(0) # 1 x B x d
49
-
50
- sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2)
51
- loss = torch.mean(sim)
52
-
53
- return loss
54
-
55
- class cons_spec_loss:
56
- def __init__(self, loss_type) -> None:
57
- self.loss_compute = {'cosine': self.cos_loss,
58
- 'l2':torch.nn.MSELoss()}[loss_type]
59
- def __call__(self,cons_spec, ind_spec):
60
- return self.loss_compute(cons_spec, ind_spec)
61
-
62
- def cos_loss(self, cons_spec, ind_spec):
63
- sim = nn.functional.cosine_similarity(cons_spec, ind_spec)
64
- loss = 1-torch.mean(sim)
65
- return loss
66
-
67
- class fp_loss:
68
- def __init__(self, loss_type) -> None:
69
- self.loss_compute = {'cosine': self.fp_loss_cos,
70
- 'bce': nn.BCELoss()}[loss_type]
71
-
72
- def __call__(self, predicted_fp, target_fp):
73
- return self.loss_compute(predicted_fp, target_fp)
74
-
75
- def fp_loss_cos(self, predicted_fp, target_fp):
76
- sim = nn.functional.cosine_similarity(predicted_fp, target_fp)
77
- return 1 - torch.mean(sim)
78
-
79
-
80
  # ---------- Utility ----------
81
  def _safe_divide(num, denom, eps=1e-8):
82
  return num / (denom + eps)
 
43
 
44
  return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # ---------- Utility ----------
47
  def _safe_divide(num, denom, eps=1e-8):
48
  return num / (denom + eps)
flare/utils/models.py CHANGED
@@ -1,6 +1,5 @@
1
  from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
2
  from flare.models.mol_encoder import MolEnc
3
- from flare.models.encoders import MLP
4
  from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive, FilipGlobalContrastive
5
 
6
  def get_spec_encoder(spec_enc:str, args):
@@ -13,12 +12,6 @@ def get_spec_encoder(spec_enc:str, args):
13
  def get_mol_encoder(mol_enc: str, args):
14
  return {'GNN': MolEnc}[mol_enc](args, in_dim=78)
15
 
16
- def get_fp_pred_model(args):
17
- return MLP(in_dim=args.final_embedding_dim, hidden_dims=[args.fp_size], final_activation='sigmoid', dropout=args.fp_dropout)
18
-
19
- def get_fp_enc_model(args):
20
- return MLP(in_dim=args.fp_size, hidden_dims=[args.final_embedding_dim,args.final_embedding_dim*2,args.final_embedding_dim,], final_activation=None, dropout=0.0)
21
-
22
  def get_model(model:str,
23
  params):
24
 
 
1
  from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer
2
  from flare.models.mol_encoder import MolEnc
 
3
  from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive, FilipGlobalContrastive
4
 
5
  def get_spec_encoder(spec_enc:str, args):
 
12
  def get_mol_encoder(mol_enc: str, args):
13
  return {'GNN': MolEnc}[mol_enc](args, in_dim=78)
14
 
 
 
 
 
 
 
15
  def get_model(model:str,
16
  params):
17
 
flare/utils/mol_search.py CHANGED
@@ -311,57 +311,7 @@ class SpectraMoleculeRetriever:
311
 
312
 
313
  if __name__ == "__main__":
314
- import sys
315
- sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
316
-
317
- from flare.utils.data import get_spec_featurizer, get_mol_featurizer
318
- from flare.utils.models import get_model
319
- from flare.utils.mol_search import SpectraMoleculeRetriever
320
- from flare.utils.general import filip_similarity_single
321
- import yaml
322
-
323
- metadata = {
324
- "class": {
325
- "lipid": ["mol1", "mol2"],
326
- "peptide": ["mol3"]
327
- },
328
- "pathway": {
329
- "beta-oxidation": ["mol1"],
330
- "glycolysis": ["mol2", "mol3"]
331
- }
332
- }
333
-
334
- smiles_dict = {
335
- "mol1": "CCO",
336
- "mol2": "CCN",
337
- "mol3": "CCC"
338
- }
339
-
340
- # Load model and data
341
- param_pth = '/data/yzhouc01/cancer/flare.yaml'
342
- with open(param_pth) as f:
343
- params = yaml.load(f, Loader=yaml.FullLoader)
344
-
345
- spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
346
- mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
347
-
348
-
349
- # load model
350
- checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250930_optimized_flare_42/epoch=1959-train_loss=0.08.ckpt"
351
- params['checkpoint_pth'] = checkpoint_pth
352
- model = get_model(params['model'], params)
353
-
354
- specMolRetriever = SpectraMoleculeRetriever(
355
- molecule_encoder=model.mol_enc_model,
356
- spectra_encoder=model.spec_enc_model,
357
- fine_similarity_fn=filip_similarity_single,
358
- smiles_preprocess=mol_featurizer
359
  )
360
-
361
- specMolRetriever.build_database(smiles_dict, metadata=metadata, cache_nodes=True)
362
-
363
- # Filter search to molecules in a specific pathway
364
- # results = specMolRetriever.search(spectrum, subset={"pathway": "beta-oxidation"})
365
-
366
- # for mol_id, score in results[:10]:
367
- # print(f"{mol_id}: {score:.3f}")
 
311
 
312
 
313
  if __name__ == "__main__":
314
+ raise SystemExit(
315
+ "SpectraMoleculeRetriever is a library class; configure paths via FLARE_PARAMS / "
316
+ "FLARE_CHECKPOINT and import it from application code (see README)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  )