Skip to content

Added SUM as aggregation type for custom statistics #4816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
/summaries
# Output Artifacts
/results
# Output Builds
/Builds

# Training environments
/envs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ public class HallwayAgent : Agent
Renderer m_GroundRenderer;
HallwaySettings m_HallwaySettings;
int m_Selection;
StatsRecorder m_statsRecorder;

public override void Initialize()
{
m_HallwaySettings = FindObjectOfType<HallwaySettings>();
m_AgentRb = GetComponent<Rigidbody>();
m_GroundRenderer = ground.GetComponent<Renderer>();
m_GroundMaterial = m_GroundRenderer.material;
m_statsRecorder = Academy.Instance.StatsRecorder;
}

public override void CollectObservations(VectorSensor sensor)
Expand Down Expand Up @@ -83,11 +85,13 @@ void OnCollisionEnter(Collision col)
{
SetReward(1f);
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.goalScoredMaterial, 0.5f));
m_statsRecorder.Add("Goal/Correct", 1, StatAggregationMethod.Sum);
}
else
{
SetReward(-0.1f);
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f));
m_statsRecorder.Add("Goal/Wrong", 1, StatAggregationMethod.Sum);
}
EndEpisode();
}
Expand Down Expand Up @@ -156,5 +160,7 @@ public override void OnEpisodeBegin()
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
}
m_statsRecorder.Add("Goal/Correct", 0, StatAggregationMethod.Sum);
m_statsRecorder.Add("Goal/Wrong", 0, StatAggregationMethod.Sum);
}
}
3 changes: 3 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to

### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- `StatAggregationMethod.Sum` can now be passed to `StatsRecorder.Add()`. This
will result in the values being summed (instead of averaged) when written to
TensorBoard. Thanks to @brccabral for the contribution! (#4816)

#### ml-agents / ml-agents-envs / gym-unity (Python)

Expand Down
7 changes: 6 additions & 1 deletion com.unity.ml-agents/Runtime/StatsRecorder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ public enum StatAggregationMethod
/// To avoid conflicts when training with multiple concurrent environments, only
/// stats from worker index 0 will be tracked.
/// </summary>
MostRecent = 1
MostRecent = 1,

/// <summary>
/// Values within the summary period are summed up before reporting.
/// </summary>
Sum = 2
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class StatsAggregationMethod(Enum):
# Only the most recent value is reported.
MOST_RECENT = 1

# Values within the summary period are summed up before reporting.
SUM = 2


StatList = List[Tuple[float, StatsAggregationMethod]]
EnvironmentStats = Mapping[str, StatList]
Expand All @@ -35,6 +38,7 @@ def __init__(self) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
"""
Receive the message from the environment, and save it for later retrieval.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the extra whitespaces that were added in the docstrings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the extra spaces are necessary for PyCharm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to find the settings to disable them; I use PyCharm as do several other team members, and this hasn't been a problem before. At a minimum, you need to revert files that only contain the whitespace changes:

  • ml-agents/mlagents/trainers/ppo/trainer.py
  • ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  • ml-agents/mlagents/trainers/trainer/rl_trainer.py

On the other hand, if you know of a way to automate this from the command line (and enforce that the :param ... names match the args), I'd love that in another PR.

:param msg:
:return:
"""
Expand All @@ -47,6 +51,7 @@ def on_message_received(self, msg: IncomingMessage) -> None:
def get_and_reset_stats(self) -> EnvironmentStats:
"""
Returns the current stats, and resets the internal storage of the stats.

:return:
"""
s = self.stats
Expand Down
15 changes: 12 additions & 3 deletions ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
):
"""
Create an AgentProcessor.

:param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory
when it is finished.
:param policy: Policy instance associated with this AgentProcessor.
Expand Down Expand Up @@ -112,7 +113,12 @@ def add_experiences(
)

def _process_step(
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int
self,
step: Union[
TerminalStep, DecisionStep
], # pylint: disable=unsubscriptable-object
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_id: str,
index: int,
) -> None:
terminated = isinstance(step, TerminalStep)
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
Expand Down Expand Up @@ -318,15 +324,18 @@ def record_environment_stats(
"""
Pass stats from the environment to the StatsReporter.
Depending on the StatsAggregationMethod, either StatsReporter.add_stat or StatsReporter.set_stat is used.
The worker_id is used to determin whether StatsReporter.set_stat should be used.
The worker_id is used to determine whether StatsReporter.set_stat should be used.

