import torch import os import torch.nn from huggingface_hub import hf_hub_download def _hf_hub_download_compat(repo_id: str, filename: str, token: str) -> str: try: return hf_hub_download(repo_id, filename, token=token) except TypeError: # Backward compatibility for older huggingface_hub releases. return hf_hub_download(repo_id, filename, use_auth_token=token) def load_dummy_model(DEBUG): model = DummyModel() if not DEBUG: file_path = _hf_hub_download_compat( "lfolle/DeepNAPSIModel", "dummy_model.pth", os.environ["DeepNAPSIModel"] ) model.load_state_dict(torch.load(file_path)) return model class DummyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x: list): return torch.softmax(torch.rand(len(x), 5), 1), 0 def __call__(self, x: list): return self.forward(x)