File size: 1,533 Bytes
b8c9192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""PyTorch datasets and dataloaders for AREDS fundus images."""

from pathlib import Path
from typing import Callable, Optional, Tuple, Union

import pandas as pd
import torch
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms


TASKS = ("ADVAMD", "DRUS", "PIG")


DEFAULT_TRANSFORM = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),
        ),
    ]
)


class AREDSDataset(Dataset):
    def __init__(
        self,
        csv_path: Union[str, Path],
        image_root: Union[str, Path],
        task: str,
        transform: Optional[Callable[[Image.Image], Tensor]] = None,
    ) -> None:
        task = task.upper()
        if task not in TASKS:
            raise ValueError(f"task must be one of {TASKS}")
        self.image_root = Path(image_root)
        self.task = task
        self.transform = transform or DEFAULT_TRANSFORM
        self.data = pd.read_csv(csv_path)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
        row = self.data.iloc[index]
        image_path = self.image_root / row.pathname
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(int(row[self.task]), dtype=torch.long)
        return image, label