Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] A simple (and effective?) way to support cherry-picked env reset in xla mode #293

Open
bkkgbkjb opened this issue Feb 26, 2024 · 1 comment
Assignees

Comments

@bkkgbkjb
Copy link

Motivation

AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:

  1. envs.reset() method doesn't accept env_ids arguments, unlike sync mode. And even if it does:
  2. (step_env() or send() seem to accept env_ids arguments,) but it's hard to generate dynamic-shaped env_ids array in xla mode and jax. (I've made some preliminary exps. to verify this)

The problem resulted from this, would be incorrect transitions to be appeared: (term_state, any_action, rew -> init_state)
as also pointed out in #194

Solution

I think we could try adding masked_env_ids in reset() and send() methods

SO we could do something like this

obs, _ = envs.reset()
handle, recv, send, step = envs.xla()

while True:

    handle, (obs, rew, terms, truncs, info) = step(handle, some_acts)

    # proposed masked auto-resetting
    auto_reset_masks = jnp.logical_or(terms, truncs)
    _obs = env.reset(masked_env_ids = auto_reset_masks)

    obs = jnp.where(auto_reset_masks, _obs, obs) 

in xla mode

masked_env_ids has static shape of env_nums and would only reset envs of True in masks and return dummy obs for False-ed envs.

Alternative Methods

Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.

But if the proposed solution is correct, I think it's better to have it for elegance.

Additional context

Unfortunately, I'm not an expert in C++ and I'm not sure if the proposed solution, despite simple, would work as expected.
But based on my understanding, this shall be implementable so long as we perform it in C++ processes.

Checklist

@JesseFarebro
Copy link

This would be great to have, as performing exact evaluation like this: #113 (comment) isn't possible to jit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants