Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| def get_model(model_path, device='cpu'): | |
| """ | |
| Loads the MobileNetV2 model with the weights from model_path. | |
| The checkpoint has 1 output unit (Binary Classification). | |
| """ | |
| # 1. Load the base model | |
| model = models.mobilenet_v2(weights=None) | |
| # 2. Modify the classifier head for 1 output unit (Binary) | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, 1) | |
| # 3. Load the weights | |
| try: | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| except Exception as e: | |
| # Fallback if the state_dict is wrapped or has different keys | |
| print(f"Error loading state_dict: {e}") | |
| if isinstance(state_dict, dict) and 'model' in state_dict: | |
| model.load_state_dict(state_dict['model']) | |
| else: | |
| raise e | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def predict(model, image, device='cpu'): | |
| """ | |
| Performs inference on a PIL image using sigmoid for binary output. | |
| """ | |
| # Standard MobileNetV2 preprocessing | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| input_tensor = preprocess(image) | |
| input_batch = input_tensor.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| # For a single output unit, we use sigmoid | |
| probability = torch.sigmoid(output[0][0]).item() | |
| # Assuming 1 = NATURAL, 0 = DROWSY | |
| # Swapped from previous version based on user test results | |
| natural_prob = probability | |
| drowsy_prob = 1 - probability | |
| confidences = { | |
| "DROWSY": drowsy_prob, | |
| "NATURAL": natural_prob | |
| } | |
| return confidences | |