Skip to content

Commit

Permalink
html files
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Rutherford authored and Alex Rutherford committed Dec 5, 2024
1 parent 0d4b1f3 commit 939c9b6
Show file tree
Hide file tree
Showing 86 changed files with 21,622 additions and 6 deletions.
File renamed without changes.
41 changes: 41 additions & 0 deletions docs/Algorithms/PPO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# IPPO & MAPPO

# IPPO Baseline

Pure JAX IPPO implementation, based on the PureJaxRL PPO implementation.

## 🔎 Implementation Details
General features:

* Agents are controlled by a single network architecture (either FF or RNN).
* Parameters are shared between agents.

## 🚀 Usage

If you have cloned JaxMARL and are in the repository root, you can run the algorithms as scripts, e.g.
```bash
python baselines/IPPO/ippo_rnn_smax.py
```
Each file has a distinct config file which resides within [`config`](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/IPPO/config).
The config file contains the IPPO hyperparameters, the environment's parameters and for some config files the `wandb` details (`wandb` is disabled by default).

# MAPPO Baseline

Pure JAX MAPPO implementation, based on the PureJaxRL PPO implementation.

## 🔎 Implementation Details
General features:

* Agents are controlled by a single network architecture (either FF or RNN).
* Parameters are shared between agents.
* Each script has a `WorldStateWrapper` which provides a global `"world_state"` observation.

## 🚀 Usage

If you have cloned JaxMARL and are in the repository root, you can run the algorithms as scripts, e.g.
```bash
python baselines/MAPPO/mappo_rnn_smax.py
```
Each file has a distinct config file which resides within [`config`](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/MAPPO/config).
The config file contains the MAPPO hyperparameters, the environment's parameters and the `wandb` details (`wandb` is disabled by default).

86 changes: 86 additions & 0 deletions docs/Algorithms/QLearning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# QLearning

Pure JAX implementations of:

* PQN-VDN (Prallelised Q-Network)
* IQL (Independent Q-Learners)
* VDN (Value Decomposition Network)
* QMIX
* TransfQMix (Transformers for Leveraging the Graph Structure of MARL Problems)
* SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning)

PQN implementation follows [purejaxql](https://github.com/mttga/purejaxql). IQL, VDN and QMix follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning).


Standard algorithms (iql, vdn, qmix) support:

- MPE
- SMAX
- Overcooked (qmix not supported)

PQN-VDN supports:

- MPE
- SMAX
- Hanabi
- Overcooked

**At the moment, PQN-VDN should be the most performant baseline for Q-Learning in terms of returns and training speed.**

❗ TransfQMix and Shaq still use an old implementation of the scripts and need refactoring to match the new format.


## ⚙️ Implementation Details

All the algorithms take advantage of the `CTRolloutManager` environment wrapper (found in `jaxmarl.wrappers.baselines`), which is used to:

- Batchify the step and reset functions to run parallel environments.
- Add a global observation (`obs["__all__"]`) and a global reward (`rewards["__all__"]`) to the returns of `env.step` for centralized training.
- Preprocess and uniform the observation vectors (flatten, pad, add additional features like id one-hot encoding, etc.).

You might want to modify this wrapper for your needs.

## 🚀 Usage

If you have cloned JaxMARL and you are in the repository root, you can run the algorithms as scripts. You will need to specify which parameter configurations will be loaded by Hydra by choosing them (or adding yours) in the config folder. Below are some examples:

```bash
# vdn rnn in in mpe spread
python baselines/QLearning/vdn_rnn.py +alg=ql_rnn_mpe
# independent IQL rnn in competetive simple_tag (predator-prey)
python baselines/QLearning/iql_rnn.py +alg=ql_rnn_mpe alg.ENV_NAME=MPE_simple_tag_v3
# QMix with SMAX
python baselines/QLearning/qmix_rnn.py +alg=ql_rnn_smax
# VDN overcooked
python baselines/QLearning/vdn_cnn_overcooked.py +alg=ql_cnn_overcooked alg.ENV_KWARGS.LAYOUT=counter_circuit
# TransfQMix
python baselines/QLearning/transf_qmix.py +alg=transf_qmix_smax

# pqn feed-forward in MPE
python baselines/QLearning/pqn_vdn_ff.py +alg=pqn_vdn_ff_mpe
# pqn feed-forward in hanabi
python baselines/QLearning/pqn_vdn_ff.py +alg=pqn_vdn_ff_hanabi
# pqn CNN in overcooked
python baselines/QLearning/pqn_vdn_cnn_overcooked.py +alg=pqn_vdn_cnn_overcooked
# pqn with RNN in SMAX
python baselines/QLearning/pqn_vdn_rnn.py +alg=pqn_vdn_rnn_smax
```

Notice that with Hydra, you can modify parameters on the go in this way:

```bash
# change learning rate
python baselines/QLearning/iql_rnn.py +alg=ql_rnn_mpe alg.LR=0.001
# change overcooked layout
python baselines/QLearning/pqn_vdn_cnn_overcooked.py +alg=pqn_vdn_cnn_overcooked alg.ENV_KWARGS.LAYOUT=counter_circuit
# change smax map
python baselines/QLearning/pqn_vdn_rnn.py +alg=pqn_vdn_rnn_smax alg.MAP_NAME=5m_vs_6m
```

Take a look at [`config.yaml`](./config/config.yaml) for the default configuration when running these scripts. There you can choose how many seeds to vmap and you can setup WANDB.

**❗Note on Transformers**: TransfQMix currently supports only MPE_Spread and SMAX. You will need to wrap the observation vectors into matrices to use transformers in other environments. See: ```jaxmarl.wrappers.transformers```

## 🎯 Hyperparameter tuning

All the scripts include a tune function to perform hyperparameter tuning. To use it, set `"HYP_TUNE": True` in the `config.yaml` and set the hyperparameters spaces in the tune function. For more information, check [wandb documentation](https://docs.wandb.ai/guides/sweeps).
8 changes: 8 additions & 0 deletions docs/Environments/coin_game.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Coin

JaxMARL contains an implementation of the Coin Game environment presented in [Model-Free Opponent Shaping (Lu et al.)](https://arxiv.org/abs/2205.01447). The original description and usage of the environment is from [Maintaining cooperation in complex social dilemmas using deep reinforcement learning (Lerer et al.)](https://arxiv.org/abs/1707.01068), and [Learning with Opponent-Learning Awareness (Foerster et al.)](https://arxiv.org/abs/1709.04326) is the first to popularize its use for opponent shaping. A description from Model-Free Opponent Shaping:

```
The Coin Game is a multi-agent grid-world environment that simulates social dilemmas like the IPD but with high dimensional dynamic states. First proposed by Lerer & Peysakhovich (2017), the game consists of two players, labeled red and blue respectively, who are tasked with picking up coins, also labeled red and blue respectively, in a 3x3 grid. If a player picks up any coin by moving into the same position as the coin, they receive a reward of +1. However, if they pick up a coin of the other player’s color, the other player receives a reward of −2. Thus, if both agents play greedily and pick up every coin, the expected reward for both agents is 0.
```

185 changes: 185 additions & 0 deletions docs/Environments/hanabi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Hanabi

This directory contains a MARL environment for the cooperative card game, Hanabi, implemented in JAX. It is inspired by the popular [Hanabi Learning Environment (HLE)](https://arxiv.org/pdf/1902.00506.pdf), but intended to be simpler to integrate and run with the growing ecosystem of JAX implemented RL research pipelines.


## Action Space
Hanabi is a turn-based game. The current player can choose to discard or play any of the cards in their hand, or hint a colour or rank to any one of their teammates.

## Observation Space
The observations closely follow the featurization in the HLE. Each observation is comprised of 658 features:

* **Hands (127)**: information about the visible hands.
* other player hand: 125
* card 0: 25,
* card 1: 25
* card 2: 25
* card 3: 25
* card 4: 25
* Hands missing card: 2 (one-hot)

* **Board (76)**: encoding of the public information visible in the board.
* Deck: 40, thermometer
* Fireworks: 25, one-hot
* Info Tokens: 8, thermometer
* ife Tokens: 3, thermometer

* **Discards (50)**: encoding of the cards in the discard pile.
* Colour R: 10 bits for each card
* Colour Y: 10 bits for each card
* Colour G: 10 bits for each card
* Colour W: 10 bits for each card
* Colour B: 10 bits for each card

* **Last Action (55)**: encoding of the last move of the previous player.
* Acting player index, relative to yourself: 2, one-hot
* MoveType: 4, one-hot
* Target player index, relative to acting player: 2, one-hot
* Color revealed: 5, one-hot
* Rank revealed: 5, one-hot
* Reveal outcome 5 bits, each bit is 1 if the card was hinted at
* Position played/discarded: 5, one-hot
* Card played/discarded 25, one-hot
* Card played scored: 1
* Card played added info token: 1

* **V0 belief (350)**: trivially-computed probability of being a specific car (given the played-discarded cards and the hints given), for each card of each player.
* Possible Card (for each card): 25 (* 10)
* Colour hinted (for each card): 5 (* 10)
* Rank hinted (for each card): 5 (* 10)

## Pretrained Models

We make available to use some pretrained models. For example you can use a jax conversion of the original R2D2 OBL model in this way:

1. Download the models from Hugginface: ```git clone https://huggingface.co/mttga/obl-r2d2-flax``` (ensure to have git lfs installed). You can also use the script: ```bash jaxmarl/environments/hanabi/models/download_r2d2_obl.sh```
2. Load the parameters, import the agent wrapper and use it with JaxMarl Hanabi:

```python
!git clone https://huggingface.co/mttga/obl-r2d2-flax
import jax
from jax import numpy as jnp
from jaxmarl import make
from jaxmarl.wrappers.baselines import load_params
from jaxmarl.environments.hanabi.pretrained import OBLAgentR2D2

weight_file = "jaxmarl/environments/hanabi/pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors"
params = load_params(weight_file)

agent = OBLAgentR2D2()
agent_carry = agent.initialize_carry(jax.random.PRNGKey(0), batch_dims=(2,))

rng = jax.random.PRNGKey(0)
env = make('hanabi')
obs, env_state = env.reset(rng)
env.render(env_state)

batchify = lambda x: jnp.stack([x[agent] for agent in env.agents])
unbatchify = lambda x: {agent:x[i] for i, agent in enumerate(env.agents)}

agent_input = (
batchify(obs),
batchify(env.get_legal_moves(env_state))
)
agent_carry, actions = agent.greedy_act(params, agent_carry, agent_input)
actions = unbatchify(actions)

obs, env_state, rewards, done, info = env.step(rng, env_state, actions)

print('actions:', {agent:env.action_encoding[int(a)] for agent, a in actions.items()})
env.render(env_state)
```

## Rendering

You can render the full environment state:

```python
obs, env_state = env.reset(rng)
env.render(env_state)

Turn: 0

Score: 0
Information: 8
Lives: 3
Deck: 40
Discards:
Fireworks:
Actor 0 Hand:<-- current player
0 W3 || XX|RYGWB12345
1 G5 || XX|RYGWB12345
2 G4 || XX|RYGWB12345
3 G1 || XX|RYGWB12345
4 Y2 || XX|RYGWB12345
Actor 1 Hand:
0 R3 || XX|RYGWB12345
1 B1 || XX|RYGWB12345
2 G1 || XX|RYGWB12345
3 R4 || XX|RYGWB12345
4 W4 || XX|RYGWB12345
```

Or you can render the partial observation of the current agent:

```python
obs, new_env_state, rewards, dones, infos = env.step_env(rng, env_state, actions)
obs_s = env.get_obs_str(new_env_state, env_state, a, include_belief=True, best_belief=5)
print(obs_s)

Turn: 1

Score: 0
Information available: 7
Lives available: 3
Deck remaining cards: 40
Discards:
Fireworks:
Other Hand:
0 Card: W3, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
1 Card: G5, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
2 Card: G4, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
3 Card: G1, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
4 Card: Y2, Hints: , Possible: RYGWB12345, Belief: [R1: 0.060 Y1: 0.060 G1: 0.060 W1: 0.060 B1: 0.060]
Your Hand:
0 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057]
1 Hints: 1, Possible: RYGWB1, Belief: [R1: 0.200 Y1: 0.200 G1: 0.200 W1: 0.200 B1: 0.200]
2 Hints: 1, Possible: RYGWB1, Belief: [R1: 0.200 Y1: 0.200 G1: 0.200 W1: 0.200 B1: 0.200]
3 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057]
4 Hints: , Possible: RYGWB2345, Belief: [R2: 0.057 R3: 0.057 R4: 0.057 Y2: 0.057 Y3: 0.057]
Last action: H1
Cards afected: [1 2]
Legal Actions: ['D0', 'D1', 'D2', 'D3', 'D4', 'P0', 'P1', 'P2', 'P3', 'P4', 'HY', 'HG', 'HW', 'H1', 'H2', 'H3', 'H4', 'H5']
```

## Manual Game

You can test the environment and your models by using the ```manual_game.py``` script in this folder. It allows to control one or two agents with the keyboard and one or two agents with a pretrained model (an obl model by default). For example, to play with an obl pretrained model:

```
python manual_game.py \
--player0 "manual" \
--player1 "obl" \
--weight1 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \
```

Or to look an obl model playing with itself:

```
python manual_game.py \
--player0 "obl" \
--player1 "obl" \
--weight0 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \
--weight1 "./pretrained/obl-r2d2-flax/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors" \
```

## Citation
The environment was orginally described in the following work:
```
@article{bard2019hanabi,
title={The Hanabi Challenge: A New Frontier for AI Research},
author={Bard, Nolan and Foerster, Jakob N. and Chandar, Sarath and Burch, Neil and Lactot, Marc and Song, H. Francis and Parisotto, Emilio and Dumoulin, Vincent and Moitra, Subhodeep and Hughes, Edward and Dunning, Ian and Mourad, Shibl and Larochelle, Hugo and Bellemare, Marc G. and Bowling},
journal={Artificial Intelligence Journal},
year={2019}
}
```
71 changes: 71 additions & 0 deletions docs/Environments/jaxnav.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# JaxNav

2D geometric navigation for differential drive robots. Using distances readings to nearby obstacles (mimicing LiDAR readings), the direction to their goal and their current velocity, robots must navigate to their goal without colliding with obstacles.

## Environment Details

JaxNav was first introduced in ["No Regrets: Investigating and Improving Regret Approximations for Curriculum Discovery"](https://www.arxiv.org/abs/2408.15099) with an in-detail specification given in the Appendix.

### Map Types
The default map is square robots of width 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle.

We also include a map which uses polygon obstacles, but note we have not used this code in a while so there may well be issues with it.

### Observation space
By default, each robot receives 200 range readings from a 360-degree arc centered on their forward axis. These range readings have a max range of 6m but no minimum range and are discretised with a resultion of 0.05 m. Alongside these range readings, each robot receives their current linear and angular velocities along with the direction to their goal. Their goal direction is given by a vector in polar form where the distance is either the max lidar range if the goal is beyond their "line of sight" or the actual distance if the goal is within their lidar range. There is no communication between agents.

### Action Space
The environments default action space is a 2D continuous action, where the first dimension is the desired linear velocity and the second the desired angular velocity. Discrete actions are also supported, where the possible combination of linear and angular velocities are discretised into 15 options.

### Reward function
By default, the reward function contains a sparse outcome based reward alongside a dense shaping term.

## Visulisation
Visualiser contained within `jaxnav_viz.py`, with an example below:

```python
from jaxmarl.environments.jaxnav.jaxnav_env import JaxNav
from jaxmarl.environments.jaxnav.jaxnav_viz import JaxNavVisualizer
import jax

env = JaxNav(num_agents=4)

rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)

obs, env_state = env.reset(_rng)

obs_list = [obs]
env_state_list = [env_state]

for _ in range(10):
rng, act_rng, step_rng = jax.random.split(rng, 3)
act_rngs = jax.random.split(act_rng, env.num_agents)
actions = {a: env.action_space(a).sample(act_rngs[i]) for i, a in enumerate(env.action_spaces.keys())}
obs, env_state, _, _, _ = env.step(step_rng, env_state, actions)
obs_list.append(obs)
env_state_list.append(env_state)

viz = JaxNavVisualizer(env, obs_list, env_state_list)
viz.animate("test.gif")
```

## TODOs:
- remove `self.rad` dependence for non circular agents
- more unit tests
- add tests for non-square agents

## Citation
JaxNav was introduced by the following paper, if you use JaxNav in your work please cite it as:

```bibtex
@misc{rutherford2024noregrets,
title={No Regrets: Investigating and Improving Regret Approximations for Curriculum Discovery},
author={Alexander Rutherford and Michael Beukman and Timon Willi and Bruno Lacerda and Nick Hawes and Jakob Foerster},
year={2024},
eprint={2408.15099},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2408.15099},
}
```
Loading

0 comments on commit 939c9b6

Please sign in to comment.