diff --git a/docs/source/api/distributions/transforms.rst b/docs/source/api/distributions/transforms.rst index 1a11a71b62..ca4756688c 100644 --- a/docs/source/api/distributions/transforms.rst +++ b/docs/source/api/distributions/transforms.rst @@ -19,6 +19,7 @@ Transform instances are the entities that should be used in the logodds simplex sum_to_1 + ordered Specific Transform Classes diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index c41f727cab..d4c1e4585a 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1259,7 +1259,7 @@ class OrderedLogistic: # Ordered logistic regression with pm.Model() as model: cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2, - transform=pm.distributions.transforms.univariate_ordered) + transform=pm.distributions.transforms.ordered) y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y) idata = pm.sample() diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 5864dcb289..7c2c6f9752 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -539,7 +539,7 @@ class NormalMixture: mu=data.mean(), sigma=10, shape=n_components, - transform=pm.distributions.transforms.univariate_ordered, + transform=pm.distributions.transforms.ordered, initval=[1, 2, 3], ) σ = pm.HalfNormal("σ", sigma=10, shape=n_components) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index f3aff75fed..2b828d5c6d 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings + from functools import singledispatch import numpy as np @@ -39,12 +41,9 @@ "logodds", "Interval", "log_exp_m1", - "univariate_ordered", - "multivariate_ordered", + "ordered", "log", "sum_to_1", - "univariate_sum_to_1", - "multivariate_sum_to_1", "circular", "CholeskyCovPacked", "Chain", @@ -52,6 +51,18 @@ ] +def __getattr__(name): + if name in ("univariate_ordered", "multivariate_ordered"): + warnings.warn(f"{name} has been deprecated, use ordered instead.", FutureWarning) + return ordered + + if name in ("univariate_sum_to_1, multivariate_sum_to_1"): + warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning) + return sum_to_1 + + raise AttributeError(f"module {__name__} has no attribute {name}") + + @singledispatch def _default_transform(op: Op, rv: TensorVariable): """Return default transform for a given Distribution `Op`""" @@ -79,31 +90,24 @@ def log_jac_det(self, value, *inputs): class Ordered(RVTransform): name = "ordered" - def __init__(self, ndim_supp=0): - if ndim_supp > 1: - raise ValueError( - f"For Ordered transformation number of core dimensions" - f"(ndim_supp) must not exceed 1 but is {ndim_supp}" - ) - self.ndim_supp = ndim_supp + def __init__(self, ndim_supp=None): + if ndim_supp is not None: + warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) def backward(self, value, *inputs): x = pt.zeros(value.shape) - x = pt.inc_subtensor(x[..., 0], value[..., 0]) - x = pt.inc_subtensor(x[..., 1:], pt.exp(value[..., 1:])) + x = pt.set_subtensor(x[..., 0], value[..., 0]) + x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:])) return pt.cumsum(x, axis=-1) def forward(self, value, *inputs): y = pt.zeros(value.shape) - y = pt.inc_subtensor(y[..., 0], value[..., 0]) - y = pt.inc_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1])) + y = pt.set_subtensor(y[..., 0], value[..., 0]) + y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1])) return y def log_jac_det(self, value, *inputs): - if self.ndim_supp == 0: - return pt.sum(value[..., 1:], axis=-1, keepdims=True) - else: - return pt.sum(value[..., 1:], axis=-1) + return pt.sum(value[..., 1:], axis=-1) class SumTo1(RVTransform): @@ -114,13 +118,9 @@ class SumTo1(RVTransform): name = "sumto1" - def __init__(self, ndim_supp=0): - if ndim_supp > 1: - raise ValueError( - f"For SumTo1 transformation number of core dimensions" - f"(ndim_supp) must not exceed 1 but is {ndim_supp}" - ) - self.ndim_supp = ndim_supp + def __init__(self, ndim_supp=None): + if ndim_supp is not None: + warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) def backward(self, value, *inputs): remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True) @@ -131,10 +131,7 @@ def forward(self, value, *inputs): def log_jac_det(self, value, *inputs): y = pt.zeros(value.shape) - if self.ndim_supp == 0: - return pt.sum(y, axis=-1, keepdims=True) - else: - return pt.sum(y, axis=-1) + return pt.sum(y, axis=-1) class CholeskyCovPacked(RVTransform): @@ -359,38 +356,21 @@ def extend_axis_rev(array, axis): Instantiation of :class:`pymc.distributions.transforms.LogExpM1` for use in the ``transform`` argument of a random variable.""" -univariate_ordered = Ordered(ndim_supp=0) -univariate_ordered.__doc__ = """ +# Deprecated +ordered = Ordered() +ordered.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.Ordered` -for use in the ``transform`` argument of a univariate random variable.""" - -multivariate_ordered = Ordered(ndim_supp=1) -multivariate_ordered.__doc__ = """ -Instantiation of :class:`pymc.distributions.transforms.Ordered` -for use in the ``transform`` argument of a multivariate random variable.""" +for use in the ``transform`` argument of a random variable.""" log = LogTransform() log.__doc__ = """ Instantiation of :class:`pymc.logprob.transforms.LogTransform` for use in the ``transform`` argument of a random variable.""" -univariate_sum_to_1 = SumTo1(ndim_supp=0) -univariate_sum_to_1.__doc__ = """ -Instantiation of :class:`pymc.distributions.transforms.SumTo1` -for use in the ``transform`` argument of a univariate random variable.""" - -multivariate_sum_to_1 = SumTo1(ndim_supp=1) -multivariate_sum_to_1.__doc__ = """ -Instantiation of :class:`pymc.distributions.transforms.SumTo1` -for use in the ``transform`` argument of a multivariate random variable.""" - -# backwards compatibility -sum_to_1 = SumTo1(ndim_supp=1) +sum_to_1 = SumTo1() sum_to_1.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.SumTo1` -for use in the ``transform`` argument of a random variable. -This instantiation is for backwards compatibility only. -Please use `univariate_sum_to_1` or `multivariate_sum_to_1` instead.""" +for use in the ``transform`` argument of a random variable.""" circular = CircularTransform() circular.__doc__ = """ diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index b5b528ea44..30e2318dcc 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -195,6 +195,9 @@ def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: phi_inv = self.backward(value, *inputs) return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) + def __str__(self): + return f"{self.__class__.__name__}" + @node_rewriter(tracks=None) def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: @@ -1219,22 +1222,46 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs): if not isinstance(logprobs, Sequence): logprobs = [logprobs] - if use_jacobian: - assert len(values) == len(logprobs) == len(op.transforms) - logprobs_jac = [] - for value, transform, logp in zip(values, op.transforms, logprobs): - if transform is None: - logprobs_jac.append(logp) - continue - assert isinstance(value.owner.op, TransformedVariable) - original_forward_value = value.owner.inputs[1] - jacobian = transform.log_jac_det(original_forward_value, *inputs).copy() + # Handle jacobian + assert len(values) == len(logprobs) == len(op.transforms) + logprobs_jac = [] + for value, transform, logp in zip(values, op.transforms, logprobs): + if transform is None: + logprobs_jac.append(logp) + continue + + assert isinstance(value.owner.op, TransformedVariable) + original_forward_value = value.owner.inputs[1] + log_jac_det = transform.log_jac_det(original_forward_value, *inputs).copy() + # The jacobian determinant has less dims than the logp + # when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions. + # In this case we have to reduce the last logp dimensions, as they are no longer independent + if log_jac_det.ndim < logp.ndim: + diff_ndims = logp.ndim - log_jac_det.ndim + logp = logp.sum(axis=np.arange(-diff_ndims, 0)) + # This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the + # multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html + elif log_jac_det.ndim > logp.ndim: + raise NotImplementedError( + f"Univariate transform {transform} cannot be applied to multivariate {rv_op}" + ) + else: + # Check there is no broadcasting between logp and jacobian + if logp.type.broadcastable != log_jac_det.type.broadcastable: + raise ValueError( + f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " + "There is a bug in the implementation of either one." + ) + + if use_jacobian: if value.name: - jacobian.name = f"{value.name}_jacobian" - logprobs_jac.append(logp + jacobian) - logprobs = logprobs_jac + log_jac_det.name = f"{value.name}_jacobian" + logprobs_jac.append(logp + log_jac_det) + else: + # We still want to use the reduced logp, even though the jacobian isn't included + logprobs_jac.append(logp) - return logprobs + return logprobs_jac new_op = copy(rv_op) new_op.__class__ = new_op_type diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 9912e0c9f6..4309b0e5e3 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -26,6 +26,7 @@ import pymc.distributions.transforms as tr from pymc.logprob.basic import transformed_conditional_logp +from pymc.logprob.transforms import RVTransform from pymc.pytensorf import floatX, jacobian from pymc.testing import ( Circ, @@ -149,14 +150,14 @@ def test_simplex_accuracy(): def test_sum_to_1(): - check_vector_transform(tr.univariate_sum_to_1, Simplex(2)) - check_vector_transform(tr.univariate_sum_to_1, Simplex(4)) + check_vector_transform(tr.sum_to_1, Simplex(2)) + check_vector_transform(tr.sum_to_1, Simplex(4)) - with pytest.raises(ValueError, match=r"\(ndim_supp\) must not exceed 1"): + with pytest.warns(FutureWarning, match="ndim_supp argument is deprecated"): tr.SumTo1(2) check_jacobian_det( - tr.univariate_sum_to_1, + tr.sum_to_1, Vector(Unit, 2), pt.vector, floatX(np.array([0, 0])), @@ -270,36 +271,33 @@ def test_circular(): def test_ordered(): - check_vector_transform(tr.univariate_ordered, SortedVector(6)) + check_vector_transform(tr.ordered, SortedVector(6)) - with pytest.raises(ValueError, match=r"\(ndim_supp\) must not exceed 1"): - tr.Ordered(2) + with pytest.warns(FutureWarning, match="ndim_supp argument is deprecated"): + tr.Ordered(1) check_jacobian_det( - tr.univariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False - ) - check_jacobian_det( - tr.multivariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False + tr.ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False ) - vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3))) + vals = get_values(tr.ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3))) close_to_logical(np.diff(vals) >= 0, True, tol) def test_chain_values(): - chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.ordered]) vals = get_values(chain_tranf, Vector(R, 5), pt.vector, floatX(np.zeros(5))) close_to_logical(np.diff(vals) >= 0, True, tol) def test_chain_vector_transform(): - chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.ordered]) check_vector_transform(chain_tranf, UnitSortedVector(3)) @pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.") def test_chain_jacob_det(): - chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) + chain_tranf = tr.Chain([tr.logodds, tr.ordered]) check_jacobian_det(chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False) @@ -311,90 +309,44 @@ def build_model(self, distfam, params, size, transform, initval=None): distfam("x", size=size, transform=transform, initval=initval, **params) return m - def check_transform_elementwise_logp(self, model): + def check_transform_elementwise_logp(self, model, vector_transform=False): x = model.free_RVs[0] x_val_transf = model.rvs_to_values[x] + transform = model.rvs_to_transforms[x] + x_val_untransf = transform.backward(x_val_transf, *x.owner.inputs) point = model.initial_point(0) test_array_transf = floatX(np.random.randn(*point[x_val_transf.name].shape)) - transform = model.rvs_to_transforms[x] - test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval() - - # Create input variable with same dimensionality as untransformed test_array - x_val_untransf = pt.constant(test_array_untransf).type() - - jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) - assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim - - v1 = ( - transformed_conditional_logp( - (x,), - rvs_to_values={x: x_val_transf}, - rvs_to_transforms={x: transform}, - jacobian=False, - )[0] - .sum() - .eval({x_val_transf: test_array_transf}) + test_array_untransf = x_val_untransf.eval({x_val_transf: test_array_transf}) + log_jac_det = transform.log_jac_det(x_val_transf, *x.owner.inputs) + + [transform_logp] = transformed_conditional_logp( + (x,), + rvs_to_values={x: x_val_transf}, + rvs_to_transforms={x: transform}, ) - v2 = ( - transformed_conditional_logp( - (x,), - rvs_to_values={x: x_val_untransf}, - rvs_to_transforms={}, - )[0] - .sum() - .eval({x_val_untransf: test_array_untransf}) + [untransform_logp] = transformed_conditional_logp( + (x,), + rvs_to_values={x: x_val_untransf}, + rvs_to_transforms={}, ) - close_to(v1, v2, tol) - - def check_vectortransform_elementwise_logp(self, model): - x = model.free_RVs[0] - x_val_transf = model.rvs_to_values[x] - - point = model.initial_point(0) - test_array_transf = floatX(np.random.randn(*point[x_val_transf.name].shape)) - transform = model.rvs_to_transforms[x] - test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval() - - # Create input variable with same dimensionality as untransformed test_array - x_val_untransf = pt.constant(test_array_untransf).type() - - jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) - # Original distribution is univariate - if x.owner.op.ndim_supp == 0: - tr_steps = getattr(transform, "transform_list", [transform]) - transform_keeps_dim = any( - [isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps] - ) - if transform_keeps_dim: - assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim - else: - assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1) - # Original distribution is multivariate + if vector_transform: + assert transform_logp.ndim == (x.ndim - 1) == log_jac_det.ndim else: - assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim - - a = ( - transformed_conditional_logp( - (x,), - rvs_to_values={x: x_val_transf}, - rvs_to_transforms={x: transform}, - jacobian=False, - )[0] - .sum() - .eval({x_val_transf: test_array_transf}) - ) - b = ( - transformed_conditional_logp( - (x,), - rvs_to_values={x: x_val_untransf}, - rvs_to_transforms={}, - )[0] - .sum() - .eval({x_val_untransf: test_array_untransf}) + assert transform_logp.ndim == x.ndim == log_jac_det.ndim + + transform_logp_eval = transform_logp.eval({x_val_transf: test_array_transf}) + untransform_logp_eval = untransform_logp.eval({x_val_untransf: test_array_untransf}) + log_jac_det_eval = log_jac_det.eval({x_val_transf: test_array_transf}) + # Summing the log_jac_det separately from the untransform_logp ensures there is no broadcasting between terms + np.testing.assert_allclose( + transform_logp_eval.sum(), + untransform_logp_eval.sum() + log_jac_det_eval.sum(), + rtol=tol, ) - # Hack to get relative tolerance - close_to(a, b, np.abs(0.5 * (a + b) * tol)) + + def check_vectortransform_elementwise_logp(self, model): + self.check_transform_elementwise_logp(model, vector_transform=True) @pytest.mark.parametrize( "sigma,size", @@ -490,7 +442,7 @@ def test_normal_ordered(self): {"mu": 0.0, "sigma": 1.0}, size=3, initval=np.asarray([-1.0, 1.0, 4.0]), - transform=tr.univariate_ordered, + transform=tr.ordered, ) self.check_vectortransform_elementwise_logp(model) @@ -508,7 +460,7 @@ def test_half_normal_ordered(self, sigma, size): {"sigma": sigma}, size=size, initval=initval, - transform=tr.Chain([tr.log, tr.univariate_ordered]), + transform=tr.Chain([tr.log, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -520,7 +472,7 @@ def test_exponential_ordered(self, lam, size): {"lam": lam}, size=size, initval=initval, - transform=tr.Chain([tr.log, tr.univariate_ordered]), + transform=tr.Chain([tr.log, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -542,7 +494,7 @@ def test_beta_ordered(self, a, b, size): {"alpha": a, "beta": b}, size=size, initval=initval, - transform=tr.Chain([tr.logodds, tr.univariate_ordered]), + transform=tr.Chain([tr.logodds, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -565,7 +517,7 @@ def transform_params(*inputs): {"lower": lower, "upper": upper}, size=size, initval=initval, - transform=tr.Chain([interval, tr.univariate_ordered]), + transform=tr.Chain([interval, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -579,7 +531,7 @@ def test_vonmises_ordered(self, mu, kappa, size): {"mu": mu, "kappa": kappa}, size=size, initval=initval, - transform=tr.Chain([tr.circular, tr.univariate_ordered]), + transform=tr.Chain([tr.circular, tr.ordered]), ) self.check_vectortransform_elementwise_logp(model) @@ -592,7 +544,7 @@ def test_vonmises_ordered(self, mu, kappa, size): floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3), - tr.Chain([tr.univariate_sum_to_1, tr.logodds]), + tr.Chain([tr.sum_to_1, tr.logodds]), ), ], ) @@ -614,14 +566,15 @@ def test_uniform_other(self, lower, upper, size, transform): (floatX(np.zeros(3)), floatX(np.diag(np.ones(3))), (4,), (4, 3)), ], ) - def test_mvnormal_ordered(self, mu, cov, size, shape): + @pytest.mark.parametrize("transform", (tr.ordered, tr.sum_to_1)) + def test_mvnormal_transform(self, mu, cov, size, shape, transform): initval = np.sort(np.random.randn(*shape)) model = self.build_model( pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, - transform=tr.multivariate_ordered, + transform=transform, ) self.check_vectortransform_elementwise_logp(model) @@ -652,93 +605,73 @@ def test_discrete_trafo(): err.match("Transformations for discrete distributions") -def test_2d_univariate_ordered(): - with pm.Model() as model: - x_1d = pm.Normal( - "x_1d", - mu=[-3, -1, 1, 2], - sigma=1, - size=(4,), - transform=tr.univariate_ordered, - ) - x_2d = pm.Normal( - "x_2d", - mu=[-3, -1, 1, 2], - sigma=1, - size=(10, 4), - transform=tr.univariate_ordered, - ) +def test_transform_univariate_dist_logp_shape(): + with pm.Model() as m: + pm.Uniform("x", shape=(4, 3), transform=tr.logodds) - log_p = model.compile_logp(sum=False)( - {"x_1d_ordered__": floatX(np.zeros((4,))), "x_2d_ordered__": floatX(np.zeros((10, 4)))} - ) - np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1]) + assert m.logp(jacobian=False, sum=False)[0].type.shape == (4, 3) + assert m.logp(jacobian=True, sum=False)[0].type.shape == (4, 3) + with pm.Model() as m: + pm.Uniform("x", shape=(4, 3), transform=tr.ordered) -def test_2d_multivariate_ordered(): - with pm.Model() as model: - x_1d = pm.MvNormal( - "x_1d", - mu=[-1, 1], - cov=np.eye(2), - initval=[-1, 1], - transform=tr.multivariate_ordered, - ) - x_2d = pm.MvNormal( - "x_2d", - mu=[-1, 1], - cov=np.eye(2), - size=2, - initval=[[-1, 1], [-1, 1]], - transform=tr.multivariate_ordered, - ) + assert m.logp(jacobian=False, sum=False)[0].type.shape == (4,) + assert m.logp(jacobian=True, sum=False)[0].type.shape == (4,) - log_p = model.compile_logp(sum=False)( - {"x_1d_ordered__": floatX(np.zeros((2,))), "x_2d_ordered__": floatX(np.zeros((2, 2)))} - ) - np.testing.assert_allclose(log_p[0], log_p[1]) +def test_univariate_transform_multivariate_dist_raises(): + with pm.Model() as m: + pm.Dirichlet("x", [1, 1, 1], transform=tr.log) -def test_2d_univariate_sum_to_1(): - with pm.Model() as model: - x_1d = pm.Normal( - "x_1d", - mu=[-3, -1, 1, 2], - sigma=1, - size=(4,), - transform=tr.univariate_sum_to_1, - ) - x_2d = pm.Normal( - "x_2d", - mu=[-3, -1, 1, 2], - sigma=1, - size=(10, 4), - transform=tr.univariate_sum_to_1, - ) + for jacobian in (True, False): + with pytest.raises( + NotImplementedError, + match="Univariate transform LogTransform cannot be applied to multivariate", + ): + m.logp(jacobian=jacobian) - log_p = model.compile_logp(sum=False)( - {"x_1d_sumto1__": floatX(np.zeros(3)), "x_2d_sumto1__": floatX(np.zeros((10, 3)))} - ) - np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1]) +def test_invalid_jacobian_broadcast_raises(): + class BuggyTransform(RVTransform): + name = "buggy" -def test_2d_multivariate_sum_to_1(): - with pm.Model() as model: - x_1d = pm.MvNormal( - "x_1d", - mu=[-1, 1], - cov=np.eye(2), - transform=tr.multivariate_sum_to_1, - ) - x_2d = pm.MvNormal( - "x_2d", - mu=[-1, 1], - cov=np.eye(2), - size=2, - transform=tr.multivariate_sum_to_1, - ) + def forward(self, value, *inputs): + return value - log_p = model.compile_logp(sum=False)( - {"x_1d_sumto1__": floatX(np.zeros(1)), "x_2d_sumto1__": floatX(np.zeros((2, 1)))} - ) - np.testing.assert_allclose(log_p[0], log_p[1]) + def backward(self, value, *inputs): + return value + + def log_jac_det(self, value, *inputs): + return pt.zeros_like(value.sum(-1, keepdims=True)) + + buggy_transform = BuggyTransform() + + with pm.Model() as m: + pm.Uniform("x", shape=(4, 3), transform=buggy_transform) + + for jacobian in (True, False): + with pytest.raises( + ValueError, + match="are not allowed to broadcast together. There is a bug in the implementation of either one", + ): + m.logp(jacobian=jacobian) + + +def test_deprecated_ndim_supp_transforms(): + with pytest.warns(FutureWarning, match="deprecated"): + tr.Ordered(ndim_supp=1) + + with pytest.warns(FutureWarning, match="deprecated"): + assert tr.univariate_ordered == tr.ordered + + with pytest.warns(FutureWarning, match="deprecated"): + assert tr.multivariate_ordered == tr.ordered + + with pytest.warns(FutureWarning, match="deprecated"): + tr.SumTo1(ndim_supp=1) + + with pytest.warns(FutureWarning, match="deprecated"): + assert tr.univariate_sum_to_1 == tr.sum_to_1 + + with pytest.warns(FutureWarning, match="deprecated"): + assert tr.multivariate_sum_to_1 == tr.sum_to_1