Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic MuZero issues invalid actions and outcomes #88

Closed
carlosgmartin opened this issue Jan 14, 2024 · 7 comments
Closed

Stochastic MuZero issues invalid actions and outcomes #88

carlosgmartin opened this issue Jan 14, 2024 · 7 comments

Comments

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Jan 14, 2024

Example:

import jax
import mctx
from jax import numpy as jnp, random


def main():
    n_actions = 7
    n_outcomes = 3
    batch_size = 1

    root = mctx.RootFnOutput(  # type: ignore
        prior_logits=jnp.zeros([batch_size, n_actions]),
        value=jnp.zeros(batch_size),
        embedding=jnp.zeros(batch_size),
    )

    def decision_recurrent_fn(params, key, action, state):
        jax.debug.print("action: {}", action)
        afterstate = state
        output = mctx.DecisionRecurrentFnOutput(  # type: ignore
            chance_logits=jnp.zeros([batch_size, n_outcomes]),
            afterstate_value=jnp.zeros(batch_size),
        )
        return output, afterstate

    def chance_recurrent_fn(params, key, outcome, afterstate):
        jax.debug.print("outcome: {}", outcome)
        state = afterstate
        output = mctx.ChanceRecurrentFnOutput(  # type: ignore
            action_logits=jnp.zeros([batch_size, n_actions]),
            value=jnp.zeros(batch_size),
            reward=jnp.zeros(batch_size),
            discount=jnp.ones(batch_size),
        )
        return output, state

    mctx.stochastic_muzero_policy(
        params={},
        rng_key=random.PRNGKey(0),
        root=root,
        decision_recurrent_fn=decision_recurrent_fn,
        chance_recurrent_fn=chance_recurrent_fn,
        num_simulations=20,
    )


if __name__ == "__main__":
    main()

Output:

action: [0]
action: [1]
outcome: [-6]
action: [7]
outcome: [0]
action: [4]
outcome: [-3]
action: [8]
outcome: [1]
action: [5]
outcome: [-2]
action: [6]
outcome: [-1]
action: [0]
outcome: [-7]
action: [3]
outcome: [-4]
action: [2]
outcome: [-5]
action: [9]
outcome: [2]
action: [5]
outcome: [-2]
action: [7]
outcome: [0]
action: [5]
outcome: [-2]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [2]
outcome: [-5]
action: [2]
outcome: [-5]

The actions range from 0 to 9 (7+3=10 in total), even though there are only 7 actions.
The outcomes range from -7 (a negative integer!) to 2 (7+3=10 in total), even though there are only 3 outcomes.

This may have something to do with the math inside stochastic_recurrent_fn.

Version information:

$ python3 --version
Python 3.11.6
$ python3 -c "import mctx; print(mctx.__version__)"
0.0.5
$ python3 -c "import jax; print(jax.__version__)"
0.4.23
$ python3 -c "import jaxlib; print(jaxlib.__version__)"
0.4.23
@evanatyourservice
Copy link

I asked this here, because decision and chance fns alternate the unused portion is discarded (the invalid ones you are seeing wouldn’t be used).

@carlosgmartin
Copy link
Contributor Author

@evanatyourservice Hmm, doesn't that mean computation is being wasted on the invalid actions and outcomes?

@evanatyourservice
Copy link

Yeah, here’s another issue I opened about that. There’s not a super simple workaround right now because jaxlib hasn’t implemented batched cond. Here’s an issue related to that in jax (there are a few). Doesn’t matter the most for mctx if using tpu or gpu because they can be computed for the most part in parallel but the issue is there nonetheless. Becomes a bigger problem when the two branches differ in computation amounts.

Maybe the mcts could be rewritten in a way that only one branch is computed, I didn’t really think into it. If you see a solution please share, I might look into it as sometimes my nets are large and there would be quite the difference in speed/compute.

@fidlej
Copy link
Collaborator

fidlej commented Jan 14, 2024

Thanks @evanatyourservice for explaining the issue.
If you rewrite the MCTS to better support Stochastic MuZero, maybe keep that in your fork. I do not plan big changes to mctx.

@fidlej fidlej closed this as completed Jan 14, 2024
@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Jan 14, 2024

@fidlej According to the JAX docs on out-of-bounds indexing:

Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) will not preserve the semantics of out of bounds indexing. Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.

If out-of-bounds indices being passed into functions that expect valid actions or outcomes could potentially cause undefined behavior somewhere downstream, would it perhaps be a good idea to clip the indices passed to decision_recurrent_fn and chance_recurrent_fn to their valid ranges?

@iamreadyi
Copy link

Thanks @evanatyourservice for explaining the issue. If you rewrite the MCTS to better support Stochastic MuZero, maybe keep that in your fork. I do not plan big changes to mctx.

Will you close enhancement issue about stochastic muzero?

@fidlej
Copy link
Collaborator

fidlej commented Jan 15, 2024

Thanks for noticing. Done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants