Spaces:
Running
Running
| """ | |
| Data Setup: Scan raw dataset folders, build manifest CSVs, stratified train/val/test split. | |
| Usage: | |
| python -m src.data_setup | |
| Expects dataset extracted to data/raw/ with structure: | |
| data/raw/ | |
| ├── Ajwa/ | |
| │ ├── img001.jpg | |
| │ └── ... | |
| ├── Galaxy/ | |
| ├── Medjool/ | |
| └── ... | |
| """ | |
| import sys | |
| from pathlib import Path | |
| from collections import Counter | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from PIL import Image | |
| from src.utils import load_config, seed_everything | |
| VALID_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"} | |
| def scan_dataset(raw_dir: str) -> pd.DataFrame: | |
| """ | |
| Scan the raw dataset directory and build a manifest DataFrame. | |
| Returns DataFrame with columns: image_path, variety, label_idx | |
| """ | |
| raw_path = Path(raw_dir) | |
| if not raw_path.exists(): | |
| print(f"ERROR: Raw data directory not found: {raw_path.resolve()}") | |
| print("\nPlease download the dataset from Kaggle:") | |
| print(" https://www.kaggle.com/datasets/wadhasnalhamdan/date-fruit-image-dataset-in-controlled-environment") | |
| print(f"\nExtract the ZIP so variety folders are directly inside: {raw_path.resolve()}/") | |
| sys.exit(1) | |
| # Discover variety folders | |
| variety_dirs = sorted([ | |
| d for d in raw_path.iterdir() | |
| if d.is_dir() and not d.name.startswith(".") | |
| ]) | |
| if len(variety_dirs) == 0: | |
| print(f"ERROR: No variety folders found in {raw_path.resolve()}") | |
| print("Expected folders like: Ajwa/, Galaxy/, Medjool/, etc.") | |
| sys.exit(1) | |
| # Build class-to-index mapping | |
| class_names = [d.name for d in variety_dirs] | |
| class_to_idx = {name: idx for idx, name in enumerate(class_names)} | |
| print(f"Found {len(class_names)} varieties: {class_names}") | |
| # Collect all image paths | |
| records = [] | |
| skipped = 0 | |
| for variety_dir in variety_dirs: | |
| variety = variety_dir.name | |
| # Walk recursively to handle nested structures | |
| for img_path in sorted(variety_dir.rglob("*")): | |
| if img_path.suffix.lower() not in VALID_EXTENSIONS: | |
| continue | |
| # Verify image is readable | |
| try: | |
| with Image.open(img_path) as img: | |
| img.verify() | |
| records.append({ | |
| "image_path": str(img_path), | |
| "variety": variety, | |
| "label_idx": class_to_idx[variety], | |
| }) | |
| except Exception: | |
| skipped += 1 | |
| if skipped > 0: | |
| print(f"Warning: Skipped {skipped} corrupted/unreadable images") | |
| df = pd.DataFrame(records) | |
| print(f"Total valid images: {len(df)}") | |
| return df | |
| def stratified_split( | |
| df: pd.DataFrame, | |
| train_ratio: float = 0.70, | |
| val_ratio: float = 0.15, | |
| test_ratio: float = 0.15, | |
| seed: int = 42, | |
| ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| """ | |
| Split dataset into train/val/test with stratification by variety. | |
| """ | |
| assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \ | |
| "Split ratios must sum to 1.0" | |
| # First split: train vs (val + test) | |
| train_df, temp_df = train_test_split( | |
| df, | |
| test_size=(val_ratio + test_ratio), | |
| stratify=df["variety"], | |
| random_state=seed, | |
| ) | |
| # Second split: val vs test | |
| relative_test_ratio = test_ratio / (val_ratio + test_ratio) | |
| val_df, test_df = train_test_split( | |
| temp_df, | |
| test_size=relative_test_ratio, | |
| stratify=temp_df["variety"], | |
| random_state=seed, | |
| ) | |
| return train_df, val_df, test_df | |
| def print_split_summary(train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: pd.DataFrame) -> None: | |
| """Print detailed split statistics.""" | |
| print("\n" + "=" * 60) | |
| print("DATASET SPLIT SUMMARY") | |
| print("=" * 60) | |
| print(f" Train: {len(train_df):>5} images ({len(train_df)/len(train_df)+len(val_df)+len(test_df):.0%}... )") | |
| print(f" Val: {len(val_df):>5} images") | |
| print(f" Test: {len(test_df):>5} images") | |
| total = len(train_df) + len(val_df) + len(test_df) | |
| print(f" Total: {total:>5} images") | |
| print(f"\n{'Variety':<15} {'Train':>6} {'Val':>6} {'Test':>6} {'Total':>6}") | |
| print("-" * 45) | |
| all_varieties = sorted(train_df["variety"].unique()) | |
| for variety in all_varieties: | |
| n_train = len(train_df[train_df["variety"] == variety]) | |
| n_val = len(val_df[val_df["variety"] == variety]) | |
| n_test = len(test_df[test_df["variety"] == variety]) | |
| n_total = n_train + n_val + n_test | |
| print(f" {variety:<13} {n_train:>6} {n_val:>6} {n_test:>6} {n_total:>6}") | |
| print("=" * 60) | |
| def main(): | |
| """Main data setup pipeline.""" | |
| config = load_config() | |
| seed_everything(config["data"]["seed"]) | |
| # Step 1: Scan dataset | |
| print("Step 1: Scanning dataset...") | |
| df = scan_dataset(config["data"]["raw_dir"]) | |
| # Step 2: Stratified split | |
| print("\nStep 2: Splitting dataset...") | |
| splits = config["data"]["splits"] | |
| train_df, val_df, test_df = stratified_split( | |
| df, | |
| train_ratio=splits[0], | |
| val_ratio=splits[1], | |
| test_ratio=splits[2], | |
| seed=config["data"]["seed"], | |
| ) | |
| # Step 3: Save CSVs | |
| data_dir = Path("data") | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| train_df.to_csv(data_dir / "train.csv", index=False) | |
| val_df.to_csv(data_dir / "val.csv", index=False) | |
| test_df.to_csv(data_dir / "test.csv", index=False) | |
| print(f"\nSaved: data/train.csv ({len(train_df)} rows)") | |
| print(f"Saved: data/val.csv ({len(val_df)} rows)") | |
| print(f"Saved: data/test.csv ({len(test_df)} rows)") | |
| # Step 4: Print summary | |
| print_split_summary(train_df, val_df, test_df) | |
| # Save class mapping for reference | |
| class_names = sorted(df["variety"].unique()) | |
| class_map = {name: idx for idx, name in enumerate(class_names)} | |
| print(f"\nClass mapping: {class_map}") | |
| print("\nData setup complete. Ready for training.") | |
| if __name__ == "__main__": | |
| main() | |