import torch import os from tirex import load_model, ForecastModel # Disable CUDA for Hugging Face endpoints unless explicitly enabled os.environ['TIREX_NO_CUDA'] = '1' class EndpointModel: def __init__(self): """ This class is used by Hugging Face Inference Endpoints to initialize the model once at startup. """ # Load the TiRex model from Hugging Face hub # This will resolve to your repo (NX-AI/TiRex) self.model: ForecastModel = load_model("NX-AI/TiRex") def __call__(self, inputs: dict) -> dict: """ This method is called for every inference request. Inputs must be JSON-serializable. Example request: { "data": [[0.1, 0.2, 0.3, ...], [0.5, 0.6, ...]], # 2D array: batch_size x context_length "prediction_length": 64 } """ # Convert input data to a torch tensor data = torch.tensor(inputs["data"], dtype=torch.float32) # Default prediction length if not provided prediction_length = inputs.get("prediction_length", 64) # Run forecast quantiles, mean = self.model.forecast( context=data, prediction_length=prediction_length ) # Return both quantiles and mean as Python lists (JSON-safe) return { "quantiles": {k: v.tolist() for k, v in quantiles.items()}, "mean": mean.tolist() }