Skip to content

Commit

Permalink
Reintroduce sampling_jax.py for backward compatibility
Browse files Browse the repository at this point in the history
This is a separate commit to make sure that git tracks the rename
of the old `sampling_jax.py` to `sampling/jax.py` correctly.
  • Loading branch information
michaelosthege authored and ricardoV94 committed Nov 7, 2022
1 parent 80fc108 commit 781d974
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
warnings.warn("This module is experimental.")


__all__ = (
"get_jaxified_graph",
"get_jaxified_logp",
"sample_blackjax_nuts",
"sample_numpyro_nuts",
)


@jax_funcify.register(Assert)
@jax_funcify.register(CheckParameterValue)
@jax_funcify.register(SpecifyShape)
Expand Down
7 changes: 7 additions & 0 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# This file exists only for backward-compatibility with imports like
# `import pymc.sampling_jax` or `from pymc import sampling_jax`.

# pylint: disable=wildcard-import
# pylint: disable=unused-wildcard-import

from pymc.sampling.jax import *
8 changes: 8 additions & 0 deletions pymc/tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

import pymc as pm


def test_old_import_route():
import pymc.sampling.jax as new_sj
import pymc.sampling_jax as old_sj

assert set(new_sj.__all__) <= set(dir(old_sj))


with pytest.warns(UserWarning, match="module is experimental"):
from pymc.sampling.jax import (
_get_batched_jittered_initial_points,
Expand Down
1 change: 1 addition & 0 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
pymc/ode/ode.py
pymc/ode/utils.py
pymc/plots/__init__.py
pymc/sampling_jax.py
pymc/sampling/__init__.py
pymc/sampling/forward.py
pymc/sampling/mcmc.py
Expand Down

0 comments on commit 781d974

Please sign in to comment.