Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple stick breaking #4129

Merged
merged 33 commits into from
Sep 27, 2020
Merged

Simple stick breaking #4129

merged 33 commits into from
Sep 27, 2020

Conversation

katosh
Copy link
Contributor

@katosh katosh commented Sep 23, 2020

This is another attempt to introduce a new transformation of the n-simplex. The stickbreaking transformation is prominently used by the Dirichlet distribution as it maps the range of the Distribution (the n-simplex) to R^(n-1) where we can sample freely and apply, e.g., ADVI. The issue with the current implementation is that the transformation of later values in the vector depends on previous values. This introduces a dependency that can be confounding for ADVI and seems to produce numerical inaccuracies in some cases.

There was a previous attempt to merge the new transformation but it had a mistake in the determinant of the jacobian: #3638

The current strikebreaking in master is an implementation of the transformation from Stan: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html Which is just a repeated application of the logit transformation with adjusting range.

Advantages

Current StickBreaking

import pymc3 as pm
import pandas as pd

with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(10)*5e-3, shape=10)
    trace1 = pm.sample()
pd.DataFrame(trace1['decomp_stickbreaking__']).plot.kde(figsize=(10,4));

image

New StickBreaking2

import pymc3 as pm
import pandas as pd
from pymc3.distributions.transforms import StickBreaking2

with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(10)*5e-3, shape=10,
                          transform=StickBreaking2())
    trace2 = pm.sample()
pd.DataFrame(trace2['decomp_stickbreaking__']).plot.kde(figsize=(10,4));

image

The PR includes tests and there are no breaking changes as it only introduces a new transformation pymc3.distributions.transforms.StickBreaking2 and leaves the original pymc3.distributions.transforms.StickBreaking untouched.

@codecov
Copy link

codecov bot commented Sep 23, 2020

Codecov Report

Merging #4129 into master will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #4129   +/-   ##
=======================================
  Coverage   88.74%   88.74%           
=======================================
  Files          89       89           
  Lines       14037    14024   -13     
=======================================
- Hits        12457    12446   -11     
+ Misses       1580     1578    -2     
Impacted Files Coverage Δ
pymc3/distributions/transforms.py 97.70% <100.00%> (+0.25%) ⬆️
pymc3/distributions/continuous.py 92.93% <0.00%> (+0.11%) ⬆️

@katosh
Copy link
Contributor Author

katosh commented Sep 24, 2020

I will remove the NumPy implementation of the backward transformation StickBreaking2.backards_val since it is not tested and not implemented for any other transformation.

@katosh
Copy link
Contributor Author

katosh commented Sep 24, 2020

I investigated sampling divergencies in the examples above. I changed the parameter for the Dirichlet distribution to np.ones(10)*1e-2 so not all NUTS samples diverge. Then I used the pairplot_divergence analogous to https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html.

import matplotlib.pyplot as plt

def pairplot_divergence(trace, var1, var2, i1=0, i2=0):
    v1 = trace.get_values(varname=var1, combine=True)[:, i1]
    v2 = trace.get_values(varname=var2, combine=True)[:, i2]
    _, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.plot(v1, v2, 'o', color='b', alpha=.5)
    divergent = trace['diverging']
    ax.plot(v1[divergent], v2[divergent], 'o', color='r')
    ax.set_xlabel('{}[{}]'.format(var1, i1))
    ax.set_ylabel('{}[{}]'.format(var2, i2))
    ax.set_title('scatter plot between {}[{}] and {}[{}]'.format(var1, i1, var2, i2));
    return ax

Current StickBreaking

pairplot_divergence(trace1, 'decomp', 'decomp', i1=2, i2=3)

image

New StickBreaking2

pairplot_divergence(trace2, 'decomp', 'decomp', i1=2, i2=3)

image

Conclusion

The parameterization from StickBreaking2 can cure divergencies and hence avoid biases in some cases.

@katosh
Copy link
Contributor Author

katosh commented Sep 24, 2020

It seems forward_val is sometimes called with an argument point, e.g., here:
https://github.com/pymc-devs/pymc3/blob/ba77d8502704e8aeb112782ee104fb339393cb19/pymc3/util.py#L183
So I will include the ignored parameter in StickBreaking2.forward_val.

Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this alternative stickbreaking seems fine, but why wouldn't it entirely replace the old one?

Also, if possible, this PR should add a new test that confirms one of the advantages of this transform over the old one. Ideally, such a test wouldn't require anything as costly as sampling. Is there a value range that demonstrates the improved numerical stability?

RELEASE-NOTES.md Outdated Show resolved Hide resolved
@twiecki
Copy link
Member

twiecki commented Sep 26, 2020

