File size: 18,834 Bytes
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fcb13f
2945c35
1fcb13f
 
 
 
2945c35
1fcb13f
9c6a17f
2945c35
 
 
1fcb13f
 
 
 
2945c35
 
 
 
 
 
1fcb13f
 
 
 
2945c35
 
 
 
 
ecf7f6e
 
 
2945c35
1fcb13f
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fcb13f
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
1fcb13f
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fcb13f
2945c35
1fcb13f
9c6a17f
2945c35
 
 
ecf7f6e
 
2945c35
 
 
 
 
 
 
 
 
 
 
ecf7f6e
 
1fcb13f
 
 
ecf7f6e
1fcb13f
 
ecf7f6e
 
1fcb13f
ecf7f6e
 
 
1fcb13f
 
 
 
ecf7f6e
2945c35
 
1fcb13f
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fcb13f
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fcb13f
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c6a17f
2945c35
ecf7f6e
 
 
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fcb13f
2945c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14a25f4
1fcb13f
 
 
 
 
 
 
 
 
 
 
 
 
3ee5ad8
737e40d
2945c35
737e40d
14a25f4
737e40d
f7f51a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2945c35
1fcb13f
2945c35
 
 
 
 
1fcb13f
2945c35
 
 
 
 
 
1fcb13f
202241e
 
 
 
c4a9342
 
202241e
 
 
 
 
 
 
1fcb13f
f7f51a2
5275b4e
f7f51a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737e40d
f7f51a2
1fcb13f
2945c35
 
 
1fcb13f
2945c35
1fcb13f
 
 
 
 
 
c4a9342
2945c35
 
1fcb13f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
import gradio as gr
from transformers import AutoModel, AutoTokenizer
from medcrab import MedCrabTranslator
import torch
import os
import sys
import tempfile
import shutil
from PIL import Image, ImageOps
import fitz
import re
import time
from io import StringIO, BytesIO
import spaces

# ==================== CONFIG ====================
OCR_MODEL_NAME = 'deepseek-ai/DeepSeek-OCR'
MODEL_CONFIGS = {
    "Crab": {"base_size": 1024, "image_size": 640, "crop_mode": True},
    "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
}

# ==================== LOAD MODELS ====================
print("🔄 Loading OCR model...")
ocr_tokenizer = AutoTokenizer.from_pretrained(OCR_MODEL_NAME, trust_remote_code=True)
try:
    ocr_model = AutoModel.from_pretrained(
        OCR_MODEL_NAME,
        attn_implementation='flash_attention_2',
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_safetensors=True
    )
    print("✅ Using Flash Attention 2")
except (ImportError, ValueError):
    print("⚠️ Flash Attention 2 not available, using eager attention")
    ocr_model = AutoModel.from_pretrained(
        OCR_MODEL_NAME,
        attn_implementation='eager',
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_safetensors=True
    )
ocr_model = ocr_model.eval()

print("🦀 Loading MedCrab translator...")
device = "cuda" if torch.cuda.is_available() else "cpu"
translator = MedCrabTranslator(device=device)
print(f"✅ MedCrab translator loaded on {device}")

