-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
setuptools-based plugin for StatsWriters #4788
Changes from 1 commit
76b43ae
c9be196
ca5b56e
d597c74
9db6da2
58a22c3
f138c48
5db6efe
b25b2e8
463be02
be038b7
cf9b4aa
a3de6cd
c72f245
f5b0854
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Dict, List | ||
from mlagents.trainers.stats import StatsWriter, StatsSummary | ||
|
||
|
||
class ExampleStatsWriter(StatsWriter): | ||
def write_stats( | ||
self, category: str, values: Dict[str, StatsSummary], step: int | ||
) -> None: | ||
print(f"ExampleStatsWriter category: {category} values: {values}") | ||
|
||
|
||
def get_example_stats_writer() -> List[StatsWriter]: | ||
print("Creating a new stats writer! This is so exciting!") | ||
return [ExampleStatsWriter()] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from setuptools import setup | ||
|
||
setup( | ||
name="mlagents_plugin_examples", | ||
version="0.0.1", | ||
entry_points={ | ||
"mlagents.stats_writer": [ | ||
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The form of this is |
||
] | ||
}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from typing import List | ||
import importlib_metadata | ||
from mlagents.trainers.stats import StatsWriter | ||
|
||
|
||
def get_default_stats_writers() -> List[StatsWriter]: | ||
# TODO move construction of default StatsWriters here | ||
return [] | ||
|
||
|
||
def register_stats_writer_plugins() -> List[StatsWriter]: | ||
all_stats_writers: List[StatsWriter] = [] | ||
eps = importlib_metadata.entry_points()["mlagents.stats_writer"] | ||
|
||
for ep in eps: | ||
print(f"registering {ep.name}") | ||
# TODO try/except around all of this | ||
plugin_func = ep.load() | ||
all_stats_writers += plugin_func() | ||
return all_stats_writers |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,13 +71,17 @@ def run(self): | |
"cattrs>=1.0.0,<1.1.0", | ||
"attrs>=19.3.0", | ||
'pypiwin32==223;platform_system=="Windows"', | ||
"importlib_metadata", # TODO for python<3.8 only | ||
], | ||
python_requires=">=3.6.1", | ||
entry_points={ | ||
"console_scripts": [ | ||
"mlagents-learn=mlagents.trainers.learn:main", | ||
"mlagents-run-experiment=mlagents.trainers.run_experiment:main", | ||
] | ||
], | ||
"mlagents.stats_writer": [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would need to add a new key for each type of plugin interface we want, e.g. |
||
"default=mlagents.plugins.stats_writer:get_default_stats_writers" | ||
], | ||
}, | ||
cmdclass={"verify": VerifyVersionCommand}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No plans to publish this package, but we could use it to set up tests, e.g. for bad imports.