Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

setuptools-based plugin for StatsWriters #4788

Merged
merged 15 commits into from
Feb 5, 2021
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Dict, List
from mlagents.trainers.stats import StatsWriter, StatsSummary


class ExampleStatsWriter(StatsWriter):
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
print(f"ExampleStatsWriter category: {category} values: {values}")


def get_example_stats_writer() -> List[StatsWriter]:
print("Creating a new stats writer! This is so exciting!")
return [ExampleStatsWriter()]
11 changes: 11 additions & 0 deletions ml-agents-plugin-examples/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from setuptools import setup
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No plans to publish this package, but we could use it to set up tests, e.g. for bad imports.


setup(
name="mlagents_plugin_examples",
version="0.0.1",
entry_points={
"mlagents.stats_writer": [
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer"
Copy link
Contributor Author

@chriselion chriselion Dec 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The form of this is {entry point name}={plugin module}:{plugin_function}

]
},
)
Empty file.
20 changes: 20 additions & 0 deletions ml-agents/mlagents/plugins/stats_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List
import importlib_metadata
from mlagents.trainers.stats import StatsWriter


def get_default_stats_writers() -> List[StatsWriter]:
# TODO move construction of default StatsWriters here
return []


def register_stats_writer_plugins() -> List[StatsWriter]:
all_stats_writers: List[StatsWriter] = []
eps = importlib_metadata.entry_points()["mlagents.stats_writer"]

for ep in eps:
print(f"registering {ep.name}")
# TODO try/except around all of this
plugin_func = ep.load()
all_stats_writers += plugin_func()
return all_stats_writers
20 changes: 12 additions & 8 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
add_metadata as add_timer_metadata,
)
from mlagents_envs import logging_util
from mlagents.plugins.stats_writer import register_stats_writer_plugins

logger = logging_util.get_logger(__name__)

Expand Down Expand Up @@ -91,14 +92,17 @@ def run_training(run_seed: int, options: RunOptions) -> None:
)

# Configure Tensorboard Writers and StatsReporter
tb_writer = TensorboardWriter(
write_path, clear_past_data=not checkpoint_settings.resume
)
gauge_write = GaugeWriter()
console_writer = ConsoleWriter()
StatsReporter.add_writer(tb_writer)
StatsReporter.add_writer(gauge_write)
StatsReporter.add_writer(console_writer)
stats_writers = [
TensorboardWriter(
write_path, clear_past_data=not checkpoint_settings.resume
),
GaugeWriter(),
ConsoleWriter(),
]

stats_writers += register_stats_writer_plugins()
for sw in stats_writers:
StatsReporter.add_writer(sw)

if env_settings.env_path is None:
port = None
Expand Down
6 changes: 5 additions & 1 deletion ml-agents/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,17 @@ def run(self):
"cattrs>=1.0.0,<1.1.0",
"attrs>=19.3.0",
'pypiwin32==223;platform_system=="Windows"',
"importlib_metadata", # TODO for python<3.8 only
],
python_requires=">=3.6.1",
entry_points={
"console_scripts": [
"mlagents-learn=mlagents.trainers.learn:main",
"mlagents-run-experiment=mlagents.trainers.run_experiment:main",
]
],
"mlagents.stats_writer": [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to add a new key for each type of plugin interface we want, e.g. mlagents.demonstration_provider: [...]

"default=mlagents.plugins.stats_writer:get_default_stats_writers"
],
},
cmdclass={"verify": VerifyVersionCommand},
)