SpyC0der77 commited on
Commit
b6a884b
·
verified ·
1 Parent(s): 5a8efe2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -258
app.py CHANGED
@@ -1,29 +1,15 @@
1
- #!/usr/bin/env python3
2
- """
3
- Gradio web interface for artifact classification
4
- """
5
-
6
- import os
7
- # Fix SSL issue on Windows
8
- os.environ['SSL_CERT_FILE'] = ''
9
-
10
  import gradio as gr
11
  import torch
12
- import torch.nn as nn
13
- from torchvision import transforms
14
  from PIL import Image
15
- import numpy as np
16
- import os
17
- import json
18
- from pathlib import Path
19
-
20
- # Define the model architecture directly (standalone)
21
- import torch
22
  import torch.nn as nn
23
  from torchvision import models
 
 
 
24
 
25
  class MultiOutputModel(nn.Module):
26
- """Multi-output model for artifact classification"""
27
 
28
  def __init__(self, num_object_classes, num_material_classes, hidden_size=512):
29
  super(MultiOutputModel, self).__init__()
@@ -34,21 +20,17 @@ class MultiOutputModel(nn.Module):
34
  self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
35
 
36
  # Freeze early layers for transfer learning
37
- for param in list(self.backbone.parameters())[:-2]:
38
  param.requires_grad = False
39
 
40
  # Classification heads for each attribute
41
  self.object_classifier = nn.Linear(2048, num_object_classes)
42
  self.material_classifier = nn.Linear(2048, num_material_classes)
43
 
44
- # Dropout for regularization
45
- self.dropout = nn.Dropout(0.3)
46
-
47
  def forward(self, x):
48
  # Extract features using backbone
49
  features = self.backbone(x)
50
  features = features.view(features.size(0), -1)
51
- features = self.dropout(features)
52
 
53
  # Get predictions for each attribute
54
  object_pred = self.object_classifier(features)
@@ -59,270 +41,211 @@ class MultiOutputModel(nn.Module):
59
  'material': material_pred,
60
  }
61
 
62
- print("MultiOutputModel class defined directly in app (standalone)")
63
 
64
- class ArtifactClassifier:
65
- def __init__(self, model_path="train/outputs/best_model.pth"):
66
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
- print(f"Using device: {self.device}")
68
 
69
- # Try to load from local file first, then from HuggingFace
70
- self.model = self.load_model(model_path)
71
- self.model.to(self.device)
72
- self.model.eval()
73
 
74
- # Set up transforms (same as training)
75
- self.transform = transforms.Compose([
76
- transforms.Resize((224, 224)),
77
- transforms.ToTensor(),
78
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
79
- ])
80
 
81
- # Load label mappings if available
82
- self.label_mappings = self.load_label_mappings()
83
- print("Model loaded successfully!")
 
84
 
85
- def load_model(self, model_path):
86
- """Load the trained model from local file or HuggingFace Hub"""
87
- # First try to load from local file
88
- if os.path.exists(model_path):
89
- print(f"Loading model from local file: {model_path}")
90
- return self._load_model_from_path(model_path)
91
-
92
- # If local file doesn't exist, try to download from HuggingFace
93
- print(f"Local model not found, downloading from HuggingFace...")
94
- try:
95
- return self._load_model_from_hub()
96
- except Exception as e:
97
- print(f"Failed to download from HuggingFace: {e}")
98
- print("Falling back to local model creation...")
99
- return self._create_model_with_defaults()
100
-
101
- def _load_model_from_path(self, model_path):
102
- """Load model from local file"""
103
- checkpoint = torch.load(model_path, map_location=self.device)
104
-
105
- # Get label mappings to determine number of classes
106
- label_mappings = checkpoint.get('label_mappings', {})
107
- num_object_classes = len(label_mappings.get('object_name', {}))
108
- num_material_classes = len(label_mappings.get('material', {}))
109
-
110
- if num_object_classes == 0:
111
- print("Warning: No label mappings found, using fallback class counts")
112
- num_object_classes, num_material_classes = 1018, 192
113
-
114
- # Create model
115
- model = MultiOutputModel(num_object_classes, num_material_classes)
116
  model.load_state_dict(checkpoint['model_state_dict'])
 
 
117
 
118
- return model
 
 
 
 
119
 
