"""Run DeepSeeNet inference for AREDS simplified score.""" import argparse import json import torch from PIL import Image from dataloader import DEFAULT_TRANSFORM from model import DeepSeeNet N_CLASSES = { "ADVAMD": 2, "DRUS": 3, "PIG": 2, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--left-image", required=True) parser.add_argument("--right-image", required=True) parser.add_argument("--advamd-checkpoint", required=True) parser.add_argument("--drus-checkpoint", required=True) parser.add_argument("--pig-checkpoint", required=True) parser.add_argument("--backbone", default="inception_v3") return parser.parse_args() def load_model(checkpoint_path: str, task: str, backbone: str, device) -> DeepSeeNet: checkpoint = torch.load(checkpoint_path, map_location=device) checkpoint_args = checkpoint.get("args", {}) model = DeepSeeNet( n_classes=N_CLASSES[task], backbone=checkpoint_args.get("backbone", backbone), pretrained=False, ).to(device) model.load_state_dict(checkpoint["model"]) model.eval() return model def load_image(path: str, device) -> torch.Tensor: image = Image.open(path).convert("RGB") return DEFAULT_TRANSFORM(image).unsqueeze(0).to(device) @torch.no_grad() def predict(model: DeepSeeNet, image: torch.Tensor) -> int: return int(model(image).argmax(dim=1).item()) def simplified_score(scores: dict[str, tuple[int, int]]) -> int: score = 0 if scores["ADVAMD"][0] or scores["ADVAMD"][1]: return 5 score += scores["PIG"][0] == 1 score += scores["PIG"][1] == 1 score += scores["DRUS"][0] == 2 score += scores["DRUS"][1] == 2 score += scores["DRUS"][0] == 1 and scores["DRUS"][1] == 1 return min(score, 5) def main() -> None: args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") images = { "left": load_image(args.left_image, device), "right": load_image(args.right_image, device), } checkpoints = { "ADVAMD": args.advamd_checkpoint, "DRUS": args.drus_checkpoint, "PIG": args.pig_checkpoint, } scores = {} for task, checkpoint in checkpoints.items(): model = load_model(checkpoint, task, args.backbone, device) scores[task] = ( predict(model, images["left"]), predict(model, images["right"]), ) print( json.dumps( { "simplified_score": simplified_score(scores), "risk_factors": scores, }, indent=2, ) ) if __name__ == "__main__": main()