Spaces:
Sleeping
Sleeping
| """Run DeepSeeNet inference for AREDS simplified score.""" | |
| import argparse | |
| import json | |
| import torch | |
| from PIL import Image | |
| from dataloader import DEFAULT_TRANSFORM | |
| from model import DeepSeeNet | |
| N_CLASSES = { | |
| "ADVAMD": 2, | |
| "DRUS": 3, | |
| "PIG": 2, | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--left-image", required=True) | |
| parser.add_argument("--right-image", required=True) | |
| parser.add_argument("--advamd-checkpoint", required=True) | |
| parser.add_argument("--drus-checkpoint", required=True) | |
| parser.add_argument("--pig-checkpoint", required=True) | |
| parser.add_argument("--backbone", default="inception_v3") | |
| return parser.parse_args() | |
| def load_model(checkpoint_path: str, task: str, backbone: str, device) -> DeepSeeNet: | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| checkpoint_args = checkpoint.get("args", {}) | |
| model = DeepSeeNet( | |
| n_classes=N_CLASSES[task], | |
| backbone=checkpoint_args.get("backbone", backbone), | |
| pretrained=False, | |
| ).to(device) | |
| model.load_state_dict(checkpoint["model"]) | |
| model.eval() | |
| return model | |
| def load_image(path: str, device) -> torch.Tensor: | |
| image = Image.open(path).convert("RGB") | |
| return DEFAULT_TRANSFORM(image).unsqueeze(0).to(device) | |
| def predict(model: DeepSeeNet, image: torch.Tensor) -> int: | |
| return int(model(image).argmax(dim=1).item()) | |
| def simplified_score(scores: dict[str, tuple[int, int]]) -> int: | |
| score = 0 | |
| if scores["ADVAMD"][0] or scores["ADVAMD"][1]: | |
| return 5 | |
| score += scores["PIG"][0] == 1 | |
| score += scores["PIG"][1] == 1 | |
| score += scores["DRUS"][0] == 2 | |
| score += scores["DRUS"][1] == 2 | |
| score += scores["DRUS"][0] == 1 and scores["DRUS"][1] == 1 | |
| return min(score, 5) | |
| def main() -> None: | |
| args = parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| images = { | |
| "left": load_image(args.left_image, device), | |
| "right": load_image(args.right_image, device), | |
| } | |
| checkpoints = { | |
| "ADVAMD": args.advamd_checkpoint, | |
| "DRUS": args.drus_checkpoint, | |
| "PIG": args.pig_checkpoint, | |
| } | |
| scores = {} | |
| for task, checkpoint in checkpoints.items(): | |
| model = load_model(checkpoint, task, args.backbone, device) | |
| scores[task] = ( | |
| predict(model, images["left"]), | |
| predict(model, images["right"]), | |
| ) | |
| print( | |
| json.dumps( | |
| { | |
| "simplified_score": simplified_score(scores), | |
| "risk_factors": scores, | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| main() | |