Skip to content

Commit

Permalink
Do not make transforms module accessible at root level
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 18, 2022
1 parent f12edd3 commit 3a184e7
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/api/distributions/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Transformations
***************

.. currentmodule:: pymc.transforms
.. currentmodule:: pymc.distributions.transforms

Transform Instances
~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 0 additions & 1 deletion pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __set_compiler_flags():
from pymc.blocking import *
from pymc.data import *
from pymc.distributions import *
from pymc.distributions import transforms
from pymc.exceptions import *
from pymc.func_utils import find_constrained_prior
from pymc.math import (
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2710,7 +2710,7 @@ def test_arguments_checks(self):
with pm.Model() as m:
x = pm.Poisson.dist(0.5)
with pytest.raises(ValueError, match=msg):
pm.Bound("bound", x, transform=pm.transforms.log)
pm.Bound("bound", x, transform=pm.distributions.transforms.log)

msg = "Given dims do not exist in model coordinates."
with pm.Model() as m:
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,11 @@ def test_deterministic_of_unobserved(self):

np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100)

def test_transform_with_rv_depenency(self):
def test_transform_with_rv_dependency(self):
# Test that untransformed variables that depend on upstream variables are properly handled
with pm.Model() as m:
x = pm.HalfNormal("x", observed=1)
transform = pm.transforms.IntervalTransform(lambda *inputs: (inputs[-2], inputs[-1]))
transform = pm.distributions.transforms.IntervalTransform(lambda *inputs: (inputs[-2], inputs[-1]))
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336)

Expand Down

0 comments on commit 3a184e7

Please sign in to comment.