From e7831062f4cf2581f1435998aa6090c24677a234 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Mon, 4 Jan 2021 12:39:24 +0100 Subject: [PATCH] Better handling of variable names in e.g. GraphViz graphs (#4403) * handle changed API of theano.gof.graph.stack_search * convert n and eta to tensors, explicitly list parameters for repr * improve robustness of get_repr_for_variable * Revert "handle changed API of theano.gof.graph.stack_search" This reverts commit 6238bff6c38f5b0ff3b822e908780efb14401f50. --- pymc3/distributions/multivariate.py | 7 +++++-- pymc3/util.py | 8 +++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index a90d2393894..5a4d1cf992b 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -963,8 +963,8 @@ class _LKJCholeskyCov(Continuous): """ def __init__(self, eta, n, sd_dist, *args, **kwargs): - self.n = n - self.eta = eta + self.n = tt.as_tensor_variable(n) + self.eta = tt.as_tensor_variable(eta) if "transform" in kwargs and kwargs["transform"] is not None: raise ValueError("Invalid parameter: transform.") @@ -1129,6 +1129,9 @@ def random(self, point=None, size=None): samples = np.reshape(samples, size + sample_shape) return samples + def _distr_parameters_for_repr(self): + return ["eta", "n"] + def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=True, *args, **kwargs): R"""Wrapper function for covariance matrix with LKJ distributed correlations. diff --git a/pymc3/util.py b/pymc3/util.py index 54e19a6c80e..84b4f6c3e5f 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -131,7 +131,13 @@ def get_default_varnames(var_iterator, include_transformed): def get_repr_for_variable(variable, formatting="plain"): """Build a human-readable string representation for a variable.""" - name = variable.name if variable is not None else None + if variable is not None and hasattr(variable, "name"): + name = variable.name + elif type(variable) in [float, int, str]: + name = str(variable) + else: + name = None + if name is None and variable is not None: if hasattr(variable, "get_parents"): try: