SpyC0der77's picture
Update app.py
7168a17 verified
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import timm
class ImprovedMultiOutputModel(nn.Module):
"""Improved multi-output model with EfficientNet backbone."""
def __init__(self, num_object_classes, num_material_classes, backbone='efficientnet_b0'):
super(ImprovedMultiOutputModel, self).__init__()
# Use EfficientNet backbone
self.backbone = timm.create_model(backbone, pretrained=True, num_classes=0)
backbone_out_features = self.backbone.num_features
# Add attention mechanism
self.attention = nn.Sequential(
nn.Linear(backbone_out_features, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, backbone_out_features),
nn.Sigmoid()
)
# Improved classification heads with dropout and batch norm
self.object_classifier = nn.Sequential(
nn.Linear(backbone_out_features, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, num_object_classes)
)
self.material_classifier = nn.Sequential(
nn.Linear(backbone_out_features, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, num_material_classes)
)
def forward(self, x):
# Extract features using backbone
features = self.backbone(x)
# Apply attention mechanism
attention_weights = self.attention(features)
features = features * attention_weights
# Get predictions for each attribute
object_pred = self.object_classifier(features)
material_pred = self.material_classifier(features)
return {
'object_name': object_pred,
'material': material_pred,
}
def get_val_transforms():
"""Get transforms for validation."""
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def load_model(model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
label_mappings = checkpoint['label_mappings']
num_object_classes = len(label_mappings['object_name'])
num_material_classes = len(label_mappings['material'])
backbone = 'efficientnet_b0'
model = ImprovedMultiOutputModel(num_object_classes, num_material_classes, backbone)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)
model.eval()
return model, label_mappings
# Load models
models = {}
models['modelv1.pth'], label_mappings_v1 = load_model('modelv1.pth')
models['modelv2.pth'], label_mappings_v2 = load_model('modelv2.pth')
# Assume label_mappings are the same for both, use v1
label_mappings = label_mappings_v1
def predict(image, model_choice):
if image is None:
return "Please upload an image."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models[model_choice]
transform = get_val_transforms()
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
pred_obj = torch.argmax(outputs['object_name'], dim=1).item()
pred_mat = torch.argmax(outputs['material'], dim=1).item()
# Map IDs back to names
obj_name = [k for k, v in label_mappings['object_name'].items() if v == pred_obj][0]
mat_name = [k for k, v in label_mappings['material'].items() if v == pred_mat][0]
return f"Predicted Object: {obj_name}\nPredicted Material: {mat_name}"
# Create Gradio interface using Blocks
with gr.Blocks(title="Artifact Classification Model") as demo:
gr.Markdown("# Artifact Classification Model")
gr.Markdown("Upload an image to classify the object name and material.")
model_selector = gr.Dropdown(choices=['modelv1.pth', 'modelv2.pth'], label="Select Model", value='modelv1.pth')
with gr.Row():
input_image = gr.Image(type="pil", label="Upload an Image")
output_text = gr.Textbox(label="Predictions")
predict_btn = gr.Button("Predict")
predict_btn.click(fn=predict, inputs=[input_image, model_selector], outputs=output_text)
if __name__ == "__main__":
demo.launch()