Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QR-DQN #13

Merged
merged 34 commits into from
Dec 21, 2020
Merged

Add QR-DQN #13

merged 34 commits into from
Dec 21, 2020

Conversation

toshikwa
Copy link
Contributor

@toshikwa toshikwa commented Dec 8, 2020

Implement QR-DQN

Description

Paper: https://arxiv.org/abs/1710.10044

  • Implement QR-DQN
  • Add quantile_huber_loss to sb3_contrib.common.utils
  • Add document for QR-DQN
  • Add benchmark

Context

  • I have raised an issue to propose this change (required)

closes #12

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • The functionality/performance matches that of the source (required for new training algorithms or training-related features).
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have included an example of using the feature (required for new features).
  • I have included baseline results (required for new training algorithms or training-related features).
  • I have updated the documentation accordingly.
  • I have updated the changelog accordingly (required).
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

Note: we are using a maximum length of 127 characters per line

Benchmark

Environments QRDQN DQN
logs_qrdqn/
Breakout 413 +/- 21 ~300
Pong 20 +/- 0 ~20
Environments DQN QRDQN
logs_qrdqn/ logs_qrdqn/
CartPole 386 +/- 64 500 +/- 0
MountainCar -111 +/- 4 -107 +/- 4
LunarLander 168 +/- 39 195 +/- 28
Acrobot -73 +/- 2 -74 +/- 2

Results_Breakout
Results_Pong

Results_Acrobot
Results_CartPole
Results_LunarLander
Results_MountainCar

sb3_contrib/__init__.py Outdated Show resolved Hide resolved
sb3_contrib/common/utils.py Outdated Show resolved Hide resolved
Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, look good already =)

Obviously missing tests, documentation and benchmark but it is a good start!

Did you had any issue so far? or things that slow you down because of how SB3 work in implementing QR-DQN?
(asking for what we can improve in SB3 ;))

@toshikwa
Copy link
Contributor Author

toshikwa commented Dec 8, 2020

@araffin
Hi, I have implemented QR-DQN roughly, and I have some points I want to discuss.

1
I named some classes and variables like QuantileNetwork and current_quantile, however in TQC there are some variables like next_z. Could you tell me what names do you prefer?

2
I added quantile_huber_loss to sb3_contrib.common.utils. However, if you plan to add IQN(paper) later, tau should be an argument. What do you think?

3
I didn't inherit any classes from DQN's for simplicity. I'm familiar with RL, but not with software architecture, and I'm not sure if I should inherit and re-use it. Could you give me any advice?

4
I don't have enough computational resources to benchmark on Atari.
Could you help me test the performance once I finish implementing it?

(I only have 16GB of RAM and I usually store images as a list of LazyFrames, which uses 4 times less memory. Maybe it's better to add such a feature to reduce memory usage?)

Thanks ;)

@araffin
Copy link
Member

araffin commented Dec 8, 2020

, however in TQC there are some variables like next_z. Could you tell me what names do you prefer?

current_quantile is must better ;) (please change TQC too so the two matches, probably next_quantiles if it is next_z)

I added quantile_huber_loss to sb3_contrib.common.utils. However, if you plan to add IQN(paper) later, tau should be an argument. What do you think?

what is tau exactly? (it would be nice if we have a better name, I'm quite new to quantile regression, so everything is not too clear yet)
yes, it sounds good if it does not add too much complexity.
(I need to re-read IQN later)

I didn't inherit any classes from DQN's for simplicity.

that's fine ;)
It is part of the design choices we made for SB3.
We favor simplicity and readability over modularity. (even though we try to avoid code duplication when possible)

Could you help me test the performance once I finish implementing it?

yes ;) but please test it first on simpler env and with a smaller replay buffer size.

I only have 16GB of RAM and I usually store images as a list of LazyFrames, which uses 4 times less memory. Maybe it's better to add such a feature to reduce memory usage?

Look at the zoo and the replay buffer, we have a optimize_memory_usage argument for that ;)

EDIT: I'm not sure for the lazy frame, but if it is simple enough to implement, it would be a good addition ;) (I think it was there in SB2 but not used)

@toshikwa
Copy link
Contributor Author

toshikwa commented Dec 8, 2020

please change TQC too so the two matches, probably next_quantiles if it is next_z

We favor simplicity and readability over modularity. (even though we try to avoid code duplication when possible)

yes ;) but please test it first on simpler env and with a smaller replay buffer size.

Sure, thank you for your advice.

what is tau exactly?

We model the quantile function, which is the mapping from the cumulative probability to the quantile value.
tau is the cumulative probabilities, which are fixed at equal intervals in TQC and QR-DQN (not in IQN).
So, cum_p or cum_prob would be a good candidate?

