Skip to content

batched SSM compatibility#685

Draft
Dekermanjian wants to merge 10 commits into
pymc-devs:mainfrom
Dekermanjian:ssm_batch_dims
Draft

batched SSM compatibility#685
Dekermanjian wants to merge 10 commits into
pymc-devs:mainfrom
Dekermanjian:ssm_batch_dims

Conversation

@Dekermanjian

@Dekermanjian Dekermanjian commented May 20, 2026

Copy link
Copy Markdown
Contributor

The following is a WIP implementation of allowing the SSM module to work with multiple time-series via batched inputs.

This first draft allows batched sampling and supports the following internals:

  • sample_unconditional_prior
  • sample_conditional_prior
  • sample_unconditional_posterior
  • sample_conditional_posterior

Additional testing is still required and redundancies in code need to be cleaned-up.

Update 1:
I updated the way batching is defined from implicitly defined to explicitly defined by a batch_size parameter and I fixed an error that I introduced in the data registry control flow. All StateSpace tests now are passing locally. This just means that adding the batching implementation did not break any existing logic.

@jessegrabowski, Here are the next steps before this will be ready:

  • Quick review just so that I know directionally if this is how the vision of batching should be
  • Continue to add additional StateSpace methods support until all methods are supported
  • Write tests for batching that runs through the full SSM workflow

@Dekermanjian Dekermanjian self-assigned this May 20, 2026
@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter

codecov-commenter commented May 20, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 94.94949% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.24%. Comparing base (86fac3c) to head (86916f0).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
pymc_extras/statespace/core/statespace.py 94.77% 8 Missing ⚠️
pymc_extras/statespace/utils/data_tools.py 92.00% 2 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main     #685       +/-   ##
===========================================
+ Coverage   51.60%   91.24%   +39.63%     
===========================================
  Files          73       79        +6     
  Lines        8003     8097       +94     
===========================================
+ Hits         4130     7388     +3258     
+ Misses       3873      709     -3164     
Files with missing lines Coverage Δ
pymc_extras/statespace/filters/distributions.py 92.26% <100.00%> (+0.14%) ⬆️
pymc_extras/statespace/models/ETS.py 98.89% <ø> (ø)
pymc_extras/statespace/models/SARIMAX.py 98.96% <ø> (ø)
pymc_extras/statespace/models/VARMAX.py 99.39% <ø> (ø)
pymc_extras/statespace/utils/batch_tools.py 100.00% <100.00%> (ø)
pymc_extras/statespace/utils/constants.py 100.00% <100.00%> (ø)
pymc_extras/statespace/utils/data_tools.py 85.82% <92.00%> (-0.62%) ⬇️
pymc_extras/statespace/core/statespace.py 89.30% <94.77%> (+3.41%) ⬆️

... and 39 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jessegrabowski jessegrabowski left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really great start, i'll fire up the notebook and play it. Here are some thoughts off the cuff.

@ricardoV94 your voice would be very welcome here :)

Comment thread pymc_extras/statespace/core/statespace.py Outdated
Comment thread pymc_extras/statespace/core/statespace.py Outdated
data_dims = None

if self.batch_size:
data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are more than one batch dim?

Comment thread pymc_extras/statespace/core/statespace.py
Comment thread pymc_extras/statespace/core/statespace.py
Comment thread pymc_extras/statespace/core/statespace.py
Comment thread pymc_extras/statespace/core/statespace.py Outdated
Comment thread pymc_extras/statespace/core/statespace.py Outdated
shock_trajectory = shock_trajectory.swapaxes(0, 1)

def bmv(A, x):
return pt.matmul(A, x[..., None])[..., 0]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we automatically do something like this already


pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])
if self.batch_size:
irf = irf.swapaxes(0, 1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This si the right pattern but you need to reason about the case where there are multiple batches

Comment thread pymc_extras/statespace/filters/distributions.py Outdated
…tch_tools utilities file, use slicing logic instead of branching for forecast indexing, make batch_size a tuple for destructuring instead of branching
…_sizes, added a utility to infer the batch dimension
… vectorize certain matrices

2. Updated observation covariance computation in sample_conditional to support batched dimensions
3. added tests for batched dimension SSMs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants