diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index bba94f8d22b..c512d845765 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -246,7 +246,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"): alpha = self.alpha m = self.m - if formatting == "latex": + if "latex" in formatting: return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$" else: return f"{name} ~ BART(alpha = {alpha}, m = {m})" diff --git a/pymc3/distributions/bound.py b/pymc3/distributions/bound.py index ae0ac673544..bde9abc8874 100644 --- a/pymc3/distributions/bound.py +++ b/pymc3/distributions/bound.py @@ -157,13 +157,13 @@ def _distr_name_for_repr(self): def _str_repr(self, **kwargs): distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped}) - if "formatting" in kwargs and kwargs["formatting"] == "latex": + if "formatting" in kwargs and "latex" in kwargs["formatting"]: distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :] else: distr_repr = distr_repr[distr_repr.index(" ~") + 3 :] self_repr = super()._str_repr(**kwargs) - if "formatting" in kwargs and kwargs["formatting"] == "latex": + if "formatting" in kwargs and "latex" in kwargs["formatting"]: return self_repr + " -- " + distr_repr else: return self_repr + "-" + distr_repr diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 60f3c7ef453..066dbfd26df 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -164,34 +164,51 @@ def _distr_name_for_repr(self): return self.__class__.__name__ def _str_repr(self, name=None, dist=None, formatting="plain"): - """Generate string representation for this distribution, optionally + """ + Generate string representation for this distribution, optionally including LaTeX markup (formatting='latex'). + + Parameters + ---------- + name : str + name of the distribution + dist : Distribution + the distribution object + formatting : str + one of { "latex", "plain", "latex_with_params", "plain_with_params" } """ if dist is None: dist = self if name is None: name = "[unnamed]" + supported_formattings = {"latex", "plain", "latex_with_params", "plain_with_params"} + if not formatting in supported_formattings: + raise ValueError(f"Unsupported formatting ''. Choose one of {supported_formattings}.") param_names = self._distr_parameters_for_repr() param_values = [ get_repr_for_variable(getattr(dist, x), formatting=formatting) for x in param_names ] - if formatting == "latex": + if "latex" in formatting: param_string = ",~".join( [fr"\mathit{{{name}}}={value}" for name, value in zip(param_names, param_values)] ) - return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format( - var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string + if formatting == "latex_with_params": + return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format( + var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string + ) + return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}$".format( + var_name=name, distr_name=dist._distr_name_for_repr() ) else: - # 'plain' is default option + # one of the plain formattings param_string = ", ".join( [f"{name}={value}" for name, value in zip(param_names, param_values)] ) - return "{var_name} ~ {distr_name}({params})".format( - var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string - ) + if formatting == "plain_with_params": + return f"{name} ~ {dist._distr_name_for_repr()}({param_string})" + return f"{name} ~ {dist._distr_name_for_repr()}" def __str__(self, **kwargs): try: @@ -199,9 +216,9 @@ def __str__(self, **kwargs): except: return super().__str__() - def _repr_latex_(self, **kwargs): + def _repr_latex_(self, *, formatting="latex_with_params", **kwargs): """Magic method name for IPython to use for LaTeX formatting.""" - return self._str_repr(formatting="latex", **kwargs) + return self._str_repr(formatting=formatting, **kwargs) def logp_nojac(self, *args, **kwargs): """Return the logp, but do not include a jacobian term for transforms. diff --git a/pymc3/distributions/simulator.py b/pymc3/distributions/simulator.py index 74933a8af57..2662c866797 100644 --- a/pymc3/distributions/simulator.py +++ b/pymc3/distributions/simulator.py @@ -126,7 +126,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"): sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat distance = getattr(self.distance, "__name__", self.distance.__class__.__name__) - if formatting == "latex": + if "latex" in formatting: return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$" else: return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})" diff --git a/pymc3/model.py b/pymc3/model.py index 55ca443d0b2..356b6b777be 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -65,7 +65,7 @@ def __rmatmul__(self, other): def _str_repr(self, name=None, dist=None, formatting="plain"): if getattr(self, "distribution", None) is None: - if formatting == "latex": + if "latex" in formatting: return None else: return super().__str__() @@ -76,8 +76,8 @@ def _str_repr(self, name=None, dist=None, formatting="plain"): dist = self.distribution return self.distribution._str_repr(name=name, dist=dist, formatting=formatting) - def _repr_latex_(self, **kwargs): - return self._str_repr(formatting="latex", **kwargs) + def _repr_latex_(self, *, formatting="latex_with_params", **kwargs): + return self._str_repr(formatting=formatting, **kwargs) def __str__(self, **kwargs): try: @@ -1375,8 +1375,8 @@ def check_test_point(self, test_point=None, round_vals=2): def _str_repr(self, formatting="plain", **kwargs): all_rv = itertools.chain(self.unobserved_RVs, self.observed_RVs) - if formatting == "latex": - rv_reprs = [rv.__latex__() for rv in all_rv] + if "latex" in formatting: + rv_reprs = [rv.__latex__(formatting=formatting) for rv in all_rv] rv_reprs = [ rv_repr.replace(r"\sim", r"&\sim &").strip("$") for rv_repr in rv_reprs @@ -1407,8 +1407,8 @@ def _str_repr(self, formatting="plain", **kwargs): def __str__(self, **kwargs): return self._str_repr(formatting="plain", **kwargs) - def _repr_latex_(self, **kwargs): - return self._str_repr(formatting="latex", **kwargs) + def _repr_latex_(self, *, formatting="latex", **kwargs): + return self._str_repr(formatting=formatting, **kwargs) __latex__ = _repr_latex_ @@ -1874,24 +1874,27 @@ def _walk_up_rv(rv, formatting="plain"): all_rvs.extend(_walk_up_rv(parent, formatting=formatting)) else: name = rv.name if rv.name else "Constant" - fmt = r"\text{{{name}}}" if formatting == "latex" else "{name}" + fmt = r"\text{{{name}}}" if "latex" in formatting else "{name}" all_rvs.append(fmt.format(name=name)) return all_rvs class DeterministicWrapper(tt.TensorVariable): def _str_repr(self, formatting="plain"): - if formatting == "latex": - return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format( - name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting)) - ) + if "latex" in formatting: + if formatting == "latex_with_params": + return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format( + name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting)) + ) + return fr"$\text{{{self.name}}} \sim \text{{Deterministic}}$" else: - return "{name} ~ Deterministic({args})".format( - name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting)) - ) + if formatting == "plain_with_params": + args = ", ".join(_walk_up_rv(self, formatting=formatting)) + return f"{self.name} ~ Deterministic({args})" + return f"{self.name} ~ Deterministic" - def _repr_latex_(self): - return self._str_repr(formatting="latex") + def _repr_latex_(self, *, formatting="latex_with_params", **kwargs): + return self._str_repr(formatting=formatting) __latex__ = _repr_latex_ diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index 2f4b36448a9..b2169dcb36b 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -121,7 +121,7 @@ def update_input_map(key: str, val: Set[VarName]): pass return input_map - def _make_node(self, var_name, graph): + def _make_node(self, var_name, graph, *, formatting: str = "plain"): """Attaches the given variable to a graphviz Digraph""" v = self.model[var_name] @@ -146,7 +146,7 @@ def _make_node(self, var_name, graph): elif isinstance(v, SharedVariable): label = f"{var_name}\n~\nData" else: - label = str(v).replace(" ~ ", "\n~\n") + label = v._str_repr(formatting=formatting).replace(" ~ ", "\n~\n") graph.node(var_name.replace(":", "&"), label, **attrs) @@ -181,7 +181,7 @@ def get_plates(self): plates[shape].add(var_name) return plates - def make_graph(self): + def make_graph(self, formatting: str = "plain"): """Make graphviz Digraph of PyMC3 model Returns @@ -205,12 +205,12 @@ def make_graph(self): # must be preceded by 'cluster' to get a box around it with graph.subgraph(name="cluster" + label) as sub: for var_name in var_names: - self._make_node(var_name, sub) + self._make_node(var_name, sub, formatting=formatting) # plate label goes bottom right sub.attr(label=label, labeljust="r", labelloc="b", style="rounded") else: for var_name in var_names: - self._make_node(var_name, graph) + self._make_node(var_name, graph, formatting=formatting) for key, values in self.make_compute_graph().items(): for value in values: @@ -218,7 +218,7 @@ def make_graph(self): return graph -def model_to_graphviz(model=None): +def model_to_graphviz(model=None, *, formatting: str = "plain"): """Produce a graphviz Digraph from a PyMC3 model. Requires graphviz, which may be installed most easily with @@ -228,6 +228,15 @@ def model_to_graphviz(model=None): and then `pip install graphviz` to get the python bindings. See http://graphviz.readthedocs.io/en/stable/manual.html for more information. + + Parameters + ---------- + model : pm.Model + The model to plot. Not required when called from inside a modelcontext. + formatting : str + one of { "plain", "plain_with_params" } """ + if not "plain" in formatting: + raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.") model = pm.modelcontext(model) - return ModelGraph(model).make_graph() + return ModelGraph(model).make_graph(formatting=formatting) diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index c9ee9e3dd07..5b384ccc999 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -179,16 +179,28 @@ def test_model_to_graphviz_for_model_with_data_container(self): pm.Normal("obs", beta * x, obs_sigma, observed=y) pm.sample(1000, init=None, tune=1000, chains=1) - g = pm.model_to_graphviz(model) - - # Data node rendered correctly? - text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]' - assert text in g.source - # Didn't break ordinary variables? - text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]' - assert text in g.source - text = f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]' - assert text in g.source + for formatting in {"latex", "latex_with_params"}: + with pytest.raises(ValueError, match="Unsupported formatting"): + pm.model_to_graphviz(model, formatting=formatting) + + exp_without = [ + 'x [label="x\n~\nData" shape=box style="rounded, filled"]', + 'beta [label="beta\n~\nNormal"]', + 'obs [label="obs\n~\nNormal" style=filled]', + ] + exp_with = [ + 'x [label="x\n~\nData" shape=box style="rounded, filled"]', + 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]', + f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]', + ] + for formatting, expected_substrings in [ + ("plain", exp_without), + ("plain_with_params", exp_with), + ]: + g = pm.model_to_graphviz(model, formatting=formatting) + # check formatting of RV nodes + for expected in expected_substrings: + assert expected in g.source def test_explicit_coords(self): N_rows = 5 diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 8f3de0d7d41..6a2593ac973 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1776,6 +1776,12 @@ def setup_class(self): # Test Cholesky parameterization Z = MvNormal("Z", mu=np.zeros(2), chol=np.eye(2), shape=(2,)) + # NegativeBinomial representations to test issue 4186 + nb1 = pm.NegativeBinomial( + "nb_with_mu_alpha", mu=pm.Normal("nbmu"), alpha=pm.Gamma("nbalpha", mu=6, sigma=1) + ) + nb2 = pm.NegativeBinomial("nb_with_p_n", p=pm.Uniform("nbp"), n=10) + # Expected value of outcome mu = Deterministic("mu", floatX(alpha + tt.dot(X, b))) @@ -1799,59 +1805,92 @@ def setup_class(self): # Likelihood (sampling distribution) of observations Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y) - self.distributions = [alpha, sigma, mu, b, Z, Y_obs, bound_var] - self.expected_latex = ( - r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", - r"$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$", - r"$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$", - r"$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", - r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$", - r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$", - r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", - r"$\text{kron_normal} \sim \text{KroneckerNormal}(\mathit{mu}=array)$", - r"$\text{mat_normal} \sim \text{MatrixNormal}(\mathit{mu}=array,~\mathit{rowcov}=array,~\mathit{colchol_cov}=array)$", - ) - self.expected_str = ( - r"alpha ~ Normal(mu=0.0, sigma=10.0)", - r"sigma ~ HalfNormal(sigma=1.0)", - r"mu ~ Deterministic(alpha, Constant, beta)", - r"beta ~ Normal(mu=0.0, sigma=10.0)", - r"Z ~ MvNormal(mu=array, chol_cov=array)", - r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))", - r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)", - r"kron_normal ~ KroneckerNormal(mu=array)", - r"mat_normal ~ MatrixNormal(mu=array, rowcov=array, colchol_cov=array)", - ) + self.distributions = [alpha, sigma, mu, b, Z, nb1, nb2, Y_obs, bound_var] + self.expected = { + "latex": ( + r"$\text{alpha} \sim \text{Normal}$", + r"$\text{sigma} \sim \text{HalfNormal}$", + r"$\text{mu} \sim \text{Deterministic}$", + r"$\text{beta} \sim \text{Normal}$", + r"$\text{Z} \sim \text{MvNormal}$", + r"$\text{nb_with_mu_alpha} \sim \text{NegativeBinomial}$", + r"$\text{nb_with_p_n} \sim \text{NegativeBinomial}$", + r"$\text{Y_obs} \sim \text{Normal}$", + r"$\text{bound_var} \sim \text{Bound}$ -- \text{Normal}$", + r"$\text{kron_normal} \sim \text{KroneckerNormal}$", + r"$\text{mat_normal} \sim \text{MatrixNormal}$", + ), + "plain": ( + r"alpha ~ Normal", + r"sigma ~ HalfNormal", + r"mu ~ Deterministic", + r"beta ~ Normal", + r"Z ~ MvNormal", + r"nb_with_mu_alpha ~ NegativeBinomial", + r"nb_with_p_n ~ NegativeBinomial", + r"Y_obs ~ Normal", + r"bound_var ~ Bound-Normal", + r"kron_normal ~ KroneckerNormal", + r"mat_normal ~ MatrixNormal", + ), + "latex_with_params": ( + r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", + r"$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$", + r"$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$", + r"$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", + r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$", + r"$\text{nb_with_mu_alpha} \sim \text{NegativeBinomial}(\mathit{mu}=\text{nbmu},~\mathit{alpha}=\text{nbalpha})$", + r"$\text{nb_with_p_n} \sim \text{NegativeBinomial}(\mathit{p}=\text{nbp},~\mathit{n}=10)$", + r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$", + r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", + r"$\text{kron_normal} \sim \text{KroneckerNormal}(\mathit{mu}=array)$", + r"$\text{mat_normal} \sim \text{MatrixNormal}(\mathit{mu}=array,~\mathit{rowcov}=array,~\mathit{colchol_cov}=array)$", + ), + "plain_with_params": ( + r"alpha ~ Normal(mu=0.0, sigma=10.0)", + r"sigma ~ HalfNormal(sigma=1.0)", + r"mu ~ Deterministic(alpha, Constant, beta)", + r"beta ~ Normal(mu=0.0, sigma=10.0)", + r"Z ~ MvNormal(mu=array, chol_cov=array)", + r"nb_with_mu_alpha ~ NegativeBinomial(mu=nbmu, alpha=nbalpha)", + r"nb_with_p_n ~ NegativeBinomial(p=nbp, n=10)", + r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))", + r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)", + r"kron_normal ~ KroneckerNormal(mu=array)", + r"mat_normal ~ MatrixNormal(mu=array, rowcov=array, colchol_cov=array)", + ), + } def test__repr_latex_(self): - for distribution, tex in zip(self.distributions, self.expected_latex): + for distribution, tex in zip(self.distributions, self.expected["latex_with_params"]): assert distribution._repr_latex_() == tex model_tex = self.model._repr_latex_() - for tex in self.expected_latex: # make sure each variable is in the model + # make sure each variable is in the model + for tex in self.expected["latex"]: for segment in tex.strip("$").split(r"\sim"): assert segment in model_tex def test___latex__(self): - for distribution, tex in zip(self.distributions, self.expected_latex): + for distribution, tex in zip(self.distributions, self.expected["latex_with_params"]): assert distribution._repr_latex_() == distribution.__latex__() assert self.model._repr_latex_() == self.model.__latex__() def test___str__(self): - for distribution, str_repr in zip(self.distributions, self.expected_str): + for distribution, str_repr in zip(self.distributions, self.expected["plain"]): assert distribution.__str__() == str_repr model_str = self.model.__str__() - for str_repr in self.expected_str: + for str_repr in self.expected["plain"]: assert str_repr in model_str def test_str(self): - for distribution, str_repr in zip(self.distributions, self.expected_str): + for distribution, str_repr in zip(self.distributions, self.expected["plain"]): assert str(distribution) == str_repr model_str = str(self.model) - for str_repr in self.expected_str: + for str_repr in self.expected["plain"]: assert str_repr in model_str @@ -1904,17 +1943,6 @@ def test_issue_3051(self, dims, dist_cls, kwargs): assert actual_a.shape == (X.shape[0],) pass - def test_issue_4186(self): - with pm.Model(): - nb = pm.NegativeBinomial( - "nb", mu=pm.Normal("mu"), alpha=pm.Gamma("alpha", mu=6, sigma=1) - ) - assert str(nb) == "nb ~ NegativeBinomial(mu=mu, alpha=alpha)" - - with pm.Model(): - nb = pm.NegativeBinomial("nb", p=pm.Uniform("p"), n=10) - assert str(nb) == "nb ~ NegativeBinomial(p=p, n=10)" - def test_serialize_density_dist(): def func(x): diff --git a/pymc3/util.py b/pymc3/util.py index d332510c2f8..63817006831 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -137,7 +137,7 @@ def get_repr_for_variable(variable, formatting="plain"): for item in variable.get_parents()[0].inputs ] # do not escape_latex these, since it is not idempotent - if formatting == "latex": + if "latex" in formatting: return "f({args})".format( args=",~".join([n for n in names if isinstance(n, str)]) ) @@ -152,7 +152,7 @@ def get_repr_for_variable(variable, formatting="plain"): return value.item() return "array" - if formatting == "latex": + if "latex" in formatting: return fr"\text{{{name}}}" else: return name