File size: 1,157 Bytes
ec42e29
b56dbfa
761b08f
 
 
94e4f13
 
 
 
 
 
 
 
761b08f
ec42e29
 
 
 
 
70583d3
94e4f13
 
70583d3
 
 
761b08f
 
 
 
 
ec42e29
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
from huggingface_hub import hf_hub_download
from nail_classification.inference import Inference


def _hf_hub_download_compat(repo_id: str, filename: str, token: str) -> str:
    try:
        return hf_hub_download(repo_id, filename, token=token)
    except TypeError:
        # Backward compatibility for older huggingface_hub releases.
        return hf_hub_download(repo_id, filename, use_auth_token=token)


class Model:
    def __init__(self, DEBUG):
        if DEBUG:
            base = r"C:\Users\follels\Documents\hand-ki-model-weights\DeepNAPSIModel\inference_checkpoints_v1"
            file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
        else:
            file_paths = [
                _hf_hub_download_compat(
                    "lfolle/DeepNAPSIModel", f"version_{v}.ckpt", os.environ["DeepNAPSIModel"]
                )
                for v in [10, 11, 12, 13, 14]
            ]
        self.inference = Inference(file_paths)

    def predict(self, x):
        y_hat, uncertainty = self.inference.predict(x)
        return y_hat, uncertainty

    def __call__(self, x):
        return self.predict(x)