-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Computing both decision and chance branches of recurrent function in Stochastic MuZero is slow #79
Comments
Unfortunately this doesn't seem to translate well to GPU, even when unrolling the batch using list comprehension instead of scan. I tested it on CPU and it seems to only help in that scenerio. Unless there's something I'm not seeing there may not be a good way to do this without batched cond through XLA :( |
Thanks for sharing your findings. |
The not having batched cond is a little rough sometimes lol I suppose this could be reopened when that comes around |
@fidlej Sorry if you're busy but do you think the newish jax.experimental.sparse api could be useful for this problem? |
Thanks for asking. The multiplication of sparse matrices is probably not a good fit for this conditional computation. |
Hello all! Thank you for creating mctx for the community. I added a performance enhancement to my copy of mctx and am wondering if you're interested in adding it to the official repo.
Computing both decision and chance branches in stochastic muzero every expand is useful if the embeddings have different shapes/dtypes, but is otherwise slow and adds a good bit of overhead (especially for high number of simulations or large networks). I modified the recurrent fn to only have to compute one branch each expand if the decision and chance embeddings are the same struct/shape/dtype.
Overall these are my changes:
In base.py another stochastic recurrent state was added that only holds one embedding:
The modified version of _make_stochastic_recurrent_fn unrolls the batch using scan so lax.cond can be used to only compute one branch:
Finally in the policy function there's a chex check to see if we can use the efficient version:
When the efficient version is used it saves a bit of memory from not having to hold both embeddings, and is faster with a wider performance gap the higher the simulation count or larger the networks. Maybe there is a better option than scan for the unroll, I just know vmap can't be used because cond is converted back to select.
Let me know if you're interested in adding this to mctx and I can make a testing colab to compare performance and would be happy to make a pull request :)
The text was updated successfully, but these errors were encountered: