diff --git a/ml-agents/mlagents/trainers/exception.py b/ml-agents/mlagents/trainers/exception.py index d4e80646c9..2788157588 100644 --- a/ml-agents/mlagents/trainers/exception.py +++ b/ml-agents/mlagents/trainers/exception.py @@ -11,6 +11,14 @@ class TrainerError(Exception): pass +class TrainerConfigError(Exception): + """ + Any error related to the configuration of trainers in the ML-Agents Toolkit. + """ + + pass + + class CurriculumError(TrainerError): """ Any error related to training with a curriculum. diff --git a/ml-agents/mlagents/trainers/tests/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/test_simple_rl.py index cbab4c3c59..8d100d2057 100644 --- a/ml-agents/mlagents/trainers/tests/test_simple_rl.py +++ b/ml-agents/mlagents/trainers/tests/test_simple_rl.py @@ -128,8 +128,8 @@ def close(self): pass -PPO_CONFIG = """ - default: +PPO_CONFIG = f""" + {BRAIN_NAME}: trainer: ppo batch_size: 16 beta: 5.0e-3 @@ -153,8 +153,8 @@ def close(self): gamma: 0.99 """ -SAC_CONFIG = """ - default: +SAC_CONFIG = f""" + {BRAIN_NAME}: trainer: sac batch_size: 8 buffer_size: 500 diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_util.py b/ml-agents/mlagents/trainers/tests/test_trainer_util.py index 03fd4f397c..1f514bfdfa 100644 --- a/ml-agents/mlagents/trainers/tests/test_trainer_util.py +++ b/ml-agents/mlagents/trainers/tests/test_trainer_util.py @@ -7,7 +7,8 @@ from mlagents.trainers.trainer_util import load_config, _load_config from mlagents.trainers.trainer_metrics import TrainerMetrics from mlagents.trainers.ppo.trainer import PPOTrainer -from mlagents.envs.exception import UnityEnvironmentException +from mlagents.trainers.exception import TrainerConfigError +from mlagents.trainers.brain import BrainParameters @pytest.fixture @@ -36,6 +37,10 @@ def dummy_config(): use_curiosity: false curiosity_strength: 0.0 curiosity_enc_size: 1 + reward_signals: + extrinsic: + strength: 1.0 + gamma: 0.99 """ ) @@ -212,7 +217,7 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock): BrainParametersMock.return_value.brain_name = "testbrain" external_brains = {"testbrain": BrainParametersMock()} - with pytest.raises(UnityEnvironmentException): + with pytest.raises(TrainerConfigError): trainer_factory = trainer_util.TrainerFactory( trainer_config=bad_config, summaries_dir=summaries_dir, @@ -228,8 +233,68 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock): trainers[brain_name] = trainer_factory.generate(brain_parameters) +def test_handles_no_default_section(): + """ + Make sure the trainer setup handles a missing "default" in the config. + """ + brain_name = "testbrain" + config = dummy_config() + no_default_config = {brain_name: config["default"]} + brain_parameters = BrainParameters( + brain_name=brain_name, + vector_observation_space_size=1, + camera_resolutions=[], + vector_action_space_size=[2], + vector_action_descriptions=[], + vector_action_space_type=0, + ) + + trainer_factory = trainer_util.TrainerFactory( + trainer_config=no_default_config, + summaries_dir="test_dir", + run_id="testrun", + model_path="model_dir", + keep_checkpoints=1, + train_model=True, + load_model=False, + seed=42, + ) + trainer_factory.generate(brain_parameters) + + +def test_raise_if_no_config_for_brain(): + """ + Make sure the trainer setup raises a friendlier exception if both "default" and the brain name + are missing from the config. + """ + brain_name = "testbrain" + config = dummy_config() + bad_config = {"some_other_brain": config["default"]} + brain_parameters = BrainParameters( + brain_name=brain_name, + vector_observation_space_size=1, + camera_resolutions=[], + vector_action_space_size=[2], + vector_action_descriptions=[], + vector_action_space_type=0, + ) + + trainer_factory = trainer_util.TrainerFactory( + trainer_config=bad_config, + summaries_dir="test_dir", + run_id="testrun", + model_path="model_dir", + keep_checkpoints=1, + train_model=True, + load_model=False, + seed=42, + ) + with pytest.raises(TrainerConfigError): + trainer_factory.generate(brain_parameters) + + def test_load_config_missing_file(): - with pytest.raises(UnityEnvironmentException): + with pytest.raises(TrainerConfigError): load_config("thisFileDefinitelyDoesNotExist.yaml") @@ -250,6 +315,6 @@ def test_load_config_invalid_yaml(): - not - parse """ - with pytest.raises(UnityEnvironmentException): + with pytest.raises(TrainerConfigError): fp = io.StringIO(file_contents) _load_config(fp) diff --git a/ml-agents/mlagents/trainers/trainer_util.py b/ml-agents/mlagents/trainers/trainer_util.py index a9bd65b7ed..64d1e2e0f5 100644 --- a/ml-agents/mlagents/trainers/trainer_util.py +++ b/ml-agents/mlagents/trainers/trainer_util.py @@ -2,7 +2,7 @@ from typing import Any, Dict, TextIO from mlagents.trainers.meta_curriculum import MetaCurriculum -from mlagents.envs.exception import UnityEnvironmentException +from mlagents.trainers.exception import TrainerConfigError from mlagents.trainers.trainer import Trainer, UnityTrainerException from mlagents.trainers.brain import BrainParameters from mlagents.trainers.ppo.trainer import PPOTrainer @@ -80,8 +80,14 @@ def initialize_trainer( :param multi_gpu: Whether to use multi-GPU training :return: """ - trainer_parameters = trainer_config["default"].copy() brain_name = brain_parameters.brain_name + if "default" not in trainer_config and brain_name not in trainer_config: + raise TrainerConfigError( + f'Trainer config must have either a "default" section, or a section for the brain name ({brain_name}). ' + "See config/trainer_config.yaml for an example." + ) + + trainer_parameters = trainer_config.get("default", {}).copy() trainer_parameters["summary_path"] = "{basedir}/{name}".format( basedir=summaries_dir, name=str(run_id) + "_" + brain_name ) @@ -96,13 +102,19 @@ def initialize_trainer( trainer_parameters.update(trainer_config[_brain_key]) trainer: Trainer = None # type: ignore # will be set to one of these, or raise - if trainer_parameters["trainer"] == "offline_bc": + if "trainer" not in trainer_parameters: + raise TrainerConfigError( + f'The "trainer" key must be set in your trainer config for brain {brain_name} (or the default brain).' + ) + trainer_type = trainer_parameters["trainer"] + + if trainer_type == "offline_bc": raise UnityTrainerException( "The offline_bc trainer has been removed. To train with demonstrations, " "please use a PPO or SAC trainer with the GAIL Reward Signal and/or the " "Behavioral Cloning feature enabled." ) - elif trainer_parameters["trainer"] == "ppo": + elif trainer_type == "ppo": trainer = PPOTrainer( brain_parameters, meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length @@ -115,7 +127,7 @@ def initialize_trainer( run_id, multi_gpu, ) - elif trainer_parameters["trainer"] == "sac": + elif trainer_type == "sac": trainer = SACTrainer( brain_parameters, meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length @@ -128,10 +140,8 @@ def initialize_trainer( run_id, ) else: - raise UnityEnvironmentException( - "The trainer config contains " - "an unknown trainer type for " - "brain {}".format(brain_name) + raise TrainerConfigError( + f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}' ) return trainer @@ -141,11 +151,9 @@ def load_config(config_path: str) -> Dict[str, Any]: with open(config_path) as data_file: return _load_config(data_file) except IOError: - raise UnityEnvironmentException( - f"Config file could not be found at {config_path}." - ) + raise TrainerConfigError(f"Config file could not be found at {config_path}.") except UnicodeDecodeError: - raise UnityEnvironmentException( + raise TrainerConfigError( f"There was an error decoding Config file from {config_path}. " f"Make sure your file is save using UTF-8" ) @@ -158,7 +166,7 @@ def _load_config(fp: TextIO) -> Dict[str, Any]: try: return yaml.safe_load(fp) except yaml.parser.ParserError as e: - raise UnityEnvironmentException( + raise TrainerConfigError( "Error parsing yaml file. Please check for formatting errors. " "A tool such as http://www.yamllint.com/ can be helpful with this." ) from e