From 151bf5802e25d44b1f915d394a4900b9ea6dcc74 Mon Sep 17 00:00:00 2001 From: Andrew Cohen Date: Wed, 20 Jan 2021 15:51:42 -0500 Subject: [PATCH] use singular entity embedding --- .../trainers/tests/torch/test_attention.py | 26 ++- .../mlagents/trainers/torch/attention.py | 156 +++++++++--------- ml-agents/mlagents/trainers/torch/layers.py | 14 ++ ml-agents/mlagents/trainers/torch/networks.py | 26 ++- 4 files changed, 119 insertions(+), 103 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_attention.py b/ml-agents/mlagents/trainers/tests/torch/test_attention.py index d771e324c6..631a5c6761 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_attention.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_attention.py @@ -4,7 +4,7 @@ from mlagents.trainers.torch.layers import linear_layer from mlagents.trainers.torch.attention import ( MultiHeadAttention, - EntityEmbeddings, + EntityEmbedding, ResidualSelfAttention, ) @@ -70,7 +70,7 @@ def generate_input_helper(pattern): input_1 = generate_input_helper(masking_pattern_1) input_2 = generate_input_helper(masking_pattern_2) - masks = EntityEmbeddings.get_masks([input_1, input_2]) + masks = ResidualSelfAttention.get_masks([input_1, input_2]) assert len(masks) == 2 masks_1 = masks[0] masks_2 = masks[1] @@ -87,18 +87,16 @@ def test_simple_transformer_training(): torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 - entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k]) - transformer = ResidualSelfAttention(embedding_size, [n_k]) + entity_embeddings = EntityEmbedding(size, size, n_k, embedding_size) + transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001 ) batch_size = 200 - point_range = 3 - init_error = -1.0 - for _ in range(250): - center = torch.rand((batch_size, size)) * point_range * 2 - point_range - key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range + for _ in range(200): + center = torch.rand((batch_size, size)) + key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum( @@ -111,19 +109,15 @@ def test_simple_transformer_training(): target = torch.stack(target, dim=0) target = target.detach() - embeddings = entity_embeddings(center, [key]) - masks = EntityEmbeddings.get_masks([key]) + embeddings = entity_embeddings(center, key) + masks = ResidualSelfAttention.get_masks([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target) ** 2, dim=1) error = torch.mean(error) / 2 - if init_error == -1.0: - init_error = error.item() - else: - assert error.item() < init_error print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() - assert error.item() < 0.3 + assert error.item() < 0.02 diff --git a/ml-agents/mlagents/trainers/torch/attention.py b/ml-agents/mlagents/trainers/torch/attention.py index b95ae541ca..622414b3a0 100644 --- a/ml-agents/mlagents/trainers/torch/attention.py +++ b/ml-agents/mlagents/trainers/torch/attention.py @@ -1,26 +1,36 @@ from mlagents.torch_utils import torch from typing import Tuple, Optional, List -from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer +from mlagents.trainers.torch.layers import ( + LinearEncoder, + Initialization, + linear_layer, + LayerNorm, +) from mlagents.trainers.torch.model_serialization import exporting_to_onnx from mlagents.trainers.exception import UnityTrainerException class MultiHeadAttention(torch.nn.Module): - """ - Multi Head Attention module. We do not use the regular Torch implementation since - Barracuda does not support some operators it uses. - Takes as input to the forward method 3 tensors: - - query: of dimensions (batch_size, number_of_queries, embedding_size) - - key: of dimensions (batch_size, number_of_keys, embedding_size) - - value: of dimensions (batch_size, number_of_keys, embedding_size) - The forward method will return 2 tensors: - - The output: (batch_size, number_of_queries, embedding_size) - - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) - """ NEG_INF = -1e6 def __init__(self, embedding_size: int, num_heads: int): + """ + Multi Head Attention module. We do not use the regular Torch implementation since + Barracuda does not support some operators it uses. + Takes as input to the forward method 3 tensors: + - query: of dimensions (batch_size, number_of_queries, embedding_size) + - key: of dimensions (batch_size, number_of_keys, embedding_size) + - value: of dimensions (batch_size, number_of_keys, embedding_size) + The forward method will return 2 tensors: + - The output: (batch_size, number_of_queries, embedding_size) + - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) + :param embedding_size: The size of the embeddings that will be generated (should be + dividable by the num_heads) + :param total_max_elements: The maximum total number of entities that can be passed to + the module + :param num_heads: The number of heads of the attention module + """ super().__init__() self.n_heads = num_heads self.head_size: int = embedding_size // self.n_heads @@ -77,37 +87,29 @@ def forward( return value_attention, att -class EntityEmbeddings(torch.nn.Module): - """ - A module used to embed entities before passing them to a self-attention block. - Used in conjunction with ResidualSelfAttention to encode information about a self - and additional entities. Can also concatenate self to entities for ego-centric self- - attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf. - """ - +class EntityEmbedding(torch.nn.Module): def __init__( self, x_self_size: int, - entity_sizes: List[int], + entity_size: int, + entity_num_max_elements: Optional[int], embedding_size: int, - entity_num_max_elements: Optional[List[int]] = None, concat_self: bool = True, ): """ Constructs an EntityEmbeddings module. :param x_self_size: Size of "self" entity. - :param entity_sizes: List of sizes for other entities. Should be of length - equivalent to the number of entities. - :param embedding_size: Embedding size for entity encoders. - :param entity_num_max_elements: Maximum elements in an entity, None for unrestricted. + :param entity_size: Size of other entitiy. + :param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted. Needs to be assigned in order for model to be exportable to ONNX and Barracuda. - :param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric + :param embedding_size: Embedding size for the entity encoder. + :param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric self-attention. """ super().__init__() self.self_size: int = x_self_size - self.entity_sizes: List[int] = entity_sizes - self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes) + self.entity_size: int = entity_size + self.entity_num_max_elements: int = -1 if entity_num_max_elements is not None: self.entity_num_max_elements = entity_num_max_elements @@ -115,60 +117,33 @@ def __init__( # If not concatenating self, input to encoder is just entity size if not concat_self: self.self_size = 0 - self.ent_encoders = torch.nn.ModuleList( - [ - LinearEncoder( - self.self_size + ent_size, - 1, - embedding_size, - kernel_init=Initialization.Normal, - kernel_gain=(0.125 / embedding_size) ** 0.5, - ) - for ent_size in self.entity_sizes - ] + # Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf + self.ent_encoder = LinearEncoder( + self.self_size + self.entity_size, + 1, + embedding_size, + kernel_init=Initialization.Normal, + kernel_gain=(0.125 / embedding_size) ** 0.5, ) - def forward( - self, x_self: torch.Tensor, entities: List[torch.Tensor] - ) -> Tuple[torch.Tensor, int]: + def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: if self.concat_self: + num_entities = self.entity_num_max_elements + if num_entities < 0: + if exporting_to_onnx.is_exporting(): + raise UnityTrainerException( + "Trying to export an attention mechanism that doesn't have a set max \ + number of elements." + ) + num_entities = entities.shape[1] + expanded_self = x_self.reshape(-1, 1, self.self_size) + expanded_self = torch.cat([expanded_self] * num_entities, dim=1) # Concatenate all observations with self - self_and_ent: List[torch.Tensor] = [] - for num_entities, ent in zip(self.entity_num_max_elements, entities): - if num_entities < 0: - if exporting_to_onnx.is_exporting(): - raise UnityTrainerException( - "Trying to export an attention mechanism that doesn't have a set max \ - number of elements." - ) - num_entities = ent.shape[1] - expanded_self = x_self.reshape(-1, 1, self.self_size) - expanded_self = torch.cat([expanded_self] * num_entities, dim=1) - self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) - else: - self_and_ent = entities - # Encode and concatenate entites - encoded_entities = torch.cat( - [ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], - dim=1, - ) + entities = torch.cat([expanded_self, entities], dim=2) + # Encode entities + encoded_entities = self.ent_encoder(entities) return encoded_entities - @staticmethod - def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: - """ - Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was - all zeros (on dimension 2) and 0 otherwise. This is used in the Attention - layer to mask the padding observations. - """ - with torch.no_grad(): - # Generate the masking tensors for each entities tensor (mask only if all zeros) - key_masks: List[torch.Tensor] = [ - (torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor) - for ent in observations - ] - return key_masks - class ResidualSelfAttention(torch.nn.Module): """ @@ -182,7 +157,7 @@ class ResidualSelfAttention(torch.nn.Module): def __init__( self, embedding_size: int, - entity_num_max_elements: Optional[List[int]] = None, + entity_num_max_elements: Optional[int] = None, num_heads: int = 4, ): """ @@ -198,13 +173,13 @@ def __init__( super().__init__() self.max_num_ent: Optional[int] = None if entity_num_max_elements is not None: - _entity_num_max_elements = entity_num_max_elements - self.max_num_ent = sum(_entity_num_max_elements) + self.max_num_ent = entity_num_max_elements self.attention = MultiHeadAttention( num_heads=num_heads, embedding_size=embedding_size ) + # Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf self.fc_q = linear_layer( embedding_size, embedding_size, @@ -229,10 +204,14 @@ def __init__( kernel_init=Initialization.Normal, kernel_gain=(0.125 / embedding_size) ** 0.5, ) + self.embedding_norm = LayerNorm() + self.residual_norm = LayerNorm() def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: # Gather the maximum number of entities information mask = torch.cat(key_masks, dim=1) + + inp = self.embedding_norm(inp) # Feed to self attention query = self.fc_q(inp) # (b, n_q, emb) key = self.fc_k(inp) # (b, n_k, emb) @@ -252,9 +231,24 @@ def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Ten output, _ = self.attention(query, key, value, num_ent, num_ent, mask) # Residual output = self.fc_out(output) + inp + output = self.residual_norm(output) # Average Pooling numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1) denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON output = numerator / denominator - # Residual between x_self and the output of the module return output + + @staticmethod + def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was + all zeros (on dimension 2) and 0 otherwise. This is used in the Attention + layer to mask the padding observations. + """ + with torch.no_grad(): + # Generate the masking tensors for each entities tensor (mask only if all zeros) + key_masks: List[torch.Tensor] = [ + (torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor) + for ent in observations + ] + return key_masks diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index cd48c3f8ce..becfb0f10d 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -115,6 +115,20 @@ def forward( pass +class LayerNorm(torch.nn.Module): + """ + A vanilla implementation of layer normalization https://arxiv.org/pdf/1607.06450.pdf + norm_x = (x - mean) / sqrt((x - mean) ^ 2) + This does not include the trainable parameters gamma and beta for performance speed. + Typically, this is norm_x * gamma + beta + """ + + def forward(self, layer_activations: torch.Tensor) -> torch.Tensor: + mean = torch.mean(layer_activations, dim=-1, keepdim=True) + var = torch.mean((layer_activations - mean) ** 2, dim=-1, keepdim=True) + return (layer_activations - mean) / (torch.sqrt(var + 1e-5)) + + class LinearEncoder(torch.nn.Module): """ Linear layers. diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 1f23a12d56..23b099289e 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -14,7 +14,7 @@ from mlagents.trainers.torch.encoders import VectorInput from mlagents.trainers.buffer import AgentBuffer from mlagents.trainers.trajectory import ObsUtil -from mlagents.trainers.torch.attention import ResidualSelfAttention, EntityEmbeddings +from mlagents.trainers.torch.attention import ResidualSelfAttention, EntityEmbedding ActivationFunction = Callable[[torch.Tensor], torch.Tensor] @@ -139,9 +139,13 @@ def __init__( + sum(self.action_spec.discrete_branches) + self.action_spec.continuous_size ) - self.entity_encoder = EntityEmbeddings( - 0, [obs_only_ent_size, q_ent_size], self.h_size, concat_self=False + self.obs_encoder = EntityEmbedding( + 0, obs_only_ent_size, None, self.h_size, concat_self=False ) + self.obs_action_encoder = EntityEmbedding( + 0, q_ent_size, None, self.h_size, concat_self=False + ) + self.self_attn = ResidualSelfAttention(self.h_size) encoder_input_size = self.h_size @@ -223,7 +227,9 @@ def forward( value_masks = self._get_masks_from_nans(value_inputs) q_masks = self._get_masks_from_nans(q_inputs) - encoded_entity = self.entity_encoder(None, [value_input_concat, q_input_concat]) + encoded_obs = self.obs_encoder(None, value_input_concat) + encoded_obs_action = self.obs_action_encoder(None, q_input_concat) + encoded_entity = torch.cat([encoded_obs, encoded_obs_action], dim=1) encoded_state = self.self_attn(encoded_entity, [value_masks, q_masks]) if len(concat_encoded_obs) == 0: @@ -643,7 +649,11 @@ def critic_pass( all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length ) value_outputs, critic_mem_out = self.critic( - critic_obs, [inputs], [actions], memories=critic_mem, sequence_length=sequence_length + critic_obs, + [inputs], + [actions], + memories=critic_mem, + sequence_length=sequence_length, ) if mar_value_outputs is None: mar_value_outputs = value_outputs @@ -678,7 +688,11 @@ def get_stats_and_value( ) value_outputs, critic_mem_outs = self.critic( - [inputs], critic_obs, actions, memories=critic_mem, sequence_length=sequence_length + [inputs], + critic_obs, + actions, + memories=critic_mem, + sequence_length=sequence_length, ) return log_probs, entropies, value_outputs, mar_value_outputs