deplot / handler.py
convexray's picture
Update handler.py
bbddc92 verified
from typing import Dict, List, Any
from PIL import Image
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
import torch
import base64
import io
class EndpointHandler:
def __init__(self, path: str = ""):
"""Called when the endpoint starts. Load model and processor."""
self.processor = Pix2StructProcessor.from_pretrained(path)
self.model = Pix2StructForConditionalGeneration.from_pretrained(path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
# Default prompt for DePlot
self.default_header = "Generate underlying data table of the figure below:"
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Called on every request.
Args:
data: Dictionary containing:
- inputs: base64 encoded image string
- parameters (optional): dict with:
- header: text prompt for the model (default: DePlot prompt)
- max_new_tokens: max generation length (default: 512)
Returns:
List containing the generated table text
"""
inputs = data.get("inputs")
parameters = data.get("parameters", {})
# Get header text - check multiple possible keys
header_text = (
parameters.get("header") or
parameters.get("text") or
parameters.get("prompt") or
data.get("header") or
data.get("text") or
data.get("prompt") or
self.default_header
)
# Decode base64 image
if isinstance(inputs, str):
try:
image_bytes = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
raise ValueError(f"Failed to decode base64 image: {e}")
else:
raise ValueError("Expected base64 encoded image string in 'inputs'")
# Process image WITH header text (required for Pix2Struct!)
model_inputs = self.processor(
images=image,
text=header_text, # <-- THIS WAS MISSING
return_tensors="pt"
).to(self.device)
# Get generation parameters
max_new_tokens = parameters.get("max_new_tokens", 512)
# Generate
with torch.no_grad():
predictions = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens
)
# Decode
output_text = self.processor.decode(
predictions[0],
skip_special_tokens=True
)
return [{"generated_text": output_text}]