File size: 2,726 Bytes
b8c9192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Run DeepSeeNet inference for AREDS simplified score."""

import argparse
import json

import torch
from PIL import Image

from dataloader import DEFAULT_TRANSFORM
from model import DeepSeeNet


N_CLASSES = {
    "ADVAMD": 2,
    "DRUS": 3,
    "PIG": 2,
}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--left-image", required=True)
    parser.add_argument("--right-image", required=True)
    parser.add_argument("--advamd-checkpoint", required=True)
    parser.add_argument("--drus-checkpoint", required=True)
    parser.add_argument("--pig-checkpoint", required=True)
    parser.add_argument("--backbone", default="inception_v3")
    return parser.parse_args()


def load_model(checkpoint_path: str, task: str, backbone: str, device) -> DeepSeeNet:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    checkpoint_args = checkpoint.get("args", {})
    model = DeepSeeNet(
        n_classes=N_CLASSES[task],
        backbone=checkpoint_args.get("backbone", backbone),
        pretrained=False,
    ).to(device)
    model.load_state_dict(checkpoint["model"])
    model.eval()
    return model


def load_image(path: str, device) -> torch.Tensor:
    image = Image.open(path).convert("RGB")
    return DEFAULT_TRANSFORM(image).unsqueeze(0).to(device)


@torch.no_grad()
def predict(model: DeepSeeNet, image: torch.Tensor) -> int:
    return int(model(image).argmax(dim=1).item())


def simplified_score(scores: dict[str, tuple[int, int]]) -> int:
    score = 0
    if scores["ADVAMD"][0] or scores["ADVAMD"][1]:
        return 5
    score += scores["PIG"][0] == 1
    score += scores["PIG"][1] == 1
    score += scores["DRUS"][0] == 2
    score += scores["DRUS"][1] == 2
    score += scores["DRUS"][0] == 1 and scores["DRUS"][1] == 1
    return min(score, 5)


def main() -> None:
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    images = {
        "left": load_image(args.left_image, device),
        "right": load_image(args.right_image, device),
    }
    checkpoints = {
        "ADVAMD": args.advamd_checkpoint,
        "DRUS": args.drus_checkpoint,
        "PIG": args.pig_checkpoint,
    }

    scores = {}
    for task, checkpoint in checkpoints.items():
        model = load_model(checkpoint, task, args.backbone, device)
        scores[task] = (
            predict(model, images["left"]),
            predict(model, images["right"]),
        )

    print(
        json.dumps(
            {
                "simplified_score": simplified_score(scores),
                "risk_factors": scores,
            },
            indent=2,
        )
    )


if __name__ == "__main__":
    main()