-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add C51 algorithm #266
Add C51 algorithm #266
Conversation
Codecov Report
@@ Coverage Diff @@
## master #266 +/- ##
==========================================
- Coverage 94.54% 93.98% -0.57%
==========================================
Files 41 42 +1
Lines 2677 2760 +83
==========================================
+ Hits 2531 2594 +63
- Misses 146 166 +20
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Nice job! Could you please also modify the |
That's ok I can help you fix the PEP8. |
Thank you very much. I will modify the README.md and docs/index.rst in my next PR. |
Just in this PR is okay. |
I'm not very good at GitHub, and I could not find how to do it in this PR. |
Okay, that's fine :) I'll take a look later on. |
All checks have passed :) It was a tough journey. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, could you please add a test script for CartPole-v0 under test/discrete/
?
I think you can directly add this file. No need to make a separate PR. |
I hope to combine the results of C51 in a variety of atrai games with it in the future PR. |
Cool, and I think you can add these results here so that this PR can be a complete version of C51 implementation :) |
I'm not sure how to add new files under the current PR. So, add a new PR is more convenient for me. In addition, I'm not quite sure when I can finish this work. |
Just add the file in shengxiang19/C51 branch instead of here. See https://stackoverflow.com/questions/10147445/github-adding-commits-to-existing-pull-request
I can wait for you. |
Thank you. I can try it later. |
fix bugs in pc51
I have add a test_c51 for CartPole-v0 under test/discrete/. Hope you can help me check it. |
I have add the results of C51 in three typical atari environments. My current plan of C51 is done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll optimize the n_step code this week (in this PR). Thanks for your great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm running QbertNoFrameskip-v4 for evaluation. Now in epoch 31 it reaches 14047.
Please double-check my implementation.
""" | ||
model = getattr(self, model) | ||
obs = batch[input] | ||
obs_ = obs.obs if hasattr(obs, "obs") else obs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why hasattr(obs, "obs")
could be false ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three are the same as existing DQNPolicy. I guess we can make a separate PR to enhance these things :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I noticed that :)
dist, h = model(obs_, state=state, info=batch.info) | ||
q = (dist * self.support).sum(2) | ||
act: np.ndarray = to_numpy(q.max(dim=1)[1]) | ||
if hasattr(obs, "mask"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like much this approach, but right now I have no idea about to avoid it. Maybe adding masked_array
method to Batch
class to offer something similar to numpy's masked arrays. Internally it would use the same mechanism, but it would be hidden in Batch
, which is way better in by opinion.
batch.weight = cross_entropy.detach() # prio-buffer | ||
loss.backward() | ||
self.optim.step() | ||
self._cnt += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend explicit variable names _cnt
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887 1. add C51 policy in tianshou/policy/modelfree/c51.py. 2. add C51 net in tianshou/utils/net/discrete.py. 3. add C51 atari example in examples/atari/atari_c51.py. 4. add C51 statement in tianshou/policy/__init__.py. 5. add C51 test in test/discrete/test_c51.py. 6. add C51 atari results in examples/atari/results/c51/. By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '20.50 ± 0.50', in epoch 9. By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
Distributional RL algorithms are very powerful in atari environments. I am going to implement a series of typical algorithms, i.e. C51, QR-DQN, IQN, FQF, based on the reinforcement learning platform Tianshou.
This is my frist PR for C51algorithm: https://arxiv.org/abs/1707.06887
By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '20.50 ± 0.50', in epoch 9.
By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.