batched SSM compatibility#685
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #685 +/- ##
===========================================
+ Coverage 51.60% 91.24% +39.63%
===========================================
Files 73 79 +6
Lines 8003 8097 +94
===========================================
+ Hits 4130 7388 +3258
+ Misses 3873 709 -3164
🚀 New features to boost your workflow:
|
…, SARIMAX, ETS so that positional arguments don't clash with the new parameter, updated error in control flow for data registry
jessegrabowski
left a comment
There was a problem hiding this comment.
This is a really great start, i'll fire up the notebook and play it. Here are some thoughts off the cuff.
@ricardoV94 your voice would be very welcome here :)
| data_dims = None | ||
|
|
||
| if self.batch_size: | ||
| data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) |
There was a problem hiding this comment.
What if there are more than one batch dim?
| shock_trajectory = shock_trajectory.swapaxes(0, 1) | ||
|
|
||
| def bmv(A, x): | ||
| return pt.matmul(A, x[..., None])[..., 0] |
There was a problem hiding this comment.
I thought we automatically do something like this already
|
|
||
| pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM]) | ||
| if self.batch_size: | ||
| irf = irf.swapaxes(0, 1) |
There was a problem hiding this comment.
This si the right pattern but you need to reason about the case where there are multiple batches
…tch_tools utilities file, use slicing logic instead of branching for forecast indexing, make batch_size a tuple for destructuring instead of branching
…_sizes, added a utility to infer the batch dimension
… vectorize certain matrices 2. Updated observation covariance computation in sample_conditional to support batched dimensions 3. added tests for batched dimension SSMs
The following is a WIP implementation of allowing the SSM module to work with multiple time-series via batched inputs.
This first draft allows batched sampling and supports the following internals:
sample_unconditional_priorsample_conditional_priorsample_unconditional_posteriorsample_conditional_posteriorAdditional testing is still required and redundancies in code need to be cleaned-up.
Update 1:
I updated the way batching is defined from implicitly defined to explicitly defined by a
batch_sizeparameter and I fixed an error that I introduced in the data registry control flow. AllStateSpacetests now are passing locally. This just means that adding the batching implementation did not break any existing logic.@jessegrabowski, Here are the next steps before this will be ready:
StateSpacemethods support until all methods are supported