-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for memoroids (linear recurrent models) #91
base: main
Are you sure you want to change the base?
Conversation
Maybe you can help guide me how to integrate this. I was thinking to copy the recurrent PPO script and replace the LSTM with this. I think the biggest issue is I would need to do some plumbing to get the |
Amazing! This should be quite simple. The |
Sorry, I'm less familiar with the |
no prob. let me just show how the auto-reset API works with an example to give you a better idea. Imagine a 3-timestep environment where each observation (obs) is simply the timestep number. Here is the rollout trajectory you would get considering you are storing transitions (obs, act, rew, done): Trajectory ExampleEpisode 1: t=0:
t=1:
t=2 (Last Timestep, terminal as well as start of episode 2):
As you can see in this final transition the obs, act and reward are actually related to the first timestep of the second episode. So when utilising this as a sequence, we would use the done/discount to mask the bootstrap prediction to be zero and we dont use the action and reward from this timestep. Continuing on we would get: Episode 2 remainder: t=3:
t=4 (Last Timestep, terminal):
ExplanationIn this scenario, we never actually see If you need access to the true final observation, it is available in the Let me know if this helps. Lastly, to just explicitly answer your questions. When timestep.last()==True this means that the observation in that timestep object is the terminal obs however as mentioned above this is not actually returned. |
I'm just leaving a checklist here of things that need to be done:
|
…back the training script that can use ffm
Additionally, I've officially merged the popgym PR so now we can test on popgym envs easily when we feel ready. |
With respect to removing the sequence dimension in h = cell.initial_state()
x = jnp.ones(..)
h, y = cell(h, x) If you run the following snippet, you will see how the time dimension is present in import jax
import jax.numpy as jnp
x = jnp.ones((1024, 4)) #[Time, feature]
W = jnp.ones((4, 4))
def ascanf(x, xp):
print(x.shape, xp.shape)
return xp @ W
jax.lax.associative_scan(fn=ascanf, elems=x)
Let me know how you want to proceed |
hmmm i see, could we implement the squeeze unsqueeze logic only in the outermost architecture thus still allowing the cell to be run on its own. So basically something like: class ScannedMemoroid(nn.Module):
cell: nn.Module
@nn.compact
def __call__(self, recurrent_state, inputs):
### CHANGE HERE
recurrent_state = jax.tree.map(lambda x.unsqueeze(0), recurrent_state)
# Recurrent state should be ((state, timestep), reset)
# Inputs should be (x, reset)
x, _ = inputs
h = self.cell.map_to_h(inputs)
recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h)
# recurrent_state is ((state, timestep), reset)
out = self.cell.map_from_h(recurrent_state, x)
# TODO: Remove this when we want to return all recurrent states instead of just the last one
final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state)
return final_recurrent_state, out
@nn.nowrap
def initialize_carry(
self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None
) -> Carry:
### AND CHANGE HERE
return jax.tree.map(lambda x.squeeze(0), self.cell.initialize_carry(batch_size, rng)) Let me know what you think? EDIT/UPDATE: I just made this change and it allows for the system file to be essentially identical to the rec_ppo system which uses normal RNNs. I think this is ideal as it means for all future recurrent algorithms we dont need to differentiate between memoroids or normal rnns. The only difference currently is the use of the nn.vmap, once we code it to explicitly handle batch dimensions then we can use the rec_ppo file exactly and change the network via the config. So the most pressing change would be to do that. |
…and return a non sequence dimension carry
…orks with memoroid conf
@smorad I've now added the explicit expectation of a batch dimension - the network now works with rec_ppo.py natively simply by changing the network conf. For example we can do as follows now: python stoix/systems/ppo/rec_ppo.py network=memoroid We now need to verify correctness, i do worry that there might be a bug somewhere since the performance on cartpole isn't that good, we get to 200+ quite quickly but it struggles to get to 500. Lastly, i had one concern, i see that the start variable is a part of the carry - is this normal? since ideally we feed the start sequence in via the inputs, not the carry. I'm not sure if it should be there? let me know? |
What?
Implement FFM in flax
Why?
For #54
How?
Simply adds a new model
Extra
More commits coming. Just opening this to keep you in the loop.