Spaces:
Sleeping
Sleeping
File size: 2,726 Bytes
b8c9192 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | """Run DeepSeeNet inference for AREDS simplified score."""
import argparse
import json
import torch
from PIL import Image
from dataloader import DEFAULT_TRANSFORM
from model import DeepSeeNet
N_CLASSES = {
"ADVAMD": 2,
"DRUS": 3,
"PIG": 2,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--left-image", required=True)
parser.add_argument("--right-image", required=True)
parser.add_argument("--advamd-checkpoint", required=True)
parser.add_argument("--drus-checkpoint", required=True)
parser.add_argument("--pig-checkpoint", required=True)
parser.add_argument("--backbone", default="inception_v3")
return parser.parse_args()
def load_model(checkpoint_path: str, task: str, backbone: str, device) -> DeepSeeNet:
checkpoint = torch.load(checkpoint_path, map_location=device)
checkpoint_args = checkpoint.get("args", {})
model = DeepSeeNet(
n_classes=N_CLASSES[task],
backbone=checkpoint_args.get("backbone", backbone),
pretrained=False,
).to(device)
model.load_state_dict(checkpoint["model"])
model.eval()
return model
def load_image(path: str, device) -> torch.Tensor:
image = Image.open(path).convert("RGB")
return DEFAULT_TRANSFORM(image).unsqueeze(0).to(device)
@torch.no_grad()
def predict(model: DeepSeeNet, image: torch.Tensor) -> int:
return int(model(image).argmax(dim=1).item())
def simplified_score(scores: dict[str, tuple[int, int]]) -> int:
score = 0
if scores["ADVAMD"][0] or scores["ADVAMD"][1]:
return 5
score += scores["PIG"][0] == 1
score += scores["PIG"][1] == 1
score += scores["DRUS"][0] == 2
score += scores["DRUS"][1] == 2
score += scores["DRUS"][0] == 1 and scores["DRUS"][1] == 1
return min(score, 5)
def main() -> None:
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
images = {
"left": load_image(args.left_image, device),
"right": load_image(args.right_image, device),
}
checkpoints = {
"ADVAMD": args.advamd_checkpoint,
"DRUS": args.drus_checkpoint,
"PIG": args.pig_checkpoint,
}
scores = {}
for task, checkpoint in checkpoints.items():
model = load_model(checkpoint, task, args.backbone, device)
scores[task] = (
predict(model, images["left"]),
predict(model, images["right"]),
)
print(
json.dumps(
{
"simplified_score": simplified_score(scores),
"risk_factors": scores,
},
indent=2,
)
)
if __name__ == "__main__":
main()
|