-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Embarrassing parallel inference #494
Comments
Hi @ramav87, I am not sure if dask is compatible with JAX. Could you try to test it first? Probably you would need to use mp strategies Otherwise, you can use
|
@ramav87 I think that using pmap works for your case. But I am not sure if there will be any performance issue so feel free to open a separate issue if you observe something strange. |
Yes, this works for me. Thank you so much!!! |
@fehiepsi, is there a way to implement the For example:
will throw an error like
I've tried working around this to set up boolean masks to select rows from the dataframe, but then start running into |
does
works for you? |
Nope.
As far as I can tell, Jax doesn't want me to use |
I'm totally open to other methods to This also fails regardless of whether the iterable I pass is a list, array, or DeviceArray:
with
|
Hi @d-diaz, I'm not sure what you are asking for. If you want to use We don't support what jax team does not support, like pmap through a list. It's also hard for us to understand the error that you got without reproducible code. Please feel free to open a thread in our forum for discussions. If you think there are issues here, please feel free to open a separate thread on github. |
I'm trying to run MCMC chains on separate shards of data, with each data shard sent to a different device. The model in your comment doesn't take any inputs, so I can't figure out how I'm supposed to pass different shards of data to the model on each processor. |
Hi @d-diaz, in that code we
|
A problem I have been unable to solve in pymc3 and numpyro is use of embarrassingly parallel for loops (e.g., with dask): assume we have a list of observations, and I am trying
where the function
compute_nuts
does MCMC and returns the trace. Is there any way to do this at the moment with numpyro? Even when I set the jax to only use the cpu, dask gives an assertion error:This is obviously not surprising, but is there a workaround for this kind of task?
The text was updated successfully, but these errors were encountered: