-
Notifications
You must be signed in to change notification settings - Fork 641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reproduction of Muesli #350
Comments
Actually, there's an implementation of Muesli in https://github.com/YuriCat/MuesliJupyterExample, though I'm not sure how good it is. And Muesli is just MPO + a learned latent dynamics/prediction network model (predicts |
This means that improvements from EfficientZero could hypothetically transfer over to Muesli's learned world model, but this would need to be verified experimentally. |
@vwxyzjn I have just started reading the paper, do you have a plan for when you want to complete it? I will try to finish it on time |
@shermansiu, thanks for the reference! They look like valuable resources. That said, I would like to see more benchmark information. For example, https://github.com/YuriCat/MuesliJupyterExample has a proof-of-concept on Tic-tac-toe, but it might not necessarily work for Atari. @xrsrke, thanks for your interest! There is no specific timeline at the moment, but I would encourage folks to post related updates here for transparency. |
Anyways, I've started work on this. |
Btw I came across @hr0nix's https://github.com/hr0nix/dejax. It may come in handy for muesli's replay buffer. |
FIY: New work from deepmind about online meta-RL also uses Muesli as a base RL algorithm, which is interesting |
Interesting! Taking a look. |
Also, interestingly enough, Torchbeast does not use an LSTM in the Impala network and was able to match or exceed the performance of Impala. |
@Howuhh, thanks for the reference! @shermansiu, note that there are some caveats.
|
It turns out that Using a naive replay buffer that doesn't reorganize updates according to the environment makes sampling sequences difficult. Plus, it makes the implementation of model rollouts (length 5) over a sequence of length 30 difficult... Surprisingly, implementing a replay buffer that could support |
@shermansiu What kind of functionality are you currently missing in dejax? Perhaps that's is something I can add relatively easily. Do I understand correctly that you need to sample fixed-length chunks of the long trajectories storied in the replay buffer? |
Glad to hear that @shermansiu! Feel free to put your current code in a PR — I might be able to help review or add suggestions. Are you looking to implement LSTM? Also, thanks for the comment @hr0nix! |
@vwxyzjn Yeah, sure! I think I'm almost done at this point: I'm trying to get something out later today. @hr0nix, here are the following things I wanted but couldn't get out of
Here's my implementation of the replay buffer. And again, the PR should be coming out later today (fingers crossed!) |
@vwxyzjn It took longer than expected, but I made the PR request! I still need to debug and test some things, but at least the code is available. And yes, this version uses an LSTM in the representation network! |
Hi @shermansiu, can you elaborate on it slightly?
dejax currently samples the objects stored in the replay buffer, which might themselves be trajectory chunks (that's how I've been using it). If this doesn't work for you, I assume that you need to store states, but sample segments of consequent states? Can you store trajectories instead, sample a trajectory and then sample a subsegment of that trajectory?
All replay buffer operations in dejax can be jitted and there are tests to ensure that it works.
Sampling can be vectorized simply by using vmap. |
I tried storing trajectories, but they have different shapes, which isn't supported by I'm sure I could use
Good to know, thanks!
I meant having operations that are parameterized by native Jax vector operations as opposed to being wrapped by
Also, the replay buffer and the online queue in my implementation have the same data source, which would potentially be hacky for a general-purpose buffer?
In my implementation, all of the functions are implemented as methods of the buffer.
This makes it inconsistent with the naming conventions used by the rest of the Jax ecosystem, where the functions are simply TLDR The main reason I didn't use All in all, I'd say |
Thanks for clarifying! I'll think how to address the issues you listed. |
Well, I have an implementation that runs. I'm just not sure if the returns are normal though (about 1-2)? |
At least the returns are steadily going up. |
Figuring out why the |
Solved the loss Now, the reward, value, and policy model losses remain relatively constant, and the CMPO regularization term goes up slightly. Just the policy gradient loss goes down. The returns are still about 1-2. |
FWIW, I have been playing with a prototype of the podracer architecture used in the Muesli. Might come in handy if this prototype was successful, and the we can port #354 to it. |
Problem Description
Muesli is a next-generation policy gradient algorithm from DeepMind that performs exceptionally well. Notably, it can match MuZero’s SOTA results in Atari without using deep searches such as MCTS. In addition to a more robust objective, Muesli also has a model and can also handle off-policy data and value inaccuracies.
It would be incredibly useful to reproduce Muesli in CleanR, making Muesli more accessible and easy to use. One possible application is to use it to fine-tune LLM, which can help https://github.com/CarperAI/trlx (cc @LouisCastricato). I also think being able to handle off-policy data will make Muesli of special interest for human-in-the-loop RL (cc @cloderic, @saikrishna-1996), given that muesli is like a PPO but can more gracefully deal with off-policy data, such as retroactive rewards enabled by cogment.
With that said, this issue describes a roadmap to reproduce Muesli and associated challenges and opportunities.
Reproduction Analysis
Muesli does look really impressive, and I like the paper a lot! However, I'd expect reproduction to be challenging for two main reasons. First, Muesli is not open-sourced, and there are no other independent reference implementations, so we're going to do this from scratch. Second, Muesli uses DeepMind's Podracer Architecture, which is arguably harder to reproduce.
Roadmap
I think the best way to reproduce it is to work on a semi-synchronous version of Muesli, similar to how OpenAI reproduced A3C as a synchronous A2C first. "semi-synchronous" means to use EnvPool to do rollouts asynchronously and do learning synchronously. This kind of architecture has many benefits: easier to reason, pretty efficient, and can apply to LLM tuning in a fairly straightforward way. We then test and iterate development on Atari. The rough steps are as follows:
I will dive into a bit of detail in the next section.
Baseline
We first need to understand the Atari baseline. Muesli's Atari setup can be summarized as follows:
Its agent then uses
IMPALA CNN (64, 128, 128, 64) with LSTM
, and Muesli and its PPO baseline reach ~562.00% and ~300% median HNS (human-normalized score) in 200M frames (50M steps), respectively.My current best setting to replicate Muesli's PPO is as follows (#338):
#338 used 86x86 image size to be consistent with existing work such as IMPALA,
noop_max=1
because sail-sg/envpool#234. I also usedIMPALA CNN (16, 32, 32)
because my GPU 3060 TI can only fit(16, 32, 32)
and I did not bother to implement LSTM yet... My reproduction in #338 yieldsppo_atari_envpool_async_jax_scan_impalanet_machado
uses the full action set (18 actions),IMPALA CNN (16, 32, 32)
, and a new set of un-tuned hyperparameters that usesnum_envs=64
and others,ppo_atari_envpool_xla_jax_truncation_machado
uses minimal action set,Nature CNN
, andopenai/baselines
Atari hyperparameters.The performance does not match the reported performance in Muesli's PPO, and here are some ideas to help us match the baseline:
noop_max=1
because it's really just 1 frame of difference)IMPALA CNN (64, 128, 128, 64)
num_steps
to 32)Prototype
In any case, our PPO in #338 can be a good starting place to implement the semi-synchronous version of Muesli. I'd suggest clone
ppo_atari_envpool_async_jax_scan_impalanet_machado.py
and create a new file calledmuesli_atari_envpool_async_jax_scan_machado.py
, try implementing muesli and iterate experiments onBreakout-v5
to see if we can replicate the game score of 791. If that's successful we can run a more comprehensive benchmark, and proceed with our contribution process.Checklist
poetry install
(see CleanRL's installation guideline.Current Behavior
Expected Behavior
Possible Solution
Steps to Reproduce
The text was updated successfully, but these errors were encountered: