Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop generalization training #2232

Merged
merged 70 commits into from
Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from 68 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
c650e8a
Adding new command line arguments
sankalp04 Jun 21, 2019
f78d974
Allow generalization training with specified arguments of min_reward …
sankalp04 Jun 28, 2019
643e816
Added sampler_class in mlagents/envs
sankalp04 Jun 28, 2019
d90031b
Change reset parameters based on reward or progress metric
sankalp04 Jul 2, 2019
a4cf9ea
Include error checking in sampling methods
sankalp04 Jul 2, 2019
e0cd1df
Incorporate generalization checks for resetting parameters in take_step
sankalp04 Jul 2, 2019
33fe8ae
Add Sampler error to track errors in the sampler class
sankalp04 Jul 2, 2019
cf0eae1
Add LessonControllerError to track errors in LessonController
sankalp04 Jul 2, 2019
155cd4e
Get rid of dead code and clean up code
sankalp04 Jul 2, 2019
65884b3
Removed check_key and replaced with **param_dict for implicit type ch…
sankalp04 Jul 3, 2019
7b21f29
Instantiate SamplerManager in learn.py instead of trainer_controller
sankalp04 Jul 3, 2019
7a1944d
Made the code cleanup changes - mostly nit changes
sankalp04 Jul 3, 2019
6b3d1c0
Cleanup PPOTrainer call and env_reset
sankalp04 Jul 3, 2019
fee2b14
Example parameter sampling file config
sankalp04 Jul 8, 2019
626b0b7
Updated example sampler config file
sankalp04 Jul 8, 2019
a194379
Remove LessonControllerException
sankalp04 Jul 8, 2019
381ff86
Conduct training without lesson_controller
sankalp04 Jul 8, 2019
e8c1e80
Remove dead code
sankalp04 Jul 8, 2019
6b82251
Suggested stylistic and nit changes made
sankalp04 Jul 8, 2019
436987e
Use seed passed in learn.py instead of in sampler config
sankalp04 Jul 10, 2019
abd4c03
Init commit of sampler tests
sankalp04 Jul 10, 2019
bfb7bd0
Init commit to sampler_class tests
sankalp04 Jul 10, 2019
390d98e
Incorrect file in test folder
sankalp04 Jul 10, 2019
e5423bc
Add tests
sankalp04 Jul 11, 2019
7ecc6d8
Add safety check in case samplers provided with incorrect or insuffic…
sankalp04 Jul 17, 2019
0f302d7
Updated test files with tests for sampler class
sankalp04 Jul 17, 2019
2f1f814
Fixed syntax errors and removed dead code
sankalp04 Jul 17, 2019
23fbff8
Changed exception types in test and imported sampler types for checks
sankalp04 Jul 17, 2019
b9a7f2a
Modify test_learn to accomodate sampler argument
sankalp04 Jul 17, 2019
1ce970c
Fixed tests
sankalp04 Jul 17, 2019
56810c0
Fixed test case of different instantiations of sampler_manager in test
sankalp04 Jul 17, 2019
ea5f4db
Attempt at mocking the sampler_class
sankalp04 Jul 17, 2019
68bdb59
Removed unused function trailing from use of lesson_controller
sankalp04 Jul 17, 2019
4d7c79f
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 18, 2019
9d1acc5
Resolved conflicts, finished type annotations, nit changes and accomo…
sankalp04 Jul 19, 2019
5063818
Trainer controller argument name changed so tests are changed accordi…
sankalp04 Jul 19, 2019
11e2c78
Modified tests
sankalp04 Jul 19, 2019
165ab57
Remove extra arg in test case
sankalp04 Jul 19, 2019
d5b0838
Fixed pytest failure in checks
sankalp04 Jul 19, 2019
c3b372c
Fixed style from black
sankalp04 Jul 19, 2019
bb84831
Fixed pending style issues
sankalp04 Jul 19, 2019
c441877
Fixed spacing issues
sankalp04 Jul 19, 2019
8f1c7a8
Fixed missing type annotation
sankalp04 Jul 19, 2019
c8493c0
Check steps for none for style check
sankalp04 Jul 19, 2019
ecdc734
Revert "Check steps for none for style check"
sankalp04 Jul 19, 2019
d4c9163
Check steps for none for type checking
sankalp04 Jul 19, 2019
337ef3e
Check lesson duration for None for style check
sankalp04 Jul 19, 2019
e285039
Clean test code to instantiate tests within function
sankalp04 Jul 19, 2019
ddc2e60
Init commit to generalization docs
sankalp04 Jul 23, 2019
e8ad471
Change key to resampling-interval
sankalp04 Jul 23, 2019
e6a1029
Remove 3D ball curves and fix formatting
sankalp04 Jul 23, 2019
30e0727
Add comments and finish annotations
sankalp04 Jul 23, 2019
603c5c8
Remove attribution between tests and environments
sankalp04 Jul 23, 2019
fb74bbe
Rename lesson-length to resampling_interval
sankalp04 Jul 23, 2019
3a5f033
Modify test cases to reflect variable name changes
sankalp04 Jul 23, 2019
ce152ee
Fix bug
sankalp04 Jul 23, 2019
b5b97bd
Remove extra file imports in test_learn
sankalp04 Jul 23, 2019
f545ed1
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 23, 2019
0e703e2
Modify test to include sampler_manager
sankalp04 Jul 23, 2019
408b32b
Fix style changes
sankalp04 Jul 23, 2019
03de771
Fixed formatting and style part II
sankalp04 Jul 23, 2019
c47ccfd
Fixed more style
sankalp04 Jul 23, 2019
44dffd7
Including missed file in style fixes
sankalp04 Jul 23, 2019
79f2079
Finish annotations
sankalp04 Jul 23, 2019
092e370
Address comments for docs
sankalp04 Jul 23, 2019
bb8654f
Fix black's bare except error
sankalp04 Jul 23, 2019
188dec7
Fixed type annotations
sankalp04 Jul 24, 2019
3b63502
Fixed final type annotation error
sankalp04 Jul 24, 2019
d93f79e
Fix training docs
sankalp04 Jul 24, 2019
f79263b
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions config/generalize_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
resampling-interval: 5000

mass:
sampler-type: "uniform"
min_value: 0.5
max_value: 10

gravity:
sampler-type: "uniform"
min_value: 7
max_value: 12

scale:
sampler-type: "uniform"
min_value: 0.75
max_value: 3
123 changes: 123 additions & 0 deletions docs/Training-Generalization-Learning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Training Generalized Reinforcement Learning agents
Copy link
Contributor

@ervteng ervteng Jul 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: capitalize agents


Reinforcement learning has a rather unique setup as opposed to supervised and
unsupervised learning. Agents here are trained and tested on the same exact
environment, which is analogous to a model being trained and tested on an
identical dataset in supervised learning! This setting results in overfitting;
the inability of the agent to generalize to slight tweaks or variations in the
environment. This is problematic in instances when environments are randomly
instantiated with varying properties. To make agents more robust, we train an
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would rephrase: "To make agents more robust, one approach is to train over multiple variations of the environment."

agent over multiple variations of the environment. The agent is trained with
the intent that it learns to maintain a minimum performance regardless of the
environment variant and that it generalizes to maintain this in unseen future
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would rephrase: "... and that it generalizes to be robust to future unseen variants of the environment."

variants of the environment.

Ball scale of 0.5 | Ball scale of 4
:-------------------------:|:-------------------------:
![](images/3dball_small.png) | ![](images/3dball_big.png)

_Variations of the 3D Ball environment._

To vary environments, we first decide what parameters to vary in an
environment. These parameters are known as `Reset Parameters`. In the 3D ball
environment example displayed in the figure above, the reset parameters are `gravity`, `ball_mass` and `ball_scale`.


## How-to

For generalization training, we need to provide a way to modify the environment
by supplying a set of reset parameters. This provision can be either
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this clearer, I'd break this into two paragraphs. First: "For generalization training, we need to provide a way to ... reset parameters, and vary them over time. The parameters could be chosen either deterministically or randomly."

2nd paragraph: "This is done by assigning each reset parameter a sampler, which (insert description of what a sampler does). ..."

deterministic or randomized. Each reset parameter is assigned a sampler. If a
sampler isn't provided for a reset parameter, the parameter maintains the
default value throughout the training, remaining unchanged. The samplers for all
the reset parameters are handled by a **Sampler Manager**, which is also
responsible for generating a new set of values for the reset parameters when
needed.

To setup the Sampler Manager, we setup a YAML file that specifies how we wish to
generate new samples. In this file, we specify the samplers and the
`resampling-duration` (number of simulation steps after which reset parameters are
resampled). Below is an example of a sampler file for the 3D ball environment.

```yaml
episode-length: 5000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which on is it ? resampling-duration or episode length ?


mass:
sampler-type: "uniform"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You provide multi-range, normal and uniform sampling, you need to explain what they are and what their arguments correspond to.

min_value: 0.5
max_value: 10

gravity:
sampler-type: "multirange_uniform"
intervals: [[7, 10], [15, 20]]

scale:
sampler-type: "uniform"
min_value: 0.75
max_value: 3

```

* `resampling-duration` (int) - Specifies the number of steps for agent to
train under a particular environment configuration before resetting the
environment with a new sample of reset parameters.

* `parameter_name` - Name of the reset parameter. This should match the name
specified in the academy of the intended environment for which the agent is
being trained. If a parameter specified in the file doesn't exist in the
environment, then this specification will be ignored.

* `sampler-type` - Specify the sampler type to use for the reset parameter.
This is a string that should exist in the `Sampler Factory` (explained
below).

* `sub-arguments` - Specify the characteristic parameters for the sampler.
In the example sampler file above, this would correspond to the `intervals`
key under the `multirange_uniform` sampler for the gravity reset parameter.
The key name should match the name of the corresponding argument in the sampler definition. (Look at defining a new sampler method)

The sampler manager allocates a sampler for a reset parameter by using the *Sampler Factory*, which maintains a dictionary mapping of string keys to sampler objects. The available samplers to be used for reset parameter resampling is as available in the Sampler Factory.

The implementation of the samplers can be found at `ml-agents-envs/mlagents/envs/sampler_class.py`.

### Defining a new sampler method

Custom sampling techniques must inherit from the *Sampler* base class (included in the `sampler_class` file) and preserve the interface. Once the class for the required method is specified, it must be registered in the Sampler Factory.

This can be done by subscribing to the *register_sampler* method of the SamplerFactory. The command is as follows:

`SamplerFactory.register_sampler(*custom_sampler_string_key*, *custom_sampler_object*)`

Once the Sampler Factory reflects the new register, the custom sampler can be used for resampling reset parameter. For demonstration, lets say our sampler was implemented as below, and we register the `CustomSampler` class with the string `custom-sampler` in the Sampler Factory.

```python
class CustomSampler(Sampler):

def __init__(self, argA, argB, argC):
self.possible_vals = [argA, argB, argC]

def sample_all(self):
return np.random.choice(self.possible_vals)
```

Now we need to specify this sampler in the sampler file. Lets say we wish to use this sampler for the reset parameter *mass*; the sampler file would specify the same for mass as the following (any order of the subarguments is valid).

```yaml
mass:
sampler-type: "custom-sampler"
argB: 1
argA: 2
argC: 3
```

With the sampler file setup, we can proceed to train our agent as explained in the next section.

### Training with Generalization Learning

We first begin with setting up the sampler file. After the sampler file is defined and configured, we proceed by launching `mlagents-learn` and specify our configured sampler file with the `--sampler` flag. To demonstrate, if we wanted to train a 3D ball agent with generalization using the `config/generalization-test.yaml` sampling setup, we can run

```sh
mlagents-learn config/trainer_config.yaml --sampler=config/generalize_test.yaml --run-id=3D-Ball-generalization --train
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generalization-test.yaml =/= config/generalize_test.yaml

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the sample code you use config/generalize_test.yaml but in the text you use config/generalization-test.yaml

```

We can observe progress and metrics via Tensorboard.
4 changes: 4 additions & 0 deletions docs/Training-ML-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ environment, you can set the following command line options when invoking
* `--curriculum=<file>` – Specify a curriculum JSON file for defining the
lessons for curriculum training. See [Curriculum
Training](Training-Curriculum-Learning.md) for more information.
* `--sampler=<file>` - Specify a sampler YAML file for defining the
sampler for generalization training. See [Generalization
Training](Training-Generalization-Learning.md) for more information.
* `--keep-checkpoints=<n>` – Specify the maximum number of model checkpoints to
keep. Checkpoints are saved after the number of steps specified by the
`save-freq` option. Once the maximum number of checkpoints has been reached,
Expand Down Expand Up @@ -197,6 +200,7 @@ are conducting, see:
* [Training with PPO](Training-PPO.md)
* [Using Recurrent Neural Networks](Feature-Memory.md)
* [Training with Curriculum Learning](Training-Curriculum-Learning.md)
* [Training with Generalization](Training-Generalization-Learning.md)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Training with Generalization sounds weird to me. I would say Training with Environment Parameters Sampling

