File size: 3,240 Bytes
a8da7b7
 
 
 
 
 
1faeebc
a1b4ce7
a8da7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1faeebc
a8da7b7
1faeebc
 
 
 
ce46e86
1faeebc
a8da7b7
1faeebc
a8da7b7
 
ce46e86
a8da7b7
1faeebc
a8da7b7
ce46e86
a8da7b7
1faeebc
a8da7b7
 
1faeebc
a8da7b7
ce46e86
a8da7b7
ce46e86
a8da7b7
 
 
ce46e86
a8da7b7
 
1faeebc
 
 
a1b4ce7
1faeebc
 
 
a1b4ce7
1faeebc
 
 
 
 
 
a1b4ce7
1faeebc
ce46e86
fe05156
1faeebc
a1b4ce7
1faeebc
 
 
 
 
 
 
 
 
 
 
 
 
a8da7b7
 
1faeebc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import os

# === Simple CNN Model Definition ===
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# === Model Loading ===
model = SimpleCNN()
model_path = 'simple_cnn_dclr_tuned.pth'

if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    print(f"Model loaded successfully from {model_path}")
else:
    print(f"Warning: Model file '{model_path}' not found. Please run train_dclr_model.py first.")

# === CIFAR-10 Class Labels ===
class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']

# === Image Preprocessing ===
preprocess = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# === Inference Function ===
def inference(input_image: Image.Image):
    if model.training:
        model.eval()
    processed_image = preprocess(input_image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(processed_image)
        probabilities = F.softmax(outputs, dim=1)
    confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
    return confidences

# === Results Viewer Function ===
def show_results(input_image: Image.Image):
    preds = inference(input_image)

    # Load plots if they exist
    perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None
    acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None

    # Load final test accuracy number
    test_acc_text = "Final test accuracy not available."
    if os.path.exists("final_test_accuracy.txt"):
        with open("final_test_accuracy.txt", "r") as f:
            test_acc_value = f.read().strip()
        test_acc_text = f"Final Test Accuracy: {test_acc_value}%"

    return preds, perf_plot, acc_plot, test_acc_text

# === Gradio Interface Setup ===
example_images = []

interface = gr.Interface(
    fn=show_results,
    inputs=gr.Image(type='pil', label='Upload Image'),
    outputs=[
        gr.Label(num_top_classes=3, label='Predictions'),
        gr.Image(type='filepath', label='Training Performance'),
        gr.Image(type='filepath', label='Final Test Accuracy Plot'),
        gr.Textbox(label='Final Test Accuracy')
    ],
    title='CIFAR-10 Image Classification with DCLR Optimizer',
    description='Upload an image to see predictions. Training/test plots and accuracy show benchmark results on CIFAR-10.',
    examples=example_images
)

if __name__ == '__main__':
    interface.launch()