zoolang / handler.py
vsimmer's picture
Upload handler.py
a20ad08 verified
import torch
from PIL import Image
import torchvision.transforms as T
from model import get_model
import os
class EndpointHandler():
def __init__(self, path=""):
# Load model and weights
self.model = get_model()
weights_path = os.path.join(path, "doge_223_sd-03.bin")
self.model.load_state_dict(torch.load(weights_path, map_location="cpu"))
self.model.eval()
# Define your specific transforms
self.transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __call__(self, data):
inputs = data.pop("inputs", data)
# Convert bytes to image
image = Image.open(inputs).convert("RGB")
tensor = self.transform(image).unsqueeze(0)
with torch.no_grad():
outputs = self.model(tensor)
prediction = torch.argmax(outputs, dim=1).item()
return {"label": prediction}