120
- def _load_model_from_hub(self):
121
- """Download and load model from HuggingFace Hub"""
122
- try:
123
- from huggingface_hub import hf_hub_download
124
 
125
- print("Downloading model from HuggingFace Hub...")
126
- model_file = hf_hub_download(
127
- repo_id="SpyC0der77/artifact-classification-model",
128
- filename="best_model.pth"
129
- )
130
 
131
- print(f"Model downloaded to: {model_file}")
132
- return self._load_model_from_path(model_file)
 
 
 
133
 
134
- except Exception as e:
135
- print(f"Error downloading from HuggingFace: {e}")
136
- raise
137
 
138
- def _create_model_with_defaults(self):
139
- """Create model with default parameters when no model is available"""
140
- print("Creating model with default parameters...")
141
- print("Note: This model won't have the trained weights!")
 
 
 
 
 
 
142
 
143
- # Use default class counts
144
- num_object_classes, num_material_classes = 1018, 192
 
145
 
146
- # Create model
147
- model = MultiOutputModel(num_object_classes, num_material_classes)
148
-
149
- return model
150
-
151
- def load_label_mappings(self):
152
- """Load label mappings for decoding predictions"""
153
- # First try local model
154
- model_path = "train/outputs/best_model.pth"
155
- if os.path.exists(model_path):
156
- try:
157
- checkpoint = torch.load(model_path, map_location='cpu')
158
- mappings = checkpoint.get('label_mappings', {})
159
-
160
- # Create reverse mappings
161
- reverse_mappings = {}
162
- for attr, mapping in mappings.items():
163
- reverse_mappings[attr] = {v: k for k, v in mapping.items()}
164
-
165
- return reverse_mappings
166
- except Exception as e:
167
- print(f"Could not load local label mappings: {e}")
168
-
169
- # Try to download from HuggingFace
170
- try:
171
- print("Downloading label mappings from HuggingFace...")
172
- from huggingface_hub import hf_hub_download
173
-
174
- mappings_file = hf_hub_download(
175
- repo_id="SpyC0der77/artifact-classification-model",
176
- filename="best_model.pth" # Contains the mappings
177
- )
178
-
179
- checkpoint = torch.load(mappings_file, map_location='cpu')
180
- mappings = checkpoint.get('label_mappings', {})
181
-
182
- # Create reverse mappings
183
- reverse_mappings = {}
184
- for attr, mapping in mappings.items():
185
- reverse_mappings[attr] = {v: k for k, v in mapping.items()}
186
-
187
- print(f"Loaded {len(reverse_mappings)} label mappings from HuggingFace")
188
- return reverse_mappings
189
-
190
- except Exception as e:
191
- print(f"Could not load label mappings from HuggingFace: {e}")
192
-
193
- return {}
194
-
195
- def predict(self, image):
196
- """Make prediction on uploaded image"""
197
- try:
198
- # Convert to PIL Image if needed
199
- if isinstance(image, np.ndarray):
200
- image = Image.fromarray(image).convert('RGB')
201
- elif not isinstance(image, Image.Image):
202
- image = Image.open(image).convert('RGB')
203
-
204
- # Apply transforms
205
- image_tensor = self.transform(image).unsqueeze(0).to(self.device)
206
-
207
- # Make prediction
208
- with torch.no_grad():
209
- outputs = self.model(image_tensor)
210
-
211
- # Process results
212
- results = {}
213
- for attr in ['object_name', 'material']:
214
- if attr in outputs:
215
- # Get probabilities and prediction
216
- probs = torch.softmax(outputs[attr], dim=1)
217
- confidence, predicted_idx = torch.max(probs, dim=1)
218
-
219
- pred_class = predicted_idx.item()
220
- conf = confidence.item()
221
-
222
- # Convert to label name
223
- if attr in self.label_mappings and pred_class in self.label_mappings[attr]:
224
- pred_label = self.label_mappings[attr][pred_class]
225
- else:
226
- pred_label = f"Class_{pred_class}"
227
-
228
- results[attr] = {
229
- 'prediction': pred_label,
230
- 'confidence': conf,
231
- 'class_id': pred_class
232
- }
233
-
234
- return results
235
-
236
- except Exception as e:
237
- return {"error": str(e)}
238
-
239
- # Global classifier instance
240
- classifier = None
241
-
242
- def classify_image(image):
243
- """Gradio interface function"""
244
- global classifier
245
-
246
- if classifier is None:
247
- return "Error: Model not loaded. Please restart the app."
248
 
249
- try:
250
- results = classifier.predict(image)
251
 
252
- if "error" in results:
253
- return f"Prediction failed: {results['error']}"
254
 
255
- # Format results
256
- output = "PREDICTION RESULTS\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
- for attr, result in results.items():
259
- status = "OK" if result['confidence'] > 0.5 else "LOW"
260
- output += f"{status} {attr.upper()}: {result['prediction']}\n"
261
- output += f" Confidence: {result['confidence']:.3f}\n"
262
- output += f" Class ID: {result['class_id']}\n\n"
263
 
264
- # Overall confidence
265
- confidences = [r['confidence'] for r in results.values()]
266
- avg_confidence = sum(confidences) / len(confidences)
267
- output += f"Average Confidence: {avg_confidence:.3f}"
268
 
269
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  except Exception as e:
272
- return f"Error during prediction: {str(e)}"
273
 
274
- def create_interface():
275
- """Create and launch the Gradio interface"""
276
- global classifier
 
 
 
 
 
 
 
 
 
 
277
 
278
- # Initialize classifier
279
  try:
280
- print("Loading model...")
281
- classifier = ArtifactClassifier()
282
  print("Model loaded successfully!")
 
 
283
  except Exception as e:
284
- print(f"Failed to load model: {e}")
285
- return
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- # Create interface
288
- interface = gr.Interface(
289
- fn=classify_image,
290
- inputs=gr.Image(type="pil", label="Upload Artifact Image"),
291
- outputs=gr.Textbox(label="Classification Results", lines=10),
292
- title="Artifact Classification",
293
- description="""
294
- Upload an image of an archaeological artifact to get AI-powered classification!
295
-
296
- Features:
297
- - Object type identification (coin, vase, statue, etc.)
298
- - Material classification (gold, silver, pottery, etc.)
299
- - Confidence scores for each prediction
300
- - GPU-accelerated processing (if available)
301
- - Auto-downloads model from HuggingFace Hub
302
- - Completely standalone - no training code needed
303
-
304
- Supported formats: JPG, PNG, JPEG
305
- """,
306
- article="""
307
- How to use:
308
- 1. Upload an artifact image using the file picker
309
- 2. Click "Submit" to run classification
310
- 3. View results with confidence scores and predictions
311
-
312
- Model trained on: British Museum artifact dataset
313
- Accuracy: ~71% for objects, ~62% for materials
314
- """,
315
- examples=[]
316
  )
317
 
318
- # Launch
319
- print("Starting Gradio interface...")
320
- interface.launch(
321
- server_name="0.0.0.0", # Allow external connections
322
- server_port=7860,
323
- share=False, # Set to True for public link
324
- debug=False
 
 
325
  )
326
 
 
 
 
 
 
 
 
 
 
327
  if __name__ == "__main__":
328
- create_interface()
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
 
 
3
  from PIL import Image
4
+ import torchvision.transforms as transforms
 
 
 
 
 
 
5
  import torch.nn as nn
6
  from torchvision import models
7
+ from typing import Dict, Tuple
8
+ import os
9
+
10
 
11
  class MultiOutputModel(nn.Module):
12
+ """Multi-output model for artifact classification (matches UI)"""
13
 
14
  def __init__(self, num_object_classes, num_material_classes, hidden_size=512):
15
  super(MultiOutputModel, self).__init__()
 
20
  self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
21
 
22
  # Freeze early layers for transfer learning
23
+ for param in list(self.backbone.parameters())[:-4]: # Unfreeze more layers for better fine-tuning
24
  param.requires_grad = False
25
 
26
  # Classification heads for each attribute
27
  self.object_classifier = nn.Linear(2048, num_object_classes)
28
  self.material_classifier = nn.Linear(2048, num_material_classes)
29
 
 
 
 
30
  def forward(self, x):
31
  # Extract features using backbone
32
  features = self.backbone(x)
33
  features = features.view(features.size(0), -1)
 
34
 
35
  # Get predictions for each attribute
36
  object_pred = self.object_classifier(features)
 
41
  'material': material_pred,
42
  }
43
 
 
44
 
 
 
 
 
45
 
 
 
 
 
46
 
