Skip to content

Add vLLM plugin for Chronos-2 inference#473

Open
laviier wants to merge 1 commit intoamazon-science:mainfrom
laviier:vllm_plugin
Open

Add vLLM plugin for Chronos-2 inference#473
laviier wants to merge 1 commit intoamazon-science:mainfrom
laviier:vllm_plugin

Conversation

@laviier
Copy link
Copy Markdown

@laviier laviier commented Feb 26, 2026

Summary

Adding a vLLM plugin in sub-directory vllm-plugin/ that enables serving Chronos-2 models through vLLM's inference engine. The plugin integrates with vLLM's IOProcessor and multimodal interfaces, allowing Chronos-2 forecasts to be served via the standard /pooling API endpoint.

What's included

  • vllm-plugin/ — a standalone, pip-installable package (chronos-vllm-plugin) with its own pyproject.toml
  • Model wrapper (model.py) — inlines the Chronos-2 architecture with vLLM-native layer substitutions (RMSNorm, ReplicatedLinear, get_act_fn) for optimal inference performance
  • Multimodal pipeline (multimodal.py) — routes time series data through vLLM's multimodal infrastructure via a custom "timeseries" modality
  • IOProcessor (io_processor.py) — handles request parsing, tensor preparation, and response formatting
  • Protocol layer (protocol/) — Pydantic models for request/response validation, covariate encoding (TargetEncoder/OrdinalEncoder for categorical covariates), and cross-series validation
  • Quantile utilities (utils/) — quantile selection and linear interpolation for arbitrary quantile levels
  • Full test suite — unit tests for all components

Key features

  • Zero-shot probabilistic forecasting with configurable quantile levels
  • Univariate and multivariate time series support
  • Past and future covariates (numeric and categorical)
  • Cross-series learning via group attention
  • Batch forecasting with automatic row-level chunking respecting series boundaries
  • Up to 1024 time series per request

Installation

pip install chronos-forecasting[vllm]

vllm serve amazon/chronos-2 \
    --io-processor-plugin chronos2 \
    --runner pooling \
    --enforce-eager \
    --no-enable-prefix-caching \
    --skip-tokenizer-init \
    --enable-mm-embeds \
    --dtype float32 \
    --max-model-len 8192

Comment thread src/chronos/chronos2/vllm/protocol/data_prep.py Outdated
Comment thread src/chronos/chronos2/vllm/utils/quantiles.py Outdated
Comment thread src/chronos/chronos2/vllm/io_processor.py Outdated
Comment thread src/chronos/chronos2/vllm/io_processor.py Outdated
Comment thread src/chronos/chronos2/vllm/model.py Outdated
Comment thread src/chronos/chronos2/vllm/multimodal.py Outdated
Signed-off-by: Li Zhang <lzhanga@amazon.com>
@abdulfatir
Copy link
Copy Markdown
Contributor

@laviier Thanks a lot for addressing my comments. I am running a full eval on fev-bench. I will then review again carefully.

Copy link
Copy Markdown
Contributor

@abdulfatir abdulfatir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @laviier! I ran the full fev-bench and the results match (I had to do a couple of minor modifications to make it run). I left some comments. After these are fixed, I will approve the PR and we can mark this as completed. However, for merging we need to wait an see how the API and integration discussion evolves. Thanks a lot for your help!

For completeness, he is the script to run full fev-bench eval:

Click to view the code

# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
#   "numpy<2.2",
#   "scipy<1.15",
#   "pandas<3",
#   "fev==0.7.0",
#   "pyarrow<21",
#   "requests",
# ]
# ///

"""
This script evaluates Chronos-2 on several tasks from the fev-bench benchmark and compares the test scores to the reference scores.
"""

import json
import logging
import math

import datasets
import fev
import pandas as pd
import requests
from tqdm.auto import tqdm

logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

SELECTED_METRICS = ["SQL", "MASE", "WQL", "WAPE"]
BATCH_SIZE = 100
MAX_CONTEXT_LENGTH = 8192
VLLM_SERVER_URL = "http://localhost:8000"


