Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop generalization training #2232

Merged
merged 70 commits into from
Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
c650e8a
Adding new command line arguments
sankalp04 Jun 21, 2019
f78d974
Allow generalization training with specified arguments of min_reward …
sankalp04 Jun 28, 2019
643e816
Added sampler_class in mlagents/envs
sankalp04 Jun 28, 2019
d90031b
Change reset parameters based on reward or progress metric
sankalp04 Jul 2, 2019
a4cf9ea
Include error checking in sampling methods
sankalp04 Jul 2, 2019
e0cd1df
Incorporate generalization checks for resetting parameters in take_step
sankalp04 Jul 2, 2019
33fe8ae
Add Sampler error to track errors in the sampler class
sankalp04 Jul 2, 2019
cf0eae1
Add LessonControllerError to track errors in LessonController
sankalp04 Jul 2, 2019
155cd4e
Get rid of dead code and clean up code
sankalp04 Jul 2, 2019
65884b3
Removed check_key and replaced with **param_dict for implicit type ch…
sankalp04 Jul 3, 2019
7b21f29
Instantiate SamplerManager in learn.py instead of trainer_controller
sankalp04 Jul 3, 2019
7a1944d
Made the code cleanup changes - mostly nit changes
sankalp04 Jul 3, 2019
6b3d1c0
Cleanup PPOTrainer call and env_reset
sankalp04 Jul 3, 2019
fee2b14
Example parameter sampling file config
sankalp04 Jul 8, 2019
626b0b7
Updated example sampler config file
sankalp04 Jul 8, 2019
a194379
Remove LessonControllerException
sankalp04 Jul 8, 2019
381ff86
Conduct training without lesson_controller
sankalp04 Jul 8, 2019
e8c1e80
Remove dead code
sankalp04 Jul 8, 2019
6b82251
Suggested stylistic and nit changes made
sankalp04 Jul 8, 2019
436987e
Use seed passed in learn.py instead of in sampler config
sankalp04 Jul 10, 2019
abd4c03
Init commit of sampler tests
sankalp04 Jul 10, 2019
bfb7bd0
Init commit to sampler_class tests
sankalp04 Jul 10, 2019
390d98e
Incorrect file in test folder
sankalp04 Jul 10, 2019
e5423bc
Add tests
sankalp04 Jul 11, 2019
7ecc6d8
Add safety check in case samplers provided with incorrect or insuffic…
sankalp04 Jul 17, 2019
0f302d7
Updated test files with tests for sampler class
sankalp04 Jul 17, 2019
2f1f814
Fixed syntax errors and removed dead code
sankalp04 Jul 17, 2019
23fbff8
Changed exception types in test and imported sampler types for checks
sankalp04 Jul 17, 2019
b9a7f2a
Modify test_learn to accomodate sampler argument
sankalp04 Jul 17, 2019
1ce970c
Fixed tests
sankalp04 Jul 17, 2019
56810c0
Fixed test case of different instantiations of sampler_manager in test
sankalp04 Jul 17, 2019
ea5f4db
Attempt at mocking the sampler_class
sankalp04 Jul 17, 2019
68bdb59
Removed unused function trailing from use of lesson_controller
sankalp04 Jul 17, 2019
4d7c79f
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 18, 2019
9d1acc5
Resolved conflicts, finished type annotations, nit changes and accomo…
sankalp04 Jul 19, 2019
5063818
Trainer controller argument name changed so tests are changed accordi…
sankalp04 Jul 19, 2019
11e2c78
Modified tests
sankalp04 Jul 19, 2019
165ab57
Remove extra arg in test case
sankalp04 Jul 19, 2019
d5b0838
Fixed pytest failure in checks
sankalp04 Jul 19, 2019
c3b372c
Fixed style from black
sankalp04 Jul 19, 2019
bb84831
Fixed pending style issues
sankalp04 Jul 19, 2019
c441877
Fixed spacing issues
sankalp04 Jul 19, 2019
8f1c7a8
Fixed missing type annotation
sankalp04 Jul 19, 2019
c8493c0
Check steps for none for style check
sankalp04 Jul 19, 2019
ecdc734
Revert "Check steps for none for style check"
sankalp04 Jul 19, 2019
d4c9163
Check steps for none for type checking
sankalp04 Jul 19, 2019
337ef3e
Check lesson duration for None for style check
sankalp04 Jul 19, 2019
e285039
Clean test code to instantiate tests within function
sankalp04 Jul 19, 2019
ddc2e60
Init commit to generalization docs
sankalp04 Jul 23, 2019
e8ad471
Change key to resampling-interval
sankalp04 Jul 23, 2019
e6a1029
Remove 3D ball curves and fix formatting
sankalp04 Jul 23, 2019
30e0727
Add comments and finish annotations
sankalp04 Jul 23, 2019
603c5c8
Remove attribution between tests and environments
sankalp04 Jul 23, 2019
fb74bbe
Rename lesson-length to resampling_interval
sankalp04 Jul 23, 2019
3a5f033
Modify test cases to reflect variable name changes
sankalp04 Jul 23, 2019
ce152ee
Fix bug
sankalp04 Jul 23, 2019
b5b97bd
Remove extra file imports in test_learn
sankalp04 Jul 23, 2019
f545ed1
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 23, 2019
0e703e2
Modify test to include sampler_manager
sankalp04 Jul 23, 2019
408b32b
Fix style changes
sankalp04 Jul 23, 2019
03de771
Fixed formatting and style part II
sankalp04 Jul 23, 2019
c47ccfd
Fixed more style
sankalp04 Jul 23, 2019
44dffd7
Including missed file in style fixes
sankalp04 Jul 23, 2019
79f2079
Finish annotations
sankalp04 Jul 23, 2019
092e370
Address comments for docs
sankalp04 Jul 23, 2019
bb8654f
Fix black's bare except error
sankalp04 Jul 23, 2019
188dec7
Fixed type annotations
sankalp04 Jul 24, 2019
3b63502
Fixed final type annotation error
sankalp04 Jul 24, 2019
d93f79e
Fix training docs
sankalp04 Jul 24, 2019
f79263b
Merge branch 'develop' into develop-generalizationTraining
sankalp04 Jul 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions config/generalize_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
episode-length: 5000

