Spaces:
Sleeping
Sleeping
File size: 2,385 Bytes
aaf5020 2cd690b 5738620 2cd690b 9ec767e 3471015 2cd690b 3471015 2cd690b 3471015 2cd690b 3471015 2cd690b 5738620 2cd690b 5738620 2cd690b 3471015 2cd690b c0eb6b0 2cd690b 5738620 2cd690b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
#!pip install -q torch torchvision transformers huggingface_hub pillow gradio
import torch, json, gradio as gr
from PIL import Image
from torchvision import transforms
from transformers import ViTModel
from huggingface_hub import hf_hub_download
# --------------------- تنظیمات ---------------------
SEQ_LEN=32
REPO="hackergeek/RADIOCAP13"
device="cuda" if torch.cuda.is_available() else "cpu"
# دانلود وزنها و vocab
w=hf_hub_download(REPO,"pytorch_model.bin")
v=hf_hub_download(REPO,"vocab.json")
# --------------------- Tokenizer ---------------------
with open(v) as f: vocab=json.load(f)
idx2word={v:k for k,v in vocab.items()}
pad=vocab.get("<PAD>",0); sos=vocab.get("<SOS>",1); eos=vocab.get("<EOS>",2)
# --------------------- مدلها ---------------------
vit=ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device).eval()
class D(torch.nn.Module):
def __init__(self):
super().__init__()
self.token_emb=torch.nn.Embedding(75460,768)
self.pos_emb=torch.nn.Embedding(SEQ_LEN-1,768)
self.final_layer=torch.nn.Linear(768,75460)
def forward(self,f,s):
x=self.token_emb(s)
x+=self.pos_emb(torch.arange(x.size(1),device=x.device).clamp(max=self.pos_emb.num_embeddings-1))
x+=f.unsqueeze(1)
return self.final_layer(x)
decoder=D().to(device); decoder.load_state_dict(torch.load(w,map_location=device)); decoder.eval()
# --------------------- پیشپردازش تصویر ---------------------
def preprocess(img):
if img.mode!="RGB": img=img.convert("RGB")
t=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
return t(img).unsqueeze(0).to(device)
# --------------------- تولید کپشن ---------------------
@torch.no_grad()
def caption(img):
f=vit(pixel_values=preprocess(img)).pooler_output
seq=[sos]
for _ in range(SEQ_LEN-1):
inp=torch.tensor(seq+[pad]*(SEQ_LEN-len(seq)),device=device).unsqueeze(0)
logits=decoder(f,inp)
nxt=torch.argmax(logits[0,len(seq)-1]).item()
seq.append(nxt)
if nxt==eos: break
return " ".join(idx2word.get(i,"<UNK>") for i in seq[1:] if i not in [pad,sos,eos])
# --------------------- رابط Gradio ---------------------
gr.Interface(fn=caption, inputs=gr.Image(type="pil"), outputs="text", title="RADIOCAP13 Captioning").launch() |