Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add str/repr formatting options and change defaults accordingly #4260

Merged
merged 4 commits into from
Nov 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
alpha = self.alpha
m = self.m

if formatting == "latex":
if "latex" in formatting:
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
else:
return f"{name} ~ BART(alpha = {alpha}, m = {m})"
4 changes: 2 additions & 2 deletions pymc3/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ def _distr_name_for_repr(self):

def _str_repr(self, **kwargs):
distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped})
if "formatting" in kwargs and kwargs["formatting"] == "latex":
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :]
else:
distr_repr = distr_repr[distr_repr.index(" ~") + 3 :]
self_repr = super()._str_repr(**kwargs)

if "formatting" in kwargs and kwargs["formatting"] == "latex":
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
return self_repr + " -- " + distr_repr
else:
return self_repr + "-" + distr_repr
Expand Down
37 changes: 27 additions & 10 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,44 +164,61 @@ def _distr_name_for_repr(self):
return self.__class__.__name__

def _str_repr(self, name=None, dist=None, formatting="plain"):
"""Generate string representation for this distribution, optionally
"""
Generate string representation for this distribution, optionally
including LaTeX markup (formatting='latex').

Parameters
----------
name : str
name of the distribution
dist : Distribution
the distribution object
formatting : str
one of { "latex", "plain", "latex_with_params", "plain_with_params" }
"""
if dist is None:
dist = self
if name is None:
name = "[unnamed]"
supported_formattings = {"latex", "plain", "latex_with_params", "plain_with_params"}
if not formatting in supported_formattings:
raise ValueError(f"Unsupported formatting ''. Choose one of {supported_formattings}.")

param_names = self._distr_parameters_for_repr()
param_values = [
get_repr_for_variable(getattr(dist, x), formatting=formatting) for x in param_names
]

if formatting == "latex":
if "latex" in formatting:
param_string = ",~".join(
[fr"\mathit{{{name}}}={value}" for name, value in zip(param_names, param_values)]
)
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
if formatting == "latex_with_params":
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
)
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}$".format(
var_name=name, distr_name=dist._distr_name_for_repr()
)
else:
# 'plain' is default option
# one of the plain formattings
param_string = ", ".join(
[f"{name}={value}" for name, value in zip(param_names, param_values)]
)
return "{var_name} ~ {distr_name}({params})".format(
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
)
if formatting == "plain_with_params":
return f"{name} ~ {dist._distr_name_for_repr()}({param_string})"
return f"{name} ~ {dist._distr_name_for_repr()}"

def __str__(self, **kwargs):
try:
return self._str_repr(formatting="plain", **kwargs)
except:
return super().__str__()

def _repr_latex_(self, **kwargs):
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
"""Magic method name for IPython to use for LaTeX formatting."""
return self._str_repr(formatting="latex", **kwargs)
return self._str_repr(formatting=formatting, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that codecov is complaining that this line is not covered by tests, as far as I can tell it should be?


def logp_nojac(self, *args, **kwargs):
"""Return the logp, but do not include a jacobian term for transforms.
Expand Down
2 changes: 1 addition & 1 deletion pymc3/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat
distance = getattr(self.distance, "__name__", self.distance.__class__.__name__)

if formatting == "latex":
if "latex" in formatting:
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
else:
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"
Expand Down
37 changes: 20 additions & 17 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __rmatmul__(self, other):

def _str_repr(self, name=None, dist=None, formatting="plain"):
if getattr(self, "distribution", None) is None:
if formatting == "latex":
if "latex" in formatting:
return None
else:
return super().__str__()
Expand All @@ -76,8 +76,8 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
dist = self.distribution
return self.distribution._str_repr(name=name, dist=dist, formatting=formatting)

def _repr_latex_(self, **kwargs):
return self._str_repr(formatting="latex", **kwargs)
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
return self._str_repr(formatting=formatting, **kwargs)

def __str__(self, **kwargs):
try:
Expand Down Expand Up @@ -1375,8 +1375,8 @@ def check_test_point(self, test_point=None, round_vals=2):
def _str_repr(self, formatting="plain", **kwargs):
all_rv = itertools.chain(self.unobserved_RVs, self.observed_RVs)

