Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization routines for ODE filters #490

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
91022c2
runge kutta init factored out
pnkraemer Jul 17, 2021
03c113d
some moving around
pnkraemer Jul 17, 2021
9c0df77
TM is class now
pnkraemer Jul 17, 2021
15abde5
removed tm convenience function
pnkraemer Jul 17, 2021
97a996b
new interface in ivpfiltsmooth
pnkraemer Jul 17, 2021
5db47d3
renamed init implementation argument
pnkraemer Jul 17, 2021
e3c42e8
fixed doctest
pnkraemer Jul 17, 2021
79e4125
doc
pnkraemer Jul 17, 2021
50cc4fd
no more classmethods
pnkraemer Jul 17, 2021
6d0c1b4
fixed pytest
pnkraemer Jul 17, 2021
a7fe0c9
exactness info in initialization
pnkraemer Jul 17, 2021
2805587
jax property
pnkraemer Jul 17, 2021
94e494e
test interface
pnkraemer Jul 17, 2021
e0ef5ef
lotka volterra fixture
pnkraemer Jul 17, 2021
c5407d3
docs in test
pnkraemer Jul 17, 2021
b4b29f3
taylor mode tests
pnkraemer Jul 17, 2021
8c759a1
test utils in initialize
pnkraemer Jul 17, 2021
1732941
changed notebook
pnkraemer Jul 17, 2021
3fc8bdb
fixed notebook
pnkraemer Jul 17, 2021
12a50b1
fixed pylint imports
pnkraemer Jul 17, 2021
6c36fc0
updated tutorial (and blacked)
pnkraemer Jul 17, 2021
b2e3b22
updated description of init strategy in ivpfiltsmooth
pnkraemer Jul 17, 2021
e5b8c22
removed initialisation classmethods for good
pnkraemer Jul 17, 2021
3aef792
removed unused pylint disable
pnkraemer Jul 17, 2021
acd07f8
some tidying in taylor mode
pnkraemer Jul 17, 2021
9bdb9b4
more verbose code in taylor mode autodiff
pnkraemer Jul 17, 2021
fd60597
more verbose code in taylor mode autodiff
pnkraemer Jul 17, 2021
eaf9403
word choice in taylor mode code
pnkraemer Jul 17, 2021
71f6037
renaming
pnkraemer Jul 17, 2021
ec0359c
extracted common functionality of tests
pnkraemer Jul 17, 2021
9d842aa
fixed pylint
pnkraemer Jul 17, 2021
01cfe61
types
pnkraemer Jul 17, 2021
e603bc9
fixed imports for types
pnkraemer Jul 17, 2021
1b27cac
Merge branch 'main' into odefiltsmooth_initialize
pnkraemer Jul 17, 2021
00ff2e7
Merge branch 'main' of https://github.com/probabilistic-numerics/prob…
pnkraemer Jul 19, 2021
0a18569
fixed notebook
pnkraemer Jul 19, 2021
d190fc9
typo corrected and pass->notimplementederror
pnkraemer Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api/diffeq/odefiltsmooth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ probnum.diffeq.odefiltsmooth
.. toctree::
:hidden:

odefiltsmooth/initialize
odefiltsmooth/initialization_routines
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
probnum.diffeq.odefiltsmooth.initialization_routines
----------------------------------------------------

.. automodapi:: probnum.diffeq.odefiltsmooth.initialization_routines
:no-heading:
:headings: "*"
6 changes: 0 additions & 6 deletions docs/source/api/diffeq/odefiltsmooth/initialize.rst

This file was deleted.

