-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Develop generalization training #2232
Changes from 49 commits
c650e8a
f78d974
643e816
d90031b
a4cf9ea
e0cd1df
33fe8ae
cf0eae1
155cd4e
65884b3
7b21f29
7a1944d
6b3d1c0
fee2b14
626b0b7
a194379
381ff86
e8c1e80
6b82251
436987e
abd4c03
bfb7bd0
390d98e
e5423bc
7ecc6d8
0f302d7
2f1f814
23fbff8
b9a7f2a
1ce970c
56810c0
ea5f4db
68bdb59
4d7c79f
9d1acc5
5063818
11e2c78
165ab57
d5b0838
c3b372c
bb84831
c441877
8f1c7a8
c8493c0
ecdc734
d4c9163
337ef3e
e285039
ddc2e60
e8ad471
e6a1029
30e0727
603c5c8
fb74bbe
3a5f033
ce152ee
b5b97bd
f545ed1
0e703e2
408b32b
03de771
c47ccfd
44dffd7
79f2079
092e370
bb8654f
188dec7
3b63502
d93f79e
f79263b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
episode-length: 5000 | ||
|
||
mass: | ||
sampler-type: "uniform" | ||
min_value: 0.5 | ||
max_value: 10 | ||
|
||
gravity: | ||
sampler-type: "uniform" | ||
min_value: 7 | ||
max_value: 12 | ||
|
||
scale: | ||
sampler-type: "uniform" | ||
min_value: 0.75 | ||
max_value: 3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Training Generalized Reinforcement Learning agents | ||
|
||
Reinforcement learning has a rather unique setup as opposed to supervised and | ||
unsupervised learning. Agents here are trained and tested on the same exact environment, | ||
which is analogous to a model being trained and tested on an identical dataset in supervised learning! This setting results in overfitting; the inability of the agent to generalize to slight tweaks or variations in the environment. This is problematic in instances when environments are randomly instantiated with varying properties. To make agents more robust, we train an agent over multiple variations of the environment. The agent is trained with the intent that it learns to maintain a minimum performance regardless of the environment variant and that it generalizes to maintain this in unseen future variants of the environment. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line 3 and 4 are cut but line 5 is not. You need to make sure that the number of characters per line is consistent throughout the document. |
||
|
||
Ball scale of 0.5 | Ball scale of 4 | ||
:-------------------------:|:-------------------------: | ||
![](images/3dball_small.png) | ![](images/3dball_big.png) | ||
|
||
_Variations of the 3D Ball environment._ | ||
|
||
To vary the environments, we first decided what the reset parameters for the environment. These parameters are known as `Reset Parameters`, introduced in [Curriculum Learning](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Curriculum-Learning.md). In the 3D ball environment example displayed in the figure above, the reset parameters are `gravity`, `ball_mass` and `ball_scale`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You say "we first decided what the reset parameters are" and then introduce the reset Parameters in the following sentence. This is confusing. Also, reset parameters were not introduced in Curriculum learning, so I would say Curriculum learning also uses reset parameters instead of introduced them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
## In-practice | ||
|
||
To test the effectiveness of this training procedure, we train 3 models over 50000 steps: | ||
1. Model Trained on Default reset parameter values (Default model). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This numbering does not look good in the markdown. I think you need to skip a line somewhere for it to be a numbered list. |
||
2. Model Trained on a range of reset parameter values (Random model). Reset parameter values are picked uniformly over the range and the model is trained on each configuration for 5000 steps before they are randomly sampled again. The range consists of the default values. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
3. Model Trained on a range of reset parameters values (Extreme model). The range here are the extreme ends, which lie outside the range used to train the Random model. Reset parameter values are selected randomly and the model is trained for 5000 steps on each configuration of values before resampling the values again. | ||
|
||
Then, these 3 models are tested for their performance in all the reset parameter settings for 50000 steps each. The figure below outlines the performance of each model. | ||
|
||
![3DBall environment mean rewards](images/3DBall_generalization_rewards.png) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The results are bad. I would not put them in the documentation for the release. |
||
|
||
In a relatively simple environment we are able to see differences in the performance between differently trained models, with the models exposed to various environments performing better in variations of environments than the default model trained over a single environment. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not obvious in the graph provided. |
||
|
||
## How-to | ||
|
||
For generalization training, we need to provide a way to modify the environment by supplying a set of reset parameters. This provision can be either deterministic or randomized. Each reset parameter is assigned a sampler. If a sampler isn't provided for a reset parameter, the parameter maintains the default value throughout the training, remaining unchanged. The samplers for all the reset parameters are handled by a **Sampler Manager**, which is also responsible for generating a new set of values for the reset parameters when needed. | ||
|
||
To setup the Sampler Manager, we setup a YAML file that specifies how we wish to generate new samples. In this file, we specify the samplers and the `resampling-duration` (number of training steps after which reset parameters are resampled). Below is an example of a sampler file for the 3D ball environment. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
```yaml | ||
episode-length: 5000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which on is it ? resampling-duration or episode length ? |
||
|
||
mass: | ||
sampler-type: "uniform" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You provide multi-range, normal and uniform sampling, you need to explain what they are and what their arguments correspond to. |
||
min_value: 0.5 | ||
max_value: 10 | ||
|
||
gravity: | ||
sampler-type: "multirange_uniform" | ||
intervals: [[7, 10], [15, 20]] | ||
|
||
scale: | ||
sampler-type: "uniform" | ||
min_value: 0.75 | ||
max_value: 3 | ||
|
||
``` | ||
|
||
* `resampling-duration` (int) - Specifies the number of steps for agent to train under a particular environment configuration before resetting the environment with a new sample of reset parameters. | ||
|
||
* `parameter_name` - Name of the reset parameter. This should match the name specified in the academy of the intended environment for which the agent is being trained. If a parameter specified in the file doesn't exist in the environment, then this specification will be ignored. | ||
|
||
* `sampler-type` - Specify the sampler type to use for the reset parameter. This is a string that should exist in the `Sampler Factory` (explained below). | ||
|
||
* `sub-arguments` - Specify the characteristic parameters for the sampler. In the example sampler file above, this would correspond to the `intervals` key under the `multirange_uniform` sampler for the gravity reset parameter. The key name should match the name of the corresponding argument in the sampler definition. (Look at defining a new sampler method) | ||
|
||
The sampler manager allocates a sampler for a reset parameter by using the *Sampler Factory*, which maintains a dictionary mapping of string keys to sampler objects. The available samplers to be used for reset parameter resampling is as available in the Sampler Factory. | ||
|
||
The implementation of the samplers can be found at `ml-agents-envs/mlagents/envs/sampler_class.py`. | ||
|
||
### Defining a new sampler method | ||
|
||
Custom sampling techniques must inherit from the *Sampler* base class (included in the `sampler_class` file) and preserve the interface. Once the class for the required method is specified, it must be registered in the Sampler Factory. | ||
|
||
This can be done by subscribing to the *register_sampler* method of the SamplerFactory. The command is as follows: | ||
|
||
`SamplerFactory.register_sampler(*custom_sampler_string_key*, *custom_sampler_object*)` | ||
|
||
Once the Sampler Factory reflects the new register, the custom sampler can be used for resampling reset parameter. For demonstration, lets say our sampler was implemented as below, and we register the `CustomSampler` class with the string `custom-sampler` in the Sampler Factory. | ||
|
||
```python | ||
class CustomSampler(Sampler): | ||
|
||
def __init__(self, argA, argB, argC): | ||
self.possible_vals = [argA, argB, argC] | ||
|
||
def sample_all(self): | ||
return np.random.choice(self.possible_vals) | ||
``` | ||
|
||
Now we need to specify this sampler in the sampler file. Lets say we wish to use this sampler for the reset parameter *mass*; the sampler file would specify the same for mass as the following. | ||
|
||
```yaml | ||
mass: | ||
sampler-type: "custom-sampler" | ||
argA: 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the ordering of the args relevant or only their names ? This should be specified. |
||
argB: 2 | ||
argC: 3 | ||
``` | ||
|
||
With the sampler file setup, we can proceed to train our agent as explained in the next section. | ||
|
||
### Training with Generalization Learning | ||
|
||
We first begin with setting up the sampler file. After the sampler file is defined and configured, we proceed by launching `mlagents-learn` and specify our configured sampler file with the `--sampler` flag. To demonstrate, if we wanted to train a 3D ball agent with generalization using the `generalization-test.yaml` sampling setup, we can run | ||
|
||
```sh | ||
mlagents-learn config/trainer_config.yaml --sampler=config/generalize_test.yaml --run-id=3D-Ball-generalization --train | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the sample code you use config/generalize_test.yaml but in the text you use config/generalization-test.yaml |
||
``` | ||
|
||
We can observe progress and metrics via Tensorboard. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,6 +101,9 @@ environment, you can set the following command line options when invoking | |
* `--curriculum=<file>` – Specify a curriculum JSON file for defining the | ||
lessons for curriculum training. See [Curriculum | ||
Training](Training-Curriculum-Learning.md) for more information. | ||
* `--sampler=<file>` - Specify a sampler YAML file for defining the | ||
sampler for generalization training. See [Generalization | ||
Training](Training-Generalization-Learning.md) for more information. | ||
* `--keep-checkpoints=<n>` – Specify the maximum number of model checkpoints to | ||
keep. Checkpoints are saved after the number of steps specified by the | ||
`save-freq` option. Once the maximum number of checkpoints has been reached, | ||
|
@@ -197,6 +200,7 @@ are conducting, see: | |
* [Training with PPO](Training-PPO.md) | ||
* [Using Recurrent Neural Networks](Feature-Memory.md) | ||
* [Training with Curriculum Learning](Training-Curriculum-Learning.md) | ||
* [Training with Generalization](Training-Generalization-Learning.md) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Training with Generalization sounds weird to me. I would say Training with Environment Parameters Sampling |
||
* [Training with Imitation Learning](Training-Imitation-Learning.md) | ||
|
||
You can also compare the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import numpy as np | ||
from typing import * | ||
from functools import * | ||
from collections import OrderedDict | ||
from abc import ABC, abstractmethod | ||
|
||
from .exception import SamplerException | ||
|
||
|
||
class Sampler(ABC): | ||
@abstractmethod | ||
def sample_parameter(self) -> float: | ||
pass | ||
|
||
|
||
class UniformSampler(Sampler): | ||
# kwargs acts as a sink for extra unneeded args | ||
def __init__(self, min_value, max_value, **kwargs): | ||
self.min_value = min_value | ||
self.max_value = max_value | ||
|
||
def sample_parameter(self) -> float: | ||
return np.random.uniform(self.min_value, self.max_value) | ||
|
||
|
||
class MultiRangeUniformSampler(Sampler): | ||
def __init__(self, intervals, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should add more comments one the classes. You need to define what the sampler does (Uniform vs normal for instance) as well as the parameters type and function. For example : Intervals is a list of floats, this needs to be specified somewhere. |
||
self.intervals = intervals | ||
# Measure the length of the intervals | ||
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals] | ||
# Cumulative size of the intervals | ||
cum_interval_length = sum(interval_lengths) | ||
# Assign weights to an interval proportionate to the interval size | ||
self.interval_weights = [x / cum_interval_length for x in interval_lengths] | ||
|
||
def sample_parameter(self) -> float: | ||
cur_min, cur_max = self.intervals[ | ||
np.random.choice(len(self.intervals), p=self.interval_weights) | ||
] | ||
return np.random.uniform(cur_min, cur_max) | ||
|
||
|
||
class GaussianSampler(Sampler): | ||
def __init__(self, mean, var, **kwargs): | ||
self.mean = mean | ||
self.var = var | ||
|
||
def sample_parameter(self) -> float: | ||
return np.random.normal(self.mean, self.var) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. variance =/= standard deviation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, let me fix the naming |
||
|
||
|
||
# To introduce new sampling methods, just need to 'register' them to this sampler factory | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment should be formatted consistently with the other comments of the repo. |
||
class SamplerFactory: | ||
NAME_TO_CLASS = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Damn I like this. |
||
"uniform": UniformSampler, | ||
"gaussian": GaussianSampler, | ||
"multirange_uniform": MultiRangeUniformSampler, | ||
} | ||
|
||
@staticmethod | ||
def register_sampler(name, sampler_cls): | ||
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls | ||
|
||
@staticmethod | ||
def init_sampler_class(name, param_dict): | ||
if name not in SamplerFactory.NAME_TO_CLASS: | ||
raise SamplerException( | ||
name + " sampler is not registered in the SamplerFactory." | ||
" Use the register_sample method to register the string" | ||
" associated to your sampler in the SamplerFactory." | ||
) | ||
sampler_cls = SamplerFactory.NAME_TO_CLASS[name] | ||
try: | ||
return sampler_cls(**param_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a parameter dict or a list of parameters ? I think this variable should be named params (I think it is usually called this way) |
||
except: | ||
raise SamplerException( | ||
"The sampler class associated to the " + name + " key in the factory " | ||
"was not provided the required arguments. Please ensure that the sampler " | ||
"config file consists of the appropriate keys for this sampler class." | ||
) | ||
|
||
|
||
class SamplerManager: | ||
def __init__(self, reset_param_dict): | ||
self.reset_param_dict = reset_param_dict if reset_param_dict else {} | ||
assert isinstance(self.reset_param_dict, dict) | ||
self.samplers = OrderedDict() | ||
for param_name, cur_param_dict in self.reset_param_dict.items(): | ||
if "sampler-type" not in cur_param_dict: | ||
raise SamplerException( | ||
"'sampler_type' argument hasn't been supplied for the {0} parameter".format( | ||
param_name | ||
) | ||
) | ||
sampler_name = cur_param_dict.pop("sampler-type") | ||
param_sampler = SamplerFactory.init_sampler_class( | ||
sampler_name, cur_param_dict | ||
) | ||
|
||
self.samplers[param_name] = param_sampler | ||
|
||
def is_empty(self) -> bool: | ||
""" | ||
If self.samplers is empty, then bool of it returns false, indicating that the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sentence is confusing. Shouldn't is_empty be true when the samplers is empty ? |
||
sampler manager isn't managing any samplers. | ||
""" | ||
return not bool(self.samplers) | ||
|
||
def sample_all(self) -> Dict[str, float]: | ||
res = {} | ||
for param_name, param_sampler in list(self.samplers.items()): | ||
res[param_name] = param_sampler.sample_parameter() | ||
return res |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from math import isclose | ||
import pytest | ||
|
||
from mlagents.envs.sampler_class import SamplerManager | ||
from mlagents.envs.sampler_class import ( | ||
UniformSampler, | ||
MultiRangeUniformSampler, | ||
GaussianSampler, | ||
) | ||
from mlagents.envs.exception import UnityException | ||
|
||
|
||
def basic_3Dball_sampler(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason these need to be functions? Why not just put the dict literal in the test function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just put them as function for cleaner code along with consistency with other test files. Nothing special otherwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rather this function not have the name 3Dball in it. Call it sampler1 or something. The test should in my opinion be agnostic of the names of the demo environments |
||
return { | ||
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10}, | ||
"gravity": { | ||
"sampler-type": "multirange_uniform", | ||
"intervals": [[8, 11], [15, 20]], | ||
}, | ||
} | ||
|
||
|
||
def check_value_in_intervals(val, intervals): | ||
check_in_bounds = [a <= val <= b for a, b in intervals] | ||
return any(check_in_bounds) | ||
|
||
|
||
def test_3Dball_sampler(): | ||
config = basic_3Dball_sampler() | ||
sampler = SamplerManager(config) | ||
|
||
assert sampler.is_empty() is False | ||
assert isinstance(sampler.samplers["mass"], UniformSampler) | ||
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler) | ||
|
||
cur_sample = sampler.sample_all() | ||
|
||
# Check uniform sampler for mass | ||
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"] | ||
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"] | ||
assert config["mass"]["min_value"] <= cur_sample["mass"] | ||
assert config["mass"]["max_value"] >= cur_sample["mass"] | ||
|
||
# Check multirange_uniform sampler for gravity | ||
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"] | ||
assert check_value_in_intervals( | ||
cur_sample["gravity"], sampler.samplers["gravity"].intervals | ||
) | ||
|
||
|
||
def basic_tennis_sampler(): | ||
return {"angle": {"sampler-type": "gaussian", "mean": 0, "var": 1}} | ||
|
||
|
||
def test_tennis_sampler(): | ||
config = basic_tennis_sampler() | ||
sampler = SamplerManager(config) | ||
assert sampler.is_empty() is False | ||
assert isinstance(sampler.samplers["angle"], GaussianSampler) | ||
|
||
cur_sample = sampler.sample_all() | ||
|
||
# Check angle gaussian sampler | ||
assert sampler.samplers["angle"].mean == config["angle"]["mean"] | ||
assert sampler.samplers["angle"].var == config["angle"]["var"] | ||
|
||
|
||
def test_empty_samplers(): | ||
empty_sampler = SamplerManager({}) | ||
assert empty_sampler.is_empty() | ||
empty_cur_sample = empty_sampler.sample_all() | ||
assert empty_cur_sample == {} | ||
|
||
none_sampler = SamplerManager(None) | ||
assert none_sampler.is_empty() | ||
none_cur_sample = none_sampler.sample_all() | ||
assert none_cur_sample == {} | ||
|
||
|
||
def incorrect_uniform_sampler(): | ||
# Do not specify required arguments to uniform sampler | ||
return {"mass": {"sampler-type": "uniform", "min-value": 10}} | ||
|
||
|
||
def incorrect_sampler_config(): | ||
# Do not specify 'sampler-type' key | ||
return {"mass": {"min-value": 2, "max-value": 30}} | ||
|
||
|
||
def test_incorrect_uniform_sampler(): | ||
config = incorrect_uniform_sampler() | ||
with pytest.raises(UnityException): | ||
SamplerManager(config) | ||
|
||
|
||
def test_incorrect_sampler(): | ||
config = incorrect_sampler_config() | ||
with pytest.raises(UnityException): | ||
SamplerManager(config) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,3 +23,5 @@ class MetaCurriculumError(TrainerError): | |
""" | ||
Any error related to the configuration of a metacurriculum. | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra line ? |
||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: capitalize agents