|
|
""" |
|
|
GPT-OSS Model Deployment on Modal with vLLM |
|
|
|
|
|
This script deploys OpenAI's GPT-OSS models (20B or 120B) on Modal.com |
|
|
with vLLM for efficient inference. |
|
|
|
|
|
Usage: |
|
|
# First time setup - pre-download model weights (run once, takes ~5-10 min) |
|
|
modal run gpt_oss_inference.py::download_model |
|
|
|
|
|
# Test the server locally |
|
|
modal run gpt_oss_inference.py |
|
|
|
|
|
# Deploy to production |
|
|
modal deploy gpt_oss_inference.py |
|
|
|
|
|
Performance Tips: |
|
|
1. Run download_model first to cache weights in the volume |
|
|
2. Reduce MAX_MODEL_LEN for faster startup (8k is sufficient for most use cases) |
|
|
3. Keep FAST_BOOT=True for cheaper GPUs (A10G, L4) |
|
|
4. Increase SCALEDOWN_WINDOW to reduce cold starts during demos |
|
|
|
|
|
Based on: https://modal.com/docs/examples/gpt_oss_inference |
|
|
""" |
|
|
|
|
|
import json |
|
|
import time |
|
|
from datetime import datetime, timezone |
|
|
from typing import Any |
|
|
|
|
|
import aiohttp |
|
|
import modal |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vllm_image = ( |
|
|
modal.Image.from_registry( |
|
|
"nvidia/cuda:12.8.1-devel-ubuntu22.04", |
|
|
add_python="3.12", |
|
|
) |
|
|
.entrypoint([]) |
|
|
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) |
|
|
.uv_pip_install( |
|
|
"vllm==0.11.0", |
|
|
"huggingface_hub[hf_transfer]==0.35.0", |
|
|
"flashinfer-python==0.3.1", |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "openai/gpt-oss-20b" |
|
|
MODEL_REVISION = "d666cf3b67006cf8227666739edf25164aaffdeb" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPU_CONFIG = "A100-40GB" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) |
|
|
|
|
|
|
|
|
vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MINUTES = 60 |
|
|
|
|
|
|
|
|
|
|
|
FAST_BOOT = True |
|
|
|
|
|
|
|
|
CUDA_GRAPH_CAPTURE_SIZES = [1, 2, 4, 8, 16, 24, 32] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USE_FLOAT16 = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_MODEL_LEN = 16384 |
|
|
|
|
|
|
|
|
VLLM_PORT = 8000 |
|
|
N_GPU = 1 |
|
|
MAX_INPUTS = 50 |
|
|
|
|
|
|
|
|
|
|
|
SCALEDOWN_WINDOW = 10 * MINUTES |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = modal.App("gpt-oss-vllm-inference") |
|
|
|
|
|
|
|
|
|
|
|
_GPU_MAP = { |
|
|
"T4": "T4", |
|
|
"L4": "L4", |
|
|
"A10G": "A10G", |
|
|
"A100": "A100:40GB", |
|
|
"A100-80GB": "A100:80GB", |
|
|
"H100": "H100", |
|
|
} |
|
|
SELECTED_GPU = _GPU_MAP.get(GPU_CONFIG, "A10G") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.function( |
|
|
image=vllm_image, |
|
|
volumes={"/root/.cache/huggingface": hf_cache_vol}, |
|
|
timeout=30 * MINUTES, |
|
|
) |
|
|
def download_model(): |
|
|
""" |
|
|
Pre-download the model weights to the volume cache. |
|
|
Run this once with: modal run gpt_oss_inference.py::download_model |
|
|
This will cache the weights and make subsequent starts much faster. |
|
|
""" |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
print(f"📥 Downloading model weights for {MODEL_NAME}...") |
|
|
print(f" Revision: {MODEL_REVISION}") |
|
|
|
|
|
snapshot_download( |
|
|
MODEL_NAME, |
|
|
revision=MODEL_REVISION, |
|
|
local_dir=f"/root/.cache/huggingface/hub/models--{MODEL_NAME.replace('/', '--')}", |
|
|
) |
|
|
|
|
|
print("✅ Model weights downloaded and cached!") |
|
|
print(" Future container starts will use the cached weights.") |
|
|
|
|
|
|
|
|
@app.function( |
|
|
image=vllm_image, |
|
|
gpu=SELECTED_GPU, |
|
|
scaledown_window=SCALEDOWN_WINDOW, |
|
|
timeout=30 * MINUTES, |
|
|
volumes={ |
|
|
"/root/.cache/huggingface": hf_cache_vol, |
|
|
"/root/.cache/vllm": vllm_cache_vol, |
|
|
}, |
|
|
) |
|
|
@modal.concurrent(max_inputs=MAX_INPUTS) |
|
|
@modal.web_server(port=VLLM_PORT, startup_timeout=30 * MINUTES) |
|
|
def serve(): |
|
|
"""Start the vLLM server with GPT-OSS model.""" |
|
|
import subprocess |
|
|
|
|
|
cmd = [ |
|
|
"vllm", |
|
|
"serve", |
|
|
"--uvicorn-log-level=info", |
|
|
MODEL_NAME, |
|
|
"--revision", |
|
|
MODEL_REVISION, |
|
|
"--served-model-name", |
|
|
"llm", |
|
|
"--host", |
|
|
"0.0.0.0", |
|
|
"--port", |
|
|
str(VLLM_PORT), |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
cmd += ["--enforce-eager" if FAST_BOOT else "--no-enforce-eager"] |
|
|
|
|
|
if not FAST_BOOT: |
|
|
cmd += [ |
|
|
"-O.cudagraph_capture_sizes=" |
|
|
+ str(CUDA_GRAPH_CAPTURE_SIZES).replace(" ", "") |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
if USE_FLOAT16: |
|
|
cmd += ["--dtype", "float16"] |
|
|
else: |
|
|
cmd += ["--dtype", "bfloat16"] |
|
|
|
|
|
|
|
|
cmd += ["--max-model-len", str(MAX_MODEL_LEN)] |
|
|
|
|
|
|
|
|
if N_GPU == 1: |
|
|
cmd += ["--disable-custom-all-reduce"] |
|
|
|
|
|
|
|
|
cmd += ["--enable-prefix-caching"] |
|
|
|
|
|
|
|
|
cmd += ["--trust-remote-code"] |
|
|
|
|
|
|
|
|
cmd += ["--load-format", "auto"] |
|
|
|
|
|
|
|
|
cmd += ["--tensor-parallel-size", str(N_GPU)] |
|
|
|
|
|
|
|
|
|
|
|
cmd += ["--disable-log-stats"] |
|
|
|
|
|
|
|
|
cmd += ["--swap-space", "4"] |
|
|
|
|
|
print(f"Starting vLLM server with command: {' '.join(cmd)}") |
|
|
|
|
|
subprocess.Popen(" ".join(cmd), shell=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.local_entrypoint() |
|
|
async def test(test_timeout=30 * MINUTES, user_content=None, twice=True): |
|
|
""" |
|
|
Test the deployed server with a sample prompt. |
|
|
|
|
|
Args: |
|
|
test_timeout: Maximum time to wait for server health |
|
|
user_content: Custom prompt to send (default: SVD explanation) |
|
|
twice: Whether to send a second request |
|
|
""" |
|
|
url = serve.get_web_url() |
|
|
|
|
|
system_prompt = { |
|
|
"role": "system", |
|
|
"content": f"""You are ChatModal, a large language model trained by Modal. |
|
|
Knowledge cutoff: 2024-06 |
|
|
Current date: {datetime.now(timezone.utc).date()} |
|
|
Reasoning: low |
|
|
# Valid channels: analysis, commentary, final. Channel must be included for every message. |
|
|
Calls to these tools must go to the commentary channel: 'functions'.""", |
|
|
} |
|
|
|
|
|
if user_content is None: |
|
|
user_content = "Explain what the Singular Value Decomposition is." |
|
|
|
|
|
messages = [ |
|
|
system_prompt, |
|
|
{"role": "user", "content": user_content}, |
|
|
] |
|
|
|
|
|
async with aiohttp.ClientSession(base_url=url) as session: |
|
|
print(f"Running health check for server at {url}") |
|
|
async with session.get("/health", timeout=test_timeout - 1 * MINUTES) as resp: |
|
|
up = resp.status == 200 |
|
|
assert up, f"Failed health check for server at {url}" |
|
|
print(f"Successful health check for server at {url}") |
|
|
|
|
|
print(f"Sending messages to {url}:", *messages, sep="\n\t") |
|
|
await _send_request(session, "llm", messages) |
|
|
|
|
|
if twice: |
|
|
messages[0]["content"] += "\nTalk like a pirate, matey." |
|
|
print(f"Re-sending messages to {url}:", *messages, sep="\n\t") |
|
|
await _send_request(session, "llm", messages) |
|
|
|
|
|
|
|
|
async def _send_request( |
|
|
session: aiohttp.ClientSession, model: str, messages: list |
|
|
) -> None: |
|
|
"""Send a streaming request to the vLLM server.""" |
|
|
|
|
|
payload: dict[str, Any] = {"messages": messages, "model": model, "stream": True} |
|
|
|
|
|
headers = {"Content-Type": "application/json", "Accept": "text/event-stream"} |
|
|
|
|
|
t = time.perf_counter() |
|
|
async with session.post( |
|
|
"/v1/chat/completions", json=payload, headers=headers, timeout=10 * MINUTES |
|
|
) as resp: |
|
|
async for raw in resp.content: |
|
|
resp.raise_for_status() |
|
|
|
|
|
line = raw.decode().strip() |
|
|
if not line or line == "data: [DONE]": |
|
|
continue |
|
|
if line.startswith("data: "): |
|
|
line = line[len("data: ") :] |
|
|
|
|
|
chunk = json.loads(line) |
|
|
assert ( |
|
|
chunk["object"] == "chat.completion.chunk" |
|
|
) |
|
|
delta = chunk["choices"][0]["delta"] |
|
|
|
|
|
if "content" in delta: |
|
|
print(delta["content"], end="") |
|
|
elif "reasoning_content" in delta: |
|
|
print(delta["reasoning_content"], end="") |
|
|
elif not delta: |
|
|
print() |
|
|
else: |
|
|
raise ValueError(f"Unsupported response delta: {delta}") |
|
|
print("") |
|
|
print(f"Time to Last Token: {time.perf_counter() - t:.2f} seconds") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_endpoint_url() -> str: |
|
|
"""Get the deployed endpoint URL.""" |
|
|
return serve.get_web_url() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Run this script with Modal:") |
|
|
print(" modal run gpt_oss_inference.py # Test the server") |
|
|
print(" modal deploy gpt_oss_inference.py # Deploy to production") |
|
|
|
|
|
|