hackergeek commited on
Commit
2cd690b
·
verified ·
1 Parent(s): 9ec767e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -125
app.py CHANGED
@@ -1,137 +1,58 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import ViTModel
4
  from PIL import Image
5
  from torchvision import transforms
6
- import json
7
- import os
8
  from huggingface_hub import hf_hub_download
9
 
10
- # ---------------------
11
- # Config
12
- # ---------------------
13
- IMG_SIZE = 224
14
- SEQ_LEN = 32
15
- VOCAB_SIZE = 75460
16
- REPO_ID = "hackergeek/RADIOCAP13" # your HF repo
17
- WEIGHTS_FILENAME = "pytorch_model.bin"
18
- VOCAB_FILENAME = "vocab.json"
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
- # ---------------------
22
- # Download model files (if not present)
23
- # ---------------------
24
- # Download weights
25
- weights_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILENAME)
26
- # Download vocab
27
- vocab_path = hf_hub_download(repo_id=REPO_ID, filename=VOCAB_FILENAME)
28
-
29
- # ---------------------
30
- # Preprocessing & Tokenizer
31
- # ---------------------
32
- transform = transforms.Compose([
33
- transforms.Resize((IMG_SIZE, IMG_SIZE)),
34
- transforms.ToTensor(),
35
- ])
36
 
37
- def preprocess_image(img):
38
- if img is None:
39
- raise ValueError("Image is None")
40
- if not isinstance(img, Image.Image):
41
- img = Image.fromarray(img)
42
- if img.mode != "RGB":
43
- img = img.convert("RGB")
44
- return transform(img)
45
 
46
- class SimpleTokenizer:
47
- def __init__(self, word2idx=None):
48
- self.word2idx = word2idx or {}
49
- self.idx2word = {v: k for k, v in self.word2idx.items()}
50
 
51
- @classmethod
52
- def load(cls, path):
53
- with open(path, "r") as f:
54
- word2idx = json.load(f)
55
- return cls(word2idx)
56
-
57
- # ---------------------
58
- # Decoder
59
- # ---------------------
60
- class BiasDecoder(torch.nn.Module):
61
- def __init__(self, feature_dim=768, vocab_size=VOCAB_SIZE):
62
  super().__init__()
63
- self.token_emb = torch.nn.Embedding(vocab_size, feature_dim)
64
- self.pos_emb = torch.nn.Embedding(SEQ_LEN-1, feature_dim)
65
- self.final_layer = torch.nn.Linear(feature_dim, vocab_size)
66
-
67
- def forward(self, img_feat, target_seq):
68
- x = self.token_emb(target_seq)
69
- pos = torch.arange(x.size(1), device=x.device).clamp(max=self.pos_emb.num_embeddings-1)
70
- x = x + self.pos_emb(pos)
71
- x = x + img_feat.unsqueeze(1)
72
  return self.final_layer(x)
 
73
 
74
- # ---------------------
75
- # Load models
76
- # ---------------------
77
- vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
78
- vit.eval()
79
-
80
- decoder = BiasDecoder().to(device)
81
- decoder.load_state_dict(torch.load(weights_path, map_location=device))
82
- decoder.eval()
83
 
84
- tokenizer = SimpleTokenizer.load(vocab_path)
85
- pad_idx = tokenizer.word2idx["<PAD>"]
86
-
87
- # ---------------------
88
- # Caption generation
89
- # ---------------------
90
  @torch.no_grad()
