-
Notifications
You must be signed in to change notification settings - Fork 3
CSMC & Sparse Particle Storage with RB and GPU-acceleration #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Use offset arrays to include initial state in reference trajectory. Implement Kalman smoother for use unit tests.
Looks like the tests before were passing by chance. I've made a few changes so they now have a valid unit test that passes:
The first change required modifying the GaussianContainer to have proposal/filtered states, which breaks the RBPF test (though the others still pass) as discussed in #14. I think the way around this might be to have initialise/update/predict act on the state itself, not the container of proposal/filtered states. The |
Code is very messy at this point but...it works 🙌 Have a unit test comparing GPU-accelerated, Rao-Blackwellised CSMC to RTS on a dummy Gaussian problem and the smoothing distributions match. I will probably clean up once we chat on Friday to discuss interface changes. The biggest long-term issue is that I'm storing the reference trajectories essentially as a vector of (D x 1) CuArrays (i.e. particle containers of size 1). This is probably not very efficient. Maybe it's best falling back to the CPU for the reference trajectory stuff since we're no longer in a parallel setting. Though a question I've been wondering is whether having one reference trajectory is always optimal for CSMC/PG. The theory works the same with any number...and the GPU would be useful if we had multiple (can compute the ancestries in parallel). |
Here's the PDF detailing the changes involved in this PR. |
The last commit starts the conversion of the interface to act on distributions rather than the combined intermediate storage. With this change the Kalman smoother, particle filter and RBPF are all simultaneously passing. I apologise that this has made some other parts of the code a bit clunkier and may get in the way of some of your ideas. I don't see this as a final version though so happy to make changes to make things more elegant provided it still allows for the RBPF to work nicely. Some of the main changes that were required:
I'll spend the rest of the day making sure the other tests pass and introducing the other changes we discussed last Friday. Thank you for your patience whilst I make these quite clunky changes—the fast GPU—RB—PGAS should all be worth it though! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like what I see so far as long as unit tests are passing across the board. I also want to make sure this code is type stable since this could be a huge bottleneck in terms of performance.
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
|
||
[compat] | ||
DataStructures = "0.18.20" | ||
GaussianDistributions = "0.5.2" | ||
OffsetArrays = "1.14.1" | ||
Statistics = "1.11.1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a little odd, I was able to remove this on my work computer (which is stuck at Julia-1.10.4)
src/containers.jl
Outdated
mutable struct Intermediate | ||
proposed::Any | ||
filtered::Any | ||
ancestors::Any | ||
Intermediate() = new() | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not too keen on the idea of sacrificing type stability for convenience. If we can ensure that the RBPF is type stable with Intermediate
that would be ideal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I think we can compute these types at instantiation time, it'll just be a bit of a faff. It's basically a generalisation of the rb_type
used with the CPU.
src/algorithms/bootstrap.jl
Outdated
function instantiate( | ||
model::StateSpaceModel{T}, filter::BootstrapFilter; kwargs... | ||
) where {T} | ||
N = filter.N | ||
particle_state = ParticleState(Vector{Vector{T}}(undef, N), Vector{T}(undef, N)) | ||
return ParticleContainer( | ||
particle_state, deepcopy(particle_state), Vector{Int}(undef, N) | ||
) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can generalize this to any AbstractParticleFilter
, except for the RBPF
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. Can probably even extend it to the RBPF too. It's just replacing the Vector{T} with the Rao-Blackwellised particle
Yeah, I definitely think some things have slipped through in the process. Code, especially the GPU stuff, seems slower now than I expected. |
Previous unit test was passing accidentally since for large N_particles, the randomly selected trajectories are good smoothers even without the reference trajectory.
The core ideas of this PR are now implemented and all tests are passing. A bit more tidying up of code and refactoring would be a good idea but might be best to get the PR merged first. |
Since these operations involve extending an array (which involves copying), there's no point doing this in-place as this makes it harder to recurse through the container structures.
Remove redundant types Renamed structs for consistency Generalised RB methods to all distributions
This PR is (finally) stable, clean, and ready to merge. Please note that I've made some significant changes to |
* Update documentation to match new interface * Incorporate SSMProblems into AbstractMCMC * Update Project.toml (#24) * Update Project.toml * Update make.jl --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
* Initial commit * Readme * Project * Prototype design. (#2) * prototype * Update SSMProblems.jl * add logM (#3) * Convert example into docstring. * Move `logM`. --------- Co-authored-by: FredericWantiez <frederic.wantiez@gmail.com> * export * example * Fred/ancestor (#5) * Gibbs * Add ancestor resampling * Better names * Clean up package (#6) * Gibbs * Setup github * Update Readme * Docs * GH actions * Format * Upgrade node * Use GH token * Fix links * Clean up * Write a proper example implementation (#7) * SMC * remove old file * Fix types * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Some minor changes (#8) * use recommended style for interface methods * minor changes to example * Update smc.jl * Update smc.jl --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Particle filter example bug fix and refactoring (#10) * fix: corrected observation generation for particle filter example * refactor: tidied particle filter example code - Removed recursive particle show method which flooded REPL - Removed redundant resampling logic - Replaced variance with std in Normal() calls - Tidied final scatter plot * Updated formatting for named argument Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: corrected flipped noise standard deviations --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix format action (#13) * Fix docs action (#12) * Fix docs action * Add DOCUMENTER_KEY * Use julia-docdeploy action * Show link to docs preview (#15) * Update documentation (#16) * Add details to doc * Fix source * Typo * Update SSM Interface * Fix linearize bug * Interface * Format * Trying things * Fix transition!! * Format * Fix doc * Helper * Format * Forget about particles * Optional timestep * Add utils * Apply suggestions from code review Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Utils module --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update README.md (#17) * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update index.md * Update documentation to match new interface (#18) * Update documentation to match new interface * Update index.md --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Build examples with doc (#19) * Build examples with doc * Reduce size of plot * Colors * Size option * Increase size per page * Update README.md * Update Project.toml * Modify SSMProblems to work with AbstractMCMC interface (#22) * Update documentation to match new interface * Incorporate SSMProblems into AbstractMCMC * Update Project.toml (#24) * Update Project.toml * Update make.jl --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Add Kalman filter example (#26) * Add Kalman filter example * Fix formatting issues * Add literate for docs * add missing deps * Comments for literate * Format * Use `Gaussian` (#28) * Format, use `Gaussian` * Fix the maths * Format * Tweaks * Update script.jl * Update script.jl --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Update Project.toml --------- Co-authored-by: FredericWantiez <frederic.wantiez@gmail.com> Co-authored-by: Hong Ge <hg344@cam.ac.uk> Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Update script.jl * Update script.jl * Create DocsNav.yml * Add example script for PMMH (#37) * Add example script for PMMH * Add Literate.jl * Update script.jl * Update script.jl * Update DocsNav.yml * Update SSMProblems.jl interface (#38) * Add split dynamics/observation interface with "extra" variables * Add utilities for forward simulation and distribution definitions * Removed redundant particle container code * Update naming convention for initialisation log-density Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Update initialisation naming Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Update naming convention Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Change sampler to AbstractMCMC Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Add section heading for SSM Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Remove redundant method check * Correct dependencies * Revert to positional arguments * Correct forward simulation element type * Update Kalman filter example to new interface * Fix formatting issue * Add missing import Co-authored-by: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> * Add default rngs to samplers through macro * Remove unnecessary section heading * Add missing dependency * Tidied Kalman filter example * Update documentation main page * Fully document Kalman filter example * Remove outdated examples * Add documentation for extra argument * Apply suggestions from code review * Update script.jl * Update examples/kalman-filter/script.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update script.jl * Remove PMMH (until new API) * Remove ref to 'Utils' * Fix broken link * Remove default RNG macro * Simplify interface methods * Correct old function names * Add parametric type to Kalman filter * Update main doc page * Update README * Make parameter order consistent --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: FredericWantiez <frederic.wantiez@gmail.com> * Minor tweaks and typo fixes (#41) * Update index.md * Update index.md * Update Project.toml * Suppress output in example script --------- Co-authored-by: THargreaves <tim.hargreaves@icloud.com> * Split method definitions to avoid docstring overwriting (#42) * Function docstring formatting (#45) * added TagBot & CompatHelper workflows (#47) * added TagBot & CompatHelper workflows * using existing Documenter Key for CompatHelper * CompatHelper: add new compat entry for Distributions at version 0.25, (keep existing compat) (#48) Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org> * Update DocsPreviewCleanup.yml * TagBot Permission Issue fixed (#50) * Add DOCUMENTER_KEY to Docs workflow (#52) * Update DocsPreviewCleanup.yml * Correct type signature for forward simulation method * Alignment of obs/dyn time steps and refactored forward simulation (#55) * Update interface documentation to align dyn/obs time steps * Refactor forward simulation, add type parameters, add unit test * Add type parameters to docstrings * Update kalman example * Add extra for initialisation and simplify obs type parameter * Bump minor version * Update unit tests to match aligned interface * Fix code comment rendering * Interface Changes for Use in Filtering (#56) * added basic particle methods and filters * added qualifiers * added parameter priors * added adaptive resampling to bootstrap filter (WIP) * Julia fomatter changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * changed eltype for <: StateSpaceModel * updated naming conventions * formatter * fixed adaptive resampling * added particle ancestry * formatter issues * fixed metropolis and added rejection resampler * Keep track of free indices using stack * updated particle types and organized directory * weakened SSM type parameter assertions * improved particle state containment and resampling * added hacky sparse ancestry to example * fixed RNG in rejection resampling * improved callbacks and resamplers * formatting * added conditional SMC * improved linear model type structure * formatter * replaced extra with kwargs * formatter * migrated filtering code * Add unittests for new interface * Update documentation to match kwargs * Rename extras/kwargs docs file * remove redundant forward simulations --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tim Hargreaves <tim.hargreaves@icloud.com> * Bump 0.4.0 (#58) * Update docs to match kwargs interface * Add method definitions for batch simulation/log-densities * Bump minor version * Correct docstring overwriting for batch methods * Update type parameters to contain both arithmetic and element type * Correct docstring indentations * Correct RBPF forward simulation * Documentation and Turing Navigation CI improvement (#61) * Update Docs.yml * Update DocsNav.yml * No need of deploydocs() after using new Docs & DocsNav workflows * Remove research files from repository * removed SSMProblems README, LICENSE, GHA workflow, JULIAFORMATTER in favor of merger * SSMProblems: added missing docstring to avoid documentation failure * added pkg_path in Docs workflow to fix package development --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: FredericWantiez <frederic.wantiez@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com> Co-authored-by: Tim Hargreaves <38204689+THargreaves@users.noreply.github.com> Co-authored-by: David Widmann <devmotion@users.noreply.github.com> Co-authored-by: Hong Ge <hg344@cam.ac.uk> Co-authored-by: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Co-authored-by: THargreaves <tim.hargreaves@icloud.com> Co-authored-by: Will Tebbutt <willtebbutt00@gmail.com> Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org> Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
A very much draft piece of work tightening up and testing the CSMC interface. Just wanted to check it technically works.
Contributions and comments welcomed.
TODO:
On the last point, the callback is not ran after initialisation so x0 of the reference trajectory is not stored.