#!pip install -q torch torchvision transformers huggingface_hub pillow gradio import torch, json, gradio as gr from PIL import Image from torchvision import transforms from transformers import ViTModel from huggingface_hub import hf_hub_download # --------------------- تنظیمات --------------------- SEQ_LEN=32 REPO="hackergeek/RADIOCAP13" device="cuda" if torch.cuda.is_available() else "cpu" # دانلود وزن‌ها و vocab w=hf_hub_download(REPO,"pytorch_model.bin") v=hf_hub_download(REPO,"vocab.json") # --------------------- Tokenizer --------------------- with open(v) as f: vocab=json.load(f) idx2word={v:k for k,v in vocab.items()} pad=vocab.get("",0); sos=vocab.get("",1); eos=vocab.get("",2) # --------------------- مدل‌ها --------------------- vit=ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device).eval() class D(torch.nn.Module): def __init__(self): super().__init__() self.token_emb=torch.nn.Embedding(75460,768) self.pos_emb=torch.nn.Embedding(SEQ_LEN-1,768) self.final_layer=torch.nn.Linear(768,75460) def forward(self,f,s): x=self.token_emb(s) x+=self.pos_emb(torch.arange(x.size(1),device=x.device).clamp(max=self.pos_emb.num_embeddings-1)) x+=f.unsqueeze(1) return self.final_layer(x) decoder=D().to(device); decoder.load_state_dict(torch.load(w,map_location=device)); decoder.eval() # --------------------- پیش‌پردازش تصویر --------------------- def preprocess(img): if img.mode!="RGB": img=img.convert("RGB") t=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()]) return t(img).unsqueeze(0).to(device) # --------------------- تولید کپشن --------------------- @torch.no_grad() def caption(img): f=vit(pixel_values=preprocess(img)).pooler_output seq=[sos] for _ in range(SEQ_LEN-1): inp=torch.tensor(seq+[pad]*(SEQ_LEN-len(seq)),device=device).unsqueeze(0) logits=decoder(f,inp) nxt=torch.argmax(logits[0,len(seq)-1]).item() seq.append(nxt) if nxt==eos: break return " ".join(idx2word.get(i,"") for i in seq[1:] if i not in [pad,sos,eos]) # --------------------- رابط Gradio --------------------- gr.Interface(fn=caption, inputs=gr.Image(type="pil"), outputs="text", title="RADIOCAP13 Captioning").launch()