File size: 2,385 Bytes
aaf5020
2cd690b
 
5738620
 
2cd690b
9ec767e
3471015
2cd690b
 
 
 
3471015
2cd690b
 
 
3471015
2cd690b
 
 
 
3471015
2cd690b
 
 
 
5738620
2cd690b
 
 
 
 
 
 
5738620
2cd690b
3471015
2cd690b
 
 
 
 
c0eb6b0
2cd690b
5738620
2cd690b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!pip install -q torch torchvision transformers huggingface_hub pillow gradio

import torch, json, gradio as gr
from PIL import Image
from torchvision import transforms
from transformers import ViTModel
from huggingface_hub import hf_hub_download

# --------------------- تنظیمات ---------------------
SEQ_LEN=32
REPO="hackergeek/RADIOCAP13"
device="cuda" if torch.cuda.is_available() else "cpu"

# دانلود وزن‌ها و vocab
w=hf_hub_download(REPO,"pytorch_model.bin")
v=hf_hub_download(REPO,"vocab.json")

# --------------------- Tokenizer ---------------------
with open(v) as f: vocab=json.load(f)
idx2word={v:k for k,v in vocab.items()}
pad=vocab.get("<PAD>",0); sos=vocab.get("<SOS>",1); eos=vocab.get("<EOS>",2)

# --------------------- مدل‌ها ---------------------
vit=ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device).eval()
class D(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb=torch.nn.Embedding(75460,768)
        self.pos_emb=torch.nn.Embedding(SEQ_LEN-1,768)
        self.final_layer=torch.nn.Linear(768,75460)
    def forward(self,f,s):
        x=self.token_emb(s)
        x+=self.pos_emb(torch.arange(x.size(1),device=x.device).clamp(max=self.pos_emb.num_embeddings-1))
        x+=f.unsqueeze(1)
        return self.final_layer(x)
decoder=D().to(device); decoder.load_state_dict(torch.load(w,map_location=device)); decoder.eval()

# --------------------- پیش‌پردازش تصویر ---------------------
def preprocess(img):
    if img.mode!="RGB": img=img.convert("RGB")
    t=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
    return t(img).unsqueeze(0).to(device)

# --------------------- تولید کپشن ---------------------
@torch.no_grad()
def caption(img):
    f=vit(pixel_values=preprocess(img)).pooler_output
    seq=[sos]
    for _ in range(SEQ_LEN-1):
        inp=torch.tensor(seq+[pad]*(SEQ_LEN-len(seq)),device=device).unsqueeze(0)
        logits=decoder(f,inp)
        nxt=torch.argmax(logits[0,len(seq)-1]).item()
        seq.append(nxt)
        if nxt==eos: break
    return " ".join(idx2word.get(i,"<UNK>") for i in seq[1:] if i not in [pad,sos,eos])

# --------------------- رابط Gradio ---------------------
gr.Interface(fn=caption, inputs=gr.Image(type="pil"), outputs="text", title="RADIOCAP13 Captioning").launch()