Spaces:
Running
Running
File size: 3,240 Bytes
a8da7b7 1faeebc a1b4ce7 a8da7b7 1faeebc a8da7b7 1faeebc ce46e86 1faeebc a8da7b7 1faeebc a8da7b7 ce46e86 a8da7b7 1faeebc a8da7b7 ce46e86 a8da7b7 1faeebc a8da7b7 1faeebc a8da7b7 ce46e86 a8da7b7 ce46e86 a8da7b7 ce46e86 a8da7b7 1faeebc a1b4ce7 1faeebc a1b4ce7 1faeebc a1b4ce7 1faeebc ce46e86 fe05156 1faeebc a1b4ce7 1faeebc a8da7b7 1faeebc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import os
# === Simple CNN Model Definition ===
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = F.relu(self.fc1(x))
return self.fc2(x)
# === Model Loading ===
model = SimpleCNN()
model_path = 'simple_cnn_dclr_tuned.pth'
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
print(f"Model loaded successfully from {model_path}")
else:
print(f"Warning: Model file '{model_path}' not found. Please run train_dclr_model.py first.")
# === CIFAR-10 Class Labels ===
class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
# === Image Preprocessing ===
preprocess = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
# === Inference Function ===
def inference(input_image: Image.Image):
if model.training:
model.eval()
processed_image = preprocess(input_image).unsqueeze(0)
with torch.no_grad():
outputs = model(processed_image)
probabilities = F.softmax(outputs, dim=1)
confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
return confidences
# === Results Viewer Function ===
def show_results(input_image: Image.Image):
preds = inference(input_image)
# Load plots if they exist
perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None
acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None
# Load final test accuracy number
test_acc_text = "Final test accuracy not available."
if os.path.exists("final_test_accuracy.txt"):
with open("final_test_accuracy.txt", "r") as f:
test_acc_value = f.read().strip()
test_acc_text = f"Final Test Accuracy: {test_acc_value}%"
return preds, perf_plot, acc_plot, test_acc_text
# === Gradio Interface Setup ===
example_images = []
interface = gr.Interface(
fn=show_results,
inputs=gr.Image(type='pil', label='Upload Image'),
outputs=[
gr.Label(num_top_classes=3, label='Predictions'),
gr.Image(type='filepath', label='Training Performance'),
gr.Image(type='filepath', label='Final Test Accuracy Plot'),
gr.Textbox(label='Final Test Accuracy')
],
title='CIFAR-10 Image Classification with DCLR Optimizer',
description='Upload an image to see predictions. Training/test plots and accuracy show benchmark results on CIFAR-10.',
examples=example_images
)
if __name__ == '__main__':
interface.launch()
|