Look at the zoo and the replay buffer, we have a optimize_memory_usage argument for that ;)

I didn't notice it. It seems it stores the next observation and observation as the same array, and so efficient 👍

I'm not sure for the lazy frame, but if it is simple enough to implement,

LazyFrames is a list of arrays, and convert a list into a (frame-stacked) array when called. Because it stores a frame-stacked array as a list of arrays, which exactly is a list of references to arrays, it never stores the same frame twice.
It's really simple to implement as this, but it is a little bit slow.

Thank you for your kind response.

@araffin
Copy link
Member

araffin commented Dec 8, 2020

tau is the cumulative probabilities, which are fixed at equal intervals in TQC and QR-DQN (not in IQN).

thanks for the refresher =).
So, if I remember correctly, this is the case for QR-DQN because each quantile is a "unit quantile" (in the sense each of them have the same weights)?

cum_prob would be a good candidate?

yep + a comment ;) (in case of doubt, more verbose name is always good ;))

I didn't notice it. It seems it stores the next observation and observation as the same array, and so efficient +1

yes, it is not as memory efficient as LazyFrame but it is as fast as a normal buffer and works also without frame stacking at the cost of some complexity (that's why it is False by default).

It's really simple to implement as this, but it is a little bit slow.

how much slower? (1.2x or ~2x slower)

@toshikwa
Copy link
Contributor Author

toshikwa commented Dec 8, 2020

So, if I remember correctly, this is the case for QR-DQN because each quantile is a "unit quantile" (in the sense each of them have the same weights)?

Yes 👍

how much slower? (1.2x or ~2x slower)

I will test it again in a couple of days.

sb3_contrib/qrdqn/__init__.py Show resolved Hide resolved
sb3_contrib/qrdqn/policies.py Outdated Show resolved Hide resolved
sb3_contrib/qrdqn/qrdqn.py Outdated Show resolved Hide resolved
sb3_contrib/qrdqn/qrdqn.py Outdated Show resolved Hide resolved
sb3_contrib/tqc/tqc.py Outdated Show resolved Hide resolved
sb3_contrib/common/utils.py Outdated Show resolved Hide resolved
current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(2)

# Compute Quantile Huber loss
loss = quantile_huber_loss(current_quantiles, target_quantiles) * self.n_quantiles
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you multiply by self.n_quantiles? (that was not done in TQC original code if I recall... I did not check for QR-DQN yet)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We sum over a quantile dimension, which is common in QR-DQN, IQN, and FQF.
I'm not sure why they take the mean in TQC.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that Dopamine is using the mean loss by default: https://github.com/google/dopamine/blob/master/dopamine/jax/agents/quantile/quantile_agent.py#L91
The same goes for IQN implementation of Pfrl: https://github.com/pfnet/pfrl/blob/master/pfrl/agents/iqn.py#L211

or the Facebook ReAgent implementation: https://github.com/facebookresearch/ReAgent/blob/master/reagent/training/qrdqn_trainer.py#L157
maybe that could be a parameter? (and check how it affects the learning)

Copy link
Contributor Author

@toshikwa toshikwa Dec 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that Dopamine is using the mean loss by default: https://github.com/google/dopamine/blob/master/dopamine/jax/agents/quantile/quantile_agent.py#L91
The same goes for IQN implementation of Pfrl: https://github.com/pfnet/pfrl/blob/master/pfrl/agents/iqn.py#L211

Their implementations sum over the (current) quantile dimension, so the same as mine, aren't they?
Multiplying "n_quantiles" means that summing over the (current) quantile dimension.

EDIT: They are the same as our QR-DQN loss, not as TQC loss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add an argument like sum_over_quantiles : bool to quantile_huber_loss?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Their implementations sum over the (current) quantile dimension, so the same as mine, aren't they?

I will try to check later when I'm rested ;)

Maybe we should add an argument like sum_over_quantiles : bool to quantile_huber_loss?

Probably, yes

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could finally do the review, LGTM, thanks =)
(only missing to add the benchmark to the doc)

@araffin araffin changed the title WIP: Add QR-DQN Add QR-DQN Dec 19, 2020
@toshikwa
Copy link
Contributor Author

Thank you so much ;)

only missing to add the benchmark to the doc

Could you do this for me, or should I update the doc?

@araffin
Copy link
Member

araffin commented Dec 20, 2020

Could you do this for me, or should I update the doc?

done ;)
I also added a small benchmark (5 seeds, tuned hyperparameters) on classic control tasks.

@Miffyli I let you decide if we need more experiments (even though we match Intel Coach results), or we merge ;)

@Miffyli
Copy link
Member

Miffyli commented Dec 20, 2020

@araffin

Visually comparing results here vs. original paper the Breakout and Pong results seem to match up till the trained 10e6, so I would trust this implementation enough to use it as QR-DQN myself and believe we can merge it :).

