File size: 9,301 Bytes
e337fdb
 
 
b6a884b
623fea8
 
b6a884b
 
 
623fea8
 
b6a884b
623fea8
 
 
 
 
 
 
 
 
 
b6a884b
623fea8
 
 
 
 
 
 
 
 
 
 
 
 
 
e337fdb
623fea8
 
 
 
e337fdb
 
 
 
b6a884b
 
 
 
e337fdb
b6a884b
 
 
 
e337fdb
b6a884b
 
 
 
 
 
 
 
 
 
 
 
 
 
e337fdb
b6a884b
 
e337fdb
b6a884b
 
 
 
 
e337fdb
b6a884b
e337fdb
 
b6a884b
 
 
 
 
e337fdb
b6a884b
 
e337fdb
b6a884b
 
 
 
 
 
 
 
 
 
e337fdb
b6a884b
 
 
e337fdb
b6a884b
 
 
e337fdb
b6a884b
e337fdb
 
b6a884b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e337fdb
b6a884b
 
e337fdb
b6a884b
 
e337fdb
b6a884b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e337fdb
 
b6a884b
e337fdb
b6a884b
 
 
 
 
 
 
4652775
 
b6a884b
 
4652775
b6a884b
e337fdb
 
b6a884b
e337fdb
b6a884b
 
e337fdb
b6a884b
 
 
 
 
 
 
 
 
 
 
 
 
 
e337fdb
b6a884b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e337fdb
 
b6a884b
 
 
 
 
 
 
 
 
e337fdb
 
b6a884b
 
 
 
 
 
 
 
 
e337fdb
b6a884b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
import torch.nn as nn
from torchvision import models
from typing import Dict, Tuple
import os


class MultiOutputModel(nn.Module):
    """Multi-output model for artifact classification (matches UI)"""

    def __init__(self, num_object_classes, num_material_classes, hidden_size=512):
        super(MultiOutputModel, self).__init__()

        # Use a pre-trained ResNet as backbone
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        # Remove the final classification layer
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])

        # Freeze early layers for transfer learning
        for param in list(self.backbone.parameters())[:-4]:  # Unfreeze more layers for better fine-tuning
            param.requires_grad = False

        # Classification heads for each attribute
        self.object_classifier = nn.Linear(2048, num_object_classes)
        self.material_classifier = nn.Linear(2048, num_material_classes)

    def forward(self, x):
        # Extract features using backbone
        features = self.backbone(x)
        features = features.view(features.size(0), -1)

        # Get predictions for each attribute
        object_pred = self.object_classifier(features)
        material_pred = self.material_classifier(features)

        return {
            'object_name': object_pred,
            'material': material_pred,
        }




def load_model(model_path: str) -> Tuple[torch.nn.Module, Dict[str, Dict[int, str]]]:
    """Load the model from checkpoint and return model and label mappings."""
    print(f"Loading model from {model_path}...")
    checkpoint = torch.load(model_path, map_location="cpu")

    # Get label mappings to determine number of classes
    label_mappings = checkpoint.get('label_mappings', {})
    num_object_classes = len(label_mappings.get('object_name', {}))
    num_material_classes = len(label_mappings.get('material', {}))

    if num_object_classes == 0:
        print("Warning: No label mappings found, using fallback class counts")
        num_object_classes, num_material_classes = 1018, 192

    # Check model type based on state_dict keys to determine which architecture to use
    model_state_dict = checkpoint.get('model_state_dict', {})
    state_dict_keys = set(model_state_dict.keys())

    # Only support v1 model (MultiOutputModel) with ResNet backbone
    print(f"Loading v1 model (MultiOutputModel) with ResNet backbone")
    model = MultiOutputModel(num_object_classes, num_material_classes)

    # Load state dict
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("Warning: No model_state_dict found in checkpoint")

    # Create reverse mappings (id2label)
    reverse_mappings = {}
    for attr, mapping in label_mappings.items():
        reverse_mappings[attr] = {int(v): str(k) for k, v in mapping.items()}
        print(f"Loaded {attr} mappings: {len(reverse_mappings[attr])} classes")

    return model, reverse_mappings


