Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
543 changes: 543 additions & 0 deletions notebooks/temporary_scratchpad_ssm_batch_dims.ipynb

Large diffs are not rendered by default.

561 changes: 470 additions & 91 deletions pymc_extras/statespace/core/statespace.py

Large diffs are not rendered by default.

37 changes: 24 additions & 13 deletions pymc_extras/statespace/filters/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pytensor.graph.basic import Node
from pytensor.tensor.random import multivariate_normal

from pymc_extras.statespace.utils.batch_tools import bmv

floatX = pytensor.config.floatX
COV_ZERO_TOL = 0

Expand Down Expand Up @@ -203,8 +205,8 @@ def step_fn(*args):
for src_idx, dst_idx in enumerate(non_seq_positions):
ordered[dst_idx] = non_seqs[src_idx]
c, d, T, Z, R, H, Q = ordered
k = T.shape[0]
a = state[:k]
k = T.shape[-1]
a = state[..., :k]

middle_rng, a_innovation = pm.MvNormal.dist(
mu=0, cov=Q, rng=rng, method=method, return_next_rng=True
Expand All @@ -213,23 +215,23 @@ def step_fn(*args):
mu=0, cov=H, rng=middle_rng, method=method, return_next_rng=True
)

a_mu = c + T @ a
a_next = a_mu + R @ a_innovation
a_mu = c + bmv(T, a)
a_next = a_mu + bmv(R, a_innovation)

y_mu = d + Z @ a_next
y_mu = d + bmv(Z, a_next)
y_next = y_mu + y_innovation

next_state = pt.concatenate([a_next, y_next], axis=0)
next_state = pt.concatenate([a_next, y_next], axis=-1)

return next_rng, next_state

Z_init = Z_ if Z_ in non_sequences else Z_[0]
H_init = H_ if H_ in non_sequences else H_[0]

init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method=method)
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method=method)
init_y_ = pm.MvNormal.dist(bmv(Z_init, init_x_), H_init, rng=rng, method=method)

init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
init_dist_ = pt.concatenate([init_x_, init_y_], axis=-1)

ss_rng, statespace = pytensor.scan(
step_fn,
Expand All @@ -242,11 +244,12 @@ def step_fn(*args):
)

if append_x0:
statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0)
statespace_ = pt.specify_shape(statespace_, (steps + 1, None))
init_dist_expanded = pt.expand_dims(init_dist_, axis=0)
statespace_ = pt.concatenate([init_dist_expanded, statespace], axis=0)
# statespace_ = pt.specify_shape(statespace_, (steps + 1, None))
else:
statespace_ = statespace
statespace_ = pt.specify_shape(statespace_, (steps, None))
# statespace_ = pt.specify_shape(statespace_, (steps, None))

linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
Expand Down Expand Up @@ -287,7 +290,15 @@ def __new__(
dims = kwargs.pop("dims", None)
latent_dims = None
obs_dims = None
if dims is not None:
if dims is not None and len(dims) > 3:
# if len(dims) != 3:
# ValueError(
# "LinearGaussianStateSpace expects 3 dims: time, all_states, and observed_states"
# )
batch_dim, time_dim, state_dim, obs_dim = dims
latent_dims = [time_dim, batch_dim, state_dim]
obs_dims = [time_dim, batch_dim, obs_dim]
elif dims is not None:
if len(dims) != 3:
ValueError(
"LinearGaussianStateSpace expects 3 dims: time, all_states, and observed_states"
Expand All @@ -313,7 +324,7 @@ def __new__(
method=method,
**kwargs,
)
latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None))
# latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None))
if k_endog is None:
k_endog = cls._get_k_endog(H)
latent_slice = slice(None, -k_endog)
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/statespace/models/ETS.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(
k_endog,
k_states,
k_posdef,
filter_type,
filter_type=filter_type,
verbose=verbose,
measurement_error=measurement_error,
mode=mode,
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/statespace/models/SARIMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __init__(
k_endog,
k_states,
k_posdef,
filter_type,
filter_type=filter_type,
verbose=verbose,
measurement_error=measurement_error,
mode=mode,
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/statespace/models/VARMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
k_endog,
k_states,
k_posdef,
filter_type,
filter_type=filter_type,
verbose=verbose,
measurement_error=measurement_error,
mode=mode,
Expand Down
5 changes: 5 additions & 0 deletions pymc_extras/statespace/utils/batch_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import pytensor.tensor as pt


def bmv(A, x):
return pt.matmul(A, x[..., None])[..., 0]
1 change: 1 addition & 0 deletions pymc_extras/statespace/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FACTOR_DIM = "factor"
ERROR_AR_PARAM_DIM = "error_lag_ar"
EXOG_STATE_DIM = "exogenous"
BATCH_DIM = "batch"

MISSING_FILL = -9999.0
JITTER_DEFAULT = 1e-8 if pytensor.config.floatX.endswith("64") else 1e-6
Expand Down
57 changes: 36 additions & 21 deletions pymc_extras/statespace/utils/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from pymc.exceptions import ImputationWarning
from pytensor.tensor.sharedvar import TensorSharedVariable

from pymc_extras.statespace.utils.constants import (
MISSING_FILL,
OBS_STATE_DIM,
TIME_DIM,
)
from pymc_extras.statespace.utils.constants import BATCH_DIM, MISSING_FILL, OBS_STATE_DIM, TIME_DIM

NO_TIME_INDEX_WARNING = (
"No time index found on the supplied data. A simple range index will be automatically "
Expand All @@ -36,11 +32,13 @@ def get_data_dims(data):
return data_dims


def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None):
def _validate_data_shape(
data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None, batched=False
):
if col_names is None:
col_names = []

if len(data_shape) != 2:
if not batched and len(data_shape) != 2:
raise ValueError("Data must be a 2d matrix")

if data_shape[-1] != n_obs:
Expand All @@ -59,22 +57,27 @@ def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=Fal
)


def preprocess_tensor_data(data, n_obs, obs_coords=None):
def preprocess_tensor_data(data, n_obs, obs_coords=None, batched=False):
data_shape = data.shape.eval()
_validate_data_shape(data_shape, n_obs, obs_coords)
_validate_data_shape(data_shape, n_obs, obs_coords, batched=batched)
if obs_coords is not None:
warnings.warn(NO_TIME_INDEX_WARNING)
index = np.arange(data_shape[0], dtype="int")

index = (
np.arange(data_shape[0], dtype="int")
if not batched
else np.arange(data_shape[1], dtype="int")
)

return data.eval(), index


def preprocess_numpy_data(data, n_obs, obs_coords=None):
_validate_data_shape(data.shape, n_obs, obs_coords)
def preprocess_numpy_data(data, n_obs, obs_coords=None, batched=False):
_validate_data_shape(data.shape, n_obs, obs_coords, batched=batched)
if obs_coords is not None:
warnings.warn(NO_TIME_INDEX_WARNING)

index = np.arange(data.shape[0], dtype="int")
index = np.arange(data.shape[0], dtype="int") if not batched else np.arange(data.shape[1])

return data, index

Expand Down Expand Up @@ -122,11 +125,15 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
return preprocess_numpy_data(data.values, n_obs, obs_coords)


def add_data_to_active_model(values, index, data_dims=None):
def add_data_to_active_model(values, index, data_dims=None, batched=False):
pymc_mod = modelcontext(None)
if data_dims is None:
data_dims = [TIME_DIM, OBS_STATE_DIM]
time_dim = data_dims[0]
if not batched:
data_dims = [TIME_DIM, OBS_STATE_DIM]
else:
data_dims = [BATCH_DIM, TIME_DIM, OBS_STATE_DIM]

time_dim = data_dims[0] if not batched else data_dims[1]

if isinstance(index, pd.Index):
index = index.rename(time_dim)
Expand All @@ -145,10 +152,14 @@ def add_data_to_active_model(values, index, data_dims=None):

# If the data has just one column, we need to specify the shape as (None, 1), or else the JAX backend will
# raise a broadcasting error.
if values.shape[-1] == 1 or values.ndim == 1:
if (values.shape[-1] == 1 or values.ndim == 1) and not batched:
data_shape = (None, 1)
else:
elif (values.shape[-1] == 1 or values.ndim == 1) and batched:
data_shape = (values.shape[0], None, 1)
elif not batched:
data_shape = (None, values.shape[-1])
else:
data_shape = (values.shape[0], None, *values.shape[2:])

data = pm.Data("data", values, dims=data_dims, shape=data_shape)

Expand Down Expand Up @@ -184,10 +195,14 @@ def mask_missing_values_in_data(values, missing_fill_value=None):
def register_data_with_pymc(
data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None
):
batched = False
if data_dims and BATCH_DIM in data_dims:
batched = True

if isinstance(data, pt.TensorVariable | TensorSharedVariable):
values, index = preprocess_tensor_data(data, n_obs, obs_coords)
values, index = preprocess_tensor_data(data, n_obs, obs_coords, batched)
elif isinstance(data, np.ndarray):
values, index = preprocess_numpy_data(data, n_obs, obs_coords)
values, index = preprocess_numpy_data(data, n_obs, obs_coords, batched)
elif isinstance(data, pd.DataFrame | pd.Series):
values, index = preprocess_pandas_data(data, n_obs, obs_coords)
else:
Expand All @@ -196,7 +211,7 @@ def register_data_with_pymc(
data, nan_mask = mask_missing_values_in_data(values, missing_fill_value)

if register_data:
data = add_data_to_active_model(data, index, data_dims)
data = add_data_to_active_model(data, index, data_dims, batched)
else:
data = pytensor.shared(data, name="data")
return data, nan_mask
Loading
Loading