Skip to content

Commit

Permalink
fix bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
xiemoyuan committed Feb 18, 2021
1 parent 222c07a commit 0724c6b
Showing 1 changed file with 72 additions and 23 deletions.
95 changes: 72 additions & 23 deletions python/paddle/nn/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,35 @@ def _convert_param_attr_to_list(param_attr, n):
return param_attrs


def _convert_attention_mask(attn_mask, dtype):
"""
Convert the attention mask to the target dtype we expect.
Parameters:
attn_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
dtype (VarType): The target type of `attn_mask` we expect.
Returns:
Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
"""
if attn_mask is not None and attn_mask.dtype != dtype:
attn_mask_dtype = convert_dtype(attn_mask.dtype)
if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype:
attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9
else:
attn_mask = paddle.cast(attn_mask, dtype)
return attn_mask


class MultiHeadAttention(Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Expand Down Expand Up @@ -334,8 +363,8 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand Down Expand Up @@ -378,11 +407,7 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
# Support bool or int mask
attn_mask_dtype = convert_dtype(attn_mask.dtype)
if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype:
attn_mask = (paddle.cast(attn_mask, product.dtype) - 1.0) * 1e9
else:
attn_mask = paddle.cast(attn_mask, product.dtype)
attn_mask = _convert_attention_mask(attn_mask, product.dtype)
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
Expand Down Expand Up @@ -519,8 +544,8 @@ def forward(self, src, src_mask=None, cache=None):
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -538,6 +563,8 @@ def forward(self, src, src_mask=None, cache=None):
incremental length. See `MultiHeadAttention.gen_cache` and \
`MultiHeadAttention.forward` for more details.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)

residual = src
if self.normalize_before:
src = self.norm1(src)
Expand Down Expand Up @@ -634,8 +661,8 @@ def forward(self, src, src_mask=None, cache=None):
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -653,6 +680,8 @@ def forward(self, src, src_mask=None, cache=None):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)

output = src
new_caches = []
for i, mod in enumerate(self.layers):
Expand Down Expand Up @@ -822,8 +851,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -832,8 +861,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have 'True' values. When the data type is int,
the unwanted positions have 0 values and the others have 0
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -853,6 +882,9 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)

residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
Expand Down Expand Up @@ -977,8 +1009,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`. When
the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -987,8 +1019,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have 'True' values. When the data type is int,
the unwanted positions have 0 values and the others have 0
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -1006,6 +1038,9 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)

output = tgt
new_caches = []
for i, mod in enumerate(self.layers):
Expand Down Expand Up @@ -1244,13 +1279,23 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
memory (Tensor): The output of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data type
should be float32 or float64.
src_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
tgt_mask (Tensor, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`. When
the data type is bool, the unwanted positions have `False`
values and the others have 'True' values. When the data type is
int, the unwanted positions have 0 values and the others have 0
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -1259,8 +1304,8 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have 'True' values. When the data type is int,
the unwanted positions have 0 values and the others have 0
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Expand All @@ -1269,7 +1314,11 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
Tensor: It is a tensor that has the same shape and data type \
as `tgt`, representing the output of Transformer decoder.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)
memory = self.encoder(src, src_mask=src_mask)

tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output
Expand Down

0 comments on commit 0724c6b

Please sign in to comment.