-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Added SUM as aggregation type for custom statistics #4816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
chriselion
merged 21 commits into
Unity-Technologies:master
from
brccabral:mlagents_python
Jan 8, 2021
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
953b08c
remove Builds folder from git
brccabral 38afbde
log number of rewards in cmd summary
brccabral b9a18a0
added sum to StatsSummary
brccabral 9c3f059
added SUM in Unity StatAggregationMethod
brccabral b7328ba
added support for SUM as StatsAggregationMethod in python mlagents
brccabral f178316
example to use SUM as aggregation
brccabral e7c45cd
fixed field order with default values for StatsSummary
brccabral b6c9a2e
simplified StatsSummary
brccabral 9206b46
add default value for custom stats
brccabral 4a44514
fixed tests
brccabral 1299062
extended test test_agent_manager_stats in test_agent_processor.py to …
brccabral f667c0b
refractor StatsSummary to add sum as property
brccabral 1531ec3
fixed tests
brccabral dadf67d
Unity coding standard
brccabral b5e6397
reverted docstring empty lines
brccabral 913fddf
GaugeWriter fix
brccabral 8bc9892
revert some whitespace
cc76bea
undo whitespace
31f400d
undo undesired change, undo whitespace
eaf51a8
revert unit test logging strings
201f099
changelog
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
/summaries | ||
# Output Artifacts | ||
/results | ||
# Output Builds | ||
/Builds | ||
|
||
# Training environments | ||
/envs | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ def __init__( | |
): | ||
""" | ||
Create an AgentProcessor. | ||
|
||
:param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory | ||
when it is finished. | ||
:param policy: Policy instance associated with this AgentProcessor. | ||
|
@@ -112,7 +113,12 @@ def add_experiences( | |
) | ||
|
||
def _process_step( | ||
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int | ||
self, | ||
step: Union[ | ||
TerminalStep, DecisionStep | ||
], # pylint: disable=unsubscriptable-object | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See pylint-dev/pylint#2377 (sigh) |
||
global_id: str, | ||
index: int, | ||
) -> None: | ||
terminated = isinstance(step, TerminalStep) | ||
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) | ||
|
@@ -318,15 +324,18 @@ def record_environment_stats( | |
""" | ||
Pass stats from the environment to the StatsReporter. | ||
Depending on the StatsAggregationMethod, either StatsReporter.add_stat or StatsReporter.set_stat is used. | ||
The worker_id is used to determin whether StatsReporter.set_stat should be used. | ||
The worker_id is used to determine whether StatsReporter.set_stat should be used. | ||
|
||
:param env_stats: | ||
:param worker_id: | ||
:return: | ||
""" | ||
for stat_name, value_list in env_stats.items(): | ||
for val, agg_type in value_list: | ||
if agg_type == StatsAggregationMethod.AVERAGE: | ||
self.stats_reporter.add_stat(stat_name, val) | ||
self.stats_reporter.add_stat(stat_name, val, agg_type) | ||
elif agg_type == StatsAggregationMethod.SUM: | ||
self.stats_reporter.add_stat(stat_name, val, agg_type) | ||
elif agg_type == StatsAggregationMethod.MOST_RECENT: | ||
# In order to prevent conflicts between multiple environments, | ||
# only stats from the first environment are recorded. | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ | |
import time | ||
from threading import RLock | ||
|
||
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod | ||
|
||
from mlagents_envs.logging_util import get_logger | ||
from mlagents_envs.timers import set_gauge | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
@@ -20,8 +22,9 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str: | |
""" | ||
Takes a parameter dictionary and converts it to a human-readable string. | ||
Recurses if there are multiple levels of dict. Used to print out hyperparameters. | ||
param: param_dict: A Dictionary of key, value parameters. | ||
return: A string version of this dictionary. | ||
|
||
:param param_dict: A Dictionary of key, value parameters. | ||
:return: A string version of this dictionary. | ||
""" | ||
if not isinstance(param_dict, dict): | ||
return str(param_dict) | ||
|
@@ -37,14 +40,23 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str: | |
) | ||
|
||
|
||
class StatsSummary(NamedTuple): | ||
class StatsSummary(NamedTuple): # pylint: disable=inherit-non-class | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See pylint-dev/pylint#3876 (sigh) |
||
mean: float | ||
std: float | ||
num: int | ||
sum: float | ||
aggregation_method: StatsAggregationMethod | ||
|
||
@staticmethod | ||
def empty() -> "StatsSummary": | ||
return StatsSummary(0.0, 0.0, 0) | ||
return StatsSummary(0.0, 0.0, 0, 0.0, StatsAggregationMethod.AVERAGE) | ||
|
||
@property | ||
def aggregated_value(self): | ||
if self.aggregation_method == StatsAggregationMethod.SUM: | ||
return self.sum | ||
else: | ||
return self.mean | ||
|
||
|
||
class StatsPropertyType(Enum): | ||
|
@@ -71,8 +83,9 @@ def add_property( | |
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters, | ||
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible | ||
with all types of properties. For instance, a TB writer doesn't need a max step. | ||
|
||
:param category: The category that the property belongs to. | ||
:param type: The type of property. | ||
:param property_type: The type of property. | ||
:param value: The property itself. | ||
""" | ||
pass | ||
|
@@ -98,6 +111,10 @@ def write_stats( | |
GaugeWriter.sanitize_string(f"{category}.{val}.mean"), | ||
float(stats_summary.mean), | ||
) | ||
set_gauge( | ||
GaugeWriter.sanitize_string(f"{category}.{val}.sum"), | ||
float(stats_summary.sum), | ||
) | ||
|
||
|
||
class ConsoleWriter(StatsWriter): | ||
|
@@ -114,7 +131,7 @@ def write_stats( | |
is_training = "Not Training" | ||
if "Is Training" in values: | ||
stats_summary = values["Is Training"] | ||
if stats_summary.mean > 0.0: | ||
if stats_summary.aggregated_value > 0.0: | ||
is_training = "Training" | ||
|
||
elapsed_time = time.time() - self.training_start_time | ||
|
@@ -156,10 +173,11 @@ class TensorboardWriter(StatsWriter): | |
def __init__(self, base_dir: str, clear_past_data: bool = False): | ||
""" | ||
A StatsWriter that writes to a Tensorboard summary. | ||
|
||
:param base_dir: The directory within which to place all the summaries. Tensorboard files will be written to a | ||
{base_dir}/{category} directory. | ||
:param clear_past_data: Whether or not to clean up existing Tensorboard files associated with the base_dir and | ||
category. | ||
category. | ||
""" | ||
self.summary_writers: Dict[str, SummaryWriter] = {} | ||
self.base_dir: str = base_dir | ||
|
@@ -170,7 +188,9 @@ def write_stats( | |
) -> None: | ||
self._maybe_create_summary_writer(category) | ||
for key, value in values.items(): | ||
self.summary_writers[category].add_scalar(f"{key}", value.mean, step) | ||
self.summary_writers[category].add_scalar( | ||
f"{key}", value.aggregated_value, step | ||
) | ||
self.summary_writers[category].flush() | ||
|
||
def _maybe_create_summary_writer(self, category: str) -> None: | ||
|
@@ -214,6 +234,9 @@ class StatsReporter: | |
writers: List[StatsWriter] = [] | ||
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list)) | ||
lock = RLock() | ||
stats_aggregation: Dict[str, Dict[str, StatsAggregationMethod]] = defaultdict( | ||
lambda: defaultdict(lambda: StatsAggregationMethod.AVERAGE) | ||
) | ||
|
||
def __init__(self, category: str): | ||
""" | ||
|
@@ -234,37 +257,51 @@ def add_property(self, property_type: StatsPropertyType, value: Any) -> None: | |
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters, | ||
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible | ||
with all types of properties. For instance, a TB writer doesn't need a max step. | ||
:param key: The type of property. | ||
|
||
:param property_type: The type of property. | ||
:param value: The property itself. | ||
""" | ||
with StatsReporter.lock: | ||
for writer in StatsReporter.writers: | ||
writer.add_property(self.category, property_type, value) | ||
|
||
def add_stat(self, key: str, value: float) -> None: | ||
def add_stat( | ||
self, | ||
key: str, | ||
value: float, | ||
aggregation: StatsAggregationMethod = StatsAggregationMethod.AVERAGE, | ||
) -> None: | ||
""" | ||
Add a float value stat to the StatsReporter. | ||
|
||
:param key: The type of statistic, e.g. Environment/Reward. | ||
:param value: the value of the statistic. | ||
:param aggregation: the aggregation method for the statistic, default StatsAggregationMethod.AVERAGE. | ||
""" | ||
with StatsReporter.lock: | ||
StatsReporter.stats_dict[self.category][key].append(value) | ||
StatsReporter.stats_aggregation[self.category][key] = aggregation | ||
chriselion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def set_stat(self, key: str, value: float) -> None: | ||
""" | ||
Sets a stat value to a float. This is for values that we don't want to average, and just | ||
want the latest. | ||
|
||
:param key: The type of statistic, e.g. Environment/Reward. | ||
:param value: the value of the statistic. | ||
""" | ||
with StatsReporter.lock: | ||
StatsReporter.stats_dict[self.category][key] = [value] | ||
StatsReporter.stats_aggregation[self.category][ | ||
key | ||
] = StatsAggregationMethod.MOST_RECENT | ||
|
||
def write_stats(self, step: int) -> None: | ||
""" | ||
Write out all stored statistics that fall under the category specified. | ||
The currently stored values will be averaged, written out as a single value, | ||
and the buffer cleared. | ||
|
||
:param step: Training step which to write these stats as. | ||
""" | ||
with StatsReporter.lock: | ||
|
@@ -279,14 +316,19 @@ def write_stats(self, step: int) -> None: | |
|
||
def get_stats_summaries(self, key: str) -> StatsSummary: | ||
""" | ||
Get the mean, std, and count of a particular statistic, since last write. | ||
Get the mean, std, count, sum and aggregation method of a particular statistic, since last write. | ||
|
||
:param key: The type of statistic, e.g. Environment/Reward. | ||
:returns: A StatsSummary NamedTuple containing (mean, std, count). | ||
:returns: A StatsSummary containing summary statistics. | ||
""" | ||
if len(StatsReporter.stats_dict[self.category][key]) > 0: | ||
return StatsSummary( | ||
mean=np.mean(StatsReporter.stats_dict[self.category][key]), | ||
std=np.std(StatsReporter.stats_dict[self.category][key]), | ||
num=len(StatsReporter.stats_dict[self.category][key]), | ||
) | ||
return StatsSummary.empty() | ||
stat_values = StatsReporter.stats_dict[self.category][key] | ||
if len(stat_values) == 0: | ||
return StatsSummary.empty() | ||
|
||
return StatsSummary( | ||
mean=np.mean(stat_values), | ||
std=np.std(stat_values), | ||
num=len(stat_values), | ||
sum=np.sum(stat_values), | ||
aggregation_method=StatsReporter.stats_aggregation[self.category][key], | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the extra whitespaces that were added in the docstrings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the extra spaces are necessary for PyCharm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please try to find the settings to disable them; I use PyCharm as do several other team members, and this hasn't been a problem before. At a minimum, you need to revert files that only contain the whitespace changes:
On the other hand, if you know of a way to automate this from the command line (and enforce that the
:param ...
names match the args), I'd love that in another PR.