mass:
sampler-type: "uniform"
min_value: 0.5
max_value: 10

gravity:
sampler-type: "uniform"
min_value: 7
max_value: 12

scale:
sampler-type: "uniform"
min_value: 0.75
max_value: 3
105 changes: 105 additions & 0 deletions docs/Training-Generalization-Learning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Training Generalized Reinforcement Learning agents
Copy link
Contributor

@ervteng ervteng Jul 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: capitalize agents


Reinforcement learning has a rather unique setup as opposed to supervised and
unsupervised learning. Agents here are trained and tested on the same exact environment,
which is analogous to a model being trained and tested on an identical dataset in supervised learning! This setting results in overfitting; the inability of the agent to generalize to slight tweaks or variations in the environment. This is problematic in instances when environments are randomly instantiated with varying properties. To make agents more robust, we train an agent over multiple variations of the environment. The agent is trained with the intent that it learns to maintain a minimum performance regardless of the environment variant and that it generalizes to maintain this in unseen future variants of the environment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 3 and 4 are cut but line 5 is not. You need to make sure that the number of characters per line is consistent throughout the document.


Ball scale of 0.5 | Ball scale of 4
:-------------------------:|:-------------------------:
![](images/3dball_small.png) | ![](images/3dball_big.png)

_Variations of the 3D Ball environment._

To vary the environments, we first decided what the reset parameters for the environment. These parameters are known as `Reset Parameters`, introduced in [Curriculum Learning](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Curriculum-Learning.md). In the 3D ball environment example displayed in the figure above, the reset parameters are `gravity`, `ball_mass` and `ball_scale`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You say "we first decided what the reset parameters are" and then introduce the reset Parameters in the following sentence. This is confusing. Also, reset parameters were not introduced in Curriculum learning, so I would say Curriculum learning also uses reset parameters instead of introduced them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To vary the environments, we first decided what the reset parameters for the environment. This is grammatically incorrect I think.


## In-practice

To test the effectiveness of this training procedure, we train 3 models over 50000 steps:
1. Model Trained on Default reset parameter values (Default model).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This numbering does not look good in the markdown. I think you need to skip a line somewhere for it to be a numbered list.

2. Model Trained on a range of reset parameter values (Random model). Reset parameter values are picked uniformly over the range and the model is trained on each configuration for 5000 steps before they are randomly sampled again. The range consists of the default values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ranges consist of the default values --> What does this mean ?

3. Model Trained on a range of reset parameters values (Extreme model). The range here are the extreme ends, which lie outside the range used to train the Random model. Reset parameter values are selected randomly and the model is trained for 5000 steps on each configuration of values before resampling the values again.

