import gradio as gr import torch from transformers import AutoModel, AutoTokenizer import spaces import os import tempfile from PIL import Image, ImageDraw import re # Import thΖ° viện regular expression # --- 1. Load Model and Tokenizer (Done only once at startup) --- print("Loading model and tokenizer...") model_name = "deepseek-ai/DeepSeek-OCR" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Load the model to CPU first; it will be moved to GPU during processing model = AutoModel.from_pretrained( model_name, _attn_implementation="flash_attention_2", trust_remote_code=True, use_safetensors=True, ) model = model.eval() print("βœ… Model loaded successfully.") # --- Helper function to find pre-generated result images --- def find_result_image(path): for filename in os.listdir(path): if "grounding" in filename or "result" in filename: try: image_path = os.path.join(path, filename) return Image.open(image_path) except Exception as e: print(f"Error opening result image {filename}: {e}") return None # --- 2. Main Processing Function (UPDATED for multi-bbox drawing) --- @spaces.GPU def process_ocr_task(image, model_size, task_type, ref_text): """ Processes an image with DeepSeek-OCR for all supported tasks. Now draws ALL detected bounding boxes for ANY task. """ if image is None: return "Please upload an image first.", None print("πŸš€ Moving model to GPU...") model_gpu = model.cuda().to(torch.bfloat16) print("βœ… Model is on GPU.") with tempfile.TemporaryDirectory() as output_path: # Build the prompt... (same as before) if task_type == "πŸ“ Free OCR": prompt = "\nFree OCR." elif task_type == "πŸ“„ Convert to Markdown": prompt = "\n<|grounding|>Convert the document to markdown." elif task_type == "πŸ“ˆ Parse Figure": prompt = "\nParse the figure." else: prompt = "\nFree OCR." temp_image_path = os.path.join(output_path, "temp_image.png") image.save(temp_image_path) # Configure model size... (same as before) size_configs = { "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, } config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) print(f"πŸƒ Running inference with prompt: {prompt}") text_result = model_gpu.infer( tokenizer, prompt=prompt, image_file=temp_image_path, output_path=output_path, base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"], save_results=True, test_compress=True, eval_mode=True, ) print(f"====\nπŸ“„ Text Result: {text_result}\n====") return text_result # --- 3. Build the Gradio Interface (UPDATED) --- with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🐳 Full Demo of DeepSeek-OCR 🐳 **πŸ’‘ How to use:** 1. **Upload an image** using the upload box. 2. Select a **Resolution**. `Gundam` is recommended for most documents. 3. Choose a **Task Type**: - **πŸ“ Free OCR**: Extracts raw text from the image. - **πŸ“„ Convert to Markdown**: Converts the document into Markdown, preserving structure. - **πŸ“ˆ Parse Figure**: Extracts structured data from charts and figures. """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Image", sources=["upload", "clipboard"]) model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Gundam (Recommended)", label="βš™οΈ Resolution Size") task_type = gr.Dropdown(choices=["πŸ“ Free OCR", "πŸ“„ Convert to Markdown", "πŸ“ˆ Parse Figure"], value="πŸ“„ Convert to Markdown", label="πŸš€ Task Type") submit_btn = gr.Button("Process Image", variant="primary") with gr.Column(scale=2): output_text = gr.Textbox(label="πŸ“„ Text Result", lines=15, show_copy_button=True) output_image = gr.Image(label="πŸ–ΌοΈ Image Result (if any)", type="pil") # --- 4. Launch the App --- if __name__ == "__main__": demo.queue(max_size=20).launch(share=True)