Just two things: docs should mention what we compared against (Intel Coach and, very roughly, original paper) and also update zoo-instructions once we merge the branch.

@araffin araffin merged commit b30397f into Stable-Baselines-Team:master Dec 21, 2020
@toshikwa
Copy link
Contributor Author

@araffin @Miffyli
Thank you for your kind support ;)

@Squeemos
Copy link

Hate to post on very late, but is there a specific reason that we cannot use action_noise in QRDQN? Was looking to use it for exploration rather than a exploration scheduler, and would prefer using action noise over noisy nets for now.

@Miffyli
Copy link
Member

Miffyli commented Mar 23, 2022

Hate to post on very late, but is there a specific reason that we cannot use action_noise in QRDQN? Was looking to use it for exploration rather than a exploration scheduler, and would prefer using action noise over noisy nets for now.

QRDQN supports only Discrete action spaces, and there is no action_noise for Discrete spaces (at least no ones implemented in SB3 as of writing).

@hh0rva1h
Copy link

hh0rva1h commented Jul 4, 2022

My results on 3 seeds (Breakout and Pong) without smoothing window and with the default hyperparameters (so the exploration fraction is wrong and the epsilon for the evaluation is the same as for training):
Environments QRDQN
logs_qrdqn/
Breakout 413 +/- 21
Pong 20 +/- 0

Results_Breakout Results_Pong

Overall, it looks good as it matches Intel Coach results (+ a nice boost in final performance ~400 vs ~300 for DQN on Breakout) but as mentioned by @Miffyli that would nice to have at least one run with the original evaluation setting.

@araffin I tried to reproduce those results, but it did not work out. Which version of the games were used in your benchmark? https://www.gymlibrary.ml/environments/atari/ is documenting 3 versions (v0, v4, v5) and there a couple of different options for the v5 setting as well as different versions of the environment for v4. I tried "Breakout-v5" with default options as well as "BreakoutNoFrameskip-v4" however my learning curves do not look anything comparable to yours.
Code for reproduction (I know I could use make_atari_env directly, however for usage with my own monitoring wrappers not included in this example I need to build the environment manually in order to have proper access to the env without vec wrapper and stacking):

import gym
#from rl_utils.monitor import Monitor #my custom monitor
from sb3_contrib import QRDQN

from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv

env = gym.make('BreakoutNoFrameskip-v4')
#gym.make('ALE/Breakout-v5', full_action_space=False, frameskip=1)
env = AtariWrapper(env)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, n_stack=4)
#env = Monitor(env, params, "./results") # commented out, this is my custom monitor for recording learning curves
#env = make_atari_env(env, n_envs=1, seed=0)
model = QRDQN("CnnPolicy", env, verbose=1,
                             exploration_fraction=0.025,   # match QRDQN PR
                             seed=100)
model.learn(total_timesteps=10000000, log_interval=4)

My results looked like the following:

Screenshot from 2022-07-04 10-11-12
Screenshot from 2022-07-04 10-12-24

Would really appreciate if you could give me hints on reproducing your baseline.

@araffin
Copy link
Member

araffin commented Jul 4, 2022

Hello @hh0rva1h ,

Which version of the games were used in your benchmark?

we always use the NoFrameskip-v4 (and gym 0.21).

I tried to reproduce those results, but it did not work out.

Please use the RL Zoo for that, we have instructions in the documentation:
https://sb3-contrib.readthedocs.io/en/master/modules/qrdqn.html#how-to-replicate-the-results

Instructions are also on huggingface hub: https://huggingface.co/sb3/qrdqn-BreakoutNoFrameskip-v4

and detailed hyperparameters are here: https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/qrdqn.yml#L1

however for usage with my own monitoring wrappers not included in this example

I think your issue may come from the monitoring, please read DLR-RM/stable-baselines3#181.

You should at least use make_vec_env and pass your custom wrappers.

@mreitschuster
Copy link

mreitschuster commented Jul 4, 2022

we always use the NoFrameskip-v4 (and gym 0.21).

@araffin any chance you have the reasoning in mind? I thought v4-noframeskip was to be avoided due to the "memorizing sequences" problem (even with the up to 30 noops from the atari wrapper I suspect it just memorizes 31 sequences)

@araffin
Copy link
Member

araffin commented Jul 4, 2022

any chance you have the reasoning in mind?

You can read more about why in the PR on updating gym:

TL;DR: the new atari v5 were not yet benchmarked and as we have everything using NoFrameskip-v4, it is best to keep it like that so we can compare apples to apples.

due to the "memorizing sequences" problem

Do you have a reference for that?

@mreitschuster
Copy link

mreitschuster commented Jul 4, 2022

@araffin: I would have concluded that from
Revisiting the Arcade Learning Environment, Machado et al
page 7: The Arcade Learning Environment is also fully deterministic – each game starts in the same state and outcomes are fully determined by the state and the action. As such, it is possible to achieve high scores by learning an open-loop policy, i.e., by simply memorizing a good action sequence, rather than learning to make good decisions in a variety of game scenarios (Bellemare, Naddaf, Veness, & Bowling, 2015).
page 14 -> Initial Noops ... The environment remains deterministic beyond the choice of starting state. Brute-like methods still perform well.

Brute-like is referring to models that rely purely on memorization.

Of course it is not memorzizing a sequence (as action always only depends on previous state), but I imagine it as being able to properly recognize only states that are from that exact playbook.

And to verify: train a model using NoFrameskip-v4 and evaluate on v4 (stochasitcity via frameskip) or v5 (stochasticity via stickies) -> it will transfer poorly.
and then other way round: train of v4 or v5 and evaluate it on NoFramskip-v4 -> good transfer.

I imagine a more sophisitcated approach would be to us NoFrameskip-v4, take randomly some set of integers A from B={0,...,30} for the noops and then train the model only on B`A and evalute the model on A. use those selected Noops in A only in evaluation but not training. If model performs poorly on those frameskips on which it was not trained -> we know it just memorized.

When trying to reproduce good results on Breakout I ran into exactly that issue: it is simple&straightforward to produce high scores with NoFrameskip-v4 (even with Noops) but with v4 or v5 it becomes much harder.

@hh0rva1h
Copy link

hh0rva1h commented Jul 4, 2022

@mreitschuster Read about the memorizing issue as well. Do you have by any chance baselines curves at hand for the v4 or v5 scenarios? Would be really great to have more baselines apart from the NoFrameskip-v4.
Also, do you know by any chance how would one exactly match the atari setting as used in the the DQN and QR-DQN papers? I could not read out much info about this from the original papers. From skimming over the code I guess using the vanilla gym wrapper https://github.com/openai/gym/blob/master/gym/wrappers/atari_preprocessing.py would be closer to the paper setup than the baselines wrappers, would you agree?
@araffin Thanks for the pointers, will look into it and come back to you.

@mreitschuster
Copy link

mreitschuster commented Jul 4, 2022

@hh0rva1h

Do you have by any chance baselines curves at hand for the v4 or v5 scenarios? Would be really great to have more baselines apart from the NoFrameskip-v4.

Not with qrdqn (yet). working on it. With PPO i have quite a few, but mostly with a non-standard wrapper configuration. But that would go off-topic from this PR. you can have a look at my work on tuning & env selection. Short answer: without an aimbot I get to score of 180 and with aimbot to 220 (on 1e7 training steps).I havent found a PM functionality in github - if you want a discussion on that feel free to open an issue/discussion and add me there.

Also, do you know by any chance how would one exactly match the atari setting as used in the the DQN and QR-DQN papers? ... would be closer to the paper setup than the baselines wrappers, would you agree?

Sorry, no & don't know. I havent dived into qrdqn too deep yet. I was just testing models on breakout and wanted to check my results with others - and qrdqn looked promising - showing scores of 400 - but then realizing it is in the deterministic environment the excitement level dropped.

@araffin
Copy link
Member

araffin commented Jul 4, 2022

Thanks @mreitschuster for the links!

stochasitcity via frameskip

The RL Zoo models are using frameskip (via the MaxAndSkip wrapper) but the frame skipping is not stochastic, is that what you meant? (see https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/atari_wrappers.py#L223)

train a model using NoFrameskip-v4 and evaluate on v4

you need to be careful when you do that, because as mentioned before, if you use atari wrapper, there is frame skipping.

stochasticity via stickies

by sticky actions you mean randomly repeating actions?

@mreitschuster
Copy link

mreitschuster commented Jul 4, 2022

@araffin

The RL Zoo models are using frameskip (via the MaxAndSkip wrapper) but the frame skipping is not stochastic, is that what you meant?

It is my understanding that we have two different frameskips available - stochastic frameskip (randomly skip frames) which is inbuilt into *-v4 (but deactivated when using *NoFrameskip-v4). it is also not active in v5 (which injects stochasticity diffferently).

and then there is the deterministic frameskip - found in sb3 atariwrapper as well as in the environment *Deterministic-v4.

you need to be careful when you do that, because as mentioned before, if you use atari wrapper, there is frame skipping.

yes you are right, but the wrapper provides determinstic frameskip (for speeding up the game), not the stochastic one (to train more transferable skills)

by sticky actions you mean randomly repeating actions?

yes.

i think if we want to deepen this (and i would be very happy to) it should be a seperate discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement QR-DQN
6 participants