# ==================== TEXT CLEANING ====================
def clean_mathrm(text):
    if not text:
        return ""
    def process_math_block(match):
        math_content = match.group(1)
        math_content = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', math_content)
        math_content = re.sub(r'\^\{([^}]+)\}', r'<sup>\1</sup>', math_content)
        math_content = re.sub(r'\^([A-Za-z0-9+\-]+)', r'<sup>\1</sup>', math_content)
        math_content = re.sub(r'_\{([^}]+)\}', r'<sub>\1</sub>', math_content)
        math_content = re.sub(r'_([A-Za-z0-9+\-]+)', r'<sub>\1</sub>', math_content)
        replacements = {
            r'\times': '×', r'\pm': '±', r'\div': '÷', r'\cdot': '·',
            r'\approx': '≈', r'\leq': '≤', r'\geq': '≥', r'\neq': '≠',
            r'\rightarrow': '→', r'\leftarrow': '←',
            r'\Rightarrow': '⇒', r'\Leftarrow': '⇐',
        }
        for latex_cmd, unicode_char in replacements.items():
            math_content = math_content.replace(latex_cmd, unicode_char)
        return math_content

    text = re.sub(r'\\\((.+?)\\\)', process_math_block, text, flags=re.DOTALL)
    def process_bracket_block(m):
        class FakeMatch:
            def __init__(self, content):
                self.content = content
            def group(self, n):
                return self.content
        content = process_math_block(FakeMatch(m.group(1)))
        return '[' + content + ']'
    text = re.sub(r'\\\[(.+?)\\\]', process_bracket_block, text, flags=re.DOTALL)
    text = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', text)
    text = text.replace(r'\%', '%')
    lines = text.split('\n')
    cleaned_lines = [re.sub(r'[ \t]+', ' ', line).strip() for line in lines]
    return '\n'.join(cleaned_lines).strip()

def clean_output(text, include_images=False, remove_labels=False):
    if not text:
        return ""
    pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
    matches = re.findall(pattern, text, re.DOTALL)
    img_num = 0
    for match in matches:
        if '<|ref|>image<|/ref|>' in match[0]:
            if include_images:
                text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1)
                img_num += 1
            else:
                text = text.replace(match[0], '', 1)
        else:
            if remove_labels:
                text = text.replace(match[0], '', 1)
            else:
                text = text.replace(match[0], match[1], 1)
    return clean_mathrm(text).strip()

