IlPakoZ commited on
Commit
13d264b
·
verified ·
1 Parent(s): eaf5193

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +3 -0
  2. modeling_dlmberta.py +9 -0
app.py CHANGED
@@ -97,6 +97,7 @@ class DrugTargetInteractionApp:
97
  ).to(self.device)
98
 
99
  # Make prediction
 
100
  with torch.no_grad():
101
  prediction = self.model(target_inputs, drug_inputs)
102
 
@@ -143,6 +144,7 @@ class DrugTargetInteractionApp:
143
  max_length=512,
144
  return_tensors="pt"
145
  ).to(self.device)
 
146
 
147
  # Make prediction and extract visualization data
148
  with torch.no_grad():
@@ -163,6 +165,7 @@ class DrugTargetInteractionApp:
163
  b = self.model.model.b
164
  scaler = self.model.model.scaler
165
 
 
166
  # Generate visualizations
167
  try:
168
  # 1. Cross-attention heatmap
 
97
  ).to(self.device)
98
 
99
  # Make prediction
100
+ self.model.INTERPR_DISABLE_MODE()
101
  with torch.no_grad():
102
  prediction = self.model(target_inputs, drug_inputs)
103
 
 
144
  max_length=512,
145
  return_tensors="pt"
146
  ).to(self.device)
147
+ self.model.INTERPR_ENABLE_MODE()
148
 
149
  # Make prediction and extract visualization data
150
  with torch.no_grad():
 
165
  b = self.model.model.b
166
  scaler = self.model.model.scaler
167
 
168
+ logger.info(target_inputs, drug_inputs)
169
  # Generate visualizations
170
  try:
171
  # 1. Cross-attention heatmap
modeling_dlmberta.py CHANGED
@@ -52,6 +52,8 @@ class InteractionModelATTNForRegression(PreTrainedModel):
52
  def INTERPR_ENABLE_MODE(self):
53
  self.model.INTERPR_ENABLE_MODE()
54
 
 
 
55
 
56
  def INTERPR_OVERRIDE_ATTN(self, new_weights):
57
  self.model.INTERPR_OVERRIDE_ATTN(new_weights)
@@ -295,6 +297,13 @@ class InteractionModelATTN(nn.Module):
295
  raise RuntimeError("Cannot enable interpretability mode while the model is training.")
296
  self.INTERPR_MODE = True
297
 
 
 
 
 
 
 
 
298
 
299
  def INTERPR_OVERRIDE_ATTN(self, new_weights):
300
  self.replace_weights = new_weights
 
52
  def INTERPR_ENABLE_MODE(self):
53
  self.model.INTERPR_ENABLE_MODE()
54
 
55
+ def INTERPR_DISABLE_MODE(self):
56
+ self.model.INTERPR_DISABLE_MODE()
57
 
58
  def INTERPR_OVERRIDE_ATTN(self, new_weights):
59
  self.model.INTERPR_OVERRIDE_ATTN(new_weights)
 
297
  raise RuntimeError("Cannot enable interpretability mode while the model is training.")
298
  self.INTERPR_MODE = True
299
 
300
+ def INTERPR_DISABLE_MODE(self):
301
+ """
302
+ Disables the interpretability mode for the model.
303
+ """
304
+ if self.training:
305
+ raise RuntimeError("Cannot disable interpretability mode while the model is training.")
306
+ self.INTERPR_MODE = False
307
 
308
  def INTERPR_OVERRIDE_ATTN(self, new_weights):
309
  self.replace_weights = new_weights