Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about policy(...) implementation. #63

Closed
edwhu opened this issue Jun 3, 2023 · 3 comments
Closed

Question about policy(...) implementation. #63

edwhu opened this issue Jun 3, 2023 · 3 comments

Comments

@edwhu
Copy link

edwhu commented Jun 3, 2023

I am interested in adding more behaviors.

In the policy function of agent.py, all behaviors are called before the mode is even checked.

def policy(self, obs, state, mode='train'):
self.config.jax.jit and print('Tracing policy function.')
obs = self.preprocess(obs)
(prev_latent, prev_action), task_state, expl_state = state
embed = self.wm.encoder(obs)
latent, _ = self.wm.rssm.obs_step(
prev_latent, prev_action, embed, obs['is_first'])
self.expl_behavior.policy(latent, expl_state)
task_outs, task_state = self.task_behavior.policy(latent, task_state)
expl_outs, expl_state = self.expl_behavior.policy(latent, expl_state)

This doesn't seem like it will scale if we add a lot of different behaviors.

Is this because we want Jax to trace all variables in the train function, i.e. in the JaxAgent where we get the initial variables of the train function for optimization?

varibs = self._train(varibs, rng, data, state, init_only=True)

@danijar
Copy link
Owner

danijar commented Jun 4, 2023

Yes, that's the reason

@edwhu
Copy link
Author

edwhu commented Jun 6, 2023

Thanks. It seems straightforward to replace the if statements with jax.lax.select or jax.lax.cond logic and move the policy calls into the branches - any particular reason why you chose to use python if statements instead?

@danijar
Copy link
Owner

danijar commented Jun 6, 2023

A cond has runtime overhead (and also wouldn't work with the current API because JAX/GPUs don't support string types: jax-ml/jax#3045).

@danijar danijar closed this as completed Jun 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants