Skip to content

Commit

Permalink
Use attention tests from master
Browse files Browse the repository at this point in the history
  • Loading branch information
Ervin Teng committed Feb 11, 2021
1 parent 492fd17 commit 78e052b
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 95 deletions.
71 changes: 59 additions & 12 deletions ml-agents/mlagents/trainers/tests/torch/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest
from mlagents.torch_utils import torch
import numpy as np

from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.layers import linear_layer, LinearEncoder
from mlagents.trainers.torch.attention import (
MultiHeadAttention,
EntityEmbeddings,
EntityEmbedding,
ResidualSelfAttention,
get_zero_entities_mask,
)


Expand Down Expand Up @@ -71,7 +73,7 @@ def generate_input_helper(pattern):
input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)

masks = EntityEmbeddings.get_masks([input_1, input_2])
masks = get_zero_entities_mask([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]
Expand All @@ -83,13 +85,60 @@ def generate_input_helper(pattern):
assert masks_2[0, 1] == 0 if i % 2 == 0 else 1


@pytest.mark.parametrize("mask_value", [0, 1])
def test_all_masking(mask_value):
# We make sure that a mask of all zeros or all ones will not trigger an error
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,
weight_decay=1e-6,
)
batch_size = 20
for _ in range(5):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum(
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()

embeddings = entity_embeddings(center, key)
masks = [torch.ones_like(key[:, :, 0]) * mask_value]
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
optimizer.zero_grad()
error.backward()
optimizer.step()


def test_predict_closest_training():
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k])
transformer = ResidualSelfAttention(embedding_size, [n_k])
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())
Expand All @@ -114,8 +163,8 @@ def test_predict_closest_training():
target = torch.stack(target, dim=0)
target = target.detach()

embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
embeddings = entity_embeddings(center, key)
masks = get_zero_entities_mask([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
Expand All @@ -135,14 +184,12 @@ def test_predict_minimum_training():
n_k = 5
size = n_k + 1
embedding_size = 64
entity_embeddings = EntityEmbeddings(
size, [size], embedding_size, [n_k], concat_self=False
)
entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self
transformer = ResidualSelfAttention(embedding_size)
l_layer = LinearEncoder(embedding_size, 2, n_k)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())
list(entity_embedding.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,
Expand All @@ -166,8 +213,8 @@ def test_predict_minimum_training():
sliced_oh = onehots[:, : num + 1]
inp = torch.cat([inp, sliced_oh], dim=2)

embeddings = entity_embeddings(inp, [inp])
masks = EntityEmbeddings.get_masks([inp])
embeddings = entity_embedding(inp, inp)
masks = get_zero_entities_mask([inp])
prediction = transformer(embeddings, masks)
prediction = l_layer(prediction)
ce = loss(prediction, argmin)
Expand Down
161 changes: 78 additions & 83 deletions ml-agents/mlagents/trainers/torch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,41 @@
from mlagents.trainers.exception import UnityTrainerException


class MultiHeadAttention(torch.nn.Module):
def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks


class MultiHeadAttention(torch.nn.Module):

NEG_INF = -1e6

def __init__(self, embedding_size: int, num_heads: int):
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
:param embedding_size: The size of the embeddings that will be generated (should be
dividable by the num_heads)
:param total_max_elements: The maximum total number of entities that can be passed to
the module
:param num_heads: The number of heads of the attention module
"""
super().__init__()
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads
Expand Down Expand Up @@ -82,7 +101,7 @@ def forward(
return value_attention, att


class EntityEmbeddings(torch.nn.Module):
class EntityEmbedding(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self
Expand All @@ -92,95 +111,69 @@ class EntityEmbeddings(torch.nn.Module):

def __init__(
self,
x_self_size: int,
entity_sizes: List[int],
entity_size: int,
entity_num_max_elements: Optional[int],
embedding_size: int,
entity_num_max_elements: Optional[List[int]] = None,
concat_self: bool = True,
):
"""
Constructs an EntityEmbeddings module.
Constructs an EntityEmbedding module.
:param x_self_size: Size of "self" entity.
:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param embedding_size: Embedding size for entity encoders.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
:param entity_size: Size of other entities.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
:param embedding_size: Embedding size for the entity encoder.
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
self-attention.
"""
super().__init__()
self.self_size: int = x_self_size
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
self.self_size: int = 0
self.entity_size: int = entity_size
self.entity_num_max_elements: int = -1
if entity_num_max_elements is not None:
self.entity_num_max_elements = entity_num_max_elements

self.concat_self: bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self:
self.self_size = 0
self.embedding_size = embedding_size
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(
self.self_size + ent_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
for ent_size in self.entity_sizes
]
self.self_ent_encoder = LinearEncoder(
self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()

def forward(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> Tuple[torch.Tensor, int]:
if self.concat_self:
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entites
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
def add_self_embedding(self, size: int) -> None:
self.self_size = size
self.self_ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
)
encoded_entities = self.embedding_norm(encoded_entities)
return encoded_entities

@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
if self.self_size > 0:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# Concatenate all observations with self
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.self_ent_encoder(entities)
return encoded_entities


class ResidualSelfAttention(torch.nn.Module):
"""
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used
with an EntityEmbeddings module, to apply multi head self attention to encode information
with an EntityEmbedding module, to apply multi head self attention to encode information
about a "Self" and a list of relevant "Entities".
"""

Expand All @@ -189,7 +182,7 @@ class ResidualSelfAttention(torch.nn.Module):
def __init__(
self,
embedding_size: int,
entity_num_max_elements: Optional[List[int]] = None,
entity_num_max_elements: Optional[int] = None,
num_heads: int = 4,
):
"""
Expand All @@ -205,8 +198,7 @@ def __init__(
super().__init__()
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
_entity_num_max_elements = entity_num_max_elements
self.max_num_ent = sum(_entity_num_max_elements)
self.max_num_ent = entity_num_max_elements

self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size
Expand Down Expand Up @@ -237,11 +229,14 @@ def __init__(
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
self.residual_norm = LayerNorm()

def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor:
# Gather the maximum number of entities information
mask = torch.cat(key_masks, dim=1)

inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)
Expand Down

0 comments on commit 78e052b

Please sign in to comment.