Skip to content

Add Simulation Smoother#680

Open
jessegrabowski wants to merge 9 commits into
pymc-devs:mainfrom
jessegrabowski:simulation-smoother
Open

Add Simulation Smoother#680
jessegrabowski wants to merge 9 commits into
pymc-devs:mainfrom
jessegrabowski:simulation-smoother

Conversation

@jessegrabowski

@jessegrabowski jessegrabowski commented May 8, 2026

Copy link
Copy Markdown
Member

A simulation smoother draws a joint sample of the full latent trajectory $\alpha_{1:T}$ from the smoothing posterior $p(\alpha_{1:T} \mid y_{1:T})$ of a linear Gaussian state-space model, rather than just the per-step marginals $p(\alpha_t \mid y_{1:T})$. The Kalman smoother already gives us the marginal mean $\hat{\alpha}t = \mathbb{E}[\alpha_t \mid y{1:T}]$ and marginal covariance $P_{t \mid T} = \text{Var}(\alpha_t \mid y_{1:T})$ for each $t$, but the cross-time block $\text{Cov}(\alpha_t, \alpha_s \mid y_{1:T})$ is not returned.

We need a simulation smoother in this library because plotting trajectories of the latent state — for posterior_predictive["smoothed_*"] and related uses — requires draws that respect the temporal correlation structure, not draws that pretend each $\alpha_t$ is conditionally independent.

This is visible if you run a structural time series model and plot the smoother outputs. If you take an average they look ok but individually they look jagged:

smoother_before

The point of the simulation smoother is to build back in the cross-time dependency, which has the effect of smoothing all the trajectories out:

smoother_after

@jessegrabowski

Copy link
Copy Markdown
Member Author

Tagging @ricardoV94 because i hope you can look at the need for inline=False in the Smoother scan. Without it I was getting hanging during scan_merge rewrites. Maybe you have some thoughts on how to do it better.

@jessegrabowski

Copy link
Copy Markdown
Member Author

Also this feature has been requested here so I have to give the people what they want!

@ricardoV94

Copy link
Copy Markdown
Member

Tagging @ricardoV94 because i hope you can look at the need for inline=False in the Smoother scan. Without it I was getting hanging during scan_merge rewrites. Maybe you have some thoughts on how to do it better.

Which test / minimal script triggers the failure?

@jessegrabowski

Copy link
Copy Markdown
Member Author

Here's an MWE claude and I came up with:

"""MRE for the scan_sit_sot_to_untraced bug.

  PYTENSOR_FLAGS=linker=auto python MRE_pure.py   ->  FAIL: shape 1, expected 4
  PYTENSOR_FLAGS=linker=py   python MRE_pure.py   ->  PASS

A sit-sot output of an outer scan has its time dim stripped to 1 when it feeds
an ``inline=False`` SymbolicRandomVariable that slices ``[-1]`` / ``[:-1]`` and
re-concatenates back to T. Each of these is load-bearing (removing any one
kills the bug):
  - outer scan has TWO sit-sot outputs ``(T, k)`` and ``(T, k, k)``;
  - a ``(k, k)`` non-sequence matmul'd against the ``(k,)`` carry;
  - the inline=False op is a SymbolicRandomVariable (plain OpFromGraph
    does not trigger it);
  - BOTH RV consumers (the sibling and the downstream one) are
    SymbolicRandomVariable wrappers; swapping either for a plain RV kills it;
  - compilation with ``pytensor.function(..., updates=rng_updates)``.
"""
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.random import normal
from pymc.distributions.distribution import SymbolicRandomVariable
from pymc.pytensorf import collect_default_updates, normalize_rng_param

T, k = 4, 1


class _RV(SymbolicRandomVariable):
    default_output = 1
    def update(self, node):
        return {node.inputs[-1]: node.outputs[0]}


def seq_n(mus):
    rng = normalize_rng_param(None)
    mus_ = mus.type()
    nxt, seq = normal(loc=mus_, scale=1.0, rng=rng, return_next_rng=True)
    return _RV(inputs=[mus_, rng], outputs=[nxt, seq], ndim_supp=1)(mus, rng)


def smoother_op():
    a_ = pt.matrix("a_", shape=(T, k))
    rng_shared = pytensor.shared(np.random.default_rng(0))
    rng_ = rng_shared.type()
    a_last = a_[-1]
    a_seq = pytensor.scan(
        lambda a_t, a_next: a_t + a_next,
        sequences=[a_[:-1]],
        outputs_info=[a_last],
        go_backwards=True,
        strict=True,
        return_updates=False,
    )
    smooth = pt.concatenate([a_seq[::-1], pt.expand_dims(a_last, 0)], axis=0)
    smooth = pt.specify_shape(smooth, (T, k))
    return _RV(
        inputs=[a_, rng_],
        outputs=[rng_, smooth],
        extended_signature="(t,s),[rng]->[rng],(t,s)",
        inline=False,
    ), rng_shared


# --- outer graph ---
T_mat = pt.eye(k) * 0.9

