Skip to content

Develop generalization training #2232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 70 commits into from
Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
c650e8a
Adding new command line arguments
sankalp04 Jun 21, 2019
f78d974
Allow generalization training with specified arguments of min_reward …
sankalp04 Jun 28, 2019
643e816
Added sampler_class in mlagents/envs
sankalp04 Jun 28, 2019
d90031b
Change reset parameters based on reward or progress metric
sankalp04 Jul 2, 2019
a4cf9ea
Include error checking in sampling methods
sankalp04 Jul 2, 2019
e0cd1df
Incorporate generalization checks for resetting parameters in take_step
sankalp04 Jul 2, 2019
33fe8ae
Add Sampler error to track errors in the sampler class
sankalp04 Jul 2, 2019
cf0eae1
Add LessonControllerError to track errors in LessonController
sankalp04 Jul 2, 2019
155cd4e
Get rid of dead code and clean up code
sankalp04 Jul 2, 2019
65884b3
Removed check_key and replaced with **param_dict for implicit type ch…
sankalp04 Jul 3, 2019
7b21f29
Instantiate SamplerManager in learn.py instead of trainer_controller
sankalp04 Jul 3, 2019
7a1944d
Made the code cleanup changes - mostly nit changes
sankalp04 Jul 3, 2019
6b3d1c0
Cleanup PPOTrainer call and env_reset
sankalp04 Jul 3, 2019
fee2b14
Example parameter sampling file config
sankalp04 Jul 8, 2019
626b0b7
Updated example sampler config file
sankalp04 Jul 8, 2019
a194379
Remove LessonControllerException
sankalp04 Jul 8, 2019
381ff86
Conduct training without lesson_controller
sankalp04 Jul 8, 2019
e8c1e80
Remove dead code
sankalp04 Jul 8, 2019
6b82251
Suggested stylistic and nit changes made
sankalp04 Jul 8, 2019
436987e
Use seed passed in learn.py instead of in sampler config
sankalp04 Jul 10, 2019
abd4c03
Init commit of sampler tests
sankalp04 Jul 10, 2019
bfb7bd0
Init commit to sampler_class tests
sankalp04 Jul 10, 2019
390d98e
Incorrect file in test folder
sankalp04 Jul 10, 2019
e5423bc
Add tests
sankalp04 Jul 11, 2019
7ecc6d8
Add safety check in case samplers provided with incorrect or insuffic…
sankalp04 Jul 17, 2019
0f302d7
Updated test files with tests for sampler class
sankalp04 Jul 17, 2019
2f1f814
Fixed syntax errors and removed dead code
sankalp04 Jul 17, 2019
23fbff8
Changed exception types in test and imported sampler types for checks
sankalp04 Jul 17, 2019
b9a7f2a
Modify test_learn to accomodate sampler argument
sankalp04 Jul 17, 2019
1ce970c
Fixed tests
sankalp04 Jul 17, 2019
56810c0
Fixed test case of different instantiations of sampler_manager in test
sankalp04 Jul 17, 2019
ea5f4db
Attempt at mocking the sampler_class
sankalp04 Jul 17, 2019
68bdb59
Removed unused function trailing from use of lesson_controller
sankalp04 Jul 17, 2019
4d7c79f
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 18, 2019
9d1acc5
Resolved conflicts, finished type annotations, nit changes and accomo…
sankalp04 Jul 19, 2019
5063818
Trainer controller argument name changed so tests are changed accordi…
sankalp04 Jul 19, 2019
11e2c78
Modified tests
sankalp04 Jul 19, 2019
165ab57
Remove extra arg in test case
sankalp04 Jul 19, 2019
d5b0838
Fixed pytest failure in checks
sankalp04 Jul 19, 2019
c3b372c
Fixed style from black
sankalp04 Jul 19, 2019
bb84831
Fixed pending style issues
sankalp04 Jul 19, 2019
c441877
Fixed spacing issues
sankalp04 Jul 19, 2019
8f1c7a8
Fixed missing type annotation
sankalp04 Jul 19, 2019
c8493c0
Check steps for none for style check
sankalp04 Jul 19, 2019
ecdc734
Revert "Check steps for none for style check"
sankalp04 Jul 19, 2019
d4c9163
Check steps for none for type checking
sankalp04 Jul 19, 2019
337ef3e
Check lesson duration for None for style check
sankalp04 Jul 19, 2019
e285039
Clean test code to instantiate tests within function
sankalp04 Jul 19, 2019
ddc2e60
Init commit to generalization docs
sankalp04 Jul 23, 2019
e8ad471
Change key to resampling-interval
sankalp04 Jul 23, 2019
e6a1029
Remove 3D ball curves and fix formatting
sankalp04 Jul 23, 2019
30e0727
Add comments and finish annotations
sankalp04 Jul 23, 2019
603c5c8
Remove attribution between tests and environments
sankalp04 Jul 23, 2019
fb74bbe
Rename lesson-length to resampling_interval
sankalp04 Jul 23, 2019
3a5f033
Modify test cases to reflect variable name changes
sankalp04 Jul 23, 2019
ce152ee
Fix bug
sankalp04 Jul 23, 2019
b5b97bd
Remove extra file imports in test_learn
sankalp04 Jul 23, 2019
f545ed1
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 23, 2019
0e703e2
Modify test to include sampler_manager
sankalp04 Jul 23, 2019
408b32b
Fix style changes
sankalp04 Jul 23, 2019
03de771
Fixed formatting and style part II
sankalp04 Jul 23, 2019
c47ccfd
Fixed more style
sankalp04 Jul 23, 2019
44dffd7
Including missed file in style fixes
sankalp04 Jul 23, 2019
79f2079
Finish annotations
sankalp04 Jul 23, 2019
092e370
Address comments for docs
sankalp04 Jul 23, 2019
bb8654f
Fix black's bare except error
sankalp04 Jul 23, 2019
188dec7
Fixed type annotations
sankalp04 Jul 24, 2019
3b63502
Fixed final type annotation error
sankalp04 Jul 24, 2019
d93f79e
Fix training docs
sankalp04 Jul 24, 2019
f79263b
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions config/generalize_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
episode-length: 1000

