Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| ''' | |
| @File : demo_metauas.py | |
| @Time : 2025/03/26 23:49:14 | |
| @Author : Bin-Bin Gao | |
| @Email : csgaobb@gmail.com | |
| @Homepage: https://csgaobb.github.io/ | |
| @Version : 1.0 | |
| @Desc : MetaUAS Demo | |
| ''' | |
| import os | |
| import cv2 | |
| import torch | |
| import json | |
| import shutil | |
| import kornia as K | |
| import numpy as np | |
| from easydict import EasyDict | |
| from argparse import ArgumentParser | |
| from metauas import MetaUAS, set_random_seed, normalize, apply_ad_scoremap, read_image_as_tensor, safely_load_state_dict | |
| if __name__ == "__main__": | |
| random_seed = 1 | |
| set_random_seed(random_seed) | |
| ckt_path = 'weights/metauas-256.ckpt' | |
| img_size = 256 | |
| #ckt_path = "weights/metauas-512.ckpt" | |
| #img_size = 512 | |
| # load model | |
| encoder = 'efficientnet-b4' | |
| decoder = 'unet' | |
| encoder_depth = 5 | |
| decoder_depth = 5 | |
| num_crossfa_layers = 3 | |
| alignment_type = 'sa' | |
| fusion_policy = 'cat' | |
| model = MetaUAS(encoder, | |
| decoder, | |
| encoder_depth, | |
| decoder_depth, | |
| num_crossfa_layers, | |
| alignment_type, | |
| fusion_policy | |
| ) | |
| model = safely_load_state_dict(model, ckt_path) | |
| model.cuda() | |
| model.eval() | |
| # load test images | |
| path_root = "./images/" | |
| path_to_prompt = path_root + "036.png" | |
| path_to_query = path_root + "024.png" | |
| query = read_image_as_tensor(path_to_query) | |
| prompt = read_image_as_tensor(path_to_prompt) | |
| if query.shape[1] != img_size: | |
| resize_trans = K.augmentation.Resize([img_size, img_size], return_transform=True) | |
| query = resize_trans(query)[0] | |
| prompt = resize_trans(prompt)[0] | |
| test_data = { | |
| "query_image": query.cuda(), | |
| "prompt_image": prompt.cuda(), | |
| } | |
| # forward | |
| predicted_masks = model(test_data) | |
| # visualization | |
| query_img = test_data["query_image"][0] * 255 | |
| query_img = query_img.permute(1,2,0) | |
| pred = (1-predicted_masks.squeeze().detach())[:, :, None].cpu().numpy().repeat(3, 2) | |
| # normalize just for analysis | |
| scoremap_self = apply_ad_scoremap(query_img.cpu(), normalize(pred)) | |
| cv2.imwrite('./anomaly_map.jpg', scoremap_self) | |