DeepNAPSI / DummyModel.py
Lukas Folle
Add compatibility function for huggingface_hub download
94e4f13
raw
history blame contribute delete
925 Bytes
import torch
import os
import torch.nn
from huggingface_hub import hf_hub_download
def _hf_hub_download_compat(repo_id: str, filename: str, token: str) -> str:
try:
return hf_hub_download(repo_id, filename, token=token)
except TypeError:
# Backward compatibility for older huggingface_hub releases.
return hf_hub_download(repo_id, filename, use_auth_token=token)
def load_dummy_model(DEBUG):
model = DummyModel()
if not DEBUG:
file_path = _hf_hub_download_compat(
"lfolle/DeepNAPSIModel", "dummy_model.pth", os.environ["DeepNAPSIModel"]
)
model.load_state_dict(torch.load(file_path))
return model
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: list):
return torch.softmax(torch.rand(len(x), 5), 1), 0
def __call__(self, x: list):
return self.forward(x)