Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding meaningful str representations to PyMC3 objects #4076

Merged
merged 7 commits into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ..model import modelcontext, Model
from .report import SamplerReport, merge_reports
from ..util import get_var_name

logger = logging.getLogger('pymc3')

Expand Down Expand Up @@ -109,7 +110,7 @@ def _set_sampler_vars(self, sampler_vars):
self.sampler_vars = sampler_vars

# pylint: disable=unused-argument
def setup(self, draws, chain, sampler_vars=None) -> None:
def setup(self, draws, chain, sampler_vars=None) -> None:
"""Perform chain-specific setup.

Parameters
Expand Down Expand Up @@ -335,7 +336,7 @@ def __getitem__(self, idx):
var = idx
burn, thin = 0, 1

var = str(var)
var = get_var_name(var)
if var in self.varnames:
if var in self.stat_names:
warnings.warn("Attribute access on a trace object is ambigous. "
Expand All @@ -355,7 +356,7 @@ def __getattr__(self, name):
if name in self._attrs:
raise AttributeError

name = str(name)
name = get_var_name(name)
if name in self.varnames:
if name in self.stat_names:
warnings.warn("Attribute access on a trace object is ambigous. "
Expand Down Expand Up @@ -482,7 +483,7 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
"""
if chains is None:
chains = self.chains
varname = str(varname)
varname = get_var_name(varname)
try:
results = [self._straces[chain].get_values(varname, burn, thin)
for chain in chains]
Expand Down
3 changes: 2 additions & 1 deletion pymc3/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from ..backends import base, ndarray
from . import tracetab as ttab
from ..util import get_var_name

