From 9c680596aa80efc38461d0c808ebafca4a24d118 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Fri, 9 Sep 2022 11:12:01 -0400 Subject: [PATCH] Split up test_ode.py --- .github/workflows/tests.yml | 7 ++-- pymc/tests/ode/__init__.py | 0 pymc/tests/{ => ode}/test_ode.py | 40 ----------------------- pymc/tests/ode/test_utils.py | 56 ++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 43 deletions(-) create mode 100644 pymc/tests/ode/__init__.py rename pymc/tests/{ => ode}/test_ode.py (92%) create mode 100644 pymc/tests/ode/test_utils.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1e8b823284f..d6a49cc1945 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -80,7 +80,8 @@ jobs: pymc/tests/gp/test_util.py pymc/tests/test_model.py pymc/tests/test_model_graph.py - pymc/tests/test_ode.py + pymc/tests/ode/test_ode.py + pymc/tests/ode/test_utils.py pymc/tests/test_profile.py pymc/tests/test_quadpotential.py @@ -151,7 +152,7 @@ jobs: test-subset: - pymc/tests/test_variational_inference.py pymc/tests/test_initial_point.py - pymc/tests/test_pickling.py pymc/tests/test_profile.py pymc/tests/test_step.py - - pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/test_ode.py pymc/tests/test_smc.py pymc/tests/test_parallel_sampling.py + - pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/test_smc.py pymc/tests/test_parallel_sampling.py - pymc/tests/test_sampling.py pymc/tests/test_posteriors.py fail-fast: false @@ -364,7 +365,7 @@ jobs: floatx: [float32] python-version: ["3.10"] test-subset: - - pymc/tests/test_sampling.py pymc/tests/test_ode.py + - pymc/tests/test_sampling.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/pymc/tests/ode/__init__.py b/pymc/tests/ode/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pymc/tests/test_ode.py b/pymc/tests/ode/test_ode.py similarity index 92% rename from pymc/tests/test_ode.py rename to pymc/tests/ode/test_ode.py index 96bbf22ee7b..8f90c677e9f 100644 --- a/pymc/tests/test_ode.py +++ b/pymc/tests/ode/test_ode.py @@ -18,54 +18,14 @@ import numpy as np import pytest -from scipy.integrate import odeint from scipy.stats import norm import pymc as pm from pymc.ode import DifferentialEquation -from pymc.ode.utils import augment_system from pymc.tests.helpers import fast_unstable_sampling_mode -def test_gradients(): - """Tests the computation of the sensitivities from the Aesara computation graph""" - - # ODE system for which to compute gradients - def ode_func(y, t, p): - return np.exp(-t) - p[0] * y[0] - - # Computation of graidients with Aesara - augmented_ode_func = augment_system(ode_func, 1, 1 + 1) - - # This is the new system, ODE + Sensitivities, which will be integrated - def augmented_system(Y, t, p): - dydt, ddt_dydp = augmented_ode_func(Y[:1], t, p, Y[1:]) - derivatives = np.concatenate([dydt, ddt_dydp]) - return derivatives - - # Create real sensitivities - y0 = 0.0 - t = np.arange(0, 12, 0.25).reshape(-1, 1) - a = 0.472 - p = np.array([y0, a]) - - # Derivatives of the analytic solution with respect to y0 and alpha - # Treat y0 like a parameter and solve analytically. Then differentiate. - # I used CAS to get these derivatives - y0_sensitivity = np.exp(-a * t) - a_sensitivity = ( - -(np.exp(t * (a - 1)) - 1 + (a - 1) * (y0 * a - y0 - 1) * t) * np.exp(-a * t) / (a - 1) ** 2 - ) - - sensitivity = np.c_[y0_sensitivity, a_sensitivity] - - integrated_solutions = odeint(func=augmented_system, y0=[y0, 1, 0], t=t.ravel(), args=(p,)) - simulated_sensitivity = integrated_solutions[:, 1:] - - np.testing.assert_allclose(sensitivity, simulated_sensitivity, rtol=1e-5) - - def test_simulate(): """Tests the integration in DifferentialEquation""" diff --git a/pymc/tests/ode/test_utils.py b/pymc/tests/ode/test_utils.py new file mode 100644 index 00000000000..9faf60e3b60 --- /dev/null +++ b/pymc/tests/ode/test_utils.py @@ -0,0 +1,56 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import scipy.integrate as ode + +from pymc.ode.utils import augment_system + + +def test_gradients(): + """Tests the computation of the sensitivities from the Aesara computation graph""" + + # ODE system for which to compute gradients + def ode_func(y, t, p): + return np.exp(-t) - p[0] * y[0] + + # Computation of graidients with Aesara + augmented_ode_func = augment_system(ode_func, 1, 1 + 1) + + # This is the new system, ODE + Sensitivities, which will be integrated + def augmented_system(Y, t, p): + dydt, ddt_dydp = augmented_ode_func(Y[:1], t, p, Y[1:]) + derivatives = np.concatenate([dydt, ddt_dydp]) + return derivatives + + # Create real sensitivities + y0 = 0.0 + t = np.arange(0, 12, 0.25).reshape(-1, 1) + a = 0.472 + p = np.array([y0, a]) + + # Derivatives of the analytic solution with respect to y0 and alpha + # Treat y0 like a parameter and solve analytically. Then differentiate. + # I used CAS to get these derivatives + y0_sensitivity = np.exp(-a * t) + a_sensitivity = ( + -(np.exp(t * (a - 1)) - 1 + (a - 1) * (y0 * a - y0 - 1) * t) * np.exp(-a * t) / (a - 1) ** 2 + ) + + sensitivity = np.c_[y0_sensitivity, a_sensitivity] + + integrated_solutions = ode.odeint(func=augmented_system, y0=[y0, 1, 0], t=t.ravel(), args=(p,)) + simulated_sensitivity = integrated_solutions[:, 1:] + + np.testing.assert_allclose(sensitivity, simulated_sensitivity, rtol=1e-5)