You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:
envs.reset() method doesn't accept env_ids arguments, unlike sync mode. And even if it does:
(step_env() or send() seem to accept env_ids arguments,) but it's hard to generate dynamic-shaped env_ids array in xla mode and jax. (I've made some preliminary exps. to verify this)
The problem resulted from this, would be incorrect transitions to be appeared: (term_state, any_action, rew -> init_state)
as also pointed out in #194
Solution
I think we could try adding masked_env_ids in reset() and send() methods
masked_env_ids has static shape of env_nums and would only reset envs of True in masks and return dummy obs for False-ed envs.
Alternative Methods
Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.
But if the proposed solution is correct, I think it's better to have it for elegance.
Additional context
Unfortunately, I'm not an expert in C++ and I'm not sure if the proposed solution, despite simple, would work as expected.
But based on my understanding, this shall be implementable so long as we perform it in C++ processes.
Motivation
AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:
envs.reset()
method doesn't acceptenv_ids
arguments, unlike sync mode. And even if it does:step_env()
orsend()
seem to acceptenv_ids
arguments,) but it's hard to generate dynamic-shapedenv_ids
array inxla
mode andjax
. (I've made some preliminary exps. to verify this)The problem resulted from this, would be incorrect transitions to be appeared:
(term_state, any_action, rew -> init_state)
as also pointed out in #194
Solution
I think we could try adding
masked_env_ids
inreset()
andsend()
methodsSO we could do something like this
in
xla
modemasked_env_ids
has static shape ofenv_nums
and would only reset envs ofTrue
in masks and return dummy obs forFalse
-ed envs.Alternative Methods
Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.
But if the proposed solution is correct, I think it's better to have it for elegance.
Additional context
Unfortunately, I'm not an expert in
C++
and I'm not sure if the proposed solution, despite simple, would work as expected.But based on my understanding, this shall be implementable so long as we perform it in
C++
processes.Checklist
The text was updated successfully, but these errors were encountered: