Add vLLM plugin for Chronos-2 inference#473
Add vLLM plugin for Chronos-2 inference#473laviier wants to merge 1 commit intoamazon-science:mainfrom
Conversation
Signed-off-by: Li Zhang <lzhanga@amazon.com>
|
@laviier Thanks a lot for addressing my comments. I am running a full eval on fev-bench. I will then review again carefully. |
There was a problem hiding this comment.
Thanks @laviier! I ran the full fev-bench and the results match (I had to do a couple of minor modifications to make it run). I left some comments. After these are fixed, I will approve the PR and we can mark this as completed. However, for merging we need to wait an see how the API and integration discussion evolves. Thanks a lot for your help!
For completeness, he is the script to run full fev-bench eval:
Click to view the code
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "numpy<2.2",
# "scipy<1.15",
# "pandas<3",
# "fev==0.7.0",
# "pyarrow<21",
# "requests",
# ]
# ///
"""
This script evaluates Chronos-2 on several tasks from the fev-bench benchmark and compares the test scores to the reference scores.
"""
import json
import logging
import math
import datasets
import fev
import pandas as pd
import requests
from tqdm.auto import tqdm
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
SELECTED_METRICS = ["SQL", "MASE", "WQL", "WAPE"]
BATCH_SIZE = 100
MAX_CONTEXT_LENGTH = 8192
VLLM_SERVER_URL = "http://localhost:8000"
class VLLMPredictor:
"""vLLM predictor that mimics SageMaker Predictor interface."""
def __init__(self, server_url: str):
self.server_url = server_url.rstrip("/")
self.session = requests.Session()
try:
health_response = self.session.get(f"{self.server_url}/health", timeout=10)
health_response.raise_for_status()
print(f"Connected to vLLM server at {self.server_url}")
except Exception as e:
raise ConnectionError(f"Failed to connect to vLLM server: {e}")
def predict(self, payload):
try:
request_json = {"model": "amazon/chronos-2", "task": "plugin", "data": payload}
# Manually serialize JSON with allow_nan=True to handle NaN values
json_data = json.dumps(request_json, allow_nan=True)
response = self.session.post(
f"{self.server_url}/pooling",
data=json_data,
timeout=300,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return response.json()
except Exception as e:
raise RuntimeError(f"vLLM prediction failed: {e}")
def convert_df_to_payload(
past_df: pd.DataFrame,
future_df: pd.DataFrame | None = None,
prediction_length: int = 1,
target: str | list = "target",
id_column: str = "item_id",
timestamp_column: str = "timestamp",
extra_parameters: dict | None = None,
):
"""
Converts past and future DataFrames into JSON payload format for the Chronos endpoint.
Args:
past_df: Historical data with target, timestamp_column, and id_column.
future_df: Future covariates with timestamp_column and id_column.
prediction_length: Number of future time steps to predict.
target: Column name(s) for target values (str for univariate, list for multivariate).
id_column: Column name for item IDs.
timestamp_column: Column name for timestamps.
extra_parameters: Extra parameters included under 'parameters' in the payload
Returns:
dict: JSON payload formatted for the Chronos endpoint.
"""
past_df = past_df.sort_values([id_column, timestamp_column])
if future_df is not None:
future_df = future_df.sort_values([id_column, timestamp_column])
target_cols = [target] if isinstance(target, str) else target
past_covariate_cols = list(past_df.columns.drop([*target_cols, id_column, timestamp_column]))
future_covariate_cols = (
[] if future_df is None else [col for col in past_covariate_cols if col in future_df.columns]
)
inputs = []
for item_id, past_group in past_df.groupby(id_column):
if len(target_cols) > 1:
target_values = [past_group[col].tolist() for col in target_cols]
series_length = len(target_values[0])
else:
target_values = past_group[target_cols[0]].tolist()
series_length = len(target_values)
if series_length < 1:
raise ValueError(f"Time series '{item_id}' has fewer than 5 observations.")
series_dict = {"target": target_values}
if past_covariate_cols:
series_dict["past_covariates"] = past_group[past_covariate_cols].to_dict(orient="list")
if future_covariate_cols:
future_group = future_df[future_df[id_column] == item_id]
if len(future_group) != prediction_length:
raise ValueError(
f"future_df must contain exactly {prediction_length=} values for each item_id from past_df "
f"(got {len(future_group)=}) for {item_id=}"
)
series_dict["future_covariates"] = future_group[future_covariate_cols].to_dict(orient="list")
inputs.append(series_dict)
extra_parameters = extra_parameters or {}
return {
"inputs": inputs,
"parameters": {"prediction_length": prediction_length, **extra_parameters},
}
def predict_for_task(predictor, task: fev.Task, batch_size: int = 100) -> list:
predictions_per_window = []
for window in tqdm(task.iter_windows(), total=task.num_windows):
past_df, future_df, _ = fev.convert_input_data(window, adapter="pandas")
past_df = past_df.groupby(task.id_column, as_index=False).tail(MAX_CONTEXT_LENGTH)
payload = convert_df_to_payload(
past_df,
future_df,
prediction_length=task.horizon,
id_column=task.id_column,
target=task.target_columns,
extra_parameters={
"quantile_levels": task.quantile_levels,
"cross_learning": True,
},
)
inputs, params = payload["inputs"], payload["parameters"]
predictions = []
task_n_dims = len(task.target_columns) + len(task.past_dynamic_columns) + len(task.known_dynamic_columns)
n_items_per_batch = math.ceil(batch_size / task_n_dims)
for i in range(0, len(inputs), n_items_per_batch):
response = predictor.predict({"inputs": inputs[i : i + n_items_per_batch], "parameters": params})
predictions.extend(response["data"]["predictions"])
# Filter out metadata columns that vLLM adds but fev framework doesn't expect
cleaned_predictions = []
for pred in predictions:
filtered_pred = {k: v for k, v in pred.items() if k not in ["item_id", "start"]}
cleaned_predictions.append(filtered_pred)
ds_dict = {}
if isinstance(task.target_columns, list) and len(task.target_columns) > 1:
for idx, col in enumerate(task.target_columns):
ds = datasets.Dataset.from_list(
[{k: v[idx] for k, v in pred.items()} for pred in cleaned_predictions]
).rename_column("mean", "predictions")
ds_dict[col] = ds
else:
col = task.target_columns if isinstance(task.target_columns, str) else task.target_columns[0]
ds = datasets.Dataset.from_list(cleaned_predictions).rename_column("mean", "predictions")
ds_dict[col] = ds
predictions_per_window.append(ds_dict)
return predictions_per_window
if __name__ == "__main__":
benchmark = fev.Benchmark.from_yaml(
"https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/fev_bench/tasks.yaml"
)
predictor = VLLMPredictor(VLLM_SERVER_URL)
summaries = []
for task in benchmark.tasks:
print(f"Evaluating {task.task_name}...")
predictions_per_window = predict_for_task(predictor, task, BATCH_SIZE)
summaries.append(task.evaluation_summary(predictions_per_window, model_name="chronos-2-vllm"))
current_summaries = pd.DataFrame(summaries).set_index("task_name").sort_index()
reference_summaries = (
pd.read_csv(
"https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/fev_bench/results/chronos-2.csv"
)
.set_index("task_name")
.sort_index()
)
current_summaries.to_csv("chronos-2-vllm-results.csv")
print("\nCurrent vLLM Results:")
print(current_summaries[SELECTED_METRICS].round(4))
print("\nReference Results:")
print(reference_summaries[SELECTED_METRICS].round(4))
print("\nDifference in test scores (Reference - vLLM):")
differences = reference_summaries[SELECTED_METRICS] - current_summaries[SELECTED_METRICS]
print(differences.round(4).sort_values(by="SQL"))| --skip-tokenizer-init \ | ||
| --enable-mm-embeds \ | ||
| --dtype float32 \ | ||
| --max-model-len 8192 |
There was a problem hiding this comment.
I had to add this to make it work.
| --max-model-len 8192 | |
| --max-model-len 8192 \ | |
| --max-num-batched-tokens 8192 |
| # Constants — match the SageMaker endpoint constraints | ||
| # TODO: Get some of the values from model config | ||
| # --------------------------------------------------------------------------- | ||
| MIN_TARGET_LENGTH: int = 5 |
There was a problem hiding this comment.
This was needed for running the full fev-bench.
| MIN_TARGET_LENGTH: int = 5 | |
| MIN_TARGET_LENGTH: int = 1 |
| if ts.past_covariates: | ||
| past_covs: dict[str, np.ndarray] = {} | ||
| for key, vals in ts.past_covariates.items(): | ||
| arr = np.array([v if v is not None else np.nan for v in vals]) |
There was a problem hiding this comment.
This can be moved under else below to avoid looping through vals twice.
| if ts.future_covariates: | ||
| future_covs: dict[str, np.ndarray] = {} | ||
| for key, vals in ts.future_covariates.items(): | ||
| arr = np.array([v if v is not None else np.nan for v in vals]) |
There was a problem hiding this comment.
Same as above. Can be moved inside else.
| inputs=raw_inputs, | ||
| context_length=context_length, | ||
| prediction_length=pred_len, | ||
| batch_size=2**31 - 1, # large enough to fit all series in one batch |
There was a problem hiding this comment.
What's the reasoning for this? Should we fix a predefined batch size instead e.g. 100? One problem with having a single batch is that you can run OOM and the server will crash if client sends a large datasets. This happened to me for some tasks in fev-bench.
| ) | ||
|
|
||
| result: dict[str, Any] = {"predictions": [], "request_id": request_id} | ||
| for i, (q_tensor, m_tensor) in enumerate(zip(quantiles_out, mean_out)): |
There was a problem hiding this comment.
Forecast start key is not added here based on start and freq, if they exist in the input request.
| # Determine num_output_patches from preprocessor or fall back to computing from future_covariates | ||
| if num_output_patches is None: | ||
| prediction_length = future_covariates.shape[1] if future_covariates is not None else 0 | ||
| if prediction_length == 0: |
There was a problem hiding this comment.
Can we verify that this is never hit? If this defaults to the largest prediction_length even for small prediction_length request, it would create unnecessary overhead.
| return torch.zeros(input_len, 0, device=positions.device, dtype=torch.float32) | ||
|
|
||
| # Determine num_output_patches from preprocessor or fall back to computing from future_covariates | ||
| if num_output_patches is None: |
| num_output_patches = int(math.ceil(prediction_length / self.output_patch_size)) | ||
| num_output_patches = min(num_output_patches, self.max_output_patches) | ||
|
|
||
| prediction_length = num_output_patches * self.output_patch_size |
There was a problem hiding this comment.
Can we rename to to something like output_length to avoid confusion with the actual prediction_length.
| "pandas[pyarrow]>=2.0,<2.4", | ||
| ] | ||
| vllm = [ | ||
| "vllm>=0.13.0", |
There was a problem hiding this comment.
| "vllm>=0.13.0", | |
| "vllm>=0.13,<0.17", |
Summary
Adding a vLLM plugin in sub-directory
vllm-plugin/that enables serving Chronos-2 models through vLLM's inference engine. The plugin integrates with vLLM's IOProcessor and multimodal interfaces, allowing Chronos-2 forecasts to be served via the standard/poolingAPI endpoint.What's included
vllm-plugin/— a standalone, pip-installable package (chronos-vllm-plugin) with its ownpyproject.tomlmodel.py) — inlines the Chronos-2 architecture with vLLM-native layer substitutions (RMSNorm, ReplicatedLinear, get_act_fn) for optimal inference performancemultimodal.py) — routes time series data through vLLM's multimodal infrastructure via a custom "timeseries" modalityio_processor.py) — handles request parsing, tensor preparation, and response formattingprotocol/) — Pydantic models for request/response validation, covariate encoding (TargetEncoder/OrdinalEncoder for categorical covariates), and cross-series validationutils/) — quantile selection and linear interpolation for arbitrary quantile levelsKey features
Installation
pip install chronos-forecasting[vllm] vllm serve amazon/chronos-2 \ --io-processor-plugin chronos2 \ --runner pooling \ --enforce-eager \ --no-enable-prefix-caching \ --skip-tokenizer-init \ --enable-mm-embeds \ --dtype float32 \ --max-model-len 8192