| import torch | |
| model = torch.load('model_best_refcoco_0508.pth', map_location='cpu') | |
| print(model['model'].keys()) | |
| new_dict = {} | |
| for k in model['model'].keys(): | |
| if 'image_model' in k or 'language_model' in k or 'classifier' in k: | |
| new_dict[k] = model['model'][k] | |
| #torch.save('gradio.pth', new_dict) | |
| torch.save(new_dict, 'gradio.pth') | |