Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt.max is not differentiable in PyMC models #7251

Closed
jessegrabowski opened this issue Apr 13, 2024 · 8 comments · Fixed by #7261
Closed

pt.max is not differentiable in PyMC models #7251

jessegrabowski opened this issue Apr 13, 2024 · 8 comments · Fixed by #7261

Comments

@jessegrabowski
Copy link
Member

Description

The following model.dlogp() raises a NotImplemented error:

with pm.Model() as model:
    x = pm.Normal('x', shape=(2,))
    mu = x.max()
    pm.Normal('obs', mu, observed=np.random.uniform())

This should be differentiable (pt.max has gradients implemented), so it seems like something is going wrong in rewrites, either with respect to logp rewrites or with respect to MaxAndArgmax

@bburan
Copy link

bburan commented Apr 15, 2024

min also raises a NotImplementedError.

@ricardoV94
Copy link
Member

min is just implemented as negative of max of negative, so that's expected

@tanish1729
Copy link

hi! i can take up this issue. i can see that this has to sort out dealing with gradients in the rewrites for some specific functions. could you provide some more details for me to start working

@ricardoV94
Copy link
Member

@tanish1729 this one is not really a beginner friendly issue. I'll try and fix it now myself. Let us know if you need help finding more suitable issues

@tanish1729
Copy link

oh great i see you did this yourself. i'll go through the code and see if i can understand it.
what are some other good beginner friendly issues open rn?

@ricardoV94
Copy link
Member

oh great i see you did this yourself. i'll go through the code and see if i can understand it. what are some other good beginner friendly issues open rn?

You can filter issues on Github by labels: https://github.com/pymc-devs/pymc/issues?q=is%3Aissue+is%3Aopen+label%3A%22beginner+friendly%22

@bburan
Copy link

bburan commented Apr 17, 2024

@ricardoV94 Thanks so much for fixing this so quickly. I can verify this solved the issue in my model and I am now able to run my model using NUTS only. Brings runtime down from ~2 hours to 10 minutes (3 minutes if I use an experimental NUTS sampler such as numpyro).

@ricardoV94
Copy link
Member

You're welcome. By the way we don't consider the numpyro integration experimental anymore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants