Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import torch.nn as nn | |
| from torchvision import models | |
| from typing import Dict, Tuple | |
| import os | |
| class MultiOutputModel(nn.Module): | |
| """Multi-output model for artifact classification (matches UI)""" | |
| def __init__(self, num_object_classes, num_material_classes, hidden_size=512): | |
| super(MultiOutputModel, self).__init__() | |
| # Use a pre-trained ResNet as backbone | |
| self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
| # Remove the final classification layer | |
| self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) | |
| # Freeze early layers for transfer learning | |
| for param in list(self.backbone.parameters())[:-4]: # Unfreeze more layers for better fine-tuning | |
| param.requires_grad = False | |
| # Classification heads for each attribute | |
| self.object_classifier = nn.Linear(2048, num_object_classes) | |
| self.material_classifier = nn.Linear(2048, num_material_classes) | |
| def forward(self, x): | |
| # Extract features using backbone | |
| features = self.backbone(x) | |
| features = features.view(features.size(0), -1) | |
| # Get predictions for each attribute | |
| object_pred = self.object_classifier(features) | |
| material_pred = self.material_classifier(features) | |
| return { | |
| 'object_name': object_pred, | |
| 'material': material_pred, | |
| } | |
| def load_model(model_path: str) -> Tuple[torch.nn.Module, Dict[str, Dict[int, str]]]: | |
| """Load the model from checkpoint and return model and label mappings.""" | |
| print(f"Loading model from {model_path}...") | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| # Get label mappings to determine number of classes | |
| label_mappings = checkpoint.get('label_mappings', {}) | |
| num_object_classes = len(label_mappings.get('object_name', {})) | |
| num_material_classes = len(label_mappings.get('material', {})) | |
| if num_object_classes == 0: | |
| print("Warning: No label mappings found, using fallback class counts") | |
| num_object_classes, num_material_classes = 1018, 192 | |
| # Check model type based on state_dict keys to determine which architecture to use | |
| model_state_dict = checkpoint.get('model_state_dict', {}) | |
| state_dict_keys = set(model_state_dict.keys()) | |
| # Only support v1 model (MultiOutputModel) with ResNet backbone | |
| print(f"Loading v1 model (MultiOutputModel) with ResNet backbone") | |
| model = MultiOutputModel(num_object_classes, num_material_classes) | |
| # Load state dict | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| print("Warning: No model_state_dict found in checkpoint") | |
| # Create reverse mappings (id2label) | |
| reverse_mappings = {} | |
| for attr, mapping in label_mappings.items(): | |
| reverse_mappings[attr] = {int(v): str(k) for k, v in mapping.items()} | |
| print(f"Loaded {attr} mappings: {len(reverse_mappings[attr])} classes") | |
| return model, reverse_mappings | |
| def run_inference(model: torch.nn.Module, pixel_values: torch.Tensor, device: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Run inference on pixel_values and return predictions and confidences for both object_name and material.""" | |
| model.eval() | |
| model.to(device) | |
| pixel_values = pixel_values.to(device) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| # Handle different output formats | |
| if isinstance(outputs, dict): | |
| # Multi-output model format | |
| if 'object_name' in outputs and 'material' in outputs: | |
| logits_obj = outputs['object_name'] | |
| logits_mat = outputs['material'] | |
| else: | |
| raise ValueError("Expected 'object_name' and 'material' in model outputs") | |
| else: | |
| raise ValueError("Expected dict output with 'object_name' and 'material' keys") | |
| preds_obj = torch.argmax(logits_obj, dim=-1) | |
| probs_obj = torch.softmax(logits_obj, dim=-1) | |
| max_probs_obj = torch.max(probs_obj, dim=-1)[0] | |
| preds_mat = torch.argmax(logits_mat, dim=-1) | |
| probs_mat = torch.softmax(logits_mat, dim=-1) | |
| max_probs_mat = torch.max(probs_mat, dim=-1)[0] | |
| return preds_obj.cpu(), max_probs_obj.cpu(), preds_mat.cpu(), max_probs_mat.cpu() | |
| # Global variables for model and label mappings | |
| model = None | |
| label_mappings = None | |
| device = None | |
| def preprocess_image(image: Image.Image) -> torch.Tensor: | |
| """Preprocess image for model inference.""" | |
| # Define transforms | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Apply transforms | |
| image = image.convert('RGB') | |
| tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| return tensor | |
| def predict_artifact(image: Image.Image) -> tuple[str, float, str, float]: | |
| """Predict object and material from image.""" | |
| global model, label_mappings, device | |
| if model is None: | |
| raise ValueError("Model not loaded. Please restart the application.") | |
| # Preprocess image | |
| pixel_values = preprocess_image(image) | |
| # Run inference | |
| preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, pixel_values, device) | |
| # Get predictions | |
| object_pred_id = preds_obj[0].item() | |
| material_pred_id = preds_mat[0].item() | |
| object_conf = confs_obj[0].item() | |
| material_conf = confs_mat[0].item() | |
| # Convert IDs to labels | |
| object_name = label_mappings['object_name'].get(object_pred_id, f"class_{object_pred_id}") | |
| material_name = label_mappings['material'].get(material_pred_id, f"class_{material_pred_id}") | |
| return object_name, object_conf, material_name, material_conf | |
| def gradio_predict(image): | |
| """Gradio interface function.""" | |
| if image is None: | |
| return "Please upload an image", "", "", "" | |
| try: | |
| object_name, object_conf, material_name, material_conf = predict_artifact(image) | |
| # Format results | |
| object_result = f"**{object_name}** ({object_conf:.1%} confidence)" | |
| material_result = f"**{material_name}** ({material_conf:.1%} confidence)" | |
| return object_result, material_result, f"{object_conf:.3f}", f"{material_conf:.3f}" | |
| except Exception as e: | |
| return f"Error: {str(e)}", "", "", "" | |
| def load_model_on_startup(): | |
| """Load model when the application starts.""" | |
| global model, label_mappings, device | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model from model.pth | |
| model_path = "model.pth" | |
| if not os.path.exists(model_path): | |
| print(f"Warning: Model file not found at {model_path}") | |
| print("Please ensure the model.pth file exists in the current directory before running the application.") | |
| return | |
| try: | |
| model, label_mappings = load_model(model_path) | |
| print("Model loaded successfully!") | |
| print(f"Object classes: {len(label_mappings.get('object_name', {}))}") | |
| print(f"Material classes: {len(label_mappings.get('material', {}))}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Load model on startup | |
| load_model_on_startup() | |
| # Create Gradio interface | |
| with gr.Blocks(title="Artifact Classification v1", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# πΊ Artifact Classification Model v1") | |
| gr.Markdown("Upload an image of an artifact to classify its **object type** and **material composition**.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload Artifact Image", type="pil") | |
| submit_btn = gr.Button("π Classify Artifact", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### π Classification Results") | |
| object_output = gr.Markdown(label="**Object Type**") | |
| material_output = gr.Markdown(label="**Material**") | |
| with gr.Accordion("π Confidence Scores", open=False): | |
| object_conf = gr.Textbox(label="Object Confidence", interactive=False) | |
| material_conf = gr.Textbox(label="Material Confidence", interactive=False) | |
| # Connect the interface | |
| submit_btn.click( | |
| fn=gradio_predict, | |
| inputs=image_input, | |
| outputs=[object_output, material_output, object_conf, material_conf] | |
| ) | |
| # Example images | |
| gr.Examples( | |
| examples=[ | |
| # You can add example image paths here if available | |
| ], | |
| inputs=image_input, | |
| outputs=[object_output, material_output, object_conf, material_conf], | |
| fn=gradio_predict, | |
| cache_examples=False | |
| ) | |
| gr.Markdown(""" | |
| ### βΉοΈ About | |
| This model uses a ResNet-50 backbone to classify museum artifacts into object types (vase, statue, pottery, etc.) | |
| and material compositions (ceramic, bronze, stone, etc.). | |
| **Model**: MultiOutputModel with ResNet-50 backbone | |
| **Training Data**: Oriental Museum artifacts dataset | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |