Spaces:
Sleeping
Sleeping
| #!pip install -q torch torchvision transformers huggingface_hub pillow gradio | |
| import torch, json, gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import ViTModel | |
| from huggingface_hub import hf_hub_download | |
| # --------------------- تنظیمات --------------------- | |
| SEQ_LEN=32 | |
| REPO="hackergeek/RADIOCAP13" | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| # دانلود وزنها و vocab | |
| w=hf_hub_download(REPO,"pytorch_model.bin") | |
| v=hf_hub_download(REPO,"vocab.json") | |
| # --------------------- Tokenizer --------------------- | |
| with open(v) as f: vocab=json.load(f) | |
| idx2word={v:k for k,v in vocab.items()} | |
| pad=vocab.get("<PAD>",0); sos=vocab.get("<SOS>",1); eos=vocab.get("<EOS>",2) | |
| # --------------------- مدلها --------------------- | |
| vit=ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device).eval() | |
| class D(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.token_emb=torch.nn.Embedding(75460,768) | |
| self.pos_emb=torch.nn.Embedding(SEQ_LEN-1,768) | |
| self.final_layer=torch.nn.Linear(768,75460) | |
| def forward(self,f,s): | |
| x=self.token_emb(s) | |
| x+=self.pos_emb(torch.arange(x.size(1),device=x.device).clamp(max=self.pos_emb.num_embeddings-1)) | |
| x+=f.unsqueeze(1) | |
| return self.final_layer(x) | |
| decoder=D().to(device); decoder.load_state_dict(torch.load(w,map_location=device)); decoder.eval() | |
| # --------------------- پیشپردازش تصویر --------------------- | |
| def preprocess(img): | |
| if img.mode!="RGB": img=img.convert("RGB") | |
| t=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()]) | |
| return t(img).unsqueeze(0).to(device) | |
| # --------------------- تولید کپشن --------------------- | |
| def caption(img): | |
| f=vit(pixel_values=preprocess(img)).pooler_output | |
| seq=[sos] | |
| for _ in range(SEQ_LEN-1): | |
| inp=torch.tensor(seq+[pad]*(SEQ_LEN-len(seq)),device=device).unsqueeze(0) | |
| logits=decoder(f,inp) | |
| nxt=torch.argmax(logits[0,len(seq)-1]).item() | |
| seq.append(nxt) | |
| if nxt==eos: break | |
| return " ".join(idx2word.get(i,"<UNK>") for i in seq[1:] if i not in [pad,sos,eos]) | |
| # --------------------- رابط Gradio --------------------- | |
| gr.Interface(fn=caption, inputs=gr.Image(type="pil"), outputs="text", title="RADIOCAP13 Captioning").launch() |