#!/usr/bin/env python3 """ Gradio web interface for artifact classification """ import os # Fix SSL issue on Windows os.environ['SSL_CERT_FILE'] = '' import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np import os import json from pathlib import Path # Define the model architecture directly (standalone) import torch import torch.nn as nn from torchvision import models class MultiOutputModel(nn.Module): """Multi-output model for artifact classification""" def __init__(self, num_object_classes, num_material_classes, hidden_size=512): super(MultiOutputModel, self).__init__() # Use a pre-trained ResNet as backbone self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) # Remove the final classification layer self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) # Freeze early layers for transfer learning for param in list(self.backbone.parameters())[:-2]: param.requires_grad = False # Classification heads for each attribute self.object_classifier = nn.Linear(2048, num_object_classes) self.material_classifier = nn.Linear(2048, num_material_classes) # Dropout for regularization self.dropout = nn.Dropout(0.3) def forward(self, x): # Extract features using backbone features = self.backbone(x) features = features.view(features.size(0), -1) features = self.dropout(features) # Get predictions for each attribute object_pred = self.object_classifier(features) material_pred = self.material_classifier(features) return { 'object_name': object_pred, 'material': material_pred, } print("MultiOutputModel class defined directly in app (standalone)") class ArtifactClassifier: def __init__(self, model_path="train/outputs/best_model.pth"): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {self.device}") # Try to load from local file first, then from HuggingFace self.model = self.load_model(model_path) self.model.to(self.device) self.model.eval() # Set up transforms (same as training) self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load label mappings if available self.label_mappings = self.load_label_mappings() print("Model loaded successfully!") def load_model(self, model_path): """Load the trained model from local file or HuggingFace Hub""" # First try to load from local file if os.path.exists(model_path): print(f"Loading model from local file: {model_path}") return self._load_model_from_path(model_path) # If local file doesn't exist, try to download from HuggingFace print(f"Local model not found, downloading from HuggingFace...") try: return self._load_model_from_hub() except Exception as e: print(f"Failed to download from HuggingFace: {e}") print("Falling back to local model creation...") return self._create_model_with_defaults() def _load_model_from_path(self, model_path): """Load model from local file""" checkpoint = torch.load(model_path, map_location=self.device) # Get label mappings to determine number of classes label_mappings = checkpoint.get('label_mappings', {}) num_object_classes = len(label_mappings.get('object_name', {})) num_material_classes = len(label_mappings.get('material', {})) if num_object_classes == 0: print("Warning: No label mappings found, using fallback class counts") num_object_classes, num_material_classes = 1018, 192 # Create model model = MultiOutputModel(num_object_classes, num_material_classes) model.load_state_dict(checkpoint['model_state_dict']) return model def _load_model_from_hub(self): """Download and load model from HuggingFace Hub""" try: from huggingface_hub import hf_hub_download print("Downloading model from HuggingFace Hub...") model_file = hf_hub_download( repo_id="SpyC0der77/artifact-classification-model", filename="best_model.pth" ) print(f"Model downloaded to: {model_file}") return self._load_model_from_path(model_file) except Exception as e: print(f"Error downloading from HuggingFace: {e}") raise def _create_model_with_defaults(self): """Create model with default parameters when no model is available""" print("Creating model with default parameters...") print("Note: This model won't have the trained weights!") # Use default class counts num_object_classes, num_material_classes = 1018, 192 # Create model model = MultiOutputModel(num_object_classes, num_material_classes) return model def load_label_mappings(self): """Load label mappings for decoding predictions""" # First try local model model_path = "train/outputs/best_model.pth" if os.path.exists(model_path): try: checkpoint = torch.load(model_path, map_location='cpu') mappings = checkpoint.get('label_mappings', {}) # Create reverse mappings reverse_mappings = {} for attr, mapping in mappings.items(): reverse_mappings[attr] = {v: k for k, v in mapping.items()} return reverse_mappings except Exception as e: print(f"Could not load local label mappings: {e}") # Try to download from HuggingFace try: print("Downloading label mappings from HuggingFace...") from huggingface_hub import hf_hub_download mappings_file = hf_hub_download( repo_id="SpyC0der77/artifact-classification-model", filename="best_model.pth" # Contains the mappings ) checkpoint = torch.load(mappings_file, map_location='cpu') mappings = checkpoint.get('label_mappings', {}) # Create reverse mappings reverse_mappings = {} for attr, mapping in mappings.items(): reverse_mappings[attr] = {v: k for k, v in mapping.items()} print(f"Loaded {len(reverse_mappings)} label mappings from HuggingFace") return reverse_mappings except Exception as e: print(f"Could not load label mappings from HuggingFace: {e}") return {} def predict(self, image): """Make prediction on uploaded image""" try: # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image).convert('RGB') elif not isinstance(image, Image.Image): image = Image.open(image).convert('RGB') # Apply transforms image_tensor = self.transform(image).unsqueeze(0).to(self.device) # Make prediction with torch.no_grad(): outputs = self.model(image_tensor) # Process results results = {} for attr in ['object_name', 'material']: if attr in outputs: # Get probabilities and prediction probs = torch.softmax(outputs[attr], dim=1) confidence, predicted_idx = torch.max(probs, dim=1) pred_class = predicted_idx.item() conf = confidence.item() # Convert to label name if attr in self.label_mappings and pred_class in self.label_mappings[attr]: pred_label = self.label_mappings[attr][pred_class] else: pred_label = f"Class_{pred_class}" results[attr] = { 'prediction': pred_label, 'confidence': conf, 'class_id': pred_class } return results except Exception as e: return {"error": str(e)} # Global classifier instance classifier = None def classify_image(image): """Gradio interface function""" global classifier if classifier is None: return "Error: Model not loaded. Please restart the app." try: results = classifier.predict(image) if "error" in results: return f"Prediction failed: {results['error']}" # Format results output = "PREDICTION RESULTS\n\n" for attr, result in results.items(): status = "OK" if result['confidence'] > 0.5 else "LOW" output += f"{status} {attr.upper()}: {result['prediction']}\n" output += f" Confidence: {result['confidence']:.3f}\n" output += f" Class ID: {result['class_id']}\n\n" # Overall confidence confidences = [r['confidence'] for r in results.values()] avg_confidence = sum(confidences) / len(confidences) output += f"Average Confidence: {avg_confidence:.3f}" return output except Exception as e: return f"Error during prediction: {str(e)}" def create_interface(): """Create and launch the Gradio interface""" global classifier # Initialize classifier try: print("Loading model...") classifier = ArtifactClassifier() print("Model loaded successfully!") except Exception as e: print(f"Failed to load model: {e}") return # Create interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Artifact Image"), outputs=gr.Textbox(label="Classification Results", lines=10), title="Artifact Classification", description=""" Upload an image of an archaeological artifact to get AI-powered classification! Features: - Object type identification (coin, vase, statue, etc.) - Material classification (gold, silver, pottery, etc.) - Confidence scores for each prediction - GPU-accelerated processing (if available) - Auto-downloads model from HuggingFace Hub - Completely standalone - no training code needed Supported formats: JPG, PNG, JPEG """, article=""" How to use: 1. Click "Upload Artifact Image" to select an image 2. Click "Submit" to run classification 3. View results with confidence scores Model trained on: British Museum artifact dataset Accuracy: ~71% for objects, ~62% for materials """, examples=[ ["example_artifact.jpg"] # Add example images if available ] ) # Launch print("Starting Gradio interface...") interface.launch( server_name="0.0.0.0", # Allow external connections server_port=7860, share=False, # Set to True for public link debug=False ) if __name__ == "__main__": create_interface()