Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproduction of Muesli #350

Closed
12 tasks
vwxyzjn opened this issue Jan 14, 2023 · 23 comments
Closed
12 tasks

Reproduction of Muesli #350

vwxyzjn opened this issue Jan 14, 2023 · 23 comments

Comments

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jan 14, 2023

Problem Description

Muesli is a next-generation policy gradient algorithm from DeepMind that performs exceptionally well. Notably, it can match MuZero’s SOTA results in Atari without using deep searches such as MCTS. In addition to a more robust objective, Muesli also has a model and can also handle off-policy data and value inaccuracies.

It would be incredibly useful to reproduce Muesli in CleanR, making Muesli more accessible and easy to use. One possible application is to use it to fine-tune LLM, which can help https://github.com/CarperAI/trlx (cc @LouisCastricato). I also think being able to handle off-policy data will make Muesli of special interest for human-in-the-loop RL (cc @cloderic, @saikrishna-1996), given that muesli is like a PPO but can more gracefully deal with off-policy data, such as retroactive rewards enabled by cogment.

With that said, this issue describes a roadmap to reproduce Muesli and associated challenges and opportunities.

Reproduction Analysis

Muesli does look really impressive, and I like the paper a lot! However, I'd expect reproduction to be challenging for two main reasons. First, Muesli is not open-sourced, and there are no other independent reference implementations, so we're going to do this from scratch. Second, Muesli uses DeepMind's Podracer Architecture, which is arguably harder to reproduce.

Roadmap

I think the best way to reproduce it is to work on a semi-synchronous version of Muesli, similar to how OpenAI reproduced A3C as a synchronous A2C first. "semi-synchronous" means to use EnvPool to do rollouts asynchronously and do learning synchronously. This kind of architecture has many benefits: easier to reason, pretty efficient, and can apply to LLM tuning in a fairly straightforward way. We then test and iterate development on Atari. The rough steps are as follows:

  • baseline: produce a PPO baseline on 57 Atari games
  • prototype: develop a prototype Muesli that can reach 791 score in Breakout
  • iterate: then test and iterate this muesli on 57 Atari games.
  • contribute: go through our contribution process
    • The main things we are looking for are 1) single file implementations (minimal lines of code), 2) documentation explaining notable implementation details, 3) benchmarking and matching the performance of reference implementations. Please check out our contribution guide for the usual process, and Add RPO to CleanRL #331 is a good example of how new algorithms are contributed end-to-end.

I will dive into a bit of detail in the next section.

Baseline

We first need to understand the Atari baseline. Muesli's Atari setup can be summarized as follows:

envs = envpool.make(
    env_id,
    env_type="gym",
    num_envs=num_envs,
    batch_size=async_batch_size,
    stack_num=4, # Hessel et al 2022, Muesli paper, Table 10
    img_height=96, # Hessel et al 2022, Muesli paper, Table 4
    img_width=96, # Hessel et al 2022, Muesli paper, Table 4
    episodic_life=False,  # Hessel et al 2022, Muesli paper, Table 4
    repeat_action_probability=0.25,  # Hessel et al 2022, Muesli paper, Table 4
    noop_max=0,  # Hessel et al 2022, Muesli paper, Table 4
    full_action_space=True,  # Hessel et al 2022, Muesli paper, Table 4
    max_episode_steps=int(108000 / 4),  # Hessel et al 2022, Muesli paper, Table 4, we divide by 4 because of the skipped frames
    reward_clip=True,
    seed=seed,
)

Its agent then uses IMPALA CNN (64, 128, 128, 64) with LSTM, and Muesli and its PPO baseline reach ~562.00% and ~300% median HNS (human-normalized score) in 200M frames (50M steps), respectively.

My current best setting to replicate Muesli's PPO is as follows (#338):

 envs = envpool.make(
     env_id,
     env_type="gym",
     num_envs=num_envs,
     batch_size=async_batch_size,
     stack_num=4, # Hessel et al 2022, Muesli paper, Table 10
+    img_height=86, # Hessel et al 2022, Muesli paper, Table 4
+    img_width=86, # Hessel et al 2022, Muesli paper, Table 4
     episodic_life=False,  # Hessel et al 2022, Muesli paper, Table 4
     repeat_action_probability=0.25,  # Hessel et al 2022, Muesli paper, Table 4
+    noop_max=1,  # Hessel et al 2022, Muesli paper, Table 4
     full_action_space=True,  # Hessel et al 2022, Muesli paper, Table 4
     max_episode_steps=int(108000 / 4),  # Hessel et al 2022, Muesli paper, Table 4, we divide by 4 because of the skipped frames
     reward_clip=True,
     seed=seed,
 )

#338 used 86x86 image size to be consistent with existing work such as IMPALA, noop_max=1 because sail-sg/envpool#234. I also used IMPALA CNN (16, 32, 32) because my GPU 3060 TI can only fit (16, 32, 32) and I did not bother to implement LSTM yet... My reproduction in #338 yields

pip install openrlbenchmark --upgrade
# expect the following command to run for hours
python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=envpool-atari&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_atari_envpool_xla_jax_truncation_machado' 'ppo_atari_envpool_async_jax_scan_impalanet_machado'  \
    --env-ids Alien-v5 Amidar-v5 Assault-v5 Asterix-v5 Asteroids-v5 Atlantis-v5 BankHeist-v5 BattleZone-v5 BeamRider-v5 Berzerk-v5 Bowling-v5 Boxing-v5 Breakout-v5 Centipede-v5 ChopperCommand-v5 CrazyClimber-v5 Defender-v5 DemonAttack-v5 DoubleDunk-v5 Enduro-v5 FishingDerby-v5 Freeway-v5 Frostbite-v5 Gopher-v5 Gravitar-v5 Hero-v5 IceHockey-v5 PrivateEye-v5 Qbert-v5 Riverraid-v5 RoadRunner-v5 Robotank-v5 Seaquest-v5 Skiing-v5 Solaris-v5 SpaceInvaders-v5 StarGunner-v5 Surround-v5 Tennis-v5 TimePilot-v5 Tutankham-v5 UpNDown-v5 Venture-v5 VideoPinball-v5 WizardOfWor-v5 YarsRevenge-v5 Zaxxon-v5 Jamesbond-v5 Kangaroo-v5 Krull-v5 KungFuMaster-v5 MontezumaRevenge-v5 MsPacman-v5 NameThisGame-v5 Phoenix-v5 Pitfall-v5 Pong-v5 \
    --check-empty-runs False \
    --ncols 5 \
    --ncols-legend 2 \
    --output-filename machado_50M_impala \
    --scan-history

python -m openrlbenchmark.hns --files machado_50M_impala.csv 

# outputs:
openrlbenchmark/envpool-atari/ppo_atari_envpool_xla_jax_truncation_machado ({})
┣━━ median hns: 1.5679929625118418
┣━━ mean hns: 8.352308370550299
openrlbenchmark/envpool-atari/ppo_atari_envpool_async_jax_scan_impalanet_machado ({})
┣━━ median hns: 1.5167741935483872
┣━━ mean hns: 11.038219990985528
  • ppo_atari_envpool_async_jax_scan_impalanet_machado uses the full action set (18 actions), IMPALA CNN (16, 32, 32), and a new set of un-tuned hyperparameters that uses num_envs=64 and others,
  • ppo_atari_envpool_xla_jax_truncation_machado uses minimal action set, Nature CNN, and openai/baselines Atari hyperparameters.

machado_50M_impala

openrlbenchmark/envpool-atari/ppo_atari_envpool_xla_jax_truncation_machado ({}) openrlbenchmark/envpool-atari/ppo_atari_envpool_async_jax_scan_impalanet_machado ({})
Alien-v5 2626.45 ± 0.00 4354.95 ± 0.00
Amidar-v5 1323.18 ± 0.00 1026.85 ± 0.00
Assault-v5 3225.26 ± 0.00 5073.31 ± 0.00
Asterix-v5 11081.75 ± 0.00 333898.08 ± 0.00
Asteroids-v5 1788.17 ± 0.00 10818.26 ± 0.00
Atlantis-v5 775537.50 ± 0.00 875300.00 ± 0.00
BankHeist-v5 1172.79 ± 0.00 1167.93 ± 0.00
BattleZone-v5 33153.75 ± 0.00 32018.12 ± 0.00
BeamRider-v5 3161.69 ± 0.00 9482.59 ± 0.00
Berzerk-v5 814.51 ± 0.00 633.49 ± 0.00
Bowling-v5 50.89 ± 0.00 29.95 ± 0.00
Boxing-v5 96.49 ± 0.00 99.59 ± 0.00
Breakout-v5 373.85 ± 0.00 470.53 ± 0.00
Centipede-v5 3082.35 ± 0.00 3854.42 ± 0.00
ChopperCommand-v5 12176.00 ± 0.00 22109.44 ± 0.00
CrazyClimber-v5 135736.62 ± 0.00 124531.36 ± 0.00
Defender-v5 57146.62 ± 0.00 68088.55 ± 0.00
DemonAttack-v5 12115.09 ± 0.00 63779.84 ± 56545.29
DoubleDunk-v5 -0.68 ± 0.00 -0.23 ± 0.00
Enduro-v5 1734.94 ± 0.00 2334.03 ± 0.00
FishingDerby-v5 42.35 ± 0.00 39.52 ± 0.00
Freeway-v5 33.63 ± 0.00 33.61 ± 0.00
Frostbite-v5 269.59 ± 0.00 269.77 ± 0.00
Gopher-v5 16318.98 ± 0.00 21028.99 ± 0.00
Gravitar-v5 2695.44 ± 0.00 512.07 ± 0.00
Hero-v5 33900.23 ± 0.00 13957.64 ± 0.00
IceHockey-v5 -4.38 ± 0.00 1.26 ± 0.00
PrivateEye-v5 72.25 ± 0.00 0.00 ± 0.00
Qbert-v5 22940.53 ± 0.00 18828.40 ± 0.00
Riverraid-v5 11789.25 ± 0.00 21917.63 ± 0.00
RoadRunner-v5 57660.25 ± 0.00 44616.17 ± 0.00
Robotank-v5 27.58 ± 0.00 28.30 ± 0.00
Seaquest-v5 1882.95 ± 0.00 955.04 ± 0.00
Skiing-v5 -29998.00 ± 0.00 -12812.83 ± 0.00
Solaris-v5 2067.62 ± 0.00 1913.79 ± 0.00
SpaceInvaders-v5 2849.32 ± 0.00 34592.05 ± 0.00
StarGunner-v5 34239.50 ± 0.00 116080.91 ± 0.00
Surround-v5 6.12 ± 0.00 3.36 ± 0.00
Tennis-v5 -0.35 ± 0.00 -0.29 ± 0.00
TimePilot-v5 10997.38 ± 0.00 31976.45 ± 0.00
Tutankham-v5 305.59 ± 0.00 226.72 ± 0.00
UpNDown-v5 263615.69 ± 0.00 359771.62 ± 0.00
Venture-v5 0.00 ± 0.00 0.00 ± 0.00
VideoPinball-v5 412265.15 ± 0.00 433614.56 ± 0.00
WizardOfWor-v5 11283.50 ± 0.00 7176.12 ± 0.00
YarsRevenge-v5 97739.87 ± 0.00 97599.37 ± 0.00
Zaxxon-v5 16688.62 ± 0.00 0.00 ± 0.00
Jamesbond-v5 522.06 ± 0.00 643.03 ± 0.00
Kangaroo-v5 14603.50 ± 0.00 14197.66 ± 0.00
Krull-v5 9884.79 ± 0.00 8577.02 ± 0.00
KungFuMaster-v5 31035.50 ± 0.00 34155.31 ± 0.00
MontezumaRevenge-v5 0.00 ± 0.00 0.00 ± 0.00
MsPacman-v5 4838.56 ± 0.00 4524.04 ± 0.00
NameThisGame-v5 11958.65 ± 0.00 13315.72 ± 0.00
Phoenix-v5 5685.30 ± 0.00 45747.97 ± 0.00
Pitfall-v5 0.00 ± 0.00 -15.12 ± 0.00
Pong-v5 16.10 ± 0.00 20.19 ± 0.00

