| import json |
| import re |
| from typing import List, Optional, Tuple, Union |
| import numpy as np |
| import os |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| from PIL import Image |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from huggingface_hub import login, snapshot_download |
| from paddleocr import PaddleOCR |
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| if HF_TOKEN: |
| login(token=HF_TOKEN.strip()) |
|
|
| |
| MED_MODEL_ID = "google/gemma-2-2b-it" |
|
|
| |
| OCR_READER = None |
| MED_MODEL = None |
| MED_TOKENIZER = None |
| OCR_MODEL_REPO_ID = "PaddlePaddle/korean_PP-OCRv5_mobile_rec" |
|
|
|
|
| def _collect_ocr_texts(ocr_payload) -> List[str]: |
| """PaddleOCR κ²°κ³Ό ꡬ쑰μμ ν
μ€νΈλ§ μΆμΆ""" |
| texts: List[str] = [] |
| seen = set() |
|
|
| def add_text(candidate: str): |
| if not isinstance(candidate, str): |
| return |
| normalized = candidate.strip() |
| if normalized and normalized not in seen: |
| seen.add(normalized) |
| texts.append(normalized) |
|
|
| def walk(node): |
| if isinstance(node, str): |
| add_text(node) |
| return |
|
|
| if isinstance(node, dict): |
| for key in ("text", "label", "transcription"): |
| add_text(node.get(key)) |
|
|
| for key in ("texts", "labels"): |
| values = node.get(key) |
| if isinstance(values, (list, tuple)): |
| for value in values: |
| add_text(value) |
|
|
| for key in ("text_recognition", "rec_results", "data", "results"): |
| if key in node: |
| walk(node[key]) |
| return |
|
|
| if isinstance(node, (list, tuple)): |
| if len(node) >= 2: |
| second = node[1] |
| if isinstance(second, str): |
| add_text(second) |
| elif isinstance(second, (list, tuple)) and second: |
| maybe_text = second[0] |
| add_text(maybe_text) |
|
|
| for item in node: |
| walk(item) |
|
|
| walk(ocr_payload) |
| return texts |
|
|
| def load_models(): |
| """λͺ¨λΈλ€μ ν λ²λ§ λ‘λ""" |
| global OCR_READER, MED_MODEL, MED_TOKENIZER |
|
|
| if OCR_READER is None: |
| print("π Loading PaddleOCR (Korean PP-OCRv5 mobile recognition)...") |
| rec_model_dir = snapshot_download( |
| OCR_MODEL_REPO_ID, |
| allow_patterns=[ |
| "*.pdmodel", |
| "*.pdiparams", |
| "*.pdparams", |
| "*.json", |
| "*.yml", |
| ], |
| ) |
| OCR_READER = PaddleOCR( |
| lang='korean', |
| use_textline_orientation=True, |
| text_recognition_model_dir=rec_model_dir, |
| text_recognition_model_name="korean_PP-OCRv5_mobile_rec", |
| ) |
| print("β
PaddleOCR loaded!") |
|
|
| if MED_MODEL is None: |
| print("π Loading Gemma-2-2B for medical analysis (8bit quantization)...") |
| MED_MODEL = AutoModelForCausalLM.from_pretrained( |
| MED_MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| load_in_8bit=True |
| ) |
| MED_TOKENIZER = AutoTokenizer.from_pretrained(MED_MODEL_ID) |
| print("β
Medical model loaded!") |
|
|
| |
| load_models() |
|
|
|
|
| def _extract_assistant_content(decoded: str) -> str: |
| """μ΄μμ€ν΄νΈ μλ΅ μΆμΆ""" |
| if "<|im_start|>assistant" in decoded: |
| content = decoded.split("<|im_start|>assistant")[-1] |
| content = content.replace("<|im_end|>", "").strip() |
| return content |
| return decoded.strip() |
|
|
|
|
| def _extract_json_block(text: str) -> Optional[str]: |
| """JSON λΈλ‘ μΆμΆ""" |
| match = re.search(r"\{.*\}", text, re.DOTALL) |
| if not match: |
| return None |
| return match.group(0) |
|
|
|
|
| @spaces.GPU(duration=120) |
| def analyze_medication_image(image: Image.Image) -> Tuple[str, str]: |
| """μ΄λ―Έμ§μμ OCR μΆμΆ ν μ½ μ 보 λΆμ""" |
| import time |
| try: |
| |
| start_time = time.time() |
| img_array = np.array(image) |
|
|
| try: |
| ocr_results = OCR_READER.predict(img_array) |
| except (TypeError, AttributeError): |
| ocr_results = OCR_READER.ocr(img_array) |
| ocr_time = time.time() - start_time |
| print(f"β±οΈ OCR took {ocr_time:.2f}s") |
|
|
| if not ocr_results: |
| return "ν
μ€νΈλ₯Ό μ°Ύμ μ μμ΅λλ€.", "" |
|
|
| |
| texts = _collect_ocr_texts(ocr_results) |
|
|
| if not texts: |
| return "ν
μ€νΈλ₯Ό μ°Ύμ μ μμ΅λλ€.", "" |
|
|
| ocr_text = "\n".join(texts) |
|
|
| |
| analysis_start = time.time() |
|
|
| analysis_prompt = f"""λ€μμ μ½ λ΄ν¬λ μ²λ°©μ μμ μΆμΆν ν
μ€νΈμ
λλ€: |
| |
| {ocr_text} |
| |
| μ ν
μ€νΈμμ μ½ μ΄λ¦μ μ°Ύμμ, κ° μ½μ λν΄ **λ
ΈμΈκ³Ό μ΄λ¦°μ΄ λͺ¨λ μ½κ² μ΄ν΄ν μ μλλ‘** μ¬λ―Έμκ³ μΉκ·Όνκ² μ€λͺ
ν΄μ£ΌμΈμ: |
| |
| π **κ° μ½λ§λ€ λ€μ μ 보λ₯Ό ν¬ν¨ν΄μ£ΌμΈμ:** |
| |
| 1. π **μ½ μ΄λ¦**: μ νν μ½ μ΄λ¦ |
| 2. π― **ν¨λ₯**: μ΄ μ½μ΄ 무μμ μΉλ£νκ³ μ΄λ»κ² λμμ΄ λλμ§ |
| 3. β οΈ **λΆμμ©**: μ£Όμν΄μΌ ν λΆμμ©λ€ |
| 4. π‘ **λ³΅μ© λ°©λ²**: μΈμ , μ΄λ»κ² λ¨Ήμ΄μΌ νλμ§ (μμ /μν, ν루 λͺ λ² λ±) |
| 5. π« **μ£Όμμ¬ν**: μ΄ μ½κ³Ό ν¨κ» λ¨ΉμΌλ©΄ μ λλ κ²λ€ (μμ, λ€λ₯Έ μ½ λ±) |
| |
| **μ€νμΌ κ°μ΄λ:** |
| - μ΄λͺ¨μ§λ₯Ό μ κ·Ή νμ©νμ¬ μ¬λ―Έμκ² μμ± |
| - ν λ¨Έλ ν μλ²μ§λ μ΄λ±νμλ μ΄ν΄ν μ μλ μ¬μ΄ λ¨μ΄ μ¬μ© |
| - κ° μ½λ§λ€ ꡬλΆμ μΌλ‘ κ΅¬λΆ |
| - μΉκ·Όνκ³ λ°λ»ν λ§ν¬ μ¬μ© |
| - λ§ν¬λ€μ΄ νμμΌλ‘ μμ± |
| |
| μμν΄μ£ΌμΈμ!""" |
|
|
| messages = [ |
| {"role": "user", "content": analysis_prompt} |
| ] |
|
|
| input_text = MED_TOKENIZER.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = MED_TOKENIZER(input_text, return_tensors="pt").to(MED_MODEL.device) |
|
|
| with torch.no_grad(): |
| outputs = MED_MODEL.generate( |
| **inputs, |
| max_new_tokens=768, |
| temperature=0.7, |
| top_p=0.9, |
| do_sample=True |
| ) |
|
|
| analysis_text = MED_TOKENIZER.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
| analysis_time = time.time() - analysis_start |
| total_time = time.time() - start_time |
| print(f"β±οΈ Medical analysis took {analysis_time:.2f}s") |
| print(f"β±οΈ Total processing time: {total_time:.2f}s") |
|
|
| return ocr_text.strip(), analysis_text.strip() |
|
|
| except Exception as e: |
| raise Exception(f"λΆμ μ€λ₯: {str(e)}") |
|
|
|
|
| def extract_medications_from_text(text: str) -> List[str]: |
| """Stage 2: Qwen2.5λ‘ ν
μ€νΈμμ μ½ μ΄λ¦λ§ μΆμΆ""" |
| try: |
| messages = [ |
| { |
| "role": "system", |
| "content": "You are a medical text analyzer. Extract only medication names from the given text and return them as a JSON array. Return ONLY valid JSON format." |
| }, |
| { |
| "role": "user", |
| "content": f"Extract all medication names from this text:\n\n{text}\n\nReturn format: {{\"medications\": [\"name1\", \"name2\"]}}" |
| } |
| ] |
|
|
| prompt = LLM_TOKENIZER.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| inputs = LLM_TOKENIZER(prompt, return_tensors="pt").to(LLM_MODEL.device) |
|
|
| with torch.no_grad(): |
| outputs = LLM_MODEL.generate( |
| **inputs, |
| max_new_tokens=512, |
| temperature=0.3, |
| top_p=0.9, |
| do_sample=True, |
| pad_token_id=LLM_TOKENIZER.eos_token_id, |
| ) |
|
|
| response = LLM_TOKENIZER.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| if "<|im_start|>assistant" in response: |
| response = response.split("<|im_start|>assistant")[-1] |
| response = response.replace("<|im_end|>", "").strip() |
|
|
| |
| json_match = re.search(r'\{.*?\}', response, re.DOTALL) |
| if json_match: |
| data = json.loads(json_match.group(0)) |
| medications = data.get("medications", []) |
| if isinstance(medications, list) and medications: |
| return [str(m).strip() for m in medications if str(m).strip()] |
|
|
| return ["μ½ μ΄λ¦μ μ°Ύμ§ λͺ»νμ΅λλ€."] |
|
|
| except Exception as e: |
| raise Exception(f"LLM λΆμ μ€λ₯: {str(e)}") |
|
|
|
|
| @spaces.GPU(duration=120) |
| def extract_medication_names(image: Image.Image) -> Tuple[str, List[str]]: |
| """2λ¨κ³ νμ΄νλΌμΈ: OCR β LLM λΆμ""" |
| try: |
| |
| extracted_text = extract_text_from_image(image) |
|
|
| if not extracted_text: |
| return "", ["ν
μ€νΈλ₯Ό μΆμΆνμ§ λͺ»νμ΅λλ€."] |
|
|
| |
| medications = extract_medications_from_text(extracted_text) |
|
|
| return extracted_text, medications |
|
|
| except Exception as e: |
| return "", [f"μ€λ₯ λ°μ: {str(e)}"] |
|
|
|
|
| def format_results(extracted_text: str, medications: List[str]) -> Tuple[str, str]: |
| """κ²°κ³Όλ₯Ό ν¬λ§·ν
""" |
| |
| text_output = f"### π μΆμΆλ ν
μ€νΈ\n\n```\n{extracted_text}\n```" |
|
|
| |
| if not medications or medications[0].startswith("μ€λ₯") or medications[0].startswith("μ½ μ΄λ¦μ μ°Ύμ§") or medications[0].startswith("ν
μ€νΈλ₯Ό"): |
| med_output = f"### β οΈ {medications[0] if medications else 'μ½ μ΄λ¦μ μ°Ύμ§ λͺ»νμ΅λλ€.'}" |
| else: |
| med_output = f"### π κ²μΆλ μ½λ¬Ό ({len(medications)}κ°)\n\n" |
| for idx, med_name in enumerate(medications, 1): |
| med_output += f"{idx}. **{med_name}**\n" |
|
|
| return text_output, med_output |
|
|
|
|
| def _ensure_pil(image_input: Optional[Union[Image.Image, np.ndarray, str]]) -> Optional[Image.Image]: |
| """Gradio μ
λ ₯μ PIL μ΄λ―Έμ§λ‘ λ³ν""" |
| if image_input is None: |
| return None |
|
|
| if isinstance(image_input, Image.Image): |
| return image_input |
|
|
| if isinstance(image_input, np.ndarray): |
| if image_input.dtype != np.uint8: |
| image_input = np.clip(image_input, 0, 255).astype(np.uint8) |
| return Image.fromarray(image_input).convert("RGB") |
|
|
| if isinstance(image_input, str): |
| if not os.path.exists(image_input): |
| return None |
| with Image.open(image_input) as img: |
| return img.convert("RGB") |
|
|
| return None |
|
|
|
|
| def run_analysis(image: Optional[Union[Image.Image, np.ndarray, str]], progress=gr.Progress()): |
| """λ©μΈ λΆμ νμ΄νλΌμΈ: OCR + μ½ μ 보 λΆμ""" |
| pil_image = _ensure_pil(image) |
|
|
| if pil_image is None: |
| return "π· μ½ λ΄ν¬λ μ²λ°©μ μ¬μ§μ μ
λ‘λν΄μ£ΌμΈμ.", "" |
|
|
| progress(0.3, desc="πΈ 1λ¨κ³: OCR ν
μ€νΈ μΆμΆ μ€...") |
| progress(0.6, desc="π€ 2λ¨κ³: μ½ μ 보 λΆμ μ€...") |
|
|
| try: |
| ocr_text, analysis = analyze_medication_image(pil_image) |
| progress(1.0, desc="β
μλ£!") |
|
|
| ocr_output = f"### π μΆμΆλ ν
μ€νΈ\n\n```\n{ocr_text}\n```" |
| analysis_output = f"### π μ½ μ 보 μ€λͺ
\n\n{analysis}" |
|
|
| return ocr_output, analysis_output |
| except Exception as e: |
| return f"### β οΈ μ€λ₯ λ°μ\n\n{str(e)}", "" |
|
|
|
|
| |
| CUSTOM_CSS = """ |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); |
| |
| :root { |
| --primary: #6366f1; |
| --secondary: #8b5cf6; |
| } |
| |
| body { |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; |
| } |
| |
| .gradio-container { |
| max-width: 900px !important; |
| margin: auto; |
| background: rgba(255, 255, 255, 0.98); |
| border-radius: 24px; |
| box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.3); |
| padding: 40px; |
| } |
| |
| .hero { |
| text-align: center; |
| padding: 30px 20px; |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
| border-radius: 20px; |
| color: white; |
| margin-bottom: 30px; |
| } |
| |
| .hero h1 { |
| font-size: 2.5rem; |
| font-weight: 700; |
| margin-bottom: 10px; |
| } |
| |
| .hero p { |
| font-size: 1.1rem; |
| opacity: 0.95; |
| } |
| |
| .upload-section { |
| background: white; |
| border-radius: 16px; |
| padding: 30px; |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); |
| margin-bottom: 20px; |
| } |
| |
| .result-section { |
| background: white; |
| border-radius: 16px; |
| padding: 30px; |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); |
| min-height: 200px; |
| } |
| |
| .analyze-btn button { |
| background: linear-gradient(135deg, var(--primary), var(--secondary)) !important; |
| color: white !important; |
| font-weight: 600 !important; |
| font-size: 1.1rem !important; |
| padding: 18px 40px !important; |
| border-radius: 12px !important; |
| border: none !important; |
| box-shadow: 0 10px 20px -5px rgba(99, 102, 241, 0.5) !important; |
| transition: all 0.3s ease !important; |
| } |
| |
| .analyze-btn button:hover { |
| transform: translateY(-2px) !important; |
| box-shadow: 0 15px 30px -5px rgba(99, 102, 241, 0.6) !important; |
| } |
| |
| .gr-image { |
| border-radius: 12px !important; |
| } |
| """ |
|
|
| HERO_HTML = """ |
| <div class="hero"> |
| <h1>π μ°λ¦¬ κ°μ‘± μ½ λμ°λ―Έ</h1> |
| <p>μ½λ΄ν¬/μ²λ°©μ μ¬μ§μμ μ½ μ 보λ₯Ό μ½κ³ μ¬λ―Έμκ² μλ €λλ €μ!</p> |
| </div> |
| """ |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: |
| gr.HTML(HERO_HTML) |
|
|
| with gr.Column(elem_classes=["upload-section"]): |
| gr.Markdown("### πΈ μ¬μ§ μ
λ‘λ") |
| image_input = gr.Image(type="numpy", image_mode="RGB", label="μ½λ΄ν¬ λλ μ²λ°©μ μ¬μ§", height=350) |
| analyze_button = gr.Button("π μ½ μ 보 λΆμνκΈ°", elem_classes=["analyze-btn"], size="lg") |
|
|
| with gr.Row(): |
| with gr.Column(elem_classes=["result-section"]): |
| gr.Markdown("### π 1λ¨κ³: μΆμΆλ ν
μ€νΈ") |
| ocr_output = gr.Markdown("OCRλ‘ μΆμΆλ ν
μ€νΈκ° μ¬κΈ° νμλ©λλ€.") |
|
|
| with gr.Column(elem_classes=["result-section"]): |
| gr.Markdown("### π 2λ¨κ³: μ¬μ΄ μ½ μ€λͺ
") |
| analysis_output = gr.Markdown("λ
ΈμΈκ³Ό μ΄λ¦°μ΄λ μ΄ν΄νκΈ° μ¬μ΄ μ½ μ λ³΄κ° μ¬κΈ° νμλ©λλ€.") |
|
|
| analyze_button.click( |
| run_analysis, |
| inputs=image_input, |
| outputs=[ocr_output, analysis_output], |
| ) |
|
|
| gr.Markdown(""" |
| --- |
| |
| **βΉοΈ μ¬μ© λ°©λ²** |
| 1. μ½ λ΄ν¬λ μ²λ°©μ μ¬μ§μ μ
λ‘λνμΈμ |
| 2. 'μ½ μ 보 λΆμνκΈ°' λ²νΌμ ν΄λ¦νμΈμ |
| 3. μΌμͺ½μλ μΆμΆλ ν
μ€νΈ, μ€λ₯Έμͺ½μλ μ¬μ΄ μ€λͺ
μ΄ λνλ©λλ€! |
| |
| **β οΈ μ£Όμμ¬ν** |
| - μ΄ μ±μ μ°Έκ³ μ©μ΄λ©°, μ€μ 볡μ½μ λ°λμ μμ¬λ μ½μ¬μ μ§μλ₯Ό λ°λ₯΄μΈμ |
| - AIκ° μμ±ν μ 보μ΄λ―λ‘ μ ννμ§ μμ μ μμ΅λλ€ |
| |
| **π€ κΈ°μ μ€ν** |
| - PaddleOCR PP-OCRv5 (νκ΅μ΄ μ΅μ ν OCR) |
| - Google Gemma-2-2B-IT (8bit μμν, λΉ λ₯Έ μλ£ μ 보 λΆμ) |
| |
| **π μ€μ λ°©λ²** |
| - Hugging Face Spacesμ Settings β Repository secretsμμ `HF_TOKEN` μΆκ° νμ |
| """) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |
|
|