import torch import torch.nn as nn from torchvision import transforms from PIL import Image import gradio as gr import timm class ImprovedMultiOutputModel(nn.Module): """Improved multi-output model with EfficientNet backbone.""" def __init__(self, num_object_classes, num_material_classes, backbone='efficientnet_b0'): super(ImprovedMultiOutputModel, self).__init__() # Use EfficientNet backbone self.backbone = timm.create_model(backbone, pretrained=True, num_classes=0) backbone_out_features = self.backbone.num_features # Add attention mechanism self.attention = nn.Sequential( nn.Linear(backbone_out_features, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, backbone_out_features), nn.Sigmoid() ) # Improved classification heads with dropout and batch norm self.object_classifier = nn.Sequential( nn.Linear(backbone_out_features, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, num_object_classes) ) self.material_classifier = nn.Sequential( nn.Linear(backbone_out_features, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, num_material_classes) ) def forward(self, x): # Extract features using backbone features = self.backbone(x) # Apply attention mechanism attention_weights = self.attention(features) features = features * attention_weights # Get predictions for each attribute object_pred = self.object_classifier(features) material_pred = self.material_classifier(features) return { 'object_name': object_pred, 'material': material_pred, } def get_val_transforms(): """Get transforms for validation.""" return transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def load_model(model_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(model_path, map_location=device) label_mappings = checkpoint['label_mappings'] num_object_classes = len(label_mappings['object_name']) num_material_classes = len(label_mappings['material']) backbone = 'efficientnet_b0' model = ImprovedMultiOutputModel(num_object_classes, num_material_classes, backbone) model.load_state_dict(checkpoint['model_state_dict'], strict=False) model.to(device) model.eval() return model, label_mappings # Load models models = {} models['modelv1.pth'], label_mappings_v1 = load_model('modelv1.pth') models['modelv2.pth'], label_mappings_v2 = load_model('modelv2.pth') # Assume label_mappings are the same for both, use v1 label_mappings = label_mappings_v1 def predict(image, model_choice): if image is None: return "Please upload an image." device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = models[model_choice] transform = get_val_transforms() image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(image_tensor) pred_obj = torch.argmax(outputs['object_name'], dim=1).item() pred_mat = torch.argmax(outputs['material'], dim=1).item() # Map IDs back to names obj_name = [k for k, v in label_mappings['object_name'].items() if v == pred_obj][0] mat_name = [k for k, v in label_mappings['material'].items() if v == pred_mat][0] return f"Predicted Object: {obj_name}\nPredicted Material: {mat_name}" # Create Gradio interface using Blocks with gr.Blocks(title="Artifact Classification Model") as demo: gr.Markdown("# Artifact Classification Model") gr.Markdown("Upload an image to classify the object name and material.") model_selector = gr.Dropdown(choices=['modelv1.pth', 'modelv2.pth'], label="Select Model", value='modelv1.pth') with gr.Row(): input_image = gr.Image(type="pil", label="Upload an Image") output_text = gr.Textbox(label="Predictions") predict_btn = gr.Button("Predict") predict_btn.click(fn=predict, inputs=[input_image, model_selector], outputs=output_text) if __name__ == "__main__": demo.launch()