SpyC0der77's picture
Update app.py
623fea8 verified
raw
history blame
11.6 kB
#!/usr/bin/env python3
"""
Gradio web interface for artifact classification
"""
import os
# Fix SSL issue on Windows
os.environ['SSL_CERT_FILE'] = ''
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import json
from pathlib import Path
# Define the model architecture directly (standalone)
import torch
import torch.nn as nn
from torchvision import models
class MultiOutputModel(nn.Module):
"""Multi-output model for artifact classification"""
def __init__(self, num_object_classes, num_material_classes, hidden_size=512):
super(MultiOutputModel, self).__init__()
# Use a pre-trained ResNet as backbone
self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
# Remove the final classification layer
self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
# Freeze early layers for transfer learning
for param in list(self.backbone.parameters())[:-2]:
param.requires_grad = False
# Classification heads for each attribute
self.object_classifier = nn.Linear(2048, num_object_classes)
self.material_classifier = nn.Linear(2048, num_material_classes)
# Dropout for regularization
self.dropout = nn.Dropout(0.3)
def forward(self, x):
# Extract features using backbone
features = self.backbone(x)
features = features.view(features.size(0), -1)
features = self.dropout(features)
# Get predictions for each attribute
object_pred = self.object_classifier(features)
material_pred = self.material_classifier(features)
return {
'object_name': object_pred,
'material': material_pred,
}
print("MultiOutputModel class defined directly in app (standalone)")
class ArtifactClassifier:
def __init__(self, model_path="train/outputs/best_model.pth"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Try to load from local file first, then from HuggingFace
self.model = self.load_model(model_path)
self.model.to(self.device)
self.model.eval()
# Set up transforms (same as training)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load label mappings if available
self.label_mappings = self.load_label_mappings()
print("Model loaded successfully!")
def load_model(self, model_path):
"""Load the trained model from local file or HuggingFace Hub"""
# First try to load from local file
if os.path.exists(model_path):
print(f"Loading model from local file: {model_path}")
return self._load_model_from_path(model_path)
# If local file doesn't exist, try to download from HuggingFace
print(f"Local model not found, downloading from HuggingFace...")
try:
return self._load_model_from_hub()
except Exception as e:
print(f"Failed to download from HuggingFace: {e}")
print("Falling back to local model creation...")
return self._create_model_with_defaults()
def _load_model_from_path(self, model_path):
"""Load model from local file"""
checkpoint = torch.load(model_path, map_location=self.device)
# Get label mappings to determine number of classes
label_mappings = checkpoint.get('label_mappings', {})
num_object_classes = len(label_mappings.get('object_name', {}))
num_material_classes = len(label_mappings.get('material', {}))
if num_object_classes == 0:
print("Warning: No label mappings found, using fallback class counts")
num_object_classes, num_material_classes = 1018, 192
# Create model
model = MultiOutputModel(num_object_classes, num_material_classes)
model.load_state_dict(checkpoint['model_state_dict'])
return model
def _load_model_from_hub(self):
"""Download and load model from HuggingFace Hub"""
try:
from huggingface_hub import hf_hub_download
print("Downloading model from HuggingFace Hub...")
model_file = hf_hub_download(
repo_id="SpyC0der77/artifact-classification-model",
filename="best_model.pth"
)
print(f"Model downloaded to: {model_file}")
return self._load_model_from_path(model_file)
except Exception as e:
print(f"Error downloading from HuggingFace: {e}")
raise
def _create_model_with_defaults(self):
"""Create model with default parameters when no model is available"""
print("Creating model with default parameters...")
print("Note: This model won't have the trained weights!")
# Use default class counts
num_object_classes, num_material_classes = 1018, 192
# Create model
model = MultiOutputModel(num_object_classes, num_material_classes)
return model
def load_label_mappings(self):
"""Load label mappings for decoding predictions"""
# First try local model
model_path = "train/outputs/best_model.pth"
if os.path.exists(model_path):
try:
checkpoint = torch.load(model_path, map_location='cpu')
mappings = checkpoint.get('label_mappings', {})
# Create reverse mappings
reverse_mappings = {}
for attr, mapping in mappings.items():
reverse_mappings[attr] = {v: k for k, v in mapping.items()}
return reverse_mappings
except Exception as e:
print(f"Could not load local label mappings: {e}")
# Try to download from HuggingFace
try:
print("Downloading label mappings from HuggingFace...")
from huggingface_hub import hf_hub_download
mappings_file = hf_hub_download(
repo_id="SpyC0der77/artifact-classification-model",
filename="best_model.pth" # Contains the mappings
)
checkpoint = torch.load(mappings_file, map_location='cpu')
mappings = checkpoint.get('label_mappings', {})
# Create reverse mappings
reverse_mappings = {}
for attr, mapping in mappings.items():
reverse_mappings[attr] = {v: k for k, v in mapping.items()}
print(f"Loaded {len(reverse_mappings)} label mappings from HuggingFace")
return reverse_mappings
except Exception as e:
print(f"Could not load label mappings from HuggingFace: {e}")
return {}
def predict(self, image):
"""Make prediction on uploaded image"""
try:
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
elif not isinstance(image, Image.Image):
image = Image.open(image).convert('RGB')
# Apply transforms
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(image_tensor)
# Process results
results = {}
for attr in ['object_name', 'material']:
if attr in outputs:
# Get probabilities and prediction
probs = torch.softmax(outputs[attr], dim=1)
confidence, predicted_idx = torch.max(probs, dim=1)
pred_class = predicted_idx.item()
conf = confidence.item()
# Convert to label name
if attr in self.label_mappings and pred_class in self.label_mappings[attr]:
pred_label = self.label_mappings[attr][pred_class]
else:
pred_label = f"Class_{pred_class}"
results[attr] = {
'prediction': pred_label,
'confidence': conf,
'class_id': pred_class
}
return results
except Exception as e:
return {"error": str(e)}
# Global classifier instance
classifier = None
def classify_image(image):
"""Gradio interface function"""
global classifier
if classifier is None:
return "Error: Model not loaded. Please restart the app."
try:
results = classifier.predict(image)
if "error" in results:
return f"Prediction failed: {results['error']}"
# Format results
output = "PREDICTION RESULTS\n\n"
for attr, result in results.items():
status = "OK" if result['confidence'] > 0.5 else "LOW"
output += f"{status} {attr.upper()}: {result['prediction']}\n"
output += f" Confidence: {result['confidence']:.3f}\n"
output += f" Class ID: {result['class_id']}\n\n"
# Overall confidence
confidences = [r['confidence'] for r in results.values()]
avg_confidence = sum(confidences) / len(confidences)
output += f"Average Confidence: {avg_confidence:.3f}"
return output
except Exception as e:
return f"Error during prediction: {str(e)}"
def create_interface():
"""Create and launch the Gradio interface"""
global classifier
# Initialize classifier
try:
print("Loading model...")
classifier = ArtifactClassifier()
print("Model loaded successfully!")
except Exception as e:
print(f"Failed to load model: {e}")
return
# Create interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Artifact Image"),
outputs=gr.Textbox(label="Classification Results", lines=10),
title="Artifact Classification",
description="""
Upload an image of an archaeological artifact to get AI-powered classification!
Features:
- Object type identification (coin, vase, statue, etc.)
- Material classification (gold, silver, pottery, etc.)
- Confidence scores for each prediction
- GPU-accelerated processing (if available)
- Auto-downloads model from HuggingFace Hub
- Completely standalone - no training code needed
Supported formats: JPG, PNG, JPEG
""",
article="""
How to use:
1. Click "Upload Artifact Image" to select an image
2. Click "Submit" to run classification
3. View results with confidence scores
Model trained on: British Museum artifact dataset
Accuracy: ~71% for objects, ~62% for materials
""",
examples=[
["example_artifact.jpg"] # Add example images if available
]
)
# Launch
print("Starting Gradio interface...")
interface.launch(
server_name="0.0.0.0", # Allow external connections
server_port=7860,
share=False, # Set to True for public link
debug=False
)
if __name__ == "__main__":
create_interface()