From 78e052be8f36381bb6857817ff0f505716be83b9 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 11 Feb 2021 12:02:42 -0500 Subject: [PATCH] Use attention tests from master --- .../trainers/tests/torch/test_attention.py | 71 ++++++-- .../mlagents/trainers/torch/attention.py | 161 +++++++++--------- 2 files changed, 137 insertions(+), 95 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_attention.py b/ml-agents/mlagents/trainers/tests/torch/test_attention.py index 4ae0e5137a..c914c28d79 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_attention.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_attention.py @@ -1,3 +1,4 @@ +import pytest from mlagents.torch_utils import torch import numpy as np @@ -5,8 +6,9 @@ from mlagents.trainers.torch.layers import linear_layer, LinearEncoder from mlagents.trainers.torch.attention import ( MultiHeadAttention, - EntityEmbeddings, + EntityEmbedding, ResidualSelfAttention, + get_zero_entities_mask, ) @@ -71,7 +73,7 @@ def generate_input_helper(pattern): input_1 = generate_input_helper(masking_pattern_1) input_2 = generate_input_helper(masking_pattern_2) - masks = EntityEmbeddings.get_masks([input_1, input_2]) + masks = get_zero_entities_mask([input_1, input_2]) assert len(masks) == 2 masks_1 = masks[0] masks_2 = masks[1] @@ -83,13 +85,60 @@ def generate_input_helper(pattern): assert masks_2[0, 1] == 0 if i % 2 == 0 else 1 +@pytest.mark.parametrize("mask_value", [0, 1]) +def test_all_masking(mask_value): + # We make sure that a mask of all zeros or all ones will not trigger an error + np.random.seed(1336) + torch.manual_seed(1336) + size, n_k, = 3, 5 + embedding_size = 64 + entity_embeddings = EntityEmbedding(size, n_k, embedding_size) + entity_embeddings.add_self_embedding(size) + transformer = ResidualSelfAttention(embedding_size, n_k) + l_layer = linear_layer(embedding_size, size) + optimizer = torch.optim.Adam( + list(entity_embeddings.parameters()) + + list(transformer.parameters()) + + list(l_layer.parameters()), + lr=0.001, + weight_decay=1e-6, + ) + batch_size = 20 + for _ in range(5): + center = torch.rand((batch_size, size)) + key = torch.rand((batch_size, n_k, size)) + with torch.no_grad(): + # create the target : The key closest to the query in euclidean distance + distance = torch.sum( + (center.reshape((batch_size, 1, size)) - key) ** 2, dim=2 + ) + argmin = torch.argmin(distance, dim=1) + target = [] + for i in range(batch_size): + target += [key[i, argmin[i], :]] + target = torch.stack(target, dim=0) + target = target.detach() + + embeddings = entity_embeddings(center, key) + masks = [torch.ones_like(key[:, :, 0]) * mask_value] + prediction = transformer.forward(embeddings, masks) + prediction = l_layer(prediction) + prediction = prediction.reshape((batch_size, size)) + error = torch.mean((prediction - target) ** 2, dim=1) + error = torch.mean(error) / 2 + optimizer.zero_grad() + error.backward() + optimizer.step() + + def test_predict_closest_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 - entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k]) - transformer = ResidualSelfAttention(embedding_size, [n_k]) + entity_embeddings = EntityEmbedding(size, n_k, embedding_size) + entity_embeddings.add_self_embedding(size) + transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) @@ -114,8 +163,8 @@ def test_predict_closest_training(): target = torch.stack(target, dim=0) target = target.detach() - embeddings = entity_embeddings(center, [key]) - masks = EntityEmbeddings.get_masks([key]) + embeddings = entity_embeddings(center, key) + masks = get_zero_entities_mask([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) @@ -135,14 +184,12 @@ def test_predict_minimum_training(): n_k = 5 size = n_k + 1 embedding_size = 64 - entity_embeddings = EntityEmbeddings( - size, [size], embedding_size, [n_k], concat_self=False - ) + entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self transformer = ResidualSelfAttention(embedding_size) l_layer = LinearEncoder(embedding_size, 2, n_k) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam( - list(entity_embeddings.parameters()) + list(entity_embedding.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, @@ -166,8 +213,8 @@ def test_predict_minimum_training(): sliced_oh = onehots[:, : num + 1] inp = torch.cat([inp, sliced_oh], dim=2) - embeddings = entity_embeddings(inp, [inp]) - masks = EntityEmbeddings.get_masks([inp]) + embeddings = entity_embedding(inp, inp) + masks = get_zero_entities_mask([inp]) prediction = transformer(embeddings, masks) prediction = l_layer(prediction) ce = loss(prediction, argmin) diff --git a/ml-agents/mlagents/trainers/torch/attention.py b/ml-agents/mlagents/trainers/torch/attention.py index 61c0cf7d80..9b503e2d98 100644 --- a/ml-agents/mlagents/trainers/torch/attention.py +++ b/ml-agents/mlagents/trainers/torch/attention.py @@ -10,22 +10,41 @@ from mlagents.trainers.exception import UnityTrainerException -class MultiHeadAttention(torch.nn.Module): +def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]: """ - Multi Head Attention module. We do not use the regular Torch implementation since - Barracuda does not support some operators it uses. - Takes as input to the forward method 3 tensors: - - query: of dimensions (batch_size, number_of_queries, embedding_size) - - key: of dimensions (batch_size, number_of_keys, embedding_size) - - value: of dimensions (batch_size, number_of_keys, embedding_size) - The forward method will return 2 tensors: - - The output: (batch_size, number_of_queries, embedding_size) - - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) + Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was + all zeros (on dimension 2) and 0 otherwise. This is used in the Attention + layer to mask the padding observations. """ + with torch.no_grad(): + # Generate the masking tensors for each entities tensor (mask only if all zeros) + key_masks: List[torch.Tensor] = [ + (torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations + ] + return key_masks + + +class MultiHeadAttention(torch.nn.Module): NEG_INF = -1e6 def __init__(self, embedding_size: int, num_heads: int): + """ + Multi Head Attention module. We do not use the regular Torch implementation since + Barracuda does not support some operators it uses. + Takes as input to the forward method 3 tensors: + - query: of dimensions (batch_size, number_of_queries, embedding_size) + - key: of dimensions (batch_size, number_of_keys, embedding_size) + - value: of dimensions (batch_size, number_of_keys, embedding_size) + The forward method will return 2 tensors: + - The output: (batch_size, number_of_queries, embedding_size) + - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) + :param embedding_size: The size of the embeddings that will be generated (should be + dividable by the num_heads) + :param total_max_elements: The maximum total number of entities that can be passed to + the module + :param num_heads: The number of heads of the attention module + """ super().__init__() self.n_heads = num_heads self.head_size: int = embedding_size // self.n_heads @@ -82,7 +101,7 @@ def forward( return value_attention, att -class EntityEmbeddings(torch.nn.Module): +class EntityEmbedding(torch.nn.Module): """ A module used to embed entities before passing them to a self-attention block. Used in conjunction with ResidualSelfAttention to encode information about a self @@ -92,95 +111,69 @@ class EntityEmbeddings(torch.nn.Module): def __init__( self, - x_self_size: int, - entity_sizes: List[int], + entity_size: int, + entity_num_max_elements: Optional[int], embedding_size: int, - entity_num_max_elements: Optional[List[int]] = None, - concat_self: bool = True, ): """ - Constructs an EntityEmbeddings module. + Constructs an EntityEmbedding module. :param x_self_size: Size of "self" entity. - :param entity_sizes: List of sizes for other entities. Should be of length - equivalent to the number of entities. - :param embedding_size: Embedding size for entity encoders. - :param entity_num_max_elements: Maximum elements in an entity, None for unrestricted. + :param entity_size: Size of other entities. + :param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted. Needs to be assigned in order for model to be exportable to ONNX and Barracuda. - :param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric + :param embedding_size: Embedding size for the entity encoder. + :param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric self-attention. """ super().__init__() - self.self_size: int = x_self_size - self.entity_sizes: List[int] = entity_sizes - self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes) + self.self_size: int = 0 + self.entity_size: int = entity_size + self.entity_num_max_elements: int = -1 if entity_num_max_elements is not None: self.entity_num_max_elements = entity_num_max_elements - - self.concat_self: bool = concat_self - # If not concatenating self, input to encoder is just entity size - if not concat_self: - self.self_size = 0 + self.embedding_size = embedding_size # Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf - self.ent_encoders = torch.nn.ModuleList( - [ - LinearEncoder( - self.self_size + ent_size, - 1, - embedding_size, - kernel_init=Initialization.Normal, - kernel_gain=(0.125 / embedding_size) ** 0.5, - ) - for ent_size in self.entity_sizes - ] + self.self_ent_encoder = LinearEncoder( + self.entity_size, + 1, + self.embedding_size, + kernel_init=Initialization.Normal, + kernel_gain=(0.125 / self.embedding_size) ** 0.5, ) - self.embedding_norm = LayerNorm() - def forward( - self, x_self: torch.Tensor, entities: List[torch.Tensor] - ) -> Tuple[torch.Tensor, int]: - if self.concat_self: - # Concatenate all observations with self - self_and_ent: List[torch.Tensor] = [] - for num_entities, ent in zip(self.entity_num_max_elements, entities): - if num_entities < 0: - if exporting_to_onnx.is_exporting(): - raise UnityTrainerException( - "Trying to export an attention mechanism that doesn't have a set max \ - number of elements." - ) - num_entities = ent.shape[1] - expanded_self = x_self.reshape(-1, 1, self.self_size) - expanded_self = torch.cat([expanded_self] * num_entities, dim=1) - self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) - else: - self_and_ent = entities - # Encode and concatenate entites - encoded_entities = torch.cat( - [ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], - dim=1, + def add_self_embedding(self, size: int) -> None: + self.self_size = size + self.self_ent_encoder = LinearEncoder( + self.self_size + self.entity_size, + 1, + self.embedding_size, + kernel_init=Initialization.Normal, + kernel_gain=(0.125 / self.embedding_size) ** 0.5, ) - encoded_entities = self.embedding_norm(encoded_entities) - return encoded_entities - @staticmethod - def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: - """ - Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was - all zeros (on dimension 2) and 0 otherwise. This is used in the Attention - layer to mask the padding observations. - """ - with torch.no_grad(): - # Generate the masking tensors for each entities tensor (mask only if all zeros) - key_masks: List[torch.Tensor] = [ - (torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations - ] - return key_masks + def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + if self.self_size > 0: + num_entities = self.entity_num_max_elements + if num_entities < 0: + if exporting_to_onnx.is_exporting(): + raise UnityTrainerException( + "Trying to export an attention mechanism that doesn't have a set max \ + number of elements." + ) + num_entities = entities.shape[1] + expanded_self = x_self.reshape(-1, 1, self.self_size) + expanded_self = torch.cat([expanded_self] * num_entities, dim=1) + # Concatenate all observations with self + entities = torch.cat([expanded_self, entities], dim=2) + # Encode entities + encoded_entities = self.self_ent_encoder(entities) + return encoded_entities class ResidualSelfAttention(torch.nn.Module): """ Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used - with an EntityEmbeddings module, to apply multi head self attention to encode information + with an EntityEmbedding module, to apply multi head self attention to encode information about a "Self" and a list of relevant "Entities". """ @@ -189,7 +182,7 @@ class ResidualSelfAttention(torch.nn.Module): def __init__( self, embedding_size: int, - entity_num_max_elements: Optional[List[int]] = None, + entity_num_max_elements: Optional[int] = None, num_heads: int = 4, ): """ @@ -205,8 +198,7 @@ def __init__( super().__init__() self.max_num_ent: Optional[int] = None if entity_num_max_elements is not None: - _entity_num_max_elements = entity_num_max_elements - self.max_num_ent = sum(_entity_num_max_elements) + self.max_num_ent = entity_num_max_elements self.attention = MultiHeadAttention( num_heads=num_heads, embedding_size=embedding_size @@ -237,11 +229,14 @@ def __init__( kernel_init=Initialization.Normal, kernel_gain=(0.125 / embedding_size) ** 0.5, ) + self.embedding_norm = LayerNorm() self.residual_norm = LayerNorm() def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: # Gather the maximum number of entities information mask = torch.cat(key_masks, dim=1) + + inp = self.embedding_norm(inp) # Feed to self attention query = self.fc_q(inp) # (b, n_q, emb) key = self.fc_k(inp) # (b, n_k, emb)