Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for memoroids (linear recurrent models) #91

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

smorad
Copy link

@smorad smorad commented Jun 16, 2024

What?

Implement FFM in flax

Why?

For #54

How?

Simply adds a new model

Extra

More commits coming. Just opening this to keep you in the loop.

@smorad smorad changed the title Add support for memoroids (linear recurrent modesl) Add support for memoroids (linear recurrent models) Jun 16, 2024
@smorad
Copy link
Author

smorad commented Jun 16, 2024

Maybe you can help guide me how to integrate this. I was thinking to copy the recurrent PPO script and replace the LSTM with this. I think the biggest issue is I would need to do some plumbing to get the start flag used in memoroids. Basically, start should be 1 at the initial timestep of an episode and zero otherwise.

@EdanToledo
Copy link
Owner

EdanToledo commented Jun 17, 2024

Amazing! This should be quite simple. The timestep.last() function checks if the current time step is the last one. In the current auto-reset API, when an episode finishes and a new one begins, it automatically resets and returns the first observation of the new episode. This means that timestep.last() will indicate the beginning of a new observation (i.e., the "start"), but the reward it returns is from the last time step of the previous episode. You can observe this behavior in the scanned RNN network, where we use this function to reset the hidden state. We want the hidden state to be zeros for the first observation of each new episode. Let me know if that makes sense.

@smorad
Copy link
Author

smorad commented Jun 17, 2024

Sorry, I'm less familiar with the dm_env step format. I know with gym, the done flag denotes that the following observation is the terminal observation. With your format, is the done flag (discount == 0) set at the same time as timestep.last() == True? Obviously it's usually either one or the other that's set, but for the purposes of off-by-one errors are they equivalent?

@EdanToledo
Copy link
Owner

no prob. let me just show how the auto-reset API works with an example to give you a better idea. Imagine a 3-timestep environment where each observation (obs) is simply the timestep number. Here is the rollout trajectory you would get considering you are storing transitions (obs, act, rew, done):

Trajectory Example

Episode 1:

t=0:

  • obs = 0
  • act = any (action taken using obs = 0 as input)
  • rew = reward obtained from taking act using obs = 0 as input
  • done = False

t=1:

  • obs = 1
  • act = any (action taken using obs = 1 as input)
  • rew = reward obtained from taking act using obs = 1 as input
  • done = False

t=2 (Last Timestep, terminal as well as start of episode 2):

  • obs = 0 (we auto-reset to the first observation of the next episode)
  • act = any (this would now be the action taken using obs = 0 as input)
  • rew = this would now be the reward obtained from taking act using obs = 0 as input
  • done = True (indicated by timestep.last or discount == 0)

As you can see in this final transition the obs, act and reward are actually related to the first timestep of the second episode. So when utilising this as a sequence, we would use the done/discount to mask the bootstrap prediction to be zero and we dont use the action and reward from this timestep.

Continuing on we would get:

Episode 2 remainder:

t=3:

  • obs = 1
  • act = any (action taken using obs = 1 as input)
  • rew = reward obtained from taking act using obs = 1 as input
  • done = False

t=4 (Last Timestep, terminal):

  • obs = 0 (auto-reset to the first observation of the next episode)
  • act = any (action taken using obs = 0 as input)
  • rew = reward obtained from taking act using obs = 0 as input
  • done = True (indicated by timestep.last or discount == 0)

Explanation

In this scenario, we never actually see obs = 2, which is the terminal observation. For non-truncating environments, the done flags (timestep.last or discount == 0 in this case) and discounts allow us to mask the bootstrap value when bootstrapping from the second obs = 0 prediction. This ensures it doesn't matter that it’s not the true final observation. This does involve care when utilising these sequences to construct value targets however its not that complicated - usually you just chop off the last timesteps reward and action values when doing things. You can see examples of this in any off policy algorithm that uses sequences (see MPO target value construction).

If you need access to the true final observation, it is available in the extras object. This does involve consideration on whether or not your environment truncates or not. For 1-step transitions like dqn I simply save the o_t and o_t+1 which eliminates any possible issue as i can use true observations. The reason for auto resetting like this is that it removes any dummy rewards and discounts for example when going from terminal obs to starting obs, there wouldn't be a real reward action, or discount.

Let me know if this helps.

Lastly, to just explicitly answer your questions. When timestep.last()==True this means that the observation in that timestep object is the terminal obs however as mentioned above this is not actually returned.

@EdanToledo
Copy link
Owner

EdanToledo commented Jun 23, 2024

I'm just leaving a checklist here of things that need to be done:

  • Explicit catering of batch dimension i.e. not relying on flax.nn.vmap
  • When feeding in a starting carry i would like it to not need a sequence dimension i.e just a batch dimension and feature dimension. It feels more natural in my head to feed it like this - if we decide otherwise we need to change it for the other RNN classes to expect it this way.
  • Add one more type of cell (but only one more since i think we leave others for different PRs) just to check how general the infrastructure is.
  • Test on popgym to ensure it works

@EdanToledo
Copy link
Owner

Additionally, I've officially merged the popgym PR so now we can test on popgym envs easily when we feel ready.

@smorad
Copy link
Author

smorad commented Jun 24, 2024

With respect to removing the sequence dimension in initialize_carry: I think you would be unable to run the cell without a singleton sequence dimension. I can implement a squeeze/unsqueeze in the FFM module, but again, this would mean that you could not run the cell on its own like follows:

h = cell.initial_state()
x = jnp.ones(..)
h, y = cell(h, x)

If you run the following snippet, you will see how the time dimension is present in FFMCell.__call___. Unlike a standard scan, the scanned function operates over more than one element at a time.

import jax
import jax.numpy as jnp

x = jnp.ones((1024, 4)) #[Time, feature]
W = jnp.ones((4, 4))

def ascanf(x, xp):
  print(x.shape, xp.shape)
  return xp @ W

jax.lax.associative_scan(fn=ascanf, elems=x)
(512, 4) (512, 4)
(256, 4) (256, 4)
(128, 4) (128, 4)
(64, 4) (64, 4)
(32, 4) (32, 4)
(16, 4) (16, 4)
(8, 4) (8, 4)
(4, 4) (4, 4)
(2, 4) (2, 4)
(1, 4) (1, 4)
(0, 4) (0, 4)
(1, 4) (1, 4)
(3, 4) (3, 4)
(7, 4) (7, 4)
(15, 4) (15, 4)
(31, 4) (31, 4)
(63, 4) (63, 4)
(127, 4) (127, 4)
(255, 4) (255, 4)
(511, 4) (511, 4)

Let me know how you want to proceed

@EdanToledo
Copy link
Owner

EdanToledo commented Jun 24, 2024

hmmm i see, could we implement the squeeze unsqueeze logic only in the outermost architecture thus still allowing the cell to be run on its own. So basically something like:

class ScannedMemoroid(nn.Module):
    cell: nn.Module

    @nn.compact
    def __call__(self, recurrent_state, inputs):
        ### CHANGE HERE
        recurrent_state = jax.tree.map(lambda x.unsqueeze(0), recurrent_state)
        # Recurrent state should be ((state, timestep), reset)
        # Inputs should be (x, reset)
        x, _ = inputs
        h = self.cell.map_to_h(inputs)
        recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h)
        # recurrent_state is ((state, timestep), reset)
        out = self.cell.map_from_h(recurrent_state, x)

        # TODO: Remove this when we want to return all recurrent states instead of just the last one
        final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state)
        return final_recurrent_state, out

    @nn.nowrap
    def initialize_carry(
        self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None
    ) -> Carry:
        ### AND CHANGE HERE
        return jax.tree.map(lambda x.squeeze(0), self.cell.initialize_carry(batch_size, rng))

Let me know what you think?

EDIT/UPDATE:

I just made this change and it allows for the system file to be essentially identical to the rec_ppo system which uses normal RNNs. I think this is ideal as it means for all future recurrent algorithms we dont need to differentiate between memoroids or normal rnns. The only difference currently is the use of the nn.vmap, once we code it to explicitly handle batch dimensions then we can use the rec_ppo file exactly and change the network via the config. So the most pressing change would be to do that.

@EdanToledo
Copy link
Owner

@smorad I've now added the explicit expectation of a batch dimension - the network now works with rec_ppo.py natively simply by changing the network conf. For example we can do as follows now: python stoix/systems/ppo/rec_ppo.py network=memoroid

We now need to verify correctness, i do worry that there might be a bug somewhere since the performance on cartpole isn't that good, we get to 200+ quite quickly but it struggles to get to 500. Lastly, i had one concern, i see that the start variable is a part of the carry - is this normal? since ideally we feed the start sequence in via the inputs, not the carry. I'm not sure if it should be there? let me know?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants