File size: 8,687 Bytes
224b1c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""
Gradio application for MedGemma inference with ZeroGPU.

This script defines a minimal Gradio interface around Google's
``medgemma‑27b‑it`` multi‑modal model.  It is designed to run on
Hugging Face Spaces using the **ZeroGPU** hardware option.  ZeroGPU
allocates an NVIDIA H200 GPU slice for the duration of each call and
releases it afterwards.  The interface accepts a textual **prompt**
(English only), an optional image upload and an optional **system
prompt** to steer the model.  All responses are returned in English and
include a short disclaimer reminding users to consult a medical
professional.

If you set an ``API_KEY`` secret in your Space, callers must supply the
same value in the hidden API key field.  Otherwise the endpoint will be
publicly accessible.  See the README for details.

Note: ZeroGPU Spaces currently only work with the **Gradio** SDK and
support specific versions of PyTorch and Python【916380489845432†L110-L118】.
Running this script outside of a Space will work on CPU or dedicated
GPU hardware, but ZeroGPU GPU allocation only takes effect when the
Space hardware is set to *ZeroGPU (Dynamic resources)*.
"""

import os
from typing import Optional

import gradio as gr
from PIL import Image
import torch
from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    GenerationConfig,
    pipeline,
)
import spaces  # for the @spaces.GPU decorator

# ----------------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------------

HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
    raise RuntimeError(
        "HF_TOKEN environment variable must be set as a Secret in the Space."
    )

# Optional API key: when set, clients must provide the same value in the
# hidden ``api_key`` field of the Gradio interface.  If not set, no
# authentication is enforced.
API_KEY = os.getenv("API_KEY")

MODEL_ID = "google/medgemma-27b-it"

# Load the processor outside of the GPU context – this is lightweight
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)

eos_id = processor.tokenizer.eos_token_id
pad_id = processor.tokenizer.pad_token_id or eos_id

# Banned phrases to reduce chatty or irrelevant responses
ban_list = [
    "Disclaimer",
    "disclaimer",
    "As an AI Chatbot",
    "as an AI Chatbot",
    "I cannot give medical advice",
    "I cannot provide medical advice",
    "I cannot give medical advise",
    "user",
    "response",
    "display",
    "response>",
    "```",
    "label",
    "tool_code",
]
bad_words_ids = [processor.tokenizer(b, add_special_tokens=False).input_ids for b in ban_list]

gen_cfg = GenerationConfig(
    max_new_tokens=120,
    do_sample=False,
    repetition_penalty=1.12,
    no_repeat_ngram_size=6,
    length_penalty=1.0,
    temperature=0.0,
    eos_token_id=eos_id,
    pad_token_id=pad_id,
    bad_words_ids=bad_words_ids,
)

# We'll load the model lazily inside run_model to ensure GPU allocation
# occurs within the ZeroGPU context.  Cache the model and pipeline on
# first use so subsequent calls are faster.  A simple attribute on the
# function serves as a persistent cache.


@spaces.GPU(duration=120)
def run_model(prompt: str, image: Optional[Image.Image], system_prompt: Optional[str]) -> str:
    """Execute the MedGemma model.

    This function will be run inside the ZeroGPU allocation context.  It
    lazily loads the model and pipeline on first invocation and reuses
    them for subsequent calls.  Inputs are combined with an optional
    system prompt to produce the full prompt.  The model's output is
    returned as a plain English string.

    Args:
        prompt: The user's question (English only).
        image: An optional PIL Image.  If provided, the model will use
            both text and image modalities; otherwise text-only.
        system_prompt: An optional system prompt to steer the model.  If
            None or empty, a default instruction is used.

    Returns:
        The raw English output from the model (without disclaimer).
    """
    # Lazy‑load the model and pipeline on first use
    if not hasattr(run_model, "model"):
        # Determine the appropriate dtype and device map.  We'll load on
        # auto to split across CPU/GPU if necessary.  Use bfloat16 when
        # CUDA is available to save memory on H200.
        model_kwargs: dict = {
            "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            "token": HF_TOKEN,
        }
        if torch.cuda.is_available():
            model_kwargs["device_map"] = "auto"
        model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
        # Create a pipeline for convenience
        vlm = pipeline(
            task="image-text-to-text",
            model=model,
            processor=processor,
            generation_config=gen_cfg,
        )
        # Store for reuse
        run_model.model = model
        run_model.vlm = vlm
    else:
        vlm = run_model.vlm

    # Compose the full prompt
    sys_prompt = (
        system_prompt.strip()
        if system_prompt and system_prompt.strip()
        else "You are a concise radiology assistant. Answer the user's question based on the image and text."
    )
    full_prompt = sys_prompt + "\n" + prompt.strip()

    # Run inference
    if image is not None:
        result = vlm(image, full_prompt)
    else:
        result = vlm(full_prompt)
    output = result[0]["generated_text"]
    return output


def predict(
    prompt: str,
    image: Optional[Image.Image] = None,
    system_prompt: Optional[str] = None,
    api_key: Optional[str] = None,
) -> str:
    """Wrapper function for Gradio.

    Handles optional API key authentication and appends a disclaimer to
    the model's output.  See README for details.

    Args:
        prompt: The user's question in English.
        image: An optional PIL image.
        system_prompt: Optional system prompt to steer the model.
        api_key: Optional API key supplied by the client.  If the
            ``API_KEY`` secret is set and this does not match, the
            request is rejected.

    Returns:
        A string containing the model's answer followed by a
        disclaimer.  If authentication fails an error message is
        returned instead.
    """
    # Enforce API key if configured
    if API_KEY:
        if api_key is None or api_key != API_KEY:
            return "Error: Invalid or missing API key."

    # Validate prompt
    if not prompt or not prompt.strip():
        return "Error: Prompt cannot be empty."

    try:
        answer = run_model(prompt, image, system_prompt)
    except Exception as e:
        return f"Error during inference: {e}"
    disclaimer = (
        "\n\nThis response is generated by an AI model and may be incorrect. "
        "Always consult a licensed medical professional for health questions."
    )
    return answer.strip() + disclaimer


def build_demo() -> gr.Interface:
    """Construct the Gradio UI for this application."""
    # Define inputs: prompt, optional image, optional system prompt, and
    # optional API key (hidden from the UI).  When API_KEY is not
    # configured the api_key input is ignored.
    inputs = [
        gr.Textbox(
            label="Prompt (English only)",
            lines=4,
            placeholder="Describe the medical image or ask a question."
        ),
        gr.Image(
            type="pil",
            label="Optional image"
        ),
        gr.Textbox(
            label="Optional system prompt",
            lines=2,
            placeholder="e.g. You are a concise radiology assistant."
        ),
        gr.Textbox(
            label="API key",
            lines=1,
            placeholder="Enter API key if required",
            type="password",
            visible=bool(API_KEY),
        ),
    ]
    outputs = gr.Textbox(label="Answer")
    description = (
        "Ask MedGemma a question about a medical image or condition. "
        "Optionally provide a system prompt to guide the model's behaviour. "
        "All responses are in English and include a disclaimer."
    )
    demo = gr.Interface(
        fn=predict,
        inputs=inputs,
        outputs=outputs,
        title="MedGemma ZeroGPU (Gradio)",
        description=description,
        allow_flagging="never",
    )
    return demo


demo = build_demo()

if __name__ == "__main__":
    # Launch with share=False to bind to the default port.  In Spaces this
    # function is not executed; Spaces uses the Gradio SDK to run the app.
    demo.launch()