-
Notifications
You must be signed in to change notification settings - Fork 7k
[rllib] Add debug info back to PPO and fix optimizer compatibility #2366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| for k, v in iter_extra_fetches.items(): | ||
| all_extra_fetches[k] += [v] | ||
| iter_extra_fetches[k].append(v) | ||
| print(i, _averaged(iter_extra_fetches)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was incorrect before; you want to return the last epoch's stats, not the mean across all epochs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the previous impl returned stats of all epochs and not only an average; it was no less "incorrect".
The reason why you need the "last epoch" is PPO specific.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eh, I'm not sure why returning all epoch stats makes sense in any scenario, unless your epochs were very small and the values noisy.
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
|
Test PASSed. |
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
richardliaw
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(comments were made from partial review a couple days ago)
python/ray/rllib/agents/ppo/ppo.py
Outdated
| self.local_evaluator, self.remote_evaluators) | ||
|
|
||
| def _train(self): | ||
| def postprocess_samples(batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
python/ray/rllib/agents/ppo/ppo.py
Outdated
| # Which observation filter to apply to the observation | ||
| "observation_filter": "MeanStdFilter", | ||
| # Debug only: use the sync samples optimizer instead of the multi-gpu one | ||
| "debug_use_simple_optimizer": False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this debug_only?
| for k, v in iter_extra_fetches.items(): | ||
| all_extra_fetches[k] += [v] | ||
| iter_extra_fetches[k].append(v) | ||
| print(i, _averaged(iter_extra_fetches)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the previous impl returned stats of all epochs and not only an average; it was no less "incorrect".
The reason why you need the "last epoch" is PPO specific.
| all_extra_fetches = defaultdict(list) | ||
| num_batches = ( | ||
| int(tuples_per_device) // int(self.per_device_batch_size)) | ||
| print("== sgd epochs ==") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you flag this off with a verbose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(flag off all of the prints in the optimizers with verbose?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do this with log level in the future. For now, we've always printed by default so this is just restoring 0.4 functionality.
| rewards_plus_v = np.concatenate( | ||
| [rollout["rewards"], np.array([last_r])]) | ||
| traj["advantages"] = discount(rewards_plus_v, gamma)[:-1] | ||
| traj["value_targets"] = np.zeros_like(traj["advantages"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use a critic without using GAE, but this does not allow that functionality. Can you add documentation noting this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
Test FAILed. |
What do these changes do?
Make it work with sync samples optimizer too (this enables LSTM), and add back debug stats.