diff --git a/README.md b/README.md index cdc66b1..dfde9b7 100644 --- a/README.md +++ b/README.md @@ -35,8 +35,9 @@ A [No-U-Turn Sampler](https://arxiv.org/abs/1111.4246) is used to ensure fast co * Python>=3.10 * [numpy](https://numpy.org/) -* [PyMC](https://docs.pymc.io/)>=4.0 +* [PyMC](https://docs.pymc.io/)>=5.1 * [arviz](https://docs.pymc.io/)>=0.14 +* [NumPyro](https://num.pyro.ai/) (*optional, for NumPyro backend, faster execution times!*) * [uproot](https://github.com/scikit-hep/uproot4)>=4.1 (*optional, for reading and writing root files*) * [h5py](https://www.h5py.org/) (*optional, for reading and writing hdf5 files*) * [matplotlib](https://matplotlib.org/) (*optional, for plotting results*) diff --git a/requirements.txt b/requirements.txt index d26609e..d426d0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ arviz>=0.10.0 h5py hist numpy -pymc>=4.0.0 +pymc>=5.1.1 tqdm uproot>=4.1 +numpyro diff --git a/setup.cfg b/setup.cfg index b939998..5d7597a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,8 @@ classifiers = Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 Topic :: Scientific/Engineering Topic :: Scientific/Engineering :: Information Analysis Topic :: Scientific/Engineering :: Physics @@ -44,7 +46,7 @@ install_requires = arviz>=0.10.0 hist numpy - pymc>=4.0.0 + pymc>=5.1.1 tqdm uproot>=4.1.0 python_requires = >=3.10 diff --git a/src/boris/core.py b/src/boris/core.py index afdbd63..45bc986 100644 --- a/src/boris/core.py +++ b/src/boris/core.py @@ -25,10 +25,7 @@ import numpy as np import pymc as pm -try: - from pytensor import tensor as pt -except ModuleNotFoundError: # pragma: no cover - from aesara import tensor as pt +from pytensor import tensor as pt from arviz import InferenceData, waic logger = logging.getLogger(__name__) @@ -210,11 +207,18 @@ def fit( start = {"incident": incident_start} if background is not None: start["background_incident"] = background_start + + import importlib.util + + numpyro_spec = importlib.util.find_spec("numpyro") + nuts_sampler = "pymc" if numpyro_spec is None else "numpyro" + trace = pm.sample( ndraws, step=step, start=start, idata_kwargs=dict(log_likelihood=True), + nuts_sampler=nuts_sampler, **kwargs, ) diff --git a/src/boris/utils.py b/src/boris/utils.py index ee2966b..0cf277f 100644 --- a/src/boris/utils.py +++ b/src/boris/utils.py @@ -22,6 +22,7 @@ from __future__ import annotations import logging +import math from collections import namedtuple from pathlib import Path from typing import Mapping, Literal @@ -35,7 +36,7 @@ logger = logging.getLogger(__name__) -one_sigma = np.math.erf(np.sqrt(0.5)) +one_sigma = math.erf(np.sqrt(0.5)) def get_bin_edges_from_calibration(