-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stochastic MuZero issues invalid actions and outcomes #88
Comments
I asked this here, because decision and chance fns alternate the unused portion is discarded (the invalid ones you are seeing wouldn’t be used). |
@evanatyourservice Hmm, doesn't that mean computation is being wasted on the invalid actions and outcomes? |
Yeah, here’s another issue I opened about that. There’s not a super simple workaround right now because jaxlib hasn’t implemented batched cond. Here’s an issue related to that in jax (there are a few). Doesn’t matter the most for mctx if using tpu or gpu because they can be computed for the most part in parallel but the issue is there nonetheless. Becomes a bigger problem when the two branches differ in computation amounts. Maybe the mcts could be rewritten in a way that only one branch is computed, I didn’t really think into it. If you see a solution please share, I might look into it as sometimes my nets are large and there would be quite the difference in speed/compute. |
Thanks @evanatyourservice for explaining the issue. |
@fidlej According to the JAX docs on out-of-bounds indexing:
If out-of-bounds indices being passed into functions that expect valid actions or outcomes could potentially cause undefined behavior somewhere downstream, would it perhaps be a good idea to clip the indices passed to |
Will you close enhancement issue about stochastic muzero? |
Thanks for noticing. Done. |
Example:
Output:
The actions range from 0 to 9 (7+3=10 in total), even though there are only 7 actions.
The outcomes range from -7 (a negative integer!) to 2 (7+3=10 in total), even though there are only 3 outcomes.
This may have something to do with the math inside
stochastic_recurrent_fn
.Version information:
The text was updated successfully, but these errors were encountered: