Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use singular entity embedding #4873

Merged
merged 1 commit into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions ml-agents/mlagents/trainers/tests/torch/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mlagents.trainers.torch.layers import linear_layer
from mlagents.trainers.torch.attention import (
MultiHeadAttention,
EntityEmbeddings,
EntityEmbedding,
ResidualSelfAttention,
)

Expand Down Expand Up @@ -70,7 +70,7 @@ def generate_input_helper(pattern):
input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)

masks = EntityEmbeddings.get_masks([input_1, input_2])
masks = ResidualSelfAttention.get_masks([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]
Expand All @@ -87,18 +87,16 @@ def test_simple_transformer_training():
torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k])
transformer = ResidualSelfAttention(embedding_size, [n_k])
entity_embeddings = EntityEmbedding(size, size, n_k, embedding_size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001
)
batch_size = 200
point_range = 3
init_error = -1.0
for _ in range(250):
center = torch.rand((batch_size, size)) * point_range * 2 - point_range
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
for _ in range(200):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum(
Expand All @@ -111,19 +109,15 @@ def test_simple_transformer_training():
target = torch.stack(target, dim=0)
target = target.detach()

embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
embeddings = entity_embeddings(center, key)
masks = ResidualSelfAttention.get_masks([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
if init_error == -1.0:
init_error = error.item()
else:
assert error.item() < init_error
print(error.item())
optimizer.zero_grad()
error.backward()
optimizer.step()
assert error.item() < 0.3
assert error.item() < 0.02
156 changes: 75 additions & 81 deletions ml-agents/mlagents/trainers/torch/attention.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
from mlagents.torch_utils import torch
from typing import Tuple, Optional, List
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer
from mlagents.trainers.torch.layers import (
LinearEncoder,
Initialization,
linear_layer,
LayerNorm,
)
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
from mlagents.trainers.exception import UnityTrainerException


class MultiHeadAttention(torch.nn.Module):
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
"""

NEG_INF = -1e6

def __init__(self, embedding_size: int, num_heads: int):
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
:param embedding_size: The size of the embeddings that will be generated (should be
dividable by the num_heads)
:param total_max_elements: The maximum total number of entities that can be passed to
the module
:param num_heads: The number of heads of the attention module
"""
super().__init__()
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads
Expand Down Expand Up @@ -77,98 +87,63 @@ def forward(
return value_attention, att


class EntityEmbeddings(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self
and additional entities. Can also concatenate self to entities for ego-centric self-
attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf.
"""

class EntityEmbedding(torch.nn.Module):
def __init__(
self,
x_self_size: int,
entity_sizes: List[int],
entity_size: int,
entity_num_max_elements: Optional[int],
embedding_size: int,
entity_num_max_elements: Optional[List[int]] = None,
concat_self: bool = True,
):
"""
Constructs an EntityEmbeddings module.
:param x_self_size: Size of "self" entity.
:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param embedding_size: Embedding size for entity encoders.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
:param entity_size: Size of other entitiy.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
:param embedding_size: Embedding size for the entity encoder.
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
self-attention.
"""
super().__init__()
self.self_size: int = x_self_size
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
self.entity_size: int = entity_size
self.entity_num_max_elements: int = -1
if entity_num_max_elements is not None:
self.entity_num_max_elements = entity_num_max_elements

self.concat_self: bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self:
self.self_size = 0
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(
self.self_size + ent_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
for ent_size in self.entity_sizes
]
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)

def forward(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> Tuple[torch.Tensor, int]:
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
if self.concat_self:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entites
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.ent_encoder(entities)
return encoded_entities

@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks


class ResidualSelfAttention(torch.nn.Module):
"""
Expand All @@ -182,7 +157,7 @@ class ResidualSelfAttention(torch.nn.Module):
def __init__(
self,
embedding_size: int,
entity_num_max_elements: Optional[List[int]] = None,
entity_num_max_elements: Optional[int] = None,
num_heads: int = 4,
):
"""
Expand All @@ -198,13 +173,13 @@ def __init__(
super().__init__()
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
_entity_num_max_elements = entity_num_max_elements
self.max_num_ent = sum(_entity_num_max_elements)
self.max_num_ent = entity_num_max_elements

self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size
)

# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.fc_q = linear_layer(
embedding_size,
embedding_size,
Expand All @@ -229,10 +204,14 @@ def __init__(
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
self.residual_norm = LayerNorm()

def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor:
# Gather the maximum number of entities information
mask = torch.cat(key_masks, dim=1)

inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)
Expand All @@ -252,9 +231,24 @@ def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Ten
output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
# Residual
output = self.fc_out(output) + inp
output = self.residual_norm(output)
# Average Pooling
numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
# Residual between x_self and the output of the module
return output

@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks
14 changes: 14 additions & 0 deletions ml-agents/mlagents/trainers/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ def forward(
pass


class LayerNorm(torch.nn.Module):
"""
A vanilla implementation of layer normalization https://arxiv.org/pdf/1607.06450.pdf
norm_x = (x - mean) / sqrt((x - mean) ^ 2)
This does not include the trainable parameters gamma and beta for performance speed.
Typically, this is norm_x * gamma + beta
"""

def forward(self, layer_activations: torch.Tensor) -> torch.Tensor:
mean = torch.mean(layer_activations, dim=-1, keepdim=True)
var = torch.mean((layer_activations - mean) ** 2, dim=-1, keepdim=True)
return (layer_activations - mean) / (torch.sqrt(var + 1e-5))


class LinearEncoder(torch.nn.Module):
"""
Linear layers.
Expand Down
26 changes: 20 additions & 6 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import ResidualSelfAttention, EntityEmbeddings
from mlagents.trainers.torch.attention import ResidualSelfAttention, EntityEmbedding


ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
Expand Down Expand Up @@ -139,9 +139,13 @@ def __init__(
+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
self.entity_encoder = EntityEmbeddings(
0, [obs_only_ent_size, q_ent_size], self.h_size, concat_self=False
self.obs_encoder = EntityEmbedding(
0, obs_only_ent_size, None, self.h_size, concat_self=False
)
self.obs_action_encoder = EntityEmbedding(
0, q_ent_size, None, self.h_size, concat_self=False
)

self.self_attn = ResidualSelfAttention(self.h_size)

encoder_input_size = self.h_size
Expand Down Expand Up @@ -223,7 +227,9 @@ def forward(
value_masks = self._get_masks_from_nans(value_inputs)
q_masks = self._get_masks_from_nans(q_inputs)

encoded_entity = self.entity_encoder(None, [value_input_concat, q_input_concat])
encoded_obs = self.obs_encoder(None, value_input_concat)
encoded_obs_action = self.obs_action_encoder(None, q_input_concat)
encoded_entity = torch.cat([encoded_obs, encoded_obs_action], dim=1)
encoded_state = self.self_attn(encoded_entity, [value_masks, q_masks])

if len(concat_encoded_obs) == 0:
Expand Down Expand Up @@ -643,7 +649,11 @@ def critic_pass(
all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length
)
value_outputs, critic_mem_out = self.critic(
critic_obs, [inputs], [actions], memories=critic_mem, sequence_length=sequence_length
critic_obs,
[inputs],
[actions],
memories=critic_mem,
sequence_length=sequence_length,
)
if mar_value_outputs is None:
mar_value_outputs = value_outputs
Expand Down Expand Up @@ -678,7 +688,11 @@ def get_stats_and_value(
)

value_outputs, critic_mem_outs = self.critic(
[inputs], critic_obs, actions, memories=critic_mem, sequence_length=sequence_length
[inputs],
critic_obs,
actions,
memories=critic_mem,
sequence_length=sequence_length,
)

return log_probs, entropies, value_outputs, mar_value_outputs
Expand Down