Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in univariate Ordered and SumTo1 transform logp #6903

Merged
merged 2 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/distributions/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Transform instances are the entities that should be used in the
logodds
simplex
sum_to_1
ordered


Specific Transform Classes
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ class OrderedLogistic:
# Ordered logistic regression
with pm.Model() as model:
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
transform=pm.distributions.transforms.univariate_ordered)
transform=pm.distributions.transforms.ordered)
y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y)
idata = pm.sample()

Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class NormalMixture:
mu=data.mean(),
sigma=10,
shape=n_components,
transform=pm.distributions.transforms.univariate_ordered,
transform=pm.distributions.transforms.ordered,
initval=[1, 2, 3],
)
σ = pm.HalfNormal("σ", sigma=10, shape=n_components)
Expand Down
86 changes: 33 additions & 53 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from functools import singledispatch

import numpy as np
Expand Down Expand Up @@ -39,19 +41,28 @@
"logodds",
"Interval",
"log_exp_m1",
"univariate_ordered",
"multivariate_ordered",
"ordered",
"log",
"sum_to_1",
"univariate_sum_to_1",
"multivariate_sum_to_1",
"circular",
"CholeskyCovPacked",
"Chain",
"ZeroSumTransform",
]


def __getattr__(name):
if name in ("univariate_ordered", "multivariate_ordered"):
warnings.warn(f"{name} has been deprecated, use ordered instead.", FutureWarning)
return ordered

if name in ("univariate_sum_to_1, multivariate_sum_to_1"):
warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning)
return sum_to_1

raise AttributeError(f"module {__name__} has no attribute {name}")


@singledispatch
def _default_transform(op: Op, rv: TensorVariable):
"""Return default transform for a given Distribution `Op`"""
Expand Down Expand Up @@ -79,31 +90,24 @@ def log_jac_det(self, value, *inputs):
class Ordered(RVTransform):
name = "ordered"

def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For Ordered transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp
def __init__(self, ndim_supp=None):
if ndim_supp is not None:
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)

def backward(self, value, *inputs):
x = pt.zeros(value.shape)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
x = pt.inc_subtensor(x[..., 0], value[..., 0])
x = pt.inc_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
x = pt.set_subtensor(x[..., 0], value[..., 0])
x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
return pt.cumsum(x, axis=-1)

def forward(self, value, *inputs):
y = pt.zeros(value.shape)
y = pt.inc_subtensor(y[..., 0], value[..., 0])
y = pt.inc_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
y = pt.set_subtensor(y[..., 0], value[..., 0])
y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
return y

def log_jac_det(self, value, *inputs):
if self.ndim_supp == 0:
return pt.sum(value[..., 1:], axis=-1, keepdims=True)
else:
return pt.sum(value[..., 1:], axis=-1)
return pt.sum(value[..., 1:], axis=-1)


class SumTo1(RVTransform):
Expand All @@ -114,13 +118,9 @@ class SumTo1(RVTransform):

name = "sumto1"

def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For SumTo1 transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp
def __init__(self, ndim_supp=None):
if ndim_supp is not None:
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)

def backward(self, value, *inputs):
remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True)
Expand All @@ -131,10 +131,7 @@ def forward(self, value, *inputs):

def log_jac_det(self, value, *inputs):
y = pt.zeros(value.shape)
if self.ndim_supp == 0:
return pt.sum(y, axis=-1, keepdims=True)
else:
return pt.sum(y, axis=-1)
return pt.sum(y, axis=-1)


class CholeskyCovPacked(RVTransform):
Expand Down Expand Up @@ -359,38 +356,21 @@ def extend_axis_rev(array, axis):
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
for use in the ``transform`` argument of a random variable."""

univariate_ordered = Ordered(ndim_supp=0)
univariate_ordered.__doc__ = """
# Deprecated
ordered = Ordered()
ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a univariate random variable."""

multivariate_ordered = Ordered(ndim_supp=1)
multivariate_ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a multivariate random variable."""
for use in the ``transform`` argument of a random variable."""

log = LogTransform()
log.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.LogTransform`
for use in the ``transform`` argument of a random variable."""

univariate_sum_to_1 = SumTo1(ndim_supp=0)
univariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a univariate random variable."""

multivariate_sum_to_1 = SumTo1(ndim_supp=1)
multivariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a multivariate random variable."""

# backwards compatibility
sum_to_1 = SumTo1(ndim_supp=1)
sum_to_1 = SumTo1()
sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a random variable.
This instantiation is for backwards compatibility only.
Please use `univariate_sum_to_1` or `multivariate_sum_to_1` instead."""
for use in the ``transform`` argument of a random variable."""

circular = CircularTransform()
circular.__doc__ = """
Expand Down
55 changes: 41 additions & 14 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@
phi_inv = self.backward(value, *inputs)
return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0]))))

def __str__(self):
return f"{self.__class__.__name__}"


@node_rewriter(tracks=None)
def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
Expand Down Expand Up @@ -1219,22 +1222,46 @@
if not isinstance(logprobs, Sequence):
logprobs = [logprobs]

if use_jacobian:
assert len(values) == len(logprobs) == len(op.transforms)
logprobs_jac = []
for value, transform, logp in zip(values, op.transforms, logprobs):
if transform is None:
logprobs_jac.append(logp)
continue
assert isinstance(value.owner.op, TransformedVariable)
original_forward_value = value.owner.inputs[1]
jacobian = transform.log_jac_det(original_forward_value, *inputs).copy()
# Handle jacobian
assert len(values) == len(logprobs) == len(op.transforms)
logprobs_jac = []
for value, transform, logp in zip(values, op.transforms, logprobs):
if transform is None:
logprobs_jac.append(logp)
continue

Check warning on line 1231 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L1230-L1231

Added lines #L1230 - L1231 were not covered by tests

assert isinstance(value.owner.op, TransformedVariable)
original_forward_value = value.owner.inputs[1]
log_jac_det = transform.log_jac_det(original_forward_value, *inputs).copy()
# The jacobian determinant has less dims than the logp
# when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions.
# In this case we have to reduce the last logp dimensions, as they are no longer independent
if log_jac_det.ndim < logp.ndim:
diff_ndims = logp.ndim - log_jac_det.ndim
logp = logp.sum(axis=np.arange(-diff_ndims, 0))
# This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the
# multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html
elif log_jac_det.ndim > logp.ndim:
raise NotImplementedError(
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
)
else:
# Check there is no broadcasting between logp and jacobian
if logp.type.broadcastable != log_jac_det.type.broadcastable:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)

if use_jacobian:
if value.name:
jacobian.name = f"{value.name}_jacobian"
logprobs_jac.append(logp + jacobian)
logprobs = logprobs_jac
log_jac_det.name = f"{value.name}_jacobian"
logprobs_jac.append(logp + log_jac_det)
else:
# We still want to use the reduced logp, even though the jacobian isn't included
logprobs_jac.append(logp)

return logprobs
return logprobs_jac

new_op = copy(rv_op)
new_op.__class__ = new_op_type
Expand Down
Loading
Loading