File size: 2,010 Bytes
b8c9192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
augmentations.py

Simple camera-style augmentations for color fundus photography (CFP)
classification.

Expected input:
    RGB NumPy image, shape (H, W, 3)

Dependencies:
    pip install albumentations opencv-python
"""

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def get_train_transforms(
    image_size=1024,
    mean=IMAGENET_MEAN,
    std=IMAGENET_STD,
):
    return A.Compose([
        A.Resize(image_size, image_size),

        # Geometry is safe
        A.HorizontalFlip(p=0.5),

        A.ShiftScaleRotate(
            shift_limit=0.02,
            scale_limit=0.03,   # slightly reduced
            rotate_limit=5,     # slightly reduced
            border_mode=0,
            value=0,
            p=0.3,
        ),

        # MUCH weaker photometric changes
        A.RandomBrightnessContrast(
            brightness_limit=0.08,   # ↓ from 0.15
            contrast_limit=0.08,
            p=0.3,
        ),

        # Remove or reduce gamma
        A.RandomGamma(
            gamma_limit=(95, 105),   # very mild
            p=0.2,
        ),

        # Remove hue shift entirely (important)
        # Hue shifts are not realistic for fundus physiology
        # -> comment this out or reduce heavily
        # A.HueSaturationValue(...)

        # Keep mild quality perturbation
        A.OneOf([
            A.GaussianBlur(blur_limit=(3, 5)),
            A.Downscale(scale_min=0.85, scale_max=0.95, interpolation=cv2.INTER_LINEAR),
            A.ImageCompression(quality_lower=80, quality_upper=100),
        ], p=0.15),

        A.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])


def get_val_transforms(
    image_size=1024,
    mean=IMAGENET_MEAN,
    std=IMAGENET_STD,
):
    """
    Validation/test transforms.
    """
    return A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])