Skip to content

Commit

Permalink
Rename logp_transform to _get_default_transform and move it to tr…
Browse files Browse the repository at this point in the history
…ansforms.py
  • Loading branch information
ricardoV94 committed Mar 18, 2022
1 parent e8c07ef commit 8b9987a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 15 deletions.
2 changes: 0 additions & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from pymc.distributions.logprob import ( # isort:skip
logcdf,
logp,
logp_transform,
joint_logpt,
)

Expand Down Expand Up @@ -195,6 +194,5 @@
"PolyaGamma",
"joint_logpt",
"logp",
"logp_transform",
"logcdf",
]
9 changes: 5 additions & 4 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def polyagamma_cdf(*args, **kwargs):
from scipy.special import expit

from pymc.aesaraf import floatX
from pymc.distributions import logp_transform, transforms
from pymc.distributions import transforms
from pymc.distributions.dist_math import (
SplineWrapper,
check_parameters,
Expand All @@ -87,6 +87,7 @@ def polyagamma_cdf(*args, **kwargs):
)
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.transforms import _get_default_transform
from pymc.math import invlogit, logdiffexp, logit
from pymc.util import UNSET

Expand Down Expand Up @@ -139,17 +140,17 @@ class CircularContinuous(Continuous):
"""Base class for circular continuous distributions"""


@logp_transform.register(PositiveContinuous)
@_get_default_transform.register(PositiveContinuous)
def pos_cont_transform(op):
return transforms.log


@logp_transform.register(UnitContinuous)
@_get_default_transform.register(UnitContinuous)
def unit_cont_transform(op):
return transforms.logodds


@logp_transform.register(CircularContinuous)
@_get_default_transform.register(CircularContinuous)
def circ_cont_transform(op):
return transforms.circular

Expand Down
7 changes: 0 additions & 7 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from collections.abc import Mapping
from functools import singledispatch
from typing import Dict, List, Optional, Sequence, Union

import aesara
Expand All @@ -25,7 +24,6 @@
from aeppl.logprob import logprob as logp_aeppl
from aeppl.transforms import TransformValuesOpt
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.graph.op import Op
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand All @@ -39,11 +37,6 @@
from pymc.aesaraf import floatX


@singledispatch
def logp_transform(op: Op):
return None


def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int):
"""
Gets scaling constant for logp.
Expand Down
8 changes: 8 additions & 0 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import singledispatch

import aesara.tensor as at

Expand All @@ -22,6 +23,7 @@
RVTransform,
Simplex,
)
from aesara.graph import Op

__all__ = [
"RVTransform",
Expand All @@ -38,6 +40,12 @@
]


@singledispatch
def _get_default_transform(op: Op):
"""Return default transform for a given Distribution `Op`"""
return None


class LogExpM1(RVTransform):
name = "log_exp_m1"

Expand Down
5 changes: 3 additions & 2 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@
)
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import GenTensorVariable, Minibatch
from pymc.distributions import joint_logpt, logp_transform
from pymc.distributions import joint_logpt
from pymc.distributions.logprob import _get_scaling
from pymc.distributions.transforms import _get_default_transform
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
from pymc.initial_point import make_initial_point_fn
from pymc.math import flatten_list
Expand Down Expand Up @@ -1419,7 +1420,7 @@ def create_value_var(
# Make the value variable a transformed value variable,
# if there's an applicable transform
if transform is UNSET and rv_var.owner:
transform = logp_transform(rv_var.owner.op)
transform = _get_default_transform(rv_var.owner.op)

if transform is not None and transform is not UNSET:
value_var.tag.transform = transform
Expand Down

0 comments on commit 8b9987a

Please sign in to comment.