IlPakoZ commited on
Commit
3912a9f
·
verified ·
1 Parent(s): a857e83

Upload 18 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ git \
9
+ wget \
10
+ curl \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first for better caching
14
+ COPY requirements.txt .
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application files
20
+ COPY . .
21
+
22
+ # Create necessary directories
23
+ RUN mkdir -p /app/models /app/cache
24
+
25
+ # Set environment variables
26
+ ENV TRANSFORMERS_CACHE=/app/cache
27
+ ENV HF_HOME=/app/cache
28
+ ENV GRADIO_SERVER_NAME=0.0.0.0
29
+ ENV GRADIO_SERVER_PORT=7860
30
+
31
+ # Expose the port
32
+ EXPOSE 7860
33
+
34
+ # Run the application
35
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,13 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: DLRNA BERTADemo
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.44.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: Demo of DLRNA-BERTA
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Drug-Target Interaction Prediction Model
2
+
3
+ ## Model Description
4
+
5
+ This model predicts drug-target interactions using a novel architecture that combines:
6
+ - Target RNA sequence encoding
7
+ - Drug SMILES molecular representation
8
+ - Cross-attention mechanism for interaction modeling
9
+ - Regression head for binding affinity prediction
10
+
11
+ ## Architecture
12
+
13
+ The model consists of several key components:
14
+
15
+ 1. **Target Encoder**: Processes RNA sequences of target
16
+ 2. **Drug Encoder**: Processes molecular SMILES representations
17
+ 3. **Cross-Attention Layer**: Models interactions between drug and target representations
18
+ 4. **Regression Head**: Predicts binding affinity scores
19
+
20
+ ## Usage
21
+
22
+ ### Using the Gradio Interface
23
+
24
+ ```python
25
+ import gradio as gr
26
+ from app import demo
27
+
28
+ # Launch the interactive interface
29
+ demo.launch()
30
+ ```
31
+
32
+ ### Programmatic Usage
33
+
34
+ ```python
35
+ from modeling_dlmberta import InteractionModelATTNForRegression
36
+ from configuration_dlmberta import InteractionModelATTNConfig
37
+
38
+ # Load model
39
+ config = InteractionModelATTNConfig.from_pretrained("path/to/model")
40
+ model = InteractionModelATTNForRegression.from_pretrained("path/to/model", config=config)
41
+
42
+ # Make predictions
43
+ target_sequence = "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU"
44
+ drug_smiles = "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
45
+
46
+ prediction = model.predict_interaction(target_sequence, drug_smiles)
47
+ ```
48
+
49
+ ## Model Inputs
50
+
51
+ - **Target Sequence**: RNA sequence of the target (string)
52
+ - **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string)
53
+
54
+ ## Model Outputs
55
+
56
+ - **Binding Affinity**: Predicted binding affinity score (float)
57
+ - **Attention Weights**: Cross-attention weights for interpretability (optional)
58
+
59
+ ## Training Data
60
+
61
+ The model was trained on drug-target interaction datasets with binding affinity labels. Training details include:
62
+
63
+ - Target sequences: RNA sequence from various families
64
+ - Drug molecules: Diverse chemical compounds represented as SMILES
65
+ - Labels: Experimental binding affinity measurements
66
+
67
+ ## Evaluation Metrics
68
+
69
+ The model performance is evaluated using:
70
+ - Mean Squared Error (MSE)
71
+ - Root Mean Squared Error (RMSE)
72
+ - Pearson Correlation Coefficient
73
+ - Concordance Index (C-Index)
74
+
75
+ ## Limitations
76
+
77
+ - Model performance depends on the quality and diversity of training data
78
+ - May not generalize well to novel RNA classs or chemical scaffolds not seen during training
79
+ - Computational requirements scale with sequence length
80
+ - SMILES representations may not capture all relevant molecular properties
81
+
82
+ ## Interpretability Features
83
+
84
+ The model includes interpretability features:
85
+ - Cross-attention visualization showing drug-target interaction patterns
86
+ - Token-level attention weights
87
+ - Layer-wise contribution analysis
88
+
89
+ ## Citation
90
+
91
+ If you use this model, please cite:
92
+
93
+ ```bibtex
94
+ @article{your_paper,
95
+ title={Drug-Target Interaction Prediction with Cross-Attention},
96
+ author={Your Name},
97
+ journal={Your Journal},
98
+ year={2024}
99
+ }
100
+ ```
101
+
102
+ ## License
103
+
104
+ This model is released under the MIT License. See LICENSE file for details.
105
+
106
+ ## Contact
107
+
108
+ For questions or issues, please contact: your.email@example.com
109
+
110
  ---
 
 
 
 
 
 
 
 
 
 
111
 
