You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This doesn't seem like it will scale if we add a lot of different behaviors.
Is this because we want Jax to trace all variables in the train function, i.e. in the JaxAgent where we get the initial variables of the train function for optimization?
Thanks. It seems straightforward to replace the if statements with jax.lax.select or jax.lax.cond logic and move the policy calls into the branches - any particular reason why you chose to use python if statements instead?
I am interested in adding more behaviors.
In the
policy
function ofagent.py
, all behaviors are called before the mode is even checked.dreamerv3/dreamerv3/agent.py
Lines 51 to 60 in 423291a
This doesn't seem like it will scale if we add a lot of different behaviors.
Is this because we want Jax to trace all variables in the train function, i.e. in the
JaxAgent
where we get the initial variables of the train function for optimization?dreamerv3/dreamerv3/jaxagent.py
Line 228 in 423291a
The text was updated successfully, but these errors were encountered: