File size: 3,497 Bytes
a8da7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import os # Import os to check for model file

# === Simple CNN Model Definition ===
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# === Model Loading ===
model = SimpleCNN()
model_path = 'simple_cnn_dclr_tuned.pth'

# Check if the model file exists before loading
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval() # Set model to evaluation mode
    print(f"Model loaded successfully from {model_path}")
else:
    print(f"Warning: Model file '{model_path}' not found. Please ensure 'train_dclr_model.py' has been run.")
    # Optionally, you might want to exit or raise an error if the model is crucial


# === CIFAR-10 Class Labels ===
class_labels = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# === Image Preprocessing ===
preprocess = transforms.Compose([
    transforms.Resize(32), # CIFAR-10 images are 32x32
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats are common
])

# === Inference Function ===
def inference(input_image: Image.Image):
    if model.training: # Ensure model is in eval mode
        model.eval()

    # Preprocess the image
    processed_image = preprocess(input_image)
    # Add a batch dimension
    processed_image = processed_image.unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        outputs = model(processed_image)
        probabilities = F.softmax(outputs, dim=1)

    # Convert probabilities to a dictionary of class labels and scores
    confidences = {class_labels[i]: float(probabilities[0, i]) for i in range(len(class_labels))}
    return confidences

# === Gradio Interface Setup ===
# Example images (replace with actual paths if available, or keep as dummy for now)
# For a Hugging Face Space, you might place example images in an 'examples/' directory.
example_images = [
    # os.path.join(os.path.dirname(__file__), "examples/example_car.png"),
    # os.path.join(os.path.dirname(__file__), "examples/example_dog.png"),
    # os.path.join(os.path.dirname(__file__), "examples/example_plane.png")
]

# A placeholder for example images since we don't have them generated yet.
# Users can upload their own or I will add some placeholder images if needed in the next step.
# For now, an empty list of examples is fine.

interface = gr.Interface(
    fn=inference,
    inputs=gr.Image(type='pil', label='Input Image'),
    outputs=gr.Label(num_top_classes=3, label='Predictions'),
    title='CIFAR-10 Image Classification with DCLR Optimizer',
    description='Upload an image and see the model\'s predictions using a SimpleCNN trained with the DCLR optimizer.',
    examples=example_images,
    allow_flagging='never'
)

# === Launch Gradio App ===
if __name__ == '__main__':
    interface.launch()