mass:
sampler-type: "uniform"
min_value: 0.5
max_value: 10
seed: 5

gravity:
sampler-type: "uniform"
min_value: 7
max_value: 12
seed: 5

scale:
sampler-type: "uniform"
min_value: 0.75
max_value: 3
seed: 5
6 changes: 6 additions & 0 deletions ml-agents-envs/mlagents/envs/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class UnityActionException(UnityException):

pass

class SamplerException(UnityException):
"""
Related to errors with the sampler actions.
"""

pass

class UnityTimeOutException(UnityException):
"""
Expand Down
99 changes: 99 additions & 0 deletions ml-agents-envs/mlagents/envs/sampler_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import numpy as np
from functools import *
from abc import ABC, abstractmethod

from .exception import SamplerException

class SamplerException(Exception):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW this should inherit from UnityException as with all of the other exception classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instance of dead code I was using for testing; removed in the next update

pass

class Sampler(ABC):

@abstractmethod
def sample_parameter(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_parameter shouldn't take args and kwargs (or should maybe take a context arg if you're designing for the future)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotation for return type: -> float:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same annotation for all the implementations.

pass


class UniformSampler(Sampler):
# kwargs acts as a sink for extra unneeded args
def __init__(self, min_value, max_value, **kwargs):
self.min_value = min_value
self.max_value = max_value

def sample_parameter(self):
return np.random.uniform(self.min_value, self.max_value)


class MultiRangeUniformSampler(Sampler):
def __init__(self, intervals, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should add more comments one the classes. You need to define what the sampler does (Uniform vs normal for instance) as well as the parameters type and function. For example : Intervals is a list of floats, this needs to be specified somewhere.

self.intervals = intervals
# Measure the length of the intervals
self.interval_lengths = list(map(lambda x: abs(x[1] - x[0]), self.intervals))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find list comprehensions more readable than maps and reduces (and they were also removed from python3). I think this works:

self.interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]
self.cum_interval_length = sum(self.interval_lengths)
self.interval_weights = [x / self.cum_interval_length for x in self.interval_lengths]

and I also think you can keep interval_lengths and cum_interval_length as local variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted them to local variables

# Cumulative size of the intervals
self.cum_interval_length = reduce(lambda x,y: x + y, self.interval_lengths, 0)
# Assign weights to an interval proportionate to the interval size
self.interval_weights = list(map(lambda x: x/self.cum_interval_length, self.interval_lengths))


def sample_parameter(self):
cur_min, cur_max = self.intervals[np.random.choice(len(self.intervals), p=self.interval_weights)]
return np.random.uniform(cur_min, cur_max)


class GaussianSampler(Sampler):
def __init__(self, mean, var, **kwargs):
self.mean = mean
self.var = var

def sample_parameter(self):
return np.random.normal(self.mean, self.var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variance =/= standard deviation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, let me fix the naming



# To introduce new sampling methods, just need to 'register' them to this sampler factory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should be formatted consistently with the other comments of the repo.

class SamplerFactory:
NAME_TO_CLASS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn I like this.

"uniform": UniformSampler,
"gaussian": GaussianSampler,
"multirange_uniform": MultiRangeUniformSampler,
}

@staticmethod
def register_sampler(name, sampler_cls):
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls

@staticmethod
def init_sampler_class(name, param_dict):
if name not in SamplerFactory.NAME_TO_CLASS:
raise SamplerException(
name + " sampler is not registered in the SamplerFactory."
" Use the register_sample method to register the string"
" associated to your sampler in the SamplerFactory."
)
sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
return sampler_cls(**param_dict)


class SamplerManager:
def __init__(self, reset_param_dict):
self.reset_param_dict = reset_param_dict
self.samplers = {}
if reset_param_dict == None:
return
for param_name, cur_param_dict in self.reset_param_dict.items():
if "sampler-type" not in cur_param_dict:
raise SamplerException(
"'sampler_type' argument hasn't been supplied for the {0} parameter".format(param_name)
)
sampler_name = cur_param_dict.pop("sampler-type")
param_sampler = SamplerFactory.init_sampler_class(sampler_name, cur_param_dict)

self.samplers[param_name] = param_sampler

def sample_all(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def sample_all(self):
def sample_all(self) -> Dict[str, float]:

res = {}
if self.samplers == {}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is what you want here; it's definitely more pythonic to do if not self.samplers: But even then, you don't need this if, since doing a for loop will iterate over 0 items.

pass
else:
for param_name, param_sampler in list(self.samplers.items()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need the list() here (only if we were trying for some sort of python2+3 compatibility).

res[param_name] = param_sampler.sample_parameter()
return res
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ class MetaCurriculumError(TrainerError):
"""
Any error related to the configuration of a metacurriculum.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra line ?

