| from transformers import Owlv2TextModel, Owlv2Processor, AutoTokenizer |
| import json |
| import torch |
| from torch import nn |
| import tqdm |
|
|
| embed_dict = nn.ParameterDict() |
| bsz = 8 |
|
|
| with open("id_to_str.json") as f: |
| data = json.load(f) |
|
|
| keys = list(data.keys()) |
| bar = tqdm.tqdm(range(len(keys)//bsz)) |
|
|
| proc = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") |
| tokenizer = AutoTokenizer.from_pretrained("google/owlv2-base-patch16-ensemble") |
| model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16-ensemble") |
|
|
| for i in bar: |
| batch = [data[key].replace("_", " ") for key in keys[i*bsz:(i+1)*bsz]] |
| tokenized = tokenizer(batch) |
| for k in range(bsz): |
| if len(tokenized[k]) > 16: |
| tokenizer.decode(tokenized[k]) |
|
|
| batch = proc(text=batch, return_tensors="pt") |
| output = model(**batch) |
| for k, key in enumerate(keys[i*bsz:(i+1)*bsz]): |
| embed_dict[key] = output.pooler_output[k, :] |
|
|
| torch.save(embed_dict.state_dict(), "embeds.pt") |