Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs for dqn.py #157

Merged
merged 9 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions benchmark/dqn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Deep Q-Learning Benchmark

This repository contains instructions to reproduce our DQN experiments.

Prerequisites:
* Python 3.8+
* [Poetry](https://python-poetry.org)
* [GitHub CLI](https://cli.github.com/)


## Reproducing CleanRL's DQN Benchmark

### Classic Control

```bash
git clone https://github.com/vwxyzjn/cleanrl.git && cd cleanrl
gh pr checkout 157
poetry install
bash benchmark/dqn/classic_control.sh
```

Note that you may need to overwrite the `--wandb-entity cleanrl` to your own W&B entity, in case you have not obtained access to the `cleanrl/openbenchmark` project.


### Atari games

```bash
git clone https://github.com/vwxyzjn/cleanrl.git && cd cleanrl
gh pr checkout 124
poetry install
poetry install -E atari
bash benchmark/dqn/atari.sh
```

Note that you may need to overwrite the `--wandb-entity cleanrl` to your own W&B entity, in case you have not obtained access to the `cleanrl/openbenchmark` project.
14 changes: 14 additions & 0 deletions benchmark/dqn/classic_control.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# CartPole-v1
poetry run python cleanrl/dqn.py --env-id CartPole-v1 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn.py --env-id CartPole-v1 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn.py --env-id CartPole-v1 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark

# Acrobot-v1
poetry run python cleanrl/dqn.py --env-id Acrobot-v1 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn.py --env-id Acrobot-v1 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn.py --env-id Acrobot-v1 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark

# MountainCar-v0
poetry run python cleanrl/dqn.py --env-id MountainCar-v0 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn.py --env-id MountainCar-v0 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn.py --env-id MountainCar-v0 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
6 changes: 5 additions & 1 deletion cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device, optimize_memory_usage=True
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
)
start_time = time.time()

Expand Down
7 changes: 6 additions & 1 deletion cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device, optimize_memory_usage=True
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
optimize_memory_usage=True,
handle_timeout_termination=True,
)
start_time = time.time()

Expand Down
24 changes: 14 additions & 10 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def parse_args():
# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=25000,
parser.add_argument("--total-timesteps", type=int, default=500000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
Expand All @@ -49,17 +49,17 @@ def parse_args():
help="the timesteps it takes to update the target network")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="the maximum norm for the gradient clipping")
parser.add_argument("--batch-size", type=int, default=32,
parser.add_argument("--batch-size", type=int, default=128,
help="the batch size of sample from the reply memory")
parser.add_argument("--start-e", type=float, default=1,
help="the starting epsilon for exploration")
parser.add_argument("--end-e", type=float, default=0.05,
help="the ending epsilon for exploration")
parser.add_argument("--exploration-fraction", type=float, default=0.8,
parser.add_argument("--exploration-fraction", type=float, default=0.5,
help="the fraction of `total-timesteps` it takes from start-e to go end-e")
parser.add_argument("--learning-starts", type=int, default=10000,
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=1,
parser.add_argument("--train-frequency", type=int, default=10,
dosssman marked this conversation as resolved.
Show resolved Hide resolved
help="the frequency of training")
args = parser.parse_args()
# fmt: on
Expand All @@ -86,11 +86,11 @@ class QNetwork(nn.Module):
def __init__(self, env):
super(QNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space.shape).prod(), 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, env.single_action_space.n),
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, env.single_action_space.n),
)

def forward(self, x):
Expand Down Expand Up @@ -141,7 +141,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device, optimize_memory_usage=True
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
)
start_time = time.time()

Expand Down
7 changes: 6 additions & 1 deletion cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device, optimize_memory_usage=True
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
optimize_memory_usage=True,
handle_timeout_termination=True,
)
start_time = time.time()

Expand Down
8 changes: 7 additions & 1 deletion cleanrl/sac_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,13 @@ def to(self, device):
alpha = args.alpha

envs.single_observation_space.dtype = np.float32
rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device)
rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
Expand Down
8 changes: 7 additions & 1 deletion cleanrl/td3_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,13 @@ def forward(self, x):
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate)

envs.single_observation_space.dtype = np.float32
rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device)
rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
Expand Down
104 changes: 89 additions & 15 deletions docs/rl-algorithms/dqn.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,11 @@ Original papers:

| Variants Implemented | Description |
| ----------- | ----------- |
| :material-github: [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), :material-file-document: [docs](/rl-algorithms/dqn/#dqnpy) | For classic control tasks like `CartPole-v1`. |
| :material-github: [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), :material-file-document: [docs](/rl-algorithms/dqn/#dqn_ataripy) | For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques. |
| :material-github: [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), :material-file-document: [docs](/rl-algorithms/dqn/#dqnpy) | For classic control tasks like `CartPole-v1`.

Below are our single-file implementations of DQN:

## `dqn.py`

The [dqn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py) has the following features:

* Works with the `Box` observation space of low-level features
* Works with the `Discrete` action space
* Works with envs like `CartPole-v1`

### Implementation details

[dqn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py) includes the 11 core implementation details:



## `dqn_atari.py`

Expand Down Expand Up @@ -90,7 +77,7 @@ with the Bellman update target is $y = r + \gamma \, Q^{'}(s', a')$ and the repl
- `dqn_atari.py` uses `--total-timesteps=10000000` (i.e., 10M timesteps = 40M frames because of frame-skipping) whereas (Mnih et al., 2015)[^1] uses `--total-timesteps=50000000` (i.e., 50M timesteps = 200M frames) (See "Training details" under "METHODS" on page 6 and the related source code [run_gpu#L32](https://github.com/deepmind/dqn/blob/9d9b1d13a2b491d6ebd4d046740c511c662bbe0f/run_gpu#L32), [dqn/train_agent.lua#L81-L82](https://github.com/deepmind/dqn/blob/9d9b1d13a2b491d6ebd4d046740c511c662bbe0f/dqn/train_agent.lua#L81-L82), and [dqn/train_agent.lua#L165-L169](https://github.com/deepmind/dqn/blob/9d9b1d13a2b491d6ebd4d046740c511c662bbe0f/dqn/train_agent.lua#L165-L169)).
- `dqn_atari.py` uses `--end-e=0.01` (the final exploration epsilon) whereas (Mnih et al., 2015)[^1] (Exntended Data Table 1) uses `--end-e=0.1`.
- `dqn_atari.py` uses `--exploration-fraction=0.1` whereas (Mnih et al., 2015)[^1] (Exntended Data Table 1) uses `--exploration-fraction=0.02` (all corresponds to 250000 steps or 1M frames being the frame that epsilon is annealed to `--end-e=0.1` ).
- `dqn_atari.py` treats termination and truncation the same way due to the gym interface[^2] whereas (Mnih et al., 2015)[^1] correctly handles truncation.
- `dqn_atari.py` handles truncation and termination properly like (Mnih et al., 2015)[^1] by using SB3's replay buffer's `handle_timeout_termination=True`.
1. `dqn_atari.py` use a self-contained evaluation scheme: `dqn_atari.py` reports the episodic returns obtained throughout training, whereas (Mnih et al., 2015)[^1] is trained with `--end-e=0.1` but reported episodic returns using a separate evaluation process with `--end-e=0.01` (See "Evaluation procedure" under "METHODS" on page 6).
1. `dqn_atari.py` rescales the gradient so that the norm of the parameters does not exceed `0.5` like done in PPO (:material-github: [ppo2/model.py#L102-L108](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L102-L108)).

Expand Down Expand Up @@ -128,6 +115,93 @@ Tracked experiments and game play videos:
<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Atari-CleanRL-s-DQN--VmlldzoxNjk3NjYx" style="width:100%; height:500px" title="CleanRL DQN Tracked Experiments"></iframe>


## `dqn.py`

The [dqn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py) has the following features:

* Works with the `Box` observation space of low-level features
* Works with the `Discrete` action space
* Works with envs like `CartPole-v1`


### Usage

```bash
python cleanrl/dqn.py --env-id CartPole-v1
```


### Explanation of the logged metrics

See [related docs](/rl-algorithms/dqn/#explanation-of-the-logged-metrics) for `dqn_atari.py`.

### Implementation details

The [dqn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py) shares the same implementation details as [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py) except the `dqn.py` runs with different hyperparameters and neural network architecture. Specifically,

1. `dqn.py` uses a simpler neural network as follows:
```python
self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, env.single_action_space.n),
)
```
2. `dqn.py` runs with different hyperparameters:

```bash
python dqn.py --total-timesteps 500000 \
--learning-rate 2.5e-4 \
--buffer-size 10000 \
--gamma 0.99 \
--target-network-frequency 500 \
--max-grad-norm 0.5 \
--batch-size 128 \
--start-e 1 \
--end-e 0.05 \
--exploration-fraction 0.5 \
--learning-starts 10000 \
--train-frequency 10
```


### Experiment results

PR :material-github: [vwxyzjn/cleanrl#157](https://github.com/vwxyzjn/cleanrl/pull/157) tracks our effort to conduct experiments, and the reprodudction instructions can be found at :material-github: [vwxyzjn/cleanrl/benchmark/dqn](https://github.com/vwxyzjn/cleanrl/tree/master/benchmark/dqn).

Below are the average episodic returns for `dqn.py`.


| Environment | `dqn.py` |
| ----------- | ----------- |
| CartPole-v1 | 471.21 ± 43.45 |
| Acrobot-v1 | -93.37 ± 8.46 |
| MountainCar-v0 | -170.51 ± 26.22 |


Note that the DQN has no official benchmark on classic control environments, so we did not include a comparison. That said, our `dqn.py` was able to achieve near perfect scores in `CartPole-v1` and `Acrobot-v1`; further, it can obtain successful runs in the sparse environment `MountainCar-v0`.


Learning curves:

<div class="grid-container">
<img src="../dqn/CartPole-v1.png">

<img src="../dqn/Acrobot-v1.png">

<img src="../dqn/MountainCar-v0.png">
</div>


Tracked experiments and game play videos:

<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Classic-Control-CleanRL-s-DQN--VmlldzoxODE4Mjg1" style="width:100%; height:500px" title="CleanRL DQN Tracked Experiments"></iframe>




[^1]:Mnih, V., Kavukcuoglu, K., Silver, D. et al. Human-level control through deep reinforcement learning. Nature 518, 529–533 (2015). https://doi.org/10.1038/nature14236
[^2]:\[Proposal\] Formal API handling of truncation vs termination. https://github.com/openai/gym/issues/2510
[^3]: Hessel, M., Modayil, J., Hasselt, H.V., Schaul, T., Ostrovski, G., Dabney, W., Horgan, D., Piot, B., Azar, M.G., & Silver, D. (2018). Rainbow: Combining Improvements in Deep Reinforcement Learning. AAAI.
Binary file added docs/rl-algorithms/dqn/Acrobot-v1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/dqn/CartPole-v1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/dqn/MountainCar-v0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.