pass

56 changes: 39 additions & 17 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from mlagents.trainers.exception import TrainerError
from mlagents.trainers import MetaCurriculumError, MetaCurriculum
from mlagents.envs import UnityEnvironment
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.envs.sampler_class import SamplerManager
from mlagents.envs.exception import UnityEnvironmentException, SamplerException
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.subprocess_environment import SubprocessUnityEnvironment

Expand Down Expand Up @@ -52,6 +53,8 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue):
fast_simulation = not bool(run_options["--slow"])
no_graphics = run_options["--no-graphics"]
trainer_config_path = run_options["<trainer-config-path>"]
sampler_file_path = run_options["--sampler"] if run_options ["--sampler"] != "None" else None

# Recognize and use docker volume if one is passed as an argument
if not docker_target_name:
model_path = "./models/{run_id}-{sub_id}".format(run_id=run_id, sub_id=sub_id)
Expand All @@ -73,6 +76,21 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue):
docker_target_name=docker_target_name
)

sampler = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think sampler_config is a better name for this. sampler implies it's an instance of a Sampler

lesson_length = None
if sampler_file_path is not None:
sampler = load_config(sampler_file_path)
if ("episode-length") in sampler:
lesson_length = sampler["episode-length"]
del sampler["episode-length"]
else:
raise SamplerException(
"Episode Length was not specified in the sampler file."
" Please specify it with the 'episode-length' key in the sampler config file."
)
sampler_manager = SamplerManager(sampler)


trainer_config = load_config(trainer_config_path)
env_factory = create_environment_factory(
env_path,
Expand All @@ -84,6 +102,7 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue):
env = SubprocessUnityEnvironment(env_factory, num_envs)
maybe_meta_curriculum = try_create_meta_curriculum(curriculum_folder, env)


# Create controller and begin training.
tc = TrainerController(
model_path,
Expand All @@ -98,6 +117,8 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue):
env.external_brains,
run_seed,
fast_simulation,
sampler_manager,
lesson_length,
)

# Signal that environment has been launched.
Expand Down Expand Up @@ -242,22 +263,23 @@ def main():
mlagents-learn --help

