-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Alex Rutherford
authored and
Alex Rutherford
committed
Dec 5, 2024
1 parent
0d4b1f3
commit 939c9b6
Showing
86 changed files
with
21,622 additions
and
6 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# IPPO & MAPPO | ||
|
||
# IPPO Baseline | ||
|
||
Pure JAX IPPO implementation, based on the PureJaxRL PPO implementation. | ||
|
||
## 🔎 Implementation Details | ||
General features: | ||
|
||
* Agents are controlled by a single network architecture (either FF or RNN). | ||
* Parameters are shared between agents. | ||
|
||
## 🚀 Usage | ||
|
||
If you have cloned JaxMARL and are in the repository root, you can run the algorithms as scripts, e.g. | ||
```bash | ||
python baselines/IPPO/ippo_rnn_smax.py | ||
``` | ||
Each file has a distinct config file which resides within [`config`](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/IPPO/config). | ||
The config file contains the IPPO hyperparameters, the environment's parameters and for some config files the `wandb` details (`wandb` is disabled by default). | ||
|
||
# MAPPO Baseline | ||
|
||
Pure JAX MAPPO implementation, based on the PureJaxRL PPO implementation. | ||
|
||
## 🔎 Implementation Details | ||
General features: | ||
|
||
* Agents are controlled by a single network architecture (either FF or RNN). | ||
* Parameters are shared between agents. | ||
* Each script has a `WorldStateWrapper` which provides a global `"world_state"` observation. | ||
|
||
## 🚀 Usage | ||
|
||
If you have cloned JaxMARL and are in the repository root, you can run the algorithms as scripts, e.g. | ||
```bash | ||
python baselines/MAPPO/mappo_rnn_smax.py | ||
``` | ||
Each file has a distinct config file which resides within [`config`](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/MAPPO/config). | ||
The config file contains the MAPPO hyperparameters, the environment's parameters and the `wandb` details (`wandb` is disabled by default). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# QLearning | ||
|
||
Pure JAX implementations of: | ||
|
||
* PQN-VDN (Prallelised Q-Network) | ||
* IQL (Independent Q-Learners) | ||
* VDN (Value Decomposition Network) | ||
* QMIX | ||
* TransfQMix (Transformers for Leveraging the Graph Structure of MARL Problems) | ||
* SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning) | ||
|
||
PQN implementation follows [purejaxql](https://github.com/mttga/purejaxql). IQL, VDN and QMix follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning). | ||
|
||
|
||
Standard algorithms (iql, vdn, qmix) support: | ||
|
||
- MPE | ||
- SMAX | ||
- Overcooked (qmix not supported) | ||
|
||
PQN-VDN supports: | ||
|
||
- MPE | ||
- SMAX | ||
- Hanabi | ||
- Overcooked | ||
|
||
**At the moment, PQN-VDN should be the most performant baseline for Q-Learning in terms of returns and training speed.** | ||
|
||
❗ TransfQMix and Shaq still use an old implementation of the scripts and need refactoring to match the new format. | ||
|
||
|
||
## ⚙️ Implementation Details | ||
|
||
All the algorithms take advantage of the `CTRolloutManager` environment wrapper (found in `jaxmarl.wrappers.baselines`), which is used to: | ||
|
||
- Batchify the step and reset functions to run parallel environments. | ||
- Add a global observation (`obs["__all__"]`) and a global reward (`rewards["__all__"]`) to the returns of `env.step` for centralized training. | ||
- Preprocess and uniform the observation vectors (flatten, pad, add additional features like id one-hot encoding, etc.). | ||
|
||
You might want to modify this wrapper for your needs. | ||
|
||
## 🚀 Usage | ||
|
||
If you have cloned JaxMARL and you are in the repository root, you can run the algorithms as scripts. You will need to specify which parameter configurations will be loaded by Hydra by choosing them (or adding yours) in the config folder. Below are some examples: | ||
|
||
```bash | ||
# vdn rnn in in mpe spread | ||
python baselines/QLearning/vdn_rnn.py +alg=ql_rnn_mpe | ||
# independent IQL rnn in competetive simple_tag (predator-prey) | ||
python baselines/QLearning/iql_rnn.py +alg=ql_rnn_mpe alg.ENV_NAME=MPE_simple_tag_v3 | ||
# QMix with SMAX | ||
python baselines/QLearning/qmix_rnn.py +alg=ql_rnn_smax | ||
# VDN overcooked | ||
python baselines/QLearning/vdn_cnn_overcooked.py +alg=ql_cnn_overcooked alg.ENV_KWARGS.LAYOUT=counter_circuit | ||
# TransfQMix | ||
python baselines/QLearning/transf_qmix.py +alg=transf_qmix_smax | ||
|
||
# pqn feed-forward in MPE | ||
python baselines/QLearning/pqn_vdn_ff.py +alg=pqn_vdn_ff_mpe | ||
# pqn feed-forward in hanabi | ||
python baselines/QLearning/pqn_vdn_ff.py +alg=pqn_vdn_ff_hanabi | ||
# pqn CNN in overcooked | ||
python baselines/QLearning/pqn_vdn_cnn_overcooked.py +alg=pqn_vdn_cnn_overcooked | ||
# pqn with RNN in SMAX | ||
python baselines/QLearning/pqn_vdn_rnn.py +alg=pqn_vdn_rnn_smax | ||
``` | ||
|
||
Notice that with Hydra, you can modify parameters on the go in this way: | ||
|
||
```bash | ||
# change learning rate | ||
python baselines/QLearning/iql_rnn.py +alg=ql_rnn_mpe alg.LR=0.001 | ||
# change overcooked layout | ||
python baselines/QLearning/pqn_vdn_cnn_overcooked.py +alg=pqn_vdn_cnn_overcooked alg.ENV_KWARGS.LAYOUT=counter_circuit | ||
# change smax map | ||
python baselines/QLearning/pqn_vdn_rnn.py +alg=pqn_vdn_rnn_smax alg.MAP_NAME=5m_vs_6m | ||
``` | ||
|
||
Take a look at [`config.yaml`](./config/config.yaml) for the default configuration when running these scripts. There you can choose how many seeds to vmap and you can setup WANDB. | ||
|
||
**❗Note on Transformers**: TransfQMix currently supports only MPE_Spread and SMAX. You will need to wrap the observation vectors into matrices to use transformers in other environments. See: ```jaxmarl.wrappers.transformers``` | ||
|
||
## 🎯 Hyperparameter tuning | ||
|
||
All the scripts include a tune function to perform hyperparameter tuning. To use it, set `"HYP_TUNE": True` in the `config.yaml` and set the hyperparameters spaces in the tune function. For more information, check [wandb documentation](https://docs.wandb.ai/guides/sweeps). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Coin | ||
|
||
JaxMARL contains an implementation of the Coin Game environment presented in [Model-Free Opponent Shaping (Lu et al.)](https://arxiv.org/abs/2205.01447). The original description and usage of the environment is from [Maintaining cooperation in complex social dilemmas using deep reinforcement learning (Lerer et al.)](https://arxiv.org/abs/1707.01068), and [Learning with Opponent-Learning Awareness (Foerster et al.)](https://arxiv.org/abs/1709.04326) is the first to popularize its use for opponent shaping. A description from Model-Free Opponent Shaping: | ||
|
||
``` | ||
The Coin Game is a multi-agent grid-world environment that simulates social dilemmas like the IPD but with high dimensional dynamic states. First proposed by Lerer & Peysakhovich (2017), the game consists of two players, labeled red and blue respectively, who are tasked with picking up coins, also labeled red and blue respectively, in a 3x3 grid. If a player picks up any coin by moving into the same position as the coin, they receive a reward of +1. However, if they pick up a coin of the other player’s color, the other player receives a reward of −2. Thus, if both agents play greedily and pick up every coin, the expected reward for both agents is 0. | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Hanabi | ||
|
||
This directory contains a MARL environment for the cooperative card game, Hanabi, implemented in JAX. It is inspired by the popular [Hanabi Learning Environment (HLE)](https://arxiv.org/pdf/1902.00506.pdf), but intended to be simpler to integrate and run with the growing ecosystem of JAX implemented RL research pipelines. | ||
|
||
|
||
## Action Space | ||
Hanabi is a turn-based game. The current player can choose to discard or play any of the cards in their hand, or hint a colour or rank to any one of their teammates. | ||
|
||
## Observation Space | ||
The observations closely follow the featurization in the HLE. Each observation is comprised of 658 features: | ||
|
||
* **Hands (127)**: information about the visible hands. | ||
* other player hand: 125 | ||
* card 0: 25, | ||
* card 1: 25 | ||
* card 2: 25 | ||
* card 3: 25 | ||
* card 4: 25 | ||
* Hands missing card: 2 (one-hot) | ||
|
||
* **Board (76)**: encoding of the public information visible in the board. | ||
* Deck: 40, thermometer | ||
* Fireworks: 25, one-hot | ||
* Info Tokens: 8, thermometer | ||
* ife Tokens: 3, thermometer | ||
|
||
* **Discards (50)**: encoding of the cards in the discard pile. | ||
* Colour R: 10 bits for each card | ||
* Colour Y: 10 bits for each card | ||
* Colour G: 10 bits for each card | ||
* Colour W: 10 bits for each card | ||
* Colour B: 10 bits for each card | ||
|
||
* **Last Action (55)**: encoding of the last move of the previous player. | ||
* Acting player index, relative to yourself: 2, one-hot | ||
* MoveType: 4, one-hot | ||
* Target player index, relative to acting player: 2, one-hot | ||
* Color revealed: 5, one-hot | ||
* Rank revealed: 5, one-hot | ||
* Reveal outcome 5 bits, each bit is 1 if the card was hinted at | ||
* Position played/discarded: 5, one-hot | ||
* Card played/discarded 25, one-hot | ||
* Card played scored: 1 | ||
* Card played added info token: 1 | ||
|
||
* **V0 belief (350)**: trivially-computed probability of being a specific car (given the played-discarded cards and the hints given), for each card of each player. | ||
* Possible Card (for each card): 25 (* 10) | ||
* Colour hinted (for each card): 5 (* 10) | ||
* Rank hinted (for each card): 5 (* 10) | ||
|
||
## Pretrained Models | ||
|
||
We make available to use some pretrained models. For example you can use a jax conversion of the original R2D2 OBL model in this way: | ||
|
||
1. Download the models from Hugginface: ```git clone https://huggingface.co/mttga/obl-r2d2-flax``` (ensure to have git lfs installed). You can also use the script: ```bash jaxmarl/environments/hanabi/models/download_r2d2_obl.sh``` | ||
2. Load the parameters, import the agent wrapper and use it with JaxMarl Hanabi: | ||
|
||
```python | ||
!git clone https://huggingface.co/mttga/obl-r2d2-flax | ||
import jax | ||
from jax import numpy as jnp | ||
from jaxmarl import make | ||
from jaxmarl.wrappers.baselines import load_params | ||
from jaxmarl.environments.hanabi.pretrained import OBLAgentR2D2 | ||
|
||
weight_file = "jaxmarl/environments/hanabi/pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" | ||
params = load_params(weight_file) | ||
|
||
agent = OBLAgentR2D2() | ||
agent_carry = agent.initialize_carry(jax.random.PRNGKey(0), batch_dims=(2,)) | ||
|
||
rng = jax.random.PRNGKey(0) | ||
env = make('hanabi') | ||
obs, env_state = env.reset(rng) | ||
env.render(env_state) | ||
|
||
batchify = lambda x: jnp.stack([x[agent] for agent in env.agents]) | ||
unbatchify = lambda x: {agent:x[i] for i, agent in enumerate(env.agents)} | ||
|
||
agent_input = ( | ||
batchify(obs), | ||
batchify(env.get_legal_moves(env_state)) | ||
) | ||
agent_carry, actions = agent.greedy_act(params, agent_carry, agent_input) | ||
actions = unbatchify(actions) | ||
|
||
obs, env_state, rewards, done, info = env.step(rng, env_state, actions) | ||
|
||
print('actions:', {agent:env.action_encoding[int(a)] for agent, a in actions.items()}) | ||
env.render(env_state) | ||
``` | ||
|
||
## Rendering | ||
|
||
You can render the full environment state: | ||
|
||
```python | ||
obs, env_state = env.reset(rng) | ||
env.render(env_state) | ||
|
||
Turn: 0 | ||
|
||
Score: 0 | ||
Information: 8 | ||
Lives: 3 | ||
Deck: 40 | ||
Discards: | ||
Fireworks: | ||
Actor 0 Hand:<-- current player | ||
0 W3 || XX|RYGWB12345 | ||
1 G5 || XX|RYGWB12345 | ||
2 G4 || XX|RYGWB12345 | ||
3 G1 || XX|RYGWB12345 | ||
4 Y2 || XX|RYGWB12345 | ||
Actor 1 Hand: | ||
0 R3 || XX|RYGWB12345 | ||
1 B1 || XX|RYGWB12345 | ||
2 G1 || XX|RYGWB12345 | ||
3 R4 || XX|RYGWB12345 | ||
4 W4 || XX|RYGWB12345 | ||
``` | ||
|
||
Or you can render the partial observation of the current agent: | ||
|
||
```python | ||
obs, new_env_state, rewards, dones, infos = env.step_env(rng, env_state, actions) | ||
obs_s = env.get_obs_str(new_env_state, env_state, a, include_belief=True, best_belief=5) | ||
print(obs_s) | ||
|
||
Turn: 1 | ||
|
||
Score: 0 | ||
Information available: 7 | ||
Lives available: 3 | ||
Deck remaining cards: 40 | ||
Discards: | ||
Fireworks: | ||
Other Hand: | ||
0 Card: W3, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060] | ||
1 Card: G5, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060] | ||
2 Card: G4, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060] | ||
3 Card: G1, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060] | ||
4 Card: Y2, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060] | ||
Your Hand: | ||
0 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057] | ||
1 Hints: 1, Possible: RYGWB1, Belief: [R1: 0.200 Y1: 0.200 G1: 0.200 W1: 0.200 B1: 0.200] | ||
2 Hints: 1, Possible: RYGWB1, Belief: [R1: 0.200 Y1: 0.200 G1: 0.200 W1: 0.200 B1: 0.200] | ||
3 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057] | ||
4 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057] | ||
Last action: H1 | ||
Cards afected: [1 2] | ||
Legal Actions: ['D0', 'D1', 'D2', 'D3', 'D4', 'P0', 'P1', 'P2', 'P3', 'P4', 'HY', 'HG', 'HW', 'H1', 'H2', 'H3', 'H4', 'H5'] | ||
``` | ||
|
||
## Manual Game | ||
|
||
You can test the environment and your models by using the ```manual_game.py``` script in this folder. It allows to control one or two agents with the keyboard and one or two agents with a pretrained model (an obl model by default). For example, to play with an obl pretrained model: | ||
|
||
``` | ||
python manual_game.py \ | ||
--player0 "manual" \ | ||
--player1 "obl" \ | ||
--weight1 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \ | ||
``` | ||
|
||
Or to look an obl model playing with itself: | ||
|
||
``` | ||
python manual_game.py \ | ||
--player0 "obl" \ | ||
--player1 "obl" \ | ||
--weight0 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \ | ||
--weight1 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \ | ||
``` | ||
|
||
## Citation | ||
The environment was orginally described in the following work: | ||
``` | ||
@article{bard2019hanabi, | ||
title={The Hanabi Challenge: A New Frontier for AI Research}, | ||
author={Bard, Nolan and Foerster, Jakob N. and Chandar, Sarath and Burch, Neil and Lactot, Marc and Song, H. Francis and Parisotto, Emilio and Dumoulin, Vincent and Moitra, Subhodeep and Hughes, Edward and Dunning, Ian and Mourad, Shibl and Larochelle, Hugo and Bellemare, Marc G. and Bowling}, | ||
journal={Artificial Intelligence Journal}, | ||
year={2019} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# JaxNav | ||
|
||
2D geometric navigation for differential drive robots. Using distances readings to nearby obstacles (mimicing LiDAR readings), the direction to their goal and their current velocity, robots must navigate to their goal without colliding with obstacles. | ||
|
||
## Environment Details | ||
|
||
JaxNav was first introduced in ["No Regrets: Investigating and Improving Regret Approximations for Curriculum Discovery"](https://www.arxiv.org/abs/2408.15099) with an in-detail specification given in the Appendix. | ||
|
||
### Map Types | ||
The default map is square robots of width 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle. | ||
|
||
We also include a map which uses polygon obstacles, but note we have not used this code in a while so there may well be issues with it. | ||
|
||
### Observation space | ||
By default, each robot receives 200 range readings from a 360-degree arc centered on their forward axis. These range readings have a max range of 6m but no minimum range and are discretised with a resultion of 0.05 m. Alongside these range readings, each robot receives their current linear and angular velocities along with the direction to their goal. Their goal direction is given by a vector in polar form where the distance is either the max lidar range if the goal is beyond their "line of sight" or the actual distance if the goal is within their lidar range. There is no communication between agents. | ||
|
||
### Action Space | ||
The environments default action space is a 2D continuous action, where the first dimension is the desired linear velocity and the second the desired angular velocity. Discrete actions are also supported, where the possible combination of linear and angular velocities are discretised into 15 options. | ||
|
||
### Reward function | ||
By default, the reward function contains a sparse outcome based reward alongside a dense shaping term. | ||
|
||
## Visulisation | ||
Visualiser contained within `jaxnav_viz.py`, with an example below: | ||
|
||
```python | ||
from jaxmarl.environments.jaxnav.jaxnav_env import JaxNav | ||
from jaxmarl.environments.jaxnav.jaxnav_viz import JaxNavVisualizer | ||
import jax | ||
|
||
env = JaxNav(num_agents=4) | ||
|
||
rng = jax.random.PRNGKey(0) | ||
rng, _rng = jax.random.split(rng) | ||
|
||
obs, env_state = env.reset(_rng) | ||
|
||
obs_list = [obs] | ||
env_state_list = [env_state] | ||
|
||
for _ in range(10): | ||
rng, act_rng, step_rng = jax.random.split(rng, 3) | ||
act_rngs = jax.random.split(act_rng, env.num_agents) | ||
actions = {a: env.action_space(a).sample(act_rngs[i]) for i, a in enumerate(env.action_spaces.keys())} | ||
obs, env_state, _, _, _ = env.step(step_rng, env_state, actions) | ||
obs_list.append(obs) | ||
env_state_list.append(env_state) | ||
|
||
viz = JaxNavVisualizer(env, obs_list, env_state_list) | ||
viz.animate("test.gif") | ||
``` | ||
|
||
## TODOs: | ||
- remove `self.rad` dependence for non circular agents | ||
- more unit tests | ||
- add tests for non-square agents | ||
|
||
## Citation | ||
JaxNav was introduced by the following paper, if you use JaxNav in your work please cite it as: | ||
|
||
```bibtex | ||
@misc{rutherford2024noregrets, | ||
title={No Regrets: Investigating and Improving Regret Approximations for Curriculum Discovery}, | ||
author={Alexander Rutherford and Michael Beukman and Timon Willi and Bruno Lacerda and Nick Hawes and Jakob Foerster}, | ||
year={2024}, | ||
eprint={2408.15099}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.LG}, | ||
url={https://arxiv.org/abs/2408.15099}, | ||
} | ||
``` |
Oops, something went wrong.