TEMPLATES = {
'table': ('CREATE TABLE IF NOT EXISTS [{table}] '
Expand Down Expand Up @@ -244,7 +245,7 @@ def get_values(self, varname, burn=0, thin=1):
if thin < 1:
raise ValueError('Only positive thin values are supported '
'in SQLite backend.')
varname = str(varname)
varname = get_var_name(varname)

statement_args = {'chain': self.chain}
if burn == 0 and thin == 1:
Expand Down
4 changes: 3 additions & 1 deletion pymc3/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np
import collections

from .util import get_var_name

__all__ = ['ArrayOrdering', 'DictToArrayBijection', 'DictToVarBijection']

VarMap = collections.namedtuple('VarMap', 'var, slc, shp, dtyp')
Expand Down Expand Up @@ -237,7 +239,7 @@ class DictToVarBijection:
"""

def __init__(self, var, idx, dpoint):
self.var = str(var)
self.var = get_var_name(var)
self.idx = idx
self.dpt = dpoint

Expand Down
7 changes: 5 additions & 2 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
import theano.tensor as tt
from theano import function
from ..util import get_repr_for_variable
from ..util import get_repr_for_variable, get_var_name
import theano
from ..memoize import memoize
from ..model import (
Expand Down Expand Up @@ -174,6 +174,9 @@ def _str_repr(self, name=None, dist=None, formatting='plain'):
return "{var_name} ~ {distr_name}({params})".format(var_name=name,
distr_name=dist._distr_name_for_repr(), params=param_string)

def __str__(self, **kwargs):
return self._str_repr(formatting="plain", **kwargs)

def _repr_latex_(self, **kwargs):
"""Magic method name for IPython to use for LaTeX formatting."""
return self._str_repr(formatting="latex", **kwargs)
Expand Down Expand Up @@ -728,7 +731,7 @@ def draw_values(params, point=None, size=None):
missing_inputs = set([j for j, p in symbolic_params])
while to_eval or missing_inputs:
if to_eval == missing_inputs:
raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval]))
raise ValueError('Cannot resolve inputs for {}'.format([get_var_name(params[j]) for j in to_eval]))
to_eval = set(missing_inputs)
missing_inputs = set()
for param_idx in to_eval:
Expand Down
8 changes: 4 additions & 4 deletions pymc3/distributions/posterior_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from ..exceptions import IncorrectArgumentsError
from ..vartypes import theano_constant
from ..util import dataset_to_point_dict, chains_and_samples
from ..util import dataset_to_point_dict, chains_and_samples, get_var_name

# Failing tests:
# test_mixture_random_shape::test_mixture_random_shape
Expand Down Expand Up @@ -460,7 +460,7 @@ def draw_values(self) -> List[np.ndarray]:
if to_eval == missing_inputs:
raise ValueError(
"Cannot resolve inputs for {}".format(
[str(trace.varnames[j]) for j in to_eval]
[get_var_name(trace.varnames[j]) for j in to_eval]
)
)
to_eval = set(missing_inputs)
Expand Down Expand Up @@ -493,7 +493,7 @@ def draw_values(self) -> List[np.ndarray]:
return [self.evaluated[j] for j in params]

def init(self) -> None:
"""This method carries out the initialization phase of sampling
"""This method carries out the initialization phase of sampling
from the posterior predictive distribution. Notably it initializes the
``_DrawValuesContext`` bookkeeping object and evaluates the "fast drawable"
parts of the model."""
Expand Down Expand Up @@ -567,7 +567,7 @@ def draw_value(self, param, trace: Optional[_TraceDict] = None, givens=None):
The value or distribution. Constants or shared variables
will be converted to an array and returned. Theano variables
are evaluated. If `param` is a pymc3 random variable, draw
values from it and return that (as ``np.ndarray``), unless a
values from it and return that (as ``np.ndarray``), unless a
value is specified in the ``trace``.
trace: pm.MultiTrace, optional
A dictionary from pymc3 variable names to samples of their values
Expand Down
19 changes: 17 additions & 2 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .theanof import gradient, hessian, inputvars, generator
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
from .blocking import DictToArrayBijection, ArrayOrdering
from .util import get_transformed_name
from .util import get_transformed_name, get_var_name
from .exceptions import ImputationWarning

__all__ = [
Expand Down Expand Up @@ -80,6 +80,9 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
def _repr_latex_(self, **kwargs):
return self._str_repr(formatting="latex", **kwargs)

def __str__(self, **kwargs):
return self._str_repr(formatting="plain", **kwargs)

__latex__ = _repr_latex_


Expand Down Expand Up @@ -1368,6 +1371,9 @@ def _str_repr(self, formatting="plain", **kwargs):
for n, d in zip(names, distrs)]
return "\n".join(rv_reprs)

def __str__(self, **kwargs):
return self._str_repr(formatting="plain", **kwargs)

def _repr_latex_(self, **kwargs):
return self._str_repr(formatting="latex", **kwargs)

Expand Down Expand Up @@ -1480,7 +1486,8 @@ def Point(*args, **kwargs):
except Exception as e:
raise TypeError("can't turn {} and {} into a dict. {}".format(args, kwargs, e))
return dict(
(str(k), np.array(v)) for k, v in d.items() if str(k) in map(str, model.vars)
(get_var_name(k), np.array(v)) for k, v in d.items()
if get_var_name(k) in map(get_var_name, model.vars)
)


Expand Down Expand Up @@ -1872,6 +1879,14 @@ def Deterministic(name, var, model=None, dims=None):
model.add_random_variable(var, dims)
var._repr_latex_ = functools.partial(_repr_deterministic_rv, var, formatting='latex')
var.__latex__ = var._repr_latex_

# simply assigning var.__str__ is not enough, since str() will default to the class-
# defined __str__ anyway; see https://stackoverflow.com/a/5918210/1692028
old_type = type(var)
new_type = type(old_type.__name__ + '_pymc3_Deterministic', (old_type,),
{'__str__': functools.partial(_repr_deterministic_rv, var, formatting='plain')})
var.__class__ = new_type

return var


Expand Down
4 changes: 2 additions & 2 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from theano.compile import SharedVariable
from theano.tensor import Tensor

from .util import get_default_varnames
from .util import get_default_varnames, get_var_name
from .model import ObservedRV
import pymc3 as pm

Expand Down Expand Up @@ -83,7 +83,7 @@ def _filter_parents(self, var, parents) -> Set[VarName]:
if self.transform_map[p] != var.name:
keep.add(self.transform_map[p])
else:
raise AssertionError('Do not know what to do with {}'.format(str(p)))
raise AssertionError('Do not know what to do with {}'.format(get_var_name(p)))
return keep

def get_parents(self, var: Tensor) -> Set[VarName]:
Expand Down
3 changes: 2 additions & 1 deletion pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..model import modelcontext
from ..theanof import inputvars
from ..blocking import ArrayOrdering, DictToArrayBijection
from ..util import get_var_name
import numpy as np
from numpy.random import uniform
from enum import IntEnum, unique
Expand Down Expand Up @@ -175,7 +176,7 @@ def __init__(self, vars, shared, blocked=True):
"""
self.vars = vars
self.ordering = ArrayOrdering(vars)
self.shared = {str(var): shared for var, shared in shared.items()}
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
self.blocked = blocked
self.bij = None

Expand Down
3 changes: 2 additions & 1 deletion pymc3/tests/sampler_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pymc3 as pm
from pymc3.util import get_var_name
import numpy as np
import numpy.testing as npt
from scipy import stats
Expand Down Expand Up @@ -145,7 +146,7 @@ def setup_class(cls):
cls.trace = pm.sample(cls.n_samples, tune=cls.tune, step=cls.step, cores=cls.chains)
cls.samples = {}
for var in cls.model.unobserved_RVs:
cls.samples[str(var)] = cls.trace.get_values(var, burn=cls.burn)
cls.samples[get_var_name(var)] = cls.trace.get_values(var, burn=cls.burn)

def test_neff(self):
if hasattr(self, 'min_n_eff'):
Expand Down
26 changes: 21 additions & 5 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@ def test_bound():
BoundPoissonPositionalArgs = Bound(Poisson, upper=6)("x", 2.0)


class TestLatex:
class TestStrAndLatexRepr:
def setup_class(self):
# True parameter values
alpha, sigma = 1, 1
Expand Down Expand Up @@ -1800,30 +1800,46 @@ def setup_class(self):
# Likelihood (sampling distribution) of observations
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)
self.distributions = [alpha, sigma, mu, b, Z, Y_obs]
self.expected = (
self.expected_latex = (
r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
r"$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$",
r"$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$",
r"$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$",
r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$",
)
self.expected_str = (
r"alpha ~ Normal(mu=0.0, sigma=10.0)",
r"sigma ~ HalfNormal(sigma=1.0)",
r"mu ~ Deterministic(alpha, Constant, beta)",
r"beta ~ Normal(mu=0.0, sigma=10.0)",
r"Z ~ MvNormal(mu=array, chol_cov=array)",
r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))",
)

def test__repr_latex_(self):
for distribution, tex in zip(self.distributions, self.expected):
for distribution, tex in zip(self.distributions, self.expected_latex):
assert distribution._repr_latex_() == tex

model_tex = self.model._repr_latex_()

for tex in self.expected: # make sure each variable is in the model
for tex in self.expected_latex: # make sure each variable is in the model
for segment in tex.strip("$").split(r"\sim"):
assert segment in model_tex

def test___latex__(self):
for distribution, tex in zip(self.distributions, self.expected):
for distribution, tex in zip(self.distributions, self.expected_latex):
assert distribution._repr_latex_() == distribution.__latex__()
assert self.model._repr_latex_() == self.model.__latex__()

def test___str__(self):
for distribution, str_repr in zip(self.distributions, self.expected_str):
assert distribution.__str__() == str_repr

model_str = self.model.__str__()
for str_repr in self.expected_str:
assert str_repr in model_str


def test_discrete_trafo():
with pytest.raises(ValueError) as err:
Expand Down
22 changes: 22 additions & 0 deletions pymc3/tests/test_starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from .models import simple_model, non_normal, simple_arbitrary_det
from .helpers import select_by_precision

from pytest import raises


def test_accuracy_normal():
_, model, (mu, _) = simple_model()
Expand Down Expand Up @@ -83,3 +85,23 @@ def test_find_MAP():

close_to(map_est2['mu'], 0, tol)
close_to(map_est2['sigma'], 1, tol)


def test_allinmodel():
model1 = Model()
model2 = Model()
with model1:
x1 = Normal('x1', mu=0, sigma=1)
y1 = Normal('y1', mu=0, sigma=1)
with model2:
x2 = Normal('x2', mu=0, sigma=1)
y2 = Normal('y2', mu=0, sigma=1)

starting.allinmodel([x1, y1], model1)
starting.allinmodel([x1], model1)
with raises(ValueError, match=r"Some variables not in the model: \['x2', 'y2'\]"):
starting.allinmodel([x2, y2], model1)
with raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
starting.allinmodel([x2, y1], model1)
with raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
starting.allinmodel([x2], model1)
3 changes: 2 additions & 1 deletion pymc3/tuning/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..model import modelcontext, Point
from ..theanof import hessian_diag, inputvars
from ..blocking import DictToArrayBijection, ArrayOrdering
from ..util import get_var_name

__all__ = ['find_hessian', 'trace_cov', 'guess_scaling']

Expand Down Expand Up @@ -135,7 +136,7 @@ def trace_cov(trace, vars=None, model=None):
vars = trace.varnames

def flat_t(var):
x = trace[str(var)]
x = trace[get_var_name(var)]
return x.reshape((x.shape[0], np.prod(x.shape[1:], dtype=int)))

return np.cov(np.concatenate(list(map(flat_t, vars)), 1).T)
3 changes: 2 additions & 1 deletion pymc3/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..theanof import inputvars
import theano.gradient as tg
from ..blocking import DictToArrayBijection, ArrayOrdering
from ..util import update_start_vals, get_default_varnames
from ..util import update_start_vals, get_default_varnames, get_var_name

import warnings
from inspect import getargspec
Expand Down Expand Up @@ -196,6 +196,7 @@ def nan_to_high(x):
def allinmodel(vars, model):
notin = [v for v in vars if v not in model.vars]
if notin:
notin = list(map(get_var_name, notin))
raise ValueError("Some variables not in the model: " + str(notin))


Expand Down
Loading