-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create a NumPyroNUTS
Op
#4646
Create a NumPyroNUTS
Op
#4646
Conversation
07fc65a
to
d24fbf7
Compare
The un-transforming is now done in the same JAX run as the sampling. |
Actually, it looks like something extra is needed for shared variables (under this current approach). Right now, shared variables used in a log-likelihood graph only ever appear in the "inner" |
d24fbf7
to
23f7a7b
Compare
All right, I figured out a way to use the shared variables from within the "inner" That should be an acceptable limitation, because log-likelihood graphs probably shouldn't have any sort of "state". Plus, if they did, I don't think it would work with PyMC3's samplers either—for roughly the same reasons. |
I tried running this on https://github.com/pymc-devs/pymc-examples/blob/main/examples/case_studies/factor_analysis.ipynb but got:
|
@twiecki, are you running the code on this branch? It looks like you're using Theano-PyMC (e.g. using the alias I just ran this locally and it worked, although the import numpy as np
import arviz as az
import matplotlib
import pymc3 as pm
import aesara.tensor as at
from pymc3.sampling_jax import sample_numpyro_nuts
n = 250
k_true = 5
d = 9
err_sd = 2
N_SAMPLE = 350
M = np.random.binomial(1, 0.25, size=(k_true, n))
Q = np.hstack(
[np.random.exponential(2 * k_true - k, size=(d, 1)) for k in range(k_true)]
) * np.random.binomial(1, 0.75, size=(d, k_true))
Y = np.round(1000 * np.dot(Q, M) + np.random.normal(size=(d, n)) * err_sd) / 1000
k = 2
with pm.Model() as PPCA:
W = pm.Normal("W", size=(d, k))
F = pm.Normal("F", size=(k, n))
psi = pm.HalfNormal("psi", 1.0)
X = pm.Normal("X", mu=at.dot(W, F), sigma=psi, observed=Y)
W_plot = pm.Deterministic("W_plot", W[1:3, 0])
F_plot = pm.Deterministic("F_plot", F[0, 1:3])
trace = sample_numpyro_nuts()
trace.posterior['W_plot'] = trace.posterior.W[:, :, 1:3, 0]
trace.posterior['F_plot'] = trace.posterior.F[:, :, 0, 1:3]
az.plot_trace(trace, ("W_plot", "F_plot", "psi")); |
This is aesara master and this branch, odd. |
Looks like I haven't rebased my local version of this branch since 94213ca, so your issue might've been caused by something introduced at or after that commit if you've rebased it locally. I'll try those exact versions of |
I just got a
|
It appears to work with |
I get the same error with |
Oh, wait, if you're on the |
I guess I should use the version that's pinned. |
I just tried it with Aesara |
Yes, ran the code you posted above. This is on OSX. |
Is your local branch in sync with this remote? |
Yep, I checked it out today.
|
Since it's complaining about a |
Did not help. |
Running this same code on
|
Does the test in this branch pass locally? |
77736e1
to
d950485
Compare
I added a MacOS test for that model and it appears to have passed in CI. |
Alright, updating from this PR it works now! Did you change anything other than adding the OSX test? It's hard to see what's changing here if you rewrite commits. |
Failing test seems to be the one reported here: #4661 |
d950485
to
56e1688
Compare
From the GitHub UI, if you click on the Locally, a |
Yes, it's a flaky test; I believe I've tried to deal with that one before. |
@brandonwillard Can we merge this? |
@twiecki why is "Rebase and merge" disabled in this repo? |
@brandonwillard Because of conflicts, which I thought I had resolved. |
e779e10
to
35482f3
Compare
It wasn't saying there was a conflict on this page, but I rebased the branch itself and that seems to have cleared it up. |
This PR creates a
NumPyroNUTS
Op
for better integration with Aesara's JAX backend. Closes #4142.The current implementation doesn't un-transform the transformed variables.Also, someone who knows the NumPyro functions needs to confirm that the output dimensions are correct.Really, this is just a prototype that finally demonstrates the point I was trying to make in #4142. Someone who has more time/interest should take over this PR and add the finishing touches, because I don't know when I'll get around to doing it myself.