a_filt, _ = pytensor.scan(
    lambda a_prev, P_prev, A: (A @ a_prev + 1.0, A @ P_prev),
    outputs_info=[pt.zeros(k), pt.eye(k)],
    non_sequences=[T_mat],
    n_steps=T,
    strict=True,
    return_updates=False,
)

filt_rv = seq_n(a_filt)
op, op_rng = smoother_op()
alpha = op(a_filt, op_rng)
obs_rv = seq_n(alpha)

rng_updates = collect_default_updates(inputs=[], outputs=[filt_rv, obs_rv])
fn = pytensor.function(
    [], [filt_rv, obs_rv], updates=rng_updates, on_unused_input="ignore"
)
print(f"linker={pytensor.config.linker}")
try:
    a, b = fn()
    print(f"OK a={a.shape} b={b.shape}")
except AssertionError as e:
    print(f"FAIL: {str(e).splitlines()[0]}")

@ricardoV94

ricardoV94 commented May 17, 2026

Copy link
Copy Markdown
Member

Bug is two nodes with the same OpFromGraph Op in numba backend. The dispatch calls the graph optimization in each dispatch, and two passes of the graph optimization make a mess. We should only optimize the op once (again we shouldn't even been doing this in the dispatch phase but at rewrite time...)

In general our rewrites should be idempotent, so a second pass shouldn't fail... except if we reject inplace ops at the start. But I think scan_save_mem specifically is not behaving well on a second pass, and may be caused by recent changes to it (if those already got released)

@ricardoV94

Copy link
Copy Markdown
Member

Actually pymc-devs/pytensor#2098 fixed the idempotency issue of scan_save_mem (hence why I couldn't reproduce at first locally).

So nothing to change here, we can wait for next pytensor release? It will still be good not to run the same graph rewrite twice, but that's not a blocker.

@jessegrabowski

Copy link
Copy Markdown
Member Author

Tried with inline=False on pytensor main, and compilation hangs. After scan_merge I ended up with a monster graph that just crushed my CPU. I can try to dig out an MWE later, or we can look at it together at some point.

@ricardoV94

Copy link
Copy Markdown
Member

wasn't the problem with inline=True?

what mwe hangs?

Could it be hanging with the assumptions PR? I tested with main before that (but I'm not sure what you are running now)

Also you keep saying scan merge is the issue, why? The issue I found before was with the old san save mem, I didn't see anything wrong with scan merge

Lift the inner forward-simulation scan body out of
_LinearGaussianStateSpace.rv_op into a module-level helper so it can be
reused by SimulationSmoother. No behavior change. Also add copy()
methods to BaseFilter and KalmanSmoother so callers that re-invoke
build_graph (e.g. inside SimulationSmoother.rv_op) can do so without
clobbering the host model's seq_names / non_seq_names.
A Durbin-Koopman simulation smoother that draws joint samples of the
latent trajectory from the posterior, fixing the cross-time covariance
gap in step-marginal SequenceMvNormal draws. Includes:

- SimulationSmootherRV (SymbolicRandomVariable, inline=False to avoid a
  pytensor scan_merge non-termination on the inner forward-sim + KS
  scan pair),
- SimulationSmoother distribution with rv_op that forward-sims
  (alpha_plus, y_plus), filters and smooths y_plus, and applies the DK
  identity alpha_sample = alpha_plus - a_smooth_plus + a_smooth,
- _logprob registration returning zero (only sampled, never scored),
- tests pinning the marginal mean, joint covariance against an analytic
  posterior, and against statsmodels' simulation_smoother on a Local
  Linear Trend model.
Replace the legacy SequenceMvNormal call for the smoothed group with a
SimulationSmoother draw, so smoothed-posterior trajectories returned by
sample_conditional_prior / sample_conditional_posterior carry the full
joint covariance instead of independent step-marginals. The predicted
and filtered groups keep the legacy SequenceMvNormal path. The observed
side reuses SequenceMvNormal with mus = d + Z @ alpha_sample and
covs = H, which is correct conditional on a joint latent draw.
With inline=True the inner scans are spliced into the parent graph, where
scan_merge fuses them into a cyclic graph that hangs toposort at compile time.
@codecov-commenter

codecov-commenter commented Jun 13, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 99.14530% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 91.11%. Comparing base (86fac3c) to head (597ddae).
⚠️ Report is 13 commits behind head on main.

Files with missing lines Patch % Lines
pymc_extras/statespace/filters/distributions.py 98.96% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main     #680       +/-   ##
===========================================
+ Coverage   51.60%   91.11%   +39.50%     
===========================================
  Files          73       89       +16     
  Lines        8003     8383      +380     
===========================================
+ Hits         4130     7638     +3508     
+ Misses       3873      745     -3128     
Files with missing lines Coverage Δ
pymc_extras/statespace/core/statespace.py 86.03% <100.00%> (+0.13%) ⬆️
pymc_extras/statespace/filters/__init__.py 100.00% <100.00%> (ø)
pymc_extras/statespace/filters/kalman_filter.py 100.00% <100.00%> (ø)
pymc_extras/statespace/filters/kalman_smoother.py 100.00% <100.00%> (ø)
pymc_extras/statespace/filters/distributions.py 93.54% <98.96%> (+1.42%) ⬆️

... and 47 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants