diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index e97c9d3346..59ea701e0b 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -14,11 +14,11 @@ and this project adheres to
### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
-
+- The `encoding_size` setting for RewardSignals has been deprecated. Please use `network_settings` instead. (#4982)
### Bug Fixes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
-
+- An issue that caused `GAIL` to fail for environments where agents can terminate episodes by self-sacrifice has been fixed. (#4971)
## [1.8.0-preview] - 2021-02-17
### Major Changes
diff --git a/config/imitation/CrawlerStatic.yaml b/config/imitation/CrawlerStatic.yaml
index 6bda49e1ef..0cdcdd08f7 100644
--- a/config/imitation/CrawlerStatic.yaml
+++ b/config/imitation/CrawlerStatic.yaml
@@ -19,7 +19,11 @@ behaviors:
gail:
gamma: 0.99
strength: 1.0
- encoding_size: 128
+ network_settings:
+ normalize: true
+ hidden_units: 128
+ num_layers: 2
+ vis_encode_type: simple
learning_rate: 0.0003
use_actions: false
use_vail: false
diff --git a/config/imitation/FoodCollector.yaml b/config/imitation/FoodCollector.yaml
index 614772331c..a05bc3c2fe 100644
--- a/config/imitation/FoodCollector.yaml
+++ b/config/imitation/FoodCollector.yaml
@@ -19,7 +19,11 @@ behaviors:
gail:
gamma: 0.99
strength: 0.1
- encoding_size: 128
+ network_settings:
+ normalize: false
+ hidden_units: 128
+ num_layers: 2
+ vis_encode_type: simple
learning_rate: 0.0003
use_actions: false
use_vail: false
diff --git a/config/imitation/Hallway.yaml b/config/imitation/Hallway.yaml
index 709bd33d7b..1678abc783 100644
--- a/config/imitation/Hallway.yaml
+++ b/config/imitation/Hallway.yaml
@@ -24,8 +24,7 @@ behaviors:
strength: 1.0
gail:
gamma: 0.99
- strength: 0.1
- encoding_size: 128
+ strength: 0.01
learning_rate: 0.0003
use_actions: false
use_vail: false
diff --git a/config/imitation/PushBlock.yaml b/config/imitation/PushBlock.yaml
index 693407ad71..57496ebbc1 100644
--- a/config/imitation/PushBlock.yaml
+++ b/config/imitation/PushBlock.yaml
@@ -16,16 +16,28 @@ behaviors:
num_layers: 2
vis_encode_type: simple
reward_signals:
- gail:
+ extrinsic:
gamma: 0.99
strength: 1.0
- encoding_size: 128
+ gail:
+ gamma: 0.99
+ strength: 0.01
+ network_settings:
+ normalize: false
+ hidden_units: 128
+ num_layers: 2
+ vis_encode_type: simple
learning_rate: 0.0003
use_actions: false
use_vail: false
demo_path: Project/Assets/ML-Agents/Examples/PushBlock/Demos/ExpertPush.demo
keep_checkpoints: 5
- max_steps: 15000000
+ max_steps: 1000000
time_horizon: 64
summary_freq: 60000
threaded: true
+ behavioral_cloning:
+ demo_path: Project/Assets/ML-Agents/Examples/PushBlock/Demos/ExpertPush.demo
+ steps: 50000
+ strength: 1.0
+ samples_per_update: 0
diff --git a/config/imitation/Pyramids.yaml b/config/imitation/Pyramids.yaml
index 826a9f683e..813a4e54a7 100644
--- a/config/imitation/Pyramids.yaml
+++ b/config/imitation/Pyramids.yaml
@@ -22,11 +22,11 @@ behaviors:
curiosity:
strength: 0.02
gamma: 0.99
- encoding_size: 256
+ network_settings:
+ hidden_units: 256
gail:
strength: 0.01
gamma: 0.99
- encoding_size: 128
demo_path: Project/Assets/ML-Agents/Examples/Pyramids/Demos/ExpertPyramid.demo
behavioral_cloning:
demo_path: Project/Assets/ML-Agents/Examples/Pyramids/Demos/ExpertPyramid.demo
diff --git a/config/ppo/Pyramids.yaml b/config/ppo/Pyramids.yaml
index a68116cea4..000b9f8dbc 100644
--- a/config/ppo/Pyramids.yaml
+++ b/config/ppo/Pyramids.yaml
@@ -22,7 +22,8 @@ behaviors:
curiosity:
gamma: 0.99
strength: 0.02
- encoding_size: 256
+ network_settings:
+ hidden_units: 256
learning_rate: 0.0003
keep_checkpoints: 5
max_steps: 10000000
diff --git a/config/ppo/PyramidsRND.yaml b/config/ppo/PyramidsRND.yaml
index 6af2732b15..75aaad8ce1 100644
--- a/config/ppo/PyramidsRND.yaml
+++ b/config/ppo/PyramidsRND.yaml
@@ -22,11 +22,11 @@ behaviors:
rnd:
gamma: 0.99
strength: 0.01
- encoding_size: 64
+ network_settings:
+ hidden_units: 64
learning_rate: 0.0001
keep_checkpoints: 5
max_steps: 3000000
time_horizon: 128
summary_freq: 30000
- framework: pytorch
threaded: true
diff --git a/config/ppo/VisualPyramids.yaml b/config/ppo/VisualPyramids.yaml
index 48782626ad..102cbdaf64 100644
--- a/config/ppo/VisualPyramids.yaml
+++ b/config/ppo/VisualPyramids.yaml
@@ -22,7 +22,8 @@ behaviors:
curiosity:
gamma: 0.99
strength: 0.01
- encoding_size: 256
+ network_settings:
+ hidden_units: 256
learning_rate: 0.0003
keep_checkpoints: 5
max_steps: 10000000
diff --git a/config/sac/Pyramids.yaml b/config/sac/Pyramids.yaml
index b0797df503..b0bf26682d 100644
--- a/config/sac/Pyramids.yaml
+++ b/config/sac/Pyramids.yaml
@@ -24,7 +24,6 @@ behaviors:
gail:
gamma: 0.99
strength: 0.01
- encoding_size: 128
learning_rate: 0.0003
use_actions: true
use_vail: false
diff --git a/config/sac/VisualPyramids.yaml b/config/sac/VisualPyramids.yaml
index b840fb7762..c30eece10c 100644
--- a/config/sac/VisualPyramids.yaml
+++ b/config/sac/VisualPyramids.yaml
@@ -24,7 +24,6 @@ behaviors:
gail:
gamma: 0.99
strength: 0.02
- encoding_size: 128
learning_rate: 0.0003
use_actions: true
use_vail: false
diff --git a/docs/ML-Agents-Overview.md b/docs/ML-Agents-Overview.md
index 9b253e0949..f3568333e7 100644
--- a/docs/ML-Agents-Overview.md
+++ b/docs/ML-Agents-Overview.md
@@ -472,12 +472,23 @@ Learning (GAIL). In most scenarios, you can combine these two features:
- If you want to help your agents learn (especially with environments that have
sparse rewards) using pre-recorded demonstrations, you can generally enable
both GAIL and Behavioral Cloning at low strengths in addition to having an
- extrinsic reward. An example of this is provided for the Pyramids example
- environment under `PyramidsLearning` in `config/gail_config.yaml`.
-- If you want to train purely from demonstrations, GAIL and BC _without_ an
- extrinsic reward signal is the preferred approach. An example of this is
- provided for the Crawler example environment under `CrawlerStaticLearning` in
- `config/gail_config.yaml`.
+ extrinsic reward. An example of this is provided for the PushBlock example
+ environment in `config/imitation/PushBlock.yaml`.
+- If you want to train purely from demonstrations with GAIL and BC _without_ an
+ extrinsic reward signal, please see the CrawlerStatic example environment under
+ in `config/imitation/CrawlerStatic.yaml`.
+
+***Note:*** GAIL introduces a [_survivor bias_](https://arxiv.org/pdf/1809.02925.pdf)
+to the learning process. That is, by giving positive rewards based on similarity
+to the expert, the agent is incentivized to remain alive for as long as possible.
+This can directly conflict with goal-oriented tasks like our PushBlock or Pyramids
+example environments where an agent must reach a goal state thus ending the
+episode as quickly as possible. In these cases, we strongly recommend that you
+use a low strength GAIL reward signal and a sparse extrinisic signal when
+the agent achieves the task. This way, the GAIL reward signal will guide the
+agent until it discovers the extrnisic signal and will not overpower it. If the
+agent appears to be ignoring the extrinsic reward signal, you should reduce
+the strength of GAIL.
#### GAIL (Generative Adversarial Imitation Learning)
diff --git a/docs/Training-Configuration-File.md b/docs/Training-Configuration-File.md
index 9b7b875922..d9a691337a 100644
--- a/docs/Training-Configuration-File.md
+++ b/docs/Training-Configuration-File.md
@@ -101,7 +101,7 @@ To enable curiosity, provide these settings:
| :--------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `curiosity -> strength` | (default = `1.0`) Magnitude of the curiosity reward generated by the intrinsic curiosity module. This should be scaled in order to ensure it is large enough to not be overwhelmed by extrinsic reward signals in the environment. Likewise it should not be too large to overwhelm the extrinsic reward signal.
Typical range: `0.001` - `0.1` |
| `curiosity -> gamma` | (default = `0.99`) Discount factor for future rewards.
Typical range: `0.8` - `0.995` |
-| `curiosity -> encoding_size` | (default = `64`) Size of the encoding used by the intrinsic curiosity model. This value should be small enough to encourage the ICM to compress the original observation, but also not too small to prevent it from learning to differentiate between expected and actual observations.
Typical range: `64` - `256` |
+| `curiosity -> network_settings` | Please see the documentation for `network_settings` under [Common Trainer Configurations](#common-trainer-configurations). The network specs used by the intrinsic curiosity model. The value should of `hidden_units` should be small enough to encourage the ICM to compress the original observation, but also not too small to prevent it from learning to differentiate between expected and actual observations.
Typical range: `64` - `256` |
| `curiosity -> learning_rate` | (default = `3e-4`) Learning rate used to update the intrinsic curiosity module. This should typically be decreased if training is unstable, and the curiosity loss is unstable.
Typical range: `1e-5` - `1e-3` |
### GAIL Intrinsic Reward
@@ -114,7 +114,7 @@ settings:
| `gail -> strength` | (default = `1.0`) Factor by which to multiply the raw reward. Note that when using GAIL with an Extrinsic Signal, this value should be set lower if your demonstrations are suboptimal (e.g. from a human), so that a trained agent will focus on receiving extrinsic rewards instead of exactly copying the demonstrations. Keep the strength below about 0.1 in those cases.
Typical range: `0.01` - `1.0` |
| `gail -> gamma` | (default = `0.99`) Discount factor for future rewards.
Typical range: `0.8` - `0.9` |
| `gail -> demo_path` | (Required, no default) The path to your .demo file or directory of .demo files. |
-| `gail -> encoding_size` | (default = `64`) Size of the hidden layer used by the discriminator. This value should be small enough to encourage the discriminator to compress the original observation, but also not too small to prevent it from learning to differentiate between demonstrated and actual behavior. Dramatically increasing this size will also negatively affect training times.
Typical range: `64` - `256` |
+| `gail -> network_settings` | Please see the documentation for `network_settings` under [Common Trainer Configurations](#common-trainer-configurations). The network specs for the GAIL discriminator. The value of `hidden_units` should be small enough to encourage the discriminator to compress the original observation, but also not too small to prevent it from learning to differentiate between demonstrated and actual behavior. Dramatically increasing this size will also negatively affect training times.
Typical range: `64` - `256` |
| `gail -> learning_rate` | (Optional, default = `3e-4`) Learning rate used to update the discriminator. This should typically be decreased if training is unstable, and the GAIL loss is unstable.
Typical range: `1e-5` - `1e-3` |
| `gail -> use_actions` | (default = `false`) Determines whether the discriminator should discriminate based on both observations and actions, or just observations. Set to True if you want the agent to mimic the actions from the demonstrations, and False if you'd rather have the agent visit the same states as in the demonstrations but with possibly different actions. Setting to False is more likely to be stable, especially with imperfect demonstrations, but may learn slower. |
| `gail -> use_vail` | (default = `false`) Enables a variational bottleneck within the GAIL discriminator. This forces the discriminator to learn a more general representation and reduces its tendency to be "too good" at discriminating, making learning more stable. However, it does increase training time. Enable this if you notice your imitation learning is unstable, or unable to learn the task at hand. |
@@ -128,7 +128,7 @@ To enable RND, provide these settings:
| :--------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `rnd -> strength` | (default = `1.0`) Magnitude of the curiosity reward generated by the intrinsic rnd module. This should be scaled in order to ensure it is large enough to not be overwhelmed by extrinsic reward signals in the environment. Likewise it should not be too large to overwhelm the extrinsic reward signal.
Typical range: `0.001` - `0.01` |
| `rnd -> gamma` | (default = `0.99`) Discount factor for future rewards.
Typical range: `0.8` - `0.995` |
-| `rnd -> encoding_size` | (default = `64`) Size of the encoding used by the intrinsic RND model.
Typical range: `64` - `256` |
+| `rnd -> network_settings` | Please see the documentation for `network_settings` under [Common Trainer Configurations](#common-trainer-configurations). The network specs for the RND model. |
| `curiosity -> learning_rate` | (default = `3e-4`) Learning rate used to update the RND module. This should be large enough for the RND module to quickly learn the state representation, but small enough to allow for stable learning.
Typical range: `1e-5` - `1e-3`
diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py
index 16905949f5..a5e5d520cb 100644
--- a/ml-agents/mlagents/trainers/settings.py
+++ b/ml-agents/mlagents/trainers/settings.py
@@ -183,6 +183,7 @@ def to_settings(self) -> type:
class RewardSignalSettings:
gamma: float = 0.99
strength: float = 1.0
+ network_settings: NetworkSettings = attr.ib(factory=NetworkSettings)
@staticmethod
def structure(d: Mapping, t: type) -> Any:
@@ -198,13 +199,26 @@ def structure(d: Mapping, t: type) -> Any:
enum_key = RewardSignalType(key)
t = enum_key.to_settings()
d_final[enum_key] = strict_to_cls(val, t)
+ # Checks to see if user specifying deprecated encoding_size for RewardSignals.
+ # If network_settings is not specified, this updates the default hidden_units
+ # to the value of encoding size. If specified, this ignores encoding size and
+ # uses network_settings values.
+ if "encoding_size" in val:
+ logger.warning(
+ "'encoding_size' was deprecated for RewardSignals. Please use network_settings."
+ )
+ # If network settings was not specified, use the encoding size. Otherwise, use hidden_units
+ if "network_settings" not in val:
+ d_final[enum_key].network_settings.hidden_units = val[
+ "encoding_size"
+ ]
return d_final
@attr.s(auto_attribs=True)
class GAILSettings(RewardSignalSettings):
- encoding_size: int = 64
learning_rate: float = 3e-4
+ encoding_size: Optional[int] = None
use_actions: bool = False
use_vail: bool = False
demo_path: str = attr.ib(kw_only=True)
@@ -212,14 +226,14 @@ class GAILSettings(RewardSignalSettings):
@attr.s(auto_attribs=True)
class CuriositySettings(RewardSignalSettings):
- encoding_size: int = 64
learning_rate: float = 3e-4
+ encoding_size: Optional[int] = None
@attr.s(auto_attribs=True)
class RNDSettings(RewardSignalSettings):
- encoding_size: int = 64
learning_rate: float = 1e-4
+ encoding_size: Optional[int] = None
# SAMPLERS #############################################################################
diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
index 1bd6cf918a..1814b22ca5 100644
--- a/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
+++ b/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
@@ -9,14 +9,16 @@
from mlagents.trainers.settings import CuriositySettings
from mlagents_envs.base_env import BehaviorSpec
+from mlagents_envs import logging_util
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_flattener import ActionFlattener
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer
-from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.trajectory import ObsUtil
+logger = logging_util.get_logger(__name__)
+
class ActionPredictionTuple(NamedTuple):
continuous: torch.Tensor
@@ -70,13 +72,14 @@ class CuriosityNetwork(torch.nn.Module):
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
super().__init__()
self._action_spec = specs.action_spec
- state_encoder_settings = NetworkSettings(
- normalize=False,
- hidden_units=settings.encoding_size,
- num_layers=2,
- vis_encode_type=EncoderType.SIMPLE,
- memory=None,
- )
+
+ state_encoder_settings = settings.network_settings
+ if state_encoder_settings.memory is not None:
+ state_encoder_settings.memory = None
+ logger.warning(
+ "memory was specified in network_settings but is not supported by Curiosity. It is being ignored."
+ )
+
self._state_encoder = NetworkBody(
specs.observation_specs, state_encoder_settings
)
@@ -84,7 +87,7 @@ def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
self._action_flattener = ActionFlattener(self._action_spec)
self.inverse_model_action_encoding = torch.nn.Sequential(
- LinearEncoder(2 * settings.encoding_size, 1, 256)
+ LinearEncoder(2 * state_encoder_settings.hidden_units, 1, 256)
)
if self._action_spec.continuous_size > 0:
@@ -98,9 +101,12 @@ def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
self.forward_model_next_state_prediction = torch.nn.Sequential(
LinearEncoder(
- settings.encoding_size + self._action_flattener.flattened_size, 1, 256
+ state_encoder_settings.hidden_units
+ + self._action_flattener.flattened_size,
+ 1,
+ 256,
),
- linear_layer(256, settings.encoding_size),
+ linear_layer(256, state_encoder_settings.hidden_units),
)
def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
index 1a19731e7e..031fd22211 100644
--- a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
+++ b/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
@@ -8,20 +8,22 @@
)
from mlagents.trainers.settings import GAILSettings
from mlagents_envs.base_env import BehaviorSpec
+from mlagents_envs import logging_util
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_flattener import ActionFlattener
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.torch.layers import linear_layer, Initialization
-from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.trajectory import ObsUtil
+logger = logging_util.get_logger(__name__)
+
class GAILRewardProvider(BaseRewardProvider):
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
super().__init__(specs, settings)
- self._ignore_done = True
+ self._ignore_done = False
self._discriminator_network = DiscriminatorNetwork(specs, settings)
self._discriminator_network.to(default_device())
_, self._demo_buffer = demo_to_buffer(
@@ -44,9 +46,12 @@ def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
)
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
+
expert_batch = self._demo_buffer.sample_mini_batch(
mini_batch.num_experiences, 1
)
+ self._discriminator_network.encoder.update_normalization(expert_batch)
+
loss, stats_dict = self._discriminator_network.compute_loss(
mini_batch, expert_batch
)
@@ -72,13 +77,13 @@ def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
self._use_vail = settings.use_vail
self._settings = settings
- encoder_settings = NetworkSettings(
- normalize=False,
- hidden_units=settings.encoding_size,
- num_layers=2,
- vis_encode_type=EncoderType.SIMPLE,
- memory=None,
- )
+ encoder_settings = settings.network_settings
+ if encoder_settings.memory is not None:
+ encoder_settings.memory = None
+ logger.warning(
+ "memory was specified in network_settings but is not supported by GAIL. It is being ignored."
+ )
+
self._action_flattener = ActionFlattener(specs.action_spec)
unencoded_size = (
self._action_flattener.flattened_size + 1 if settings.use_actions else 0
@@ -87,14 +92,14 @@ def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
specs.observation_specs, encoder_settings, unencoded_size
)
- estimator_input_size = settings.encoding_size
+ estimator_input_size = encoder_settings.hidden_units
if settings.use_vail:
estimator_input_size = self.z_size
self._z_sigma = torch.nn.Parameter(
torch.ones((self.z_size), dtype=torch.float), requires_grad=True
)
self._z_mu_layer = linear_layer(
- settings.encoding_size,
+ encoder_settings.hidden_units,
self.z_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
index 2d41650fe6..8408b08b8d 100644
--- a/ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
+++ b/ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
@@ -9,11 +9,13 @@
from mlagents.trainers.settings import RNDSettings
from mlagents_envs.base_env import BehaviorSpec
+from mlagents_envs import logging_util
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
-from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.trajectory import ObsUtil
+logger = logging_util.get_logger(__name__)
+
class RNDRewardProvider(BaseRewardProvider):
"""
@@ -58,13 +60,13 @@ class RNDNetwork(torch.nn.Module):
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
super().__init__()
- state_encoder_settings = NetworkSettings(
- normalize=True,
- hidden_units=settings.encoding_size,
- num_layers=3,
- vis_encode_type=EncoderType.SIMPLE,
- memory=None,
- )
+ state_encoder_settings = settings.network_settings
+ if state_encoder_settings.memory is not None:
+ state_encoder_settings.memory = None
+ logger.warning(
+ "memory was specified in network_settings but is not supported by RND. It is being ignored."
+ )
+
self._encoder = NetworkBody(specs.observation_specs, state_encoder_settings)
def forward(self, mini_batch: AgentBuffer) -> torch.Tensor: