File size: 1,461 Bytes
155c2a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
import os
from tirex import load_model, ForecastModel
# Disable CUDA for Hugging Face endpoints unless explicitly enabled
os.environ['TIREX_NO_CUDA'] = '1'
class EndpointModel:
def __init__(self):
"""
This class is used by Hugging Face Inference Endpoints
to initialize the model once at startup.
"""
# Load the TiRex model from Hugging Face hub
# This will resolve to your repo (NX-AI/TiRex)
self.model: ForecastModel = load_model("NX-AI/TiRex")
def __call__(self, inputs: dict) -> dict:
"""
This method is called for every inference request.
Inputs must be JSON-serializable.
Example request:
{
"data": [[0.1, 0.2, 0.3, ...], [0.5, 0.6, ...]], # 2D array: batch_size x context_length
"prediction_length": 64
}
"""
# Convert input data to a torch tensor
data = torch.tensor(inputs["data"], dtype=torch.float32)
# Default prediction length if not provided
prediction_length = inputs.get("prediction_length", 64)
# Run forecast
quantiles, mean = self.model.forecast(
context=data,
prediction_length=prediction_length
)
# Return both quantiles and mean as Python lists (JSON-safe)
return {
"quantiles": {k: v.tolist() for k, v in quantiles.items()},
"mean": mean.tolist()
}
|