Skip to content
85 changes: 67 additions & 18 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from pymc_extras.statespace.filters.distributions import (
LinearGaussianStateSpace,
SequenceMvNormal,
SimulationSmoother,
)
from pymc_extras.statespace.filters.utilities import stabilize
from pymc_extras.statespace.utils.constants import (
Expand Down Expand Up @@ -1398,6 +1399,14 @@ def _sample_conditional(
compile_kwargs = kwargs.pop("compile_kwargs", {})
compile_kwargs.setdefault("mode", self.mode)

# Recover the static time length so it can be pinned on the smoothed-state
# mean fed into ``SimulationSmoother``. With ``inline=False`` the inner
# forward-sim / kalman scans are opaque to the parent fgraph, so shape
# info that would normally flow in from ``freeze_dims_and_data`` is lost
# and JAX's ``lax.scan`` chokes on a dynamic n_steps.
data_for_len = data if data is not None else self._fit_data
static_T = data_for_len.shape[0] if hasattr(data_for_len, "shape") else len(data_for_len)

with pm.Model(coords=self._fit_coords) as forward_model:
(
[
Expand Down Expand Up @@ -1428,26 +1437,66 @@ def _sample_conditional(
else (None, None)
)

SequenceMvNormal(
f"{name}_{group}",
mus=mu,
covs=cov,
logp=dummy_ll,
dims=state_dims,
method=mvn_method,
)
if name == "smoothed":
# Durbin-Koopman simulation smoother: draws a joint sample of
# alpha_{1:T} with the full posterior covariance, not just the
# per-step marginals. See SimulationSmoother docstring.
mu_pinned = pt.specify_shape(mu, (static_T, *mu.type.shape[1:]))
alpha_sample = SimulationSmoother(
f"{name}_{group}",
a_smooth=mu_pinned,
x0=x0,
P0=P0,
c=c,
d=d,
T=T,
Z=Z,
R=R,
H=H,
Q=Q,
kalman_filter=self.kalman_filter.copy(),
kalman_smoother=self.kalman_smoother.copy(),
sequence_names=tuple(self.kalman_filter.seq_names),
dims=state_dims,
method=mvn_method,
)
# Conditional on a joint draw of the latent path, the observation
# noise eta_t is iid by the LGSSM definition, so a per-step MvN with
# cov=H is the correct joint over the obs trajectory. Broadcast H
# against the time dim of obs_mu so SequenceMvNormal's gufunc
# ``(n),(n,n)->(n)`` doesn't collapse to a length-1 scan when H is
# rank-3 but time-broadcastable.
obs_mu = d + (Z @ alpha_sample[..., None]).squeeze(-1)
obs_cov = pt.zeros((obs_mu.shape[0], 1, 1), dtype=H.dtype) + H
SequenceMvNormal(
f"{name}_{group}_observed",
mus=obs_mu,
covs=obs_cov,
logp=pt.zeros_like(obs_mu),
dims=obs_dims,
method=mvn_method,
)
else:
SequenceMvNormal(
f"{name}_{group}",
mus=mu,
covs=cov,
logp=dummy_ll,
dims=state_dims,
method=mvn_method,
)

obs_mu = d + (Z @ mu[..., None]).squeeze(-1)
obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H
obs_mu = d + (Z @ mu[..., None]).squeeze(-1)
obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H

SequenceMvNormal(
f"{name}_{group}_observed",
mus=obs_mu,
covs=obs_cov,
logp=dummy_ll,
dims=obs_dims,
method=mvn_method,
)
SequenceMvNormal(
f"{name}_{group}_observed",
mus=obs_mu,
covs=obs_cov,
logp=dummy_ll,
dims=obs_dims,
method=mvn_method,
)

# TODO: Remove this after pm.Flat initial values are fixed
forward_model.rvs_to_initial_values = {
Expand Down
6 changes: 5 additions & 1 deletion pymc_extras/statespace/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
from pymc_extras.statespace.filters.distributions import (
LinearGaussianStateSpace,
SimulationSmoother,
)
from pymc_extras.statespace.filters.kalman_filter import (
SquareRootFilter,
StandardFilter,
Expand All @@ -9,6 +12,7 @@
__all__ = [
"KalmanSmoother",
"LinearGaussianStateSpace",
"SimulationSmoother",
"SquareRootFilter",
"StandardFilter",
"UnivariateFilter",
Expand Down
Loading
Loading