diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 1ff2c66e247..19864a52817 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -47,7 +47,7 @@ def find_observations(model: "Model") -> Dict[str, Var]: """If there are observations available, return them as a dictionary.""" observations = {} for obs in model.observed_RVs: - aux_obs = getattr(obs.tag, "observations", None) + aux_obs = model.rvs_to_values.get(obs, None) if aux_obs is not None: try: obs_data = extract_obs_data(aux_obs) @@ -261,7 +261,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun): if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)): try: - obs_data = extract_obs_data(var.tag.observations) + obs_data = extract_obs_data(self.model.rvs_to_values[var]) except TypeError: warnings.warn(f"Could not extract data from symbolic observation {var}") diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 5ada4d67ce1..0f362b3e676 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -16,7 +16,6 @@ logcdf, logp, joint_logp, - joint_logpt, ) from pymc.distributions.bound import Bound @@ -199,7 +198,6 @@ "Censored", "CAR", "PolyaGamma", - "joint_logpt", "joint_logp", "logp", "logcdf", diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 95c31a7c93c..b5d4e2642c9 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -25,7 +25,7 @@ from aeppl.logprob import logcdf as logcdf_aeppl from aeppl.logprob import logprob as logp_aeppl from aeppl.tensor import MeasurableJoin -from aeppl.transforms import TransformValuesRewrite +from aeppl.transforms import RVTransform, TransformValuesRewrite from aesara import tensor as at from aesara.graph.basic import graph_inputs, io_toposort from aesara.tensor.random.op import RandomVariable @@ -122,12 +122,26 @@ def _get_scaling( ) -def joint_logpt(*args, **kwargs): - warnings.warn( - "joint_logpt has been deprecated. Use joint_logp instead.", - FutureWarning, - ) - return joint_logp(*args, **kwargs) +def _check_no_rvs(logp_terms: Sequence[TensorVariable]): + # Raise if there are unexpected RandomVariables in the logp graph + # Only SimulatorRVs are allowed + from pymc.distributions.simulator import SimulatorRV + + unexpected_rv_nodes = [ + node + for node in aesara.graph.ancestors(logp_terms) + if ( + node.owner + and isinstance(node.owner.op, RandomVariable) + and not isinstance(node.owner.op, SimulatorRV) + ) + ] + if unexpected_rv_nodes: + raise ValueError( + f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n" + "This can happen when DensityDist logp or Interval transform functions " + "reference nonlocal variables." + ) def joint_logp( @@ -169,6 +183,10 @@ def joint_logp( Sum the log-likelihood or return each term as a separate list item. """ + warnings.warn( + "joint_logp has been deprecated, use model.logp instead", + FutureWarning, + ) # TODO: In future when we drop support for tag.value_var most of the following # logic can be removed and logp can just be a wrapper function that calls aeppl's # joint_logprob directly. @@ -241,26 +259,6 @@ def joint_logp( **kwargs, ) - # Raise if there are unexpected RandomVariables in the logp graph - # Only SimulatorRVs are allowed - from pymc.distributions.simulator import SimulatorRV - - unexpected_rv_nodes = [ - node - for node in aesara.graph.ancestors(list(temp_logp_var_dict.values())) - if ( - node.owner - and isinstance(node.owner.op, RandomVariable) - and not isinstance(node.owner.op, SimulatorRV) - ) - ] - if unexpected_rv_nodes: - raise ValueError( - f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n" - "This can happen when DensityDist logp or Interval transform functions " - "reference nonlocal variables." - ) - # aeppl returns the logp for every single value term we provided to it. This includes # the extra values we plugged in above, so we filter those we actually wanted in the # same order they were given in. @@ -268,6 +266,8 @@ def joint_logp( for value_var in rv_values.values(): logp_var_dict[value_var] = temp_logp_var_dict[value_var] + _check_no_rvs(list(logp_var_dict.values())) + if scaling: for value_var in logp_var_dict.keys(): if value_var in rv_scalings: @@ -281,6 +281,50 @@ def joint_logp( return logp_var +def _joint_logp( + rvs: Sequence[TensorVariable], + *, + rvs_to_values: Dict[TensorVariable, TensorVariable], + rvs_to_transforms: Optional[Dict[TensorVariable, RVTransform]] = None, + jacobian: bool = True, + rvs_to_total_size: Optional[Dict[TensorVariable, TensorVariable]] = None, + sum: bool = True, + **kwargs, +) -> List[TensorVariable]: + + transform_rewrite = None + if rvs_to_transforms: + values_to_transforms = {rvs_to_values[rv]: transform for rv, transform in rvs_to_transforms} + transform_rewrite = TransformValuesRewrite(values_to_transforms) + + temp_logp_terms = factorized_joint_logprob( + rvs_to_values, + extra_rewrites=transform_rewrite, + use_jacobian=jacobian, + **kwargs, + ) + + # aeppl returns the logp for every single value term we provided to it. This includes + # the extra values we plugged in above, so we filter those we actually wanted in the + # same order they were given in. + logp_terms = {} + for rv in rvs: + value_var = rvs_to_values[rv] + logp_term = temp_logp_terms[value_var] + total_size = rvs_to_total_size.get(rv, None) + if total_size: + scaling = _get_scaling(total_size, value_var.shape, value_var.ndim) + logp_term *= scaling + logp_terms[value_var] = logp_term + + _check_no_rvs(list(logp_terms.values())) + + if sum: + return at.sum([at.sum(factor) for factor in logp_terms.values()]) + else: + return list(logp_terms.values()) + + def logp(rv: TensorVariable, value) -> TensorVariable: """Return the log-probability graph of a Random Variable""" diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 7b09f856b16..9c54833a305 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -20,6 +20,7 @@ import aesara.tensor as at import numpy as np +from aeppl.transforms import RVTransform from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph from aesara.tensor.var import TensorVariable @@ -43,9 +44,7 @@ def convert_str_to_rv_dict( if isinstance(key, str): if is_transformed_name(key): rv = model[get_untransformed_name(key)] - initvals[rv] = model.rvs_to_values[rv].tag.transform.backward( - initval, *rv.owner.inputs - ) + initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs) else: initvals[model[key]] = initval else: @@ -158,7 +157,7 @@ def make_initial_point_fn( initial_values = make_initial_point_expression( free_rvs=model.free_RVs, - rvs_to_values=model.rvs_to_values, + rvs_to_transforms=model.rvs_to_transforms, initval_strategies=initval_strats, jitter_rvs=jitter_rvs, default_strategy=default_strategy, @@ -172,7 +171,7 @@ def make_initial_point_fn( varnames = [] for var in model.free_RVs: - transform = getattr(model.rvs_to_values[var].tag, "transform", None) + transform = model.rvs_to_transforms.get(var, None) if transform is not None and return_transformed: name = get_transformed_name(var.name, transform) else: @@ -197,7 +196,7 @@ def inner(seed, *args, **kwargs): def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], - rvs_to_values: Dict[TensorVariable, TensorVariable], + rvs_to_transforms: Dict[TensorVariable, RVTransform], initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], jitter_rvs: Set[TensorVariable] = None, default_strategy: str = "moment", @@ -265,7 +264,7 @@ def make_initial_point_expression( else: value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype) - transform = getattr(rvs_to_values[variable].tag, "transform", None) + transform = rvs_to_transforms.get(variable, None) if transform is not None: value = transform.forward(value, *variable.owner.inputs) diff --git a/pymc/model.py b/pymc/model.py index 2c640b1a699..2e9bc0e3959 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -60,8 +60,7 @@ ) from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, Minibatch -from pymc.distributions import joint_logp -from pymc.distributions.logprob import _get_scaling +from pymc.distributions.logprob import _joint_logp from pymc.distributions.transforms import _default_transform from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning from pymc.initial_point import make_initial_point_fn @@ -554,6 +553,8 @@ def __init__( self.named_vars = treedict(parent=self.parent.named_vars) self.values_to_rvs = treedict(parent=self.parent.values_to_rvs) self.rvs_to_values = treedict(parent=self.parent.rvs_to_values) + self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms) + self.rvs_to_total_size = treedict(parent=self.parent.rvs_to_total_size) self.free_RVs = treelist(parent=self.parent.free_RVs) self.observed_RVs = treelist(parent=self.parent.observed_RVs) self.auto_deterministics = treelist(parent=self.parent.auto_deterministics) @@ -566,6 +567,8 @@ def __init__( self.named_vars = treedict() self.values_to_rvs = treedict() self.rvs_to_values = treedict() + self.rvs_to_transforms = treedict() + self.rvs_to_total_size = treedict() self.free_RVs = treelist() self.observed_RVs = treelist() self.auto_deterministics = treelist() @@ -723,13 +726,13 @@ def logp( # We need to separate random variables from potential terms, and remember their # original order so that we can merge them together in the same order at the end - rv_values = {} + rvs = [] potentials = [] rv_order, potential_order = [], [] for i, var in enumerate(varlist): - value_var = self.rvs_to_values.get(var) - if value_var is not None: - rv_values[var] = value_var + rv = self.values_to_rvs.get(var, var) + if rv in self.basic_RVs: + rvs.append(rv) rv_order.append(i) else: if var in self.potentials: @@ -741,8 +744,15 @@ def logp( ) rv_logps: List[TensorVariable] = [] - if rv_values: - rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian) + if rvs: + rv_logps = _joint_logp( + rvs=rvs, + rvs_to_values=self.rvs_to_values, + rvs_to_transforms=self.rvs_to_transforms, + rvs_to_total_size=self.rvs_to_total_size, + jacobian=jacobian, + sum=False, + ) assert isinstance(rv_logps, list) # Replace random variables by their value variables in potential terms @@ -891,7 +901,7 @@ def unobserved_value_vars(self): untransformed_vars = [] for rv in self.free_RVs: value_var = self.rvs_to_values[rv] - transform = getattr(value_var.tag, "transform", None) + transform = self.rvs_to_transfroms.get(rv, None) if transform is not None: # We need to create and add an un-transformed version of # each transformed variable @@ -942,7 +952,7 @@ def basic_RVs(self): These are the actual random variable terms that make up the "sample-space" graph (i.e. you can sample these graphs by compiling them with `aesara.function`). If you want the corresponding log-likelihood terms, - use `var.tag.value_var`. + use `model.value_vars` instead. """ return self.free_RVs + self.observed_RVs @@ -953,7 +963,7 @@ def unobserved_RVs(self): These are the actual random variable terms that make up the "sample-space" graph (i.e. you can sample these graphs by compiling them with `aesara.function`). If you want the corresponding log-likelihood terms, - use `var.tag.value_var`. + use `var.unobserved_value_vars` instead. """ return self.free_RVs + self.deterministics @@ -978,17 +988,6 @@ def dim_lengths(self) -> Dict[str, Variable]: """ return self._dim_lengths - @property - def unobserved_RVs(self): - """List of all random variables, including deterministic ones. - - These are the actual random variable terms that make up the - "sample-space" graph (i.e. you can sample these graphs by compiling them - with `aesara.function`). If you want the corresponding log-likelihood terms, - use `var.tag.value_var`. - """ - return self.free_RVs + self.deterministics - @property def test_point(self) -> Dict[str, np.ndarray]: """Deprecated alias for `Model.initial_point(seed=None)`.""" @@ -1318,8 +1317,9 @@ def register_rv( """ name = self.name_for(name) rv_var.name = name + # TODO: Stop adding mapping information to the tag rv_var.tag.total_size = total_size - rv_var.tag.scaling = _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim) + self.rvs_to_total_size[rv_var] = total_size # Associate previously unknown dimension names with # the length of the corresponding RV dimension. @@ -1387,7 +1387,7 @@ def make_obs_var( if test_value is not None: # We try to reuse the old test value - rv_var.tag.test_value = np.broadcast_to(test_value, rv_var.tag.test_value.shape) + rv_var.tag.test_value = np.broadcast_to(test_value, rv_var.shape) else: rv_var.tag.test_value = data @@ -1455,6 +1455,7 @@ def make_obs_var( # if the size of the masked and unmasked array happened to coincide _, size, _, *inps = observed_rv_var.owner.inputs observed_rv_var = observed_rv_var.owner.op(*inps, size=size, name=f"{name}_observed") + # TODO: Stop adding mapping information to the tag observed_rv_var.tag.observations = nonmissing_data self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) @@ -1472,6 +1473,7 @@ def make_obs_var( data = sparse.basic.as_sparse(data, name=name) else: data = at.as_tensor_variable(data, name=name) + # TODO: Stop adding mapping information to the tag rv_var.tag.observations = data self.create_value_var(rv_var, transform=None, value_var=data) self.add_random_variable(rv_var, dims) @@ -1499,6 +1501,7 @@ def create_value_var( if aesara.config.compute_test_value != "off": value_var.tag.test_value = rv_var.tag.test_value + # TODO: Stop adding mapping information to the tag rv_var.tag.value_var = value_var # Make the value variable a transformed value variable, @@ -1507,6 +1510,7 @@ def create_value_var( transform = _default_transform(rv_var.owner.op, rv_var) if transform is not None and transform is not UNSET: + # TODO: Stop adding mapping information to the tag value_var.tag.transform = transform value_var.name = f"{value_var.name}_{transform.name}__" if aesara.config.compute_test_value != "off": @@ -1514,6 +1518,7 @@ def create_value_var( value_var, *rv_var.owner.inputs ).tag.test_value self.named_vars[value_var.name] = value_var + self.rvs_to_transforms[rv_var] = transform self.rvs_to_values[rv_var] = value_var self.values_to_rvs[value_var] = rv_var @@ -1677,8 +1682,7 @@ def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]: names = [] outputs = [] for rv in self.free_RVs: - rv_var = self.rvs_to_values[rv] - transform = getattr(rv_var.tag, "transform", None) + transform = self.rvs_to_transforms.get(rv, None) if transform is not None: names.append(get_transformed_name(rv.name, transform)) outputs.append(transform.forward(rv, *rv.owner.inputs).shape) @@ -1687,7 +1691,7 @@ def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]: f = aesara.function( inputs=[], outputs=outputs, - givens=[(obs, obs.tag.observations) for obs in self.observed_RVs], + givens=[(obs, self.rvs_to_values[obs]) for obs in self.observed_RVs], mode=aesara.compile.mode.FAST_COMPILE, on_unused_input="ignore", ) @@ -1996,7 +2000,6 @@ def Potential(name, var, model=None): """ model = modelcontext(model) var.name = model.name_for(name) - var.tag.scaling = 1.0 model.potentials.append(var) model.add_random_variable(var) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index f5dd3807e9c..93eed6ce01c 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -79,8 +79,8 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va raise ValueError(f"{var_name} is not in this model.") for model_var in self.var_list: - if hasattr(model_var.tag, "observations"): - if model_var.tag.observations == self.model[var_name]: + if model_var in self.model.observed_RVs: + if self.model.rvs_to_values[model_var] == self.model[var_name]: selected_names.add(model_var.name) selected_ancestors = set( @@ -91,8 +91,8 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va ) for var in selected_ancestors.copy(): - if hasattr(var.tag, "observations"): - selected_ancestors.add(var.tag.observations) + if var in self.model.observed_RVs: + selected_ancestors.add(self.model.rvs_to_values[var]) # ordering of self._all_var_names is important return [var.name for var in selected_ancestors] @@ -108,8 +108,8 @@ def make_compute_graph( parent_name = self.get_parent_names(var) input_map[var_name] = input_map[var_name].union(parent_name) - if hasattr(var.tag, "observations"): - obs_node = var.tag.observations + if var in self.model.observed_RVs: + obs_node = self.model.rvs_to_values[var] # loop created so that the elif block can go through this again # and remove any intermediate ops, notably dtype casting, to observations diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index a5450f951ae..ed91b77f3fa 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -388,7 +388,7 @@ def sample_prior_predictive( for name in sorted(missing_names): transformed_value_var = model[name] rv_var = model.values_to_rvs[transformed_value_var] - transform = transformed_value_var.tag.transform + transform = model.rvs_to_transforms[rv_var] transformed_rv_var = transform.forward(rv_var, *rv_var.owner.inputs) names.append(name) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ac04f0f4654..5a971448a9d 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -210,10 +210,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: def all_continuous(vars): - """Check that vars not include discrete variables, excepting observed RVs.""" - - vars_ = [var for var in vars if not hasattr(var.tag, "observations")] - + """Check that vars not include discrete variables""" if any([(var.dtype in discrete_types) for var in vars_]): return False else: diff --git a/pymc/tests/backends/fixtures.py b/pymc/tests/backends/fixtures.py index 287ac79dae9..a1ac8ce89ec 100644 --- a/pymc/tests/backends/fixtures.py +++ b/pymc/tests/backends/fixtures.py @@ -143,9 +143,9 @@ def setup_class(cls): cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape) if hasattr(cls, "write_partial_chain") and cls.write_partial_chain is True: - cls.chain_vars = [v.tag.value_var for v in cls.model.unobserved_RVs[1:]] + cls.chain_vars = [cls.model.rvs_to_values[v] for v in cls.model.unobserved_RVs[1:]] else: - cls.chain_vars = [v.tag.value_var for v in cls.model.unobserved_RVs] + cls.chain_vars = [cls.model.rvs_to_values[v] for v in cls.model.unobserved_RVs] with cls.model: strace0 = cls.backend(cls.name, vars=cls.chain_vars) diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index 8e46228e8c5..952b6087ddb 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -73,7 +73,7 @@ def get_dist_params_and_interval_bounds(self, model, rv_name): interval_rv = model.named_vars[f"{rv_name}_interval__"] rv = model.named_vars[rv_name] dist_params = rv.owner.inputs - lower_interval, upper_interval = interval_rv.tag.transform.args_fn(*rv.owner.inputs) + lower_interval, upper_interval = model.rvs_to_transforms[rv].args_fn(*rv.owner.inputs) return ( dist_params, lower_interval, diff --git a/pymc/tests/distributions/test_distribution.py b/pymc/tests/distributions/test_distribution.py index 27ac0d3d63d..01b05f0bd24 100644 --- a/pymc/tests/distributions/test_distribution.py +++ b/pymc/tests/distributions/test_distribution.py @@ -215,13 +215,15 @@ def logp(value, mu): mu = pm.Normal("mu", size=supp_shape) a = pm.DensityDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size) - mu_val = npr.normal(loc=0, scale=1, size=supp_shape).astype(aesara.config.floatX) - a_val = npr.normal(loc=mu_val, scale=1, size=to_tuple(size) + (supp_shape,)).astype( - aesara.config.floatX - ) - log_densityt = joint_logp(a, a.tag.value_var, sum=False)[0] + mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(aesara.config.floatX) + a_test_value = npr.normal( + loc=mu_test_value, scale=1, size=to_tuple(size) + (supp_shape,) + ).astype(aesara.config.floatX) + a_value_var = model.rvs_to_value[a] + mu_value_var = model.rvs_to_values[mu] + log_densityt = joint_logp(a, a_value_var, sum=False)[0] assert log_densityt.eval( - {a.tag.value_var: a_val, mu.tag.value_var: mu_val}, + {a_value_var: a_test_value, mu.value_var: mu_test_value}, ).shape == to_tuple(size) @pytest.mark.parametrize( diff --git a/pymc/tests/distributions/test_logprob.py b/pymc/tests/distributions/test_logprob.py index 4212b4baa7b..b4c989ebab3 100644 --- a/pymc/tests/distributions/test_logprob.py +++ b/pymc/tests/distributions/test_logprob.py @@ -47,7 +47,6 @@ _get_scaling, ignore_logprob, joint_logp, - joint_logpt, logcdf, logp, ) @@ -117,17 +116,16 @@ def test_joint_logp_basic(): b = Uniform("b", b_l, b_l + 1.0) a_value_var = m.rvs_to_values[a] - assert a_value_var.tag.transform + assert m.rvs_to_transforms[a] b_value_var = m.rvs_to_values[b] - assert b_value_var.tag.transform + assert m.rvs_to_transforms[b] c_value_var = m.rvs_to_values[c] b_logp = joint_logp(b, b_value_var, sum=False) - with pytest.warns(FutureWarning): - b_logpt = joint_logpt(b, b_value_var, sum=False) + b_logpt = joint_logp(b, b_value_var, sum=False) res_ancestors = list(walk_model(b_logp)) res_rv_ancestors = [ @@ -363,8 +361,8 @@ def test_hierarchical_logp(): ops = {a.owner.op for a in logp_ancestors if a.owner} assert len(ops) > 0 assert not any(isinstance(o, RandomVariable) for o in ops) - assert x.tag.value_var in logp_ancestors - assert y.tag.value_var in logp_ancestors + assert m.rvs_to_values[x] in logp_ancestors + assert m.rvs_to_values[y] in logp_ancestors def test_hierarchical_obs_logp(): diff --git a/pymc/tests/distributions/test_transform.py b/pymc/tests/distributions/test_transform.py index 71f4868d7aa..65e1ce6c88b 100644 --- a/pymc/tests/distributions/test_transform.py +++ b/pymc/tests/distributions/test_transform.py @@ -276,7 +276,7 @@ def build_model(self, distfam, params, size, transform, initval=None): def check_transform_elementwise_logp(self, model): x = model.free_RVs[0] - x_val_transf = x.tag.value_var + x_val_transf = model.rvs_to_values[x] pt = model.initial_point(0) test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape)) @@ -297,7 +297,7 @@ def check_transform_elementwise_logp(self, model): def check_vectortransform_elementwise_logp(self, model): x = model.free_RVs[0] - x_val_transf = x.tag.value_var + x_val_transf = model.rvs_to_values[x] pt = model.initial_point(0) test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape)) @@ -544,7 +544,7 @@ def test_triangular_transform(): with pm.Model() as m: x = pm.Triangular("x", lower=0, c=1, upper=2) - transform = x.tag.value_var.tag.transform + transform = m.rvs_to_transforms[x] assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0) assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2) diff --git a/pymc/tests/distributions/util.py b/pymc/tests/distributions/util.py index 3f65350694d..531f15137b0 100644 --- a/pymc/tests/distributions/util.py +++ b/pymc/tests/distributions/util.py @@ -288,9 +288,9 @@ def _model_input_dict(model, param_vars, pt): for k, v in pt.items(): rv_var = model.named_vars.get(k) nv = param_vars.get(k, rv_var) - nv = getattr(nv.tag, "value_var", nv) + nv = model.rvs_to_values.get(nv, nv) - transform = getattr(nv.tag, "transform", None) + transform = model.rvs_to_transforms.get(nv, None) if transform: # todo: the compiled graph behind this should be cached and # reused (if it isn't already). diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index be9e4304623..0b82841445b 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -69,6 +69,7 @@ from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection +from pymc.distributions.logprob import _get_scaling from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext from pymc.util import ( @@ -1039,7 +1040,14 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None) @node_property def symbolic_normalizing_constant(self): """*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`""" - t = self.to_flat_input(at.max([v.tag.scaling for v in self.group])) + t = self.to_flat_input( + at.max( + [ + _get_scaling(self.model.rvs_to_total_size.get(v, None), v.shape, v.ndim) + for v in self.group + ] + ) + ) t = self.symbolic_single_sample(t) return pm.floatX(t) @@ -1171,7 +1179,14 @@ def symbolic_normalizing_constant(self): """ t = at.max( self.collect("symbolic_normalizing_constant") - + [var.tag.scaling for var in self.model.observed_RVs] + + [ + _get_scaling( + self.model.rvs_to_total_size.get(obs, None), + obs.shape, + obs.ndim, + ) + for obs in self.model.observed_RVs + ] ) t = at.switch(self._scale_cost_to_minibatch, t, at.constant(1, dtype=t.dtype)) return pm.floatX(t)