Spaces:
Runtime error
Runtime error
File size: 7,919 Bytes
fc0ff8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Code adapted from https://github.com/openai/CLIP/blob/main/
from transformers import CLIPProcessor, CLIPModel
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets_classes_templates import data_seeds
import numpy as np
from datetime import datetime
def zeroshot_classifier(classnames, templates, processor, model):
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [template.format(classname) for template in templates] #format with class
text_inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to('cuda')
class_embeddings = model.get_text_features(text_inputs['input_ids']) #embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
def classification_collate_fn(batch):
images, labels = zip(*batch)
labels = torch.tensor(labels)
return images, labels
def accuracy(output, target, topk=(1,)):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default=None, choices=['non_fine_tuned','MS_COCO','medium','base','all'], help='Data on which clip was fine-tuned')
parser.add_argument("--dataset", type=str, default="CIFAR10", choices=["CIFAR10", "CIFAR100", "ImageNet", "Caltech101", "Caltech256", "Food101"])
parser.add_argument("--method",type=str, default="COCO_CF", choices=['COCO_CF','APGD_1','APGD_4','NONE'])
args = parser.parse_args()
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
results_filename = f'./Results/fine_tuned_clip/zeroshot_image_classification_results_{args.dataset}_{args.data}_{args.method}_{current_time}.txt'
with open(results_filename, 'w') as f:
f.write(f'Arguments: {args}\n\n')
if args.data == 'MS_COCO':
assert args.method == 'NONE' and args.data == 'MS_COCO', 'Use NONE for method for MS_COCO data'
imagenet_path = '/software/ais2t/pytorch_datasets/imagenet/' # Fill the path for imagenet here
if args.dataset == "CIFAR10":
from datasets_classes_templates import CIFAR10_CLASSES_TEMPLATES as classes_templates
from torchvision.datasets import CIFAR10
data = CIFAR10(root='./image_classification_datasets/cifar10/', train=False, download=True)
elif args.dataset == "CIFAR100":
from datasets_classes_templates import CIFAR100_CLASSES_TEMPLATES as classes_templates
from torchvision.datasets import CIFAR100
data = CIFAR100(root='./image_classification_datasets/cifar100/', train=False, download=True)
elif args.dataset == "ImageNet":
from datasets_classes_templates import ImageNet_CLASSES_TEMPLATES as classes_templates
from torchvision.datasets import ImageNet
data = ImageNet(root=imagenet_path, split='val')
elif args.dataset == "Caltech101":
torch.manual_seed(42)
from datasets_classes_templates import Caltech101_CLASSES_TEMPLATES as classes_templates
from torchvision.datasets import Caltech101
data = Caltech101(root='./image_classification_datasets/', download=False)
train_size = int(0.8 * len(data)) # 80% for training
val_size = len(data) - train_size
_, data = torch.utils.data.random_split(data, [train_size, val_size])
elif args.dataset == "Caltech256":
torch.manual_seed(42)
from datasets_classes_templates import Caltech256_CLASSES_TEMPLATES as classes_templates
from torchvision.datasets import Caltech256
data = Caltech256(root='./image_classification_datasets/', download=False)
train_size = int(0.8 * len(data)) # 80% for training
val_size = len(data) - train_size
_, data = torch.utils.data.random_split(data, [train_size, val_size])
elif args.dataset == "Food101":
from datasets_classes_templates import Food101_CLASSES_TEMPLATES as classes_templates
from torchvision.datasets import Food101
data = Food101(root='./image_classification_datasets/food101/', download=True, split='test')
else:
raise NotImplementedError
print(f'Conducting zero-shot image classification on {args.dataset}')
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model_base_path = './fine_tuned_clip_models'
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
top1_list = []
for data_seed in data_seeds:
print(f'Conducting zero-shot image classification on {args.data} with seed {data_seed} for the method {args.method}')
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
if args.data != 'non_fine_tuned':
if args.method != 'NONE':
if args.data not in ['all']:
model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20_data_seed_{data_seed}.pt'))
else:
model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20.pt'))
elif args.method == 'NONE' and args.data == 'MS_COCO':
model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20.pt'))
model.eval()
data_loader = DataLoader(data, batch_size=128, collate_fn=classification_collate_fn, shuffle=False)
zeroshot_weights = zeroshot_classifier(classes_templates['classes'],
classes_templates['templates'],
processor,
model
)
with torch.no_grad():
top1, top5, n = 0., 0., 0.
for i, (images, target) in enumerate(tqdm(data_loader)):
target = target.to(device)
images = list(images)
images = processor(images=images, return_tensors="pt").to(device)
# predict
image_features = model.get_image_features(images['pixel_values']).to(device)
image_features /= image_features.norm(dim=-1, keepdim=True)
logits = 100. * image_features @ zeroshot_weights
# measure accuracy
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
top1 += acc1
top5 += acc5
n += image_features.size(0)
top1 = (top1 / n) * 100
top5 = (top5 / n) * 100
with open(results_filename, 'a') as f:
f.write(f'Seed {data_seed}: Top-1 Accuracy: {top1:.2f}, Top-5 Accuracy: {top5:.2f}\n')
top1_list.append(top1)
print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")
print('-'*40)
if args.method == 'NONE' or args.data in ['MS_COCO','all'] or args.data == 'non_fine_tuned':
break
top1 = np.asarray(top1_list)
print(f'Mean of the top 1 accuracy is {np.mean(top1)}')
print(f'Standard deviation of the top 1 accuracy is {np.std(top1)}')
with open(results_filename, 'a') as f:
f.write(f'\nMean Top-1 Accuracy: {np.mean(top1):.2f}\n')
f.write(f'Standard Deviation of Top-1 Accuracy: {np.std(top1):.2f}\n')
if __name__ == "__main__":
main()
|