47
+ def load_model(model_path: str) -> Tuple[torch.nn.Module, Dict[str, Dict[int, str]]]:
48
+ """Load the model from checkpoint and return model and label mappings."""
49
+ print(f"Loading model from {model_path}...")
50
+ checkpoint = torch.load(model_path, map_location="cpu")
 
 
51
 
52
+ # Get label mappings to determine number of classes
53
+ label_mappings = checkpoint.get('label_mappings', {})
54
+ num_object_classes = len(label_mappings.get('object_name', {}))
55
+ num_material_classes = len(label_mappings.get('material', {}))
56
 
57
+ if num_object_classes == 0:
58
+ print("Warning: No label mappings found, using fallback class counts")
59
+ num_object_classes, num_material_classes = 1018, 192
60
+
61
+ # Check model type based on state_dict keys to determine which architecture to use
62
+ model_state_dict = checkpoint.get('model_state_dict', {})
63
+ state_dict_keys = set(model_state_dict.keys())
64
+
65
+ # Only support v1 model (MultiOutputModel) with ResNet backbone
66
+ print(f"Loading v1 model (MultiOutputModel) with ResNet backbone")
67
+ model = MultiOutputModel(num_object_classes, num_material_classes)
68
+
69
+ # Load state dict
70
+ if 'model_state_dict' in checkpoint:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  model.load_state_dict(checkpoint['model_state_dict'])
72
+ else:
73
+ print("Warning: No model_state_dict found in checkpoint")
74
 
75
+ # Create reverse mappings (id2label)
76
+ reverse_mappings = {}
77
+ for attr, mapping in label_mappings.items():
78
+ reverse_mappings[attr] = {int(v): str(k) for k, v in mapping.items()}
79
+ print(f"Loaded {attr} mappings: {len(reverse_mappings[attr])} classes")
80
 
81
+ return model, reverse_mappings
 
 
 
82
 
 
 
 
 
 
83
 
84
+ def run_inference(model: torch.nn.Module, pixel_values: torch.Tensor, device: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
85
+ """Run inference on pixel_values and return predictions and confidences for both object_name and material."""
86
+ model.eval()
87
+ model.to(device)
88
+ pixel_values = pixel_values.to(device)
89
 
90
+ with torch.no_grad():
91
+ outputs = model(pixel_values)
 
92
 
93
+ # Handle different output formats
94
+ if isinstance(outputs, dict):
95
+ # Multi-output model format
96
+ if 'object_name' in outputs and 'material' in outputs:
97
+ logits_obj = outputs['object_name']
98
+ logits_mat = outputs['material']
99
+ else:
100
+ raise ValueError("Expected 'object_name' and 'material' in model outputs")
101
+ else:
102
+ raise ValueError("Expected dict output with 'object_name' and 'material' keys")
103
 
104
+ preds_obj = torch.argmax(logits_obj, dim=-1)
105
+ probs_obj = torch.softmax(logits_obj, dim=-1)
106
+ max_probs_obj = torch.max(probs_obj, dim=-1)[0]
107
 
108
+ preds_mat = torch.argmax(logits_mat, dim=-1)
109
+ probs_mat = torch.softmax(logits_mat, dim=-1)
110
+ max_probs_mat = torch.max(probs_mat, dim=-1)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ return preds_obj.cpu(), max_probs_obj.cpu(), preds_mat.cpu(), max_probs_mat.cpu()
 
113
 
 
 
114
 
115
+ # Global variables for model and label mappings
116
+ model = None
117
+ label_mappings = None
118
+ device = None
119
+
120
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
121
+ """Preprocess image for model inference."""
122
+ # Define transforms
123
+ transform = transforms.Compose([
124
+ transforms.Resize(256),
125
+ transforms.CenterCrop(224),
126
+ transforms.ToTensor(),
127
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
128
+ ])
129
+
130
+ # Apply transforms
131
+ image = image.convert('RGB')
132
+ tensor = transform(image).unsqueeze(0) # Add batch dimension
133
+
134
+ return tensor
135
+
136
+ def predict_artifact(image: Image.Image) -> tuple[str, float, str, float]:
137
+ """Predict object and material from image."""
138
+ global model, label_mappings, device
139
 
140
+ if model is None:
141
+ raise ValueError("Model not loaded. Please restart the application.")
 
 
 
142
 
143
+ # Preprocess image
144
+ pixel_values = preprocess_image(image)
 
 
145
 
146
+ # Run inference
147
+ preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, pixel_values, device)
148
+
149
+ # Get predictions
150
+ object_pred_id = preds_obj[0].item()
151
+ material_pred_id = preds_mat[0].item()
152
+ object_conf = confs_obj[0].item()
153
+ material_conf = confs_mat[0].item()
154
+
155
+ # Convert IDs to labels
156
+ object_name = label_mappings['object_name'].get(object_pred_id, f"class_{object_pred_id}")
157
+ material_name = label_mappings['material'].get(material_pred_id, f"class_{material_pred_id}")
158
+
159
+ return object_name, object_conf, material_name, material_conf
160
+
161
+ def gradio_predict(image):
162
+ """Gradio interface function."""
163
+ if image is None:
164
+ return "Please upload an image", "", "", ""
165
+
166
+ try:
167
+ object_name, object_conf, material_name, material_conf = predict_artifact(image)
168
+
169
+ # Format results
170
+ object_result = f"**{object_name}** ({object_conf:.1%} confidence)"
171
+ material_result = f"**{material_name}** ({material_conf:.1%} confidence)"
172
+
173
+ return object_result, material_result, f"{object_conf:.3f}", f"{material_conf:.3f}"
174
 
