| | import base64 |
| | import json |
| | import os |
| | from io import BytesIO |
| | from typing import Any, Dict, List |
| |
|
| | import numpy as np |
| | from PIL import Image |
| |
|
| | from openpi.policies import policy_config |
| | from openpi.training import config as train_config |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize the handler for pi0 model inference using openpi infrastructure. |
| | |
| | Args: |
| | path: Path to the model weights directory |
| | """ |
| | |
| | model_path = os.environ.get("MODEL_PATH", path) |
| | if not model_path: |
| | model_path = "weights/pi0" |
| |
|
| | |
| | config_path = os.path.join(model_path, "config.json") |
| | with open(config_path, "r") as f: |
| | model_config = json.load(f) |
| |
|
| | model_type = model_config.get("type", "pi0") |
| |
|
| | |
| | |
| | if model_type == "pi0": |
| | self.train_config = train_config.get_config("pi0") |
| | else: |
| | |
| | self.train_config = train_config.get_config("pi0") |
| |
|
| | |
| | |
| | self.policy = policy_config.create_trained_policy( |
| | self.train_config, |
| | model_path, |
| | pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" |
| | ) |
| |
|
| | |
| | self.default_num_steps = 50 |
| |
|
| | def _decode_base64_image(self, base64_str: str) -> np.ndarray: |
| | """ |
| | Decode base64 image string to numpy array. |
| | |
| | Args: |
| | base64_str: Base64 encoded image string |
| | |
| | Returns: |
| | numpy array of shape (H, W, 3) with values in [0, 255] |
| | """ |
| | |
| | if base64_str.startswith("data:image"): |
| | base64_str = base64_str.split(",", 1)[1] |
| |
|
| | |
| | image_bytes = base64.b64decode(base64_str) |
| |
|
| | |
| | image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| | image_array = np.array(image) |
| |
|
| | return image_array |
| |
|
| | def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]: |
| | """ |
| | Prepare observation dictionary in the format expected by openpi. |
| | |
| | Args: |
| | images: Dictionary mapping camera names to base64 encoded images |
| | state: List of robot state values |
| | prompt: Optional text prompt |
| | |
| | Returns: |
| | Observation dictionary in openpi format |
| | """ |
| | |
| | processed_images = {} |
| |
|
| | |
| | |
| | camera_mapping = { |
| | "camera0": "cam_high", |
| | "camera1": "cam_left_wrist", |
| | "camera2": "cam_right_wrist", |
| | |
| | "base_camera": "cam_high", |
| | "left_wrist": "cam_left_wrist", |
| | "right_wrist": "cam_right_wrist", |
| | |
| | "cam_high": "cam_high", |
| | "cam_left_wrist": "cam_left_wrist", |
| | "cam_right_wrist": "cam_right_wrist" |
| | } |
| |
|
| | for input_name, image_b64 in images.items(): |
| | |
| | openpi_name = camera_mapping.get(input_name, input_name) |
| |
|
| | |
| | image_array = self._decode_base64_image(image_b64) |
| |
|
| | |
| | if image_array.shape[:2] != (224, 224): |
| | image_pil = Image.fromarray(image_array) |
| | image_resized = image_pil.resize((224, 224)) |
| | image_array = np.array(image_resized) |
| |
|
| | |
| | processed_images[openpi_name] = image_array.astype(np.uint8) |
| |
|
| | |
| | required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"] |
| | for cam_name in required_cameras: |
| | if cam_name not in processed_images: |
| | |
| | processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8) |
| |
|
| | |
| | state_array = np.array(state, dtype=np.float32) |
| |
|
| | |
| | observation = { |
| | "state": state_array, |
| | "images": processed_images, |
| | } |
| |
|
| | |
| | if prompt: |
| | observation["prompt"] = prompt |
| |
|
| | return observation |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Main inference function called by HuggingFace endpoint. |
| | |
| | Args: |
| | data: Input data dictionary containing: |
| | - inputs: Dictionary with: |
| | - images: Dict mapping camera names to base64 encoded images |
| | - state: List of robot state values |
| | - prompt: Optional text prompt |
| | - num_actions: Optional, number of actions to predict (default: 50) |
| | - noise: Optional, noise array for sampling |
| | |
| | Returns: |
| | List containing prediction results |
| | """ |
| | try: |
| | inputs = data.get("inputs", {}) |
| |
|
| | |
| | images = inputs.get("images", {}) |
| | state = inputs.get("state", []) |
| | prompt = inputs.get("prompt", "") |
| | num_actions = inputs.get("num_actions", self.default_num_steps) |
| | noise_input = inputs.get("noise", None) |
| |
|
| | |
| | if not images: |
| | raise ValueError("No images provided") |
| | if not state: |
| | raise ValueError("No state provided") |
| |
|
| | |
| | observation = self._prepare_observation(images, state, prompt) |
| |
|
| | |
| | noise = None |
| | if noise_input is not None: |
| | noise = np.array(noise_input, dtype=np.float32) |
| |
|
| | |
| | |
| | result = self.policy.infer(observation, noise=noise) |
| |
|
| | |
| | actions = result["actions"] |
| |
|
| | |
| | if isinstance(actions, np.ndarray): |
| | actions_list = actions.tolist() |
| | else: |
| | actions_list = actions |
| |
|
| | |
| | return [{ |
| | "actions": actions_list, |
| | "num_actions": len(actions_list), |
| | "action_horizon": len(actions_list), |
| | "action_dim": len(actions_list[0]) if actions_list else 0, |
| | "success": True, |
| | "metadata": { |
| | "model_type": self.train_config.model.model_type.value, |
| | "policy_metadata": getattr(self.policy, '_metadata', {}) |
| | } |
| | }] |
| |
|
| | except Exception as e: |
| | return [{ |
| | "error": str(e), |
| | "success": False |
| | }] |