-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature] Enable default settings for TrainerSettings #4448
Changes from 6 commits
e2cecce
f6b100e
18a2329
508a8ed
7a057e1
ea0b063
a698e00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,13 +2,24 @@ | |
|
||
import attr | ||
import cattr | ||
from typing import Dict, Optional, List, Any, DefaultDict, Mapping, Tuple, Union | ||
from typing import ( | ||
Dict, | ||
Optional, | ||
List, | ||
Any, | ||
DefaultDict, | ||
Mapping, | ||
Tuple, | ||
Union, | ||
ClassVar, | ||
) | ||
from enum import Enum | ||
import collections | ||
import argparse | ||
import abc | ||
import numpy as np | ||
import math | ||
import copy | ||
|
||
from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser | ||
from mlagents.trainers.cli_utils import load_config | ||
|
@@ -46,6 +57,17 @@ def defaultdict_to_dict(d: DefaultDict) -> Dict: | |
return {key: cattr.unstructure(val) for key, val in d.items()} | ||
|
||
|
||
def deep_update_dict(d: Dict, update_d: Mapping) -> None: | ||
""" | ||
Similar to dict.update(), but works for nested dicts of dicts as well. | ||
""" | ||
for key, val in update_d.items(): | ||
if key in d and isinstance(d[key], Mapping) and isinstance(val, Mapping): | ||
deep_update_dict(d[key], val) | ||
else: | ||
d[key] = val | ||
|
||
|
||
class SerializationSettings: | ||
convert_to_barracuda = True | ||
convert_to_onnx = True | ||
|
@@ -539,6 +561,7 @@ class FrameworkType(Enum): | |
|
||
@attr.s(auto_attribs=True) | ||
class TrainerSettings(ExportableSettings): | ||
default_override: ClassVar[Optional["TrainerSettings"]] = None | ||
trainer_type: TrainerType = TrainerType.PPO | ||
hyperparameters: HyperparamSettings = attr.ib() | ||
|
||
|
@@ -578,8 +601,8 @@ def _check_batch_size_seq_length(self, attribute, value): | |
|
||
@staticmethod | ||
def dict_to_defaultdict(d: Dict, t: type) -> DefaultDict: | ||
return collections.defaultdict( | ||
TrainerSettings, cattr.structure(d, Dict[str, TrainerSettings]) | ||
return TrainerSettings.DefaultTrainerDict( | ||
cattr.structure(d, Dict[str, TrainerSettings]) | ||
) | ||
|
||
@staticmethod | ||
|
@@ -588,10 +611,18 @@ def structure(d: Mapping, t: type) -> Any: | |
Helper method to structure a TrainerSettings class. Meant to be registered with | ||
cattr.register_structure_hook() and called with cattr.structure(). | ||
""" | ||
|
||
if not isinstance(d, Mapping): | ||
raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.") | ||
|
||
d_copy: Dict[str, Any] = {} | ||
d_copy.update(d) | ||
|
||
# Check if a default_settings was specified. If so, used those as the default | ||
# rather than an empty dict. | ||
if TrainerSettings.default_override is not None: | ||
d_copy.update(cattr.unstructure(TrainerSettings.default_override)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this will do what you want for nested dictionaries. For example if you had
you'd end up with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed using a recursive dict update |
||
|
||
deep_update_dict(d_copy, d) | ||
|
||
for key, val in d_copy.items(): | ||
if attr.has(type(val)): | ||
|
@@ -613,6 +644,16 @@ def structure(d: Mapping, t: type) -> Any: | |
d_copy[key] = check_and_structure(key, val, t) | ||
return t(**d_copy) | ||
|
||
class DefaultTrainerDict(collections.defaultdict): | ||
def __init__(self, *args): | ||
super().__init__(TrainerSettings, *args) | ||
|
||
def __missing__(self, key: Any) -> "TrainerSettings": | ||
if TrainerSettings.default_override is not None: | ||
return copy.deepcopy(TrainerSettings.default_override) | ||
else: | ||
return TrainerSettings() | ||
|
||
|
||
# COMMAND LINE ######################################################################### | ||
@attr.s(auto_attribs=True) | ||
|
@@ -653,8 +694,9 @@ class EngineSettings: | |
|
||
@attr.s(auto_attribs=True) | ||
class RunOptions(ExportableSettings): | ||
default_settings: Optional[TrainerSettings] = None | ||
behaviors: DefaultDict[str, TrainerSettings] = attr.ib( | ||
factory=lambda: collections.defaultdict(TrainerSettings) | ||
factory=TrainerSettings.DefaultTrainerDict | ||
) | ||
env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings) | ||
engine_settings: EngineSettings = attr.ib(factory=EngineSettings) | ||
|
@@ -733,4 +775,12 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions": | |
|
||
@staticmethod | ||
def from_dict(options_dict: Dict[str, Any]) -> "RunOptions": | ||
# If a default settings was specified, set the TrainerSettings class override | ||
if ( | ||
"default_settings" in options_dict.keys() | ||
and options_dict["default_settings"] is not None | ||
): | ||
TrainerSettings.default_override = cattr.structure( | ||
options_dict["default_settings"], TrainerSettings | ||
) | ||
return cattr.structure(options_dict, RunOptions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.