From 24d26dea618f507e26d5e065427674930d679f00 Mon Sep 17 00:00:00 2001 From: xiemoyuan Date: Mon, 8 Feb 2021 16:54:28 +0800 Subject: [PATCH 1/4] Support 'bool' and 'int' for attention mask. --- python/paddle/nn/layer/transformer.py | 117 ++++++++++++++++---------- 1 file changed, 72 insertions(+), 45 deletions(-) diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 75f998b037e30..0996888d0f033 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -34,6 +34,7 @@ from ...fluid import layers from ...fluid.dygraph import Layer, LayerList from ...fluid.param_attr import ParamAttr +from ...fluid.data_feeder import convert_dtype def _convert_param_attr_to_list(param_attr, n): @@ -331,11 +332,13 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape - broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, - where the unwanted positions have `-INF` values and the others - have 0 values. The data type should be float32 or float64. It can - be None when nothing wanted or needed to be prevented attention to. - Default None + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have 'True' values. When the data type is + int, the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): It is a namedtuple with `k` and `v` as fields, and stores tensors shaped `[batch_size, num_heads, length, embed_dim]` which are results @@ -374,7 +377,12 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): product = layers.matmul( x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) if attn_mask is not None: - # TODO(guosheng): support bool mask + # Support bool or int mask + attn_mask_dtype = convert_dtype(attn_mask.dtype) + if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype: + attn_mask = (paddle.cast(attn_mask, product.dtype) - 1.0) * 1e9 + else: + attn_mask = paddle.cast(attn_mask, product.dtype) product = product + attn_mask weights = F.softmax(product) if self.dropout: @@ -509,11 +517,13 @@ def forward(self, src, src_mask=None, cache=None): src_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape - broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, - where the unwanted positions have `-INF` values and the others - have 0 values. The data type should be float32 or float64. It can - be None when nothing wanted or needed to be prevented attention to. - Default None + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have 'True' values. When the data type is + int, the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`. See `TransformerEncoderLayer.gen_cache` for more details. It is only used for inference and should be None for training. Default @@ -531,7 +541,7 @@ def forward(self, src, src_mask=None, cache=None): residual = src if self.normalize_before: src = self.norm1(src) - # TODO(guosheng): Add cache for encoder for the usage like UniLM + # Add cache for encoder for the usage like UniLM if cache is None: src = self.self_attn(src, src, src, src_mask) else: @@ -622,11 +632,13 @@ def forward(self, src, src_mask=None, cache=None): src_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape - broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, - where the unwanted positions have `-INF` values and the others - have 0 values. The data type should be float32 or float64. It can - be None when nothing wanted or needed to be prevented attention to. - Default None + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have 'True' values. When the data type is + int, the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. cache (list, optional): It is a list, and each element in the list is `incremental_cache` produced by `TransformerEncoderLayer.gen_cache`. See `TransformerEncoder.gen_cache` for more details. It is only @@ -808,18 +820,23 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): tgt_mask (Tensor, optional): A tensor used in self attention to prevents attention to some unwanted positions, usually the the subsequent positions. It is a tensor with shape broadcasted - to `[batch_size, n_head, target_length, target_length]`, - where the unwanted positions have `-INF` values and the others - have 0 values. The data type should be float32 or float64. It can - be None when nothing wanted or needed to be prevented attention to. - Default None + to `[batch_size, n_head, target_length, target_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have 'True' values. When the data type is + int, the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. memory_mask (Tensor, optional): A tensor used in decoder-encoder cross attention to prevents attention to some unwanted positions, - usually the paddings. It is a tensor with shape broadcasted to - `[batch_size, n_head, target_length, source_length]`, where the - unwanted positions have `-INF` values and the others have 0 values. - The data type should be float32 or float64. It can be None when - nothing wanted or needed to be prevented attention to. Default None + usually the paddings. It is a tensor with shape broadcasted to + `[batch_size, n_head, target_length, source_length]`. When the + data type is bool, the unwanted positions have `False` values + and the others have 'True' values. When the data type is int, + the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. cache (tuple, optional): It is a tuple( :code:`(incremental_cache, static_cache)` ), `incremental_cache` is an instance of `MultiHeadAttention.Cache`, `static_cache` is an instance of `MultiHeadAttention.StaticCache. @@ -958,18 +975,23 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): tgt_mask (Tensor, optional): A tensor used in self attention to prevents attention to some unwanted positions, usually the the subsequent positions. It is a tensor with shape broadcasted - to `[batch_size, n_head, target_length, target_length]`, - where the unwanted positions have `-INF` values and the others - have 0 values. The data type should be float32 or float64. It can - be None when nothing wanted or needed to be prevented attention to. - Default None + to `[batch_size, n_head, target_length, target_length]`. When + the data type is bool, the unwanted positions have `False` + values and the others have 'True' values. When the data type is + int, the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. memory_mask (Tensor, optional): A tensor used in decoder-encoder cross attention to prevents attention to some unwanted positions, usually the paddings. It is a tensor with shape broadcasted to - `[batch_size, n_head, target_length, source_length]`, where the - unwanted positions have `-INF` values and the others have 0 values. - The data type should be float32 or float64. It can be None when - nothing wanted or needed to be prevented attention to. Default None + `[batch_size, n_head, target_length, source_length]`. When the + data type is bool, the unwanted positions have `False` values + and the others have 'True' values. When the data type is int, + the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. cache (list, optional): It is a list, and each element in the list is a tuple( :code:`(incremental_cache, static_cache)` ). See `TransformerDecoder.gen_cache` for more details. It is only @@ -1225,18 +1247,23 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): tgt_mask (Tensor, optional): A tensor used in self attention to prevents attention to some unwanted positions, usually the the subsequent positions. It is a tensor with shape broadcasted - to `[batch_size, n_head, target_length, target_length]`, - where the unwanted positions have `-INF` values and the others - have 0 values. The data type should be float32 or float64. It can - be None when nothing wanted or needed to be prevented attention to. - Default None + to `[batch_size, n_head, target_length, target_length]`. When + the data type is bool, the unwanted positions have `False` + values and the others have 'True' values. When the data type is + int, the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. memory_mask (Tensor, optional): A tensor used in decoder-encoder cross attention to prevents attention to some unwanted positions, usually the paddings. It is a tensor with shape broadcasted to - `[batch_size, n_head, target_length, source_length]`, where the - unwanted positions have `-INF` values and the others have 0 values. - The data type should be float32 or float64. It can be None when - nothing wanted or needed to be prevented attention to. Default None + `[batch_size, n_head, target_length, source_length]`. When the + data type is bool, the unwanted positions have `False` values + and the others have 'True' values. When the data type is int, + the unwanted positions have 0 values and the others have 0 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. Returns: Tensor: It is a tensor that has the same shape and data type \ From ed50a5c620b8dbf60cc8fdbce12d7123f4314b8e Mon Sep 17 00:00:00 2001 From: xiemoyuan Date: Mon, 8 Feb 2021 17:08:54 +0800 Subject: [PATCH 2/4] Update docs. --- python/paddle/nn/layer/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 0996888d0f033..731fdf63e64a5 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -106,7 +106,7 @@ class MultiHeadAttention(Layer): weight_attr(ParamAttr, optional): To specify the weight parameter property. Default: None, which means the default weight parameter property is used. See usage for details in :code:`ParamAttr` . - bias_attr (ParamAttr, optional): To specify the bias parameter property. + bias_attr (ParamAttr|bool, optional): To specify the bias parameter property. Default: None, which means the default bias parameter property is used. If it is set to False, this layer will not have trainable bias parameter. See usage for details in :code:`ParamAttr` . From 222c07acf85d8612677e949fccaa43af856eece4 Mon Sep 17 00:00:00 2001 From: xiemoyuan Date: Tue, 9 Feb 2021 12:18:37 +0800 Subject: [PATCH 3/4] Add unittest for Transformer. --- .../tests/unittests/test_transformer_api.py | 113 ++++++++++-------- 1 file changed, 64 insertions(+), 49 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_transformer_api.py b/python/paddle/fluid/tests/unittests/test_transformer_api.py index 194503b8ad2e7..587cedc6aad74 100644 --- a/python/paddle/fluid/tests/unittests/test_transformer_api.py +++ b/python/paddle/fluid/tests/unittests/test_transformer_api.py @@ -51,6 +51,7 @@ def generate_query_key_value_cache(self_attention, num_heads, query_length, embed_dim, + attn_mask_type, key_length=None, value_length=None, kdim=None, @@ -58,8 +59,14 @@ def generate_query_key_value_cache(self_attention, cache=None): query = np.random.rand(batch_size, query_length, embed_dim).astype("float32") - attn_mask = np.zeros((batch_size, num_heads, query_length, key_length)) - attn_mask[0][0][0][0] = -1e9 + attn_mask = np.ones( + (batch_size, num_heads, query_length, key_length), dtype=attn_mask_type) + if attn_mask_type == 'int64': + attn_mask = np.tril(attn_mask) + elif attn_mask_type == 'float64': + attn_mask = (np.tril(attn_mask) - 1.0) * 1e9 + else: + raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") head_dim = embed_dim // num_heads if self_attention: @@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn): k = k.transpose([0, 1, 3, 2]) qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64)) if attn_mask is not None: + if attn_mask.dtype.name == 'int64': + attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9 + else: + attn_mask = attn_mask.astype(qkt.dtype) qkt += attn_mask weight = softmax(qkt) attn_heads = batch_matmul(weight, v) @@ -219,53 +230,57 @@ def multihead_attention_test_helper(self_attention, cache): # generate params for multi_head_attention batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params( "attn", self_attention) - query, key, value, attn_mask, cache_dict = generate_query_key_value_cache( - self_attention, batch_size, num_heads, query_length, - embed_dim, key_length, value_length, kdim, vdim, cache) - if cache and self_attention: - attn_mask = np.concatenate((attn_mask, attn_mask), axis=3) - need_weight, param_attr, bias_attr = False, None, None - # call paddle's function - multi_head_attn = MultiHeadAttention( - embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight, - param_attr, bias_attr) - # construct cache object - cache_obj = None - if cache_dict: - if 'k' and 'v' in cache_dict: - cache_obj = multi_head_attn.Cache( - paddle.to_tensor(cache_dict['k']), - paddle.to_tensor(cache_dict['v'])) - elif 'static_k' and 'static_v' in cache_dict: - cache_obj = multi_head_attn.StaticCache( - paddle.to_tensor(cache_dict['static_k']), - paddle.to_tensor(cache_dict['static_v'])) - if attn_mask is not None: - attn_output = multi_head_attn( - paddle.to_tensor(query), - paddle.to_tensor(key), - paddle.to_tensor(value), - paddle.to_tensor(attn_mask), cache_obj) - else: - attn_output = multi_head_attn( - paddle.to_tensor(query), - paddle.to_tensor(key), - paddle.to_tensor(value), attn_mask, cache_obj) - attn_output = attn_output[0] if cache_dict else attn_output - - # implementation by numpy - # compute q, k, v - q, k, v, _ = prepare_qkv(query, key, value, num_heads, - embed_dim, self_attention, - multi_head_attn, cache_dict) - # scale dot product attention - attn_heads = scaled_dot_product_attention( - q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn) - out_proj_weight = multi_head_attn.out_proj.weight.numpy() - reference = fc(attn_heads, out_proj_weight) - - np.testing.assert_allclose( - attn_output.numpy(), reference, atol=1e-6) + for attn_mask_type in ['int64', 'float64']: + query, key, value, attn_mask, cache_dict = generate_query_key_value_cache( + self_attention, batch_size, num_heads, query_length, + embed_dim, attn_mask_type, key_length, value_length, + kdim, vdim, cache) + if cache and self_attention: + attn_mask = np.concatenate( + (attn_mask, attn_mask), axis=3) + need_weight, param_attr, bias_attr = False, None, None + # call paddle's function + multi_head_attn = MultiHeadAttention( + embed_dim, num_heads, attn_dropout, kdim, vdim, + need_weight, param_attr, bias_attr) + # construct cache object + cache_obj = None + if cache_dict: + if 'k' and 'v' in cache_dict: + cache_obj = multi_head_attn.Cache( + paddle.to_tensor(cache_dict['k']), + paddle.to_tensor(cache_dict['v'])) + elif 'static_k' and 'static_v' in cache_dict: + cache_obj = multi_head_attn.StaticCache( + paddle.to_tensor(cache_dict['static_k']), + paddle.to_tensor(cache_dict['static_v'])) + if attn_mask is not None: + attn_output = multi_head_attn( + paddle.to_tensor(query), + paddle.to_tensor(key), + paddle.to_tensor(value), + paddle.to_tensor(attn_mask), cache_obj) + else: + attn_output = multi_head_attn( + paddle.to_tensor(query), + paddle.to_tensor(key), + paddle.to_tensor(value), attn_mask, cache_obj) + attn_output = attn_output[0] if cache_dict else attn_output + + # implementation by numpy + # compute q, k, v + q, k, v, _ = prepare_qkv(query, key, value, num_heads, + embed_dim, self_attention, + multi_head_attn, cache_dict) + # scale dot product attention + attn_heads = scaled_dot_product_attention( + q, k, v, embed_dim // num_heads, attn_mask, + multi_head_attn) + out_proj_weight = multi_head_attn.out_proj.weight.numpy() + reference = fc(attn_heads, out_proj_weight) + + np.testing.assert_allclose( + attn_output.numpy(), reference, atol=1e-6) multihead_attention_test_helper(True, True) multihead_attention_test_helper(True, False) From 0724c6bfc81e12f5a7a22bbf18519b9baa4cd01c Mon Sep 17 00:00:00 2001 From: xiemoyuan Date: Thu, 18 Feb 2021 11:29:46 +0800 Subject: [PATCH 4/4] fix bugs. --- python/paddle/nn/layer/transformer.py | 95 ++++++++++++++++++++------- 1 file changed, 72 insertions(+), 23 deletions(-) diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 731fdf63e64a5..5aded4949e2d7 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -83,6 +83,35 @@ def _convert_param_attr_to_list(param_attr, n): return param_attrs +def _convert_attention_mask(attn_mask, dtype): + """ + Convert the attention mask to the target dtype we expect. + + Parameters: + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. + dtype (VarType): The target type of `attn_mask` we expect. + + Returns: + Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`. + """ + if attn_mask is not None and attn_mask.dtype != dtype: + attn_mask_dtype = convert_dtype(attn_mask.dtype) + if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype: + attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9 + else: + attn_mask = paddle.cast(attn_mask, dtype) + return attn_mask + + class MultiHeadAttention(Layer): """ Attention mapps queries and a set of key-value pairs to outputs, and @@ -334,8 +363,8 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): paddings or the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the data type is bool, the unwanted positions have `False` - values and the others have 'True' values. When the data type is - int, the unwanted positions have 0 values and the others have 0 + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -378,11 +407,7 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) if attn_mask is not None: # Support bool or int mask - attn_mask_dtype = convert_dtype(attn_mask.dtype) - if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype: - attn_mask = (paddle.cast(attn_mask, product.dtype) - 1.0) * 1e9 - else: - attn_mask = paddle.cast(attn_mask, product.dtype) + attn_mask = _convert_attention_mask(attn_mask, product.dtype) product = product + attn_mask weights = F.softmax(product) if self.dropout: @@ -519,8 +544,8 @@ def forward(self, src, src_mask=None, cache=None): paddings or the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the data type is bool, the unwanted positions have `False` - values and the others have 'True' values. When the data type is - int, the unwanted positions have 0 values and the others have 0 + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -538,6 +563,8 @@ def forward(self, src, src_mask=None, cache=None): incremental length. See `MultiHeadAttention.gen_cache` and \ `MultiHeadAttention.forward` for more details. """ + src_mask = _convert_attention_mask(src_mask, src.dtype) + residual = src if self.normalize_before: src = self.norm1(src) @@ -634,8 +661,8 @@ def forward(self, src, src_mask=None, cache=None): paddings or the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the data type is bool, the unwanted positions have `False` - values and the others have 'True' values. When the data type is - int, the unwanted positions have 0 values and the others have 0 + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -653,6 +680,8 @@ def forward(self, src, src_mask=None, cache=None): See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ for more details. """ + src_mask = _convert_attention_mask(src_mask, src.dtype) + output = src new_caches = [] for i, mod in enumerate(self.layers): @@ -822,8 +851,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, target_length, target_length]`. When the data type is bool, the unwanted positions have `False` - values and the others have 'True' values. When the data type is - int, the unwanted positions have 0 values and the others have 0 + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -832,8 +861,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): usually the paddings. It is a tensor with shape broadcasted to `[batch_size, n_head, target_length, source_length]`. When the data type is bool, the unwanted positions have `False` values - and the others have 'True' values. When the data type is int, - the unwanted positions have 0 values and the others have 0 + and the others have `True` values. When the data type is int, + the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -853,6 +882,9 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ for more details. """ + tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) + memory_mask = _convert_attention_mask(memory_mask, memory.dtype) + residual = tgt if self.normalize_before: tgt = self.norm1(tgt) @@ -977,8 +1009,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, target_length, target_length]`. When the data type is bool, the unwanted positions have `False` - values and the others have 'True' values. When the data type is - int, the unwanted positions have 0 values and the others have 0 + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -987,8 +1019,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): usually the paddings. It is a tensor with shape broadcasted to `[batch_size, n_head, target_length, source_length]`. When the data type is bool, the unwanted positions have `False` values - and the others have 'True' values. When the data type is int, - the unwanted positions have 0 values and the others have 0 + and the others have `True` values. When the data type is int, + the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -1006,6 +1038,9 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ for more details. """ + tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) + memory_mask = _convert_attention_mask(memory_mask, memory.dtype) + output = tgt new_caches = [] for i, mod in enumerate(self.layers): @@ -1244,13 +1279,23 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): memory (Tensor): The output of Transformer encoder. It is a tensor with shape `[batch_size, source_length, d_model]`. The data type should be float32 or float64. + src_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. tgt_mask (Tensor, optional): A tensor used in self attention to prevents attention to some unwanted positions, usually the the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, target_length, target_length]`. When the data type is bool, the unwanted positions have `False` - values and the others have 'True' values. When the data type is - int, the unwanted positions have 0 values and the others have 0 + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -1259,8 +1304,8 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): usually the paddings. It is a tensor with shape broadcasted to `[batch_size, n_head, target_length, source_length]`. When the data type is bool, the unwanted positions have `False` values - and the others have 'True' values. When the data type is int, - the unwanted positions have 0 values and the others have 0 + and the others have `True` values. When the data type is int, + the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. @@ -1269,7 +1314,11 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): Tensor: It is a tensor that has the same shape and data type \ as `tgt`, representing the output of Transformer decoder. """ + src_mask = _convert_attention_mask(src_mask, src.dtype) memory = self.encoder(src, src_mask=src_mask) + + tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) + memory_mask = _convert_attention_mask(memory_mask, memory.dtype) output = self.decoder( tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) return output