diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 6f6e6a41..3b8a57e6 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -41,7 +41,7 @@ jobs: echo "Checking import and version number (on release)" venv-bdist/bin/python -c "import pymc_extras as pmx; assert pmx.__version__ == '${{ github.ref_name }}'[1:] if '${{ github.ref_type }}' == 'tag' else pmx.__version__; print(pmx.__version__)" cd .. - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: artifact path: dist/* @@ -58,7 +58,7 @@ jobs: # write id-token is necessary for trusted publishing (OIDC) id-token: write steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: artifact path: dist diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index ad60a8da..a1aa85c7 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -3,14 +3,15 @@ channels: - conda-forge - nodefaults dependencies: -- pymc>=5.19.1 +- pymc>=5.20 - pytest-cov>=2.5 - pytest>=3.0 - dask - xhistogram - statsmodels +- numba<=0.60.0 - pip - pip: - blackjax - scikit-learn - - better_optimize>=0.0.10 + - better_optimize diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 8814c696..1d1eb774 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -9,8 +9,9 @@ dependencies: - dask - xhistogram - statsmodels +- numba<=0.60.0 +- pymc>=5.20 - pip: - - pymc>=5.19.1 # CI was failing to resolve - blackjax - scikit-learn - - better_optimize>=0.0.10 + - better_optimize diff --git a/notebooks/Exponential Trend Smoothing.ipynb b/notebooks/Exponential Trend Smoothing.ipynb index f86f7613..fff68327 100644 --- a/notebooks/Exponential Trend Smoothing.ipynb +++ b/notebooks/Exponential Trend Smoothing.ipynb @@ -188,7 +188,7 @@ "\n", " # For the forecasts we need a function that lets us take draws from the distribution. We'll get the mean\n", " # and covariance from samples by calling it a lot of times.\n", - " f_forecast = pm.compile_pymc(pm.inputvars(obs_forecast), obs_forecast, mode=\"JAX\")\n", + " f_forecast = pm.compile(pm.inputvars(obs_forecast), obs_forecast, mode=\"JAX\")\n", "\n", " return f_ets, f_forecast\n", "\n", @@ -863,17 +863,17 @@ "\n" ], "text/plain": [ - "\u001b[3m Model Requirements \u001b[0m\n", + "\u001B[3m Model Requirements \u001B[0m\n", " \n", - " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape\u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mDimensions\u001b[0m\u001b[1m \u001b[0m \n", + " \u001B[1m \u001B[0m\u001B[1mVariable \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mShape\u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mConstraints \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mDimensions\u001B[0m\u001B[1m \u001B[0m \n", " ──────────────────────────────────────────────────── \n", - " initial_level \u001b[3;35mNone\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " alpha \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < alpha < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " sigma_state \u001b[3;35mNone\u001b[0m Positive \u001b[3;35mNone\u001b[0m \n", + " initial_level \u001B[3;35mNone\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " alpha \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < alpha < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " sigma_state \u001B[3;35mNone\u001B[0m Positive \u001B[3;35mNone\u001B[0m \n", " \n", - "\u001b[2;3m These parameters should be assigned priors inside a \u001b[0m\n", - "\u001b[2;3m PyMC model block before calling the \u001b[0m\n", - "\u001b[2;3m build_statespace_graph method. \u001b[0m\n" + "\u001B[2;3m These parameters should be assigned priors inside a \u001B[0m\n", + "\u001B[2;3m PyMC model block before calling the \u001B[0m\n", + "\u001B[2;3m build_statespace_graph method. \u001B[0m\n" ] }, "metadata": {}, @@ -1394,19 +1394,19 @@ "\n" ], "text/plain": [ - "\u001b[3m Model Requirements \u001b[0m\n", + "\u001B[3m Model Requirements \u001B[0m\n", " \n", - " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape\u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mDimensions\u001b[0m\u001b[1m \u001b[0m \n", + " \u001B[1m \u001B[0m\u001B[1mVariable \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mShape\u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mConstraints \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mDimensions\u001B[0m\u001B[1m \u001B[0m \n", " ──────────────────────────────────────────────────── \n", - " initial_level \u001b[3;35mNone\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " initial_trend \u001b[3;35mNone\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " alpha \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < alpha < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " beta \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < beta < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " sigma_state \u001b[3;35mNone\u001b[0m Positive \u001b[3;35mNone\u001b[0m \n", + " initial_level \u001B[3;35mNone\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " initial_trend \u001B[3;35mNone\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " alpha \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < alpha < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " beta \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < beta < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " sigma_state \u001B[3;35mNone\u001B[0m Positive \u001B[3;35mNone\u001B[0m \n", " \n", - "\u001b[2;3m These parameters should be assigned priors inside a \u001b[0m\n", - "\u001b[2;3m PyMC model block before calling the \u001b[0m\n", - "\u001b[2;3m build_statespace_graph method. \u001b[0m\n" + "\u001B[2;3m These parameters should be assigned priors inside a \u001B[0m\n", + "\u001B[2;3m PyMC model block before calling the \u001B[0m\n", + "\u001B[2;3m build_statespace_graph method. \u001B[0m\n" ] }, "metadata": {}, @@ -2044,20 +2044,20 @@ "\n" ], "text/plain": [ - "\u001b[3m Model Requirements \u001b[0m\n", + "\u001B[3m Model Requirements \u001B[0m\n", " \n", - " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape\u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mDimensions\u001b[0m\u001b[1m \u001b[0m \n", + " \u001B[1m \u001B[0m\u001B[1mVariable \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mShape\u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mConstraints \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mDimensions\u001B[0m\u001B[1m \u001B[0m \n", " ──────────────────────────────────────────────────── \n", - " initial_level \u001b[3;35mNone\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " initial_trend \u001b[3;35mNone\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " alpha \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < alpha < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " beta \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < beta < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " phi \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < phi < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " sigma_state \u001b[3;35mNone\u001b[0m Positive \u001b[3;35mNone\u001b[0m \n", + " initial_level \u001B[3;35mNone\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " initial_trend \u001B[3;35mNone\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " alpha \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < alpha < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " beta \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < beta < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " phi \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < phi < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " sigma_state \u001B[3;35mNone\u001B[0m Positive \u001B[3;35mNone\u001B[0m \n", " \n", - "\u001b[2;3m These parameters should be assigned priors inside a \u001b[0m\n", - "\u001b[2;3m PyMC model block before calling the \u001b[0m\n", - "\u001b[2;3m build_statespace_graph method. \u001b[0m\n" + "\u001B[2;3m These parameters should be assigned priors inside a \u001B[0m\n", + "\u001B[2;3m PyMC model block before calling the \u001B[0m\n", + "\u001B[2;3m build_statespace_graph method. \u001B[0m\n" ] }, "metadata": {}, @@ -2664,19 +2664,19 @@ "\n" ], "text/plain": [ - "\u001b[3m Model Requirements \u001b[0m\n", + "\u001B[3m Model Requirements \u001B[0m\n", " \n", - " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", + " \u001B[1m \u001B[0m\u001B[1mVariable \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mShape \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mConstraints \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1m Dimensions\u001B[0m\u001B[1m \u001B[0m \n", " ──────────────────────────────────────────────────────────────────────────────────────────── \n", - " initial_level \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m,\u001b[1m)\u001b[0m \n", - " initial_trend \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m,\u001b[1m)\u001b[0m \n", - " alpha \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m,\u001b[1m)\u001b[0m \u001b[1;36m0\u001b[0m < alpha < \u001b[1;36m1\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m,\u001b[1m)\u001b[0m \n", - " beta \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m,\u001b[1m)\u001b[0m \u001b[1;36m0\u001b[0m < beta < \u001b[1;36m1\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m,\u001b[1m)\u001b[0m \n", - " phi \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m,\u001b[1m)\u001b[0m \u001b[1;36m0\u001b[0m < phi < \u001b[1;36m1\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m,\u001b[1m)\u001b[0m \n", - " state_cov \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m, \u001b[1;36m2\u001b[0m\u001b[1m)\u001b[0m Positive Semi-definite \u001b[1m(\u001b[0m\u001b[32m'observed_state'\u001b[0m, \u001b[32m'observed_state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", + " initial_level \u001B[1m(\u001B[0m\u001B[1;36m2\u001B[0m,\u001B[1m)\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m,\u001B[1m)\u001B[0m \n", + " initial_trend \u001B[1m(\u001B[0m\u001B[1;36m2\u001B[0m,\u001B[1m)\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m,\u001B[1m)\u001B[0m \n", + " alpha \u001B[1m(\u001B[0m\u001B[1;36m2\u001B[0m,\u001B[1m)\u001B[0m \u001B[1;36m0\u001B[0m < alpha < \u001B[1;36m1\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m,\u001B[1m)\u001B[0m \n", + " beta \u001B[1m(\u001B[0m\u001B[1;36m2\u001B[0m,\u001B[1m)\u001B[0m \u001B[1;36m0\u001B[0m < beta < \u001B[1;36m1\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m,\u001B[1m)\u001B[0m \n", + " phi \u001B[1m(\u001B[0m\u001B[1;36m2\u001B[0m,\u001B[1m)\u001B[0m \u001B[1;36m0\u001B[0m < phi < \u001B[1;36m1\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m,\u001B[1m)\u001B[0m \n", + " state_cov \u001B[1m(\u001B[0m\u001B[1;36m2\u001B[0m, \u001B[1;36m2\u001B[0m\u001B[1m)\u001B[0m Positive Semi-definite \u001B[1m(\u001B[0m\u001B[32m'observed_state'\u001B[0m, \u001B[32m'observed_state_aux'\u001B[0m\u001B[1m)\u001B[0m \n", " \n", - "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block before calling the \u001b[0m\n", - "\u001b[2;3m build_statespace_graph method. \u001b[0m\n" + "\u001B[2;3m These parameters should be assigned priors inside a PyMC model block before calling the \u001B[0m\n", + "\u001B[2;3m build_statespace_graph method. \u001B[0m\n" ] }, "metadata": {}, @@ -3633,21 +3633,21 @@ "\n" ], "text/plain": [ - "\u001b[3m Model Requirements \u001b[0m\n", + "\u001B[3m Model Requirements \u001B[0m\n", " \n", - " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape\u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", + " \u001B[1m \u001B[0m\u001B[1mVariable \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mShape\u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1mConstraints \u001B[0m\u001B[1m \u001B[0m \u001B[1m \u001B[0m\u001B[1m Dimensions\u001B[0m\u001B[1m \u001B[0m \n", " ────────────────────────────────────────────────────────────── \n", - " initial_level \u001b[3;35mNone\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " initial_trend \u001b[3;35mNone\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " initial_seasonal \u001b[1m(\u001b[0m\u001b[1;36m12\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'seasonal_lag'\u001b[0m,\u001b[1m)\u001b[0m \n", - " alpha \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < alpha < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " beta \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < beta < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " gamma \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < gamma< \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " phi \u001b[3;35mNone\u001b[0m \u001b[1;36m0\u001b[0m < phi < \u001b[1;36m1\u001b[0m \u001b[3;35mNone\u001b[0m \n", - " sigma_state \u001b[3;35mNone\u001b[0m Positive \u001b[3;35mNone\u001b[0m \n", + " initial_level \u001B[3;35mNone\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " initial_trend \u001B[3;35mNone\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " initial_seasonal \u001B[1m(\u001B[0m\u001B[1;36m12\u001B[0m,\u001B[1m)\u001B[0m \u001B[1m(\u001B[0m\u001B[32m'seasonal_lag'\u001B[0m,\u001B[1m)\u001B[0m \n", + " alpha \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < alpha < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " beta \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < beta < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " gamma \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < gamma< \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " phi \u001B[3;35mNone\u001B[0m \u001B[1;36m0\u001B[0m < phi < \u001B[1;36m1\u001B[0m \u001B[3;35mNone\u001B[0m \n", + " sigma_state \u001B[3;35mNone\u001B[0m Positive \u001B[3;35mNone\u001B[0m \n", " \n", - "\u001b[2;3m These parameters should be assigned priors inside a PyMC model \u001b[0m\n", - "\u001b[2;3m block before calling the build_statespace_graph method. \u001b[0m\n" + "\u001B[2;3m These parameters should be assigned priors inside a PyMC model \u001B[0m\n", + "\u001B[2;3m block before calling the build_statespace_graph method. \u001B[0m\n" ] }, "metadata": {}, diff --git a/pymc_extras/__init__.py b/pymc_extras/__init__.py index be5b6c48..5ef7a58e 100644 --- a/pymc_extras/__init__.py +++ b/pymc_extras/__init__.py @@ -15,7 +15,9 @@ from pymc_extras import gp, statespace, utils from pymc_extras.distributions import * +from pymc_extras.inference.find_map import find_MAP from pymc_extras.inference.fit import fit +from pymc_extras.inference.laplace import fit_laplace from pymc_extras.model.marginal.marginal_model import ( MarginalModel, marginalize, diff --git a/pymc_extras/inference/find_map.py b/pymc_extras/inference/find_map.py index 72ce3b19..063f6ce9 100644 --- a/pymc_extras/inference/find_map.py +++ b/pymc_extras/inference/find_map.py @@ -1,9 +1,9 @@ import logging from collections.abc import Callable +from importlib.util import find_spec from typing import Literal, cast, get_args -import jax import numpy as np import pymc as pm import pytensor @@ -30,13 +30,29 @@ def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): method_info = MINIMIZE_MODE_KWARGS[method].copy() - use_grad = use_grad if use_grad is not None else method_info["uses_grad"] - use_hess = use_hess if use_hess is not None else method_info["uses_hess"] - use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"] - if use_hess and use_hessp: + _log.warning( + 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the ' + 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. ' + 'Setting "use_hess" to False.' + ) use_hess = False + use_grad = use_grad if use_grad is not None else method_info["uses_grad"] + + if use_hessp is not None and use_hess is None: + use_hess = not use_hessp + + elif use_hess is not None and use_hessp is None: + use_hessp = not use_hess + + elif use_hessp is None and use_hess is None: + use_hessp = method_info["uses_hessp"] + use_hess = method_info["uses_hess"] + if use_hessp and use_hess: + # If a method could use either hess or hessp, we default to using hessp + use_hess = False + return use_grad, use_hess, use_hessp @@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray: The nearest positive semi-definite matrix to the input matrix. """ C = (A + A.T) / 2 - eigval, eigvec = np.linalg.eig(C) + eigval, eigvec = np.linalg.eigh(C) eigval[eigval < 0] = 0 return eigvec @ np.diag(eigval) @ eigvec.T @@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, return f_untransform(posterior_draws) -def _compile_jax_gradients( +def _compile_grad_and_hess_to_jax( f_loss: Function, use_hess: bool, use_hessp: bool ) -> tuple[Callable | None, Callable | None]: """ @@ -122,6 +138,8 @@ def _compile_jax_gradients( f_hessp: Callable | None The compiled hessian-vector product function, or None if use_hessp is False. """ + import jax + f_hess = None f_hessp = None @@ -152,7 +170,7 @@ def f_hess_jax(x): return f_loss_and_grad, f_hess, f_hessp -def _compile_functions( +def _compile_functions_for_scipy_optimize( loss: TensorVariable, inputs: list[TensorVariable], compute_grad: bool, @@ -177,7 +195,7 @@ def _compile_functions( compute_hessp: bool Whether to compile a function that computes the Hessian-vector product of the loss function. compile_kwargs: dict, optional - Additional keyword arguments to pass to the ``pm.compile_pymc`` function. + Additional keyword arguments to pass to the ``pm.compile`` function. Returns ------- @@ -193,19 +211,19 @@ def _compile_functions( if compute_grad: grads = pytensor.gradient.grad(loss, inputs) grad = pt.concatenate([grad.ravel() for grad in grads]) - f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs) + f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs) else: - f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs) + f_loss = pm.compile(inputs, loss, **compile_kwargs) return [f_loss] if compute_hess: hess = pytensor.gradient.jacobian(grad, inputs)[0] - f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs) + f_hess = pm.compile(inputs, hess, **compile_kwargs) if compute_hessp: p = pt.tensor("p", shape=inputs[0].type.shape) hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p) - f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs) + f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs) return [f_loss_and_grad, f_hess, f_hessp] @@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss( gradient_backend: str, default "pytensor" Which backend to use to compute gradients. Must be one of "jax" or "pytensor" compile_kwargs: - Additional keyword arguments to pass to the ``pm.compile_pymc`` function. + Additional keyword arguments to pass to the ``pm.compile`` function. Returns ------- @@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss( ) use_jax_gradients = (gradient_backend == "jax") and use_grad + if use_jax_gradients and not find_spec("jax"): + raise ImportError("JAX must be installed to use JAX gradients") mode = compile_kwargs.get("mode", None) if mode is None and use_jax_gradients: @@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss( compute_hess = use_hess and not use_jax_gradients compute_hessp = use_hessp and not use_jax_gradients - funcs = _compile_functions( + funcs = _compile_functions_for_scipy_optimize( loss=loss, inputs=[flat_input], compute_grad=compute_grad, @@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss( if use_jax_gradients: # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values - f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp) + f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp) return f_loss, f_hess, f_hessp diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 1cfe0413..c6c9fd95 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -16,6 +16,7 @@ import logging from functools import reduce +from importlib.util import find_spec from itertools import product from typing import Literal @@ -231,7 +232,7 @@ def add_data_to_inferencedata( return idata -def fit_mvn_to_MAP( +def fit_mvn_at_MAP( optimized_point: dict[str, np.ndarray], model: pm.Model | None = None, on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", @@ -276,6 +277,9 @@ def fit_mvn_to_MAP( inverse_hessian: np.ndarray The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. """ + if gradient_backend == "jax" and not find_spec("jax"): + raise ImportError("JAX must be installed to use JAX gradients") + model = pm.modelcontext(model) compile_kwargs = {} if compile_kwargs is None else compile_kwargs frozen_model = freeze_dims_and_data(model) @@ -344,8 +348,10 @@ def sample_laplace_posterior( Parameters ---------- - mu - H_inv + mu: RaveledVars + The MAP estimate of the model parameters. + H_inv: np.ndarray + The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. model : Model A PyMC model chains : int @@ -384,9 +390,7 @@ def sample_laplace_posterior( constrained_rvs, replace={unconstrained_vector: batched_values} ) - f_constrain = pm.compile_pymc( - inputs=[batched_values], outputs=batched_rvs, **compile_kwargs - ) + f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs) posterior_draws = f_constrain(posterior_draws) else: @@ -472,15 +476,17 @@ def fit_laplace( and 1). .. warning:: - This argumnet should be considered highly experimental. It has not been verified if this method produces + This argument should be considered highly experimental. It has not been verified if this method produces valid draws from the posterior. **Use at your own risk**. gradient_backend: str, default "pytensor" The backend to use for gradient computations. Must be one of "pytensor" or "jax". chains: int, default: 2 - The number of sampling chains running in parallel. + The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel, + because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are + compatible with the ArviZ library. draws: int, default: 500 - The number of samples to draw from the approximated posterior. + The number of samples to draw from the approximated posterior. Totals samples will be chains * draws. on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. @@ -547,11 +553,12 @@ def fit_laplace( **optimizer_kwargs, ) - mu, H_inv = fit_mvn_to_MAP( + mu, H_inv = fit_mvn_at_MAP( optimized_point=optimized_point, model=model, on_bad_cov=on_bad_cov, transform_samples=fit_in_unconstrained_space, + gradient_backend=gradient_backend, zero_tol=zero_tol, diag_jitter=diag_jitter, compile_kwargs=compile_kwargs, diff --git a/pymc_extras/model/marginal/marginal_model.py b/pymc_extras/model/marginal/marginal_model.py index 178847a2..b6ca25bf 100644 --- a/pymc_extras/model/marginal/marginal_model.py +++ b/pymc_extras/model/marginal/marginal_model.py @@ -19,7 +19,8 @@ model_free_rv, model_from_fgraph, ) -from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace +from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace +from pymc.pytensorf import compile as compile_pymc from pymc.util import RandomState, _get_seeds_per_chain from pytensor import In, Out from pytensor.compile import SharedVariable diff --git a/pymc_extras/statespace/core/compile.py b/pymc_extras/statespace/core/compile.py index a9e13ee2..b6641ed7 100644 --- a/pymc_extras/statespace/core/compile.py +++ b/pymc_extras/statespace/core/compile.py @@ -30,7 +30,7 @@ def compile_statespace( inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs)) - _f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs) + _f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs) def f(*, draws=1, **params): if isinstance(steps, pt.Variable): diff --git a/requirements.txt b/requirements.txt index 1051d2ee..a4f00ee2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -pymc>=5.19.1 +pymc>=5.20 scikit-learn +better-optimize diff --git a/tests/test_find_map.py b/tests/test_find_map.py index c762ba55..34c8fc76 100644 --- a/tests/test_find_map.py +++ b/tests/test_find_map.py @@ -54,24 +54,28 @@ def compute_z(x): @pytest.mark.parametrize( - "method, use_grad, use_hess", + "method, use_grad, use_hess, use_hessp", [ - ("nelder-mead", False, False), - ("powell", False, False), - ("CG", True, False), - ("BFGS", True, False), - ("L-BFGS-B", True, False), - ("TNC", True, False), - ("SLSQP", True, False), - ("dogleg", True, True), - ("trust-ncg", True, True), - ("trust-exact", True, True), - ("trust-krylov", True, True), - ("trust-constr", True, True), + ("nelder-mead", False, False, False), + ("powell", False, False, False), + ("CG", True, False, False), + ("BFGS", True, False, False), + ("L-BFGS-B", True, False, False), + ("TNC", True, False, False), + ("SLSQP", True, False, False), + ("dogleg", True, True, False), + ("Newton-CG", True, True, False), + ("Newton-CG", True, False, True), + ("trust-ncg", True, True, False), + ("trust-ncg", True, False, True), + ("trust-exact", True, True, False), + ("trust-krylov", True, True, False), + ("trust-krylov", True, False, True), + ("trust-constr", True, True, False), ], ) @pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) -def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng): +def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng): extra_kwargs = {} if method == "dogleg": # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point @@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, **extra_kwargs, use_grad=use_grad, use_hess=use_hess, + use_hessp=use_hessp, progressbar=False, gradient_backend=gradient_backend, compile_kwargs={"mode": "JAX"}, diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 50cdd9ad..18214da6 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -19,10 +19,10 @@ import pymc_extras as pmx -from pymc_extras.inference.find_map import find_MAP +from pymc_extras.inference.find_map import GradientBackend, find_MAP from pymc_extras.inference.laplace import ( fit_laplace, - fit_mvn_to_MAP, + fit_mvn_at_MAP, sample_laplace_posterior, ) @@ -37,7 +37,11 @@ def rng(): "ignore:hessian will stop negating the output in a future version of PyMC.\n" + "To suppress this warning set `negate_output=False`:FutureWarning", ) -def test_laplace(): +@pytest.mark.parametrize( + "mode, gradient_backend", + [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], +) +def test_laplace(mode, gradient_backend: GradientBackend): # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, # Aki Vehtari, and Donald Rubin. @@ -55,7 +59,13 @@ def test_laplace(): vars = [mu, logsigma] idata = pmx.fit( - method="laplace", optimize_method="trust-ncg", draws=draws, random_seed=173300, chains=1 + method="laplace", + optimize_method="trust-ncg", + draws=draws, + random_seed=173300, + chains=1, + compile_kwargs={"mode": mode}, + gradient_backend=gradient_backend, ) assert idata.posterior["mu"].shape == (1, draws) @@ -71,7 +81,11 @@ def test_laplace(): np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) -def test_laplace_only_fit(): +@pytest.mark.parametrize( + "mode, gradient_backend", + [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], +) +def test_laplace_only_fit(mode, gradient_backend: GradientBackend): # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, # Aki Vehtari, and Donald Rubin. @@ -90,8 +104,8 @@ def test_laplace_only_fit(): method="laplace", optimize_method="BFGS", progressbar=True, - gradient_backend="jax", - compile_kwargs={"mode": "JAX"}, + gradient_backend=gradient_backend, + compile_kwargs={"mode": mode}, optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100), random_seed=173300, ) @@ -111,8 +125,11 @@ def test_laplace_only_fit(): [True, False], ids=["transformed", "untransformed"], ) -@pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"]) -def test_fit_laplace_coords(rng, transform_samples, mode): +@pytest.mark.parametrize( + "mode, gradient_backend", + [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], +) +def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend): coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as model: mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) @@ -131,13 +148,13 @@ def test_fit_laplace_coords(rng, transform_samples, mode): use_hessp=True, progressbar=False, compile_kwargs=dict(mode=mode), - gradient_backend="jax" if mode == "JAX" else "pytensor", + gradient_backend=gradient_backend, ) for value in optimized_point.values(): assert value.shape == (3,) - mu, H_inv = fit_mvn_to_MAP( + mu, H_inv = fit_mvn_at_MAP( optimized_point=optimized_point, model=model, transform_samples=transform_samples, @@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode): ] -def test_fit_laplace_ragged_coords(rng): +@pytest.mark.parametrize( + "mode, gradient_backend", + [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], +) +def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng): coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as ragged_dim_model: X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) @@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng): progressbar=False, use_grad=True, use_hessp=True, - gradient_backend="jax", - compile_kwargs={"mode": "JAX"}, + gradient_backend=gradient_backend, + compile_kwargs={"mode": mode}, ) assert idata["posterior"].beta.shape[-2:] == (3, 2) @@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng): [True, False], ids=["transformed", "untransformed"], ) -def test_fit_laplace(fit_in_unconstrained_space): +@pytest.mark.parametrize( + "mode, gradient_backend", + [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], +) +def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend): with pm.Model() as simp_model: mu = pm.Normal("mu", mu=3, sigma=0.5) sigma = pm.Exponential("sigma", 1) @@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space): use_hessp=True, fit_in_unconstrained_space=fit_in_unconstrained_space, optimizer_kwargs=dict(maxiter=100_000, tol=1e-100), + compile_kwargs={"mode": mode}, + gradient_backend=gradient_backend, ) np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1)