D-BETA / README.md
Manhph2211's picture
.
a056e69 verified
---
license: apache-2.0
language:
- en
library_name: transformers
pipeline_tag: feature-extraction
tags:
- medical
- cardiovascular
- ecg
- ecg-text representation learning
- ecg-foundation-model
- pytorch
---
<div align="center" style="font-size: 2em;">
<strong>Boosting Masked ECG-Text Auto-Encoders as Discriminative Learners</strong>
</div>
<div align="center" style="font-size: 2em;">
<strong>(ICML 2025)</strong>
</div>
<div align="center">
<a href="https://manhph2211.github.io/D-BETA/" style="display:inline-block;">
<img src="https://img.shields.io/badge/Website-DBETA WebPage-blue?style=for-the-badge">
</a>
<a href="https://arxiv.org/pdf/2410.02131" style="display:inline-block;">
<img src="https://img.shields.io/badge/arxiv-Paper-red?style=for-the-badge">
</a>
<a href="https://huggingface.co/Manhph2211" style="display:inline-block;">
<img src="https://img.shields.io/badge/Checkpoint-%F0%9F%A4%97%20Hugging%20Face-White?style=for-the-badge">
</a>
</div>
<div align="center">
<a href="https://maxph2211.dev/" target="_blank">Hung&nbsp;Manh&nbsp;Pham</a> &emsp;
<a href="https://aqibsaeed.github.io/" target="_blank">Aaqib&nbsp;Saeed</a> &emsp;
<a href="https://www.dongma.info/" target="_blank">Dong&nbsp;Ma</a> &emsp;
</div>
<br>
## Load with `transformers==4.36.2`
```python
from transformers import AutoModel
import torch
model = AutoModel.from_pretrained("Manhph2211/D-BETA", trust_remote_code=True)
model.eval()
ecgs = torch.randn(2, 12, 5000) # [batch, leads, length]
with torch.no_grad():
output = model(ecgs)
ecg_features = output.pooler_output
print(ecg_features.shape) # (2, 768)
```
## Load with the GitHub repo
Clone the project and prepare the environment:
```bash
git clone https://github.com/manhph2211/D-BETA.git && cd D-BETA
conda create -n dbeta python=3.9
conda activate dbeta
pip install -r requirements.txt
```
```python
import torch
from models.processor import get_model, get_ecg_feats
model = get_model(config_path='configs/config.json', checkpoint_path='checkpoints/pytorch_model.bin')
ecgs = torch.randn(2, 12, 5000) # [batch, leads, length]
ecg_features = get_ecg_feats(model, ecgs)
print(ecg_features.shape) # (2, 768)
```
### Citation
If you find this work useful, please consider citing our paper:
```bibtex
@inproceedings{
hung2025boosting,
title={Boosting Masked {ECG}-Text Auto-Encoders as Discriminative Learners},
author={Manh Pham Hung and Aaqib Saeed and Dong Ma},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
}
```