SpyC0der77 commited on
Commit
e337fdb
·
verified ·
1 Parent(s): 7a704d0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -0
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio web interface for artifact classification
4
+ """
5
+
6
+ import os
7
+ # Fix SSL issue on Windows
8
+ os.environ['SSL_CERT_FILE'] = ''
9
+
10
+ import gradio as gr
11
+ import torch
12
+ import torch.nn as nn
13
+ from torchvision import transforms
14
+ from PIL import Image
15
+ import numpy as np
16
+ import os
17
+ import json
18
+ from pathlib import Path
19
+
20
+ # Import the model architecture
21
+ import sys
22
+ import os
23
+
24
+ # Add the train directory to Python path
25
+ train_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train')
26
+ sys.path.insert(0, train_dir)
27
+
28
+ # Now we can import from train.py
29
+ try:
30
+ import train
31
+ MultiOutputModel = train.MultiOutputModel
32
+ except ImportError as e:
33
+ print(f"Import error: {e}")
34
+ print("Make sure train.py exists in the train/ directory")
35
+ sys.exit(1)
36
+
37
+ class ArtifactClassifier:
38
+ def __init__(self, model_path="train/outputs/best_model.pth"):
39
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
+ print(f"Using device: {self.device}")
41
+
42
+ # Try to load from local file first, then from HuggingFace
43
+ self.model = self.load_model(model_path)
44
+ self.model.to(self.device)
45
+ self.model.eval()
46
+
47
+ # Set up transforms (same as training)
48
+ self.transform = transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
52
+ ])
53
+
54
+ # Load label mappings if available
55
+ self.label_mappings = self.load_label_mappings()
56
+ print("Model loaded successfully!")
57
+
58
+ def load_model(self, model_path):
59
+ """Load the trained model from local file or HuggingFace Hub"""
60
+ # First try to load from local file
61
+ if os.path.exists(model_path):
62
+ print(f"Loading model from local file: {model_path}")
63
+ return self._load_model_from_path(model_path)
64
+
65
+ # If local file doesn't exist, try to download from HuggingFace
66
+ print(f"Local model not found, downloading from HuggingFace...")
67
+ try:
68
+ return self._load_model_from_hub()
69
+ except Exception as e:
70
+ print(f"Failed to download from HuggingFace: {e}")
71
+ print("Falling back to local model creation...")
72
+ return self._create_model_with_defaults()
73
+
74
+ def _load_model_from_path(self, model_path):
75
+ """Load model from local file"""
76
+ checkpoint = torch.load(model_path, map_location=self.device)
77
+
78
+ # Get label mappings to determine number of classes
79
+ label_mappings = checkpoint.get('label_mappings', {})
80
+ num_object_classes = len(label_mappings.get('object_name', {}))
81
+ num_material_classes = len(label_mappings.get('material', {}))
82
+
83
+ if num_object_classes == 0:
84
+ print("Warning: No label mappings found, using fallback class counts")
85
+ num_object_classes, num_material_classes = 1018, 192
86
+
87
+ # Create model
88
+ model = MultiOutputModel(num_object_classes, num_material_classes)
89
+ model.load_state_dict(checkpoint['model_state_dict'])
90
+
91
+ return model
92
+
93
+ def _load_model_from_hub(self):
94
+ """Download and load model from HuggingFace Hub"""
95
+ try:
96
+ from huggingface_hub import hf_hub_download
97
+
98
+ print("Downloading model from HuggingFace Hub...")
99
+ model_file = hf_hub_download(
100
+ repo_id="SpyC0der77/artifact-classification-model",
101
+ filename="best_model.pth"
102
+ )
103
+
104
+ print(f"Model downloaded to: {model_file}")
105
+ return self._load_model_from_path(model_file)
106
+
107
+ except Exception as e:
108
+ print(f"Error downloading from HuggingFace: {e}")
109
+ raise
110
+
111
+ def _create_model_with_defaults(self):
112
+ """Create model with default parameters when no model is available"""
113
+ print("Creating model with default parameters...")
114
+ print("Note: This model won't have the trained weights!")
115
+
116
+ # Use default class counts
117
+ num_object_classes, num_material_classes = 1018, 192
118
+
119
+ # Create model
120
+ model = MultiOutputModel(num_object_classes, num_material_classes)
121
+
122
+ return model
123
+
124
+ def load_label_mappings(self):
125
+ """Load label mappings for decoding predictions"""
126
+ # First try local model
127
+ model_path = "train/outputs/best_model.pth"
128
+ if os.path.exists(model_path):
129
+ try:
130
+ checkpoint = torch.load(model_path, map_location='cpu')
131
+ mappings = checkpoint.get('label_mappings', {})
132
+
133
+ # Create reverse mappings
134
+ reverse_mappings = {}
135
+ for attr, mapping in mappings.items():
136
+ reverse_mappings[attr] = {v: k for k, v in mapping.items()}
137
+
138
+ return reverse_mappings
139
+ except Exception as e:
140
+ print(f"Could not load local label mappings: {e}")
141
+
142
+ # Try to download from HuggingFace
143
+ try:
144
+ print("Downloading label mappings from HuggingFace...")
145
+ from huggingface_hub import hf_hub_download
146
+
147
+ mappings_file = hf_hub_download(
148
+ repo_id="SpyC0der77/artifact-classification-model",
149
+ filename="best_model.pth" # Contains the mappings
150
+ )
151
+
152
+ checkpoint = torch.load(mappings_file, map_location='cpu')
153
+ mappings = checkpoint.get('label_mappings', {})
154
+
155
+ # Create reverse mappings
156
+ reverse_mappings = {}
157
+ for attr, mapping in mappings.items():
158
+ reverse_mappings[attr] = {v: k for k, v in mapping.items()}
159
+
160
+ print(f"Loaded {len(reverse_mappings)} label mappings from HuggingFace")
161
+ return reverse_mappings
162
+
163
+ except Exception as e:
164
+ print(f"Could not load label mappings from HuggingFace: {e}")
165
+
166
+ return {}
167
+
168
+ def predict(self, image):
169
+ """Make prediction on uploaded image"""
170
+ try:
171
+ # Convert to PIL Image if needed
172
+ if isinstance(image, np.ndarray):
173
+ image = Image.fromarray(image).convert('RGB')
174
+ elif not isinstance(image, Image.Image):
175
+ image = Image.open(image).convert('RGB')
176
+
177
+ # Apply transforms
178
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
179
+
180
+ # Make prediction
181
+ with torch.no_grad():
182
+ outputs = self.model(image_tensor)
183
+
184
+ # Process results
185
+ results = {}
186
+ for attr in ['object_name', 'material']:
187
+ if attr in outputs:
188
+ # Get probabilities and prediction
189
+ probs = torch.softmax(outputs[attr], dim=1)
190
+ confidence, predicted_idx = torch.max(probs, dim=1)
191
+
192
+ pred_class = predicted_idx.item()
193
+ conf = confidence.item()
194
+
195
+ # Convert to label name
196
+ if attr in self.label_mappings and pred_class in self.label_mappings[attr]:
197
+ pred_label = self.label_mappings[attr][pred_class]
198
+ else:
199
+ pred_label = f"Class_{pred_class}"
200
+
201
+ results[attr] = {
202
+ 'prediction': pred_label,
203
+ 'confidence': conf,
204
+ 'class_id': pred_class
205
+ }
206
+
207
+ return results
208
+
209
+ except Exception as e:
210
+ return {"error": str(e)}
211
+
212
+ # Global classifier instance
213
+ classifier = None
214
+
215
+ def classify_image(image):
216
+ """Gradio interface function"""
217
+ global classifier
218
+
219
+ if classifier is None:
220
+ return "Error: Model not loaded. Please restart the app."
221
+
222
+ try:
223
+ results = classifier.predict(image)
224
+
225
+ if "error" in results:
226
+ return f"Prediction failed: {results['error']}"
227
+
228
+ # Format results
229
+ output = "PREDICTION RESULTS\n\n"
230
+
231
+ for attr, result in results.items():
232
+ status = "OK" if result['confidence'] > 0.5 else "LOW"
233
+ output += f"{status} {attr.upper()}: {result['prediction']}\n"
234
+ output += f" Confidence: {result['confidence']:.3f}\n"
235
+ output += f" Class ID: {result['class_id']}\n\n"
236
+
237
+ # Overall confidence
238
+ confidences = [r['confidence'] for r in results.values()]
239
+ avg_confidence = sum(confidences) / len(confidences)
240
+ output += f"Average Confidence: {avg_confidence:.3f}"
241
+
242
+ return output
243
+
244
+ except Exception as e:
245
+ return f"Error during prediction: {str(e)}"
246
+
247
+ def create_interface():
248
+ """Create and launch the Gradio interface"""
249
+ global classifier
250
+
251
+ # Initialize classifier
252
+ try:
253
+ print("Loading model...")
254
+ classifier = ArtifactClassifier()
255
+ print("Model loaded successfully!")
256
+ except Exception as e:
257
+ print(f"Failed to load model: {e}")
258
+ return
259
+
260
+ # Create interface
261
+ interface = gr.Interface(
262
+ fn=classify_image,
263
+ inputs=gr.Image(type="pil", label="Upload Artifact Image"),
264
+ outputs=gr.Textbox(label="Classification Results", lines=10),
265
+ title="Artifact Classification",
266
+ description="""
267
+ Upload an image of an archaeological artifact to get AI-powered classification!
268
+
269
+ Features:
270
+ - Object type identification (coin, vase, statue, etc.)
271
+ - Material classification (gold, silver, pottery, etc.)
272
+ - Confidence scores for each prediction
273
+ - GPU-accelerated processing (RTX 2060)
274
+ - Auto-downloads model from HuggingFace Hub
275
+
276
+ Supported formats: JPG, PNG, JPEG
277
+ """,
278
+ article="""
279
+ How to use:
280
+ 1. Click "Upload Artifact Image" to select an image
281
+ 2. Click "Submit" to run classification
282
+ 3. View results with confidence scores
283
+
284
+ Model trained on: British Museum artifact dataset
285
+ Accuracy: ~71% for objects, ~62% for materials
286
+ """,
287
+ examples=[
288
+ ["example_artifact.jpg"] # Add example images if available
289
+ ]
290
+ )
291
+
292
+ # Launch
293
+ print("Starting Gradio interface...")
294
+ interface.launch(
295
+ server_name="0.0.0.0", # Allow external connections
296
+ server_port=7860,
297
+ share=False, # Set to True for public link
298
+ debug=False
299
+ )
300
+
301
+ if __name__ == "__main__":
302
+ create_interface()