| import torch |
| from PIL import Image |
| import torchvision.transforms as T |
| from model import get_model |
| import os |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| self.model = get_model() |
| weights_path = os.path.join(path, "doge_223_sd-03.bin") |
| self.model.load_state_dict(torch.load(weights_path, map_location="cpu")) |
| self.model.eval() |
| |
| |
| self.transform = T.Compose([ |
| T.Resize((224, 224)), |
| T.ToTensor(), |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| def __call__(self, data): |
| inputs = data.pop("inputs", data) |
| |
| image = Image.open(inputs).convert("RGB") |
| tensor = self.transform(image).unsqueeze(0) |
| |
| with torch.no_grad(): |
| outputs = self.model(tensor) |
| prediction = torch.argmax(outputs, dim=1).item() |
| |
| return {"label": prediction} |