Skip to content

Commit

Permalink
Better handling of variable names in e.g. GraphViz graphs (#4403)
Browse files Browse the repository at this point in the history
* handle changed API of theano.gof.graph.stack_search

* convert n and eta to tensors, explicitly list parameters for repr

* improve robustness of get_repr_for_variable

* Revert "handle changed API of theano.gof.graph.stack_search"

This reverts commit 6238bff.
  • Loading branch information
Spaak authored Jan 4, 2021
1 parent 240c372 commit e783106
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
7 changes: 5 additions & 2 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,8 +963,8 @@ class _LKJCholeskyCov(Continuous):
"""

def __init__(self, eta, n, sd_dist, *args, **kwargs):
self.n = n
self.eta = eta
self.n = tt.as_tensor_variable(n)
self.eta = tt.as_tensor_variable(eta)

if "transform" in kwargs and kwargs["transform"] is not None:
raise ValueError("Invalid parameter: transform.")
Expand Down Expand Up @@ -1129,6 +1129,9 @@ def random(self, point=None, size=None):
samples = np.reshape(samples, size + sample_shape)
return samples

def _distr_parameters_for_repr(self):
return ["eta", "n"]


def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=True, *args, **kwargs):
R"""Wrapper function for covariance matrix with LKJ distributed correlations.
Expand Down
8 changes: 7 additions & 1 deletion pymc3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,13 @@ def get_default_varnames(var_iterator, include_transformed):

def get_repr_for_variable(variable, formatting="plain"):
"""Build a human-readable string representation for a variable."""
name = variable.name if variable is not None else None
if variable is not None and hasattr(variable, "name"):
name = variable.name
elif type(variable) in [float, int, str]:
name = str(variable)
else:
name = None

if name is None and variable is not None:
if hasattr(variable, "get_parents"):
try:
Expand Down

0 comments on commit e783106

Please sign in to comment.