diff --git a/.gitignore b/.gitignore index fa799ee0..fc4f0723 100755 --- a/.gitignore +++ b/.gitignore @@ -342,3 +342,4 @@ replay_pid* .DS_Store /.quarto/ +*_files diff --git a/.pre-commit-rst-placeholder.sh b/.pre-commit-rst-placeholder.sh old mode 100644 new mode 100755 diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 00000000..567609b1 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst index b8a35815..d6f0e0f4 100644 --- a/docs/source/tutorials/index.rst +++ b/docs/source/tutorials/index.rst @@ -9,3 +9,4 @@ This section contains tutorials that demonstrate how to use the `pyrenew` packag getting-started example-with-datasets pyrenew_demo + extending_pyrenew diff --git a/model/.gitignore b/model/.gitignore new file mode 100644 index 00000000..691e2357 --- /dev/null +++ b/model/.gitignore @@ -0,0 +1,2 @@ +/.quarto/ +_compiled diff --git a/model/Makefile b/model/Makefile index c80166ab..6fa9d67a 100644 --- a/model/Makefile +++ b/model/Makefile @@ -26,7 +26,7 @@ test: poetry run pytest --mpl --mpl-default-tolerance=10 docs: docs/pyrenew_demo.md docs/getting-started.md \ - docs/example-with-datasets.md + docs/example-with-datasets.md docs/extending_pyrenew.md docs/pyrenew_demo.md: docs/pyrenew_demo.qmd poetry run quarto render docs/pyrenew_demo.qmd @@ -37,16 +37,20 @@ docs/getting-started.md: docs/getting-started.qmd docs/example-with-datasets.md: docs/example-with-datasets.qmd poetry run quarto render docs/example-with-datasets.qmd +docs/extending_pyrenew.md: docs/extending_pyrenew.qmd + poetry run quarto render docs/extending_pyrenew.qmd + docs/py: docs/notebooks - jupyter nbconvert --to python docs/pyrenew_demo.ipynb - jupyter nbconvert --to python docs/getting-started.ipynb - jupyter nbconvert --to python docs/example-with-datasets.ipynb + for i in docs/*.ipynb; do \ + jupyter nbconvert --to python $$i ; \ + done docs/notebooks: - quarto convert docs/pyrenew_demo.qmd --output docs/pyrenew_demo.ipynb - quarto convert docs/getting-started.qmd --output docs/getting-started.ipynb - quarto convert docs/example-with-datasets.qmd --output \ - docs/example-with-datasets.ipynb + for i in docs/*.qmd; do \ + if [ $$i -nt $$(basename $$i .qmd).ipynb ]; then \ + quarto convert $$i --output docs/$$(basename $$i .qmd).ipynb ; \ + fi \ + done test_images: echo "Generating reference images for tests" diff --git a/model/docs/.gitignore b/model/docs/.gitignore index 6930a0ef..4234ad15 100644 --- a/model/docs/.gitignore +++ b/model/docs/.gitignore @@ -5,3 +5,6 @@ /.quarto/ _compiled + +*.ipynb +*.py diff --git a/model/docs/extending_pyrenew.qmd b/model/docs/extending_pyrenew.qmd new file mode 100644 index 00000000..2076d054 --- /dev/null +++ b/model/docs/extending_pyrenew.qmd @@ -0,0 +1,256 @@ +--- +title: Extending pyrenew +format: gfm +--- + +This tutorial illustrates how to extend `pyrenew` with custom `RandomVariable` classes. We will use the `InfectionsWithFeedback` class as an example. The `InfectionsWithFeedback` class is a `RandomVariable` that models the number of infections at time $t$ as a function of the number of infections at time $t - \tau$ and the reproduction number at time $t$. The reproduction number at time $t$ is a function of the *unadjusted* reproduction number at time $t - \tau$ and the number of infections at time $t - \tau$: + +$$ +\begin{align*} +I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\ +\mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(-\gamma(t)\sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right) +\end{align*} +$$ + +Where $\mathcal{R}^u(t)$ is the unadjusted reproduction number, $g(t)$ is the generation interval, $\gamma(t)$ is the infection feedback strength, and $f(t)$ is the infection feedback pmf. + +## The expected outcome + +Before we start, let's simulate the model with the original `InfectionsWithFeedback` class. To keep it simple, we will simulate the model with no observation process, in other words, only with latent infections. The following code-chunk loads the required libraries and defines the model components: + +```{python} +#| label: setup +import jax +import jax.numpy as jnp +import numpy as np +import numpyro as npro +import numpyro.distributions as dist +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.latent import InfectionsWithFeedback +from pyrenew.model import RtInfectionsRenewalModel +from pyrenew.process import RtRandomWalkProcess +from pyrenew.metaclass import DistributionalRV +``` + +The following code-chunk defines the model components. Notice that for both the generation interval and the infection feedback, we use a deterministic PMF with equal probabilities: + +```{python} +#| label: model-components +gen_int = DeterministicPMF(jnp.array([0.25, 0.5, 0.15, 0.1])) +feedback_strength = DeterministicVariable(0.05) + +I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + +latent_infections = InfectionsWithFeedback( + infection_feedback_strength = feedback_strength, + infection_feedback_pmf = gen_int, +) + +rt = RtRandomWalkProcess() +``` + +With all the components defined, we can build the model: + +```{python} +#| label: build1 +model0 = RtInfectionsRenewalModel( + gen_int=gen_int, + I0=I0, + latent_infections=latent_infections, + Rt_process=rt, + observation_process=None, +) +``` + +And simulate it from: + +```{python} +#| label: simulate1 +# Sampling and fitting model 0 (with no obs for infections) +np.random.seed(223) +with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + model0_samp = model0.sample(n_timepoints=30) +``` + +```{python} +#| label: fig-simulate1 +#| fig-cap: Simulated infections with no observation process +import matplotlib.pyplot as plt +fig, ax = plt.subplots() +ax.plot(model0_samp.latent_infections) +ax.set_xlabel("Time") +ax.set_ylabel("Infections") +plt.show() +``` + +## Pyrenew's random variable class + +### Fundamentals + +All instances of PyRenew's `RandomVariable` should have at least three functions: `__init__()`, `validate()`, and `sample()`. The `__init__()` function is the constructor and initializes the class. The `validate()` function checks if the class is correctly initialized. Finally, the `sample()` method contains the core of the class; it should return a tuple or named tuple. The following is a minimal example of a `RandomVariable` class based on `numpyro.distributions.Normal`: + +```python +from pyrenew.metaclass import RandomVariable + +class MyNormal(RandomVariable): + def __init__(self, loc, scale): + self.validate(scale) + self.loc = loc + self.scale = scale + return None + + @staticmethod + def validate(self): + if self.scale <= 0: + raise ValueError("Scale must be positive") + return None + + def sample(self, **kwargs): + return (dist.Normal(loc=self.loc, scale=self.scale),) +``` + +The `@staticmethod` decorator exposes the `validate` function to be used outside the class. Next, we show how to build a more complex `RandomVariable` class; the `InfectionsWithFeedback` class. + +### The `InfectionsWithFeedback` class + +Although returning namedtuples is not strictly required, they are the recommended return type, as they make the code more readable. The following code-chunk shows how to create a named tuple for the `InfectionsWithFeedback` class: + +```{python} +#| label: data-class +from collections import namedtuple + +# Creating a tuple to store the output +InfFeedbackSample = namedtuple( + typename='InfFeedbackSample', + field_names=['infections', 'rt'], + defaults=(None, None) +) +``` + +The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.datautils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: + +```{python} +#| label: new-model-def +#| code-line-numbers: true +# Creating the class +from pyrenew.metaclass import RandomVariable +from pyrenew.latent import compute_infections_from_rt_with_feedback +from pyrenew import datautils as du +from jax.typing import ArrayLike +import jax.numpy as jnp + +class InfFeedback(RandomVariable): + """Latent infections""" + + def __init__( + self, + infection_feedback_strength: RandomVariable, + infection_feedback_pmf: RandomVariable, + infections_mean_varname: str = "latent_infections", + ) -> None: + """Constructor""" + + self.infection_feedback_strength = infection_feedback_strength + self.infection_feedback_pmf = infection_feedback_pmf + self.infections_mean_varname = infections_mean_varname + + return None + + def validate(self): + """ + Generally, this method should be more meaningful, but we will skip it for now + """ + return None + + def sample( + self, + Rt: ArrayLike, + I0: ArrayLike, + gen_int: ArrayLike, + **kwargs, + ) -> tuple: + """Sample infections with feedback""" + + # Generation interval + gen_int_rev = jnp.flip(gen_int) + + # Baseline infections + I0_vec = I0[-gen_int_rev.size :] + + # Sampling inf feedback strength and adjusting the shape + inf_feedback_strength, *_ = self.infection_feedback_strength.sample( + **kwargs, + ) + inf_feedback_strength = du.pad_x_to_match_y( + x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0] + ) + + # Sampling inf feedback and adjusting the shape + inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) + + # Generating the infections with feedback + all_infections, Rt_adj = compute_infections_from_rt_with_feedback( + I0=I0_vec, + Rt_raw=Rt, + infection_feedback_strength=inf_feedback_strength, + reversed_generation_interval_pmf=gen_int_rev, + reversed_infection_feedback_pmf=inf_fb_pmf_rev, + ) + + # Storing adjusted Rt for future use + npro.deterministic("Rt_adjusted", Rt_adj) + + # Preparing theoutput + + return InfFeedbackSample( + infections=all_infections, + rt=Rt_adj, + ) +``` + +The core of the class is implemented in the `sample()` method. Things to highlight from the above code: + +1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback.sample()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable.sample()` calls are expected to include the `**kwargs` argument, even if unused. + +2. **Calls to `RandomVariable.sample()`**: All calls to `RandomVariable.sample()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength.sample()` and `infection_feedback_pmf.sample()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`). + +3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later. + +4. **Return type of `InfFeedback.sample()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`. + +```{python} +#| label: simulation2 +latent_infections2 = InfFeedback( + infection_feedback_strength = feedback_strength, + infection_feedback_pmf = gen_int, +) + +model1 = RtInfectionsRenewalModel( + gen_int=gen_int, + I0=I0, + latent_infections=latent_infections2, + Rt_process=rt, + observation_process=None, +) + +# Sampling and fitting model 0 (with no obs for infections) +np.random.seed(223) +with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + model1_samp = model1.sample(n_timepoints=30) +``` + +Comparing `model0` with `model1`, these two should match: + +```{python} +#| label: fig-model0-vs-model1 +#| fig-cap: Comparing latent infections from model 0 and model 1 +import matplotlib.pyplot as plt +fig, ax = plt.subplots(ncols=2) +ax[0].plot(model0_samp.latent_infections) +ax[1].plot(model1_samp.latent_infections) +ax[0].set_xlabel("Time (model 0)") +ax[1].set_xlabel("Time (model 1)") +ax[0].set_ylabel("Infections") +plt.show() +``` diff --git a/model/poetry.lock b/model/poetry.lock index 558785b9..57c77dbd 100644 --- a/model/poetry.lock +++ b/model/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "alabaster" @@ -73,6 +73,45 @@ files = [ [package.extras] dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] +[[package]] +name = "beautifulsoup4" +version = "4.12.3" +description = "Screen-scraping library" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"}, + {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, +] + +[package.dependencies] +soupsieve = ">1.2" + +[package.extras] +cchardet = ["cchardet"] +chardet = ["chardet"] +charset-normalizer = ["charset-normalizer"] +html5lib = ["html5lib"] +lxml = ["lxml"] + +[[package]] +name = "bleach" +version = "6.1.0" +description = "An easy safelist-based HTML-sanitizing tool." +optional = false +python-versions = ">=3.8" +files = [ + {file = "bleach-6.1.0-py3-none-any.whl", hash = "sha256:3225f354cfc436b9789c66c4ee030194bee0568fbf9cbdad3bc8b5c26c5f12b6"}, + {file = "bleach-6.1.0.tar.gz", hash = "sha256:0a31f1837963c41d46bbf1331b8778e1308ea0791db03cc4e7357b97cf42a8fe"}, +] + +[package.dependencies] +six = ">=1.9.0" +webencodings = "*" + +[package.extras] +css = ["tinycss2 (>=1.1.0,<1.3)"] + [[package]] name = "certifi" version = "2024.2.2" @@ -462,6 +501,17 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "defusedxml" +version = "0.7.1" +description = "XML bomb protection for Python stdlib modules" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, + {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, +] + [[package]] name = "docutils" version = "0.21.2" @@ -686,13 +736,13 @@ test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "num [[package]] name = "jax" -version = "0.4.27" +version = "0.4.28" description = "Differentiate, compile, and transform Numpy code." optional = false python-versions = ">=3.9" files = [ - {file = "jax-0.4.27-py3-none-any.whl", hash = "sha256:02cafc7310d0b89bead77a1559b719fbfa84c0e6683715b4941e04487a6d377e"}, - {file = "jax-0.4.27.tar.gz", hash = "sha256:f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c"}, + {file = "jax-0.4.28-py3-none-any.whl", hash = "sha256:6a181e6b5a5b1140e19cdd2d5c4aa779e4cb4ec627757b918be322d8e81035ba"}, + {file = "jax-0.4.28.tar.gz", hash = "sha256:dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9"}, ] [package.dependencies] @@ -710,43 +760,43 @@ scipy = [ [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.26)"] -cpu = ["jaxlib (==0.4.27)"] -cuda = ["jaxlib (==0.4.27+cuda12.cudnn89)"] -cuda12 = ["jax-cuda12-plugin (==0.4.27)", "jaxlib (==0.4.27)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] -cuda12-cudnn89 = ["jaxlib (==0.4.27+cuda12.cudnn89)"] -cuda12-local = ["jaxlib (==0.4.27+cuda12.cudnn89)"] -cuda12-pip = ["jaxlib (==0.4.27+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] -minimum-jaxlib = ["jaxlib (==0.4.23)"] -tpu = ["jaxlib (==0.4.27)", "libtpu-nightly (==0.1.dev20240507)", "requests"] +ci = ["jaxlib (==0.4.27)"] +cpu = ["jaxlib (==0.4.28)"] +cuda = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12 = ["jax-cuda12-plugin (==0.4.28)", "jaxlib (==0.4.28)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +cuda12-cudnn89 = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12-local = ["jaxlib (==0.4.28+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.28+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +minimum-jaxlib = ["jaxlib (==0.4.27)"] +tpu = ["jaxlib (==0.4.28)", "libtpu-nightly (==0.1.dev20240508)", "requests"] [[package]] name = "jaxlib" -version = "0.4.27" +version = "0.4.28" description = "XLA library for JAX" optional = false python-versions = ">=3.9" files = [ - {file = "jaxlib-0.4.27-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:9493a92c8a8796bbb96b422430465437a1d4426e0b444a86bb2e7c2ea9dfbe69"}, - {file = "jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:93f0714b2c37dbc3c31e30c7b40296b9947d0bd61070410b271527084cf7f66d"}, - {file = "jaxlib-0.4.27-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:8267b015b3b1f8d1fc6779f93c035e4ccc3092692d0b17b05c5169e73785862f"}, - {file = "jaxlib-0.4.27-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d04febab81452f0bc611ec88f23c9a480fd968b845a9f942283bd6fbd229ac23"}, - {file = "jaxlib-0.4.27-cp310-cp310-win_amd64.whl", hash = "sha256:1828c0f0546cf9c252ef6afbd0b93fc0836a3c5ee27336c0b4a33cf7c602625e"}, - {file = "jaxlib-0.4.27-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:8feb311b63a6e1b23acc2b69070bf4feae07d8c4e673679d9d1e7d5928692bc5"}, - {file = "jaxlib-0.4.27-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:867705bcc7f2e7769b719b03257542647922423165cba79c63a27b4d5e957eaa"}, - {file = "jaxlib-0.4.27-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:f030b9f8914e2de193cc1bd7afd9eca2e03aa48347372fbb89dbb481c2614505"}, - {file = "jaxlib-0.4.27-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2b75c09d838fccfb7d4463025b597bc377a1a349d09ffa8fb8c5a7f88197f7e8"}, - {file = "jaxlib-0.4.27-cp311-cp311-win_amd64.whl", hash = "sha256:d37f1d8cab6fca11d8ba44d6404fdbb9b796d4e6f447d12317e4ce0660310072"}, - {file = "jaxlib-0.4.27-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:1bbfabae8c4e8c1daed701d3f2b2390399bdc805c152b9f3446f2912c85f7b65"}, - {file = "jaxlib-0.4.27-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6ea7b3c56d2341a4b6f1a69bd8e11dbdef3c566bca6538d77f24e23a7fc54601"}, - {file = "jaxlib-0.4.27-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:c8cfc399f01f717fe80eb47f76ed13c037340fb274703b0e61a9bd8ab00b5488"}, - {file = "jaxlib-0.4.27-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:85f8a0ae137dbd01ce563deae8220dde82d7a98ad4bf3a3cf159deee39dc81c9"}, - {file = "jaxlib-0.4.27-cp312-cp312-win_amd64.whl", hash = "sha256:9379594f35a9cab8dded3b78a930fe022fd838159388909f75c02ab307dff057"}, - {file = "jaxlib-0.4.27-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:d909bfbdfe358aa48fd3c8a9059df797655cca84e560b45de1cf2894c702ef7e"}, - {file = "jaxlib-0.4.27-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b545c4379d5ec84025985875620f5b1a6519a6bb932fc4368c76fc859af131f4"}, - {file = "jaxlib-0.4.27-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b3d2761a4f0a964521bdd8966abdb072fdead75fe6d542e31e57e96d420d8a06"}, - {file = "jaxlib-0.4.27-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:76640b14f618a4615a44339d1f6287a50e4e03a6612a56de4ff8e24488e7429a"}, - {file = "jaxlib-0.4.27-cp39-cp39-win_amd64.whl", hash = "sha256:9d10fa2bf384ef27ad0689a746f441bab7df770906ae2f69f3126d84b09ecda5"}, + {file = "jaxlib-0.4.28-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:a421d237f8c25d2850166d334603c673ddb9b6c26f52bc496704b8782297bd66"}, + {file = "jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f038e68bd10d1a3554722b0bbe36e6a448384437a75aa9d283f696f0ed9f8c09"}, + {file = "jaxlib-0.4.28-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:fabe77c174e9e196e9373097cefbb67e00c7e5f9d864583a7cfcf9dabd2429b6"}, + {file = "jaxlib-0.4.28-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:e3bcdc6f8e60f8554f415c14d930134e602e3ca33c38e546274fd545f875769b"}, + {file = "jaxlib-0.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:a8b31c0e5eea36b7915696b9be40ea8646edc395a3e5437bf7ef26b7239a567a"}, + {file = "jaxlib-0.4.28-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2ff8290edc7b92c7eae52517f65492633e267b2e9067bad3e4c323d213e77cf5"}, + {file = "jaxlib-0.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:793857faf37f371cafe752fea5fc811f435e43b8fb4b502058444a7f5eccf829"}, + {file = "jaxlib-0.4.28-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b41a6b0d506c09f86a18ecc05bd376f072b548af89c333107e49bb0c09c1a3f8"}, + {file = "jaxlib-0.4.28-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:45ce0f3c840cff8236cff26c37f26c9ff078695f93e0c162c320c281f5041275"}, + {file = "jaxlib-0.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:d4d762c3971d74e610a0e85a7ee063cea81a004b365b2a7dc65133f08b04fac5"}, + {file = "jaxlib-0.4.28-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:d6c09a545329722461af056e735146d2c8c74c22ac7426a845eb69f326b4f7a0"}, + {file = "jaxlib-0.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8dd8bffe3853702f63cd924da0ee25734a4d19cd5c926be033d772ba7d1c175d"}, + {file = "jaxlib-0.4.28-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:de2e8521eb51e16e85093a42cb51a781773fa1040dcf9245d7ea160a14ee5a5b"}, + {file = "jaxlib-0.4.28-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:46a1aa857f4feee8a43fcba95c0e0ab62d40c26cc9730b6c69655908ba359f8d"}, + {file = "jaxlib-0.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:eee428eac31697a070d655f1f24f6ab39ced76750d93b1de862377a52dcc2401"}, + {file = "jaxlib-0.4.28-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:4f98cc837b2b6c6dcfe0ab7ff9eb109314920946119aa3af9faa139718ff2787"}, + {file = "jaxlib-0.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b01562ec8ad75719b7d0389752489e97eb6b4dcb4c8c113be491634d5282ad3c"}, + {file = "jaxlib-0.4.28-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:aa77a9360a395ba9faf6932df637686fb0c14ddcf4fdc1d2febe04bc88a580a6"}, + {file = "jaxlib-0.4.28-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4a56ebf05b4a4c1791699d874e072f3f808f0986b4010b14fb549a69c90ca9dc"}, + {file = "jaxlib-0.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:459a4ddcc3e120904b9f13a245430d7801d707bca48925981cbdc59628057dc8"}, ] [package.dependencies] @@ -873,6 +923,17 @@ traitlets = ">=5.3" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"] +[[package]] +name = "jupyterlab-pygments" +version = "0.3.0" +description = "Pygments theme using JupyterLab CSS variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780"}, + {file = "jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d"}, +] + [[package]] name = "kiwisolver" version = "1.4.5" @@ -1057,39 +1118,40 @@ files = [ [[package]] name = "matplotlib" -version = "3.8.4" +version = "3.9.0" description = "Python plotting package" optional = false python-versions = ">=3.9" files = [ - {file = "matplotlib-3.8.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014"}, - {file = "matplotlib-3.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106"}, - {file = "matplotlib-3.8.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10"}, - {file = "matplotlib-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0"}, - {file = "matplotlib-3.8.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef"}, - {file = "matplotlib-3.8.4-cp310-cp310-win_amd64.whl", hash = "sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338"}, - {file = "matplotlib-3.8.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661"}, - {file = "matplotlib-3.8.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c"}, - {file = "matplotlib-3.8.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa"}, - {file = "matplotlib-3.8.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71"}, - {file = "matplotlib-3.8.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b"}, - {file = "matplotlib-3.8.4-cp311-cp311-win_amd64.whl", hash = "sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae"}, - {file = "matplotlib-3.8.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616"}, - {file = "matplotlib-3.8.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732"}, - {file = "matplotlib-3.8.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb"}, - {file = "matplotlib-3.8.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30"}, - {file = "matplotlib-3.8.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25"}, - {file = "matplotlib-3.8.4-cp312-cp312-win_amd64.whl", hash = "sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a"}, - {file = "matplotlib-3.8.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:843cbde2f0946dadd8c5c11c6d91847abd18ec76859dc319362a0964493f0ba6"}, - {file = "matplotlib-3.8.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67"}, - {file = "matplotlib-3.8.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc"}, - {file = "matplotlib-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:606e3b90897554c989b1e38a258c626d46c873523de432b1462f295db13de6f9"}, - {file = "matplotlib-3.8.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9bb0189011785ea794ee827b68777db3ca3f93f3e339ea4d920315a0e5a78d54"}, - {file = "matplotlib-3.8.4-cp39-cp39-win_amd64.whl", hash = "sha256:6209e5c9aaccc056e63b547a8152661324404dd92340a6e479b3a7f24b42a5d0"}, - {file = "matplotlib-3.8.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35"}, - {file = "matplotlib-3.8.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f"}, - {file = "matplotlib-3.8.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94"}, - {file = "matplotlib-3.8.4.tar.gz", hash = "sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea"}, + {file = "matplotlib-3.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56"}, + {file = "matplotlib-3.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b"}, + {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241"}, + {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d"}, + {file = "matplotlib-3.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4"}, + {file = "matplotlib-3.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463"}, + {file = "matplotlib-3.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38"}, + {file = "matplotlib-3.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152"}, + {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85"}, + {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb"}, + {file = "matplotlib-3.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674"}, + {file = "matplotlib-3.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be"}, + {file = "matplotlib-3.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382"}, + {file = "matplotlib-3.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84"}, + {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5"}, + {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db"}, + {file = "matplotlib-3.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7"}, + {file = "matplotlib-3.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf"}, + {file = "matplotlib-3.9.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956"}, + {file = "matplotlib-3.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a"}, + {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321"}, + {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89"}, + {file = "matplotlib-3.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b"}, + {file = "matplotlib-3.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e"}, + {file = "matplotlib-3.9.0.tar.gz", hash = "sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a"}, ] [package.dependencies] @@ -1097,12 +1159,15 @@ contourpy = ">=1.0.1" cycler = ">=0.10" fonttools = ">=4.22.0" kiwisolver = ">=1.3.1" -numpy = ">=1.21" +numpy = ">=1.23" packaging = ">=20.0" pillow = ">=8" pyparsing = ">=2.3.1" python-dateutil = ">=2.7" +[package.extras] +dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setuptools (>=64)", "setuptools_scm (>=7)"] + [[package]] name = "matplotlib-inline" version = "0.1.7" @@ -1117,6 +1182,17 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "mistune" +version = "3.0.2" +description = "A sane and fast Markdown parser with useful plugins and renderers" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, + {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, +] + [[package]] name = "ml-dtypes" version = "0.4.0" @@ -1186,6 +1262,43 @@ dev = ["pre-commit"] docs = ["autodoc-traits", "mock", "moto", "myst-parser", "nbclient[test]", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling"] test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0,<8)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] +[[package]] +name = "nbconvert" +version = "7.16.4" +description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)." +optional = false +python-versions = ">=3.8" +files = [ + {file = "nbconvert-7.16.4-py3-none-any.whl", hash = "sha256:05873c620fe520b6322bf8a5ad562692343fe3452abda5765c7a34b7d1aa3eb3"}, + {file = "nbconvert-7.16.4.tar.gz", hash = "sha256:86ca91ba266b0a448dc96fa6c5b9d98affabde2867b363258703536807f9f7f4"}, +] + +[package.dependencies] +beautifulsoup4 = "*" +bleach = "!=5.0.0" +defusedxml = "*" +jinja2 = ">=3.0" +jupyter-core = ">=4.7" +jupyterlab-pygments = "*" +markupsafe = ">=2.0" +mistune = ">=2.0.3,<4" +nbclient = ">=0.5.0" +nbformat = ">=5.7" +packaging = "*" +pandocfilters = ">=1.4.1" +pygments = ">=2.4.1" +tinycss2 = "*" +traitlets = ">=5.1" + +[package.extras] +all = ["flaky", "ipykernel", "ipython", "ipywidgets (>=7.5)", "myst-parser", "nbsphinx (>=0.2.12)", "playwright", "pydata-sphinx-theme", "pyqtwebengine (>=5.15)", "pytest (>=7)", "sphinx (==5.0.2)", "sphinxcontrib-spelling", "tornado (>=6.1)"] +docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sphinx-theme", "sphinx (==5.0.2)", "sphinxcontrib-spelling"] +qtpdf = ["pyqtwebengine (>=5.15)"] +qtpng = ["pyqtwebengine (>=5.15)"] +serve = ["tornado (>=6.1)"] +test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest (>=7)"] +webpdf = ["playwright"] + [[package]] name = "nbformat" version = "5.10.4" @@ -1340,6 +1453,17 @@ files = [ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] +[[package]] +name = "pandocfilters" +version = "1.5.1" +description = "Utilities for writing pandoc filters in python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pandocfilters-1.5.1-py2.py3-none-any.whl", hash = "sha256:93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc"}, + {file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"}, +] + [[package]] name = "parso" version = "0.8.4" @@ -1457,13 +1581,13 @@ xmp = ["defusedxml"] [[package]] name = "platformdirs" -version = "4.2.1" +version = "4.2.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.2.1-py3-none-any.whl", hash = "sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1"}, - {file = "platformdirs-4.2.1.tar.gz", hash = "sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf"}, + {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, + {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, ] [package.extras] @@ -1488,22 +1612,22 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "polars" -version = "0.20.25" +version = "0.20.28" description = "Blazingly fast DataFrame library" optional = false python-versions = ">=3.8" files = [ - {file = "polars-0.20.25-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:126e3b7d9394e4b23b4cc48919b7188203feeeb35d861ad808f281eaa06d76e2"}, - {file = "polars-0.20.25-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:3bda62b681726538714a1159638ab7c9eeca6b8633fd778d84810c3e13b9c7e3"}, - {file = "polars-0.20.25-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62c8826e81c759f07bf5c0ae00f57a537644ae05fe68737185666b8ad8430664"}, - {file = "polars-0.20.25-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:0fb5e7a4a9831fba742f1c706e01656607089b6362a5e6f8d579b134a99795ce"}, - {file = "polars-0.20.25-cp38-abi3-win_amd64.whl", hash = "sha256:9eaeb9080c853e11b207d191025e0ba8fd59ea06a36c22d410a48f2f124e18cd"}, - {file = "polars-0.20.25.tar.gz", hash = "sha256:4308d63f956874bac9ae040bdd6d62b2992d0b1e1349301bc7a3b59458189108"}, + {file = "polars-0.20.28-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:18eb471826edaf00b9f4ff0885d19b824e48b4a16521ee173e24ef17691b623e"}, + {file = "polars-0.20.28-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2258ab7ee65e1cfc86f9e762ae10c9959e79f8650f121d30404bef2a266922e0"}, + {file = "polars-0.20.28-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53b2d164ae9645fffb54920eed91e0a48f72fcc3dc9414c9e67871cdcf94533d"}, + {file = "polars-0.20.28-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:4524bc108ddacd337c096abeabcd3310f42f10ba0f53c52d3ac5c9a17a9e71af"}, + {file = "polars-0.20.28-cp38-abi3-win_amd64.whl", hash = "sha256:eae3549c3748e4323ca038314484215823bc7c3f764bb685ea1080e132fd29fd"}, + {file = "polars-0.20.28.tar.gz", hash = "sha256:ac3a59032b88a7e036eebe03f393f4f7e24f5197a5311f51b9f2b59458031fc6"}, ] [package.extras] adbc = ["adbc-driver-manager", "adbc-driver-sqlite"] -all = ["polars[adbc,async,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,pyiceberg,sqlalchemy,timezone,torch,xlsx2csv,xlsxwriter]"] +all = ["polars[adbc,async,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,iceberg,numpy,pandas,plot,pyarrow,pydantic,sqlalchemy,timezone,xlsx2csv,xlsxwriter]"] async = ["nest-asyncio"] cloudpickle = ["cloudpickle"] connectorx = ["connectorx (>=0.3.2)"] @@ -1511,6 +1635,7 @@ deltalake = ["deltalake (>=0.15.0)"] fastexcel = ["fastexcel (>=0.9)"] fsspec = ["fsspec"] gevent = ["gevent"] +iceberg = ["pyiceberg (>=0.5.0)"] matplotlib = ["matplotlib"] numpy = ["numpy (>=1.16.0)"] openpyxl = ["openpyxl (>=3.0.0)"] @@ -1518,11 +1643,9 @@ pandas = ["pandas", "pyarrow (>=7.0.0)"] plot = ["hvplot (>=0.9.1)"] pyarrow = ["pyarrow (>=7.0.0)"] pydantic = ["pydantic"] -pyiceberg = ["pyiceberg (>=0.5.0)"] pyxlsb = ["pyxlsb (>=1.0)"] sqlalchemy = ["pandas", "sqlalchemy"] timezone = ["backports-zoneinfo", "tzdata"] -torch = ["torch"] xlsx2csv = ["xlsx2csv (>=0.8.0)"] xlsxwriter = ["xlsxwriter"] @@ -1634,13 +1757,13 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pytest" -version = "8.2.0" +version = "8.2.1" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.2.0-py3-none-any.whl", hash = "sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233"}, - {file = "pytest-8.2.0.tar.gz", hash = "sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f"}, + {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"}, + {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"}, ] [package.dependencies] @@ -2099,6 +2222,17 @@ files = [ {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, ] +[[package]] +name = "soupsieve" +version = "2.5" +description = "A modern CSS selector implementation for Beautiful Soup." +optional = false +python-versions = ">=3.8" +files = [ + {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, + {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, +] + [[package]] name = "sphinx" version = "7.3.7" @@ -2261,6 +2395,24 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tinycss2" +version = "1.3.0" +description = "A tiny CSS parser" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tinycss2-1.3.0-py3-none-any.whl", hash = "sha256:54a8dbdffb334d536851be0226030e9505965bb2f30f21a4a82c55fb2a80fae7"}, + {file = "tinycss2-1.3.0.tar.gz", hash = "sha256:152f9acabd296a8375fbca5b84c961ff95971fcfc32e79550c8df8e29118c54d"}, +] + +[package.dependencies] +webencodings = ">=0.4" + +[package.extras] +doc = ["sphinx", "sphinx_rtd_theme"] +test = ["pytest", "ruff"] + [[package]] name = "tomli" version = "2.0.1" @@ -2366,7 +2518,18 @@ files = [ {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, ] +[[package]] +name = "webencodings" +version = "0.5.1" +description = "Character encoding aliases for legacy web content" +optional = false +python-versions = "*" +files = [ + {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"}, + {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, +] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "32cf7460d2913c29354a0b65b52ab47f9b1fae93a14240c0acb0dcd69b562ad5" +content-hash = "620b4f3c0809f97d326ba2142e27cd25ebd08e3e362f9f9311ab6127649f91d4" diff --git a/model/pyproject.toml b/model/pyproject.toml index 2cfe23f4..1ed67e45 100755 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -16,6 +16,8 @@ jax = "^0.4.24" numpy = "^1.26.4" polars = "^0.20.13" pillow = "^10.3.0" # See #56 on CDCgov/multisignal-epi-inference +nbconvert = "^7.16.4" +pytest-mpl = "^0.17.0" [tool.poetry.group.dev] optional = true diff --git a/model/src/pyrenew/convolve.py b/model/src/pyrenew/convolve.py index e545e360..ae838d2d 100755 --- a/model/src/pyrenew/convolve.py +++ b/model/src/pyrenew/convolve.py @@ -48,7 +48,8 @@ def _new_scanner( def new_double_scanner( - dists: tuple[ArrayLike, ArrayLike], transforms: tuple[Callable, Callable] + dists: tuple[ArrayLike, ArrayLike], + transforms: tuple[Callable, Callable], ) -> Callable: """ Factory function to create a scanner function that applies two sequential transformations @@ -74,7 +75,8 @@ def new_double_scanner( t1, t2 = transforms def _new_scanner( - history_subset: ArrayLike, multipliers: tuple[float, float] + history_subset: ArrayLike, + multipliers: tuple[float, float], ) -> tuple[ArrayLike, tuple[float, float]]: # numpydoc ignore=GL08 m1, m2 = multipliers m_net1 = t1(m1 * jnp.dot(d1, history_subset)) diff --git a/model/src/pyrenew/datautils.py b/model/src/pyrenew/datautils.py new file mode 100644 index 00000000..fe7b6944 --- /dev/null +++ b/model/src/pyrenew/datautils.py @@ -0,0 +1,75 @@ +""" +Utility functions for data processing. +""" + +import jax.numpy as jnp +from jax.typing import ArrayLike + + +def pad_to_match( + x: ArrayLike, + y: ArrayLike, + fill_value: float = 0.0, + fix_y: bool = False, +) -> tuple[ArrayLike, ArrayLike]: + """ + Pad the shorter array at the end to match the length of the longer array. + + Parameters + ---------- + x : ArrayLike + First array. + y : ArrayLike + Second array. + fill_value : float, optional + Value to use for padding, by default 0.0. + fix_y : bool, optional + If True, raise an error when `y` is shorter than `x`, by default False. + + Returns + ------- + tuple[ArrayLike, ArrayLike] + Tuple of the two arrays with the same length. + """ + + x = jnp.atleast_1d(x) + y = jnp.atleast_1d(y) + + x_len = x.size + y_len = y.size + if x_len > y_len: + if fix_y: + raise ValueError( + "Cannot fix y when x is longer than y." + + f" x_len: {x_len}, y_len: {y_len}." + ) + + y = jnp.pad(y, (0, x_len - y_len), constant_values=fill_value) + + elif y_len > x_len: + x = jnp.pad(x, (0, y_len - x_len), constant_values=fill_value) + + return x, y + + +def pad_x_to_match_y( + x: ArrayLike, + y: ArrayLike, + fill_value: float = 0.0, +) -> ArrayLike: + """ + Pad the `x` array at the end to match the length of the `y` array. + + Parameters + ---------- + x : ArrayLike + First array. + y : ArrayLike + Second array. + + Returns + ------- + Array + Padded array. + """ + return pad_to_match(x, y, fill_value=fill_value, fix_y=True)[0] diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index f7a4db6c..1e0bcc04 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -3,6 +3,7 @@ from __future__ import annotations +import jax.numpy as jnp from jax.typing import ArrayLike from pyrenew.metaclass import RandomVariable @@ -33,7 +34,7 @@ def __init__( """ self.validate(vars) - self.vars = vars + self.vars = jnp.atleast_1d(vars) self.label = label return None diff --git a/model/src/pyrenew/latent/__init__.py b/model/src/pyrenew/latent/__init__.py index 5487c70e..61adc9c0 100644 --- a/model/src/pyrenew/latent/__init__.py +++ b/model/src/pyrenew/latent/__init__.py @@ -9,6 +9,7 @@ logistic_susceptibility_adjustment, ) from pyrenew.latent.infections import Infections +from pyrenew.latent.infectionswithfeedback import InfectionsWithFeedback __all__ = [ "HospitalAdmissions", @@ -16,4 +17,5 @@ "logistic_susceptibility_adjustment", "compute_infections_from_rt", "compute_infections_from_rt_with_feedback", + "InfectionsWithFeedback", ] diff --git a/model/src/pyrenew/latent/infection_functions.py b/model/src/pyrenew/latent/infection_functions.py index 19257791..565e2d54 100755 --- a/model/src/pyrenew/latent/infection_functions.py +++ b/model/src/pyrenew/latent/infection_functions.py @@ -23,7 +23,7 @@ def compute_infections_from_rt( ---------- I0 : ArrayLike Array of initial infections of the - same length as the generation inferval + same length as the generation interval pmf vector. Rt : ArrayLike Timeseries of R(t) values @@ -90,10 +90,10 @@ def compute_infections_from_rt_with_feedback( I0: ArrayLike, Rt_raw: ArrayLike, infection_feedback_strength: ArrayLike, - generation_interval_pmf: ArrayLike, - infection_feedback_pmf: ArrayLike, + reversed_generation_interval_pmf: ArrayLike, + reversed_infection_feedback_pmf: ArrayLike, ) -> tuple: - """ + r""" Generate infections according to a renewal process with infection feedback (generalizing Asher 2018: @@ -103,7 +103,7 @@ def compute_infections_from_rt_with_feedback( ---------- I0 : ArrayLike Array of initial infections of the - same length as the generation inferval + same length as the generation interval pmf vector. Rt_raw : ArrayLike Timeseries of raw R(t) values not @@ -114,30 +114,74 @@ def compute_infections_from_rt_with_feedback( strength in time) or a vector representing the infection feedback strength at a given point in time. - generation_interval_pmf : ArrayLike + reversed_generation_interval_pmf : ArrayLike discrete probability mass vector representing the generation interval - of the infection process - infection_feedback_pmf : ArrayLike + of the infection process, where the final + entry represents an infection 1 time unit in the + past, the second-to-last entry represents + an infection two time units in the past, etc. + reversed_infection_feedback_pmf : ArrayLike discrete probability mass vector - whose `i`th entry represents the - relative contribution to infection + representing the infection feedback + process, where the final entry represents + the relative contribution to infection feedback from infections that occurred - `i` days in the past. + 1 time unit in the past, the second-to-last + entry represents the contribution from infections + that occurred 2 time units in the past, etc. Returns ------- tuple - A tuple `(Rt_adjusted, infections)`, + A tuple `(infections, Rt_adjusted)`, where `Rt_adjusted` is the infection-feedback-adjusted timeseries of the reproduction number R(t) and infections is the incident infection timeseries. + + Notes + ----- + This function implements the following renewal process: + + .. math:: + + I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) + + \mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(\gamma(t)\ + \sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right) + + where :math:`\mathcal{R}(t)` is the reproductive number, + :math:`\gamma(t)` is the infection feedback strength, + :math:`T_g` is the max-length of the + generation interval, :math:`\mathcal{R}^u(t)` is the raw reproduction + number, :math:`f(t)` is the infection feedback pmf, and :math:`T_f` + is the max-length of the infection feedback pmf. + + Note that negative :math:`\gamma(t)` implies + that recent incident infections reduce :math:`\mathcal{R}(t)` + below its raw value in the absence of feedback, while + positive :math:`\gamma` implies that recent incident infections + _increase_ :math:`\mathcal{R}(t)` above its raw value, and + :math:`gamma(t)=0` implies no feedback. + + In general, negative :math:`\gamma` is the more common modeling + choice, as it can be used to model susceptible depletion, + reductions in contact rate due to awareness of high incidence, + et cetera. """ feedback_scanner = new_double_scanner( - (infection_feedback_pmf, generation_interval_pmf), - (jnp.exp, lambda x: x), + dists=( + reversed_infection_feedback_pmf, + reversed_generation_interval_pmf, + ), + transforms=(jnp.exp, lambda x: x), ) - latest, infs_and_R = jax.lax.scan( - feedback_scanner, I0, (infection_feedback_strength, Rt_raw) + latest, infs_and_R_adj = jax.lax.scan( + f=feedback_scanner, + init=I0, + xs=(infection_feedback_strength, Rt_raw), ) - return infs_and_R + + infections, R_adjustment = infs_and_R_adj + Rt_adjusted = R_adjustment * Rt_raw + return infections, Rt_adjusted diff --git a/model/src/pyrenew/latent/infections.py b/model/src/pyrenew/latent/infections.py index 5edfbb99..acdfb101 100644 --- a/model/src/pyrenew/latent/infections.py +++ b/model/src/pyrenew/latent/infections.py @@ -31,8 +31,8 @@ def __repr__(self): class Infections(RandomVariable): r"""Latent infections - This class samples infections given Rt, initial infections, and generation - interval. + This class samples infections given Rt, + initial infections, and generation interval. Notes ----- @@ -79,7 +79,7 @@ def sample( I0: ArrayLike, gen_int: ArrayLike, **kwargs, - ) -> tuple: + ) -> InfectionsSample: """ Samples infections given Rt, initial infections, and generation interval. @@ -89,9 +89,11 @@ def sample( Rt : ArrayLike Reproduction number. I0 : ArrayLike - Initial infections. + Initial infections vector + of the same length as the + generation interval. gen_int : ArrayLike - Generation interval. + Generation interval pmf vector. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. @@ -104,8 +106,15 @@ def sample( gen_int_rev = jnp.flip(gen_int) - n_lead = gen_int_rev.size - 1 - I0_vec = jnp.hstack([jnp.zeros(n_lead), I0]) + if I0.size < gen_int_rev.size: + raise ValueError( + "Initial infections vector must be at least as long as " + "the generation interval. " + f"Initial infections vector length: {I0.size}, " + f"generation interval length: {gen_int_rev.size}." + ) + else: + I0_vec = I0[-gen_int_rev.size :] all_infections = inf.compute_infections_from_rt( I0=I0_vec, diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py new file mode 100644 index 00000000..5e65a2e2 --- /dev/null +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +# numpydoc ignore=GL08 + +from typing import NamedTuple + +import jax.numpy as jnp +import numpyro as npro +import pyrenew.datautils as du +import pyrenew.latent.infection_functions as inf +from numpy.typing import ArrayLike +from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype + + +class InfectionsRtFeedbackSample(NamedTuple): + """ + A container for holding the output from the InfectionsWithFeedback. + + Attributes + ---------- + infections : ArrayLike | None, optional + The estimated latent infections. Defaults to None. + rt : ArrayLike | None, optional + The adjusted reproduction number. Defaults to None. + """ + + infections: ArrayLike | None = None + rt: ArrayLike | None = None + + def __repr__(self): + return f"InfectionsSample(infections={self.infections}, rt={self.rt})" + + +class InfectionsWithFeedback(RandomVariable): + r""" + Latent infections + + This class computes infections, given Rt, initial infections, and generation + interval. + + Notes + ----- + This function implements the following renewal process (reproduced from + :func:`pyrenew.latent.infection_functions.sample_infections_with_feedback`): + + .. math:: + + I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) + + \mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(-\gamma(t)\ + \sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right) + + where :math:`\mathcal{R}(t)` is the reproductive number, :math:`\gamma(t)` + is the infection feedback strength, :math:`T_g` is the max-length of the + generation interval, :math:`\mathcal{R}^u(t)` is the raw reproduction + number, :math:`f(t)` is the infection feedback pmf, and :math:`T_f` + is the max-length of the infection feedback pmf. + """ + + def __init__( + self, + infection_feedback_strength: RandomVariable, + infection_feedback_pmf: RandomVariable, + infections_mean_varname: str = "latent_infections", + ) -> None: + """ + Default constructor for Infections class. + + Parameters + ---------- + infection_feedback_strength : RandomVariable + Infection feedback strength. + infection_feedback_pmf : RandomVariable + Infection feedback pmf. + infections_mean_varname : str, optional + Name to be assigned to the deterministic variable in the model. + Defaults to "latent_infections". + + Returns + ------- + None + """ + + self.validate(infection_feedback_strength, infection_feedback_pmf) + + self.infection_feedback_strength = infection_feedback_strength + self.infection_feedback_pmf = infection_feedback_pmf + self.infections_mean_varname = infections_mean_varname + + return None + + @staticmethod + def validate( + inf_feedback_strength: any, + inf_feedback_pmf: any, + ) -> None: # numpydoc ignore=GL08 + """ + Validates the input parameters. + + Parameters + ---------- + inf_feedback_strength : RandomVariable + Infection feedback strength. + inf_feedback_pmf : RandomVariable + Infection feedback pmf. + + Returns + ------- + None + """ + _assert_sample_and_rtype(inf_feedback_strength) + _assert_sample_and_rtype(inf_feedback_pmf) + + return None + + def sample( + self, + Rt: ArrayLike, + I0: ArrayLike, + gen_int: ArrayLike, + **kwargs, + ) -> InfectionsRtFeedbackSample: + """ + Samples infections given Rt, initial infections, and generation + interval. + + Parameters + ---------- + Rt : ArrayLike + Reproduction number. + I0 : ArrayLike + Initial infections, as an array + at least as long as the + interval PMF. + gen_int : ArrayLike + Generation interval PMF. + **kwargs : dict, optional + Additional keyword arguments passed through to internal + sample calls, should there be any. + + Returns + ------- + InfectionsWithFeedback + Named tuple with "infections". + """ + if I0.size < gen_int.size: + raise ValueError( + "Initial infections must be at least as long as the " + f"generation interval. Got {I0.size} initial infections " + f"and {gen_int.size} generation interval." + ) + + gen_int_rev = jnp.flip(gen_int) + + I0 = I0[-gen_int_rev.size :] + + # Sampling inf feedback strength + inf_feedback_strength, *_ = self.infection_feedback_strength.sample( + **kwargs, + ) + + # Making sure inf_feedback_strength spans the Rt length + if inf_feedback_strength.size == 1: + inf_feedback_strength = du.pad_x_to_match_y( + x=inf_feedback_strength, + y=Rt, + fill_value=inf_feedback_strength[0], + ) + elif inf_feedback_strength.size != Rt.size: + raise ValueError( + "Infection feedback strength must be of size 1 or the same " + f"size as the reproduction number. Got {inf_feedback_strength.size} " + f"and {Rt.size} respectively." + ) + + # Sampling inf feedback pmf + inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs) + + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) + + all_infections, Rt_adj = inf.compute_infections_from_rt_with_feedback( + I0=I0, + Rt_raw=Rt, + infection_feedback_strength=inf_feedback_strength, + reversed_generation_interval_pmf=gen_int_rev, + reversed_infection_feedback_pmf=inf_fb_pmf_rev, + ) + + npro.deterministic("Rt_adjusted", Rt_adj) + + return InfectionsRtFeedbackSample( + infections=all_infections, + rt=Rt_adj, + ) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index eeb9fc76..70d73359 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -8,6 +8,7 @@ from typing import NamedTuple, get_type_hints import jax +import jax.numpy as jnp import matplotlib.pyplot as plt import numpyro as npro import polars as pl @@ -216,10 +217,12 @@ def sample( DistributionalRVSample """ return DistributionalRVSample( - value=npro.sample( - name=self.name, - fn=self.dist, - obs=obs, + value=jnp.atleast_1d( + npro.sample( + name=self.name, + fn=self.dist, + obs=obs, + ) ), ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index c98f9ce8..4d161d42 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -6,7 +6,8 @@ from typing import NamedTuple import jax.numpy as jnp -from jax.typing import ArrayLike +import pyrenew.datautils as du +from numpy.typing import ArrayLike from pyrenew.deterministic import NullObservation from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype @@ -284,6 +285,14 @@ def sample( # Sampling initial infections i0, *_ = self.sample_i0(**kwargs) + # Padding i0 to match gen_int + # PADDING SHOULD BE REMOVED ONCE + # https://github.com/CDCgov/multisignal-epi-inference/pull/124 + # is merged. + # SEE ALSO: + # https://github.com/CDCgov/multisignal-epi-inference/pull/123#discussion_r1612337288 + i0 = du.pad_x_to_match_y(x=i0, y=gen_int, fill_value=0.0) + # Sampling from the latent process latent, *_ = self.sample_infections_latent( Rt=Rt, diff --git a/model/src/pyrenew/transform.py b/model/src/pyrenew/transform.py new file mode 100755 index 00000000..2bed7fca --- /dev/null +++ b/model/src/pyrenew/transform.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +""" +Transform classes for PyRenew +""" +from __future__ import annotations + +from abc import ABCMeta, abstractmethod + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike + + +class AbstractTransform(metaclass=ABCMeta): + """ + Abstract base class for transformations + """ + + def __call__(self, x): + return self.transform(x) + + @abstractmethod + def transform(self, x): + """ + Transform generated predictions + """ + pass + + @abstractmethod + def inverse(self, x): + """ + Take the inverse of transformed predictions + """ + pass + + +class IdentityTransform(AbstractTransform): + """ + Identity transformation, which + is its own inverse. + + f(x) = x + f^-1(x) = x + """ + + def transform(self, x: any): # numpydoc ignore=SS01 + """ + Transform function + + Parameters + ---------- + x : any + Input, usually ArrayLike + + Returns + ------- + any + The same object that was inputted. + """ + return x + + def inverse(self, x: any): # numpydoc ignore=SS01 + """ + Inverse function + + Parameters + ---------- + x : any + Input, usually ArrayLike + + Returns + ------- + any + The same object that was inputted. + """ + return x + + +class LogTransform(AbstractTransform): + """ + Logarithmic (base e) transformation, whose + inverse is exponentiation. + + f(x) = log(x) + f^-1(x) = exp(x) + """ + + def transform(self, x: ArrayLike): # numpydoc ignore=SS01 + """ + Log transform function + + Parameters + ---------- + x : ArrayLike + Input, usually predictions array.. + + Returns + ------- + ArrayLike + Log-transformed input + """ + return jnp.log(x) + + def inverse(self, x: ArrayLike): # numpydoc ignore=SS01 + """ + Inverse of log transform function + + Parameters + ---------- + x : ArrayLike + Input, usually log-scale predictions array. + + Returns + ------- + ArrayLike + Exponentiated input + """ + return jnp.exp(x) + + +class LogitTransform(AbstractTransform): + """ + Logistic transformation, whose + inverse is the inverse logit or + 'expit' function: + + f(x) = log(x) - log(1 - x) + f^-1(x) = 1 / (1 + exp(-x)) + """ + + def transform(self, x: ArrayLike): # numpydoc ignore=SS01 + """ + Logit transform function + + Parameters + ---------- + x : ArrayLike + Input, usually predictions array. + + Returns + ------- + ArrayLike + Logit transformed input. + """ + return jax.scipy.special.logit(x) + + def inverse(self, x: ArrayLike): # numpydoc ignore=SS01 + """ + Inverse of logit transform function + + Parameters + ---------- + x : ArrayLike + Input, usually logit-transformed predictions array. + + Returns + ------- + ArrayLike + Inversed logit transformed input. + """ + return jax.scipy.special.expit(x) + + +class ScaledLogitTransform(AbstractTransform): + """ + Scaled logistic transformation from the + interval (0, X_max) to the interval + (-infinity, +infinity). + It's inverse is the inverse logit or + 'expit' function multiplied by X_max + f(x) = log(x/X_max) - log(1 - x/X_max) + f^-1(x) = X_max / (1 + exp(-x)) + """ + + def __init__(self, x_max: float): # numpydoc ignore=RT01 + """ + Default constructor + + Parameters + ---------- + x_max : float + Maximum value on the untransformed scale + (will be transformed to +infinity) + """ + self.x_max = x_max + + def transform(self, x: ArrayLike): # numpydoc ignore=SS01 + """ + Scaled logit transform function + + Parameters + ---------- + x : ArrayLike + Input, usually predictions array. + + Returns + ------- + ArrayLike + x_max scaled logit transformed input. + """ + return jax.scipy.special.logit(x / self.x_max) + + def inverse(self, x: ArrayLike): # numpydoc ignore=SS01 + """ + Inverse of scaled logit transform function + + Parameters + ---------- + x : ArrayLike + Input, usually scaled logit predictions array. + + Returns + ------- + ArrayLike + Inverse of x_max scaled logit transformed input. + """ + return self.x_max * jax.scipy.special.expit(x) diff --git a/model/src/test/baseline/test_model_basicrenewal_plot.png b/model/src/test/baseline/test_model_basicrenewal_plot.png index 80a0d656..c1ddb461 100644 Binary files a/model/src/test/baseline/test_model_basicrenewal_plot.png and b/model/src/test/baseline/test_model_basicrenewal_plot.png differ diff --git a/model/src/test/test_datautils.py b/model/src/test/test_datautils.py new file mode 100644 index 00000000..860c5f60 --- /dev/null +++ b/model/src/test/test_datautils.py @@ -0,0 +1,49 @@ +""" +Tests for the datautils module. +""" + +import jax.numpy as jnp +import pyrenew.datautils as du +import pytest + + +def test_datautils_pad_to_match(): + """ + Verifies extension when required and error when `fix_y` is True. + """ + + x = jnp.array([1, 2, 3]) + y = jnp.array([1, 2]) + + x_pad, y_pad = du.pad_to_match(x, y) + + assert x_pad.size == y_pad.size + assert x_pad.size == 3 + + x = jnp.array([1, 2]) + y = jnp.array([1, 2, 3]) + + x_pad, y_pad = du.pad_to_match(x, y) + + assert x_pad.size == y_pad.size + assert x_pad.size == 3 + + x = jnp.array([1, 2, 3]) + y = jnp.array([1, 2]) + + # Verify that the function raises an error when `fix_y` is True + with pytest.raises(ValueError): + x_pad, y_pad = du.pad_to_match(x, y, fix_y=True) + + +def test_datautils_pad_x_to_match_y(): + """ + Verifies extension when required + """ + + x = jnp.array([1, 2]) + y = jnp.array([1, 2, 3]) + + x_pad = du.pad_x_to_match_y(x, y) + + assert x_pad.size == 3 diff --git a/model/src/test/test_infection_functions.py b/model/src/test/test_infection_functions.py new file mode 100644 index 00000000..a7a9f407 --- /dev/null +++ b/model/src/test/test_infection_functions.py @@ -0,0 +1,58 @@ +""" +Test functions from the latent.infection_functions +submodule +""" + +import jax.numpy as jnp +from numpy.testing import assert_array_equal +from pyrenew.latent import infection_functions as inf + + +def test_compute_infections_from_rt_with_feedback(): + """ + test that the implementation of infection + feedback is as expected + """ + + # if feedback is zero, results should be + # equivalent to compute_infections_from_rt + # and Rt_adjusted should be Rt_raw + + gen_ints = [ + jnp.array([0.25, 0.5, 0.25]), + jnp.array([1.0]), + jnp.ones(35) / 35, + ] + + inf_pmfs = [jnp.ones_like(x) for x in gen_ints] + + I0s = [ + jnp.array([0.235, 6.523, 100052.0]), + jnp.array([5.0]), + 3.5235 * jnp.ones(35), + ] + + Rts_raw = [ + jnp.array([1.25, 0.52, 23.0, 1.0]), + jnp.ones(500), + jnp.zeros(253), + ] + + for I0, gen_int, inf_pmf in zip(I0s, gen_ints, inf_pmfs): + for Rt_raw in Rts_raw: + ( + infs_feedback, + Rt_adj, + ) = inf.compute_infections_from_rt_with_feedback( + I0, Rt_raw, jnp.zeros_like(Rt_raw), gen_int, inf_pmf + ) + + assert_array_equal( + inf.compute_infections_from_rt(I0, Rt_raw, gen_int), + infs_feedback, + ) + + assert_array_equal(Rt_adj, Rt_raw) + pass + pass + return None diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py new file mode 100644 index 00000000..ed1e9434 --- /dev/null +++ b/model/src/test/test_infectionsrtfeedback.py @@ -0,0 +1,146 @@ +""" +Test the InfectionsWithFeedback class +""" + +import jax.numpy as jnp +import numpy as np +import numpyro as npro +import pyrenew.latent as latent +from jax.typing import ArrayLike +from numpy.testing import assert_array_almost_equal, assert_array_equal +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable + + +def _infection_w_feedback_alt( + gen_int: ArrayLike, + Rt: ArrayLike, + I0: ArrayLike, + inf_feedback_strength: ArrayLike, + inf_feedback_pmf: ArrayLike, +) -> tuple: + """ + Calculate the infections with feedback. + Parameters + ---------- + gen_int : ArrayLike + Generation interval. + Rt : ArrayLike + Reproduction number. + I0 : ArrayLike + Initial infections. + inf_feedback_strength : ArrayLike + Infection feedback strength. + inf_feedback_pmf : ArrayLike + Infection feedback pmf. + + Returns + ------- + tuple + """ + + Rt = np.array(Rt) # coerce from jax to use numpy-like operations + T = len(Rt) + len_gen = len(gen_int) + I_vec = np.concatenate([I0, np.zeros(T)]) + Rt_adj = np.zeros(T) + + for t in range(T): + Rt_adj[t] = Rt[t] * np.exp( + inf_feedback_strength[t] + * np.dot(I_vec[t : t + len_gen], np.flip(inf_feedback_pmf)) + ) + + I_vec[t + len_gen] = Rt_adj[t] * np.dot( + I_vec[t : t + len_gen], np.flip(gen_int) + ) + + return {"infections": I_vec[-T:], "rt": Rt_adj} + + +def test_infectionsrtfeedback(): + """ + Test the InfectionsWithFeedback matching the Infections class. + """ + + Rt = jnp.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25]) + I0 = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) + gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) + + # By doing the infection feedback strength 0, Rt = Rt_adjusted + # So infection should be equal in both + inf_feed_strength = DeterministicVariable(jnp.zeros_like(Rt)) + inf_feedback_pmf = DeterministicPMF(gen_int) + + # Test the InfectionsWithFeedback class + InfectionsWithFeedback = latent.InfectionsWithFeedback( + infection_feedback_strength=inf_feed_strength, + infection_feedback_pmf=inf_feedback_pmf, + ) + + infections = latent.Infections() + + with npro.handlers.seed(rng_seed=0): + samp1 = InfectionsWithFeedback.sample( + gen_int=gen_int, + Rt=Rt, + I0=I0, + ) + + samp2 = infections.sample( + gen_int=gen_int, + Rt=Rt, + I0=I0, + ) + + assert_array_equal(samp1.infections, samp2.infections) + assert_array_equal(samp1.rt, Rt) + + return None + + +def test_infectionsrtfeedback_feedback(): + """ + Test the InfectionsWithFeedback with feedback + """ + + Rt = jnp.array([0.5, 0.6, 1.5, 2.523, 0.7, 0.8]) + I0 = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) + gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) + + inf_feed_strength = DeterministicVariable(jnp.repeat(0.5, len(Rt))) + inf_feedback_pmf = DeterministicPMF(gen_int) + + # Test the InfectionsWithFeedback class + InfectionsWithFeedback = latent.InfectionsWithFeedback( + infection_feedback_strength=inf_feed_strength, + infection_feedback_pmf=inf_feedback_pmf, + ) + + infections = latent.Infections() + + with npro.handlers.seed(rng_seed=0): + samp1 = InfectionsWithFeedback.sample( + gen_int=gen_int, + Rt=Rt, + I0=I0, + ) + + samp2 = infections.sample( + gen_int=gen_int, + Rt=Rt, + I0=I0, + ) + + res = _infection_w_feedback_alt( + gen_int=gen_int, + Rt=Rt, + I0=I0, + inf_feedback_strength=inf_feed_strength.sample()[0], + inf_feedback_pmf=inf_feedback_pmf.sample()[0], + ) + + assert not jnp.array_equal(samp1.infections, samp2.infections) + assert_array_almost_equal(samp1.infections, res["infections"]) + assert_array_almost_equal(samp1.rt, res["rt"]) + + return None diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 8d61aef0..144654aa 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -24,8 +24,8 @@ def test_admissions_sample(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): sim_rt, *_ = rt.sample(n_timepoints=30) - gen_int = jnp.array([0.25, 0.25, 0.25, 0.25]) - i0 = 10 + gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) + i0 = 10 * jnp.ones_like(gen_int) inf1 = Infections() diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index 96df037f..58a06989 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -5,14 +5,15 @@ import numpy as np import numpy.testing as testing import numpyro as npro +import pytest from pyrenew.latent import Infections from pyrenew.process import RtRandomWalkProcess def test_infections_as_deterministic(): """ - Check that an InfectionObservation - can be initialized and sampled from (deterministic) + Test that the Infections class samples the same infections when + the same seed is used. """ np.random.seed(223) @@ -24,12 +25,21 @@ def test_infections_as_deterministic(): inf1 = Infections() + obs = dict( + Rt=sim_rt, + I0=jnp.zeros(gen_int.size), + gen_int=gen_int, + ) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - obs = dict(Rt=sim_rt, I0=10, gen_int=gen_int) inf_sampled1 = inf1.sample(**obs) inf_sampled2 = inf1.sample(**obs) - # Should match! testing.assert_array_equal( inf_sampled1.infections, inf_sampled2.infections ) + + # Check that Initial infections vector must be at least as long as the generation interval. + with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with pytest.raises(ValueError): + obs["I0"] = jnp.array([1]) + inf1.sample(**obs) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index f7ad04c9..0b8087a1 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -132,7 +132,26 @@ def test_model_basicrenewal_with_obs_model(): @pytest.mark.mpl_image_compare -def test_model_basicrenewal_plot() -> plt.Figure: # numpydoc ignore=GL08 +def test_model_basicrenewal_plot() -> plt.Figure: + """ + Check that the posterior sample looks the same (reproducibility) + + Returns + ------- + plt.Figure + The figure object + + Notes + ----- + IMPORTANT: If this test fails, it may be that you need + to regenerate the figures. To do so, you can the test using the following + command: + + poetry run pytest --mpl-generate-path=src/test/baseline + + This will skip validating the figure and save the new figure in the + `src/test/baseline` folder. + """ gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")