Then, these 3 models are tested for their performance in all the reset parameter settings for 50000 steps each. The figure below outlines the performance of each model.

![3DBall environment mean rewards](images/3DBall_generalization_rewards.png)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results are bad. I would not put them in the documentation for the release.


In a relatively simple environment we are able to see differences in the performance between differently trained models, with the models exposed to various environments performing better in variations of environments than the default model trained over a single environment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not obvious in the graph provided.


## How-to

For generalization training, we need to provide a way to modify the environment by supplying a set of reset parameters. This provision can be either deterministic or randomized. Each reset parameter is assigned a sampler. If a sampler isn't provided for a reset parameter, the parameter maintains the default value throughout the training, remaining unchanged. The samplers for all the reset parameters are handled by a **Sampler Manager**, which is also responsible for generating a new set of values for the reset parameters when needed.

To setup the Sampler Manager, we setup a YAML file that specifies how we wish to generate new samples. In this file, we specify the samplers and the `resampling-duration` (number of training steps after which reset parameters are resampled). Below is an example of a sampler file for the 3D ball environment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resampling duration does not exist in the YAML provided.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of training steps after which reset parameters are resampled --> these are not necessarily training steps, they are just simulations steps.


```yaml
episode-length: 5000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which on is it ? resampling-duration or episode length ?


mass:
sampler-type: "uniform"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You provide multi-range, normal and uniform sampling, you need to explain what they are and what their arguments correspond to.

min_value: 0.5
max_value: 10

gravity:
sampler-type: "multirange_uniform"
intervals: [[7, 10], [15, 20]]

scale:
sampler-type: "uniform"
min_value: 0.75
max_value: 3

```

* `resampling-duration` (int) - Specifies the number of steps for agent to train under a particular environment configuration before resetting the environment with a new sample of reset parameters.

* `parameter_name` - Name of the reset parameter. This should match the name specified in the academy of the intended environment for which the agent is being trained. If a parameter specified in the file doesn't exist in the environment, then this specification will be ignored.

* `sampler-type` - Specify the sampler type to use for the reset parameter. This is a string that should exist in the `Sampler Factory` (explained below).

* `sub-arguments` - Specify the characteristic parameters for the sampler. In the example sampler file above, this would correspond to the `intervals` key under the `multirange_uniform` sampler for the gravity reset parameter. The key name should match the name of the corresponding argument in the sampler definition. (Look at defining a new sampler method)

The sampler manager allocates a sampler for a reset parameter by using the *Sampler Factory*, which maintains a dictionary mapping of string keys to sampler objects. The available samplers to be used for reset parameter resampling is as available in the Sampler Factory.

The implementation of the samplers can be found at `ml-agents-envs/mlagents/envs/sampler_class.py`.

### Defining a new sampler method

Custom sampling techniques must inherit from the *Sampler* base class (included in the `sampler_class` file) and preserve the interface. Once the class for the required method is specified, it must be registered in the Sampler Factory.

This can be done by subscribing to the *register_sampler* method of the SamplerFactory. The command is as follows:

`SamplerFactory.register_sampler(*custom_sampler_string_key*, *custom_sampler_object*)`

Once the Sampler Factory reflects the new register, the custom sampler can be used for resampling reset parameter. For demonstration, lets say our sampler was implemented as below, and we register the `CustomSampler` class with the string `custom-sampler` in the Sampler Factory.

```python
class CustomSampler(Sampler):

def __init__(self, argA, argB, argC):
self.possible_vals = [argA, argB, argC]

def sample_all(self):
return np.random.choice(self.possible_vals)
```

Now we need to specify this sampler in the sampler file. Lets say we wish to use this sampler for the reset parameter *mass*; the sampler file would specify the same for mass as the following.

```yaml
mass:
sampler-type: "custom-sampler"
argA: 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the ordering of the args relevant or only their names ? This should be specified.

argB: 2
argC: 3
```

With the sampler file setup, we can proceed to train our agent as explained in the next section.

### Training with Generalization Learning

We first begin with setting up the sampler file. After the sampler file is defined and configured, we proceed by launching `mlagents-learn` and specify our configured sampler file with the `--sampler` flag. To demonstrate, if we wanted to train a 3D ball agent with generalization using the `generalization-test.yaml` sampling setup, we can run

```sh
mlagents-learn config/trainer_config.yaml --sampler=config/generalize_test.yaml --run-id=3D-Ball-generalization --train
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generalization-test.yaml =/= config/generalize_test.yaml

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the sample code you use config/generalize_test.yaml but in the text you use config/generalization-test.yaml

```

We can observe progress and metrics via Tensorboard.
4 changes: 4 additions & 0 deletions docs/Training-ML-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ environment, you can set the following command line options when invoking
* `--curriculum=<file>` – Specify a curriculum JSON file for defining the
lessons for curriculum training. See [Curriculum
Training](Training-Curriculum-Learning.md) for more information.
* `--sampler=<file>` - Specify a sampler YAML file for defining the
sampler for generalization training. See [Generalization
Training](Training-Generalization-Learning.md) for more information.
* `--keep-checkpoints=<n>` – Specify the maximum number of model checkpoints to
keep. Checkpoints are saved after the number of steps specified by the
`save-freq` option. Once the maximum number of checkpoints has been reached,
Expand Down Expand Up @@ -197,6 +200,7 @@ are conducting, see:
* [Training with PPO](Training-PPO.md)
* [Using Recurrent Neural Networks](Feature-Memory.md)
* [Training with Curriculum Learning](Training-Curriculum-Learning.md)
* [Training with Generalization](Training-Generalization-Learning.md)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Training with Generalization sounds weird to me. I would say Training with Environment Parameters Sampling

* [Training with Imitation Learning](Training-Imitation-Learning.md)

You can also compare the
Expand Down
Binary file added docs/images/3DBall_generalization_rewards.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/3dball_big.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/3dball_small.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions ml-agents-envs/mlagents/envs/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class UnityActionException(UnityException):
pass


class SamplerException(UnityException):
"""
Related to errors with the sampler actions.
"""

pass


class UnityTimeOutException(UnityException):
"""
Related to errors with communication timeouts.
Expand Down
113 changes: 113 additions & 0 deletions ml-agents-envs/mlagents/envs/sampler_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import numpy as np
from typing import *
from functools import *
from collections import OrderedDict
from abc import ABC, abstractmethod

from .exception import SamplerException


class Sampler(ABC):
@abstractmethod
def sample_parameter(self) -> float:
pass


class UniformSampler(Sampler):
# kwargs acts as a sink for extra unneeded args
def __init__(self, min_value, max_value, **kwargs):
self.min_value = min_value
self.max_value = max_value

def sample_parameter(self) -> float:
return np.random.uniform(self.min_value, self.max_value)


class MultiRangeUniformSampler(Sampler):
def __init__(self, intervals, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should add more comments one the classes. You need to define what the sampler does (Uniform vs normal for instance) as well as the parameters type and function. For example : Intervals is a list of floats, this needs to be specified somewhere.

self.intervals = intervals
# Measure the length of the intervals
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]
# Cumulative size of the intervals
cum_interval_length = sum(interval_lengths)
# Assign weights to an interval proportionate to the interval size
self.interval_weights = [x / cum_interval_length for x in interval_lengths]

def sample_parameter(self) -> float:
cur_min, cur_max = self.intervals[
np.random.choice(len(self.intervals), p=self.interval_weights)
]
return np.random.uniform(cur_min, cur_max)


class GaussianSampler(Sampler):
def __init__(self, mean, var, **kwargs):
self.mean = mean
self.var = var

def sample_parameter(self) -> float:
return np.random.normal(self.mean, self.var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variance =/= standard deviation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, let me fix the naming



# To introduce new sampling methods, just need to 'register' them to this sampler factory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should be formatted consistently with the other comments of the repo.

class SamplerFactory:
NAME_TO_CLASS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn I like this.

"uniform": UniformSampler,
"gaussian": GaussianSampler,
"multirange_uniform": MultiRangeUniformSampler,
}

@staticmethod
def register_sampler(name, sampler_cls):
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls

@staticmethod
def init_sampler_class(name, param_dict):
if name not in SamplerFactory.NAME_TO_CLASS:
raise SamplerException(
name + " sampler is not registered in the SamplerFactory."
" Use the register_sample method to register the string"
" associated to your sampler in the SamplerFactory."
)
sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
try:
return sampler_cls(**param_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a parameter dict or a list of parameters ? I think this variable should be named params (I think it is usually called this way)

except:
raise SamplerException(
"The sampler class associated to the " + name + " key in the factory "
"was not provided the required arguments. Please ensure that the sampler "
"config file consists of the appropriate keys for this sampler class."
)


class SamplerManager:
def __init__(self, reset_param_dict):
self.reset_param_dict = reset_param_dict if reset_param_dict else {}
assert isinstance(self.reset_param_dict, dict)
self.samplers = OrderedDict()
for param_name, cur_param_dict in self.reset_param_dict.items():
if "sampler-type" not in cur_param_dict:
raise SamplerException(
"'sampler_type' argument hasn't been supplied for the {0} parameter".format(
param_name
)
)
sampler_name = cur_param_dict.pop("sampler-type")
param_sampler = SamplerFactory.init_sampler_class(
sampler_name, cur_param_dict
)

self.samplers[param_name] = param_sampler

def is_empty(self) -> bool:
"""
If self.samplers is empty, then bool of it returns false, indicating that the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence is confusing. Shouldn't is_empty be true when the samplers is empty ?

sampler manager isn't managing any samplers.
"""
return not bool(self.samplers)

def sample_all(self) -> Dict[str, float]:
res = {}
for param_name, param_sampler in list(self.samplers.items()):
res[param_name] = param_sampler.sample_parameter()
return res
99 changes: 99 additions & 0 deletions ml-agents-envs/mlagents/envs/tests/test_sampler_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from math import isclose
import pytest

from mlagents.envs.sampler_class import SamplerManager
from mlagents.envs.sampler_class import (
UniformSampler,
MultiRangeUniformSampler,
GaussianSampler,
)
from mlagents.envs.exception import UnityException


def basic_3Dball_sampler():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason these need to be functions? Why not just put the dict literal in the test function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just put them as function for cleaner code along with consistency with other test files. Nothing special otherwise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather this function not have the name 3Dball in it. Call it sampler1 or something. The test should in my opinion be agnostic of the names of the demo environments

return {
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10},
"gravity": {
"sampler-type": "multirange_uniform",
"intervals": [[8, 11], [15, 20]],
},
}


def check_value_in_intervals(val, intervals):
check_in_bounds = [a <= val <= b for a, b in intervals]
return any(check_in_bounds)


def test_3Dball_sampler():
config = basic_3Dball_sampler()
sampler = SamplerManager(config)

assert sampler.is_empty() is False
assert isinstance(sampler.samplers["mass"], UniformSampler)
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler)

cur_sample = sampler.sample_all()

# Check uniform sampler for mass
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"]
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"]
assert config["mass"]["min_value"] <= cur_sample["mass"]
assert config["mass"]["max_value"] >= cur_sample["mass"]

# Check multirange_uniform sampler for gravity
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"]
assert check_value_in_intervals(
cur_sample["gravity"], sampler.samplers["gravity"].intervals
)


def basic_tennis_sampler():
return {"angle": {"sampler-type": "gaussian", "mean": 0, "var": 1}}


def test_tennis_sampler():
config = basic_tennis_sampler()
sampler = SamplerManager(config)
assert sampler.is_empty() is False
assert isinstance(sampler.samplers["angle"], GaussianSampler)

cur_sample = sampler.sample_all()

# Check angle gaussian sampler
assert sampler.samplers["angle"].mean == config["angle"]["mean"]
assert sampler.samplers["angle"].var == config["angle"]["var"]


def test_empty_samplers():
empty_sampler = SamplerManager({})
assert empty_sampler.is_empty()
empty_cur_sample = empty_sampler.sample_all()
assert empty_cur_sample == {}

none_sampler = SamplerManager(None)
assert none_sampler.is_empty()
none_cur_sample = none_sampler.sample_all()
assert none_cur_sample == {}


def incorrect_uniform_sampler():
# Do not specify required arguments to uniform sampler
return {"mass": {"sampler-type": "uniform", "min-value": 10}}


def incorrect_sampler_config():
# Do not specify 'sampler-type' key
return {"mass": {"min-value": 2, "max-value": 30}}


def test_incorrect_uniform_sampler():
config = incorrect_uniform_sampler()
with pytest.raises(UnityException):
SamplerManager(config)


def test_incorrect_sampler():
config = incorrect_sampler_config()
with pytest.raises(UnityException):
SamplerManager(config)
2 changes: 2 additions & 0 deletions ml-agents/mlagents/trainers/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ class MetaCurriculumError(TrainerError):
"""
Any error related to the configuration of a metacurriculum.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra line ?

pass
Loading