def run_inference(model: torch.nn.Module, pixel_values: torch.Tensor, device: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Run inference on pixel_values and return predictions and confidences for both object_name and material."""
    model.eval()
    model.to(device)
    pixel_values = pixel_values.to(device)

    with torch.no_grad():
        outputs = model(pixel_values)

        # Handle different output formats
        if isinstance(outputs, dict):
            # Multi-output model format
            if 'object_name' in outputs and 'material' in outputs:
                logits_obj = outputs['object_name']
                logits_mat = outputs['material']
            else:
                raise ValueError("Expected 'object_name' and 'material' in model outputs")
        else:
            raise ValueError("Expected dict output with 'object_name' and 'material' keys")

        preds_obj = torch.argmax(logits_obj, dim=-1)
        probs_obj = torch.softmax(logits_obj, dim=-1)
        max_probs_obj = torch.max(probs_obj, dim=-1)[0]

        preds_mat = torch.argmax(logits_mat, dim=-1)
        probs_mat = torch.softmax(logits_mat, dim=-1)
        max_probs_mat = torch.max(probs_mat, dim=-1)[0]

    return preds_obj.cpu(), max_probs_obj.cpu(), preds_mat.cpu(), max_probs_mat.cpu()


# Global variables for model and label mappings
model = None
label_mappings = None
device = None

def preprocess_image(image: Image.Image) -> torch.Tensor:
    """Preprocess image for model inference."""
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Apply transforms
    image = image.convert('RGB')
    tensor = transform(image).unsqueeze(0)  # Add batch dimension

    return tensor

def predict_artifact(image: Image.Image) -> tuple[str, float, str, float]:
    """Predict object and material from image."""
    global model, label_mappings, device

    if model is None:
        raise ValueError("Model not loaded. Please restart the application.")

    # Preprocess image
    pixel_values = preprocess_image(image)

    # Run inference
    preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, pixel_values, device)

    # Get predictions
    object_pred_id = preds_obj[0].item()
    material_pred_id = preds_mat[0].item()
    object_conf = confs_obj[0].item()
    material_conf = confs_mat[0].item()

    # Convert IDs to labels
    object_name = label_mappings['object_name'].get(object_pred_id, f"class_{object_pred_id}")
    material_name = label_mappings['material'].get(material_pred_id, f"class_{material_pred_id}")

    return object_name, object_conf, material_name, material_conf

def gradio_predict(image):
    """Gradio interface function."""
    if image is None:
        return "Please upload an image", "", "", ""

    try:
        object_name, object_conf, material_name, material_conf = predict_artifact(image)

        # Format results
        object_result = f"**{object_name}** ({object_conf:.1%} confidence)"
        material_result = f"**{material_name}** ({material_conf:.1%} confidence)"

        return object_result, material_result, f"{object_conf:.3f}", f"{material_conf:.3f}"

    except Exception as e:
        return f"Error: {str(e)}", "", "", ""

def load_model_on_startup():
    """Load model when the application starts."""
    global model, label_mappings, device

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model from model.pth
    model_path = "model.pth"
    if not os.path.exists(model_path):
        print(f"Warning: Model file not found at {model_path}")
        print("Please ensure the model.pth file exists in the current directory before running the application.")
        return

    try:
        model, label_mappings = load_model(model_path)
        print("Model loaded successfully!")
        print(f"Object classes: {len(label_mappings.get('object_name', {}))}")
        print(f"Material classes: {len(label_mappings.get('material', {}))}")
    except Exception as e:
        print(f"Error loading model: {e}")

# Load model on startup
load_model_on_startup()

# Create Gradio interface
with gr.Blocks(title="Artifact Classification v1", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🏺 Artifact Classification Model v1")
    gr.Markdown("Upload an image of an artifact to classify its **object type** and **material composition**.")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Artifact Image", type="pil")
            submit_btn = gr.Button("πŸ” Classify Artifact", variant="primary")

        with gr.Column():
            gr.Markdown("### πŸ“Š Classification Results")

            object_output = gr.Markdown(label="**Object Type**")
            material_output = gr.Markdown(label="**Material**")

            with gr.Accordion("πŸ“ˆ Confidence Scores", open=False):
                object_conf = gr.Textbox(label="Object Confidence", interactive=False)
                material_conf = gr.Textbox(label="Material Confidence", interactive=False)

    # Connect the interface
    submit_btn.click(
        fn=gradio_predict,
        inputs=image_input,
        outputs=[object_output, material_output, object_conf, material_conf]
    )

    # Example images
    gr.Examples(
        examples=[
            # You can add example image paths here if available
        ],
        inputs=image_input,
        outputs=[object_output, material_output, object_conf, material_conf],
        fn=gradio_predict,
        cache_examples=False
    )

    gr.Markdown("""
    ### ℹ️ About
    This model uses a ResNet-50 backbone to classify museum artifacts into object types (vase, statue, pottery, etc.)
    and material compositions (ceramic, bronze, stone, etc.).

    **Model**: MultiOutputModel with ResNet-50 backbone
    **Training Data**: Oriental Museum artifacts dataset
    """)

if __name__ == "__main__":
    demo.launch()