|
|
from typing import Dict, List, Any |
|
|
from PIL import Image |
|
|
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
"""Called when the endpoint starts. Load model and processor.""" |
|
|
self.processor = Pix2StructProcessor.from_pretrained(path) |
|
|
self.model = Pix2StructForConditionalGeneration.from_pretrained(path) |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.default_header = "Generate underlying data table of the figure below:" |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Called on every request. |
|
|
|
|
|
Args: |
|
|
data: Dictionary containing: |
|
|
- inputs: base64 encoded image string |
|
|
- parameters (optional): dict with: |
|
|
- header: text prompt for the model (default: DePlot prompt) |
|
|
- max_new_tokens: max generation length (default: 512) |
|
|
|
|
|
Returns: |
|
|
List containing the generated table text |
|
|
""" |
|
|
inputs = data.get("inputs") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
header_text = ( |
|
|
parameters.get("header") or |
|
|
parameters.get("text") or |
|
|
parameters.get("prompt") or |
|
|
data.get("header") or |
|
|
data.get("text") or |
|
|
data.get("prompt") or |
|
|
self.default_header |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
try: |
|
|
image_bytes = base64.b64decode(inputs) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to decode base64 image: {e}") |
|
|
else: |
|
|
raise ValueError("Expected base64 encoded image string in 'inputs'") |
|
|
|
|
|
|
|
|
model_inputs = self.processor( |
|
|
images=image, |
|
|
text=header_text, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 512) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = self.model.generate( |
|
|
**model_inputs, |
|
|
max_new_tokens=max_new_tokens |
|
|
) |
|
|
|
|
|
|
|
|
output_text = self.processor.decode( |
|
|
predictions[0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return [{"generated_text": output_text}] |