The performance does not match the reported performance in Muesli's PPO, and here are some ideas to help us match the baseline:

  • use the same image size 96x96 (we shouldn't worry about noop_max=1 because it's really just 1 frame of difference)
  • use Muesli's network size such as IMPALA CNN (64, 128, 128, 64)
  • implement LSTM
  • try a different set of hyperparameters (maybe don't decrease num_steps to 32)
  • investigate failure cases such as Zaxxon-v5
    ACEF2758-3E4E-44D5-AEEC-89307C6FE938

Prototype

In any case, our PPO in #338 can be a good starting place to implement the semi-synchronous version of Muesli. I'd suggest clone ppo_atari_envpool_async_jax_scan_impalanet_machado.py and create a new file called muesli_atari_envpool_async_jax_scan_machado.py, try implementing muesli and iterate experiments on Breakout-v5 to see if we can replicate the game score of 791. If that's successful we can run a more comprehensive benchmark, and proceed with our contribution process.

Checklist

Current Behavior

Expected Behavior

Possible Solution

Steps to Reproduce

@shermansiu
Copy link

Actually, there's an implementation of Muesli in https://github.com/YuriCat/MuesliJupyterExample, though I'm not sure how good it is.

And Muesli is just MPO + a learned latent dynamics/prediction network model (predicts $r$, $V$, and $\pi$), so we can refer to implementations of MPO (https://github.com/daisatojp/mpo, https://github.com/theogruner/rl_pro_telu) and MuZero, I guess.

@shermansiu
Copy link

This means that improvements from EfficientZero could hypothetically transfer over to Muesli's learned world model, but this would need to be verified experimentally.

@xrsrke
Copy link

xrsrke commented Jan 16, 2023

@vwxyzjn I have just started reading the paper, do you have a plan for when you want to complete it? I will try to finish it on time

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jan 16, 2023

@shermansiu, thanks for the reference! They look like valuable resources. That said, I would like to see more benchmark information. For example, https://github.com/YuriCat/MuesliJupyterExample has a proof-of-concept on Tic-tac-toe, but it might not necessarily work for Atari.

@xrsrke, thanks for your interest! There is no specific timeline at the moment, but I would encourage folks to post related updates here for transparency.

@shermansiu
Copy link

Anyways, I've started work on this.

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jan 17, 2023

Btw I came across @hr0nix's https://github.com/hr0nix/dejax. It may come in handy for muesli's replay buffer.

@Howuhh
Copy link
Contributor

Howuhh commented Jan 19, 2023

FIY: New work from deepmind about online meta-RL also uses Muesli as a base RL algorithm, which is interesting
https://sites.google.com/view/adaptive-agent/

@shermansiu
Copy link

Interesting! Taking a look.

@shermansiu
Copy link

Also, interestingly enough, Torchbeast does not use an LSTM in the Impala network and was able to match or exceed the performance of Impala.

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jan 19, 2023

@Howuhh, thanks for the reference! @shermansiu, note that there are some caveats.

  1. The torchbeast paper's tfimpala is modified from the deepmind/scalable_agent, however tfimpala does not necessarily reproduce the same level of performance as reported in the IMPALA paper. Using Breakout as an example, the IMPALA paper reports 640.43 score for the shallow model, and tfimpala reports ~150 score. It's possible that Deepmind used a different codebase for Atari experiments, it's also possible that different Atari simulators were used which caused the problem (e.g., Deepmind seemed to have used xitari internally).
  2. To more quantitatively measure the performance, we need to use normalized human scores. moolib performs better than torchbeast, but when measured in human-normalized scores, we see that it is a bit lower than expected. See Atari median human-normalized score facebookresearch/moolib#30 (comment) for more detail.
  3. Note that there are additional reproduction issues torchbeast works poorly on atari. facebookresearch/torchbeast#37 Cannot reproduce the performance of "SpaceInvaders" game?  facebookresearch/torchbeast#25 with monobeast, which is different from polybeast.

@shermansiu
Copy link

It turns out that dejax was insufficient for sampling sequences and I couldn't find an existing, alternate implementation. So I implemented my own version.

Using a naive replay buffer that doesn't reorganize updates according to the environment makes sampling sequences difficult. Plus, it makes the implementation of model rollouts (length 5) over a sequence of length 30 difficult...

Surprisingly, implementing a replay buffer that could support envpool's asynchronous updates took longer than expected. The implementation is fully vectorized and mostly jitted, which I am proud of, but sadly, it uses SRSWR, which can increase the variance of the estimator. But I have something ready, with tests.

@hr0nix
Copy link

hr0nix commented Jan 22, 2023

@shermansiu What kind of functionality are you currently missing in dejax? Perhaps that's is something I can add relatively easily.

Do I understand correctly that you need to sample fixed-length chunks of the long trajectories storied in the replay buffer?

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jan 22, 2023

Glad to hear that @shermansiu! Feel free to put your current code in a PR — I might be able to help review or add suggestions.

Are you looking to implement LSTM?

Also, thanks for the comment @hr0nix!

@shermansiu
Copy link

@vwxyzjn Yeah, sure! I think I'm almost done at this point: I'm trying to get something out later today.

@hr0nix, here are the following things I wanted but couldn't get out of dejax:

  • Sampling sequences
  • jitted operations
  • Fully vectorized sampling

Here's my implementation of the replay buffer. And again, the PR should be coming out later today (fingers crossed!)
https://gist.github.com/shermansiu/b492fddf4127f4214d57a647c0160b8f

@shermansiu shermansiu mentioned this issue Jan 24, 2023
20 tasks
@shermansiu
Copy link

@vwxyzjn It took longer than expected, but I made the PR request! I still need to debug and test some things, but at least the code is available. And yes, this version uses an LSTM in the representation network!

@hr0nix
Copy link

hr0nix commented Jan 24, 2023

Hi @shermansiu, can you elaborate on it slightly?

Sampling sequences

dejax currently samples the objects stored in the replay buffer, which might themselves be trajectory chunks (that's how I've been using it). If this doesn't work for you, I assume that you need to store states, but sample segments of consequent states? Can you store trajectories instead, sample a trajectory and then sample a subsegment of that trajectory?

jitted operations

All replay buffer operations in dejax can be jitted and there are tests to ensure that it works.

Fully vectorized sampling

Sampling can be vectorized simply by using vmap.

@shermansiu
Copy link

shermansiu commented Jan 24, 2023

@hr0nix

Sampling sequences

I tried storing trajectories, but they have different shapes, which isn't supported by dejax (see utils.assert_tree_is_batch_of_tree, which is called by circular_buffer:push). This was the main problem I faced. The other points are mainly minor nitpicks. Trying to add vector-padding in an efficient manner is challenging when using envpools asynchronous mode, which is why I didn't use dejax. But I'm welcome to suggestions to how to use it to store vectors of different length!

I'm sure I could use dejax.clustered if more features were added and I'm sure we can refactor the code to use it later, once that happens.

jitted operations

Good to know, thanks!

Fully-vectorized sampling

I meant having operations that are parameterized by native Jax vector operations as opposed to being wrapped by vmap or jax.lax.scan. Once again, this is a nitpick and possibly premature optimization.

Sampling from the same data source, with two different buffer heads

Also, the replay buffer and the online queue in my implementation have the same data source, which would potentially be hacky for a general-purpose buffer?

nit: Buffer operations aren't methods of the buffer state itself. Even jnp ndarrays have convenient methods.

In my implementation, all of the functions are implemented as methods of the buffer. dejax's implementation could be changed to use this, but then the API would not be backwards compatible.

nit: The buffer methods in dejax have the suffix _fn.

This makes it inconsistent with the naming conventions used by the rest of the Jax ecosystem, where the functions are simply init or update.


TLDR The main reason I didn't use dejax was because trajectories have different shapes and it was difficult to use with envpool's asynchronous API (that was one of the first things I tried). I'm welcome to more suggestions on how to use it with dejax, as I'm not a fan of re-inventing the wheel.

All in all, I'd say dejax is a pretty good package and shoutout to @hr0nix for his awesome work on it!

@hr0nix
Copy link

hr0nix commented Jan 24, 2023

Thanks for clarifying! I'll think how to address the issues you listed.

@shermansiu
Copy link

Well, I have an implementation that runs. I'm just not sure if the returns are normal though (about 1-2)?

@shermansiu
Copy link

At least the returns are steadily going up.

@shermansiu
Copy link

Figuring out why the r_loss and v_loss variables are nan.

@shermansiu
Copy link

Solved the loss nan issues, as well as a few other loss-related bugs.

Now, the reward, value, and policy model losses remain relatively constant, and the CMPO regularization term goes up slightly. Just the policy gradient loss goes down. The returns are still about 1-2.

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jan 26, 2023

FWIW, I have been playing with a prototype of the podracer architecture used in the Muesli. Might come in handy if this prototype was successful, and the we can port #354 to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants