sleep_recognition / model_loader.py
BalaAndegue's picture
Upload model_loader.py with huggingface_hub
3bed08c verified
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
def get_model(model_path, device='cpu'):
"""
Loads the MobileNetV2 model with the weights from model_path.
The checkpoint has 1 output unit (Binary Classification).
"""
# 1. Load the base model
model = models.mobilenet_v2(weights=None)
# 2. Modify the classifier head for 1 output unit (Binary)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 1)
# 3. Load the weights
try:
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
except Exception as e:
# Fallback if the state_dict is wrapped or has different keys
print(f"Error loading state_dict: {e}")
if isinstance(state_dict, dict) and 'model' in state_dict:
model.load_state_dict(state_dict['model'])
else:
raise e
model.to(device)
model.eval()
return model
def predict(model, image, device='cpu'):
"""
Performs inference on a PIL image using sigmoid for binary output.
"""
# Standard MobileNetV2 preprocessing
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_batch)
# For a single output unit, we use sigmoid
probability = torch.sigmoid(output[0][0]).item()
# Assuming 1 = NATURAL, 0 = DROWSY
# Swapped from previous version based on user test results
natural_prob = probability
drowsy_prob = 1 - probability
confidences = {
"DROWSY": drowsy_prob,
"NATURAL": natural_prob
}
return confidences