Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Gradio web interface for artifact classification | |
| """ | |
| import os | |
| # Fix SSL issue on Windows | |
| os.environ['SSL_CERT_FILE'] = '' | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import json | |
| from pathlib import Path | |
| # Define the model architecture directly (standalone) | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| class MultiOutputModel(nn.Module): | |
| """Multi-output model for artifact classification""" | |
| def __init__(self, num_object_classes, num_material_classes, hidden_size=512): | |
| super(MultiOutputModel, self).__init__() | |
| # Use a pre-trained ResNet as backbone | |
| self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
| # Remove the final classification layer | |
| self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) | |
| # Freeze early layers for transfer learning | |
| for param in list(self.backbone.parameters())[:-2]: | |
| param.requires_grad = False | |
| # Classification heads for each attribute | |
| self.object_classifier = nn.Linear(2048, num_object_classes) | |
| self.material_classifier = nn.Linear(2048, num_material_classes) | |
| # Dropout for regularization | |
| self.dropout = nn.Dropout(0.3) | |
| def forward(self, x): | |
| # Extract features using backbone | |
| features = self.backbone(x) | |
| features = features.view(features.size(0), -1) | |
| features = self.dropout(features) | |
| # Get predictions for each attribute | |
| object_pred = self.object_classifier(features) | |
| material_pred = self.material_classifier(features) | |
| return { | |
| 'object_name': object_pred, | |
| 'material': material_pred, | |
| } | |
| print("MultiOutputModel class defined directly in app (standalone)") | |
| class ArtifactClassifier: | |
| def __init__(self, model_path="train/outputs/best_model.pth"): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {self.device}") | |
| # Try to load from local file first, then from HuggingFace | |
| self.model = self.load_model(model_path) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Set up transforms (same as training) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Load label mappings if available | |
| self.label_mappings = self.load_label_mappings() | |
| print("Model loaded successfully!") | |
| def load_model(self, model_path): | |
| """Load the trained model from local file or HuggingFace Hub""" | |
| # First try to load from local file | |
| if os.path.exists(model_path): | |
| print(f"Loading model from local file: {model_path}") | |
| return self._load_model_from_path(model_path) | |
| # If local file doesn't exist, try to download from HuggingFace | |
| print(f"Local model not found, downloading from HuggingFace...") | |
| try: | |
| return self._load_model_from_hub() | |
| except Exception as e: | |
| print(f"Failed to download from HuggingFace: {e}") | |
| print("Falling back to local model creation...") | |
| return self._create_model_with_defaults() | |
| def _load_model_from_path(self, model_path): | |
| """Load model from local file""" | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| # Get label mappings to determine number of classes | |
| label_mappings = checkpoint.get('label_mappings', {}) | |
| num_object_classes = len(label_mappings.get('object_name', {})) | |
| num_material_classes = len(label_mappings.get('material', {})) | |
| if num_object_classes == 0: | |
| print("Warning: No label mappings found, using fallback class counts") | |
| num_object_classes, num_material_classes = 1018, 192 | |
| # Create model | |
| model = MultiOutputModel(num_object_classes, num_material_classes) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| return model | |
| def _load_model_from_hub(self): | |
| """Download and load model from HuggingFace Hub""" | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| print("Downloading model from HuggingFace Hub...") | |
| model_file = hf_hub_download( | |
| repo_id="SpyC0der77/artifact-classification-model", | |
| filename="best_model.pth" | |
| ) | |
| print(f"Model downloaded to: {model_file}") | |
| return self._load_model_from_path(model_file) | |
| except Exception as e: | |
| print(f"Error downloading from HuggingFace: {e}") | |
| raise | |
| def _create_model_with_defaults(self): | |
| """Create model with default parameters when no model is available""" | |
| print("Creating model with default parameters...") | |
| print("Note: This model won't have the trained weights!") | |
| # Use default class counts | |
| num_object_classes, num_material_classes = 1018, 192 | |
| # Create model | |
| model = MultiOutputModel(num_object_classes, num_material_classes) | |
| return model | |
| def load_label_mappings(self): | |
| """Load label mappings for decoding predictions""" | |
| # First try local model | |
| model_path = "train/outputs/best_model.pth" | |
| if os.path.exists(model_path): | |
| try: | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| mappings = checkpoint.get('label_mappings', {}) | |
| # Create reverse mappings | |
| reverse_mappings = {} | |
| for attr, mapping in mappings.items(): | |
| reverse_mappings[attr] = {v: k for k, v in mapping.items()} | |
| return reverse_mappings | |
| except Exception as e: | |
| print(f"Could not load local label mappings: {e}") | |
| # Try to download from HuggingFace | |
| try: | |
| print("Downloading label mappings from HuggingFace...") | |
| from huggingface_hub import hf_hub_download | |
| mappings_file = hf_hub_download( | |
| repo_id="SpyC0der77/artifact-classification-model", | |
| filename="best_model.pth" # Contains the mappings | |
| ) | |
| checkpoint = torch.load(mappings_file, map_location='cpu') | |
| mappings = checkpoint.get('label_mappings', {}) | |
| # Create reverse mappings | |
| reverse_mappings = {} | |
| for attr, mapping in mappings.items(): | |
| reverse_mappings[attr] = {v: k for k, v in mapping.items()} | |
| print(f"Loaded {len(reverse_mappings)} label mappings from HuggingFace") | |
| return reverse_mappings | |
| except Exception as e: | |
| print(f"Could not load label mappings from HuggingFace: {e}") | |
| return {} | |
| def predict(self, image): | |
| """Make prediction on uploaded image""" | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| elif not isinstance(image, Image.Image): | |
| image = Image.open(image).convert('RGB') | |
| # Apply transforms | |
| image_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor) | |
| # Process results | |
| results = {} | |
| for attr in ['object_name', 'material']: | |
| if attr in outputs: | |
| # Get probabilities and prediction | |
| probs = torch.softmax(outputs[attr], dim=1) | |
| confidence, predicted_idx = torch.max(probs, dim=1) | |
| pred_class = predicted_idx.item() | |
| conf = confidence.item() | |
| # Convert to label name | |
| if attr in self.label_mappings and pred_class in self.label_mappings[attr]: | |
| pred_label = self.label_mappings[attr][pred_class] | |
| else: | |
| pred_label = f"Class_{pred_class}" | |
| results[attr] = { | |
| 'prediction': pred_label, | |
| 'confidence': conf, | |
| 'class_id': pred_class | |
| } | |
| return results | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Global classifier instance | |
| classifier = None | |
| def classify_image(image): | |
| """Gradio interface function""" | |
| global classifier | |
| if classifier is None: | |
| return "Error: Model not loaded. Please restart the app." | |
| try: | |
| results = classifier.predict(image) | |
| if "error" in results: | |
| return f"Prediction failed: {results['error']}" | |
| # Format results | |
| output = "PREDICTION RESULTS\n\n" | |
| for attr, result in results.items(): | |
| status = "OK" if result['confidence'] > 0.5 else "LOW" | |
| output += f"{status} {attr.upper()}: {result['prediction']}\n" | |
| output += f" Confidence: {result['confidence']:.3f}\n" | |
| output += f" Class ID: {result['class_id']}\n\n" | |
| # Overall confidence | |
| confidences = [r['confidence'] for r in results.values()] | |
| avg_confidence = sum(confidences) / len(confidences) | |
| output += f"Average Confidence: {avg_confidence:.3f}" | |
| return output | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}" | |
| def create_interface(): | |
| """Create and launch the Gradio interface""" | |
| global classifier | |
| # Initialize classifier | |
| try: | |
| print("Loading model...") | |
| classifier = ArtifactClassifier() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| return | |
| # Create interface | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil", label="Upload Artifact Image"), | |
| outputs=gr.Textbox(label="Classification Results", lines=10), | |
| title="Artifact Classification", | |
| description=""" | |
| Upload an image of an archaeological artifact to get AI-powered classification! | |
| Features: | |
| - Object type identification (coin, vase, statue, etc.) | |
| - Material classification (gold, silver, pottery, etc.) | |
| - Confidence scores for each prediction | |
| - GPU-accelerated processing (if available) | |
| - Auto-downloads model from HuggingFace Hub | |
| - Completely standalone - no training code needed | |
| Supported formats: JPG, PNG, JPEG | |
| """, | |
| article=""" | |
| How to use: | |
| 1. Click "Upload Artifact Image" to select an image | |
| 2. Click "Submit" to run classification | |
| 3. View results with confidence scores | |
| Model trained on: British Museum artifact dataset | |
| Accuracy: ~71% for objects, ~62% for materials | |
| """, | |
| examples=[ | |
| ["example_artifact.jpg"] # Add example images if available | |
| ] | |
| ) | |
| # Launch | |
| print("Starting Gradio interface...") | |
| interface.launch( | |
| server_name="0.0.0.0", # Allow external connections | |
| server_port=7860, | |
| share=False, # Set to True for public link | |
| debug=False | |
| ) | |
| if __name__ == "__main__": | |
| create_interface() |