112
+ ## Files in this Repository
113
+
114
+ - `modeling_dlmberta.py`: Main model implementation
115
+ - `configuration_dlmberta.py`: Model configuration class
116
+ - `chemberta.py`: Custom tokenizer for chemical SMILES
117
+ - `app.py`: Gradio application interface
118
+ - `requirements.txt`: Python dependencies
119
+ - `Dockerfile`: Container configuration
120
+ - `config.json`: Model configuration file
121
+
122
+ ## Installation
123
+
124
+ 1. Clone the repository:
125
+ ```bash
126
+ git clone https://huggingface.co/your-username/your-model-name
127
+ cd your-model-name
128
+ ```
129
+
130
+ 2. Install dependencies:
131
+ ```bash
132
+ pip install -r requirements.txt
133
+ ```
134
+
135
+ 3. Run the application:
136
+ ```bash
137
+ python app.py
138
+ ```
139
+
140
+ ## Docker Usage
141
+
142
+ Build and run with Docker:
143
+
144
+ ```bash
145
+ docker build -t drug-target-model .
146
+ docker run -p 7860:7860 drug-target-model
147
+ ```
148
+
149
+ ## Model Performance
150
+
151
+ | Metric | Value |
152
+ |--------|-------|
153
+ | RMSE | 0.85 |
154
+ | Pearson R | 0.72 |
155
+ | C-Index | 0.68 |
156
+
157
+ *Note: Replace with actual performance metrics from your evaluation*
158
+
159
+ ## Updates
160
+
161
+ - **v1.0**: Initial model release
162
+ - **v1.1**: Added interpretability features
163
+ - **v1.2**: Improved tokenization and preprocessing
README_spaces.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Drug-Target Interaction Predictor
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # Drug-Target Interaction Predictor
14
+
15
+ An interactive application for predicting drug-target interactions using deep learning. This model uses a novel cross-attention architecture to model the interactions between drug molecules (represented as SMILES) and target RNA sequences.
16
+
17
+ ## Features
18
+
19
+ - 🔮 **Prediction Interface**: Input RNA sequences and drug SMILES to get binding affinity predictions
20
+ - ⚙️ **Model Management**: Load and configure different model checkpoints
21
+ - 📊 **Interpretability**: Visualize attention weights to understand model decisions
22
+ - 🧬 **Scientific Accuracy**: Based on state-of-the-art deep learning architectures
23
+
24
+ ## How to Use
25
+
26
+ 1. **Load Model**: Go to the "Model Settings" tab and specify the path to your trained model
27
+ 2. **Make Predictions**:
28
+ - Enter a target RNA sequence
29
+ - Enter a drug SMILES string
30
+ - Click "Predict Interaction" to get binding affinity score
31
+ 3. **Explore Examples**: Try the provided examples to see the model in action
32
+
33
+ ## Model Architecture
34
+
35
+ The model combines:
36
+ - Target protein encoder for processing amino acid sequences
37
+ - Drug encoder for processing molecular SMILES representations
38
+ - Cross-attention mechanism to capture drug-target interactions
39
+ - Regression head for binding affinity prediction
40
+
41
+ ## Input Format
42
+
43
+ - **Target Sequence**: Standard amino acid single-letter codes (e.g., "AUGCUAGCUAGUACGUA...")
44
+ - **Drug SMILES**: Simplified Molecular Input Line Entry System notation (e.g., "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O")
45
+
46
+ ## Example Usage
47
+
48
+ Try these example inputs:
49
+ - **Target**: `AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU`
50
+ - **Drug**: `C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2`
51
+
52
+ ## Technical Details
53
+
54
+ - Built with Transformers and PyTorch
55
+ - Uses Gradio for the interactive interface
56
+ - Supports GPU acceleration when available
57
+ - Includes attention visualization for model interpretability
58
+
59
+ For more details, see the [model documentation](https://huggingface.co/IlPakoZ/DLRNA-BERTa9700).
app.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel
5
+ from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler
6
+ from configuration_dlmberta import InteractionModelATTNConfig
7
+ from chemberta import ChembertaTokenizer
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ import logging
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class DrugTargetInteractionApp:
18
+ def __init__(self):
19
+ self.model = None
20
+ self.target_tokenizer = None
21
+ self.drug_tokenizer = None
22
+ self.scaler = None
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ def load_model(self, model_path="./"):
26
+ """Load the pre-trained model and tokenizers"""
27
+ try:
28
+ # Load configuration
29
+ config = InteractionModelATTNConfig.from_pretrained(model_path)
30
+
31
+ # Load drug encoder (ChemBERTa)
32
+ drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
33
+ drug_encoder_config.pooler = None
34
+ drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)
35
+
36
+ # Load target encoder
37
+ target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
38
+
39
+ # Load scaler if exists
40
+ scaler_path = os.path.join(model_path, "scaler.config")
41
+ scaler = None
42
+ if os.path.exists(scaler_path):
43
+ scaler = StdScaler()
44
+ scaler.load(model_path)
45
+
46
+ self.model = InteractionModelATTNForRegression.from_pretrained(
47
+ model_path,
48
+ config=config,
49
+ target_encoder=target_encoder,
50
+ drug_encoder=drug_encoder,
51
+ scaler=scaler
52
+ )
53
+
54
+ self.model.to(self.device)
55
+ self.model.eval()
56
+
57
+ # Load tokenizers
58
+ self.target_tokenizer = AutoTokenizer.from_pretrained(
59
+ os.path.join(model_path, "target_tokenizer")
60
+ )
61
+
62
+ # Load drug tokenizer (ChemBERTa)
63
+ vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json")
64
+ self.drug_tokenizer = ChembertaTokenizer(vocab_file)
65
+
66
+ logger.info("Model and tokenizers loaded successfully!")
67
+ return True
68
+
69
+ except Exception as e:
70
+ logger.error(f"Error loading model: {str(e)}")
71
+ return False
72
+
73
+ def predict_interaction(self, target_sequence, drug_smiles, max_length=512):
74
+ """Predict drug-target interaction"""
75
+ if self.model is None:
76
+ return "Error: Model not loaded. Please load a model first."
77
+
78
+ try:
79
+ # Tokenize inputs
80
+ target_inputs = self.target_tokenizer(
81
+ target_sequence,
82
+ padding=True,
83
+ truncation=True,
84
+ max_length=max_length,
85
+ return_tensors="pt"
86
+ ).to(self.device)
87
+
88
+ drug_inputs = self.drug_tokenizer(
89
+ drug_smiles,
90
+ padding=True,
91
+ truncation=True,
92
+ max_length=max_length,
93
+ return_tensors="pt"
94
+ ).to(self.device)
95
+
96
+ # Make prediction
97
+ with torch.no_grad():
98
+ prediction = self.model(target_inputs, drug_inputs)
99
+
100
+ # Unscale if scaler exists
101
+ if self.model.scaler is not None:
102
+ prediction = self.model.unscale(prediction)
103
+
104
+ prediction_value = prediction.cpu().numpy()[0][0]
105
+
106
+ return f"Predicted Binding Affinity: {prediction_value:.4f}"
107
+
108
+ except Exception as e:
109
+ logger.error(f"Prediction error: {str(e)}")
110
+ return f"Error during prediction: {str(e)}"
111
+
112
+ def get_attention_visualization(self, target_sequence, drug_smiles, max_length=512):
113
+ """Get attention weights for visualization"""
114
+ if self.model is None:
115
+ return None, "Model not loaded"
116
+
117
+ try:
118
+ # Enable interpretation mode
119
+ self.model.INTERPR_ENABLE_MODE()
120
+
121
+ # Tokenize inputs
122
+ target_inputs = self.target_tokenizer(
123
+ target_sequence,
124
+ padding=True,
125
+ truncation=True,
126
+ max_length=max_length,
127
+ return_tensors="pt"
128
+ ).to(self.device)
129
+
130
+ drug_inputs = self.drug_tokenizer(
131
+ drug_smiles,
132
+ padding=True,
133
+ truncation=True,
134
+ max_length=max_length,
135
+ return_tensors="pt"
136
+ ).to(self.device)
137
+
138
+ # Make prediction to get attention weights
139
+ with torch.no_grad():
140
+ _ = self.model(target_inputs, drug_inputs)
141
+
142
+ # Get attention weights
143
+ attention_weights = self.model.model.crossattention_weights
144
+ if attention_weights is not None:
145
+ # Convert to numpy for visualization
146
+ attention_weights = attention_weights.cpu().numpy()
147
+
148
+ # Get tokens for visualization
149
+ target_tokens = self.target_tokenizer.convert_ids_to_tokens(
150
+ target_inputs["input_ids"][0], skip_special_tokens=True
151
+ )
152
+ drug_tokens = self.drug_tokenizer.convert_ids_to_tokens(
153
+ drug_inputs["input_ids"][0], skip_special_tokens=True
154
+ )
155
+
156
+ return attention_weights, target_tokens, drug_tokens, "Attention visualization ready"
157
+ else:
158
+ return None, None, None, "No attention weights available"
159
+
160
+ except Exception as e:
161
+ logger.error(f"Attention visualization error: {str(e)}")
162
+ return None, None, None, f"Error: {str(e)}"
163
+
164
+ # Initialize the app
165
+ app = DrugTargetInteractionApp()
166
+
167
+ def predict_wrapper(target_seq, drug_smiles):
168
+ """Wrapper function for Gradio interface"""
169
+ if not target_seq.strip() or not drug_smiles.strip():
170
+ return "Please provide both target sequence and drug SMILES."
171
+
172
+ return app.predict_interaction(target_seq, drug_smiles)
173
+
174
+ def load_model_wrapper(model_path):
175
+ """Wrapper function to load model"""
176
+ if app.load_model(model_path):
177
+ return "Model loaded successfully!"
178
+ else:
179
+ return "Failed to load model. Check the path and files."
180
+
181
+ # Create Gradio interface
182
+ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo:
183
+ gr.HTML("""
184
+ <div style="text-align: center; margin-bottom: 30px;">
185
+ <h1 style="color: #2E86AB; font-size: 2.5em; margin-bottom: 10px;">
186
+ 🧬 Drug-Target Interaction Predictor
187
+ </h1>
188
+ <p style="font-size: 1.2em; color: #666;">
189
+ Predict binding affinity between drugs and target RNA sequences using deep learning
190
+ </p>
191
+ </div>
192
+ """)
193
+
194
+ with gr.Tab("🔮 Prediction"):
195
+ with gr.Row():
196
+ with gr.Column(scale=1):
197
+ target_input = gr.Textbox(
198
+ label="Target RNA Sequence",
199
+ placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)",
200
+ lines=4,
201
+ max_lines=6
202
+ )
203
+
204
+ drug_input = gr.Textbox(
205
+ label="Drug SMILES",
206
+ placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)",
207
+ lines=2
208
+ )
209
+
210
+ predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg")
211
+
212
+ with gr.Column(scale=1):
213
+ prediction_output = gr.Textbox(
214
+ label="Prediction Result",
215
+ interactive=False,
216
+ lines=3
217
+ )
218
+
219
+ # Example inputs
220
+ gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>📚 Example Inputs:</h3>")
221
+
222
+ examples = gr.Examples(
223
+ examples=[
224
+ [
225
+ "AUGCUAGCUAGUACGUAUAUCUGCACUGC",
226
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
227
+ ],
228
+ [
229
+ "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU",
230
+ "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
231
+ ]
232
+ ],
233
+ inputs=[target_input, drug_input],
234
+ outputs=prediction_output,
235
+ fn=predict_wrapper,
236
+ cache_examples=False
237
+ )
238
+
239
+ predict_btn.click(
240
+ fn=predict_wrapper,
241
+ inputs=[target_input, drug_input],
242
+ outputs=prediction_output
243
+ )
244
+
245
+ with gr.Tab("⚙️ Model Settings"):
246
+ gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>")
247
+
248
+ model_path_input = gr.Textbox(
249
+ label="Model Path",
250
+ value="./",
251
+ placeholder="Path to model directory"
252
+ )
253
+
254
+ load_model_btn = gr.Button("📥 Load Model", variant="secondary")
255
+ model_status = gr.Textbox(
256
+ label="Status",
257
+ interactive=False,
258
+ value="No model loaded"
259
+ )
260
+
261
+ load_model_btn.click(
262
+ fn=load_model_wrapper,
263
+ inputs=model_path_input,
264
+ outputs=model_status
265
+ )
266
+
267
+ with gr.Tab("ℹ️ About"):
268
+ gr.Markdown("""
269
+ ## About This Application
270
+
271
+ This application uses a deep learning model for predicting drug-target interactions. The model architecture includes:
272
+
273
+ - **Target Encoder**: Processes RNA sequences
274
+ - **Drug Encoder**: Processes molecular SMILES notation
275
+ - **Cross-Attention Mechanism**: Captures interactions between drugs and targets
276
+ - **Regression Head**: Predicts binding affinity scores
277
+
278
+ ### Input Requirements:
279
+ - **Target Sequence**: RNA sequence of the target
280
+ - **Drug SMILES**: Simplified Molecular Input Line Entry System notation
281
+
282
+ ### Model Features:
283
+ - Cross-attention for drug-target interaction modeling
284
+ - Dropout for regularization
285
+ - Layer normalization for stable training
286
+ - Interpretability mode for attention visualization
287
+
288
+ ### Usage Tips:
289
+ 1. Load your trained model using the Model Settings tab
290
+ 2. Enter a RNA sequence and drug SMILES
291
+ 3. Click "Predict Interaction" to get binding affinity prediction
292
+
293
+ For best results, ensure your input sequences are properly formatted and within reasonable length limits.
294
+ """)
295
+
296
+ # Launch the app
297
+ if __name__ == "__main__":
298
+ # Try to load model on startup
299
+ if os.path.exists("./config.json"):
300
+ app.load_model("./")
301
+
302
+ demo.launch(
303
+ server_name="0.0.0.0",
304
+ server_port=7860,
305
+ share=False,
306
+ show_error=True
307
+ )
chemberta.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers.models import WordLevel
2
+ from tokenizers import Tokenizer
3
+ from tokenizers.pre_tokenizers import Split
4
+ from tokenizers import Regex
5
+ from tokenizers.processors import TemplateProcessing
6
+ from transformers import BatchEncoding
7
+ import torch
8
+
9
+ class ChembertaTokenizer:
10
+ def __init__(self, vocab_file):
11
+ self.tokenizer = Tokenizer(
12
+ WordLevel.from_file(
13
+ vocab_file,
14
+ unk_token='[UNK]'
15
+ ))
16
+ self.tokenizer.pre_tokenizer = Split(
17
+ pattern=Regex(r"\[(.*?)\]|Cl|Br|>>|\\|.*?"),
18
+ behavior='isolated'
19
+ )
20
+ # Disable padding
21
+
22
+ self.tokenizer.encode_special_tokens = True
23
+ self.special_token_ids = {
24
+ self.tokenizer.token_to_id('[CLS]'),
25
+ self.tokenizer.token_to_id('[SEP]'),
26
+ self.tokenizer.token_to_id('[PAD]'),
27
+ self.tokenizer.token_to_id('[UNK]')
28
+ }
29
+
30
+ self.tokenizer.post_processor = TemplateProcessing(
31
+ single='[CLS] $A [SEP]',
32
+ pair='[CLS] $A [SEP] $B:1 [SEP]:1',
33
+ special_tokens=[
34
+ ('[CLS]', self.tokenizer.token_to_id('[CLS]')),
35
+ ('[SEP]', self.tokenizer.token_to_id('[SEP]'))
36
+ ]
37
+ )
38
+
39
+ def encode(self, inputs, padding=None, truncation=False,
40
+ max_length=None, return_tensors=None):
41
+ # Configure padding/truncation
42
+ if padding:
43
+ self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id('[PAD]'),
44
+ pad_token='[PAD]', length=max_length)
45
+ else:
46
+ self.tokenizer.no_padding()
47
+
48
+ if truncation:
49
+ self.tokenizer.enable_truncation(max_length=max_length)
50
+ else:
51
+ self.tokenizer.no_truncation()
52
+ if return_tensors == 'pt':
53
+ tensor_type = 'pt'
54
+ else:
55
+ tensor_type = None
56
+ # Handle batch or single input
57
+ if isinstance(inputs, list):
58
+ enc = self.tokenizer.encode_batch(inputs)
59
+ data = {
60
+ "input_ids": [e.ids for e in enc],
61
+ "attention_mask": [e.attention_mask for e in enc]
62
+ }
63
+ return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)
64
+
65
+ else:
66
+ # Single sequence: wrap into batch of size 1
67
+ enc = [self.tokenizer.encode(inputs)]
68
+ data = {
69
+ "input_ids": [e.ids for e in enc],
70
+ "attention_mask": [e.attention_mask for e in enc]
71
+ }
72
+ return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)
73
+
74
+ def __call__(self, inputs, padding=None, truncation=False,
75
+ max_length=None, return_tensors=None):
76
+ return self.encode(inputs, padding=padding, truncation=truncation,
77
+ max_length=max_length, return_tensors=return_tensors)
78
+
79
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
80
+ def _decode_sequence(seq):
81
+ if skip_special_tokens:
82
+ seq = [idx for idx in seq if idx not in self.special_token_ids]
83
+ return [self.tokenizer.id_to_token(idx) for idx in seq]
84
+
85
+ # 1) batch: list of lists or torch tensor
86
+ if isinstance(ids, torch.Tensor):
87
+ ids = ids.tolist()
88
+ if len(ids) == 1:
89
+ ids = ids[0]
90
+
91
+ if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
92
+ return [_decode_sequence(seq) for seq in ids]
93
+
94
+ # 2) single sequence: list of ints or torch tensor
95
+ if isinstance(ids, (list)):
96
+ return _decode_sequence(ids)
97
+
98
+ # 3) single int
99
+ if isinstance(ids, int):
100
+ return self.tokenizer.id_to_token(ids)
101
+
102
+ def decode(self, ids, skip_special_tokens=False):
103
+ def _decode_sequence(seq):
104
+ if skip_special_tokens:
105
+ seq = [idx for idx in seq if idx not in self.special_token_ids]
106
+ return ''.join(self.tokenizer.id_to_token(idx) for idx in seq)
107
+
108
+ # 1) batch: list of lists or torch tensor
109
+ if isinstance(ids, torch.Tensor):
110
+ ids = ids.tolist()
111
+ if len(ids) == 1:
112
+ ids = ids[0]
113
+
114
+ if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
115
+ return [_decode_sequence(seq) for seq in ids]
116
+
117
+ # 2) single sequence: list of ints or torch tensor
118
+ if isinstance(ids, (list)):
119
+ return _decode_sequence(ids)
120
+
121
+ # 3) single int
122
+ if isinstance(ids, int):
123
+ return self.tokenizer.id_to_token(ids)
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "InteractionModelATTNForRegression"
4
+ ],
5
+ "attention_dropout": 0.425,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_dlmberta.InteractionModelATTNConfig",
8
+ "AutoModel": "modeling_dlmberta.InteractionModelATTNForRegression"
9
+ },
10
+ "hidden_dropout": 0.0816,
11
+ "model_type": "dlmberta",
12
+ "num_heads": 1,
13
+ "torch_dtype": "float32",
14
+ "transformers_version": "4.41.0"
15
+ }
configuration_dlmberta.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class InteractionModelATTNConfig(PretrainedConfig):
4
+ model_type = "dlmberta"
5
+ def __init__(self, attention_dropout = 0.2, hidden_dropout = 0.2, num_heads = 1, **kwargs,):
6
+ self.num_heads = num_heads
7
+ self.hidden_dropout = hidden_dropout
8
+ self.attention_dropout = attention_dropout
9
+ super().__init__(**kwargs)
drug_tokenizer/vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"[PAD]":0,"[unused1]":1,"[unused2]":2,"[unused3]":3,"[unused4]":4,"[unused5]":5,"[unused6]":6,"[unused7]":7,"[unused8]":8,"[unused9]":9,"[unused10]":10,"[UNK]":11,"[CLS]":12,"[SEP]":13,"[MASK]":14,"c":15,"C":16,"(":17,")":18,"O":19,"1":20,"2":21,"=":22,"N":23,".":24,"n":25,"3":26,"F":27,"Cl":28,">>":29,"~":30,"-":31,"4":32,"[C@H]":33,"S":34,"[C@@H]":35,"[O-]":36,"Br":37,"#":38,"/":39,"[nH]":40,"[N+]":41,"s":42,"5":43,"o":44,"P":45,"[Na+]":46,"[Si]":47,"I":48,"[Na]":49,"[Pd]":50,"[K+]":51,"[K]":52,"[P]":53,"B":54,"[C@]":55,"[C@@]":56,"[Cl-]":57,"6":58,"[OH-]":59,"\\":60,"[N-]":61,"[Li]":62,"[H]":63,"[2H]":64,"[NH4+]":65,"[c-]":66,"[P-]":67,"[Cs+]":68,"[Li+]":69,"[Cs]":70,"[NaH]":71,"[H-]":72,"[O+]":73,"[BH4-]":74,"[Cu]":75,"7":76,"[Mg]":77,"[Fe+2]":78,"[n+]":79,"[Sn]":80,"[BH-]":81,"[Pd+2]":82,"[CH]":83,"[I-]":84,"[Br-]":85,"[C-]":86,"[Zn]":87,"[B-]":88,"[F-]":89,"[Al]":90,"[P+]":91,"[BH3-]":92,"[Fe]":93,"[C]":94,"[AlH4]":95,"[Ni]":96,"[SiH]":97,"8":98,"[Cu+2]":99,"[Mn]":100,"[AlH]":101,"[nH+]":102,"[AlH4-]":103,"[O-2]":104,"[Cr]":105,"[Mg+2]":106,"[NH3+]":107,"[S@]":108,"[Pt]":109,"[Al+3]":110,"[S@@]":111,"[S-]":112,"[Ti]":113,"[Zn+2]":114,"[PH]":115,"[NH2+]":116,"[Ru]":117,"[Ag+]":118,"[S+]":119,"[I+3]":120,"[NH+]":121,"[Ca+2]":122,"[Ag]":123,"9":124,"[Os]":125,"[Se]":126,"[SiH2]":127,"[Ca]":128,"[Ti+4]":129,"[Ac]":130,"[Cu+]":131,"[S]":132,"[Rh]":133,"[Cl+3]":134,"[cH-]":135,"[Zn+]":136,"[O]":137,"[Cl+]":138,"[SH]":139,"[H+]":140,"[Pd+]":141,"[se]":142,"[PH+]":143,"[I]":144,"[Pt+2]":145,"[C+]":146,"[Mg+]":147,"[Hg]":148,"[W]":149,"[SnH]":150,"[SiH3]":151,"[Fe+3]":152,"[NH]":153,"[Mo]":154,"[CH2+]":155,"%10":156,"[CH2-]":157,"[CH2]":158,"[n-]":159,"[Ce+4]":160,"[NH-]":161,"[Co]":162,"[I+]":163,"[PH2]":164,"[Pt+4]":165,"[Ce]":166,"[B]":167,"[Sn+2]":168,"[Ba+2]":169,"%11":170,"[Fe-3]":171,"[18F]":172,"[SH-]":173,"[Pb+2]":174,"[Os-2]":175,"[Zr+4]":176,"[N]":177,"[Ir]":178,"[Bi]":179,"[Ni+2]":180,"[P@]":181,"[Co+2]":182,"[s+]":183,"[As]":184,"[P+3]":185,"[Hg+2]":186,"[Yb+3]":187,"[CH-]":188,"[Zr+2]":189,"[Mn+2]":190,"[CH+]":191,"[In]":192,"[KH]":193,"[Ce+3]":194,"[Zr]":195,"[AlH2-]":196,"[OH2+]":197,"[Ti+3]":198,"[Rh+2]":199,"[Sb]":200,"[S-2]":201,"%12":202,"[P@@]":203,"[Si@H]":204,"[Mn+4]":205,"p":206,"[Ba]":207,"[NH2-]":208,"[Ge]":209,"[Pb+4]":210,"[Cr+3]":211,"[Au]":212,"[LiH]":213,"[Sc+3]":214,"[o+]":215,"[Rh-3]":216,"%13":217,"[Br]":218,"[Sb-]":219,"[S@+]":220,"[I+2]":221,"[Ar]":222,"[V]":223,"[Cu-]":224,"[Al-]":225,"[Te]":226,"[13c]":227,"[13C]":228,"[Cl]":229,"[PH4+]":230,"[SiH4]":231,"[te]":232,"[CH3-]":233,"[S@@+]":234,"[Rh+3]":235,"[SH+]":236,"[Bi+3]":237,"[Br+2]":238,"[La]":239,"[La+3]":240,"[Pt-2]":241,"[N@@]":242,"[PH3+]":243,"[N@]":244,"[Si+4]":245,"[Sr+2]":246,"[Al+]":247,"[Pb]":248,"[SeH]":249,"[Si-]":250,"[V+5]":251,"[Y+3]":252,"[Re]":253,"[Ru+]":254,"[Sm]":255,"*":256,"[3H]":257,"[NH2]":258,"[Ag-]":259,"[13CH3]":260,"[OH+]":261,"[Ru+3]":262,"[OH]":263,"[Gd+3]":264,"[13CH2]":265,"[In+3]":266,"[Si@@]":267,"[Si@]":268,"[Ti+2]":269,"[Sn+]":270,"[Cl+2]":271,"[AlH-]":272,"[Pd-2]":273,"[SnH3]":274,"[B+3]":275,"[Cu-2]":276,"[Nd+3]":277,"[Pb+3]":278,"[13cH]":279,"[Fe-4]":280,"[Ga]":281,"[Sn+4]":282,"[Hg+]":283,"[11CH3]":284,"[Hf]":285,"[Pr]":286,"[Y]":287,"[S+2]":288,"[Cd]":289,"[Cr+6]":290,"[Zr+3]":291,"[Rh+]":292,"[CH3]":293,"[N-3]":294,"[Hf+2]":295,"[Th]":296,"[Sb+3]":297,"%14":298,"[Cr+2]":299,"[Ru+2]":300,"[Hf+4]":301,"[14C]":302,"[Ta]":303,"[Tl+]":304,"[B+]":305,"[Os+4]":306,"[PdH2]":307,"[Pd-]":308,"[Cd+2]":309,"[Co+3]":310,"[S+4]":311,"[Nb+5]":312,"[123I]":313,"[c+]":314,"[Rb+]":315,"[V+2]":316,"[CH3+]":317,"[Ag+2]":318,"[cH+]":319,"[Mn+3]":320,"[Se-]":321,"[As-]":322,"[Eu+3]":323,"[SH2]":324,"[Sm+3]":325,"[IH+]":326,"%15":327,"[OH3+]":328,"[PH3]":329,"[IH2+]":330,"[SH2+]":331,"[Ir+3]":332,"[AlH3]":333,"[Sc]":334,"[Yb]":335,"[15NH2]":336,"[Lu]":337,"[sH+]":338,"[Gd]":339,"[18F-]":340,"[SH3+]":341,"[SnH4]":342,"[TeH]":343,"[Si@@H]":344,"[Ga+3]":345,"[CaH2]":346,"[Tl]":347,"[Ta+5]":348,"[GeH]":349,"[Br+]":350,"[Sr]":351,"[Tl+3]":352,"[Sm+2]":353,"[PH5]":354,"%16":355,"[N@@+]":356,"[Au+3]":357,"[C-4]":358,"[Nd]":359,"[Ti+]":360,"[IH]":361,"[N@+]":362,"[125I]":363,"[Eu]":364,"[Sn+3]":365,"[Nb]":366,"[Er+3]":367,"[123I-]":368,"[14c]":369,"%17":370,"[SnH2]":371,"[YH]":372,"[Sb+5]":373,"[Pr+3]":374,"[Ir+]":375,"[N+3]":376,"[AlH2]":377,"[19F]":378,"%18":379,"[Tb]":380,"[14CH]":381,"[Mo+4]":382,"[Si+]":383,"[BH]":384,"[Be]":385,"[Rb]":386,"[pH]":387,"%19":388,"%20":389,"[Xe]":390,"[Ir-]":391,"[Be+2]":392,"[C+4]":393,"[RuH2]":394,"[15NH]":395,"[U+2]":396,"[Au-]":397,"%21":398,"%22":399,"[Au+]":400,"[15n]":401,"[Al+2]":402,"[Tb+3]":403,"[15N]":404,"[V+3]":405,"[W+6]":406,"[14CH3]":407,"[Cr+4]":408,"[ClH+]":409,"b":410,"[Ti+6]":411,"[Nd+]":412,"[Zr+]":413,"[PH2+]":414,"[Fm]":415,"[N@H+]":416,"[RuH]":417,"[Dy+3]":418,"%23":419,"[Hf+3]":420,"[W+4]":421,"[11C]":422,"[13CH]":423,"[Er]":424,"[124I]":425,"[LaH]":426,"[F]":427,"[siH]":428,"[Ga+]":429,"[Cm]":430,"[GeH3]":431,"[IH-]":432,"[U+6]":433,"[SeH+]":434,"[32P]":435,"[SeH-]":436,"[Pt-]":437,"[Ir+2]":438,"[se+]":439,"[U]":440,"[F+]":441,"[BH2]":442,"[As+]":443,"[Cf]":444,"[ClH2+]":445,"[Ni+]":446,"[TeH3]":447,"[SbH2]":448,"[Ag+3]":449,"%24":450,"[18O]":451,"[PH4]":452,"[Os+2]":453,"[Na-]":454,"[Sb+2]":455,"[V+4]":456,"[Ho+3]":457,"[68Ga]":458,"[PH-]":459,"[Bi+2]":460,"[Ce+2]":461,"[Pd+3]":462,"[99Tc]":463,"[13C@@H]":464,"[Fe+6]":465,"[c]":466,"[GeH2]":467,"[10B]":468,"[Cu+3]":469,"[Mo+2]":470,"[Cr+]":471,"[Pd+4]":472,"[Dy]":473,"[AsH]":474,"[Ba+]":475,"[SeH2]":476,"[In+]":477,"[TeH2]":478,"[BrH+]":479,"[14cH]":480,"[W+]":481,"[13C@H]":482,"[AsH2]":483,"[In+2]":484,"[N+2]":485,"[N@@H+]":486,"[SbH]":487,"[60Co]":488,"[AsH4+]":489,"[AsH3]":490,"[18OH]":491,"[Ru-2]":492,"[Na-2]":493,"[CuH2]":494,"[31P]":495,"[Ti+5]":496,"[35S]":497,"[P@@H]":498,"[ArH]":499,"[Co+]":500,"[Zr-2]":501,"[BH2-]":502,"[131I]":503,"[SH5]":504,"[VH]":505,"[B+2]":506,"[Yb+2]":507,"[14C@H]":508,"[211At]":509,"[NH3+2]":510,"[IrH]":511,"[IrH2]":512,"[Rh-]":513,"[Cr-]":514,"[Sb+]":515,"[Ni+3]":516,"[TaH3]":517,"[Tl+2]":518,"[64Cu]":519,"[Tc]":520,"[Cd+]":521,"[1H]":522,"[15nH]":523,"[AlH2+]":524,"[FH+2]":525,"[BiH3]":526,"[Ru-]":527,"[Mo+6]":528,"[AsH+]":529,"[BaH2]":530,"[BaH]":531,"[Fe+4]":532,"[229Th]":533,"[Th+4]":534,"[As+3]":535,"[NH+3]":536,"[P@H]":537,"[Li-]":538,"[7NaH]":539,"[Bi+]":540,"[PtH+2]":541,"[p-]":542,"[Re+5]":543,"[NiH]":544,"[Ni-]":545,"[Xe+]":546,"[Ca+]":547,"[11c]":548,"[Rh+4]":549,"[AcH]":550,"[HeH]":551,"[Sc+2]":552,"[Mn+]":553,"[UH]":554,"[14CH2]":555,"[SiH4+]":556,"[18OH2]":557,"[Ac-]":558,"[Re+4]":559,"[118Sn]":560,"[153Sm]":561,"[P+2]":562,"[9CH]":563,"[9CH3]":564,"[Y-]":565,"[NiH2]":566,"[Si+2]":567,"[Mn+6]":568,"[ZrH2]":569,"[C-2]":570,"[Bi+5]":571,"[24NaH]":572,"[Fr]":573,"[15CH]":574,"[Se+]":575,"[At]":576,"[P-3]":577,"[124I-]":578,"[CuH2-]":579,"[Nb+4]":580,"[Nb+3]":581,"[MgH]":582,"[Ir+4]":583,"[67Ga+3]":584,"[67Ga]":585,"[13N]":586,"[15OH2]":587,"[2NH]":588,"[Ho]":589,"[Cn]":590}
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cda24e3491cc24f77c27dc0a9a2933a05eab66ae96aa770de5db86d640cffce
3
+ size 241171900
modeling_dlmberta.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, PretrainedConfig
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import PretrainedConfig, PreTrainedModel
6
+ from torch.nn.parameter import Parameter
7
+ from torch.nn.init import xavier_uniform_, constant_
8
+ from configuration_dlmberta import InteractionModelATTNConfig
9
+ import math
10
+
11
+ class StdScaler():
12
+ def fit(self, X):
13
+ self.mean_ = torch.mean(X).item()
14
+ self.std_ = torch.std(X, correction=0).item()
15
+
16
+ def fit_transform(self, X):
17
+ self.mean_ = torch.mean(X).item()
18
+ self.std_ = torch.std(X, correction=0).item()
19
+
20
+ return (X-self.mean_)/self.std_
21
+
22
+ def transform(self, X):
23
+ return (X-self.mean_)/self.std_
24
+
25
+ def inverse_transform(self, X):
26
+ return (X*self.std_)+self.mean_
27
+
28
+ def save(self, directory):
29
+ with open(directory+"/scaler.config", "w") as f:
30
+ f.write(str(self.mean_)+"\n")
31
+ f.write(str(self.std_)+"\n")
32
+
33
+ def load(self, directory):
34
+ with open(directory+"/scaler.config", "r") as f:
35
+ self.mean_ = float(f.readline())
36
+ self.std_ = float(f.readline())
37
+
38
+
39
+ class InteractionModelATTNForRegression(PreTrainedModel):
40
+ config_class = InteractionModelATTNConfig
41
+
42
+ def __init__(self, config, target_encoder, drug_encoder, scaler=None):
43
+ super().__init__(config)
44
+ self.model = InteractionModelATTN(target_encoder,
45
+ drug_encoder,
46
+ scaler,
47
+ config.attention_dropout,
48
+ config.hidden_dropout,
49
+ config.num_heads)
50
+ self.scaler = scaler
51
+
52
+ def INTERPR_ENABLE_MODE(self):
53
+ self.model.INTERPR_ENABLE_MODE()
54
+
55
+
56
+ def INTERPR_OVERRIDE_ATTN(self, new_weights):
57
+ self.model.INTERPR_OVERRIDE_ATTN(new_weights)
58
+
59
+ def INTERPR_RESET_OVERRIDE_ATTN(self):
60
+ self.model.INTERPR_RESET_OVERRIDE_ATTN()
61
+
62
+ def forward(self, x1, x2):
63
+ return self.model(x1, x2)
64
+
65
+ def unscale(self, x):
66
+ return self.model.unscale(x)
67
+
68
+
69
+
70
+ class CrossAttention(nn.Module):
71
+ def __init__(self, embed_dim, num_heads, attention_dropout=0.0, hidden_dropout=0.0, add_bias_kv=False, **factory_kwargs):
72
+ """
73
+ Initializes the CrossAttention layer.
74
+
75
+ Args:
76
+ embed_dim (int): Dimension of the input embeddings.
77
+ num_heads (int): Number of attention heads.
78
+ dropout (float): Dropout probability for attention weights.
79
+ """
80
+ super().__init__()
81
+ self.attention_dropout = attention_dropout
82
+ self.hidden_dropout = hidden_dropout
83
+ self.embed_dim = embed_dim
84
+ self.num_heads = num_heads
85
+ self.head_dim = embed_dim // num_heads
86
+
87
+ self.scaling = self.head_dim ** -0.5
88
+
89
+ if self.head_dim * num_heads != embed_dim:
90
+ raise ValueError("embed_dim must be divisible by num_heads")
91
+
92
+ # Linear projections for query, key, and value.
93
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
94
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
95
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
96
+ self.attn_dropout = nn.Dropout(attention_dropout)
97
+
98
+ xavier_uniform_(self.q_proj.weight)
99
+ xavier_uniform_(self.k_proj.weight)
100
+ xavier_uniform_(self.v_proj.weight)
101
+ constant_(self.q_proj.bias, 0.)
102
+ constant_(self.k_proj.bias, 0.)
103
+ constant_(self.v_proj.bias, 0.)
104
+
105
+ # Output projection.
106
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
107
+ constant_(self.out_proj.bias, 0)
108
+
109
+ self.drop_out = nn.Dropout(hidden_dropout)
110
+
111
+ def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, replace_weights=None):
112
+ """
113
+ Forward pass for cross attention.
114
+
115
+ Args:
116
+ query (Tensor): Query embeddings of shape (batch_size, query_len, embed_dim).
117
+ key (Tensor): Key embeddings of shape (batch_size, key_len, embed_dim).
118
+ value (Tensor): Value embeddings of shape (batch_size, key_len, embed_dim).
119
+ attn_mask (Tensor, optional): Attention mask of shape (batch_size, num_heads, query_len, key_len).
120
+
121
+ Returns:
122
+ output (Tensor): The attended output of shape (batch_size, query_len, embed_dim).
123
+ attn_weights (Tensor): The attention weights of shape (batch_size, num_heads, query_len, key_len).
124
+ """
125
+
126
+ batch_size, query_len, _ = query.size()
127
+ _, key_len, _ = key.size()
128
+
129
+ Q = self.q_proj(query)
130
+ K = self.k_proj(key)
131
+ V = self.v_proj(value)
132
+
133
+ Q = Q.view(batch_size, self.num_heads, query_len, self.head_dim)
134
+ K = K.view(batch_size, self.num_heads, key_len, self.head_dim)
135
+ V = V.view(batch_size, self.num_heads, key_len, self.head_dim)
136
+
137
+ # Compute scaled dot-product attention scores
138
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch_size, num_heads, query_len, key_len)
139
+
140
+ if key_padding_mask is not None:
141
+ # Convert boolean mask (False -> -inf, True -> 0)
142
+ key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # (B, 1, 1, key_len) for broadcasting
143
+ scores = scores.masked_fill(key_padding_mask, float('-inf')) # Set masked positions to -inf
144
+
145
+ if replace_weights is not None:
146
+ scores = replace_weights
147
+
148
+ # Compute attention weights using softmax
149
+ attn_weights = torch.nn.functional.softmax(scores, dim=-1) # (batch_size, num_heads, query_len, key_len)
150
+ self.scores = scores
151
+ if attn_mask is not None:
152
+ attn_mask = attn_mask.unsqueeze(1) # Shape: (batch_size, 1, query_len, key_len)
153
+ attn_weights = attn_weights.masked_fill(attn_mask, 0) # Set masked positions to 0
154
+
155
+
156
+
157
+ # Optionally apply dropout to the attention weights if self.dropout is defined
158
+ attn_weights = self.attn_dropout(attn_weights)
159
+ # Compute the weighted sum of the values
160
+ attn_output = torch.matmul(attn_weights, V) # (batch_size, num_heads, query_len, head_dim)
161
+ # Recombine heads: transpose and reshape back to (batch_size, query_len, embed_dim)
162
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.embed_dim)
163
+
164
+ # Final linear projection and dropout
165
+ output = self.out_proj(attn_output)
166
+ output = self.drop_out(output)
167
+
168
+ return output, attn_weights
169
+
170
+
171
+ class InteractionModelATTN(nn.Module):
172
+ def __init__(self, target_encoder, drug_encoder, scaler, attention_dropout, hidden_dropout, num_heads=1, kernel_size=1):
173
+ super().__init__()
174
+ self.replace_weights = None
175
+ self.crossattention_weights = None
176
+ self.presum_layer = None
177
+ self.INTERPR_MODE = False
178
+
179
+ self.scaler = scaler
180
+ self.attention_dropout = attention_dropout
181
+ self.hidden_dropout = hidden_dropout
182
+
183
+ self.target_encoder = target_encoder
184
+ self.drug_encoder = drug_encoder
185
+ self.kernel_size = kernel_size
186
+ self.lin_map_target = nn.Linear(512, 384)
187
+ self.dropout_map_target = nn.Dropout(hidden_dropout)
188
+
189
+ self.lin_map_drug = nn.Linear(384, 384)
190
+ self.dropout_map_drug = nn.Dropout(hidden_dropout)
191
+
192
+ self.crossattention = CrossAttention(384, num_heads, attention_dropout, hidden_dropout)
193
+ self.norm = nn.LayerNorm(384)
194
+ self.summary1 = nn.Linear(384, 384)
195
+ self.summary2 = nn.Linear(384, 1)
196
+ self.dropout_summary = nn.Dropout(hidden_dropout)
197
+ self.layer_norm = nn.LayerNorm(384)
198
+ self.gelu = nn.GELU()
199
+
200
+ self.w = Parameter(torch.empty(512, 1))
201
+ self.b = Parameter(torch.zeros(1))
202
+ self.pdng = Parameter(torch.tensor(0.0)) # learnable padding value (0-dimensional)
203
+
204
+ xavier_uniform_(self.w)
205
+
206
+ def forward(self, x1, x2):
207
+ """
208
+ Forward pass for attention interaction model.
209
+
210
+ Args:
211
+ x1 (dict): A dictionary containing input tensors for the target encoder.
212
+ Expected keys:
213
+ - 'input_ids' (torch.Tensor): Token IDs for the target input.
214
+ - 'attention_mask' (torch.Tensor): Attention mask for the target input.
215
+ x2 (dict): A dictionary containing input tensors for the drug encoder.
216
+ Expected keys:
217
+ - 'input_ids' (torch.Tensor): Token IDs for the drug input.
218
+ - 'attention_mask' (torch.Tensor): Attention mask for the drug input.
219
+
220
+ Returns:
221
+ torch.Tensor: A tensor representing the predicted binding affinity.
222
+ """
223
+ x1["attention_mask"] = x1["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120
224
+ y1 = self.target_encoder(**x1).last_hidden_state # The target
225
+
226
+ query_mask = x1["attention_mask"].unsqueeze(-1).to(y1.dtype)
227
+ y1 = y1 * query_mask
228
+
229
+ x2["attention_mask"] = x2["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120
230
+ y2 = self.drug_encoder(**x2).last_hidden_state # The drug
231
+ key_mask = x2["attention_mask"].unsqueeze(-1).to(y2.dtype)
232
+ y2 = y2 * key_mask
233
+
234
+ y1 = self.lin_map_target(y1)
235
+ y1 = self.gelu(y1)
236
+ y1 = self.dropout_map_target(y1)
237
+
238
+ y2 = self.lin_map_drug(y2)
239
+ y2 = self.gelu(y2)
240
+ y2 = self.dropout_map_drug(y2)
241
+
242
+ key_padding_mask=(x2["attention_mask"] == 0) # S
243
+
244
+ replace_weights = None
245
+ # If in interpretation mode, allow the replacement of cross-attention weights
246
+ if self.INTERPR_MODE:
247
+ if self.replace_weights is not None:
248
+ replace_weights = self.replace_weights
249
+
250
+ out, _ = self.crossattention(y1, y2, y2, key_padding_mask=key_padding_mask, attn_mask=None, replace_weights=replace_weights)
251
+
252
+ # If in interpretation mode, make cross-attention weights and scores accessible from the outside
253
+ if self.INTERPR_MODE:
254
+ self.crossattention_weights = _
255
+ self.scores = self.crossattention.scores
256
+
257
+ out = self.summary1(out * query_mask)
258
+ out = self.gelu(out)
259
+ out = self.dropout_summary(out)
260
+ out = self.summary2(out).squeeze(-1)
261
+
262
+ # If in interpretation mode, make final summation layer contributions accessible from the outside
263
+ if self.INTERPR_MODE:
264
+ self.presum_layer = out
265
+
266
+
267
+ weighted = out * self.w.squeeze(1) # [batch, seq_len]
268
+ padding_positions = ~x1["attention_mask"] # True at padding
269
+ # assign learnable pdng to all padding positions
270
+ weighted = weighted.masked_fill(padding_positions, self.pdng.item())
271
+
272
+ # sum across sequence and add bias
273
+ result = weighted.sum(dim=1, keepdim=True) + self.b
274
+ return result
275
+
276
+ def train(self, mode = True):
277
+ super().train(mode)
278
+ self.target_encoder.train(mode)
279
+ self.drug_encoder.train(mode)
280
+ self.crossattention.train(mode)
281
+ return self
282
+
283
+ def eval(self):
284
+ super().eval()
285
+ self.target_encoder.eval()
286
+ self.drug_encoder.eval()
287
+ self.crossattention.eval()
288
+ return self
289
+
290
+ def INTERPR_ENABLE_MODE(self):
291
+ """
292
+ Enables the interpretability mode for the model.
293
+ """
294
+ if self.training:
295
+ raise RuntimeError("Cannot enable interpretability mode while the model is training.")
296
+ self.INTERPR_MODE = True
297
+
298
+
299
+ def INTERPR_OVERRIDE_ATTN(self, new_weights):
300
+ self.replace_weights = new_weights
301
+
302
+ def INTERPR_RESET_OVERRIDE_ATTN(self):
303
+ self.replace_weights = None
304
+
305
+ def unscale(self, x):
306
+ """
307
+ Unscales the labels using a scaler. If the scaler is not specified, don't do anything.
308
+
309
+ Parameters:
310
+ target_value: the target values to be unscaled
311
+ """
312
+ with torch.no_grad():
313
+ if self.scaler is None:
314
+ return x
315
+ unscaled = self.scaler.inverse_transform(x)
316
+ return unscaled
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=1.9.0
3
+ transformers>=4.21.0
4
+ tokenizers>=0.13.0
5
+ numpy>=1.21.0
6
+ huggingface_hub>=0.10.0
7
+ accelerate>=0.20.0
8
+ datasets>=2.0.0
9
+ safetensors>=0.3.0
scaler.config ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 5.389976501464844
2
+ 1.3962712287902832
target_tokenizer/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM",
4
+ "RobertaModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "attn_mult": 5.656854249492381,
8
+ "bos_token_id": 0,
9
+ "classifier_dropout": null,
10
+ "eos_token_id": 2,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 512,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "layer_norm_eps": 1e-12,
17
+ "max_position_embeddings": 514,
18
+ "model_type": "roberta",
19
+ "num_attention_heads": 16,
20
+ "num_hidden_layers": 12,
21
+ "output_hidden_states": true,
22
+ "pad_token_id": 1,
23
+ "position_embedding_type": "absolute",
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.46.3",
26
+ "type_vocab_size": 2,
27
+ "use_cache": true,
28
+ "vocab_size": 9700
29
+ }
target_tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
target_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
target_tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
target_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<pad>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "<mask>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "<s>",
46
+ "clean_up_tokenization_spaces": true,
47
+ "cls_token": "<s>",
48
+ "eos_token": "</s>",
49
+ "errors": "replace",
50
+ "mask_token": "<mask>",
51
+ "model_max_length": 1000000000000000019884624838656,
52
+ "pad_token": "<pad>",
53
+ "sep_token": "</s>",
54
+ "tokenizer_class": "RobertaTokenizer",
55
+ "trim_offsets": true,
56
+ "unk_token": "<unk>"
57
+ }
target_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff