Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| from medcrab import MedCrabTranslator | |
| import torch | |
| import os | |
| import sys | |
| import tempfile | |
| import shutil | |
| from PIL import Image, ImageOps | |
| import fitz | |
| import re | |
| import time | |
| from io import StringIO, BytesIO | |
| import spaces | |
| # ==================== CONFIG ==================== | |
| OCR_MODEL_NAME = 'deepseek-ai/DeepSeek-OCR' | |
| MODEL_CONFIGS = { | |
| "Crab": {"base_size": 1024, "image_size": 640, "crop_mode": True}, | |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| } | |
| # ==================== LOAD MODELS ==================== | |
| print("🔄 Loading OCR model...") | |
| ocr_tokenizer = AutoTokenizer.from_pretrained(OCR_MODEL_NAME, trust_remote_code=True) | |
| try: | |
| ocr_model = AutoModel.from_pretrained( | |
| OCR_MODEL_NAME, | |
| attn_implementation='flash_attention_2', | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| use_safetensors=True | |
| ) | |
| print("✅ Using Flash Attention 2") | |
| except (ImportError, ValueError): | |
| print("⚠️ Flash Attention 2 not available, using eager attention") | |
| ocr_model = AutoModel.from_pretrained( | |
| OCR_MODEL_NAME, | |
| attn_implementation='eager', | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| use_safetensors=True | |
| ) | |
| ocr_model = ocr_model.eval() | |
| print("🦀 Loading MedCrab translator...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| translator = MedCrabTranslator(device=device) | |
| print(f"✅ MedCrab translator loaded on {device}") | |
| # ==================== TEXT CLEANING ==================== | |
| def clean_mathrm(text): | |
| if not text: | |
| return "" | |
| def process_math_block(match): | |
| math_content = match.group(1) | |
| math_content = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', math_content) | |
| math_content = re.sub(r'\^\{([^}]+)\}', r'<sup>\1</sup>', math_content) | |
| math_content = re.sub(r'\^([A-Za-z0-9+\-]+)', r'<sup>\1</sup>', math_content) | |
| math_content = re.sub(r'_\{([^}]+)\}', r'<sub>\1</sub>', math_content) | |
| math_content = re.sub(r'_([A-Za-z0-9+\-]+)', r'<sub>\1</sub>', math_content) | |
| replacements = { | |
| r'\times': '×', r'\pm': '±', r'\div': '÷', r'\cdot': '·', | |
| r'\approx': '≈', r'\leq': '≤', r'\geq': '≥', r'\neq': '≠', | |
| r'\rightarrow': '→', r'\leftarrow': '←', | |
| r'\Rightarrow': '⇒', r'\Leftarrow': '⇐', | |
| } | |
| for latex_cmd, unicode_char in replacements.items(): | |
| math_content = math_content.replace(latex_cmd, unicode_char) | |
| return math_content | |
| text = re.sub(r'\\\((.+?)\\\)', process_math_block, text, flags=re.DOTALL) | |
| def process_bracket_block(m): | |
| class FakeMatch: | |
| def __init__(self, content): | |
| self.content = content | |
| def group(self, n): | |
| return self.content | |
| content = process_math_block(FakeMatch(m.group(1))) | |
| return '[' + content + ']' | |
| text = re.sub(r'\\\[(.+?)\\\]', process_bracket_block, text, flags=re.DOTALL) | |
| text = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', text) | |
| text = text.replace(r'\%', '%') | |
| lines = text.split('\n') | |
| cleaned_lines = [re.sub(r'[ \t]+', ' ', line).strip() for line in lines] | |
| return '\n'.join(cleaned_lines).strip() | |
| def clean_output(text, include_images=False, remove_labels=False): | |
| if not text: | |
| return "" | |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| img_num = 0 | |
| for match in matches: | |
| if '<|ref|>image<|/ref|>' in match[0]: | |
| if include_images: | |
| text = text.replace(match[0], f'\n\n**[Figure {img_num + 1}]**\n\n', 1) | |
| img_num += 1 | |
| else: | |
| text = text.replace(match[0], '', 1) | |
| else: | |
| if remove_labels: | |
| text = text.replace(match[0], '', 1) | |
| else: | |
| text = text.replace(match[0], match[1], 1) | |
| return clean_mathrm(text).strip() | |
| # ==================== OCR HELPERS ==================== | |
| def ocr_process_image(image, mode="Crab"): | |
| if image is None: | |
| return "Error: Upload image" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ocr_model.to(device) | |
| if image.mode in ('RGBA', 'LA', 'P'): | |
| image = image.convert('RGB') | |
| image = ImageOps.exif_transpose(image) | |
| config = MODEL_CONFIGS[mode] | |
| prompt = "<image>\n<|grounding|>Convert the document to markdown." | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
| image.save(tmp.name, 'JPEG', quality=95) | |
| tmp.close() | |
| out_dir = tempfile.mkdtemp() | |
| stdout = sys.stdout | |
| sys.stdout = StringIO() | |
| try: | |
| ocr_model.infer( | |
| tokenizer=ocr_tokenizer, | |
| prompt=prompt, | |
| image_file=tmp.name, | |
| output_path=out_dir, | |
| base_size=config["base_size"], | |
| image_size=config["image_size"], | |
| crop_mode=config["crop_mode"] | |
| ) | |
| result = '\n'.join([l for l in sys.stdout.getvalue().split('\n') | |
| if not any(s in l for s in ['image:', 'other:', 'PATCHES', '====', 'BASE:', '%|', 'torch.Size'])]).strip() | |
| finally: | |
| sys.stdout = stdout | |
| try: | |
| os.unlink(tmp.name) | |
| except: | |
| pass | |
| shutil.rmtree(out_dir, ignore_errors=True) | |
| if not result: | |
| return "No text detected" | |
| return clean_output(result, True, True) | |
| def ocr_process_pdf(path, mode, page_num): | |
| doc = fitz.open(path) | |
| total_pages = len(doc) | |
| if page_num < 1 or page_num > total_pages: | |
| doc.close() | |
| return f"Invalid page number. PDF has {total_pages} pages." | |
| page = doc.load_page(page_num - 1) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) | |
| img = Image.open(BytesIO(pix.tobytes("png"))) | |
| doc.close() | |
| return ocr_process_image(img, mode) | |
| def ocr_process_file(path, mode, page_num): | |
| if not path: | |
| return "Error: Upload file" | |
| if path.lower().endswith('.pdf'): | |
| return ocr_process_pdf(path, mode, page_num) | |
| else: | |
| return ocr_process_image(Image.open(path), mode) | |
| # ==================== TRANSLATION HELPERS ==================== | |
| def split_by_sentences(text: str, max_words: int = 100): | |
| def count_words(t): | |
| return len(t.strip().split()) | |
| chunks = [] | |
| lines = text.split('\n') | |
| i = 0 | |
| while i < len(lines): | |
| line = lines[i] | |
| empty_count = 0 | |
| if not line.strip(): | |
| while i < len(lines) and not lines[i].strip(): | |
| empty_count += 1 | |
| i += 1 | |
| if chunks: | |
| prev_text, prev_newlines = chunks[-1] | |
| chunks[-1] = (prev_text, prev_newlines + empty_count) | |
| continue | |
| line = line.strip() | |
| is_last_line = (i == len(lines) - 1) | |
| if count_words(line) <= max_words: | |
| chunks.append((line, 0 if is_last_line else 1)) | |
| i += 1 | |
| continue | |
| sentences = re.split(r'(?<=[.!?])\s+', line) | |
| current_chunk = "" | |
| current_words = 0 | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if not sentence: | |
| continue | |
| sentence_words = count_words(sentence) | |
| if sentence_words > max_words: | |
| if current_chunk: | |
| chunks.append((current_chunk.strip(), 0)) | |
| current_chunk = "" | |
| current_words = 0 | |
| sub_parts = re.split(r',\s*', sentence) | |
| temp_chunk = "" | |
| temp_words = 0 | |
| for part in sub_parts: | |
| part_words = count_words(part) | |
| if temp_words + part_words > max_words and temp_chunk: | |
| chunks.append((temp_chunk.strip(), 0)) | |
| temp_chunk = part | |
| temp_words = part_words | |
| else: | |
| if temp_chunk: | |
| temp_chunk += ", " + part | |
| else: | |
| temp_chunk = part | |
| temp_words += part_words | |
| if temp_chunk.strip(): | |
| current_chunk = temp_chunk.strip() | |
| current_words = temp_words | |
| elif current_words + sentence_words <= max_words: | |
| if current_chunk: | |
| current_chunk += " " + sentence | |
| else: | |
| current_chunk = sentence | |
| current_words += sentence_words | |
| else: | |
| chunks.append((current_chunk.strip(), 0)) | |
| current_chunk = sentence | |
| current_words = sentence_words | |
| if current_chunk.strip(): | |
| chunks.append((current_chunk.strip(), 0 if is_last_line else 1)) | |
| i += 1 | |
| return chunks | |
| def translate_chunk(chunk_text): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if hasattr(translator, 'model') and hasattr(translator.model, 'to'): | |
| translator.model.to(device) | |
| return translator.translate(chunk_text, max_new_tokens=2048).strip() | |
| def streaming_translate(text: str): | |
| if not text or not text.strip(): | |
| yield '<div style="padding:20px; color:#ff6b6b;">⚠️ Vui lòng nhập văn bản tiếng Anh để dịch.</div>' | |
| return | |
| chunks = split_by_sentences(text, max_words=100) | |
| accumulated = "" | |
| for i, (chunk_text, newline_count) in enumerate(chunks): | |
| try: | |
| translated = translate_chunk(chunk_text) | |
| if accumulated and not accumulated.endswith('\n'): | |
| accumulated += " " + translated | |
| else: | |
| accumulated += translated | |
| chunk_start = len(accumulated) - len(translated) | |
| for j in range(len(translated)): | |
| current_display = accumulated[:chunk_start + j + 1] | |
| html_output = f'<div style="padding:20px; line-height:1.8; font-size:15px; white-space:pre-wrap; font-family:Arial,sans-serif;">{current_display}</div>' | |
| yield html_output | |
| time.sleep(0.015) | |
| if newline_count > 0: | |
| actual_newlines = min(newline_count, 2) | |
| accumulated += "\n" * actual_newlines | |
| html_output = f'<div style="padding:20px; line-height:1.8; font-size:15px; white-space:pre-wrap; font-family:Arial,sans-serif;">{accumulated}</div>' | |
| yield html_output | |
| except Exception as e: | |
| yield f'<div style="padding:20px; color:#ff6b6b;">❌ Lỗi dịch chunk {i+1}: {str(e)}</div>' | |
| return | |
| # ==================== UI HELPERS ==================== | |
| def load_image(file_path, page_num_str="1"): | |
| if not file_path: | |
| return None | |
| try: | |
| try: | |
| page_num = int(page_num_str) | |
| except (ValueError, TypeError): | |
| page_num = 1 | |
| if file_path.lower().endswith('.pdf'): | |
| doc = fitz.open(file_path) | |
| page_idx = max(0, min(page_num - 1, len(doc) - 1)) | |
| page = doc.load_page(page_idx) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72), alpha=False) | |
| img = Image.open(BytesIO(pix.tobytes("png"))) | |
| doc.close() | |
| return img | |
| else: | |
| return Image.open(file_path) | |
| except Exception as e: | |
| print(f"Error loading image: {e}") | |
| return None | |
| def get_pdf_page_count(file_path): | |
| if not file_path or not file_path.lower().endswith('.pdf'): | |
| return 1 | |
| try: | |
| doc = fitz.open(file_path) | |
| count = len(doc) | |
| doc.close() | |
| return count | |
| except Exception as e: | |
| print(f"Error reading PDF page count: {e}") | |
| return 1 | |
| def update_page_info(file_path): | |
| if not file_path: | |
| return gr.update(label="Số trang (chỉ dùng cho PDF, mặc định: 1)") | |
| if file_path.lower().endswith('.pdf'): | |
| page_count = get_pdf_page_count(file_path) | |
| return gr.update( | |
| label=f"Số trang (PDF có {page_count} trang, nhập 1-{page_count})", | |
| value="1" | |
| ) | |
| return gr.update( | |
| label="Số trang (chỉ dùng cho PDF, mặc định: 1)", | |
| value="1" | |
| ) | |
| # ==================== COMBINED OCR + TRANSLATION ==================== | |
| def ocr_and_translate_streaming(file_path, mode, page_num_str): | |
| if not file_path: | |
| yield '<div style="padding:20px; color:#ff6b6b;">⚠️ Vui lòng tải file lên trước!</div>' | |
| return | |
| yield '<div style="padding:20px; color:#4CAF50;">🔍 Đang quét OCR...</div>' | |
| try: | |
| try: | |
| page_num = int(page_num_str) | |
| except (ValueError, TypeError): | |
| page_num = 1 | |
| markdown = ocr_process_file(file_path, mode, page_num) | |
| if not markdown or markdown.startswith("Error") or markdown.startswith("Invalid"): | |
| yield f'<div style="padding:20px; color:#ff6b6b;">❌ Lỗi OCR: {markdown}</div>' | |
| return | |
| except Exception as e: | |
| yield f'<div style="padding:20px; color:#ff6b6b;">❌ Lỗi OCR: {str(e)}</div>' | |
| return | |
| yield '<div style="padding:20px; color:#2196F3;">🦀 Đang dịch...</div>' | |
| time.sleep(0.5) | |
| try: | |
| yield from streaming_translate(markdown) | |
| except Exception as e: | |
| yield f'<div style="padding:20px; color:#ff6b6b;">❌ Lỗi dịch: {str(e)}</div>' | |
| # ==================== GRADIO INTERFACE ==================== | |
| def load_default_example(): | |
| src = "images/example1.png" | |
| if not os.path.exists(src): | |
| # fallback: return empty values | |
| return None, None | |
| tmp_path = "/tmp/example1.png" | |
| try: | |
| shutil.copy(src, tmp_path) | |
| except Exception: | |
| # if copy fails, try to use src directly | |
| tmp_path = src | |
| img = Image.open(tmp_path) | |
| return tmp_path, img | |
| with gr.Blocks(theme=gr.themes.Soft(), title="MedCrab Translation") as demo: | |
| gr.Markdown(""" | |
| <div style="text-align: center;"> | |
| <h1>🦀 MedCrab Translation</h1> | |
| <p style="font-size: 18px;"><b>Quét PDF Y khoa → Dịch trực tiếp sang tiếng Việt (Streaming)</b></p> | |
| <p style="font-size: 15px;"> | |
| <b>Model:</b> | |
| <a href="https://huggingface.co/pnnbao-ump/MedCrab-1.5B" target="_blank"> | |
| https://huggingface.co/pnnbao-ump/MedCrab-1.5B | |
| </a> | |
| </p> | |
| <p style="font-size: 15px;"> | |
| <b>Dataset:</b> | |
| <a href="https://huggingface.co/datasets/pnnbao-ump/MedCrab" target="_blank"> | |
| https://huggingface.co/datasets/pnnbao-ump/MedCrab | |
| </a> | |
| </p> | |
| <p style="font-size: 15px;"> | |
| <b>GitHub Repository:</b> | |
| <a href="https://github.com/pnnbao97/MedCrab" target="_blank"> | |
| https://github.com/pnnbao97/MedCrab | |
| </a> | |
| </p> | |
| <p style="font-size: 15px;"> | |
| <b>Tác giả:</b> Phạm Nguyễn Ngọc Bảo | |
| </p> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📤 Tải file lên") | |
| file_in = gr.File(label="PDF hoặc Hình ảnh", file_types=["image", ".pdf"], type="filepath") | |
| input_img = gr.Image(label="Xem trước", type="pil", height=300) | |
| page_input = gr.Textbox(label="Số trang (chỉ dùng cho PDF, mặc định: 1)", value="1", placeholder="Nhập số trang...") | |
| mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Crab", label="Chế độ OCR") | |
| gr.Markdown("### 🦀 Quét và Dịch") | |
| process_btn = gr.Button("🚀 Quét OCR + Dịch tiếng Việt", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📄 Kết quả dịch tiếng Việt (Streaming)") | |
| translation_output = gr.HTML(label="", value="") | |
| with gr.Accordion("📚 Ví dụ mẫu", open=True): | |
| gr.Markdown("**Thử ngay với các ví dụ có sẵn:**") | |
| gr.Examples( | |
| examples=[ | |
| ["images/example1.png", "Crab", "1"], | |
| ["images/example2.png", "Crab", "1"], | |
| ], | |
| inputs=[file_in, mode, page_input], | |
| outputs=[translation_output], | |
| fn=ocr_and_translate_streaming, | |
| cache_examples=False, | |
| label="Nhấp vào ví dụ để thử" | |
| ) | |
| with gr.Accordion("⚖️ Giấy phép & Liên hệ", open=False): | |
| gr.Markdown(""" | |
| ### ⚖️ Giấy phép sử dụng | |
| MedCrab được phát hành theo giấy phép: | |
| **Creative Commons Attribution–NonCommercial 4.0 International (CC BY-NC 4.0)** | |
| ### ✅ Được phép | |
| - Sử dụng cho mục đích cá nhân | |
| - Nghiên cứu học thuật | |
| - Giảng dạy, học tập, minh họa | |
| - Thử nghiệm nội bộ (không phục vụ vận hành thực tế) | |
| ### ❌ Không được phép | |
| - Sử dụng cho mục đích thương mại dưới mọi hình thức | |
| - Tích hợp vào hệ thống sản xuất (production system) | |
| - Triển khai tại bệnh viện, phòng khám, cơ sở y tế | |
| - Cung cấp như một dịch vụ trả phí / SaaS | |
| - Bán lại, cho thuê, nhượng quyền phần mềm | |
| ### 💼 Nhu cầu sử dụng thương mại | |
| Nếu bạn đại diện cho: | |
| - Bệnh viện / phòng khám | |
| - Công ty công nghệ y tế | |
| - Viện nghiên cứu có hoạt động thương mại hóa | |
| - Startup, doanh nghiệp triển khai sản phẩm y tế | |
| Vui lòng liên hệ trực tiếp tác giả để trao đổi về **giấy phép thương mại**: | |
| 👤 **Phạm Nguyễn Ngọc Bảo** | |
| 📧 Facebook: https://www.facebook.com/bao.phamnguyenngoc.5/ | |
| --- | |
| ⚠️ **Lưu ý pháp lý:** | |
| Phần mềm này chỉ phục vụ cho mục đích nghiên cứu và tham khảo, | |
| **không thay thế cho chẩn đoán hoặc quyết định y khoa.** | |
| """) | |
| # Events | |
| file_in.change(load_image, [file_in, page_input], [input_img]) | |
| file_in.change(update_page_info, [file_in], [page_input]) | |
| page_input.change(load_image, [file_in, page_input], [input_img]) | |
| process_btn.click(ocr_and_translate_streaming, [file_in, mode, page_input], [translation_output]) | |
| # Load default example into both file_in (filepath) and input_img (PIL) when UI starts | |
| demo.load( | |
| load_default_example, | |
| inputs=None, | |
| outputs=[file_in, input_img] | |
| ) | |
| if __name__ == "__main__": | |
| print("🚀 Starting MedCrab Translation on Hugging Face Spaces...") | |
| demo.queue(max_size=20).launch() | |