if formatting == "latex":
rv_reprs = [rv.__latex__() for rv in all_rv]
if "latex" in formatting:
rv_reprs = [rv.__latex__(formatting=formatting) for rv in all_rv]
rv_reprs = [
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
for rv_repr in rv_reprs
Expand Down Expand Up @@ -1407,8 +1407,8 @@ def _str_repr(self, formatting="plain", **kwargs):
def __str__(self, **kwargs):
return self._str_repr(formatting="plain", **kwargs)

def _repr_latex_(self, **kwargs):
return self._str_repr(formatting="latex", **kwargs)
def _repr_latex_(self, *, formatting="latex", **kwargs):
return self._str_repr(formatting=formatting, **kwargs)

__latex__ = _repr_latex_

Expand Down Expand Up @@ -1874,24 +1874,27 @@ def _walk_up_rv(rv, formatting="plain"):
all_rvs.extend(_walk_up_rv(parent, formatting=formatting))
else:
name = rv.name if rv.name else "Constant"
fmt = r"\text{{{name}}}" if formatting == "latex" else "{name}"
fmt = r"\text{{{name}}}" if "latex" in formatting else "{name}"
all_rvs.append(fmt.format(name=name))
return all_rvs


class DeterministicWrapper(tt.TensorVariable):
def _str_repr(self, formatting="plain"):
if formatting == "latex":
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
)
if "latex" in formatting:
if formatting == "latex_with_params":
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
)
return fr"$\text{{{self.name}}} \sim \text{{Deterministic}}$"
else:
return "{name} ~ Deterministic({args})".format(
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting))
)
if formatting == "plain_with_params":
args = ", ".join(_walk_up_rv(self, formatting=formatting))
return f"{self.name} ~ Deterministic({args})"
return f"{self.name} ~ Deterministic"

def _repr_latex_(self):
return self._str_repr(formatting="latex")
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
return self._str_repr(formatting=formatting)

__latex__ = _repr_latex_

Expand Down
23 changes: 16 additions & 7 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def update_input_map(key: str, val: Set[VarName]):
pass
return input_map

def _make_node(self, var_name, graph):
def _make_node(self, var_name, graph, *, formatting: str = "plain"):
"""Attaches the given variable to a graphviz Digraph"""
v = self.model[var_name]

Expand All @@ -146,7 +146,7 @@ def _make_node(self, var_name, graph):
elif isinstance(v, SharedVariable):
label = f"{var_name}\n~\nData"
else:
label = str(v).replace(" ~ ", "\n~\n")
label = v._str_repr(formatting=formatting).replace(" ~ ", "\n~\n")

graph.node(var_name.replace(":", "&"), label, **attrs)

Expand Down Expand Up @@ -181,7 +181,7 @@ def get_plates(self):
plates[shape].add(var_name)
return plates

def make_graph(self):
def make_graph(self, formatting: str = "plain"):
"""Make graphviz Digraph of PyMC3 model

Returns
Expand All @@ -205,20 +205,20 @@ def make_graph(self):
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name="cluster" + label) as sub:
for var_name in var_names:
self._make_node(var_name, sub)
self._make_node(var_name, sub, formatting=formatting)
# plate label goes bottom right
sub.attr(label=label, labeljust="r", labelloc="b", style="rounded")
else:
for var_name in var_names:
self._make_node(var_name, graph)
self._make_node(var_name, graph, formatting=formatting)

for key, values in self.make_compute_graph().items():
for value in values:
graph.edge(value.replace(":", "&"), key.replace(":", "&"))
return graph


def model_to_graphviz(model=None):
def model_to_graphviz(model=None, *, formatting: str = "plain"):
"""Produce a graphviz Digraph from a PyMC3 model.

Requires graphviz, which may be installed most easily with
Expand All @@ -228,6 +228,15 @@ def model_to_graphviz(model=None):
and then `pip install graphviz` to get the python bindings. See
http://graphviz.readthedocs.io/en/stable/manual.html
for more information.

Parameters
----------
model : pm.Model
The model to plot. Not required when called from inside a modelcontext.
formatting : str
one of { "plain", "plain_with_params" }
"""
if not "plain" in formatting:
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
model = pm.modelcontext(model)
return ModelGraph(model).make_graph()
return ModelGraph(model).make_graph(formatting=formatting)
32 changes: 22 additions & 10 deletions pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,28 @@ def test_model_to_graphviz_for_model_with_data_container(self):
pm.Normal("obs", beta * x, obs_sigma, observed=y)
pm.sample(1000, init=None, tune=1000, chains=1)

g = pm.model_to_graphviz(model)

# Data node rendered correctly?
text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]'
assert text in g.source
# Didn't break ordinary variables?
text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]'
assert text in g.source
text = f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]'
assert text in g.source
for formatting in {"latex", "latex_with_params"}:
with pytest.raises(ValueError, match="Unsupported formatting"):
pm.model_to_graphviz(model, formatting=formatting)

exp_without = [
'x [label="x\n~\nData" shape=box style="rounded, filled"]',
'beta [label="beta\n~\nNormal"]',
'obs [label="obs\n~\nNormal" style=filled]',
]
exp_with = [
'x [label="x\n~\nData" shape=box style="rounded, filled"]',
'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]',
f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]',
]
for formatting, expected_substrings in [
("plain", exp_without),
("plain_with_params", exp_with),
]:
g = pm.model_to_graphviz(model, formatting=formatting)
# check formatting of RV nodes
for expected in expected_substrings:
assert expected in g.source

def test_explicit_coords(self):
N_rows = 5
Expand Down
Loading