Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,22 +17,49 @@ import os
|
|
| 17 |
import json
|
| 18 |
from pathlib import Path
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
import
|
| 22 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
import train
|
| 31 |
-
MultiOutputModel = train.MultiOutputModel
|
| 32 |
-
except ImportError as e:
|
| 33 |
-
print(f"Import error: {e}")
|
| 34 |
-
print("Make sure train.py exists in the train/ directory")
|
| 35 |
-
sys.exit(1)
|
| 36 |
|
| 37 |
class ArtifactClassifier:
|
| 38 |
def __init__(self, model_path="train/outputs/best_model.pth"):
|
|
@@ -270,8 +297,9 @@ def create_interface():
|
|
| 270 |
- Object type identification (coin, vase, statue, etc.)
|
| 271 |
- Material classification (gold, silver, pottery, etc.)
|
| 272 |
- Confidence scores for each prediction
|
| 273 |
-
- GPU-accelerated processing (
|
| 274 |
- Auto-downloads model from HuggingFace Hub
|
|
|
|
| 275 |
|
| 276 |
Supported formats: JPG, PNG, JPEG
|
| 277 |
""",
|
|
|
|
| 17 |
import json
|
| 18 |
from pathlib import Path
|
| 19 |
|
| 20 |
+
# Define the model architecture directly (standalone)
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from torchvision import models
|
| 24 |
+
|
| 25 |
+
class MultiOutputModel(nn.Module):
|
| 26 |
+
"""Multi-output model for artifact classification"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, num_object_classes, num_material_classes, hidden_size=512):
|
| 29 |
+
super(MultiOutputModel, self).__init__()
|
| 30 |
+
|
| 31 |
+
# Use a pre-trained ResNet as backbone
|
| 32 |
+
self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
| 33 |
+
# Remove the final classification layer
|
| 34 |
+
self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
|
| 35 |
+
|
| 36 |
+
# Freeze early layers for transfer learning
|
| 37 |
+
for param in list(self.backbone.parameters())[:-2]:
|
| 38 |
+
param.requires_grad = False
|
| 39 |
+
|
| 40 |
+
# Classification heads for each attribute
|
| 41 |
+
self.object_classifier = nn.Linear(2048, num_object_classes)
|
| 42 |
+
self.material_classifier = nn.Linear(2048, num_material_classes)
|
| 43 |
+
|
| 44 |
+
# Dropout for regularization
|
| 45 |
+
self.dropout = nn.Dropout(0.3)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
# Extract features using backbone
|
| 49 |
+
features = self.backbone(x)
|
| 50 |
+
features = features.view(features.size(0), -1)
|
| 51 |
+
features = self.dropout(features)
|
| 52 |
+
|
| 53 |
+
# Get predictions for each attribute
|
| 54 |
+
object_pred = self.object_classifier(features)
|
| 55 |
+
material_pred = self.material_classifier(features)
|
| 56 |
|
| 57 |
+
return {
|
| 58 |
+
'object_name': object_pred,
|
| 59 |
+
'material': material_pred,
|
| 60 |
+
}
|
| 61 |
|
| 62 |
+
print("MultiOutputModel class defined directly in app (standalone)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
class ArtifactClassifier:
|
| 65 |
def __init__(self, model_path="train/outputs/best_model.pth"):
|
|
|
|
| 297 |
- Object type identification (coin, vase, statue, etc.)
|
| 298 |
- Material classification (gold, silver, pottery, etc.)
|
| 299 |
- Confidence scores for each prediction
|
| 300 |
+
- GPU-accelerated processing (if available)
|
| 301 |
- Auto-downloads model from HuggingFace Hub
|
| 302 |
+
- Completely standalone - no training code needed
|
| 303 |
|
| 304 |
Supported formats: JPG, PNG, JPEG
|
| 305 |
""",
|