Skip to content

Commit

Permalink
Better error handling if trainer config doesn't contain "default" sec…
Browse files Browse the repository at this point in the history
…tion (#3063)
  • Loading branch information
Chris Elion authored Dec 12, 2019
1 parent 69c086e commit d5c6ff8
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 22 deletions.
8 changes: 8 additions & 0 deletions ml-agents/mlagents/trainers/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ class TrainerError(Exception):
pass


class TrainerConfigError(Exception):
"""
Any error related to the configuration of trainers in the ML-Agents Toolkit.
"""

pass


class CurriculumError(TrainerError):
"""
Any error related to training with a curriculum.
Expand Down
8 changes: 4 additions & 4 deletions ml-agents/mlagents/trainers/tests/test_simple_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def close(self):
pass


PPO_CONFIG = """
default:
PPO_CONFIG = f"""
{BRAIN_NAME}:
trainer: ppo
batch_size: 16
beta: 5.0e-3
Expand All @@ -153,8 +153,8 @@ def close(self):
gamma: 0.99
"""

SAC_CONFIG = """
default:
SAC_CONFIG = f"""
{BRAIN_NAME}:
trainer: sac
batch_size: 8
buffer_size: 500
Expand Down
73 changes: 69 additions & 4 deletions ml-agents/mlagents/trainers/tests/test_trainer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from mlagents.trainers.trainer_util import load_config, _load_config
from mlagents.trainers.trainer_metrics import TrainerMetrics
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.brain import BrainParameters


@pytest.fixture
Expand Down Expand Up @@ -36,6 +37,10 @@ def dummy_config():
use_curiosity: false
curiosity_strength: 0.0
curiosity_enc_size: 1
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)

Expand Down Expand Up @@ -212,7 +217,7 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
BrainParametersMock.return_value.brain_name = "testbrain"
external_brains = {"testbrain": BrainParametersMock()}

with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
trainer_factory = trainer_util.TrainerFactory(
trainer_config=bad_config,
summaries_dir=summaries_dir,
Expand All @@ -228,8 +233,68 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
trainers[brain_name] = trainer_factory.generate(brain_parameters)


def test_handles_no_default_section():
"""
Make sure the trainer setup handles a missing "default" in the config.
"""
brain_name = "testbrain"
config = dummy_config()
no_default_config = {brain_name: config["default"]}
brain_parameters = BrainParameters(
brain_name=brain_name,
vector_observation_space_size=1,
camera_resolutions=[],
vector_action_space_size=[2],
vector_action_descriptions=[],
vector_action_space_type=0,
)

trainer_factory = trainer_util.TrainerFactory(
trainer_config=no_default_config,
summaries_dir="test_dir",
run_id="testrun",
model_path="model_dir",
keep_checkpoints=1,
train_model=True,
load_model=False,
seed=42,
)
trainer_factory.generate(brain_parameters)


def test_raise_if_no_config_for_brain():
"""
Make sure the trainer setup raises a friendlier exception if both "default" and the brain name
are missing from the config.
"""
brain_name = "testbrain"
config = dummy_config()
bad_config = {"some_other_brain": config["default"]}
brain_parameters = BrainParameters(
brain_name=brain_name,
vector_observation_space_size=1,
camera_resolutions=[],
vector_action_space_size=[2],
vector_action_descriptions=[],
vector_action_space_type=0,
)

trainer_factory = trainer_util.TrainerFactory(
trainer_config=bad_config,
summaries_dir="test_dir",
run_id="testrun",
model_path="model_dir",
keep_checkpoints=1,
train_model=True,
load_model=False,
seed=42,
)
with pytest.raises(TrainerConfigError):
trainer_factory.generate(brain_parameters)


def test_load_config_missing_file():
with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
load_config("thisFileDefinitelyDoesNotExist.yaml")


Expand All @@ -250,6 +315,6 @@ def test_load_config_invalid_yaml():
- not
- parse
"""
with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
fp = io.StringIO(file_contents)
_load_config(fp)
36 changes: 22 additions & 14 deletions ml-agents/mlagents/trainers/trainer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, TextIO

from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.ppo.trainer import PPOTrainer
Expand Down Expand Up @@ -80,8 +80,14 @@ def initialize_trainer(
:param multi_gpu: Whether to use multi-GPU training
:return:
"""
trainer_parameters = trainer_config["default"].copy()
brain_name = brain_parameters.brain_name
if "default" not in trainer_config and brain_name not in trainer_config:
raise TrainerConfigError(
f'Trainer config must have either a "default" section, or a section for the brain name ({brain_name}). '
"See config/trainer_config.yaml for an example."
)

trainer_parameters = trainer_config.get("default", {}).copy()
trainer_parameters["summary_path"] = "{basedir}/{name}".format(
basedir=summaries_dir, name=str(run_id) + "_" + brain_name
)
Expand All @@ -96,13 +102,19 @@ def initialize_trainer(
trainer_parameters.update(trainer_config[_brain_key])

trainer: Trainer = None # type: ignore # will be set to one of these, or raise
if trainer_parameters["trainer"] == "offline_bc":
if "trainer" not in trainer_parameters:
raise TrainerConfigError(
f'The "trainer" key must be set in your trainer config for brain {brain_name} (or the default brain).'
)
trainer_type = trainer_parameters["trainer"]

if trainer_type == "offline_bc":
raise UnityTrainerException(
"The offline_bc trainer has been removed. To train with demonstrations, "
"please use a PPO or SAC trainer with the GAIL Reward Signal and/or the "
"Behavioral Cloning feature enabled."
)
elif trainer_parameters["trainer"] == "ppo":
elif trainer_type == "ppo":
trainer = PPOTrainer(
brain_parameters,
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length
Expand All @@ -115,7 +127,7 @@ def initialize_trainer(
run_id,
multi_gpu,
)
elif trainer_parameters["trainer"] == "sac":
elif trainer_type == "sac":
trainer = SACTrainer(
brain_parameters,
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length
Expand All @@ -128,10 +140,8 @@ def initialize_trainer(
run_id,
)
else:
raise UnityEnvironmentException(
"The trainer config contains "
"an unknown trainer type for "
"brain {}".format(brain_name)
raise TrainerConfigError(
f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}'
)
return trainer

Expand All @@ -141,11 +151,9 @@ def load_config(config_path: str) -> Dict[str, Any]:
with open(config_path) as data_file:
return _load_config(data_file)
except IOError:
raise UnityEnvironmentException(
f"Config file could not be found at {config_path}."
)
raise TrainerConfigError(f"Config file could not be found at {config_path}.")
except UnicodeDecodeError:
raise UnityEnvironmentException(
raise TrainerConfigError(
f"There was an error decoding Config file from {config_path}. "
f"Make sure your file is save using UTF-8"
)
Expand All @@ -158,7 +166,7 @@ def _load_config(fp: TextIO) -> Dict[str, Any]:
try:
return yaml.safe_load(fp)
except yaml.parser.ParserError as e:
raise UnityEnvironmentException(
raise TrainerConfigError(
"Error parsing yaml file. Please check for formatting errors. "
"A tool such as http://www.yamllint.com/ can be helpful with this."
) from e

0 comments on commit d5c6ff8

Please sign in to comment.