Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LaTeX and pretty str representations for v4, continued #4849

Merged
merged 13 commits into from
Jul 13, 2021

Conversation

Spaak
Copy link
Member

@Spaak Spaak commented Jul 9, 2021

This achieves largely the same as discussed in #4692, though now without either type patching or overriding __repr__ or __str__ in Aesara. I've also updated the relevant tests.

Quoting the previous PR:

This PR adds LaTeX and string representations for distributions and Models for v4, fixing #4494. Specifically, it adds the following functionality:

model = pm.Model()
with model:
    b_mu = pm.Normal('b_mu', mu=0, sigma=10)
    b_sigma = pm.HalfNormal('b_sigma', sigma=10)

    a_mu = pm.Normal('a_mu', mu=0, sigma=10)
    a_sigma = pm.HalfNormal('a_sigma', sigma=10)

    b = pm.Normal('b', mu=b_mu, sigma=b_sigma)
    a = pm.Normal('a', mu=a_mu, sigma=a_sigma)

    y_sigma = pm.HalfNormal('y_sigma', sigma=10)
    y = pm.Normal('y', mu=a+b*obs_x, sigma=y_sigma, observed=obs_y)

after which (in interactive console session), we can do:

>>> y
y ~ N(f(a, b), y_sigma)

>>> model
   b_mu ~ N(0, 10)
b_sigma ~ N**+(0, 10)
   a_mu ~ N(0, 10)
a_sigma ~ N**+(0, 10)
      b ~ N(b_mu, b_sigma)
      a ~ N(a_mu, a_sigma)
y_sigma ~ N**+(0, 10)
      y ~ N(f(a, b), y_sigma)

Additionally, in Jupyter notebooks, the model now renders as LaTeX (MathJax, to be precise):

image

There are some changes w.r.t. how this functionality was implemented in v3:

  • The distribution's names are now provided by the Aesara RandomVariable._print_name.
  • Distribution parameters are printed without names (e.g. N(0, 1) instead of N(mu=0, sigma=1). This is a consequence of RandomVariable's interface. I considered using inspection to extract the parameter names (like before). However, Aesara follows numpy's naming (e.g. "loc" and "scale" for NormalRV), while PyMC3 uses its own ("mu" and "sigma"). For now I decided not to include parameter names to avoid this mismatch, but am open to implement it after all. In general I would actually prefer to have the parameter names.
  • Previously, in more complex models, parameter strings would often become very long, in the form of f(f(f(f(f(a)), f(f(f(b))), .... I've now changed this such that the entire graph going into a parameter is flattened, and only f(a1, a2, ..., aN) is left, for every RandomVariable a occurring in the graph. This should make the strings a lot easier to read.
  • I've grouped all printing-related code in a new module printing.

The code is structured somewhat differently from the previous PR, and in terms of functionality the main difference is that this PR will only provide pretty printing in IPython shells (and Jupyter), by registering custom pretty printing handlers on import of pymc3.printing. The previous PR attempted to be universal (also in plain python shells) by overriding __repr__, leading to potentially hairy situations.

@Spaak Spaak linked an issue Jul 9, 2021 that may be closed by this pull request
@@ -668,6 +669,13 @@ def __init__(
self.deterministics = treelist()
self.potentials = treelist()

from pymc3.printing import str_for_model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not move this to the global level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid the circular import (printing depends on model and vice versa). I don't like the look of this either, but it seems to be the recommended way to avoid circular imports.

@twiecki
Copy link
Member

twiecki commented Jul 9, 2021

This looks good to me code-wise. Not sure what's going on with CI.

@Spaak
Copy link
Member Author

Spaak commented Jul 9, 2021

This looks good to me code-wise. Not sure what's going on with CI.

Hmm it appears that my use of functools.partial does not play well with pickling (and hence anything that does multiprocessing). Will fix.

@michaelosthege
Copy link
Member

I'm on vacation and won't be able to take a look before Wednesday, so don't wait for my review.

@Spaak
Copy link
Member Author

Spaak commented Jul 13, 2021

This now depends on #4858. Once that's merged I'll rebase this one and CI should hopefully be all green.

@codecov
Copy link

codecov bot commented Jul 13, 2021

Codecov Report

Merging #4849 (78f9da0) into main (0b1ecdb) will increase coverage by 0.61%.
The diff coverage is 89.23%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #4849      +/-   ##
==========================================
+ Coverage   72.11%   72.73%   +0.61%     
==========================================
  Files          85       86       +1     
  Lines       13739    13748       +9     
==========================================
+ Hits         9908     9999      +91     
+ Misses       3831     3749      -82     
Impacted Files Coverage Δ
pymc3/distributions/bart.py 18.05% <ø> (+0.52%) ⬆️
pymc3/distributions/bound.py 32.11% <ø> (+0.96%) ⬆️
pymc3/distributions/simulator.py 25.86% <ø> (+2.67%) ⬆️
pymc3/util.py 73.50% <ø> (+9.26%) ⬆️
pymc3/printing.py 85.85% <85.85%> (ø)
pymc3/__init__.py 100.00% <100.00%> (ø)
pymc3/distributions/distribution.py 80.43% <100.00%> (+11.94%) ⬆️
pymc3/model.py 83.47% <100.00%> (+1.78%) ⬆️
pymc3/smc/sample_smc.py 92.53% <100.00%> (+0.53%) ⬆️
pymc3/math.py 67.85% <0.00%> (-0.49%) ⬇️
... and 3 more

@Spaak Spaak merged commit 8623cd4 into pymc-devs:main Jul 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

String (and graph) representations in v4
3 participants