Spaces:
Running
Running
Upload 2 files
Browse files- app.py +3 -0
- modeling_dlmberta.py +9 -0
app.py
CHANGED
|
@@ -97,6 +97,7 @@ class DrugTargetInteractionApp:
|
|
| 97 |
).to(self.device)
|
| 98 |
|
| 99 |
# Make prediction
|
|
|
|
| 100 |
with torch.no_grad():
|
| 101 |
prediction = self.model(target_inputs, drug_inputs)
|
| 102 |
|
|
@@ -143,6 +144,7 @@ class DrugTargetInteractionApp:
|
|
| 143 |
max_length=512,
|
| 144 |
return_tensors="pt"
|
| 145 |
).to(self.device)
|
|
|
|
| 146 |
|
| 147 |
# Make prediction and extract visualization data
|
| 148 |
with torch.no_grad():
|
|
@@ -163,6 +165,7 @@ class DrugTargetInteractionApp:
|
|
| 163 |
b = self.model.model.b
|
| 164 |
scaler = self.model.model.scaler
|
| 165 |
|
|
|
|
| 166 |
# Generate visualizations
|
| 167 |
try:
|
| 168 |
# 1. Cross-attention heatmap
|
|
|
|
| 97 |
).to(self.device)
|
| 98 |
|
| 99 |
# Make prediction
|
| 100 |
+
self.model.INTERPR_DISABLE_MODE()
|
| 101 |
with torch.no_grad():
|
| 102 |
prediction = self.model(target_inputs, drug_inputs)
|
| 103 |
|
|
|
|
| 144 |
max_length=512,
|
| 145 |
return_tensors="pt"
|
| 146 |
).to(self.device)
|
| 147 |
+
self.model.INTERPR_ENABLE_MODE()
|
| 148 |
|
| 149 |
# Make prediction and extract visualization data
|
| 150 |
with torch.no_grad():
|
|
|
|
| 165 |
b = self.model.model.b
|
| 166 |
scaler = self.model.model.scaler
|
| 167 |
|
| 168 |
+
logger.info(target_inputs, drug_inputs)
|
| 169 |
# Generate visualizations
|
| 170 |
try:
|
| 171 |
# 1. Cross-attention heatmap
|
modeling_dlmberta.py
CHANGED
|
@@ -52,6 +52,8 @@ class InteractionModelATTNForRegression(PreTrainedModel):
|
|
| 52 |
def INTERPR_ENABLE_MODE(self):
|
| 53 |
self.model.INTERPR_ENABLE_MODE()
|
| 54 |
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def INTERPR_OVERRIDE_ATTN(self, new_weights):
|
| 57 |
self.model.INTERPR_OVERRIDE_ATTN(new_weights)
|
|
@@ -295,6 +297,13 @@ class InteractionModelATTN(nn.Module):
|
|
| 295 |
raise RuntimeError("Cannot enable interpretability mode while the model is training.")
|
| 296 |
self.INTERPR_MODE = True
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
def INTERPR_OVERRIDE_ATTN(self, new_weights):
|
| 300 |
self.replace_weights = new_weights
|
|
|
|
| 52 |
def INTERPR_ENABLE_MODE(self):
|
| 53 |
self.model.INTERPR_ENABLE_MODE()
|
| 54 |
|
| 55 |
+
def INTERPR_DISABLE_MODE(self):
|
| 56 |
+
self.model.INTERPR_DISABLE_MODE()
|
| 57 |
|
| 58 |
def INTERPR_OVERRIDE_ATTN(self, new_weights):
|
| 59 |
self.model.INTERPR_OVERRIDE_ATTN(new_weights)
|
|
|
|
| 297 |
raise RuntimeError("Cannot enable interpretability mode while the model is training.")
|
| 298 |
self.INTERPR_MODE = True
|
| 299 |
|
| 300 |
+
def INTERPR_DISABLE_MODE(self):
|
| 301 |
+
"""
|
| 302 |
+
Disables the interpretability mode for the model.
|
| 303 |
+
"""
|
| 304 |
+
if self.training:
|
| 305 |
+
raise RuntimeError("Cannot disable interpretability mode while the model is training.")
|
| 306 |
+
self.INTERPR_MODE = False
|
| 307 |
|
| 308 |
def INTERPR_OVERRIDE_ATTN(self, new_weights):
|
| 309 |
self.replace_weights = new_weights
|