Spaces:
Runtime error
Runtime error
Update modeling_metalatte.py
Browse files- modeling_metalatte.py +21 -1
modeling_metalatte.py
CHANGED
|
@@ -14,7 +14,7 @@ import gc
|
|
| 14 |
from torch.optim.lr_scheduler import _LRScheduler
|
| 15 |
from transformers import EsmModel, PreTrainedModel
|
| 16 |
from configuration import MetaLATTEConfig
|
| 17 |
-
|
| 18 |
seed_everything(42)
|
| 19 |
|
| 20 |
class GELU(nn.Module):
|
|
@@ -218,6 +218,26 @@ class MultitaskProteinModel(PreTrainedModel):
|
|
| 218 |
|
| 219 |
# Initialize weights and apply final processing
|
| 220 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
def forward(self, input_ids, attention_mask=None):
|
| 223 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
|
|
|
| 14 |
from torch.optim.lr_scheduler import _LRScheduler
|
| 15 |
from transformers import EsmModel, PreTrainedModel
|
| 16 |
from configuration import MetaLATTEConfig
|
| 17 |
+
from urllib.parse import urljoin
|
| 18 |
seed_everything(42)
|
| 19 |
|
| 20 |
class GELU(nn.Module):
|
|
|
|
| 218 |
|
| 219 |
# Initialize weights and apply final processing
|
| 220 |
self.post_init()
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 224 |
+
config = kwargs.pop("config", None)
|
| 225 |
+
if config is None:
|
| 226 |
+
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
|
| 227 |
+
|
| 228 |
+
model = cls(config)
|
| 229 |
+
#state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
|
| 230 |
+
try:
|
| 231 |
+
state_dict_url = urljoin(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/", "pytorch_model.bin")
|
| 232 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 233 |
+
state_dict_url,
|
| 234 |
+
map_location=torch.device('cpu')
|
| 235 |
+
)['state_dict']
|
| 236 |
+
model.load_state_dict(state_dict, strict=False)
|
| 237 |
+
except Exception as e:
|
| 238 |
+
raise RuntimeError(f"Error loading state_dict from {pretrained_model_name_or_path}/pytorch_model.bin: {e}")
|
| 239 |
+
|
| 240 |
+
return model
|
| 241 |
|
| 242 |
def forward(self, input_ids, attention_mask=None):
|
| 243 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|