Options:
--env=<file> Name of the Unity executable [default: None].
--curriculum=<directory> Curriculum json directory for environment [default: None].
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5].
--lesson=<n> Start learning from this lesson [default: 0].
--load Whether to load the model or randomly initialize [default: False].
--run-id=<path> The directory name for model and summary statistics [default: ppo].
--num-runs=<n> Number of concurrent training sessions [default: 1].
--save-freq=<n> Frequency at which to save model [default: 50000].
--seed=<n> Random seed used for training [default: -1].
--slow Whether to run the game at training speed [default: False].
--train Whether to train model, or only run inference [default: False].
--base-port=<n> Base port for environment communication [default: 5005].
--num-envs=<n> Number of parallel environments to use for training [default: 1]
--docker-target-name=<dt> Docker volume to store training-specific files [default: None].
--no-graphics Whether to run the environment in no-graphics mode [default: False].
--debug Whether to run ML-Agents in debug mode with detailed logging [default: False].
--env=<file> Name of the Unity executable [default: None].
--curriculum=<directory> Curriculum json directory for environment [default: None].
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5].
--lesson=<n> Start learning from this lesson [default: 0].
--load Whether to load the model or randomly initialize [default: False].
--run-id=<path> The directory name for model and summary statistics [default: ppo].
--num-runs=<n> Number of concurrent training sessions [default: 1].
--save-freq=<n> Frequency at which to save model [default: 50000].
--seed=<n> Random seed used for training [default: -1].
--slow Whether to run the game at training speed [default: False].
--train Whether to train model, or only run inference [default: False].
--base-port=<n> Base port for environment communication [default: 5005].
--num-envs=<n> Number of parallel environments to use for training [default: 1]
--docker-target-name=<dt> Docker volume to store training-specific files [default: None].
--no-graphics Whether to run the environment in no-graphics mode [default: False].
--sampler=<directory> Reset parameter yaml directory for sampling of environment reset parameters [default: None].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a directory or a file ? In the code, it looks like a file.

