Skip to content

Commit

Permalink
Feat: Use numpyro NUTS backend
Browse files Browse the repository at this point in the history
  • Loading branch information
op3 committed Dec 5, 2023
1 parent e2128e7 commit 0bc07f1
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ A [No-U-Turn Sampler](https://arxiv.org/abs/1111.4246) is used to ensure fast co

* Python>=3.10
* [numpy](https://numpy.org/)
* [PyMC](https://docs.pymc.io/)>=4.0
* [PyMC](https://docs.pymc.io/)>=5.1
* [arviz](https://docs.pymc.io/)>=0.14
* [NumPyro](https://num.pyro.ai/) (*optional, for NumPyro backend, faster execution times!*)
* [uproot](https://github.com/scikit-hep/uproot4)>=4.1 (*optional, for reading and writing root files*)
* [h5py](https://www.h5py.org/) (*optional, for reading and writing hdf5 files*)
* [matplotlib](https://matplotlib.org/) (*optional, for plotting results*)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ arviz>=0.10.0
h5py
hist
numpy
pymc>=4.0.0
pymc>=5.1.1
tqdm
uproot>=4.1
numpyro
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ classifiers =
Programming Language :: Python :: 3
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Information Analysis
Topic :: Scientific/Engineering :: Physics
Expand All @@ -44,7 +46,7 @@ install_requires =
arviz>=0.10.0
hist
numpy
pymc>=4.0.0
pymc>=5.1.1
tqdm
uproot>=4.1.0
python_requires = >=3.10
Expand Down
12 changes: 8 additions & 4 deletions src/boris/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
import numpy as np
import pymc as pm

try:
from pytensor import tensor as pt
except ModuleNotFoundError: # pragma: no cover
from aesara import tensor as pt
from pytensor import tensor as pt
from arviz import InferenceData, waic

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -210,11 +207,18 @@ def fit(
start = {"incident": incident_start}
if background is not None:
start["background_incident"] = background_start

import importlib.util

numpyro_spec = importlib.util.find_spec("numpyro")
nuts_sampler = "pymc" if numpyro_spec is None else "numpyro"

trace = pm.sample(
ndraws,
step=step,
start=start,
idata_kwargs=dict(log_likelihood=True),
nuts_sampler=nuts_sampler,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion src/boris/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

import logging
import math
from collections import namedtuple
from pathlib import Path
from typing import Mapping, Literal
Expand All @@ -35,7 +36,7 @@

logger = logging.getLogger(__name__)

one_sigma = np.math.erf(np.sqrt(0.5))
one_sigma = math.erf(np.sqrt(0.5))


def get_bin_edges_from_calibration(
Expand Down

0 comments on commit 0bc07f1

Please sign in to comment.