175
  except Exception as e:
176
+ return f"Error: {str(e)}", "", "", ""
177
 
178
+ def load_model_on_startup():
179
+ """Load model when the application starts."""
180
+ global model, label_mappings, device
181
+
182
+ # Set device
183
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
184
+
185
+ # Load model
186
+ model_path = "model/v1/best_model.pth"
187
+ if not os.path.exists(model_path):
188
+ print(f"Warning: Model file not found at {model_path}")
189
+ print("Please ensure the model file exists before running the application.")
190
+ return
191
 
 
192
  try:
193
+ model, label_mappings = load_model(model_path)
 
194
  print("Model loaded successfully!")
195
+ print(f"Object classes: {len(label_mappings.get('object_name', {}))}")
196
+ print(f"Material classes: {len(label_mappings.get('material', {}))}")
197
  except Exception as e:
198
+ print(f"Error loading model: {e}")
199
+
200
+ # Load model on startup
201
+ load_model_on_startup()
202
+
203
+ # Create Gradio interface
204
+ with gr.Blocks(title="Artifact Classification v1", theme=gr.themes.Soft()) as demo:
205
+ gr.Markdown("# 🏺 Artifact Classification Model v1")
206
+ gr.Markdown("Upload an image of an artifact to classify its **object type** and **material composition**.")
207
+
208
+ with gr.Row():
209
+ with gr.Column():
210
+ image_input = gr.Image(label="Upload Artifact Image", type="pil")
211
+ submit_btn = gr.Button("🔍 Classify Artifact", variant="primary")
212
 
213
+ with gr.Column():
214
+ gr.Markdown("### 📊 Classification Results")
215
+
216
+ object_output = gr.Markdown(label="**Object Type**")
217
+ material_output = gr.Markdown(label="**Material**")
218
+
219
+ with gr.Accordion("📈 Confidence Scores", open=False):
220
+ object_conf = gr.Textbox(label="Object Confidence", interactive=False)
221
+ material_conf = gr.Textbox(label="Material Confidence", interactive=False)
222
+
223
+ # Connect the interface
224
+ submit_btn.click(
225
+ fn=gradio_predict,
226
+ inputs=image_input,
227
+ outputs=[object_output, material_output, object_conf, material_conf]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  )
229
 
230
+ # Example images
231
+ gr.Examples(
232
+ examples=[
233
+ # You can add example image paths here if available
234
+ ],
235
+ inputs=image_input,
236
+ outputs=[object_output, material_output, object_conf, material_conf],
237
+ fn=gradio_predict,
238
+ cache_examples=False
239
  )
240
 
241
+ gr.Markdown("""
242
+ ### ℹ️ About
243
+ This model uses a ResNet-50 backbone to classify museum artifacts into object types (vase, statue, pottery, etc.)
244
+ and material compositions (ceramic, bronze, stone, etc.).
245
+
246
+ **Model**: MultiOutputModel with ResNet-50 backbone
247
+ **Training Data**: Oriental Museum artifacts dataset
248
+ """)
249
+
250
  if __name__ == "__main__":
251
+ demo.launch()