-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added SUM as aggregation type for custom statistics #4816
Added SUM as aggregation type for custom statistics #4816
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I think this is a good change and I can see how it would be useful. I have some feedback on the particular implementation though. There are two (mostly mutually exclusive) ways to go from here:
Option 1
Instead of renaming StatsSummary.mean
, you could add a sum
field, and also add an aggregation_method
field. Then in the StatsWriter implementation (for example
self.summary_writers[category].add_scalar(f"{key}", value.mean, step) |
mean
or sum
.
You could also add an aggregated_value
property to StatsSummary
like
@property
def aggregated_value(self):
return self.sum if aggregated_value == StatsAggregationMethod.SUM else self.mean
Option 2
Add a method to StatsReporter
like increment_stat
, and call that from record_environment_stats()
when agg_type == SUM
. That would look something like
def increment_stat(self, key: str, value: float) -> None:
with StatsReporter.lock:
if StatsReporter.stats_dict[self.category][key]:
# Add to the last stat. If we're always using increment_stat, this
# should be the only element in the list.
StatsReporter.stats_dict[self.category][key][-1] += value
else:
# New list with the value as the only element
StatsReporter.stats_dict[self.category][key] = [value]
(but make sure you test it 😃 )
I think I prefer the second approach because it's more consistent with the existing idea of "aggregating" in the StatsReporter and it doesn't require storing the AggregationMethod, but the first approach is more flexible (especially since we're planning to let users add their own StatsWriters soon, see the plugins PR).
Do you have a preference for either one?
@@ -35,6 +38,7 @@ def __init__(self) -> None: | |||
def on_message_received(self, msg: IncomingMessage) -> None: | |||
""" | |||
Receive the message from the environment, and save it for later retrieval. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the extra whitespaces that were added in the docstrings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the extra spaces are necessary for PyCharm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please try to find the settings to disable them; I use PyCharm as do several other team members, and this hasn't been a problem before. At a minimum, you need to revert files that only contain the whitespace changes:
- ml-agents/mlagents/trainers/ppo/trainer.py
- ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
- ml-agents/mlagents/trainers/trainer/rl_trainer.py
On the other hand, if you know of a way to automate this from the command line (and enforce that the :param ...
names match the args), I'd love that in another PR.
@@ -272,8 +272,8 @@ def test_agent_manager_stats(): | |||
manager.record_environment_stats(env_stats, worker_id=0) | |||
|
|||
expected_stats = { | |||
"averaged": StatsSummary(mean=2.0, std=mock.ANY, num=2), | |||
"most_recent": StatsSummary(mean=4.0, std=0.0, num=1), | |||
"averaged": StatsSummary(stats_value=2.0, std=mock.ANY, num=2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should also extend this test to include StatsAggregationMethod.SUM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New commit to extend it
ml-agents/mlagents/trainers/stats.py
Outdated
log_info.append(f"Std of Reward: {stats_summary.std:0.3f}") | ||
log_info.append(f"Num of Reward: {stats_summary.num:0.3f}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you should add this; it's not useful for training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me it was. If you have a Decision Request script attached, the expected value may change.
In my case, MaxStep = 3000, DecisionRequest = 6, but I wasn't sure how this works at that time.
My custom stat was 0.3 every step.
0.3 * 3000 = 900
but instead I was getting 150, and I didn't know why.
Only after I added the "Num of Reward" I noticed that I was getting 500 Num.
That is because 3000 / 6 = 500.
So,
0.3 * 3000 / 6 = 150.
Because I didn't know how DecisionRequest worked, it took me a while to figure this out.
With the "Num of Rewards" I think it would have saved me a few hours.
Well, that was my experience, I won't mind if you want to remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'd like you to remove it. You'll be able to add your own StatsWriter that gets called soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this (and updated the tests)
@@ -75,17 +76,24 @@ def test_run_training( | |||
learn.run_training(0, options) | |||
mock_init.assert_called_once_with( | |||
trainer_factory_mock.return_value, | |||
"results/ppo", | |||
os.path.join("results", "ppo"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is a bad change, but does it need to happen in this PR? Does this test fail on a certain platform (Windows?) without it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it fails in Windows (I am using Windows)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I'll make a note for us to run pytest on Windows too.
Hi @chriselion
but I would have to rewrite all the writers classes ( Option 2 would not give much flexibility to add other aggregations, we are writing our own implementations, but numpy already have them implemented. |
If you add the property I suggested, you wouldn't need to add any extra logic in the writers. You could even call the property
Maybe, but I'd rather not make the logic any more complicated.
We can add them as we need them. |
ml-agents/mlagents/trainers/stats.py
Outdated
std=np.std(StatsReporter.stats_dict[self.category][key]), | ||
num=len(StatsReporter.stats_dict[self.category][key]), | ||
) | ||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not your fault, but this method is getting messy and should be cleaned up. Something like
stat_values = StatsReporter.stats_dict[self.category][key]
if len(stat_values) == 0:
return StatsSummary.empty()
return StatsSummary(
mean=np.mean(stat_values),
sum=np.sum(stat_values),
std=np.std(stat_values),
num=len(stat_values),
aggregation_method = StatsReporter.stats_aggregation[self.category][key],
)
should be a bit cleaner.
Hi @chriselion , |
ml-agents/mlagents/trainers/stats.py
Outdated
@@ -95,8 +108,8 @@ def write_stats( | |||
) -> None: | |||
for val, stats_summary in values.items(): | |||
set_gauge( | |||
GaugeWriter.sanitize_string(f"{category}.{val}.mean"), | |||
float(stats_summary.mean), | |||
GaugeWriter.sanitize_string(f"{category}.{val}.aggregated_value"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to break some of our internal processes (that live outside of this repo).
Can you make it:
set_gauge(GaugeWriter.sanitize_string(f"{category}.{val}.mean"), float(stats_summary.mean))
set_gauge(GaugeWriter.sanitize_string(f"{category}.{val}.sum"), float(stats_summary.sum))
instead?
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
Outdated
Show resolved
Hide resolved
Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
Outdated
Show resolved
Hide resolved
Thanks, I think it's almost there! I left a few final comments but otherwise it looks pretty good. Sorry to keep harping on the newlines, but I don't think a bug in pycharm's display is a worthwhile reason to change. I'm OK with them in the files that you have other changes in, but files where those are the only changes should be reverted (I can do this for you in git, as long as you have the setting on the PR to allow repo owners to push changes). |
All done, @chriselion |
self, | ||
step: Union[ | ||
TerminalStep, DecisionStep | ||
], # pylint: disable=unsubscriptable-object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See pylint-dev/pylint#2377 (sigh)
@@ -37,14 +40,23 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str: | |||
) | |||
|
|||
|
|||
class StatsSummary(NamedTuple): | |||
class StatsSummary(NamedTuple): # pylint: disable=inherit-non-class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See pylint-dev/pylint#3876 (sigh)
Thanks, I made a few small cleanups and reverts. Will merge this when tests pass. |
Added SUM as aggregation type for custom statistics