21 changes: 15 additions & 6 deletions docs/source/tutorials/odes/odesolvers_from_scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# ODE-Solvers from Scratch\n",
"# ODE Solvers from Scratch\n",
"\n",
"All the other tutorials show how to use the ODE-solver with the `probsolve_ivp` function.\n",
"This is great, though `probnum` has more customisation to offer."
Expand Down Expand Up @@ -87,8 +87,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we construct the ODE filter. One choice that has not been made yet is the initialiation strategy. The current default choice is to initialise by fitting the prior to a few steps of a Runge-Kutta solution. An alternative is to use automatic differentiation, which is currently in development.\n",
"An easy-access version of those initialisation strategies is to use the constructor `GaussianIVPFilter.construct_with_rk_init`. "
"Next, we construct the ODE filter. One choice that has not been made yet is the initialiation strategy. The current default choice is to initialise by fitting the prior to a few steps of a Runge-Kutta solution. An alternative is to use automatic differentiation.\n",
"\n",
"All of those options can be found in `diffeq.odefiltsmooth.initialize`.\n"
]
},
{
Expand All @@ -97,8 +98,16 @@
"metadata": {},
"outputs": [],
"source": [
"diffmodel =statespace.PiecewiseConstantDiffusion(t0=t0)\n",
"solver = diffeq.odefiltsmooth.GaussianIVPFilter.construct_with_rk_init(ivp, prior_process=prior_process, measurement_model=ekf, diffusion_model=diffmodel, with_smoothing=True)\n"
"diffmodel = statespace.PiecewiseConstantDiffusion(t0=t0)\n",
"init_routine = diffeq.odefiltsmooth.initialization_routines.RungeKuttaInitialization()\n",
"solver = diffeq.odefiltsmooth.GaussianIVPFilter(\n",
" ivp,\n",
" prior_process=prior_process,\n",
" measurement_model=ekf,\n",
" initialization_routine=init_routine,\n",
" diffusion_model=diffmodel,\n",
" with_smoothing=True,\n",
")"
]
},
{
Expand Down Expand Up @@ -132,7 +141,7 @@
"metadata": {},
"outputs": [],
"source": [
"evalgrid = np.arange(ivp.t0, ivp.tmax, step=0.1)\n"
"evalgrid = np.arange(ivp.t0, ivp.tmax, step=0.1)"
]
},
{
Expand Down
11 changes: 7 additions & 4 deletions src/probnum/diffeq/_probsolve_ivp.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,14 @@ def probsolve_ivp(
measmod = odefiltsmooth.GaussianIVPFilter.string_to_measurement_model(
method, ivp, prior_process
)
solver = odefiltsmooth.GaussianIVPFilter.construct_with_rk_init(
ivp,
prior_process,
measmod,

rk_init = odefiltsmooth.initialization_routines.RungeKuttaInitialization()
solver = odefiltsmooth.GaussianIVPFilter(
ivp=ivp,
prior_process=prior_process,
measurement_model=measmod,
with_smoothing=dense_output,
initialization_routine=rk_init,
diffusion_model=diffusion,
)

Expand Down
94 changes: 10 additions & 84 deletions src/probnum/diffeq/odefiltsmooth/_ivpfiltsmooth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Gaussian IVP filtering and smoothing."""

from typing import Callable, Optional
from typing import Optional

import numpy as np
import scipy.linalg

from probnum import filtsmooth, problems, randprocs, randvars, statespace, utils
from probnum.diffeq import _odesolver
from probnum.diffeq.odefiltsmooth import _kalman_odesolution, initialize
from probnum.diffeq.odefiltsmooth import _kalman_odesolution, initialization_routines


class GaussianIVPFilter(_odesolver.ODESolver):
Expand All @@ -28,9 +28,10 @@ class GaussianIVPFilter(_odesolver.ODESolver):
ODE measurement model.
with_smoothing
To smooth after the solve or not to smooth after the solve.
init_implementation :
Initialization algorithm. Either via Scipy (``initialize_odefilter_with_rk``) or via Taylor-mode AD (``initialize_odefilter_with_taylormode``).
For more convenient construction, consider :func:`GaussianIVPFilter.construct_with_rk_init` and :func:`GaussianIVPFilter.construct_with_taylormode_init`.
initialization_routine :
Initialization algorithm.
Either via fitting the prior to a few steps of a Runge-Kutta method (:class:`RungeKuttaInitialization`)
or via Taylor-mode automatic differentiation (:class:``TaylorModeInitialization``).
diffusion_model :
Diffusion model. This determines which kind of calibration is used. We refer to Bosch et al. (2020) [1]_ for a survey.
_reference_coordinates :
Expand All @@ -51,16 +52,7 @@ def __init__(
prior_process: randprocs.MarkovProcess,
measurement_model: statespace.DiscreteGaussian,
with_smoothing: bool,
init_implementation: Callable[
[
Callable,
np.ndarray,
float,
randprocs.MarkovProcess,
Optional[Callable],
],
randvars.Normal,
],
initialization_routine: initialization_routines.InitializationRoutine,
diffusion_model: Optional[statespace.Diffusion] = None,
_reference_coordinates: Optional[int] = 0,
):
Expand All @@ -74,7 +66,7 @@ def __init__(

self.sigma_squared_mle = 1.0
self.with_smoothing = with_smoothing
self.init_implementation = init_implementation
self.initialization_routine = initialization_routine
super().__init__(ivp=ivp, order=prior_process.transition.ordint)

# Set up the diffusion_model style: constant or piecewise constant.
Expand All @@ -98,75 +90,9 @@ def __init__(
# or from any other state.
self._reference_coordinates = _reference_coordinates

# Construct an ODE solver from different initialisation methods.
# The reason for implementing these via classmethods is that different
# initialisation methods require different parameters.

@classmethod
def construct_with_rk_init(
cls,
ivp,
prior_process,
measurement_model,
with_smoothing,
diffusion_model=None,
_reference_coordinates=0,
init_h0=0.01,
init_method="DOP853",
):
"""Create a Gaussian IVP filter that is initialised via
:func:`initialize_odefilter_with_rk`."""

def init_implementation(f, y0, t0, prior_process, df=None):
return initialize.initialize_odefilter_with_rk(
f=f,
y0=y0,
t0=t0,
prior_process=prior_process,
df=df,
h0=init_h0,
method=init_method,
)

return cls(
ivp,
prior_process,
measurement_model,
with_smoothing,
init_implementation=init_implementation,
diffusion_model=diffusion_model,
_reference_coordinates=_reference_coordinates,
)

@classmethod
def construct_with_taylormode_init(
cls,
ivp,
prior_process,
measurement_model,
with_smoothing,
diffusion_model=None,
_reference_coordinates=0,
):
"""Create a Gaussian IVP filter that is initialised via
:func:`initialize_odefilter_with_taylormode`."""
return cls(
ivp,
prior_process,
measurement_model,
with_smoothing,
init_implementation=initialize.initialize_odefilter_with_taylormode,
diffusion_model=diffusion_model,
_reference_coordinates=_reference_coordinates,
)

def initialise(self):
initrv = self.init_implementation(
self.ivp.f,
self.ivp.y0,
self.ivp.t0,
self.prior_process,
self.ivp.df,
initrv = self.initialization_routine(
ivp=self.ivp, prior_process=self.prior_process
)

return self.ivp.t0, initrv
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Initialisation procedures for ODE filters."""

from ._initialization_routine import InitializationRoutine
from ._runge_kutta import RungeKuttaInitialization
from ._taylor_mode import TaylorModeInitialization

__all__ = [
"InitializationRoutine",
"RungeKuttaInitialization",
"TaylorModeInitialization",
]


# Set correct module paths (for superclasses).
# Corrects links and module paths in documentation.
InitializationRoutine.__module__ = (
"probnum.diffeq.odefiltsmooth.initialization_routines"
)
RungeKuttaInitialization.__module__ = (
"probnum.diffeq.odefiltsmooth.initialization_routines"
)
TaylorModeInitialization.__module__ = (
"probnum.diffeq.odefiltsmooth.initialization_routines"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Interface for ODE filter initialization."""

import abc

from probnum import problems, randprocs, randvars


class InitializationRoutine(abc.ABC):
"""Interface for initialization routines for a filtering-based ODE solver.

One crucial factor for stable implementation of probabilistic ODE solvers is
starting with a good approximation of the derivatives of the initial condition [1]_.
(This is common in all Nordsieck-like ODE solvers.)
For this reason, efficient methods of initialization need to be devised.
All initialization routines in ProbNum implement the interface :class:`InitializationRoutine`.

References
----------
.. [1] Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers,
*arXiv:2012.10106*, 2020.
"""

def __init__(self, is_exact: bool, requires_jax: bool):
self._is_exact = is_exact
self._requires_jax = requires_jax

@abc.abstractmethod
def __call__(
self, ivp: problems.InitialValueProblem, prior_process: randprocs.MarkovProcess
) -> randvars.RandomVariable:
raise NotImplementedError

@property
def is_exact(self) -> bool:
"""Exactness of the computed initial values.

Some initialization routines yield the exact initial derivatives, some others
only yield approximations.
"""
return self._is_exact

@property
def requires_jax(self) -> bool:
"""Whether the implementation of the routine relies on JAX."""
return self._requires_jax
Loading