* [Training with Imitation Learning](Training-Imitation-Learning.md)

You can also compare the
Expand Down
Binary file added docs/images/3dball_big.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/3dball_small.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions ml-agents-envs/mlagents/envs/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class UnityActionException(UnityException):
pass


class SamplerException(UnityException):
"""
Related to errors with the sampler actions.
"""

pass


class UnityTimeOutException(UnityException):
"""
Related to errors with communication timeouts.
Expand Down
134 changes: 134 additions & 0 deletions ml-agents-envs/mlagents/envs/sampler_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import numpy as np
from typing import *
from functools import *
from collections import OrderedDict
from abc import ABC, abstractmethod

from .exception import SamplerException


class Sampler(ABC):
@abstractmethod
def sample_parameter(self) -> float:
pass


class UniformSampler(Sampler):
"""
Uniformly draws a single sample in the range [min_value, max_value).
"""

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need a doc string on the init methods so we know what the arguments correspond to

self, min_value: Union[int, float], max_value: Union[int, float], **kwargs
) -> None:
self.min_value = min_value
self.max_value = max_value

def sample_parameter(self) -> float:
return np.random.uniform(self.min_value, self.max_value)


class MultiRangeUniformSampler(Sampler):
"""
Draws a single sample uniformly from the intervals provided. The sampler
first picks an interval based on a weighted selection, with the weights
assigned to an interval based on its range. After picking the range,
it proceeds to pick a value uniformly in that range.
"""

def __init__(self, intervals: List[List[Union[int, float]]], **kwargs) -> None:
self.intervals = intervals
# Measure the length of the intervals
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]
cum_interval_length = sum(interval_lengths)
# Assign weights to an interval proportionate to the interval size
self.interval_weights = [x / cum_interval_length for x in interval_lengths]

def sample_parameter(self) -> float:
cur_min, cur_max = self.intervals[
np.random.choice(len(self.intervals), p=self.interval_weights)
]
return np.random.uniform(cur_min, cur_max)


class GaussianSampler(Sampler):
"""
Draw a single sample value from a normal (gaussian) distribution.
This sampler is characterized by the mean and the standard deviation.
"""

def __init__(
self, mean: Union[float, int], st_dev: Union[float, int], **kwargs
) -> None:
self.mean = mean
self.st_dev = st_dev

def sample_parameter(self) -> float:
return np.random.normal(self.mean, self.st_dev)


class SamplerFactory:
"""
Maintain a directory of all samplers available.
Add new samplers using the register_sampler method.
"""

NAME_TO_CLASS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn I like this.

"uniform": UniformSampler,
"gaussian": GaussianSampler,
"multirange_uniform": MultiRangeUniformSampler,
}

@staticmethod
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None:
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls

@staticmethod
def init_sampler_class(name: str, params: Dict[str, Any]):
if name not in SamplerFactory.NAME_TO_CLASS:
raise SamplerException(
name + " sampler is not registered in the SamplerFactory."
" Use the register_sample method to register the string"
" associated to your sampler in the SamplerFactory."
)
sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
try:
return sampler_cls(**params)
except TypeError:
raise SamplerException(
"The sampler class associated to the " + name + " key in the factory "
"was not provided the required arguments. Please ensure that the sampler "
"config file consists of the appropriate keys for this sampler class."
)


class SamplerManager:
def __init__(self, reset_param_dict: Dict[str, Any]) -> None:
self.reset_param_dict = reset_param_dict if reset_param_dict else {}
assert isinstance(self.reset_param_dict, dict)
self.samplers: Dict[str, Sampler] = {}
for param_name, cur_param_dict in self.reset_param_dict.items():
if "sampler-type" not in cur_param_dict:
raise SamplerException(
"'sampler_type' argument hasn't been supplied for the {0} parameter".format(
param_name
)
)
sampler_name = cur_param_dict.pop("sampler-type")
param_sampler = SamplerFactory.init_sampler_class(
sampler_name, cur_param_dict
)

self.samplers[param_name] = param_sampler

def is_empty(self) -> bool:
"""
Check for if sampler_manager is empty.
"""
return not bool(self.samplers)

def sample_all(self) -> Dict[str, float]:
res = {}
for param_name, param_sampler in list(self.samplers.items()):
res[param_name] = param_sampler.sample_parameter()
return res
Loading