class VLLMPredictor:
    """vLLM predictor that mimics SageMaker Predictor interface."""

    def __init__(self, server_url: str):
        self.server_url = server_url.rstrip("/")
        self.session = requests.Session()

        try:
            health_response = self.session.get(f"{self.server_url}/health", timeout=10)
            health_response.raise_for_status()
            print(f"Connected to vLLM server at {self.server_url}")
        except Exception as e:
            raise ConnectionError(f"Failed to connect to vLLM server: {e}")

    def predict(self, payload):
        try:
            request_json = {"model": "amazon/chronos-2", "task": "plugin", "data": payload}

            # Manually serialize JSON with allow_nan=True to handle NaN values
            json_data = json.dumps(request_json, allow_nan=True)

            response = self.session.post(
                f"{self.server_url}/pooling",
                data=json_data,
                timeout=300,
                headers={"Content-Type": "application/json"},
            )
            response.raise_for_status()
            return response.json()

        except Exception as e:
            raise RuntimeError(f"vLLM prediction failed: {e}")


def convert_df_to_payload(
    past_df: pd.DataFrame,
    future_df: pd.DataFrame | None = None,
    prediction_length: int = 1,
    target: str | list = "target",
    id_column: str = "item_id",
    timestamp_column: str = "timestamp",
    extra_parameters: dict | None = None,
):
    """
    Converts past and future DataFrames into JSON payload format for the Chronos endpoint.

    Args:
        past_df: Historical data with target, timestamp_column, and id_column.
        future_df: Future covariates with timestamp_column and id_column.
        prediction_length: Number of future time steps to predict.
        target: Column name(s) for target values (str for univariate, list for multivariate).
        id_column: Column name for item IDs.
        timestamp_column: Column name for timestamps.
        extra_parameters: Extra parameters included under 'parameters' in the payload

    Returns:
        dict: JSON payload formatted for the Chronos endpoint.
    """
    past_df = past_df.sort_values([id_column, timestamp_column])
    if future_df is not None:
        future_df = future_df.sort_values([id_column, timestamp_column])

    target_cols = [target] if isinstance(target, str) else target
    past_covariate_cols = list(past_df.columns.drop([*target_cols, id_column, timestamp_column]))
    future_covariate_cols = (
        [] if future_df is None else [col for col in past_covariate_cols if col in future_df.columns]
    )

    inputs = []
    for item_id, past_group in past_df.groupby(id_column):
        if len(target_cols) > 1:
            target_values = [past_group[col].tolist() for col in target_cols]
            series_length = len(target_values[0])
        else:
            target_values = past_group[target_cols[0]].tolist()
            series_length = len(target_values)

        if series_length < 1:
            raise ValueError(f"Time series '{item_id}' has fewer than 5 observations.")

        series_dict = {"target": target_values}

        if past_covariate_cols:
            series_dict["past_covariates"] = past_group[past_covariate_cols].to_dict(orient="list")

        if future_covariate_cols:
            future_group = future_df[future_df[id_column] == item_id]
            if len(future_group) != prediction_length:
                raise ValueError(
                    f"future_df must contain exactly {prediction_length=} values for each item_id from past_df "
                    f"(got {len(future_group)=}) for {item_id=}"
                )
            series_dict["future_covariates"] = future_group[future_covariate_cols].to_dict(orient="list")

        inputs.append(series_dict)

    extra_parameters = extra_parameters or {}
    return {
        "inputs": inputs,
        "parameters": {"prediction_length": prediction_length, **extra_parameters},
    }


def predict_for_task(predictor, task: fev.Task, batch_size: int = 100) -> list:
    predictions_per_window = []
    for window in tqdm(task.iter_windows(), total=task.num_windows):
        past_df, future_df, _ = fev.convert_input_data(window, adapter="pandas")
        past_df = past_df.groupby(task.id_column, as_index=False).tail(MAX_CONTEXT_LENGTH)
        payload = convert_df_to_payload(
            past_df,
            future_df,
            prediction_length=task.horizon,
            id_column=task.id_column,
            target=task.target_columns,
            extra_parameters={
                "quantile_levels": task.quantile_levels,
                "cross_learning": True,
            },
        )
        inputs, params = payload["inputs"], payload["parameters"]
        predictions = []
        task_n_dims = len(task.target_columns) + len(task.past_dynamic_columns) + len(task.known_dynamic_columns)
        n_items_per_batch = math.ceil(batch_size / task_n_dims)

        for i in range(0, len(inputs), n_items_per_batch):
            response = predictor.predict({"inputs": inputs[i : i + n_items_per_batch], "parameters": params})
            predictions.extend(response["data"]["predictions"])

        # Filter out metadata columns that vLLM adds but fev framework doesn't expect
        cleaned_predictions = []
        for pred in predictions:
            filtered_pred = {k: v for k, v in pred.items() if k not in ["item_id", "start"]}
            cleaned_predictions.append(filtered_pred)

        ds_dict = {}
        if isinstance(task.target_columns, list) and len(task.target_columns) > 1:
            for idx, col in enumerate(task.target_columns):
                ds = datasets.Dataset.from_list(
                    [{k: v[idx] for k, v in pred.items()} for pred in cleaned_predictions]
                ).rename_column("mean", "predictions")
                ds_dict[col] = ds
        else:
            col = task.target_columns if isinstance(task.target_columns, str) else task.target_columns[0]
            ds = datasets.Dataset.from_list(cleaned_predictions).rename_column("mean", "predictions")
            ds_dict[col] = ds

        predictions_per_window.append(ds_dict)

    return predictions_per_window