--debug Whether to run ML-Agents in debug mode with detailed logging [default: False].
"""

options = docopt(_USAGE)
Expand Down
76 changes: 56 additions & 20 deletions ml-agents/mlagents/trainers/trainer_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from mlagents.trainers.meta_curriculum import MetaCurriculum



class TrainerController(object):
# Type of reset_param_dict unspecified as typing library does not support heterogeneous dictionary types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use Dict[str, Any] for heterogenous values (or Dict[str, Union[type1, type1]] if you know the possible types). But I think the proto specifies string keys and float values:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing comment from old code, no longer relevant here, but type information is good to know.

def __init__(
self,
model_path: str,
Expand All @@ -37,6 +39,8 @@ def __init__(
external_brains: Dict[str, BrainParameters],
training_seed: int,
fast_simulation: bool,
sampler_manager,
Copy link
Contributor

@chriselion chriselion Jul 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotation here: sampler_manager: SamplerManager,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added type annotation

lesson_length: Optional[int],
):
"""
:param model_path: Path to save the model.
Expand All @@ -50,6 +54,8 @@ def __init__(
:param lesson: Start learning from this lesson.
:param external_brains: dictionary of external brain names to BrainInfo objects.
:param training_seed: Seed to use for Numpy and Tensorflow random number generation.
:param sampler_manager: SamplerManager object which stores information about samplers to use for the reset parameters.
:param lesson_length: Specifies number of steps after which reset parameters are resampled.
"""

self.model_path = model_path
Expand All @@ -72,6 +78,8 @@ def __init__(
self.fast_simulation = fast_simulation
np.random.seed(self.seed)
tf.set_random_seed(self.seed)
self.sampler_manager = sampler_manager
self.lesson_length = lesson_length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit uncomfortable with multiple things called lesson_length. This lesson length is different than the one used by meta_curriculum, right? Is it the same as episode_length defined in the YAML file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as the episode_length in the yaml file, not sure if episode length would be appropriate either. Any thoughts about lesson_duration instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would use something that doesn't involve the word lesson (curriculum) or episode (Is the time until Done()). Maybe resampling_interval or something similar?

Whatever you choose, it should be the same in the yaml and in the code.


def _get_measure_vals(self):
if self.meta_curriculum:
Expand All @@ -90,6 +98,7 @@ def _get_measure_vals(self):
measure_val = np.mean(self.trainers[brain_name].reward_buffer)
brain_names_to_measure_vals[brain_name] = measure_val
return brain_names_to_measure_vals

else:
return None

Expand Down Expand Up @@ -167,13 +176,16 @@ def initialize_trainers(self, trainer_config: Dict[str, Dict[str, str]]):
self.run_id,
)
elif trainer_parameters_dict[brain_name]["trainer"] == "ppo":
# Find lesson length based on the form of learning
if self.meta_curriculum:
lesson_length = self.meta_curriculum.brains_to_curriculums[
brain_name].min_lesson_length
else:
lesson_length = 0

self.trainers[brain_name] = PPOTrainer(
self.external_brains[brain_name],
self.meta_curriculum.brains_to_curriculums[
brain_name
].min_lesson_length
if self.meta_curriculum
else 0,
lesson_length,
trainer_parameters_dict[brain_name],
self.train_model,
self.load_model,
Expand Down Expand Up @@ -203,20 +215,30 @@ def _create_model_path(model_path):
"permissions are set correctly.".format(model_path)
)

@staticmethod
def _check_reset_params(reset_params, new_config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking, where is this method called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code trailing from use of lesson_controller; not needed anymore so removed from next update.

for k in new_config:
if (k in reset_params) and (isinstance(config[k], (int, float))):
continue
elif not isinstance(new_config[k], (int, float)):
raise UnityEnvironmentException(
"The parameter '{0}'' generated by the sampler doesn't exist in this environment.".format(
k
)
)

def _reset_env(self, env: BaseUnityEnvironment):
"""Resets the environment.

Returns:
A Data structure corresponding to the initial reset state of the
environment.
"""
if self.meta_curriculum is not None:
return env.reset(
train_mode=self.fast_simulation,
config=self.meta_curriculum.get_config(),
)
else:
return env.reset(train_mode=self.fast_simulation)
sampled_reset_param = self.sampler_manager.sample_all()
new_meta_curriculum_config = (self.meta_curriculum.get_config()
if self.meta_curriculum else {})
sampled_reset_param.update(new_meta_curriculum_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean meta_curriculum would override sampler_manager if provided? I'd just make sure this is clear in the documentation, or say that you have to use one or the other.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only overrides if they conflict on a key. One of them has to take priority, not sure which one it should be though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta_curriculum does take priority over generalization training as this made most sense.

return env.reset(train_mode = self.fast_simulation, config = sampled_reset_param)

def start_learning(self, env: BaseUnityEnvironment, trainer_config):
# TODO: Should be able to start learning at different lesson numbers
Expand Down Expand Up @@ -263,6 +285,22 @@ def start_learning(self, env: BaseUnityEnvironment, trainer_config):
self._write_training_metrics()
self._export_graph()

def end_trainer_episodes(self, env, lessons_incremented):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotations.

Suggested change
def end_trainer_episodes(self, env, lessons_incremented):
def end_trainer_episodes(self, env: BaseUnityEnvironment, lessons_incremented: Dict[str, bool]) -> (whatever _reset_env returns):

curr_info = self._reset_env(env)
for brain_name, trainer in self.trainers.items():
trainer.end_episode()
for brain_name, changed in lessons_incremented.items():
if changed:
self.trainers[brain_name].reward_buffer.clear()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're doing a generalization reset, do we need to clear the reward buffers? I don't think that will currently happen.

Copy link
Contributor Author

@sankalp04 sankalp04 Jul 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirming, no need to change reward buffers in the case of generalization reset as generalization doesn't make use of reward buffers to decide on reset

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, please add a comment to this effect; it's not obvious from the code.

return curr_info

def check_empty_sampler_manager(self):
"""
If self.samplers is empty, then bool of it returns false, indicating
there is no sampler manager.
"""
return not bool(self.sampler_manager.samplers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making this a method or property on SamplerManager instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on making check_empty_sampler_manager part of SamplerManager

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the check part of the SamplerManager class


def take_step(self, env: BaseUnityEnvironment, curr_info: AllBrainInfo):
if self.meta_curriculum:
# Get the sizes of the reward buffers.
Expand All @@ -279,14 +317,12 @@ def take_step(self, env: BaseUnityEnvironment, curr_info: AllBrainInfo):

# If any lessons were incremented or the environment is
# ready to be reset
if self.meta_curriculum and any(lessons_incremented.values()):
curr_info = self._reset_env(env)
for brain_name, trainer in self.trainers.items():
trainer.end_episode()
for brain_name, changed in lessons_incremented.items():
if changed:
self.trainers[brain_name].reward_buffer.clear()

if ( ((self.meta_curriculum) and any(lessons_incremented.values()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is way too complicated. If you're actually checking what the comment says, then split it up into something like

lessons_were_incremented = ...
ready_for_reset = ...
if lessons_were_incremented or ready_for_reset:
    # do stuff

(also watch the modulo by zero)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a check to ensure that global_step isn't 0 to ensure modulo safety

or ( (not self.check_empty_sampler_manager()) and (self.global_step % self.lesson_length == 0)
and (self.global_step != 0)) ):
curr_info = self.end_trainer_episodes(env, lessons_incremented)


# Decide and take an action
take_action_vector = {}
take_action_memories = {}
Expand Down