This is an official Pytorch implementation of MAMBA: AN EFFECTIVE WORLD MODEL APPROACH FOR META-REINFORCEMENT LEARNING - Zohar Rimon, Tom Jurgenson, Orr Krupnik, Gilad Adler, Aviv Tamar, published at ICLR 2024.
Check out more cool visualizations and details on our website: https://sites.google.com/view/mamba-iclr2024
MAMBA is a new model-based approach to meta- RL, based on elements from existing state-of-the-art model-based and meta-RL methods. We demonstrate the effectiveness of MAMBA on common meta-RL benchmark domains, attaining greater return with better sample efficiency (up to 15×) while requiring very little hyperparameter tuning. In addition, we validate our approach on a slate of more challenging, higher-dimensional domains, taking a step towards real-world generalizing agents
We have a bunch of cool environments implemented in the code! Here are some examples of learned MAMBA policies, all of them are behaving close to the bayes-optimal policy:
For full empirical results and baseline comparisons, please refer to the paper. For more cool visualizations and details, check out our website: https://sites.google.com/view/mamba-iclr2024
If you don't need MuJoCo (i.e simple environments and dm-control environments) , you can install the requirements with:
pip install -r requirements.txt
To use MuJoCo:
- Run
pip install -r requirements.txt
if you haven't already - Download MuJoCo 1.50 from here, together with an activation key and place them in
~/.mujoco/mjpro150
- Install mujoco-py with the following commands:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mjpro150/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia-000
conda install -c conda-forge mesalib glew glfw patchelf
pip install lockfile==0.12.2 cffi==1.15.1 Cython==0.29.36 mujoco-py==1.50.1.68
Run the code, for example, with:
python3 dreamer.py --configs=dmc_proprio_goal --max_episode_length=300 --num_goals=2 --num_meta_episodes=3 --steps=20000000 --task=dmc_reacher_goal --train_ratio=0.03
This will run reacher-4 environment from the paper with 3 episodes in every meta-episode, each of 300 steps. It will run for 20M total env steps with a train ratio of 0.03. The code will log to wandb. You can find all of the parameters in the config.yaml file.
We provide an example of how to run the policy inference and visualize the results in policy_inference.py
.
Currently, the code only supports the panda_reach environment, but it can be easily extended to other environments.
To visualize an experiment, just change the path in the policy_inference to the wandb experiment dir.
To add a new environment:
- Add the env code to the
envs
folder - Add the env to the make_env function in
dreamer.py
together with any needed wrappers (you can use theenvs/wrappers.py
as a reference and the other envs as examples) - Add a new config to the
config.yaml
file
This code is based on the following implementations:
If you find this code useful, please cite our paper: COMING SOON