Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parallel Q-Networks algorithm (PQN) #472

Merged
merged 10 commits into from
Nov 14, 2024
Merged

Conversation

roger-creus
Copy link
Collaborator

@roger-creus roger-creus commented Jul 17, 2024

Description

Adding PQN from Simplifying Deep Temporal Difference Learning

I have implemented both pqn.py and pqn_atari_envpool.py. The results are promising for the Cartpole version. Check them out here. I am now running some debugging experiments for the Atari version.

Some details about the implementations:

  • Both use envpool
  • Hyperaprameters try to match the configs from the official implementations but some are changed (epsilon-decay schedule matches the DQN implementation from cleanRL. I haven't checked the importance of the hyperparameter the defaults in CleanRL made more sense to me)
  • For comparing pqn.py and dqn.py in cartpole I multiplied the rewards from the environment by 0.1 as done in the official implementation of PQN. performance increases for both algos.
  • Using LayerNorm in the networks instead of allowing the user to select between Layer or Batch norm. Layer norm should work better.
  • Not giving the user the option to add BatchNorm to the inputs to the network (i.e. states) as in the official implementaiton.

Overall the implementation is similar to ppo with envpool (so very fast!) but with the sample-efficiency of Q-learning! Nice algorithm! :)

Let me know how to proceed!

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the tests accordingly (if applicable).
  • I have updated the documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers.

If you need to run benchmark experiments for a performance-impacting changes:

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team.
  • I have used the benchmark utility to submit the tracked experiments to the openrlbenchmark/cleanrl W&B project, optionally with --capture_video.
  • I have performed RLops with python -m openrlbenchmark.rlops.
    • For new feature or bug fix:
      • I have used the RLops utility to understand the performance impact of the changes and confirmed there is no regression.
    • For new algorithm:
      • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves generated by the python -m openrlbenchmark.rlops utility to the documentation.
    • I have added links to the tracked experiments in W&B, generated by python -m openrlbenchmark.rlops ....your_args... --report, to the documentation.

Copy link

vercel bot commented Jul 17, 2024

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add feedback Nov 14, 2024 3:11pm

@sdpkjc
Copy link
Collaborator

sdpkjc commented Jul 17, 2024

Hey Roger, it's really cool to see you adding PQN to CleanRL! I've read the paper before, and I think your implementation is great. When it comes time to run benchmarks or add documentation, let's collaborate to see how we can best do it. Looking forward to seeing the completed PR! 🚀👍

@roger-creus
Copy link
Collaborator Author

roger-creus commented Jul 17, 2024

I think the code might be ready to be benchmarked. These are some results in Breakout. It seems to converge to 400 score in 10M which would match DQN. The official imlpementation reports 515 score after 400M steps. Should I be added to the openrlbenchmark W&B team?

image

@sdpkjc
Copy link
Collaborator

sdpkjc commented Jul 18, 2024

I think the code might be ready to be benchmarked. These are some results in Breakout. It seems to converge to 400 score in 10M which would match DQN. The official imlpementation reports 515 score after 400M steps. Should I be added to the openrlbenchmark W&B team?

image

I noticed that the epsilon greedy implementation in our current setup differs from the official one, where each environment independently performs epsilon greedy exploration, whereas in our implementation, all environments share a single random number. This might have an impact when running many environments in parallel. Of course, there could be other reasons for the performance differences too. Let's start by running some benchmark tests to see if the performance also falls short in other environments. Looking forward to working through this together!

https://github.com/mttga/purejaxql/blob/9878a74439593c5d0acc8e506fefc44daa230c51/purejaxql/pqn_atari.py#L312-L325

… some envs can explore and some exploit, like in the official implementation
@roger-creus
Copy link
Collaborator Author

Very nice catch! Let me try to set up the benchmark experiments :)

@roger-creus
Copy link
Collaborator Author

Here are some first results!
I think they look pretty good but maybe in BeamRider-v5 is falling a bit short. Let me double check the implementation and run some more experiments

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 19, 2024

Been watching this from far, very cool work!!

@pseudo-rnd-thoughts
Copy link
Collaborator

pseudo-rnd-thoughts commented Jul 19, 2024

Nice job, your results show it takes 25 minutes for 10 million frames while the paper reports 200 million in an hour.
Do you know why there are such significant differences in performance?

No equivalent to jax.jit or jax.lax.scan?

@roger-creus
Copy link
Collaborator Author

Updated results here. I wonder how should I generate the comparison between DQN/PQN with the rlops function since I am using envpool and I am not being able to compare pqn_atari_envpool.py in Breakout-v5 vs dqn_atari.py in BreakoutNoFrameskip-v4 for instance. Should I make a version of PQN that doesn't use envpool?

@pseudo-rnd-thoughts It is probably because jax.lax.scan. I am not used to coding in jax, but after searching online, it seems pytorch does not have a function like scan...

@sdpkjc
Copy link
Collaborator

sdpkjc commented Jul 20, 2024

Maybe try torch.compile?

@roger-creus
Copy link
Collaborator Author

Hey! How do you think we should proceed?

I believe that it will be hard to match the speed of the JAX-based original implementation in this torch implementation, but at least it provides a Q-learning alternative + envpool that matches CleanRL envpool PPO, which can already be very useful! :)

@roger-creus
Copy link
Collaborator Author

I realized I was re-computing the values for each state in the rollouts when computing Q(lambda) returns. I have now used a values buffer (as used in the PPO implementations actually) and replaced the computations in the Q(lambda) process. Performance remains the same and the code is now approx 150% faster.

Also, I added pqn_atari_lstm_envpool! First results in the atari environments show the implementation is correct. Please double check! :)

Please let me know how we should continue!

@pseudo-rnd-thoughts
Copy link
Collaborator

@roger-creus There is a larger issue of EnvPool with rollouts and computing the loss function, see #475

@roger-creus
Copy link
Collaborator Author

roger-creus commented Oct 23, 2024

@vwxyzjn Hey! Following up :) What more do you think should be done for this PR?
See the results here

@vwxyzjn
Copy link
Owner

vwxyzjn commented Oct 23, 2024

Hey @roger-creus thanks for the ping and the nice RP! Could you add documentation like the other algorithms?

Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments on style. Overall we should try to reduce the lines of code differences against a reference file (e.g., pqn.py and ppo.py or ppo_atari_envpool.py and pqn_atari_envpool.py)

cleanrl/dqn.py Outdated Show resolved Hide resolved
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.99
"""the discount factor gamma"""
num_minibatches: int = 32
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of minibatches seems quite high. Where is this coming from? When comparing to other parameters I noticed quite some difference as well.

image

It's also inconsistent with the pqn_atari_lstm_envpool
image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took the PQN hyperparameters from the official implementation. Overall, I have found num_envs=128, num_steps=32, and num_minibatches=32 to be a great combination of both speed and policy rewards (8.2k FPS in my setup compared to the current ppo_atari_envpool for which I get 2.7k FPS).

I am running some experiments to evaluate PQN with the 2 different configurations:

  • num_envs=128, num_steps=32, num_minibatches=32
    vs
  • num_envs=8, num_steps=128, num_minibatches=4

and will update you with the results. Maybe in a future PR I can do the same for PPO?

The inconsistencies with pqn_lstm will be fixed!

Copy link
Collaborator Author

@roger-creus roger-creus Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually for now I will just leave it to the current defaults of CleanRL: num_envs=8, num_steps=128, num_minibatches=4 and once this PR is merged I can make a new one to decide the best hyperparameters for both PQN and PPO (envpool versions).

Comment on lines 115 to 128
nn.Conv2d(4, 32, 8, stride=4),
nn.LayerNorm([32, 20, 20]),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.LayerNorm([64, 9, 9]),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
nn.LayerNorm([64, 7, 7]),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Linear(512, env.single_action_space.n),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe do the layer_init to PQN as well like in PPO?

cleanrl/pqn_atari_lstm_envpool.py Outdated Show resolved Hide resolved
cleanrl/pqn.py Outdated Show resolved Hide resolved
… PQN Lstm. Ran validation experiments. Reverting undesired changes in DQN. WIP documentation
@roger-creus
Copy link
Collaborator Author

Thanks for your review @vwxyzjn !

  • pqn.py now uses gymnasium instead of envpool (I ran some experiments in CartPole and it continues to work).
  • pqn_atari_envpool.py and pqn_atari_envpool_lstm.py now have the same rollout hyperparameters as PPO.
  • Validation experiments are successful -- they match PPO. See that PQN-LSTM experiments are now running.
  • I added a documentation page but I am unsure how to add the plots/curves/tables for performance with RLOps. I tried but got an error, maybe because the logs are currently in my WandB account.

Let me know how to proceed! :)

cleanrl/pqn.py Outdated
@@ -0,0 +1,247 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be updated

@vwxyzjn
Copy link
Owner

vwxyzjn commented Oct 30, 2024

Hi @roger-creus sorry the slow response. Lots of going on lately. Could you clone the openrlbenchmark repo and run the following? Namely:

git clone https://github.com/openrlbenchmark/openrlbenchmark.git
poetry install
python -m openrlbenchmark.rlops \
    --filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
        'pqn_atari_envpool?tag=pr-472&cl=CleanRL PQN' \
    --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \
    --no-check-empty-runs \
    --pc.ncols 3 \
    --pc.ncols-legend 3 \
    --rliable \
    --rc.score_normalization_method maxmin \
    --rc.normalized_score_threshold 1.0 \
    --rc.sample_efficiency_plots \
    --rc.sample_efficiency_and_walltime_efficiency_method Median \
    --rc.performance_profile_plots  \
    --rc.aggregate_metrics_plots  \
    --rc.sample_efficiency_num_bootstrap_reps 10 \
    --rc.performance_profile_num_bootstrap_reps 10 \
    --rc.interval_estimates_num_bootstrap_reps 10 \
    --output-filename static/0compare \
    --scan-history

I can't do the plots because I don't have access to the cleanrl project. Could you make it public? Thanks.

image

@vwxyzjn
Copy link
Owner

vwxyzjn commented Oct 30, 2024

I tested it with a new install and openrlbenchmark seems to work fine with the demo command: https://github.com/openrlbenchmark/openrlbenchmark?tab=readme-ov-file#get-started

image

@roger-creus
Copy link
Collaborator Author

Thanks for the reply @vwxyzjn ! I have updated the documentation, added the plots, etc. I have checked the changes locally and I think it looks fine! All tests pass, etc.

This PR is pretty much ready in my opinion, but let me know if you would like some additional details or experiments! :)

@roger-creus
Copy link
Collaborator Author

@vwxyzjn Just following up :) Let me know if you'd like anything else in this PR

@vwxyzjn
Copy link
Owner

vwxyzjn commented Nov 14, 2024

This is an amazing PR. Merging as is. Thanks for the great work @roger-creus!!! I am also sent you an invite as a collaborator for CleanRL. Feel free to merge the PR.

@roger-creus roger-creus merged commit e648ee2 into vwxyzjn:master Nov 14, 2024
38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants