-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Develop generalization training #2232
Changes from 68 commits
c650e8a
f78d974
643e816
d90031b
a4cf9ea
e0cd1df
33fe8ae
cf0eae1
155cd4e
65884b3
7b21f29
7a1944d
6b3d1c0
fee2b14
626b0b7
a194379
381ff86
e8c1e80
6b82251
436987e
abd4c03
bfb7bd0
390d98e
e5423bc
7ecc6d8
0f302d7
2f1f814
23fbff8
b9a7f2a
1ce970c
56810c0
ea5f4db
68bdb59
4d7c79f
9d1acc5
5063818
11e2c78
165ab57
d5b0838
c3b372c
bb84831
c441877
8f1c7a8
c8493c0
ecdc734
d4c9163
337ef3e
e285039
ddc2e60
e8ad471
e6a1029
30e0727
603c5c8
fb74bbe
3a5f033
ce152ee
b5b97bd
f545ed1
0e703e2
408b32b
03de771
c47ccfd
44dffd7
79f2079
092e370
bb8654f
188dec7
3b63502
d93f79e
f79263b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
resampling-interval: 5000 | ||
|
||
mass: | ||
sampler-type: "uniform" | ||
min_value: 0.5 | ||
max_value: 10 | ||
|
||
gravity: | ||
sampler-type: "uniform" | ||
min_value: 7 | ||
max_value: 12 | ||
|
||
scale: | ||
sampler-type: "uniform" | ||
min_value: 0.75 | ||
max_value: 3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Training Generalized Reinforcement Learning agents | ||
|
||
Reinforcement learning has a rather unique setup as opposed to supervised and | ||
unsupervised learning. Agents here are trained and tested on the same exact | ||
environment, which is analogous to a model being trained and tested on an | ||
identical dataset in supervised learning! This setting results in overfitting; | ||
the inability of the agent to generalize to slight tweaks or variations in the | ||
environment. This is problematic in instances when environments are randomly | ||
instantiated with varying properties. To make agents more robust, we train an | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would rephrase: "To make agents more robust, one approach is to train over multiple variations of the environment." |
||
agent over multiple variations of the environment. The agent is trained with | ||
the intent that it learns to maintain a minimum performance regardless of the | ||
environment variant and that it generalizes to maintain this in unseen future | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would rephrase: "... and that it generalizes to be robust to future unseen variants of the environment." |
||
variants of the environment. | ||
|
||
Ball scale of 0.5 | Ball scale of 4 | ||
:-------------------------:|:-------------------------: | ||
![](images/3dball_small.png) | ![](images/3dball_big.png) | ||
|
||
_Variations of the 3D Ball environment._ | ||
|
||
To vary environments, we first decide what parameters to vary in an | ||
environment. These parameters are known as `Reset Parameters`. In the 3D ball | ||
environment example displayed in the figure above, the reset parameters are `gravity`, `ball_mass` and `ball_scale`. | ||
|
||
|
||
## How-to | ||
|
||
For generalization training, we need to provide a way to modify the environment | ||
by supplying a set of reset parameters. This provision can be either | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make this clearer, I'd break this into two paragraphs. First: "For generalization training, we need to provide a way to ... reset parameters, and vary them over time. The parameters could be chosen either deterministically or randomly." 2nd paragraph: "This is done by assigning each reset parameter a sampler, which (insert description of what a sampler does). ..." |
||
deterministic or randomized. Each reset parameter is assigned a sampler. If a | ||
sampler isn't provided for a reset parameter, the parameter maintains the | ||
default value throughout the training, remaining unchanged. The samplers for all | ||
the reset parameters are handled by a **Sampler Manager**, which is also | ||
responsible for generating a new set of values for the reset parameters when | ||
needed. | ||
|
||
To setup the Sampler Manager, we setup a YAML file that specifies how we wish to | ||
generate new samples. In this file, we specify the samplers and the | ||
`resampling-duration` (number of simulation steps after which reset parameters are | ||
resampled). Below is an example of a sampler file for the 3D ball environment. | ||
|
||
```yaml | ||
episode-length: 5000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which on is it ? resampling-duration or episode length ? |
||
|
||
mass: | ||
sampler-type: "uniform" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You provide multi-range, normal and uniform sampling, you need to explain what they are and what their arguments correspond to. |
||
min_value: 0.5 | ||
max_value: 10 | ||
|
||
gravity: | ||
sampler-type: "multirange_uniform" | ||
intervals: [[7, 10], [15, 20]] | ||
|
||
scale: | ||
sampler-type: "uniform" | ||
min_value: 0.75 | ||
max_value: 3 | ||
|
||
``` | ||
|
||
* `resampling-duration` (int) - Specifies the number of steps for agent to | ||
train under a particular environment configuration before resetting the | ||
environment with a new sample of reset parameters. | ||
|
||
* `parameter_name` - Name of the reset parameter. This should match the name | ||
specified in the academy of the intended environment for which the agent is | ||
being trained. If a parameter specified in the file doesn't exist in the | ||
environment, then this specification will be ignored. | ||
|
||
* `sampler-type` - Specify the sampler type to use for the reset parameter. | ||
This is a string that should exist in the `Sampler Factory` (explained | ||
below). | ||
|
||
* `sub-arguments` - Specify the characteristic parameters for the sampler. | ||
In the example sampler file above, this would correspond to the `intervals` | ||
key under the `multirange_uniform` sampler for the gravity reset parameter. | ||
The key name should match the name of the corresponding argument in the sampler definition. (Look at defining a new sampler method) | ||
|
||
The sampler manager allocates a sampler for a reset parameter by using the *Sampler Factory*, which maintains a dictionary mapping of string keys to sampler objects. The available samplers to be used for reset parameter resampling is as available in the Sampler Factory. | ||
|
||
The implementation of the samplers can be found at `ml-agents-envs/mlagents/envs/sampler_class.py`. | ||
|
||
### Defining a new sampler method | ||
|
||
Custom sampling techniques must inherit from the *Sampler* base class (included in the `sampler_class` file) and preserve the interface. Once the class for the required method is specified, it must be registered in the Sampler Factory. | ||
|
||
This can be done by subscribing to the *register_sampler* method of the SamplerFactory. The command is as follows: | ||
|
||
`SamplerFactory.register_sampler(*custom_sampler_string_key*, *custom_sampler_object*)` | ||
|
||
Once the Sampler Factory reflects the new register, the custom sampler can be used for resampling reset parameter. For demonstration, lets say our sampler was implemented as below, and we register the `CustomSampler` class with the string `custom-sampler` in the Sampler Factory. | ||
|
||
```python | ||
class CustomSampler(Sampler): | ||
|
||
def __init__(self, argA, argB, argC): | ||
self.possible_vals = [argA, argB, argC] | ||
|
||
def sample_all(self): | ||
return np.random.choice(self.possible_vals) | ||
``` | ||
|
||
Now we need to specify this sampler in the sampler file. Lets say we wish to use this sampler for the reset parameter *mass*; the sampler file would specify the same for mass as the following (any order of the subarguments is valid). | ||
|
||
```yaml | ||
mass: | ||
sampler-type: "custom-sampler" | ||
argB: 1 | ||
argA: 2 | ||
argC: 3 | ||
``` | ||
|
||
With the sampler file setup, we can proceed to train our agent as explained in the next section. | ||
|
||
### Training with Generalization Learning | ||
|
||
We first begin with setting up the sampler file. After the sampler file is defined and configured, we proceed by launching `mlagents-learn` and specify our configured sampler file with the `--sampler` flag. To demonstrate, if we wanted to train a 3D ball agent with generalization using the `config/generalization-test.yaml` sampling setup, we can run | ||
|
||
```sh | ||
mlagents-learn config/trainer_config.yaml --sampler=config/generalize_test.yaml --run-id=3D-Ball-generalization --train | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the sample code you use config/generalize_test.yaml but in the text you use config/generalization-test.yaml |
||
``` | ||
|
||
We can observe progress and metrics via Tensorboard. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,6 +101,9 @@ environment, you can set the following command line options when invoking | |
* `--curriculum=<file>` – Specify a curriculum JSON file for defining the | ||
lessons for curriculum training. See [Curriculum | ||
Training](Training-Curriculum-Learning.md) for more information. | ||
* `--sampler=<file>` - Specify a sampler YAML file for defining the | ||
sampler for generalization training. See [Generalization | ||
Training](Training-Generalization-Learning.md) for more information. | ||
* `--keep-checkpoints=<n>` – Specify the maximum number of model checkpoints to | ||
keep. Checkpoints are saved after the number of steps specified by the | ||
`save-freq` option. Once the maximum number of checkpoints has been reached, | ||
|
@@ -197,6 +200,7 @@ are conducting, see: | |
* [Training with PPO](Training-PPO.md) | ||
* [Using Recurrent Neural Networks](Feature-Memory.md) | ||
* [Training with Curriculum Learning](Training-Curriculum-Learning.md) | ||
* [Training with Generalization](Training-Generalization-Learning.md) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Training with Generalization sounds weird to me. I would say Training with Environment Parameters Sampling |
||
* [Training with Imitation Learning](Training-Imitation-Learning.md) | ||
|
||
You can also compare the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import numpy as np | ||
from typing import * | ||
from functools import * | ||
from collections import OrderedDict | ||
from abc import ABC, abstractmethod | ||
|
||
from .exception import SamplerException | ||
|
||
|
||
class Sampler(ABC): | ||
@abstractmethod | ||
def sample_parameter(self) -> float: | ||
pass | ||
|
||
|
||
class UniformSampler(Sampler): | ||
""" | ||
Uniformly draws a single sample in the range [min_value, max_value). | ||
""" | ||
|
||
def __init__( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need a doc string on the init methods so we know what the arguments correspond to |
||
self, min_value: Union[int, float], max_value: Union[int, float], **kwargs | ||
) -> None: | ||
self.min_value = min_value | ||
self.max_value = max_value | ||
|
||
def sample_parameter(self) -> float: | ||
return np.random.uniform(self.min_value, self.max_value) | ||
|
||
|
||
class MultiRangeUniformSampler(Sampler): | ||
""" | ||
Draws a single sample uniformly from the intervals provided. The sampler | ||
first picks an interval based on a weighted selection, with the weights | ||
assigned to an interval based on its range. After picking the range, | ||
it proceeds to pick a value uniformly in that range. | ||
""" | ||
|
||
def __init__(self, intervals: List[List[Union[int, float]]], **kwargs) -> None: | ||
self.intervals = intervals | ||
# Measure the length of the intervals | ||
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals] | ||
cum_interval_length = sum(interval_lengths) | ||
# Assign weights to an interval proportionate to the interval size | ||
self.interval_weights = [x / cum_interval_length for x in interval_lengths] | ||
|
||
def sample_parameter(self) -> float: | ||
cur_min, cur_max = self.intervals[ | ||
np.random.choice(len(self.intervals), p=self.interval_weights) | ||
] | ||
return np.random.uniform(cur_min, cur_max) | ||
|
||
|
||
class GaussianSampler(Sampler): | ||
""" | ||
Draw a single sample value from a normal (gaussian) distribution. | ||
This sampler is characterized by the mean and the standard deviation. | ||
""" | ||
|
||
def __init__( | ||
self, mean: Union[float, int], st_dev: Union[float, int], **kwargs | ||
) -> None: | ||
self.mean = mean | ||
self.st_dev = st_dev | ||
|
||
def sample_parameter(self) -> float: | ||
return np.random.normal(self.mean, self.st_dev) | ||
|
||
|
||
class SamplerFactory: | ||
""" | ||
Maintain a directory of all samplers available. | ||
Add new samplers using the register_sampler method. | ||
""" | ||
|
||
NAME_TO_CLASS = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Damn I like this. |
||
"uniform": UniformSampler, | ||
"gaussian": GaussianSampler, | ||
"multirange_uniform": MultiRangeUniformSampler, | ||
} | ||
|
||
@staticmethod | ||
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None: | ||
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls | ||
|
||
@staticmethod | ||
def init_sampler_class(name: str, params: Dict[str, Any]): | ||
if name not in SamplerFactory.NAME_TO_CLASS: | ||
raise SamplerException( | ||
name + " sampler is not registered in the SamplerFactory." | ||
" Use the register_sample method to register the string" | ||
" associated to your sampler in the SamplerFactory." | ||
) | ||
sampler_cls = SamplerFactory.NAME_TO_CLASS[name] | ||
try: | ||
return sampler_cls(**params) | ||
except TypeError: | ||
raise SamplerException( | ||
"The sampler class associated to the " + name + " key in the factory " | ||
"was not provided the required arguments. Please ensure that the sampler " | ||
"config file consists of the appropriate keys for this sampler class." | ||
) | ||
|
||
|
||
class SamplerManager: | ||
def __init__(self, reset_param_dict: Dict[str, Any]) -> None: | ||
self.reset_param_dict = reset_param_dict if reset_param_dict else {} | ||
assert isinstance(self.reset_param_dict, dict) | ||
self.samplers: Dict[str, Sampler] = {} | ||
for param_name, cur_param_dict in self.reset_param_dict.items(): | ||
if "sampler-type" not in cur_param_dict: | ||
raise SamplerException( | ||
"'sampler_type' argument hasn't been supplied for the {0} parameter".format( | ||
param_name | ||
) | ||
) | ||
sampler_name = cur_param_dict.pop("sampler-type") | ||
param_sampler = SamplerFactory.init_sampler_class( | ||
sampler_name, cur_param_dict | ||
) | ||
|
||
self.samplers[param_name] = param_sampler | ||
|
||
def is_empty(self) -> bool: | ||
""" | ||
Check for if sampler_manager is empty. | ||
""" | ||
return not bool(self.samplers) | ||
|
||
def sample_all(self) -> Dict[str, float]: | ||
res = {} | ||
for param_name, param_sampler in list(self.samplers.items()): | ||
res[param_name] = param_sampler.sample_parameter() | ||
return res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: capitalize agents