# ==================== OCR HELPERS ====================
@spaces.GPU
def ocr_process_image(image, mode="Crab"):
    if image is None:
        return "Error: Upload image"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ocr_model.to(device)
    if image.mode in ('RGBA', 'LA', 'P'):
        image = image.convert('RGB')
    image = ImageOps.exif_transpose(image)
    config = MODEL_CONFIGS[mode]
    prompt = "<image>\n<|grounding|>Convert the document to markdown."
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
    image.save(tmp.name, 'JPEG', quality=95)
    tmp.close()
    out_dir = tempfile.mkdtemp()
    stdout = sys.stdout
    sys.stdout = StringIO()
    try:
        ocr_model.infer(
            tokenizer=ocr_tokenizer,
            prompt=prompt,
            image_file=tmp.name,
            output_path=out_dir,
            base_size=config["base_size"],
            image_size=config["image_size"],
            crop_mode=config["crop_mode"]
        )
        result = '\n'.join([l for l in sys.stdout.getvalue().split('\n')
                            if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size'])]).strip()
    finally:
        sys.stdout = stdout
        try:
            os.unlink(tmp.name)
        except:
            pass
        shutil.rmtree(out_dir, ignore_errors=True)
    if not result:
        return "No text detected"
    return clean_output(result, True, True)

def ocr_process_pdf(path, mode, page_num):
    doc = fitz.open(path)
    total_pages = len(doc)
    if page_num < 1 or page_num > total_pages:
        doc.close()
        return f"Invalid page number. PDF has {total_pages} pages."
    page = doc.load_page(page_num - 1)
    pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
    img = Image.open(BytesIO(pix.tobytes("png")))
    doc.close()
    return ocr_process_image(img, mode)

def ocr_process_file(path, mode, page_num):
    if not path:
        return "Error: Upload file"
    if path.lower().endswith('.pdf'):
        return ocr_process_pdf(path, mode, page_num)
    else:
        return ocr_process_image(Image.open(path), mode)

# ==================== TRANSLATION HELPERS ====================
def split_by_sentences(text: str, max_words: int = 100):
    def count_words(t):
        return len(t.strip().split())
    chunks = []
    lines = text.split('\n')
    i = 0
    while i < len(lines):
        line = lines[i]
        empty_count = 0
        if not line.strip():
            while i < len(lines) and not lines[i].strip():
                empty_count += 1
                i += 1
            if chunks:
                prev_text, prev_newlines = chunks[-1]
                chunks[-1] = (prev_text, prev_newlines + empty_count)
            continue
        line = line.strip()
        is_last_line = (i == len(lines) - 1)
        if count_words(line) <= max_words:
            chunks.append((line, 0 if is_last_line else 1))
            i += 1
            continue
        sentences = re.split(r'(?<=[.!?])\s+', line)
        current_chunk = ""
        current_words = 0
        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue
            sentence_words = count_words(sentence)
            if sentence_words > max_words:
                if current_chunk:
                    chunks.append((current_chunk.strip(), 0))
                    current_chunk = ""
                    current_words = 0
                sub_parts = re.split(r',\s*', sentence)
                temp_chunk = ""
                temp_words = 0
                for part in sub_parts:
                    part_words = count_words(part)
                    if temp_words + part_words > max_words and temp_chunk:
                        chunks.append((temp_chunk.strip(), 0))
                        temp_chunk = part
                        temp_words = part_words
                    else:
                        if temp_chunk:
                            temp_chunk += ", " + part
                        else:
                            temp_chunk = part
                        temp_words += part_words
                if temp_chunk.strip():
                    current_chunk = temp_chunk.strip()
                    current_words = temp_words
            elif current_words + sentence_words <= max_words:
                if current_chunk:
                    current_chunk += " " + sentence
                else:
                    current_chunk = sentence
                current_words += sentence_words
            else:
                chunks.append((current_chunk.strip(), 0))
                current_chunk = sentence
                current_words = sentence_words
        if current_chunk.strip():
            chunks.append((current_chunk.strip(), 0 if is_last_line else 1))
        i += 1
    return chunks

@spaces.GPU
def translate_chunk(chunk_text):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if hasattr(translator, 'model') and hasattr(translator.model, 'to'):
        translator.model.to(device)
    return translator.translate(chunk_text, max_new_tokens=2048).strip()

def streaming_translate(text: str):
    if not text or not text.strip():
        yield '<div style="padding:20px; color:#ff6b6b;">⚠️ Vui lòng nhập văn bản tiếng Anh để dịch.</div>'
        return
    chunks = split_by_sentences(text, max_words=100)
    accumulated = ""
    for i, (chunk_text, newline_count) in enumerate(chunks):
        try:
            translated = translate_chunk(chunk_text)
            if accumulated and not accumulated.endswith('\n'):
                accumulated += " " + translated
            else:
                accumulated += translated
            chunk_start = len(accumulated) - len(translated)
            for j in range(len(translated)):
                current_display = accumulated[:chunk_start + j + 1]
                html_output = f'<div style="padding:20px; line-height:1.8; font-size:15px; white-space:pre-wrap; font-family:Arial,sans-serif;">{current_display}</div>'
                yield html_output
                time.sleep(0.015)
            if newline_count > 0:
                actual_newlines = min(newline_count, 2)
                accumulated += "\n" * actual_newlines
                html_output = f'<div style="padding:20px; line-height:1.8; font-size:15px; white-space:pre-wrap; font-family:Arial,sans-serif;">{accumulated}</div>'
                yield html_output
        except Exception as e:
            yield f'<div style="padding:20px; color:#ff6b6b;">❌ Lỗi dịch chunk {i+1}: {str(e)}</div>'
            return

# ==================== UI HELPERS ====================
def load_image(file_path, page_num_str="1"):
    if not file_path:
        return None
    try:
        try:
            page_num = int(page_num_str)
        except (ValueError, TypeError):
            page_num = 1
        if file_path.lower().endswith('.pdf'):
            doc = fitz.open(file_path)
            page_idx = max(0, min(page_num - 1, len(doc) - 1))
            page = doc.load_page(page_idx)
            pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False)
            img = Image.open(BytesIO(pix.tobytes("png")))
            doc.close()
            return img
        else:
            return Image.open(file_path)
    except Exception as e:
        print(f"Error loading image: {e}")
        return None

def get_pdf_page_count(file_path):
    if not file_path or not file_path.lower().endswith('.pdf'):
        return 1
    try:
        doc = fitz.open(file_path)
        count = len(doc)
        doc.close()
        return count
    except Exception as e:
        print(f"Error reading PDF page count: {e}")
        return 1

def update_page_info(file_path):
    if not file_path:
        return gr.update(label="Số trang (chỉ dùng cho PDF, mặc định: 1)")
    if file_path.lower().endswith('.pdf'):
        page_count = get_pdf_page_count(file_path)
        return gr.update(
            label=f"Số trang (PDF có {page_count} trang, nhập 1-{page_count})",
            value="1"
        )
    return gr.update(
        label="Số trang (chỉ dùng cho PDF, mặc định: 1)",
        value="1"
    )

# ==================== COMBINED OCR + TRANSLATION ====================
def ocr_and_translate_streaming(file_path, mode, page_num_str):
    if not file_path:
        yield '<div style="padding:20px; color:#ff6b6b;">⚠️ Vui lòng tải file lên trước!</div>'
        return
    yield '<div style="padding:20px; color:#4CAF50;">🔍 Đang quét OCR...</div>'
    try:
        try:
            page_num = int(page_num_str)
        except (ValueError, TypeError):
            page_num = 1
        markdown = ocr_process_file(file_path, mode, page_num)
        if not markdown or markdown.startswith("Error") or markdown.startswith("Invalid"):
            yield f'<div style="padding:20px; color:#ff6b6b;">❌ Lỗi OCR: {markdown}</div>'
            return
    except Exception as e:
        yield f'<div style="padding:20px; color:#ff6b6b;">❌ Lỗi OCR: {str(e)}</div>'
        return
    yield '<div style="padding:20px; color:#2196F3;">🦀 Đang dịch...</div>'
    time.sleep(0.5)
    try:
        yield from streaming_translate(markdown)
    except Exception as e:
        yield f'<div style="padding:20px; color:#ff6b6b;">❌ Lỗi dịch: {str(e)}</div>'

# ==================== GRADIO INTERFACE ====================
def load_default_example():
    src = "images/example1.png"
    if not os.path.exists(src):
        # fallback: return empty values
        return None, None
    tmp_path = "/tmp/example1.png"
    try:
        shutil.copy(src, tmp_path)
    except Exception:
        # if copy fails, try to use src directly
        tmp_path = src
    img = Image.open(tmp_path)
    return tmp_path, img

with gr.Blocks(theme=gr.themes.Soft(), title="MedCrab Translation") as demo:
    gr.Markdown("""
    <div style="text-align: center;">
    <h1>🦀 MedCrab Translation</h1>
    <p style="font-size: 18px;"><b>Quét PDF Y khoa → Dịch trực tiếp sang tiếng Việt (Streaming)</b></p>
    
    <p style="font-size: 15px;">
    <b>Model:</b> 
    <a href="https://huggingface.co/pnnbao-ump/MedCrab-1.5B" target="_blank">
    https://huggingface.co/pnnbao-ump/MedCrab-1.5B
    </a>
    </p>
    
    <p style="font-size: 15px;">
    <b>Dataset:</b> 
    <a href="https://huggingface.co/datasets/pnnbao-ump/MedCrab" target="_blank">
    https://huggingface.co/datasets/pnnbao-ump/MedCrab
    </a>
    </p>
    
    <p style="font-size: 15px;">
    <b>GitHub Repository:</b> 
    <a href="https://github.com/pnnbao97/MedCrab" target="_blank">
    https://github.com/pnnbao97/MedCrab
    </a>
    </p>
    
    <p style="font-size: 15px;">
    <b>Tác giả:</b> Phạm Nguyễn Ngọc Bảo
    </p>
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 📤 Tải file lên")
            file_in = gr.File(label="PDF hoặc Hình ảnh", file_types=["image", ".pdf"], type="filepath")
            input_img = gr.Image(label="Xem trước", type="pil", height=300)
            page_input = gr.Textbox(label="Số trang (chỉ dùng cho PDF, mặc định: 1)", value="1", placeholder="Nhập số trang...")
            mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Crab", label="Chế độ OCR")
            gr.Markdown("### 🦀 Quét và Dịch")
            process_btn = gr.Button("🚀 Quét OCR + Dịch tiếng Việt", variant="primary", size="lg")
        with gr.Column(scale=2):
            gr.Markdown("### 📄 Kết quả dịch tiếng Việt (Streaming)")
            translation_output = gr.HTML(label="", value="")

    with gr.Accordion("📚 Ví dụ mẫu", open=True):
        gr.Markdown("**Thử ngay với các ví dụ có sẵn:**")
        gr.Examples(
            examples=[
                ["images/example1.png", "Crab", "1"],
                ["images/example2.png", "Crab", "1"],
            ],
            inputs=[file_in, mode, page_input],
            outputs=[translation_output],
            fn=ocr_and_translate_streaming,
            cache_examples=False,
            label="Nhấp vào ví dụ để thử"
        )

    with gr.Accordion("⚖️ Giấy phép & Liên hệ", open=False): 
        gr.Markdown("""
        ### ⚖️ Giấy phép sử dụng
        
        MedCrab được phát hành theo giấy phép:  
        **Creative Commons Attribution–NonCommercial 4.0 International (CC BY-NC 4.0)**
        
        ### ✅ Được phép
        - Sử dụng cho mục đích cá nhân  
        - Nghiên cứu học thuật  
        - Giảng dạy, học tập, minh họa  
        - Thử nghiệm nội bộ (không phục vụ vận hành thực tế)
        
        ### ❌ Không được phép
        - Sử dụng cho mục đích thương mại dưới mọi hình thức  
        - Tích hợp vào hệ thống sản xuất (production system)  
        - Triển khai tại bệnh viện, phòng khám, cơ sở y tế  
        - Cung cấp như một dịch vụ trả phí / SaaS  
        - Bán lại, cho thuê, nhượng quyền phần mềm
        
        ### 💼 Nhu cầu sử dụng thương mại
        Nếu bạn đại diện cho:
        - Bệnh viện / phòng khám  
        - Công ty công nghệ y tế  
        - Viện nghiên cứu có hoạt động thương mại hóa  
        - Startup, doanh nghiệp triển khai sản phẩm y tế  
        
        Vui lòng liên hệ trực tiếp tác giả để trao đổi về **giấy phép thương mại**:
        
        👤 **Phạm Nguyễn Ngọc Bảo**  
        📧 Facebook: https://www.facebook.com/bao.phamnguyenngoc.5/
        
        ---
        
        ⚠️ **Lưu ý pháp lý:**  
        Phần mềm này chỉ phục vụ cho mục đích nghiên cứu và tham khảo,  
        **không thay thế cho chẩn đoán hoặc quyết định y khoa.**
        """)
    
    # Events
    file_in.change(load_image, [file_in, page_input], [input_img])
    file_in.change(update_page_info, [file_in], [page_input])
    page_input.change(load_image, [file_in, page_input], [input_img])
    process_btn.click(ocr_and_translate_streaming, [file_in, mode, page_input], [translation_output])

    # Load default example into both file_in (filepath) and input_img (PIL) when UI starts
    demo.load(
        load_default_example,
        inputs=None,
        outputs=[file_in, input_img]
    )

if __name__ == "__main__":
    print("🚀 Starting MedCrab Translation on Hugging Face Spaces...")
    demo.queue(max_size=20).launch()