from typing import Dict, List, Any from PIL import Image from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration import torch import base64 import io class EndpointHandler: def __init__(self, path: str = ""): """Called when the endpoint starts. Load model and processor.""" self.processor = Pix2StructProcessor.from_pretrained(path) self.model = Pix2StructForConditionalGeneration.from_pretrained(path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() # Default prompt for DePlot self.default_header = "Generate underlying data table of the figure below:" def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Called on every request. Args: data: Dictionary containing: - inputs: base64 encoded image string - parameters (optional): dict with: - header: text prompt for the model (default: DePlot prompt) - max_new_tokens: max generation length (default: 512) Returns: List containing the generated table text """ inputs = data.get("inputs") parameters = data.get("parameters", {}) # Get header text - check multiple possible keys header_text = ( parameters.get("header") or parameters.get("text") or parameters.get("prompt") or data.get("header") or data.get("text") or data.get("prompt") or self.default_header ) # Decode base64 image if isinstance(inputs, str): try: image_bytes = base64.b64decode(inputs) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as e: raise ValueError(f"Failed to decode base64 image: {e}") else: raise ValueError("Expected base64 encoded image string in 'inputs'") # Process image WITH header text (required for Pix2Struct!) model_inputs = self.processor( images=image, text=header_text, # <-- THIS WAS MISSING return_tensors="pt" ).to(self.device) # Get generation parameters max_new_tokens = parameters.get("max_new_tokens", 512) # Generate with torch.no_grad(): predictions = self.model.generate( **model_inputs, max_new_tokens=max_new_tokens ) # Decode output_text = self.processor.decode( predictions[0], skip_special_tokens=True ) return [{"generated_text": output_text}]