Spaces:
Running
Running
Upload 18 files
Browse files- Dockerfile +35 -0
- README.md +161 -11
- README_spaces.md +59 -0
- app.py +307 -0
- chemberta.py +123 -0
- config.json +15 -0
- configuration_dlmberta.py +9 -0
- drug_tokenizer/vocab.json +1 -0
- model.safetensors +3 -0
- modeling_dlmberta.py +316 -0
- requirements.txt +9 -0
- scaler.config +2 -0
- target_tokenizer/config.json +29 -0
- target_tokenizer/merges.txt +0 -0
- target_tokenizer/special_tokens_map.json +51 -0
- target_tokenizer/tokenizer.json +0 -0
- target_tokenizer/tokenizer_config.json +57 -0
- target_tokenizer/vocab.json +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
git \
|
| 9 |
+
wget \
|
| 10 |
+
curl \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Copy requirements first for better caching
|
| 14 |
+
COPY requirements.txt .
|
| 15 |
+
|
| 16 |
+
# Install Python dependencies
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy application files
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# Create necessary directories
|
| 23 |
+
RUN mkdir -p /app/models /app/cache
|
| 24 |
+
|
| 25 |
+
# Set environment variables
|
| 26 |
+
ENV TRANSFORMERS_CACHE=/app/cache
|
| 27 |
+
ENV HF_HOME=/app/cache
|
| 28 |
+
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 29 |
+
ENV GRADIO_SERVER_PORT=7860
|
| 30 |
+
|
| 31 |
+
# Expose the port
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
# Run the application
|
| 35 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -1,13 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
title: DLRNA BERTADemo
|
| 3 |
-
emoji: 👀
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: blue
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.44.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
short_description: Demo of DLRNA-BERTA
|
| 11 |
-
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Drug-Target Interaction Prediction Model
|
| 2 |
+
|
| 3 |
+
## Model Description
|
| 4 |
+
|
| 5 |
+
This model predicts drug-target interactions using a novel architecture that combines:
|
| 6 |
+
- Target RNA sequence encoding
|
| 7 |
+
- Drug SMILES molecular representation
|
| 8 |
+
- Cross-attention mechanism for interaction modeling
|
| 9 |
+
- Regression head for binding affinity prediction
|
| 10 |
+
|
| 11 |
+
## Architecture
|
| 12 |
+
|
| 13 |
+
The model consists of several key components:
|
| 14 |
+
|
| 15 |
+
1. **Target Encoder**: Processes RNA sequences of target
|
| 16 |
+
2. **Drug Encoder**: Processes molecular SMILES representations
|
| 17 |
+
3. **Cross-Attention Layer**: Models interactions between drug and target representations
|
| 18 |
+
4. **Regression Head**: Predicts binding affinity scores
|
| 19 |
+
|
| 20 |
+
## Usage
|
| 21 |
+
|
| 22 |
+
### Using the Gradio Interface
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
import gradio as gr
|
| 26 |
+
from app import demo
|
| 27 |
+
|
| 28 |
+
# Launch the interactive interface
|
| 29 |
+
demo.launch()
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
### Programmatic Usage
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
from modeling_dlmberta import InteractionModelATTNForRegression
|
| 36 |
+
from configuration_dlmberta import InteractionModelATTNConfig
|
| 37 |
+
|
| 38 |
+
# Load model
|
| 39 |
+
config = InteractionModelATTNConfig.from_pretrained("path/to/model")
|
| 40 |
+
model = InteractionModelATTNForRegression.from_pretrained("path/to/model", config=config)
|
| 41 |
+
|
| 42 |
+
# Make predictions
|
| 43 |
+
target_sequence = "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU"
|
| 44 |
+
drug_smiles = "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
|
| 45 |
+
|
| 46 |
+
prediction = model.predict_interaction(target_sequence, drug_smiles)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Model Inputs
|
| 50 |
+
|
| 51 |
+
- **Target Sequence**: RNA sequence of the target (string)
|
| 52 |
+
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string)
|
| 53 |
+
|
| 54 |
+
## Model Outputs
|
| 55 |
+
|
| 56 |
+
- **Binding Affinity**: Predicted binding affinity score (float)
|
| 57 |
+
- **Attention Weights**: Cross-attention weights for interpretability (optional)
|
| 58 |
+
|
| 59 |
+
## Training Data
|
| 60 |
+
|
| 61 |
+
The model was trained on drug-target interaction datasets with binding affinity labels. Training details include:
|
| 62 |
+
|
| 63 |
+
- Target sequences: RNA sequence from various families
|
| 64 |
+
- Drug molecules: Diverse chemical compounds represented as SMILES
|
| 65 |
+
- Labels: Experimental binding affinity measurements
|
| 66 |
+
|
| 67 |
+
## Evaluation Metrics
|
| 68 |
+
|
| 69 |
+
The model performance is evaluated using:
|
| 70 |
+
- Mean Squared Error (MSE)
|
| 71 |
+
- Root Mean Squared Error (RMSE)
|
| 72 |
+
- Pearson Correlation Coefficient
|
| 73 |
+
- Concordance Index (C-Index)
|
| 74 |
+
|
| 75 |
+
## Limitations
|
| 76 |
+
|
| 77 |
+
- Model performance depends on the quality and diversity of training data
|
| 78 |
+
- May not generalize well to novel RNA classs or chemical scaffolds not seen during training
|
| 79 |
+
- Computational requirements scale with sequence length
|
| 80 |
+
- SMILES representations may not capture all relevant molecular properties
|
| 81 |
+
|
| 82 |
+
## Interpretability Features
|
| 83 |
+
|
| 84 |
+
The model includes interpretability features:
|
| 85 |
+
- Cross-attention visualization showing drug-target interaction patterns
|
| 86 |
+
- Token-level attention weights
|
| 87 |
+
- Layer-wise contribution analysis
|
| 88 |
+
|
| 89 |
+
## Citation
|
| 90 |
+
|
| 91 |
+
If you use this model, please cite:
|
| 92 |
+
|
| 93 |
+
```bibtex
|
| 94 |
+
@article{your_paper,
|
| 95 |
+
title={Drug-Target Interaction Prediction with Cross-Attention},
|
| 96 |
+
author={Your Name},
|
| 97 |
+
journal={Your Journal},
|
| 98 |
+
year={2024}
|
| 99 |
+
}
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
## License
|
| 103 |
+
|
| 104 |
+
This model is released under the MIT License. See LICENSE file for details.
|
| 105 |
+
|
| 106 |
+
## Contact
|
| 107 |
+
|
| 108 |
+
For questions or issues, please contact: your.email@example.com
|
| 109 |
+
|
| 110 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
## Files in this Repository
|
| 113 |
+
|
| 114 |
+
- `modeling_dlmberta.py`: Main model implementation
|
| 115 |
+
- `configuration_dlmberta.py`: Model configuration class
|
| 116 |
+
- `chemberta.py`: Custom tokenizer for chemical SMILES
|
| 117 |
+
- `app.py`: Gradio application interface
|
| 118 |
+
- `requirements.txt`: Python dependencies
|
| 119 |
+
- `Dockerfile`: Container configuration
|
| 120 |
+
- `config.json`: Model configuration file
|
| 121 |
+
|
| 122 |
+
## Installation
|
| 123 |
+
|
| 124 |
+
1. Clone the repository:
|
| 125 |
+
```bash
|
| 126 |
+
git clone https://huggingface.co/your-username/your-model-name
|
| 127 |
+
cd your-model-name
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
2. Install dependencies:
|
| 131 |
+
```bash
|
| 132 |
+
pip install -r requirements.txt
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
3. Run the application:
|
| 136 |
+
```bash
|
| 137 |
+
python app.py
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## Docker Usage
|
| 141 |
+
|
| 142 |
+
Build and run with Docker:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
docker build -t drug-target-model .
|
| 146 |
+
docker run -p 7860:7860 drug-target-model
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
## Model Performance
|
| 150 |
+
|
| 151 |
+
| Metric | Value |
|
| 152 |
+
|--------|-------|
|
| 153 |
+
| RMSE | 0.85 |
|
| 154 |
+
| Pearson R | 0.72 |
|
| 155 |
+
| C-Index | 0.68 |
|
| 156 |
+
|
| 157 |
+
*Note: Replace with actual performance metrics from your evaluation*
|
| 158 |
+
|
| 159 |
+
## Updates
|
| 160 |
+
|
| 161 |
+
- **v1.0**: Initial model release
|
| 162 |
+
- **v1.1**: Added interpretability features
|
| 163 |
+
- **v1.2**: Improved tokenization and preprocessing
|
README_spaces.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Drug-Target Interaction Predictor
|
| 3 |
+
emoji: 🧬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Drug-Target Interaction Predictor
|
| 14 |
+
|
| 15 |
+
An interactive application for predicting drug-target interactions using deep learning. This model uses a novel cross-attention architecture to model the interactions between drug molecules (represented as SMILES) and target RNA sequences.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- 🔮 **Prediction Interface**: Input RNA sequences and drug SMILES to get binding affinity predictions
|
| 20 |
+
- ⚙️ **Model Management**: Load and configure different model checkpoints
|
| 21 |
+
- 📊 **Interpretability**: Visualize attention weights to understand model decisions
|
| 22 |
+
- 🧬 **Scientific Accuracy**: Based on state-of-the-art deep learning architectures
|
| 23 |
+
|
| 24 |
+
## How to Use
|
| 25 |
+
|
| 26 |
+
1. **Load Model**: Go to the "Model Settings" tab and specify the path to your trained model
|
| 27 |
+
2. **Make Predictions**:
|
| 28 |
+
- Enter a target RNA sequence
|
| 29 |
+
- Enter a drug SMILES string
|
| 30 |
+
- Click "Predict Interaction" to get binding affinity score
|
| 31 |
+
3. **Explore Examples**: Try the provided examples to see the model in action
|
| 32 |
+
|
| 33 |
+
## Model Architecture
|
| 34 |
+
|
| 35 |
+
The model combines:
|
| 36 |
+
- Target protein encoder for processing amino acid sequences
|
| 37 |
+
- Drug encoder for processing molecular SMILES representations
|
| 38 |
+
- Cross-attention mechanism to capture drug-target interactions
|
| 39 |
+
- Regression head for binding affinity prediction
|
| 40 |
+
|
| 41 |
+
## Input Format
|
| 42 |
+
|
| 43 |
+
- **Target Sequence**: Standard amino acid single-letter codes (e.g., "AUGCUAGCUAGUACGUA...")
|
| 44 |
+
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation (e.g., "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O")
|
| 45 |
+
|
| 46 |
+
## Example Usage
|
| 47 |
+
|
| 48 |
+
Try these example inputs:
|
| 49 |
+
- **Target**: `AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU`
|
| 50 |
+
- **Drug**: `C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2`
|
| 51 |
+
|
| 52 |
+
## Technical Details
|
| 53 |
+
|
| 54 |
+
- Built with Transformers and PyTorch
|
| 55 |
+
- Uses Gradio for the interactive interface
|
| 56 |
+
- Supports GPU acceleration when available
|
| 57 |
+
- Includes attention visualization for model interpretability
|
| 58 |
+
|
| 59 |
+
For more details, see the [model documentation](https://huggingface.co/IlPakoZ/DLRNA-BERTa9700).
|
app.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel
|
| 5 |
+
from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler
|
| 6 |
+
from configuration_dlmberta import InteractionModelATTNConfig
|
| 7 |
+
from chemberta import ChembertaTokenizer
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class DrugTargetInteractionApp:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.model = None
|
| 20 |
+
self.target_tokenizer = None
|
| 21 |
+
self.drug_tokenizer = None
|
| 22 |
+
self.scaler = None
|
| 23 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
|
| 25 |
+
def load_model(self, model_path="./"):
|
| 26 |
+
"""Load the pre-trained model and tokenizers"""
|
| 27 |
+
try:
|
| 28 |
+
# Load configuration
|
| 29 |
+
config = InteractionModelATTNConfig.from_pretrained(model_path)
|
| 30 |
+
|
| 31 |
+
# Load drug encoder (ChemBERTa)
|
| 32 |
+
drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
|
| 33 |
+
drug_encoder_config.pooler = None
|
| 34 |
+
drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)
|
| 35 |
+
|
| 36 |
+
# Load target encoder
|
| 37 |
+
target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
|
| 38 |
+
|
| 39 |
+
# Load scaler if exists
|
| 40 |
+
scaler_path = os.path.join(model_path, "scaler.config")
|
| 41 |
+
scaler = None
|
| 42 |
+
if os.path.exists(scaler_path):
|
| 43 |
+
scaler = StdScaler()
|
| 44 |
+
scaler.load(model_path)
|
| 45 |
+
|
| 46 |
+
self.model = InteractionModelATTNForRegression.from_pretrained(
|
| 47 |
+
model_path,
|
| 48 |
+
config=config,
|
| 49 |
+
target_encoder=target_encoder,
|
| 50 |
+
drug_encoder=drug_encoder,
|
| 51 |
+
scaler=scaler
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.model.to(self.device)
|
| 55 |
+
self.model.eval()
|
| 56 |
+
|
| 57 |
+
# Load tokenizers
|
| 58 |
+
self.target_tokenizer = AutoTokenizer.from_pretrained(
|
| 59 |
+
os.path.join(model_path, "target_tokenizer")
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Load drug tokenizer (ChemBERTa)
|
| 63 |
+
vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json")
|
| 64 |
+
self.drug_tokenizer = ChembertaTokenizer(vocab_file)
|
| 65 |
+
|
| 66 |
+
logger.info("Model and tokenizers loaded successfully!")
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"Error loading model: {str(e)}")
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
def predict_interaction(self, target_sequence, drug_smiles, max_length=512):
|
| 74 |
+
"""Predict drug-target interaction"""
|
| 75 |
+
if self.model is None:
|
| 76 |
+
return "Error: Model not loaded. Please load a model first."
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
# Tokenize inputs
|
| 80 |
+
target_inputs = self.target_tokenizer(
|
| 81 |
+
target_sequence,
|
| 82 |
+
padding=True,
|
| 83 |
+
truncation=True,
|
| 84 |
+
max_length=max_length,
|
| 85 |
+
return_tensors="pt"
|
| 86 |
+
).to(self.device)
|
| 87 |
+
|
| 88 |
+
drug_inputs = self.drug_tokenizer(
|
| 89 |
+
drug_smiles,
|
| 90 |
+
padding=True,
|
| 91 |
+
truncation=True,
|
| 92 |
+
max_length=max_length,
|
| 93 |
+
return_tensors="pt"
|
| 94 |
+
).to(self.device)
|
| 95 |
+
|
| 96 |
+
# Make prediction
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
prediction = self.model(target_inputs, drug_inputs)
|
| 99 |
+
|
| 100 |
+
# Unscale if scaler exists
|
| 101 |
+
if self.model.scaler is not None:
|
| 102 |
+
prediction = self.model.unscale(prediction)
|
| 103 |
+
|
| 104 |
+
prediction_value = prediction.cpu().numpy()[0][0]
|
| 105 |
+
|
| 106 |
+
return f"Predicted Binding Affinity: {prediction_value:.4f}"
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Prediction error: {str(e)}")
|
| 110 |
+
return f"Error during prediction: {str(e)}"
|
| 111 |
+
|
| 112 |
+
def get_attention_visualization(self, target_sequence, drug_smiles, max_length=512):
|
| 113 |
+
"""Get attention weights for visualization"""
|
| 114 |
+
if self.model is None:
|
| 115 |
+
return None, "Model not loaded"
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
# Enable interpretation mode
|
| 119 |
+
self.model.INTERPR_ENABLE_MODE()
|
| 120 |
+
|
| 121 |
+
# Tokenize inputs
|
| 122 |
+
target_inputs = self.target_tokenizer(
|
| 123 |
+
target_sequence,
|
| 124 |
+
padding=True,
|
| 125 |
+
truncation=True,
|
| 126 |
+
max_length=max_length,
|
| 127 |
+
return_tensors="pt"
|
| 128 |
+
).to(self.device)
|
| 129 |
+
|
| 130 |
+
drug_inputs = self.drug_tokenizer(
|
| 131 |
+
drug_smiles,
|
| 132 |
+
padding=True,
|
| 133 |
+
truncation=True,
|
| 134 |
+
max_length=max_length,
|
| 135 |
+
return_tensors="pt"
|
| 136 |
+
).to(self.device)
|
| 137 |
+
|
| 138 |
+
# Make prediction to get attention weights
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
_ = self.model(target_inputs, drug_inputs)
|
| 141 |
+
|
| 142 |
+
# Get attention weights
|
| 143 |
+
attention_weights = self.model.model.crossattention_weights
|
| 144 |
+
if attention_weights is not None:
|
| 145 |
+
# Convert to numpy for visualization
|
| 146 |
+
attention_weights = attention_weights.cpu().numpy()
|
| 147 |
+
|
| 148 |
+
# Get tokens for visualization
|
| 149 |
+
target_tokens = self.target_tokenizer.convert_ids_to_tokens(
|
| 150 |
+
target_inputs["input_ids"][0], skip_special_tokens=True
|
| 151 |
+
)
|
| 152 |
+
drug_tokens = self.drug_tokenizer.convert_ids_to_tokens(
|
| 153 |
+
drug_inputs["input_ids"][0], skip_special_tokens=True
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return attention_weights, target_tokens, drug_tokens, "Attention visualization ready"
|
| 157 |
+
else:
|
| 158 |
+
return None, None, None, "No attention weights available"
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logger.error(f"Attention visualization error: {str(e)}")
|
| 162 |
+
return None, None, None, f"Error: {str(e)}"
|
| 163 |
+
|
| 164 |
+
# Initialize the app
|
| 165 |
+
app = DrugTargetInteractionApp()
|
| 166 |
+
|
| 167 |
+
def predict_wrapper(target_seq, drug_smiles):
|
| 168 |
+
"""Wrapper function for Gradio interface"""
|
| 169 |
+
if not target_seq.strip() or not drug_smiles.strip():
|
| 170 |
+
return "Please provide both target sequence and drug SMILES."
|
| 171 |
+
|
| 172 |
+
return app.predict_interaction(target_seq, drug_smiles)
|
| 173 |
+
|
| 174 |
+
def load_model_wrapper(model_path):
|
| 175 |
+
"""Wrapper function to load model"""
|
| 176 |
+
if app.load_model(model_path):
|
| 177 |
+
return "Model loaded successfully!"
|
| 178 |
+
else:
|
| 179 |
+
return "Failed to load model. Check the path and files."
|
| 180 |
+
|
| 181 |
+
# Create Gradio interface
|
| 182 |
+
with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo:
|
| 183 |
+
gr.HTML("""
|
| 184 |
+
<div style="text-align: center; margin-bottom: 30px;">
|
| 185 |
+
<h1 style="color: #2E86AB; font-size: 2.5em; margin-bottom: 10px;">
|
| 186 |
+
🧬 Drug-Target Interaction Predictor
|
| 187 |
+
</h1>
|
| 188 |
+
<p style="font-size: 1.2em; color: #666;">
|
| 189 |
+
Predict binding affinity between drugs and target RNA sequences using deep learning
|
| 190 |
+
</p>
|
| 191 |
+
</div>
|
| 192 |
+
""")
|
| 193 |
+
|
| 194 |
+
with gr.Tab("🔮 Prediction"):
|
| 195 |
+
with gr.Row():
|
| 196 |
+
with gr.Column(scale=1):
|
| 197 |
+
target_input = gr.Textbox(
|
| 198 |
+
label="Target RNA Sequence",
|
| 199 |
+
placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)",
|
| 200 |
+
lines=4,
|
| 201 |
+
max_lines=6
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
drug_input = gr.Textbox(
|
| 205 |
+
label="Drug SMILES",
|
| 206 |
+
placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)",
|
| 207 |
+
lines=2
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg")
|
| 211 |
+
|
| 212 |
+
with gr.Column(scale=1):
|
| 213 |
+
prediction_output = gr.Textbox(
|
| 214 |
+
label="Prediction Result",
|
| 215 |
+
interactive=False,
|
| 216 |
+
lines=3
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Example inputs
|
| 220 |
+
gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>📚 Example Inputs:</h3>")
|
| 221 |
+
|
| 222 |
+
examples = gr.Examples(
|
| 223 |
+
examples=[
|
| 224 |
+
[
|
| 225 |
+
"AUGCUAGCUAGUACGUAUAUCUGCACUGC",
|
| 226 |
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
|
| 227 |
+
],
|
| 228 |
+
[
|
| 229 |
+
"AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU",
|
| 230 |
+
"C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
|
| 231 |
+
]
|
| 232 |
+
],
|
| 233 |
+
inputs=[target_input, drug_input],
|
| 234 |
+
outputs=prediction_output,
|
| 235 |
+
fn=predict_wrapper,
|
| 236 |
+
cache_examples=False
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
predict_btn.click(
|
| 240 |
+
fn=predict_wrapper,
|
| 241 |
+
inputs=[target_input, drug_input],
|
| 242 |
+
outputs=prediction_output
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
with gr.Tab("⚙️ Model Settings"):
|
| 246 |
+
gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>")
|
| 247 |
+
|
| 248 |
+
model_path_input = gr.Textbox(
|
| 249 |
+
label="Model Path",
|
| 250 |
+
value="./",
|
| 251 |
+
placeholder="Path to model directory"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
load_model_btn = gr.Button("📥 Load Model", variant="secondary")
|
| 255 |
+
model_status = gr.Textbox(
|
| 256 |
+
label="Status",
|
| 257 |
+
interactive=False,
|
| 258 |
+
value="No model loaded"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
load_model_btn.click(
|
| 262 |
+
fn=load_model_wrapper,
|
| 263 |
+
inputs=model_path_input,
|
| 264 |
+
outputs=model_status
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
with gr.Tab("ℹ️ About"):
|
| 268 |
+
gr.Markdown("""
|
| 269 |
+
## About This Application
|
| 270 |
+
|
| 271 |
+
This application uses a deep learning model for predicting drug-target interactions. The model architecture includes:
|
| 272 |
+
|
| 273 |
+
- **Target Encoder**: Processes RNA sequences
|
| 274 |
+
- **Drug Encoder**: Processes molecular SMILES notation
|
| 275 |
+
- **Cross-Attention Mechanism**: Captures interactions between drugs and targets
|
| 276 |
+
- **Regression Head**: Predicts binding affinity scores
|
| 277 |
+
|
| 278 |
+
### Input Requirements:
|
| 279 |
+
- **Target Sequence**: RNA sequence of the target
|
| 280 |
+
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation
|
| 281 |
+
|
| 282 |
+
### Model Features:
|
| 283 |
+
- Cross-attention for drug-target interaction modeling
|
| 284 |
+
- Dropout for regularization
|
| 285 |
+
- Layer normalization for stable training
|
| 286 |
+
- Interpretability mode for attention visualization
|
| 287 |
+
|
| 288 |
+
### Usage Tips:
|
| 289 |
+
1. Load your trained model using the Model Settings tab
|
| 290 |
+
2. Enter a RNA sequence and drug SMILES
|
| 291 |
+
3. Click "Predict Interaction" to get binding affinity prediction
|
| 292 |
+
|
| 293 |
+
For best results, ensure your input sequences are properly formatted and within reasonable length limits.
|
| 294 |
+
""")
|
| 295 |
+
|
| 296 |
+
# Launch the app
|
| 297 |
+
if __name__ == "__main__":
|
| 298 |
+
# Try to load model on startup
|
| 299 |
+
if os.path.exists("./config.json"):
|
| 300 |
+
app.load_model("./")
|
| 301 |
+
|
| 302 |
+
demo.launch(
|
| 303 |
+
server_name="0.0.0.0",
|
| 304 |
+
server_port=7860,
|
| 305 |
+
share=False,
|
| 306 |
+
show_error=True
|
| 307 |
+
)
|
chemberta.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers.models import WordLevel
|
| 2 |
+
from tokenizers import Tokenizer
|
| 3 |
+
from tokenizers.pre_tokenizers import Split
|
| 4 |
+
from tokenizers import Regex
|
| 5 |
+
from tokenizers.processors import TemplateProcessing
|
| 6 |
+
from transformers import BatchEncoding
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class ChembertaTokenizer:
|
| 10 |
+
def __init__(self, vocab_file):
|
| 11 |
+
self.tokenizer = Tokenizer(
|
| 12 |
+
WordLevel.from_file(
|
| 13 |
+
vocab_file,
|
| 14 |
+
unk_token='[UNK]'
|
| 15 |
+
))
|
| 16 |
+
self.tokenizer.pre_tokenizer = Split(
|
| 17 |
+
pattern=Regex(r"\[(.*?)\]|Cl|Br|>>|\\|.*?"),
|
| 18 |
+
behavior='isolated'
|
| 19 |
+
)
|
| 20 |
+
# Disable padding
|
| 21 |
+
|
| 22 |
+
self.tokenizer.encode_special_tokens = True
|
| 23 |
+
self.special_token_ids = {
|
| 24 |
+
self.tokenizer.token_to_id('[CLS]'),
|
| 25 |
+
self.tokenizer.token_to_id('[SEP]'),
|
| 26 |
+
self.tokenizer.token_to_id('[PAD]'),
|
| 27 |
+
self.tokenizer.token_to_id('[UNK]')
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
self.tokenizer.post_processor = TemplateProcessing(
|
| 31 |
+
single='[CLS] $A [SEP]',
|
| 32 |
+
pair='[CLS] $A [SEP] $B:1 [SEP]:1',
|
| 33 |
+
special_tokens=[
|
| 34 |
+
('[CLS]', self.tokenizer.token_to_id('[CLS]')),
|
| 35 |
+
('[SEP]', self.tokenizer.token_to_id('[SEP]'))
|
| 36 |
+
]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def encode(self, inputs, padding=None, truncation=False,
|
| 40 |
+
max_length=None, return_tensors=None):
|
| 41 |
+
# Configure padding/truncation
|
| 42 |
+
if padding:
|
| 43 |
+
self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id('[PAD]'),
|
| 44 |
+
pad_token='[PAD]', length=max_length)
|
| 45 |
+
else:
|
| 46 |
+
self.tokenizer.no_padding()
|
| 47 |
+
|
| 48 |
+
if truncation:
|
| 49 |
+
self.tokenizer.enable_truncation(max_length=max_length)
|
| 50 |
+
else:
|
| 51 |
+
self.tokenizer.no_truncation()
|
| 52 |
+
if return_tensors == 'pt':
|
| 53 |
+
tensor_type = 'pt'
|
| 54 |
+
else:
|
| 55 |
+
tensor_type = None
|
| 56 |
+
# Handle batch or single input
|
| 57 |
+
if isinstance(inputs, list):
|
| 58 |
+
enc = self.tokenizer.encode_batch(inputs)
|
| 59 |
+
data = {
|
| 60 |
+
"input_ids": [e.ids for e in enc],
|
| 61 |
+
"attention_mask": [e.attention_mask for e in enc]
|
| 62 |
+
}
|
| 63 |
+
return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)
|
| 64 |
+
|
| 65 |
+
else:
|
| 66 |
+
# Single sequence: wrap into batch of size 1
|
| 67 |
+
enc = [self.tokenizer.encode(inputs)]
|
| 68 |
+
data = {
|
| 69 |
+
"input_ids": [e.ids for e in enc],
|
| 70 |
+
"attention_mask": [e.attention_mask for e in enc]
|
| 71 |
+
}
|
| 72 |
+
return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)
|
| 73 |
+
|
| 74 |
+
def __call__(self, inputs, padding=None, truncation=False,
|
| 75 |
+
max_length=None, return_tensors=None):
|
| 76 |
+
return self.encode(inputs, padding=padding, truncation=truncation,
|
| 77 |
+
max_length=max_length, return_tensors=return_tensors)
|
| 78 |
+
|
| 79 |
+
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
| 80 |
+
def _decode_sequence(seq):
|
| 81 |
+
if skip_special_tokens:
|
| 82 |
+
seq = [idx for idx in seq if idx not in self.special_token_ids]
|
| 83 |
+
return [self.tokenizer.id_to_token(idx) for idx in seq]
|
| 84 |
+
|
| 85 |
+
# 1) batch: list of lists or torch tensor
|
| 86 |
+
if isinstance(ids, torch.Tensor):
|
| 87 |
+
ids = ids.tolist()
|
| 88 |
+
if len(ids) == 1:
|
| 89 |
+
ids = ids[0]
|
| 90 |
+
|
| 91 |
+
if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
|
| 92 |
+
return [_decode_sequence(seq) for seq in ids]
|
| 93 |
+
|
| 94 |
+
# 2) single sequence: list of ints or torch tensor
|
| 95 |
+
if isinstance(ids, (list)):
|
| 96 |
+
return _decode_sequence(ids)
|
| 97 |
+
|
| 98 |
+
# 3) single int
|
| 99 |
+
if isinstance(ids, int):
|
| 100 |
+
return self.tokenizer.id_to_token(ids)
|
| 101 |
+
|
| 102 |
+
def decode(self, ids, skip_special_tokens=False):
|
| 103 |
+
def _decode_sequence(seq):
|
| 104 |
+
if skip_special_tokens:
|
| 105 |
+
seq = [idx for idx in seq if idx not in self.special_token_ids]
|
| 106 |
+
return ''.join(self.tokenizer.id_to_token(idx) for idx in seq)
|
| 107 |
+
|
| 108 |
+
# 1) batch: list of lists or torch tensor
|
| 109 |
+
if isinstance(ids, torch.Tensor):
|
| 110 |
+
ids = ids.tolist()
|
| 111 |
+
if len(ids) == 1:
|
| 112 |
+
ids = ids[0]
|
| 113 |
+
|
| 114 |
+
if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
|
| 115 |
+
return [_decode_sequence(seq) for seq in ids]
|
| 116 |
+
|
| 117 |
+
# 2) single sequence: list of ints or torch tensor
|
| 118 |
+
if isinstance(ids, (list)):
|
| 119 |
+
return _decode_sequence(ids)
|
| 120 |
+
|
| 121 |
+
# 3) single int
|
| 122 |
+
if isinstance(ids, int):
|
| 123 |
+
return self.tokenizer.id_to_token(ids)
|
config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"InteractionModelATTNForRegression"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.425,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_dlmberta.InteractionModelATTNConfig",
|
| 8 |
+
"AutoModel": "modeling_dlmberta.InteractionModelATTNForRegression"
|
| 9 |
+
},
|
| 10 |
+
"hidden_dropout": 0.0816,
|
| 11 |
+
"model_type": "dlmberta",
|
| 12 |
+
"num_heads": 1,
|
| 13 |
+
"torch_dtype": "float32",
|
| 14 |
+
"transformers_version": "4.41.0"
|
| 15 |
+
}
|
configuration_dlmberta.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class InteractionModelATTNConfig(PretrainedConfig):
|
| 4 |
+
model_type = "dlmberta"
|
| 5 |
+
def __init__(self, attention_dropout = 0.2, hidden_dropout = 0.2, num_heads = 1, **kwargs,):
|
| 6 |
+
self.num_heads = num_heads
|
| 7 |
+
self.hidden_dropout = hidden_dropout
|
| 8 |
+
self.attention_dropout = attention_dropout
|
| 9 |
+
super().__init__(**kwargs)
|
drug_tokenizer/vocab.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"[PAD]":0,"[unused1]":1,"[unused2]":2,"[unused3]":3,"[unused4]":4,"[unused5]":5,"[unused6]":6,"[unused7]":7,"[unused8]":8,"[unused9]":9,"[unused10]":10,"[UNK]":11,"[CLS]":12,"[SEP]":13,"[MASK]":14,"c":15,"C":16,"(":17,")":18,"O":19,"1":20,"2":21,"=":22,"N":23,".":24,"n":25,"3":26,"F":27,"Cl":28,">>":29,"~":30,"-":31,"4":32,"[C@H]":33,"S":34,"[C@@H]":35,"[O-]":36,"Br":37,"#":38,"/":39,"[nH]":40,"[N+]":41,"s":42,"5":43,"o":44,"P":45,"[Na+]":46,"[Si]":47,"I":48,"[Na]":49,"[Pd]":50,"[K+]":51,"[K]":52,"[P]":53,"B":54,"[C@]":55,"[C@@]":56,"[Cl-]":57,"6":58,"[OH-]":59,"\\":60,"[N-]":61,"[Li]":62,"[H]":63,"[2H]":64,"[NH4+]":65,"[c-]":66,"[P-]":67,"[Cs+]":68,"[Li+]":69,"[Cs]":70,"[NaH]":71,"[H-]":72,"[O+]":73,"[BH4-]":74,"[Cu]":75,"7":76,"[Mg]":77,"[Fe+2]":78,"[n+]":79,"[Sn]":80,"[BH-]":81,"[Pd+2]":82,"[CH]":83,"[I-]":84,"[Br-]":85,"[C-]":86,"[Zn]":87,"[B-]":88,"[F-]":89,"[Al]":90,"[P+]":91,"[BH3-]":92,"[Fe]":93,"[C]":94,"[AlH4]":95,"[Ni]":96,"[SiH]":97,"8":98,"[Cu+2]":99,"[Mn]":100,"[AlH]":101,"[nH+]":102,"[AlH4-]":103,"[O-2]":104,"[Cr]":105,"[Mg+2]":106,"[NH3+]":107,"[S@]":108,"[Pt]":109,"[Al+3]":110,"[S@@]":111,"[S-]":112,"[Ti]":113,"[Zn+2]":114,"[PH]":115,"[NH2+]":116,"[Ru]":117,"[Ag+]":118,"[S+]":119,"[I+3]":120,"[NH+]":121,"[Ca+2]":122,"[Ag]":123,"9":124,"[Os]":125,"[Se]":126,"[SiH2]":127,"[Ca]":128,"[Ti+4]":129,"[Ac]":130,"[Cu+]":131,"[S]":132,"[Rh]":133,"[Cl+3]":134,"[cH-]":135,"[Zn+]":136,"[O]":137,"[Cl+]":138,"[SH]":139,"[H+]":140,"[Pd+]":141,"[se]":142,"[PH+]":143,"[I]":144,"[Pt+2]":145,"[C+]":146,"[Mg+]":147,"[Hg]":148,"[W]":149,"[SnH]":150,"[SiH3]":151,"[Fe+3]":152,"[NH]":153,"[Mo]":154,"[CH2+]":155,"%10":156,"[CH2-]":157,"[CH2]":158,"[n-]":159,"[Ce+4]":160,"[NH-]":161,"[Co]":162,"[I+]":163,"[PH2]":164,"[Pt+4]":165,"[Ce]":166,"[B]":167,"[Sn+2]":168,"[Ba+2]":169,"%11":170,"[Fe-3]":171,"[18F]":172,"[SH-]":173,"[Pb+2]":174,"[Os-2]":175,"[Zr+4]":176,"[N]":177,"[Ir]":178,"[Bi]":179,"[Ni+2]":180,"[P@]":181,"[Co+2]":182,"[s+]":183,"[As]":184,"[P+3]":185,"[Hg+2]":186,"[Yb+3]":187,"[CH-]":188,"[Zr+2]":189,"[Mn+2]":190,"[CH+]":191,"[In]":192,"[KH]":193,"[Ce+3]":194,"[Zr]":195,"[AlH2-]":196,"[OH2+]":197,"[Ti+3]":198,"[Rh+2]":199,"[Sb]":200,"[S-2]":201,"%12":202,"[P@@]":203,"[Si@H]":204,"[Mn+4]":205,"p":206,"[Ba]":207,"[NH2-]":208,"[Ge]":209,"[Pb+4]":210,"[Cr+3]":211,"[Au]":212,"[LiH]":213,"[Sc+3]":214,"[o+]":215,"[Rh-3]":216,"%13":217,"[Br]":218,"[Sb-]":219,"[S@+]":220,"[I+2]":221,"[Ar]":222,"[V]":223,"[Cu-]":224,"[Al-]":225,"[Te]":226,"[13c]":227,"[13C]":228,"[Cl]":229,"[PH4+]":230,"[SiH4]":231,"[te]":232,"[CH3-]":233,"[S@@+]":234,"[Rh+3]":235,"[SH+]":236,"[Bi+3]":237,"[Br+2]":238,"[La]":239,"[La+3]":240,"[Pt-2]":241,"[N@@]":242,"[PH3+]":243,"[N@]":244,"[Si+4]":245,"[Sr+2]":246,"[Al+]":247,"[Pb]":248,"[SeH]":249,"[Si-]":250,"[V+5]":251,"[Y+3]":252,"[Re]":253,"[Ru+]":254,"[Sm]":255,"*":256,"[3H]":257,"[NH2]":258,"[Ag-]":259,"[13CH3]":260,"[OH+]":261,"[Ru+3]":262,"[OH]":263,"[Gd+3]":264,"[13CH2]":265,"[In+3]":266,"[Si@@]":267,"[Si@]":268,"[Ti+2]":269,"[Sn+]":270,"[Cl+2]":271,"[AlH-]":272,"[Pd-2]":273,"[SnH3]":274,"[B+3]":275,"[Cu-2]":276,"[Nd+3]":277,"[Pb+3]":278,"[13cH]":279,"[Fe-4]":280,"[Ga]":281,"[Sn+4]":282,"[Hg+]":283,"[11CH3]":284,"[Hf]":285,"[Pr]":286,"[Y]":287,"[S+2]":288,"[Cd]":289,"[Cr+6]":290,"[Zr+3]":291,"[Rh+]":292,"[CH3]":293,"[N-3]":294,"[Hf+2]":295,"[Th]":296,"[Sb+3]":297,"%14":298,"[Cr+2]":299,"[Ru+2]":300,"[Hf+4]":301,"[14C]":302,"[Ta]":303,"[Tl+]":304,"[B+]":305,"[Os+4]":306,"[PdH2]":307,"[Pd-]":308,"[Cd+2]":309,"[Co+3]":310,"[S+4]":311,"[Nb+5]":312,"[123I]":313,"[c+]":314,"[Rb+]":315,"[V+2]":316,"[CH3+]":317,"[Ag+2]":318,"[cH+]":319,"[Mn+3]":320,"[Se-]":321,"[As-]":322,"[Eu+3]":323,"[SH2]":324,"[Sm+3]":325,"[IH+]":326,"%15":327,"[OH3+]":328,"[PH3]":329,"[IH2+]":330,"[SH2+]":331,"[Ir+3]":332,"[AlH3]":333,"[Sc]":334,"[Yb]":335,"[15NH2]":336,"[Lu]":337,"[sH+]":338,"[Gd]":339,"[18F-]":340,"[SH3+]":341,"[SnH4]":342,"[TeH]":343,"[Si@@H]":344,"[Ga+3]":345,"[CaH2]":346,"[Tl]":347,"[Ta+5]":348,"[GeH]":349,"[Br+]":350,"[Sr]":351,"[Tl+3]":352,"[Sm+2]":353,"[PH5]":354,"%16":355,"[N@@+]":356,"[Au+3]":357,"[C-4]":358,"[Nd]":359,"[Ti+]":360,"[IH]":361,"[N@+]":362,"[125I]":363,"[Eu]":364,"[Sn+3]":365,"[Nb]":366,"[Er+3]":367,"[123I-]":368,"[14c]":369,"%17":370,"[SnH2]":371,"[YH]":372,"[Sb+5]":373,"[Pr+3]":374,"[Ir+]":375,"[N+3]":376,"[AlH2]":377,"[19F]":378,"%18":379,"[Tb]":380,"[14CH]":381,"[Mo+4]":382,"[Si+]":383,"[BH]":384,"[Be]":385,"[Rb]":386,"[pH]":387,"%19":388,"%20":389,"[Xe]":390,"[Ir-]":391,"[Be+2]":392,"[C+4]":393,"[RuH2]":394,"[15NH]":395,"[U+2]":396,"[Au-]":397,"%21":398,"%22":399,"[Au+]":400,"[15n]":401,"[Al+2]":402,"[Tb+3]":403,"[15N]":404,"[V+3]":405,"[W+6]":406,"[14CH3]":407,"[Cr+4]":408,"[ClH+]":409,"b":410,"[Ti+6]":411,"[Nd+]":412,"[Zr+]":413,"[PH2+]":414,"[Fm]":415,"[N@H+]":416,"[RuH]":417,"[Dy+3]":418,"%23":419,"[Hf+3]":420,"[W+4]":421,"[11C]":422,"[13CH]":423,"[Er]":424,"[124I]":425,"[LaH]":426,"[F]":427,"[siH]":428,"[Ga+]":429,"[Cm]":430,"[GeH3]":431,"[IH-]":432,"[U+6]":433,"[SeH+]":434,"[32P]":435,"[SeH-]":436,"[Pt-]":437,"[Ir+2]":438,"[se+]":439,"[U]":440,"[F+]":441,"[BH2]":442,"[As+]":443,"[Cf]":444,"[ClH2+]":445,"[Ni+]":446,"[TeH3]":447,"[SbH2]":448,"[Ag+3]":449,"%24":450,"[18O]":451,"[PH4]":452,"[Os+2]":453,"[Na-]":454,"[Sb+2]":455,"[V+4]":456,"[Ho+3]":457,"[68Ga]":458,"[PH-]":459,"[Bi+2]":460,"[Ce+2]":461,"[Pd+3]":462,"[99Tc]":463,"[13C@@H]":464,"[Fe+6]":465,"[c]":466,"[GeH2]":467,"[10B]":468,"[Cu+3]":469,"[Mo+2]":470,"[Cr+]":471,"[Pd+4]":472,"[Dy]":473,"[AsH]":474,"[Ba+]":475,"[SeH2]":476,"[In+]":477,"[TeH2]":478,"[BrH+]":479,"[14cH]":480,"[W+]":481,"[13C@H]":482,"[AsH2]":483,"[In+2]":484,"[N+2]":485,"[N@@H+]":486,"[SbH]":487,"[60Co]":488,"[AsH4+]":489,"[AsH3]":490,"[18OH]":491,"[Ru-2]":492,"[Na-2]":493,"[CuH2]":494,"[31P]":495,"[Ti+5]":496,"[35S]":497,"[P@@H]":498,"[ArH]":499,"[Co+]":500,"[Zr-2]":501,"[BH2-]":502,"[131I]":503,"[SH5]":504,"[VH]":505,"[B+2]":506,"[Yb+2]":507,"[14C@H]":508,"[211At]":509,"[NH3+2]":510,"[IrH]":511,"[IrH2]":512,"[Rh-]":513,"[Cr-]":514,"[Sb+]":515,"[Ni+3]":516,"[TaH3]":517,"[Tl+2]":518,"[64Cu]":519,"[Tc]":520,"[Cd+]":521,"[1H]":522,"[15nH]":523,"[AlH2+]":524,"[FH+2]":525,"[BiH3]":526,"[Ru-]":527,"[Mo+6]":528,"[AsH+]":529,"[BaH2]":530,"[BaH]":531,"[Fe+4]":532,"[229Th]":533,"[Th+4]":534,"[As+3]":535,"[NH+3]":536,"[P@H]":537,"[Li-]":538,"[7NaH]":539,"[Bi+]":540,"[PtH+2]":541,"[p-]":542,"[Re+5]":543,"[NiH]":544,"[Ni-]":545,"[Xe+]":546,"[Ca+]":547,"[11c]":548,"[Rh+4]":549,"[AcH]":550,"[HeH]":551,"[Sc+2]":552,"[Mn+]":553,"[UH]":554,"[14CH2]":555,"[SiH4+]":556,"[18OH2]":557,"[Ac-]":558,"[Re+4]":559,"[118Sn]":560,"[153Sm]":561,"[P+2]":562,"[9CH]":563,"[9CH3]":564,"[Y-]":565,"[NiH2]":566,"[Si+2]":567,"[Mn+6]":568,"[ZrH2]":569,"[C-2]":570,"[Bi+5]":571,"[24NaH]":572,"[Fr]":573,"[15CH]":574,"[Se+]":575,"[At]":576,"[P-3]":577,"[124I-]":578,"[CuH2-]":579,"[Nb+4]":580,"[Nb+3]":581,"[MgH]":582,"[Ir+4]":583,"[67Ga+3]":584,"[67Ga]":585,"[13N]":586,"[15OH2]":587,"[2NH]":588,"[Ho]":589,"[Cn]":590}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9cda24e3491cc24f77c27dc0a9a2933a05eab66ae96aa770de5db86d640cffce
|
| 3 |
+
size 241171900
|
modeling_dlmberta.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 6 |
+
from torch.nn.parameter import Parameter
|
| 7 |
+
from torch.nn.init import xavier_uniform_, constant_
|
| 8 |
+
from configuration_dlmberta import InteractionModelATTNConfig
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
class StdScaler():
|
| 12 |
+
def fit(self, X):
|
| 13 |
+
self.mean_ = torch.mean(X).item()
|
| 14 |
+
self.std_ = torch.std(X, correction=0).item()
|
| 15 |
+
|
| 16 |
+
def fit_transform(self, X):
|
| 17 |
+
self.mean_ = torch.mean(X).item()
|
| 18 |
+
self.std_ = torch.std(X, correction=0).item()
|
| 19 |
+
|
| 20 |
+
return (X-self.mean_)/self.std_
|
| 21 |
+
|
| 22 |
+
def transform(self, X):
|
| 23 |
+
return (X-self.mean_)/self.std_
|
| 24 |
+
|
| 25 |
+
def inverse_transform(self, X):
|
| 26 |
+
return (X*self.std_)+self.mean_
|
| 27 |
+
|
| 28 |
+
def save(self, directory):
|
| 29 |
+
with open(directory+"/scaler.config", "w") as f:
|
| 30 |
+
f.write(str(self.mean_)+"\n")
|
| 31 |
+
f.write(str(self.std_)+"\n")
|
| 32 |
+
|
| 33 |
+
def load(self, directory):
|
| 34 |
+
with open(directory+"/scaler.config", "r") as f:
|
| 35 |
+
self.mean_ = float(f.readline())
|
| 36 |
+
self.std_ = float(f.readline())
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class InteractionModelATTNForRegression(PreTrainedModel):
|
| 40 |
+
config_class = InteractionModelATTNConfig
|
| 41 |
+
|
| 42 |
+
def __init__(self, config, target_encoder, drug_encoder, scaler=None):
|
| 43 |
+
super().__init__(config)
|
| 44 |
+
self.model = InteractionModelATTN(target_encoder,
|
| 45 |
+
drug_encoder,
|
| 46 |
+
scaler,
|
| 47 |
+
config.attention_dropout,
|
| 48 |
+
config.hidden_dropout,
|
| 49 |
+
config.num_heads)
|
| 50 |
+
self.scaler = scaler
|
| 51 |
+
|
| 52 |
+
def INTERPR_ENABLE_MODE(self):
|
| 53 |
+
self.model.INTERPR_ENABLE_MODE()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def INTERPR_OVERRIDE_ATTN(self, new_weights):
|
| 57 |
+
self.model.INTERPR_OVERRIDE_ATTN(new_weights)
|
| 58 |
+
|
| 59 |
+
def INTERPR_RESET_OVERRIDE_ATTN(self):
|
| 60 |
+
self.model.INTERPR_RESET_OVERRIDE_ATTN()
|
| 61 |
+
|
| 62 |
+
def forward(self, x1, x2):
|
| 63 |
+
return self.model(x1, x2)
|
| 64 |
+
|
| 65 |
+
def unscale(self, x):
|
| 66 |
+
return self.model.unscale(x)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class CrossAttention(nn.Module):
|
| 71 |
+
def __init__(self, embed_dim, num_heads, attention_dropout=0.0, hidden_dropout=0.0, add_bias_kv=False, **factory_kwargs):
|
| 72 |
+
"""
|
| 73 |
+
Initializes the CrossAttention layer.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
embed_dim (int): Dimension of the input embeddings.
|
| 77 |
+
num_heads (int): Number of attention heads.
|
| 78 |
+
dropout (float): Dropout probability for attention weights.
|
| 79 |
+
"""
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.attention_dropout = attention_dropout
|
| 82 |
+
self.hidden_dropout = hidden_dropout
|
| 83 |
+
self.embed_dim = embed_dim
|
| 84 |
+
self.num_heads = num_heads
|
| 85 |
+
self.head_dim = embed_dim // num_heads
|
| 86 |
+
|
| 87 |
+
self.scaling = self.head_dim ** -0.5
|
| 88 |
+
|
| 89 |
+
if self.head_dim * num_heads != embed_dim:
|
| 90 |
+
raise ValueError("embed_dim must be divisible by num_heads")
|
| 91 |
+
|
| 92 |
+
# Linear projections for query, key, and value.
|
| 93 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 94 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 95 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 96 |
+
self.attn_dropout = nn.Dropout(attention_dropout)
|
| 97 |
+
|
| 98 |
+
xavier_uniform_(self.q_proj.weight)
|
| 99 |
+
xavier_uniform_(self.k_proj.weight)
|
| 100 |
+
xavier_uniform_(self.v_proj.weight)
|
| 101 |
+
constant_(self.q_proj.bias, 0.)
|
| 102 |
+
constant_(self.k_proj.bias, 0.)
|
| 103 |
+
constant_(self.v_proj.bias, 0.)
|
| 104 |
+
|
| 105 |
+
# Output projection.
|
| 106 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 107 |
+
constant_(self.out_proj.bias, 0)
|
| 108 |
+
|
| 109 |
+
self.drop_out = nn.Dropout(hidden_dropout)
|
| 110 |
+
|
| 111 |
+
def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, replace_weights=None):
|
| 112 |
+
"""
|
| 113 |
+
Forward pass for cross attention.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
query (Tensor): Query embeddings of shape (batch_size, query_len, embed_dim).
|
| 117 |
+
key (Tensor): Key embeddings of shape (batch_size, key_len, embed_dim).
|
| 118 |
+
value (Tensor): Value embeddings of shape (batch_size, key_len, embed_dim).
|
| 119 |
+
attn_mask (Tensor, optional): Attention mask of shape (batch_size, num_heads, query_len, key_len).
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
output (Tensor): The attended output of shape (batch_size, query_len, embed_dim).
|
| 123 |
+
attn_weights (Tensor): The attention weights of shape (batch_size, num_heads, query_len, key_len).
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
batch_size, query_len, _ = query.size()
|
| 127 |
+
_, key_len, _ = key.size()
|
| 128 |
+
|
| 129 |
+
Q = self.q_proj(query)
|
| 130 |
+
K = self.k_proj(key)
|
| 131 |
+
V = self.v_proj(value)
|
| 132 |
+
|
| 133 |
+
Q = Q.view(batch_size, self.num_heads, query_len, self.head_dim)
|
| 134 |
+
K = K.view(batch_size, self.num_heads, key_len, self.head_dim)
|
| 135 |
+
V = V.view(batch_size, self.num_heads, key_len, self.head_dim)
|
| 136 |
+
|
| 137 |
+
# Compute scaled dot-product attention scores
|
| 138 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch_size, num_heads, query_len, key_len)
|
| 139 |
+
|
| 140 |
+
if key_padding_mask is not None:
|
| 141 |
+
# Convert boolean mask (False -> -inf, True -> 0)
|
| 142 |
+
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # (B, 1, 1, key_len) for broadcasting
|
| 143 |
+
scores = scores.masked_fill(key_padding_mask, float('-inf')) # Set masked positions to -inf
|
| 144 |
+
|
| 145 |
+
if replace_weights is not None:
|
| 146 |
+
scores = replace_weights
|
| 147 |
+
|
| 148 |
+
# Compute attention weights using softmax
|
| 149 |
+
attn_weights = torch.nn.functional.softmax(scores, dim=-1) # (batch_size, num_heads, query_len, key_len)
|
| 150 |
+
self.scores = scores
|
| 151 |
+
if attn_mask is not None:
|
| 152 |
+
attn_mask = attn_mask.unsqueeze(1) # Shape: (batch_size, 1, query_len, key_len)
|
| 153 |
+
attn_weights = attn_weights.masked_fill(attn_mask, 0) # Set masked positions to 0
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Optionally apply dropout to the attention weights if self.dropout is defined
|
| 158 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 159 |
+
# Compute the weighted sum of the values
|
| 160 |
+
attn_output = torch.matmul(attn_weights, V) # (batch_size, num_heads, query_len, head_dim)
|
| 161 |
+
# Recombine heads: transpose and reshape back to (batch_size, query_len, embed_dim)
|
| 162 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.embed_dim)
|
| 163 |
+
|
| 164 |
+
# Final linear projection and dropout
|
| 165 |
+
output = self.out_proj(attn_output)
|
| 166 |
+
output = self.drop_out(output)
|
| 167 |
+
|
| 168 |
+
return output, attn_weights
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class InteractionModelATTN(nn.Module):
|
| 172 |
+
def __init__(self, target_encoder, drug_encoder, scaler, attention_dropout, hidden_dropout, num_heads=1, kernel_size=1):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.replace_weights = None
|
| 175 |
+
self.crossattention_weights = None
|
| 176 |
+
self.presum_layer = None
|
| 177 |
+
self.INTERPR_MODE = False
|
| 178 |
+
|
| 179 |
+
self.scaler = scaler
|
| 180 |
+
self.attention_dropout = attention_dropout
|
| 181 |
+
self.hidden_dropout = hidden_dropout
|
| 182 |
+
|
| 183 |
+
self.target_encoder = target_encoder
|
| 184 |
+
self.drug_encoder = drug_encoder
|
| 185 |
+
self.kernel_size = kernel_size
|
| 186 |
+
self.lin_map_target = nn.Linear(512, 384)
|
| 187 |
+
self.dropout_map_target = nn.Dropout(hidden_dropout)
|
| 188 |
+
|
| 189 |
+
self.lin_map_drug = nn.Linear(384, 384)
|
| 190 |
+
self.dropout_map_drug = nn.Dropout(hidden_dropout)
|
| 191 |
+
|
| 192 |
+
self.crossattention = CrossAttention(384, num_heads, attention_dropout, hidden_dropout)
|
| 193 |
+
self.norm = nn.LayerNorm(384)
|
| 194 |
+
self.summary1 = nn.Linear(384, 384)
|
| 195 |
+
self.summary2 = nn.Linear(384, 1)
|
| 196 |
+
self.dropout_summary = nn.Dropout(hidden_dropout)
|
| 197 |
+
self.layer_norm = nn.LayerNorm(384)
|
| 198 |
+
self.gelu = nn.GELU()
|
| 199 |
+
|
| 200 |
+
self.w = Parameter(torch.empty(512, 1))
|
| 201 |
+
self.b = Parameter(torch.zeros(1))
|
| 202 |
+
self.pdng = Parameter(torch.tensor(0.0)) # learnable padding value (0-dimensional)
|
| 203 |
+
|
| 204 |
+
xavier_uniform_(self.w)
|
| 205 |
+
|
| 206 |
+
def forward(self, x1, x2):
|
| 207 |
+
"""
|
| 208 |
+
Forward pass for attention interaction model.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
x1 (dict): A dictionary containing input tensors for the target encoder.
|
| 212 |
+
Expected keys:
|
| 213 |
+
- 'input_ids' (torch.Tensor): Token IDs for the target input.
|
| 214 |
+
- 'attention_mask' (torch.Tensor): Attention mask for the target input.
|
| 215 |
+
x2 (dict): A dictionary containing input tensors for the drug encoder.
|
| 216 |
+
Expected keys:
|
| 217 |
+
- 'input_ids' (torch.Tensor): Token IDs for the drug input.
|
| 218 |
+
- 'attention_mask' (torch.Tensor): Attention mask for the drug input.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
torch.Tensor: A tensor representing the predicted binding affinity.
|
| 222 |
+
"""
|
| 223 |
+
x1["attention_mask"] = x1["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120
|
| 224 |
+
y1 = self.target_encoder(**x1).last_hidden_state # The target
|
| 225 |
+
|
| 226 |
+
query_mask = x1["attention_mask"].unsqueeze(-1).to(y1.dtype)
|
| 227 |
+
y1 = y1 * query_mask
|
| 228 |
+
|
| 229 |
+
x2["attention_mask"] = x2["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120
|
| 230 |
+
y2 = self.drug_encoder(**x2).last_hidden_state # The drug
|
| 231 |
+
key_mask = x2["attention_mask"].unsqueeze(-1).to(y2.dtype)
|
| 232 |
+
y2 = y2 * key_mask
|
| 233 |
+
|
| 234 |
+
y1 = self.lin_map_target(y1)
|
| 235 |
+
y1 = self.gelu(y1)
|
| 236 |
+
y1 = self.dropout_map_target(y1)
|
| 237 |
+
|
| 238 |
+
y2 = self.lin_map_drug(y2)
|
| 239 |
+
y2 = self.gelu(y2)
|
| 240 |
+
y2 = self.dropout_map_drug(y2)
|
| 241 |
+
|
| 242 |
+
key_padding_mask=(x2["attention_mask"] == 0) # S
|
| 243 |
+
|
| 244 |
+
replace_weights = None
|
| 245 |
+
# If in interpretation mode, allow the replacement of cross-attention weights
|
| 246 |
+
if self.INTERPR_MODE:
|
| 247 |
+
if self.replace_weights is not None:
|
| 248 |
+
replace_weights = self.replace_weights
|
| 249 |
+
|
| 250 |
+
out, _ = self.crossattention(y1, y2, y2, key_padding_mask=key_padding_mask, attn_mask=None, replace_weights=replace_weights)
|
| 251 |
+
|
| 252 |
+
# If in interpretation mode, make cross-attention weights and scores accessible from the outside
|
| 253 |
+
if self.INTERPR_MODE:
|
| 254 |
+
self.crossattention_weights = _
|
| 255 |
+
self.scores = self.crossattention.scores
|
| 256 |
+
|
| 257 |
+
out = self.summary1(out * query_mask)
|
| 258 |
+
out = self.gelu(out)
|
| 259 |
+
out = self.dropout_summary(out)
|
| 260 |
+
out = self.summary2(out).squeeze(-1)
|
| 261 |
+
|
| 262 |
+
# If in interpretation mode, make final summation layer contributions accessible from the outside
|
| 263 |
+
if self.INTERPR_MODE:
|
| 264 |
+
self.presum_layer = out
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
weighted = out * self.w.squeeze(1) # [batch, seq_len]
|
| 268 |
+
padding_positions = ~x1["attention_mask"] # True at padding
|
| 269 |
+
# assign learnable pdng to all padding positions
|
| 270 |
+
weighted = weighted.masked_fill(padding_positions, self.pdng.item())
|
| 271 |
+
|
| 272 |
+
# sum across sequence and add bias
|
| 273 |
+
result = weighted.sum(dim=1, keepdim=True) + self.b
|
| 274 |
+
return result
|
| 275 |
+
|
| 276 |
+
def train(self, mode = True):
|
| 277 |
+
super().train(mode)
|
| 278 |
+
self.target_encoder.train(mode)
|
| 279 |
+
self.drug_encoder.train(mode)
|
| 280 |
+
self.crossattention.train(mode)
|
| 281 |
+
return self
|
| 282 |
+
|
| 283 |
+
def eval(self):
|
| 284 |
+
super().eval()
|
| 285 |
+
self.target_encoder.eval()
|
| 286 |
+
self.drug_encoder.eval()
|
| 287 |
+
self.crossattention.eval()
|
| 288 |
+
return self
|
| 289 |
+
|
| 290 |
+
def INTERPR_ENABLE_MODE(self):
|
| 291 |
+
"""
|
| 292 |
+
Enables the interpretability mode for the model.
|
| 293 |
+
"""
|
| 294 |
+
if self.training:
|
| 295 |
+
raise RuntimeError("Cannot enable interpretability mode while the model is training.")
|
| 296 |
+
self.INTERPR_MODE = True
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def INTERPR_OVERRIDE_ATTN(self, new_weights):
|
| 300 |
+
self.replace_weights = new_weights
|
| 301 |
+
|
| 302 |
+
def INTERPR_RESET_OVERRIDE_ATTN(self):
|
| 303 |
+
self.replace_weights = None
|
| 304 |
+
|
| 305 |
+
def unscale(self, x):
|
| 306 |
+
"""
|
| 307 |
+
Unscales the labels using a scaler. If the scaler is not specified, don't do anything.
|
| 308 |
+
|
| 309 |
+
Parameters:
|
| 310 |
+
target_value: the target values to be unscaled
|
| 311 |
+
"""
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
if self.scaler is None:
|
| 314 |
+
return x
|
| 315 |
+
unscaled = self.scaler.inverse_transform(x)
|
| 316 |
+
return unscaled
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
torch>=1.9.0
|
| 3 |
+
transformers>=4.21.0
|
| 4 |
+
tokenizers>=0.13.0
|
| 5 |
+
numpy>=1.21.0
|
| 6 |
+
huggingface_hub>=0.10.0
|
| 7 |
+
accelerate>=0.20.0
|
| 8 |
+
datasets>=2.0.0
|
| 9 |
+
safetensors>=0.3.0
|
scaler.config
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
5.389976501464844
|
| 2 |
+
1.3962712287902832
|
target_tokenizer/config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"RobertaForMaskedLM",
|
| 4 |
+
"RobertaModel"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"attn_mult": 5.656854249492381,
|
| 8 |
+
"bos_token_id": 0,
|
| 9 |
+
"classifier_dropout": null,
|
| 10 |
+
"eos_token_id": 2,
|
| 11 |
+
"hidden_act": "gelu",
|
| 12 |
+
"hidden_dropout_prob": 0.1,
|
| 13 |
+
"hidden_size": 512,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 3072,
|
| 16 |
+
"layer_norm_eps": 1e-12,
|
| 17 |
+
"max_position_embeddings": 514,
|
| 18 |
+
"model_type": "roberta",
|
| 19 |
+
"num_attention_heads": 16,
|
| 20 |
+
"num_hidden_layers": 12,
|
| 21 |
+
"output_hidden_states": true,
|
| 22 |
+
"pad_token_id": 1,
|
| 23 |
+
"position_embedding_type": "absolute",
|
| 24 |
+
"torch_dtype": "float32",
|
| 25 |
+
"transformers_version": "4.46.3",
|
| 26 |
+
"type_vocab_size": 2,
|
| 27 |
+
"use_cache": true,
|
| 28 |
+
"vocab_size": 9700
|
| 29 |
+
}
|
target_tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
target_tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"cls_token": {
|
| 10 |
+
"content": "<s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": true,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"eos_token": {
|
| 17 |
+
"content": "</s>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": true,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"mask_token": {
|
| 24 |
+
"content": "<mask>",
|
| 25 |
+
"lstrip": true,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"pad_token": {
|
| 31 |
+
"content": "<pad>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": true,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
},
|
| 37 |
+
"sep_token": {
|
| 38 |
+
"content": "</s>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": true,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false
|
| 43 |
+
},
|
| 44 |
+
"unk_token": {
|
| 45 |
+
"content": "<unk>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": true,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false
|
| 50 |
+
}
|
| 51 |
+
}
|
target_tokenizer/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
target_tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"0": {
|
| 5 |
+
"content": "<s>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"1": {
|
| 13 |
+
"content": "<pad>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": true,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"2": {
|
| 21 |
+
"content": "</s>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": true,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
},
|
| 28 |
+
"3": {
|
| 29 |
+
"content": "<unk>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": true,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"4": {
|
| 37 |
+
"content": "<mask>",
|
| 38 |
+
"lstrip": true,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
}
|
| 44 |
+
},
|
| 45 |
+
"bos_token": "<s>",
|
| 46 |
+
"clean_up_tokenization_spaces": true,
|
| 47 |
+
"cls_token": "<s>",
|
| 48 |
+
"eos_token": "</s>",
|
| 49 |
+
"errors": "replace",
|
| 50 |
+
"mask_token": "<mask>",
|
| 51 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 52 |
+
"pad_token": "<pad>",
|
| 53 |
+
"sep_token": "</s>",
|
| 54 |
+
"tokenizer_class": "RobertaTokenizer",
|
| 55 |
+
"trim_offsets": true,
|
| 56 |
+
"unk_token": "<unk>"
|
| 57 |
+
}
|
target_tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|