:param env_stats:
:param worker_id:
:return:
"""
for stat_name, value_list in env_stats.items():
for val, agg_type in value_list:
if agg_type == StatsAggregationMethod.AVERAGE:
self.stats_reporter.add_stat(stat_name, val)
self.stats_reporter.add_stat(stat_name, val, agg_type)
elif agg_type == StatsAggregationMethod.SUM:
self.stats_reporter.add_stat(stat_name, val, agg_type)
elif agg_type == StatsAggregationMethod.MOST_RECENT:
# In order to prevent conflicts between multiple environments,
# only stats from the first environment are recorded.
Expand Down
80 changes: 61 additions & 19 deletions ml-agents/mlagents/trainers/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import time
from threading import RLock

from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod

from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import set_gauge
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -20,8 +22,9 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
"""
Takes a parameter dictionary and converts it to a human-readable string.
Recurses if there are multiple levels of dict. Used to print out hyperparameters.
param: param_dict: A Dictionary of key, value parameters.
return: A string version of this dictionary.

:param param_dict: A Dictionary of key, value parameters.
:return: A string version of this dictionary.
"""
if not isinstance(param_dict, dict):
return str(param_dict)
Expand All @@ -37,14 +40,23 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
)


class StatsSummary(NamedTuple):
class StatsSummary(NamedTuple): # pylint: disable=inherit-non-class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean: float
std: float
num: int
sum: float
aggregation_method: StatsAggregationMethod

@staticmethod
def empty() -> "StatsSummary":
return StatsSummary(0.0, 0.0, 0)
return StatsSummary(0.0, 0.0, 0, 0.0, StatsAggregationMethod.AVERAGE)

@property
def aggregated_value(self):
if self.aggregation_method == StatsAggregationMethod.SUM:
return self.sum
else:
return self.mean


class StatsPropertyType(Enum):
Expand All @@ -71,8 +83,9 @@ def add_property(
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
with all types of properties. For instance, a TB writer doesn't need a max step.

:param category: The category that the property belongs to.
:param type: The type of property.
:param property_type: The type of property.
:param value: The property itself.
"""
pass
Expand All @@ -98,6 +111,10 @@ def write_stats(
GaugeWriter.sanitize_string(f"{category}.{val}.mean"),
float(stats_summary.mean),
)
set_gauge(
GaugeWriter.sanitize_string(f"{category}.{val}.sum"),
float(stats_summary.sum),
)


class ConsoleWriter(StatsWriter):
Expand All @@ -114,7 +131,7 @@ def write_stats(
is_training = "Not Training"
if "Is Training" in values:
stats_summary = values["Is Training"]
if stats_summary.mean > 0.0:
if stats_summary.aggregated_value > 0.0:
is_training = "Training"

elapsed_time = time.time() - self.training_start_time
Expand Down Expand Up @@ -156,10 +173,11 @@ class TensorboardWriter(StatsWriter):
def __init__(self, base_dir: str, clear_past_data: bool = False):
"""
A StatsWriter that writes to a Tensorboard summary.

:param base_dir: The directory within which to place all the summaries. Tensorboard files will be written to a
{base_dir}/{category} directory.
:param clear_past_data: Whether or not to clean up existing Tensorboard files associated with the base_dir and
category.
category.
"""
self.summary_writers: Dict[str, SummaryWriter] = {}
self.base_dir: str = base_dir
Expand All @@ -170,7 +188,9 @@ def write_stats(
) -> None:
self._maybe_create_summary_writer(category)
for key, value in values.items():
self.summary_writers[category].add_scalar(f"{key}", value.mean, step)
self.summary_writers[category].add_scalar(
f"{key}", value.aggregated_value, step
)
self.summary_writers[category].flush()

def _maybe_create_summary_writer(self, category: str) -> None:
Expand Down Expand Up @@ -214,6 +234,9 @@ class StatsReporter:
writers: List[StatsWriter] = []
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list))
lock = RLock()
stats_aggregation: Dict[str, Dict[str, StatsAggregationMethod]] = defaultdict(
lambda: defaultdict(lambda: StatsAggregationMethod.AVERAGE)
)

def __init__(self, category: str):
"""
Expand All @@ -234,37 +257,51 @@ def add_property(self, property_type: StatsPropertyType, value: Any) -> None:
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
with all types of properties. For instance, a TB writer doesn't need a max step.
:param key: The type of property.

:param property_type: The type of property.
:param value: The property itself.
"""
with StatsReporter.lock:
for writer in StatsReporter.writers:
writer.add_property(self.category, property_type, value)

def add_stat(self, key: str, value: float) -> None:
def add_stat(
self,
key: str,
value: float,
aggregation: StatsAggregationMethod = StatsAggregationMethod.AVERAGE,
) -> None:
"""
Add a float value stat to the StatsReporter.

:param key: The type of statistic, e.g. Environment/Reward.
:param value: the value of the statistic.
:param aggregation: the aggregation method for the statistic, default StatsAggregationMethod.AVERAGE.
"""
with StatsReporter.lock:
StatsReporter.stats_dict[self.category][key].append(value)
StatsReporter.stats_aggregation[self.category][key] = aggregation

def set_stat(self, key: str, value: float) -> None:
"""
Sets a stat value to a float. This is for values that we don't want to average, and just
want the latest.

:param key: The type of statistic, e.g. Environment/Reward.
:param value: the value of the statistic.
"""
with StatsReporter.lock:
StatsReporter.stats_dict[self.category][key] = [value]
StatsReporter.stats_aggregation[self.category][
key
] = StatsAggregationMethod.MOST_RECENT

def write_stats(self, step: int) -> None:
"""
Write out all stored statistics that fall under the category specified.
The currently stored values will be averaged, written out as a single value,
and the buffer cleared.

:param step: Training step which to write these stats as.
"""
with StatsReporter.lock:
Expand All @@ -279,14 +316,19 @@ def write_stats(self, step: int) -> None:

def get_stats_summaries(self, key: str) -> StatsSummary:
"""
Get the mean, std, and count of a particular statistic, since last write.
Get the mean, std, count, sum and aggregation method of a particular statistic, since last write.

:param key: The type of statistic, e.g. Environment/Reward.
:returns: A StatsSummary NamedTuple containing (mean, std, count).
:returns: A StatsSummary containing summary statistics.
"""
if len(StatsReporter.stats_dict[self.category][key]) > 0:
return StatsSummary(
mean=np.mean(StatsReporter.stats_dict[self.category][key]),
std=np.std(StatsReporter.stats_dict[self.category][key]),
num=len(StatsReporter.stats_dict[self.category][key]),
)
return StatsSummary.empty()
stat_values = StatsReporter.stats_dict[self.category][key]
if len(stat_values) == 0:
return StatsSummary.empty()

return StatsSummary(
mean=np.mean(stat_values),
std=np.std(stat_values),
num=len(stat_values),
sum=np.sum(stat_values),
aggregation_method=StatsReporter.stats_aggregation[self.category][key],
)
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/tests/check_env_trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def write_stats(
) -> None:
for val, stats_summary in values.items():
if val == "Environment/Cumulative Reward":
print(step, val, stats_summary.mean)
self._last_reward_summary[category] = stats_summary.mean
print(step, val, stats_summary.aggregated_value)
self._last_reward_summary[category] = stats_summary.aggregated_value


# The reward processor is passed as an argument to _check_environment_trains.
Expand Down
25 changes: 23 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,39 @@ def test_agent_manager_stats():
{
"averaged": [(1.0, StatsAggregationMethod.AVERAGE)],
"most_recent": [(2.0, StatsAggregationMethod.MOST_RECENT)],
"summed": [(3.1, StatsAggregationMethod.SUM)],
},
{
"averaged": [(3.0, StatsAggregationMethod.AVERAGE)],
"most_recent": [(4.0, StatsAggregationMethod.MOST_RECENT)],
"summed": [(1.1, StatsAggregationMethod.SUM)],
},
]
for env_stats in all_env_stats:
manager.record_environment_stats(env_stats, worker_id=0)

expected_stats = {
"averaged": StatsSummary(mean=2.0, std=mock.ANY, num=2),
"most_recent": StatsSummary(mean=4.0, std=0.0, num=1),
"averaged": StatsSummary(
mean=2.0,
std=mock.ANY,
num=2,
sum=4.0,
aggregation_method=StatsAggregationMethod.AVERAGE,
),
"most_recent": StatsSummary(
mean=4.0,
std=0.0,
num=1,
sum=4.0,
aggregation_method=StatsAggregationMethod.MOST_RECENT,
),
"summed": StatsSummary(
mean=2.1,
std=mock.ANY,
num=2,
sum=4.2,
aggregation_method=StatsAggregationMethod.SUM,
),
}
stats_reporter.write_stats(123)
writer.write_stats.assert_any_call("FakeCategory", expected_stats, 123)
Expand Down
Loading