@katosh I would then still take the eps kwarg and raise a deprecation error.

@katosh
Copy link
Contributor Author

katosh commented Sep 26, 2020

Code coverage is reduced since:

  1. The deprecation warning is not tested.
  2. The length of the completely covered code is reduced.

@brandonwillard
Copy link
Contributor

  1. The deprecation warning is not tested.

You can add a noqa to that line.

@MarcoGorelli
Copy link
Contributor

MarcoGorelli commented Sep 26, 2020

  1. The deprecation warning is not tested.

You can add a noqa to that line.

why not write a test with

with pytest.warns(DeprecationWarning("<warning text>")):
    <test which sets `eps` parameter>

which covers it?

@brandonwillard
Copy link
Contributor

why not write a test with

You can definitely do that, but we're not really testing much of our own code in this case, so it's not a particularly relevant unit test.

@katosh
Copy link
Contributor Author

katosh commented Sep 26, 2020

already done the test :)

@katosh
Copy link
Contributor Author

katosh commented Sep 26, 2020

I tested how close to the edge of the simplex we can go before the transformation starts to break and for the cases I tested it seems to work down to the smallest float64:

>>> import numpy as np
>>> from pymc3.distributions.transforms import stick_breaking
>>> a = 5e-324
>>> vec = np.array([a, a, a, 1-(3*a)]) # a point very close to the edge of the 4-simplex
>>> stick_breaking.backward(stick_breaking.forward(vec).eval()).eval()
array([5.e-324, 5.e-324, 5.e-324, 1.e+000]) # very close to vec

However, the same can possibly not be said about the jacobian!

@katosh
Copy link
Contributor Author

katosh commented Sep 27, 2020

I investigated the jacobian by plotting its values for points StickBreaking.forwad(Simplex(3)) with

import itertools
import numpy as np
import theano
from pymc3.distributions.transforms import stick_breaking
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

def plot_jacobian_det(line):
    xl = list()
    yl = list()
    zl = list()
    p = theano.tensor.dvector('p')
    jd = theano.function([p], stick_breaking.jacobian_det(p))

    for x, y in itertools.product(line, repeat=2):
        xl.append(x)
        yl.append(y)
        log_jacobian_det = jd(np.array([x, y]))
        zl.append(log_jacobian_det)

    x = np.stack(xl)
    y = np.stack(yl)
    z = np.stack(zl)

    fig = plt.figure()
    ax = Axes3D(fig)
    ax.plot_trisurf(x, y, z, cmap=cm.jet)

I looked at the values in the center of the simplex:

plot_jacobian_det(np.linspace(-1, 1, 100))

image

and all the way to the edge

plot_jacobian_det(np.linspace(-500, 500, 100))

image
Note that stick_breaking.backward(np.array([-350, 350])).eval() is array([9.85967654e-305, 1.00000000e+000, 9.92959040e-153]) so we go as close to the edge of the simplex as possible with float64.

There are no numerical issues apparent and the log-determinant of the jacobian seems to behave as expected down to values so close to the edge of the simplex that they cannot be distinguished by the Dirichlet distribution (I belive they are mapped to the simplex befor Dirichlet.logp is evaluated). But this is of course not an exhaustive investigation.

@twiecki
Copy link
Member

twiecki commented Sep 27, 2020

@katosh This looks great and quite thorough. Is there anything missing before merging from your end?

@katosh
Copy link
Contributor Author

katosh commented Sep 27, 2020

I am done so far but of course, I can do further testing if someone has a request.

@twiecki
Copy link
Member

twiecki commented Sep 27, 2020

I think this is great, thanks so much for the contribution!

@twiecki twiecki merged commit fd76e96 into pymc-devs:master Sep 27, 2020
@katosh
Copy link
Contributor Author

katosh commented Sep 27, 2020

Awesome, thank you for having me be part of this project!

@helmutsimon
Copy link
Contributor

It appears that StickBreaking.forward_val is being eliminated, with no equivalent in the new version. This would concern me, as I happen to use it in a public repository. I could work around it, but perhaps there are others using it also. Is there any depreciation warning in the meantime? I only found out about this because I wanted to add a backward_val.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 24, 2021

It appears that StickBreaking.forward_val is being eliminated, with no equivalent in the new version. This would concern me, as I happen to use it in a public repository. I could work around it, but perhaps there are others using it also. Is there any depreciation warning in the meantime? I only found out about this because I wanted to add a backward_val.

Do you mind opening a separate issue for that? This one is pretty long and the forward_val were removed for all distributions not just StickBreaking

@helmutsimon
Copy link
Contributor

Do you mind opening a separate issue for that? This one is pretty long and the forward_val were removed for all distributions not just StickBreaking

See discourse topic here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants