Spaces:
Running
Running
| import argparse | |
| import glob | |
| import os | |
| import cv2 | |
| import numpy | |
| import torch | |
| from PIL import Image | |
| from Model import TRCaptionNet, clip_transform | |
| def demo(opt): | |
| preprocess = clip_transform(224) | |
| model = TRCaptionNet({ | |
| "max_length": 35, | |
| "clip": "ViT-L/14", | |
| "bert": "dbmdz/bert-base-turkish-cased", | |
| "proj": True, | |
| "proj_num_head": 16 | |
| }) | |
| device = torch.device(opt.device) | |
| model.load_state_dict(torch.load(opt.model_ckpt, map_location=device)["model"], strict=True) | |
| model = model.to(device) | |
| model.eval() | |
| image_paths = glob.glob(os.path.join(opt.input_dir, '*.jpg')) | |
| for image_path in sorted(image_paths): | |
| img_name = image_path.split('/')[-1] | |
| img0 = Image.open(image_path) | |
| batch = preprocess(img0).unsqueeze(0).to(device) | |
| caption = model.generate(batch, min_length=11, repetition_penalty=1.6)[0] | |
| print(f"{img_name} :", caption) | |
| orj_img = numpy.array(img0)[:, :, ::-1] | |
| h, w, _ = orj_img.shape | |
| new_h = 800 | |
| new_w = int(new_h * (w / h)) | |
| orj_img = cv2.resize(orj_img, (new_w, new_h)) | |
| cv2.imshow("image", orj_img) | |
| cv2.waitKey(0) | |
| return | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Turkish-Image-Captioning!') | |
| parser.add_argument('--model-ckpt', type=str, default='./checkpoints/TRCaptionNet_L14_berturk.pth') | |
| parser.add_argument('--input-dir', type=str, default='./images/') | |
| parser.add_argument('--device', type=str, default='cuda:0') | |
| args = parser.parse_args() | |
| demo(args) | |