if __name__ == "__main__":
    benchmark = fev.Benchmark.from_yaml(
        "https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/fev_bench/tasks.yaml"
    )
    predictor = VLLMPredictor(VLLM_SERVER_URL)

    summaries = []
    for task in benchmark.tasks:
        print(f"Evaluating {task.task_name}...")
        predictions_per_window = predict_for_task(predictor, task, BATCH_SIZE)
        summaries.append(task.evaluation_summary(predictions_per_window, model_name="chronos-2-vllm"))

    current_summaries = pd.DataFrame(summaries).set_index("task_name").sort_index()
    reference_summaries = (
        pd.read_csv(
            "https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/fev_bench/results/chronos-2.csv"
        )
        .set_index("task_name")
        .sort_index()
    )

    current_summaries.to_csv("chronos-2-vllm-results.csv")

    print("\nCurrent vLLM Results:")
    print(current_summaries[SELECTED_METRICS].round(4))
    print("\nReference Results:")
    print(reference_summaries[SELECTED_METRICS].round(4))
    print("\nDifference in test scores (Reference - vLLM):")
    differences = reference_summaries[SELECTED_METRICS] - current_summaries[SELECTED_METRICS]
    print(differences.round(4).sort_values(by="SQL"))

--skip-tokenizer-init \
--enable-mm-embeds \
--dtype float32 \
--max-model-len 8192
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add this to make it work.

Suggested change
--max-model-len 8192
--max-model-len 8192 \
--max-num-batched-tokens 8192

# Constants — match the SageMaker endpoint constraints
# TODO: Get some of the values from model config
# ---------------------------------------------------------------------------
MIN_TARGET_LENGTH: int = 5
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed for running the full fev-bench.

Suggested change
MIN_TARGET_LENGTH: int = 5
MIN_TARGET_LENGTH: int = 1

if ts.past_covariates:
past_covs: dict[str, np.ndarray] = {}
for key, vals in ts.past_covariates.items():
arr = np.array([v if v is not None else np.nan for v in vals])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be moved under else below to avoid looping through vals twice.

if ts.future_covariates:
future_covs: dict[str, np.ndarray] = {}
for key, vals in ts.future_covariates.items():
arr = np.array([v if v is not None else np.nan for v in vals])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Can be moved inside else.

inputs=raw_inputs,
context_length=context_length,
prediction_length=pred_len,
batch_size=2**31 - 1, # large enough to fit all series in one batch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning for this? Should we fix a predefined batch size instead e.g. 100? One problem with having a single batch is that you can run OOM and the server will crash if client sends a large datasets. This happened to me for some tasks in fev-bench.

)

result: dict[str, Any] = {"predictions": [], "request_id": request_id}
for i, (q_tensor, m_tensor) in enumerate(zip(quantiles_out, mean_out)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forecast start key is not added here based on start and freq, if they exist in the input request.

# Determine num_output_patches from preprocessor or fall back to computing from future_covariates
if num_output_patches is None:
prediction_length = future_covariates.shape[1] if future_covariates is not None else 0
if prediction_length == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we verify that this is never hit? If this defaults to the largest prediction_length even for small prediction_length request, it would create unnecessary overhead.

return torch.zeros(input_len, 0, device=positions.device, dtype=torch.float32)

# Determine num_output_patches from preprocessor or fall back to computing from future_covariates
if num_output_patches is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this ever be None?

num_output_patches = int(math.ceil(prediction_length / self.output_patch_size))
num_output_patches = min(num_output_patches, self.max_output_patches)

prediction_length = num_output_patches * self.output_patch_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename to to something like output_length to avoid confusion with the actual prediction_length.

Comment thread pyproject.toml
"pandas[pyarrow]>=2.0,<2.4",
]
vllm = [
"vllm>=0.13.0",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"vllm>=0.13.0",
"vllm>=0.13,<0.17",

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants