saudi-date-classifier / src /ensemble.py
Rashidbm
Initial deployment
6276d4c
"""
Ensemble: Combines ResNet50, EfficientNet-B0, and ViT using soft voting.
Downloads checkpoints from HuggingFace, runs all three on the test set,
averages softmax probabilities, and reports the ensemble accuracy.
Usage:
python -m src.ensemble
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from transformers import ViTForImageClassification
from huggingface_hub import hf_hub_download
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
from src.dataset import DateFruitDataset, get_val_transforms
from src.utils import load_config, get_device, seed_everything
HF_REPO_ID = "Rashidbm/saudi-date-classifier"
CHECKPOINTS = {
"resnet": "arabic_dates_resnet50_best_V2.pth",
"efficientnet": "efficientnet_best.pth",
"vit": "vit_best_model.pth",
}
CLASS_NAMES = [
"Ajwa", "Galaxy", "Medjool", "Meneifi", "Nabtat Ali",
"Rutab", "Shaishe", "Sokari", "Sugaey",
]
def build_resnet50(num_classes=9, dropout=0.3):
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model.fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(model.fc.in_features, num_classes),
)
return model
def build_efficientnet(num_classes=9, dropout=0.3):
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(in_features, num_classes),
)
return model
class PretrainedViTClassifier(nn.Module):
def __init__(self, num_classes=9):
super().__init__()
self.backbone = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=num_classes,
ignore_mismatched_sizes=True,
)
def forward(self, x):
return self.backbone(x).logits
def load_checkpoint(model, path, device):
"""Load checkpoint, handle both 'model_state_dict' and raw state dicts."""
ckpt = torch.load(path, map_location=device, weights_only=False)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
else:
model.load_state_dict(ckpt)
model.to(device)
model.eval()
return model
def load_all_models(device):
"""Download checkpoints from HuggingFace and load all three models."""
print("Downloading checkpoints from HuggingFace...")
paths = {
name: hf_hub_download(repo_id=HF_REPO_ID, filename=fname)
for name, fname in CHECKPOINTS.items()
}
print("Loading models...")
models_dict = {}
models_dict["resnet"] = load_checkpoint(
build_resnet50(num_classes=9), paths["resnet"], device
)
models_dict["efficientnet"] = load_checkpoint(
build_efficientnet(num_classes=9), paths["efficientnet"], device
)
models_dict["vit"] = load_checkpoint(
PretrainedViTClassifier(num_classes=9), paths["vit"], device
)
print("All models loaded.")
return models_dict
@torch.no_grad()
def evaluate_single(model, loader, device, name):
"""Evaluate a single model, return (accuracy, all_probs, all_labels)."""
all_probs = []
all_labels = []
for images, labels, _ in tqdm(loader, desc=f"Evaluating {name}"):
images = images.to(device)
logits = model(images)
probs = F.softmax(logits, dim=1)
all_probs.append(probs.cpu())
all_labels.append(labels)
all_probs = torch.cat(all_probs)
all_labels = torch.cat(all_labels)
preds = all_probs.argmax(dim=1)
acc = accuracy_score(all_labels.numpy(), preds.numpy()) * 100
return acc, all_probs, all_labels
def main():
config = load_config()
seed_everything(42)
device = get_device()
print(f"Device: {device}")
# Load all models
models_dict = load_all_models(device)
# Load test set
transform = get_val_transforms(config)
test_dataset = DateFruitDataset("data/test.csv", transform=transform)
test_loader = DataLoader(
test_dataset, batch_size=16, shuffle=False, num_workers=0
)
print(f"\nTest set: {len(test_dataset)} images")
# Evaluate each model
results = {}
for name, model in models_dict.items():
acc, probs, labels = evaluate_single(model, test_loader, device, name)
results[name] = {"accuracy": acc, "probs": probs, "labels": labels}
# Ensemble (soft voting - average of softmax probabilities)
ensemble_probs = sum(r["probs"] for r in results.values()) / len(results)
ensemble_preds = ensemble_probs.argmax(dim=1).numpy()
true_labels = results["vit"]["labels"].numpy()
ensemble_acc = accuracy_score(true_labels, ensemble_preds) * 100
print(f"\n{'='*50}")
print(f"INDIVIDUAL vs ENSEMBLE")
print(f"{'='*50}")
for name, r in results.items():
print(f" {name.upper():<15} {r['accuracy']:>6.2f}%")
print(f" {'ENSEMBLE':<15} {ensemble_acc:>6.2f}%")
print(f"\nEnsemble Classification Report:")
print(classification_report(true_labels, ensemble_preds, target_names=CLASS_NAMES))
print("Confusion Matrix:")
print(confusion_matrix(true_labels, ensemble_preds))
if __name__ == "__main__":
main()