saudi-date-classifier / src /data_setup.py
Rashidbm
Initial deployment
6276d4c
"""
Data Setup: Scan raw dataset folders, build manifest CSVs, stratified train/val/test split.
Usage:
python -m src.data_setup
Expects dataset extracted to data/raw/ with structure:
data/raw/
├── Ajwa/
│ ├── img001.jpg
│ └── ...
├── Galaxy/
├── Medjool/
└── ...
"""
import sys
from pathlib import Path
from collections import Counter
import pandas as pd
from sklearn.model_selection import train_test_split
from PIL import Image
from src.utils import load_config, seed_everything
VALID_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
def scan_dataset(raw_dir: str) -> pd.DataFrame:
"""
Scan the raw dataset directory and build a manifest DataFrame.
Returns DataFrame with columns: image_path, variety, label_idx
"""
raw_path = Path(raw_dir)
if not raw_path.exists():
print(f"ERROR: Raw data directory not found: {raw_path.resolve()}")
print("\nPlease download the dataset from Kaggle:")
print(" https://www.kaggle.com/datasets/wadhasnalhamdan/date-fruit-image-dataset-in-controlled-environment")
print(f"\nExtract the ZIP so variety folders are directly inside: {raw_path.resolve()}/")
sys.exit(1)
# Discover variety folders
variety_dirs = sorted([
d for d in raw_path.iterdir()
if d.is_dir() and not d.name.startswith(".")
])
if len(variety_dirs) == 0:
print(f"ERROR: No variety folders found in {raw_path.resolve()}")
print("Expected folders like: Ajwa/, Galaxy/, Medjool/, etc.")
sys.exit(1)
# Build class-to-index mapping
class_names = [d.name for d in variety_dirs]
class_to_idx = {name: idx for idx, name in enumerate(class_names)}
print(f"Found {len(class_names)} varieties: {class_names}")
# Collect all image paths
records = []
skipped = 0
for variety_dir in variety_dirs:
variety = variety_dir.name
# Walk recursively to handle nested structures
for img_path in sorted(variety_dir.rglob("*")):
if img_path.suffix.lower() not in VALID_EXTENSIONS:
continue
# Verify image is readable
try:
with Image.open(img_path) as img:
img.verify()
records.append({
"image_path": str(img_path),
"variety": variety,
"label_idx": class_to_idx[variety],
})
except Exception:
skipped += 1
if skipped > 0:
print(f"Warning: Skipped {skipped} corrupted/unreadable images")
df = pd.DataFrame(records)
print(f"Total valid images: {len(df)}")
return df
def stratified_split(
df: pd.DataFrame,
train_ratio: float = 0.70,
val_ratio: float = 0.15,
test_ratio: float = 0.15,
seed: int = 42,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Split dataset into train/val/test with stratification by variety.
"""
assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \
"Split ratios must sum to 1.0"
# First split: train vs (val + test)
train_df, temp_df = train_test_split(
df,
test_size=(val_ratio + test_ratio),
stratify=df["variety"],
random_state=seed,
)
# Second split: val vs test
relative_test_ratio = test_ratio / (val_ratio + test_ratio)
val_df, test_df = train_test_split(
temp_df,
test_size=relative_test_ratio,
stratify=temp_df["variety"],
random_state=seed,
)
return train_df, val_df, test_df
def print_split_summary(train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: pd.DataFrame) -> None:
"""Print detailed split statistics."""
print("\n" + "=" * 60)
print("DATASET SPLIT SUMMARY")
print("=" * 60)
print(f" Train: {len(train_df):>5} images ({len(train_df)/len(train_df)+len(val_df)+len(test_df):.0%}... )")
print(f" Val: {len(val_df):>5} images")
print(f" Test: {len(test_df):>5} images")
total = len(train_df) + len(val_df) + len(test_df)
print(f" Total: {total:>5} images")
print(f"\n{'Variety':<15} {'Train':>6} {'Val':>6} {'Test':>6} {'Total':>6}")
print("-" * 45)
all_varieties = sorted(train_df["variety"].unique())
for variety in all_varieties:
n_train = len(train_df[train_df["variety"] == variety])
n_val = len(val_df[val_df["variety"] == variety])
n_test = len(test_df[test_df["variety"] == variety])
n_total = n_train + n_val + n_test
print(f" {variety:<13} {n_train:>6} {n_val:>6} {n_test:>6} {n_total:>6}")
print("=" * 60)
def main():
"""Main data setup pipeline."""
config = load_config()
seed_everything(config["data"]["seed"])
# Step 1: Scan dataset
print("Step 1: Scanning dataset...")
df = scan_dataset(config["data"]["raw_dir"])
# Step 2: Stratified split
print("\nStep 2: Splitting dataset...")
splits = config["data"]["splits"]
train_df, val_df, test_df = stratified_split(
df,
train_ratio=splits[0],
val_ratio=splits[1],
test_ratio=splits[2],
seed=config["data"]["seed"],
)
# Step 3: Save CSVs
data_dir = Path("data")
data_dir.mkdir(parents=True, exist_ok=True)
train_df.to_csv(data_dir / "train.csv", index=False)
val_df.to_csv(data_dir / "val.csv", index=False)
test_df.to_csv(data_dir / "test.csv", index=False)
print(f"\nSaved: data/train.csv ({len(train_df)} rows)")
print(f"Saved: data/val.csv ({len(val_df)} rows)")
print(f"Saved: data/test.csv ({len(test_df)} rows)")
# Step 4: Print summary
print_split_summary(train_df, val_df, test_df)
# Save class mapping for reference
class_names = sorted(df["variety"].unique())
class_map = {name: idx for idx, name in enumerate(class_names)}
print(f"\nClass mapping: {class_map}")
print("\nData setup complete. Ready for training.")
if __name__ == "__main__":
main()