91
- def generate_caption(img, max_len=SEQ_LEN, beam_size=3):
92
- img_tensor = preprocess_image(img).unsqueeze(0).to(device)
93
- img_feat = vit(pixel_values=img_tensor).pooler_output
94
-
95
- beams = [([tokenizer.word2idx["<SOS>"]], 0.0)]
96
-
97
- for _ in range(max_len - 1):
98
- candidates = []
99
- for seq, score in beams:
100
- inp = torch.tensor(seq + [pad_idx] * (SEQ_LEN - len(seq)), device=device).unsqueeze(0)
101
- logits = decoder(img_feat, inp)
102
- probs = torch.nn.functional.log_softmax(logits[0, len(seq)-1], dim=-1)
103
- top_p, top_i = torch.topk(probs, beam_size)
104
-
105
- for i in range(beam_size):
106
- candidates.append((seq + [top_i[i].item()], score + top_p[i].item()))
107
-
108
- beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
109
-
110
- if all(s[-1] == tokenizer.word2idx["<EOS>"] for s, _ in beams):
111
- break
112
-
113
- words = [tokenizer.idx2word.get(i, "<UNK>") for i in beams[0][0][1:] if i != pad_idx]
114
- return " ".join(words)
115
-
116
- # ---------------------
117
- # Gradio interface
118
- # ---------------------
119
- with gr.Blocks() as demo:
120
- gr.Markdown("# RADIOCAP13 — Image Captioning Demo")
121
- gr.Markdown(f"**Device:** {'GPU 🚀' if torch.cuda.is_available() else 'CPU 🐢'}")
122
-
123
- img_in = gr.Image(type="pil", label="Upload an Image")
124
- out = gr.Textbox(label="Generated Caption")
125
- btn = gr.Button("Generate Caption")
126
- status = gr.Markdown("Ready.")
127
-
128
- def wrapped(img):
129
- status.update("Processing…")
130
- caption = generate_caption(img)
131
- status.update("Done ✔️")
132
- return caption
133
-
134
- btn.click(wrapped, inputs=img_in, outputs=out)
135
-
136
- if __name__ == "__main__":
137
- demo.launch()
 
1
+ !pip install -q torch torchvision transformers huggingface_hub pillow gradio
2
+
3
+ import torch, json, gradio as gr
4
  from PIL import Image
5
  from torchvision import transforms
6
+ from transformers import ViTModel
 
7
  from huggingface_hub import hf_hub_download
8
 
9
+ # --------------------- تنظیمات ---------------------
10
+ SEQ_LEN=32
11
+ REPO="hackergeek/RADIOCAP13"
12
+ device="cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # دانلود وزن‌ها و vocab
15
+ w=hf_hub_download(REPO,"pytorch_model.bin")
16
+ v=hf_hub_download(REPO,"vocab.json")
 
 
 
 
 
17
 
18
+ # --------------------- Tokenizer ---------------------
19
+ with open(v) as f: vocab=json.load(f)
20
+ idx2word={v:k for k,v in vocab.items()}
21
+ pad=vocab.get("<PAD>",0); sos=vocab.get("<SOS>",1); eos=vocab.get("<EOS>",2)
22
 
23
+ # --------------------- مدل‌ها ---------------------
24
+ vit=ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device).eval()
25
+ class D(torch.nn.Module):
26
+ def __init__(self):
 
 
 
 
 
 
 
27
  super().__init__()
28
+ self.token_emb=torch.nn.Embedding(75460,768)
29
+ self.pos_emb=torch.nn.Embedding(SEQ_LEN-1,768)
30
+ self.final_layer=torch.nn.Linear(768,75460)
31
+ def forward(self,f,s):
32
+ x=self.token_emb(s)
33
+ x+=self.pos_emb(torch.arange(x.size(1),device=x.device).clamp(max=self.pos_emb.num_embeddings-1))
34
+ x+=f.unsqueeze(1)
 
 
35
  return self.final_layer(x)
36
+ decoder=D().to(device); decoder.load_state_dict(torch.load(w,map_location=device)); decoder.eval()
37
 
38
+ # --------------------- پیش‌پردازش تصویر ---------------------
39
+ def preprocess(img):
40
+ if img.mode!="RGB": img=img.convert("RGB")
41
+ t=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
42
+ return t(img).unsqueeze(0).to(device)
 
 
 
 
43
 
44
+ # --------------------- تولید کپشن ---------------------
 
 
 
 
 
45
  @torch.no_grad()
46
+ def caption(img):
47
+ f=vit(pixel_values=preprocess(img)).pooler_output
48
+ seq=[sos]
49
+ for _ in range(SEQ_LEN-1):
50
+ inp=torch.tensor(seq+[pad]*(SEQ_LEN-len(seq)),device=device).unsqueeze(0)
51
+ logits=decoder(f,inp)
52
+ nxt=torch.argmax(logits[0,len(seq)-1]).item()
53
+ seq.append(nxt)
54
+ if nxt==eos: break
55
+ return " ".join(idx2word.get(i,"<UNK>") for i in seq[1:] if i not in [pad,sos,eos])
56
+
57
+ # --------------------- رابط Gradio ---------------------
58
+ gr.Interface(fn=caption, inputs=gr.Image(type="pil"), outputs="text", title="RADIOCAP13 Captioning").launch()