Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization of Transformer API #30957

Merged
merged 4 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 64 additions & 49 deletions python/paddle/fluid/tests/unittests/test_transformer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,22 @@ def generate_query_key_value_cache(self_attention,
num_heads,
query_length,
embed_dim,
attn_mask_type,
key_length=None,
value_length=None,
kdim=None,
vdim=None,
cache=None):
query = np.random.rand(batch_size, query_length,
embed_dim).astype("float32")
attn_mask = np.zeros((batch_size, num_heads, query_length, key_length))
attn_mask[0][0][0][0] = -1e9
attn_mask = np.ones(
(batch_size, num_heads, query_length, key_length), dtype=attn_mask_type)
if attn_mask_type == 'int64':
attn_mask = np.tril(attn_mask)
elif attn_mask_type == 'float64':
attn_mask = (np.tril(attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")

head_dim = embed_dim // num_heads
if self_attention:
Expand Down Expand Up @@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn):
k = k.transpose([0, 1, 3, 2])
qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64))
if attn_mask is not None:
if attn_mask.dtype.name == 'int64':
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
else:
attn_mask = attn_mask.astype(qkt.dtype)
qkt += attn_mask
weight = softmax(qkt)
attn_heads = batch_matmul(weight, v)
Expand Down Expand Up @@ -219,53 +230,57 @@ def multihead_attention_test_helper(self_attention, cache):
# generate params for multi_head_attention
batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params(
"attn", self_attention)
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
self_attention, batch_size, num_heads, query_length,
embed_dim, key_length, value_length, kdim, vdim, cache)
if cache and self_attention:
attn_mask = np.concatenate((attn_mask, attn_mask), axis=3)
need_weight, param_attr, bias_attr = False, None, None
# call paddle's function
multi_head_attn = MultiHeadAttention(
embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight,
param_attr, bias_attr)
# construct cache object
cache_obj = None
if cache_dict:
if 'k' and 'v' in cache_dict:
cache_obj = multi_head_attn.Cache(
paddle.to_tensor(cache_dict['k']),
paddle.to_tensor(cache_dict['v']))
elif 'static_k' and 'static_v' in cache_dict:
cache_obj = multi_head_attn.StaticCache(
paddle.to_tensor(cache_dict['static_k']),
paddle.to_tensor(cache_dict['static_v']))
if attn_mask is not None:
attn_output = multi_head_attn(
paddle.to_tensor(query),
paddle.to_tensor(key),
paddle.to_tensor(value),
paddle.to_tensor(attn_mask), cache_obj)
else:
attn_output = multi_head_attn(
paddle.to_tensor(query),
paddle.to_tensor(key),
paddle.to_tensor(value), attn_mask, cache_obj)
attn_output = attn_output[0] if cache_dict else attn_output

# implementation by numpy
# compute q, k, v
q, k, v, _ = prepare_qkv(query, key, value, num_heads,
embed_dim, self_attention,
multi_head_attn, cache_dict)
# scale dot product attention
attn_heads = scaled_dot_product_attention(
q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn)
out_proj_weight = multi_head_attn.out_proj.weight.numpy()
reference = fc(attn_heads, out_proj_weight)

np.testing.assert_allclose(
attn_output.numpy(), reference, atol=1e-6)
for attn_mask_type in ['int64', 'float64']:
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
self_attention, batch_size, num_heads, query_length,
embed_dim, attn_mask_type, key_length, value_length,
kdim, vdim, cache)
if cache and self_attention:
attn_mask = np.concatenate(
(attn_mask, attn_mask), axis=3)
need_weight, param_attr, bias_attr = False, None, None
# call paddle's function
multi_head_attn = MultiHeadAttention(
embed_dim, num_heads, attn_dropout, kdim, vdim,
need_weight, param_attr, bias_attr)
# construct cache object
cache_obj = None
if cache_dict:
if 'k' and 'v' in cache_dict:
cache_obj = multi_head_attn.Cache(
paddle.to_tensor(cache_dict['k']),
paddle.to_tensor(cache_dict['v']))
elif 'static_k' and 'static_v' in cache_dict:
cache_obj = multi_head_attn.StaticCache(
paddle.to_tensor(cache_dict['static_k']),
paddle.to_tensor(cache_dict['static_v']))
if attn_mask is not None:
attn_output = multi_head_attn(
paddle.to_tensor(query),
paddle.to_tensor(key),
paddle.to_tensor(value),
paddle.to_tensor(attn_mask), cache_obj)
else:
attn_output = multi_head_attn(
paddle.to_tensor(query),
paddle.to_tensor(key),
paddle.to_tensor(value), attn_mask, cache_obj)
attn_output = attn_output[0] if cache_dict else attn_output

# implementation by numpy
# compute q, k, v
q, k, v, _ = prepare_qkv(query, key, value, num_heads,
embed_dim, self_attention,
multi_head_attn, cache_dict)
# scale dot product attention
attn_heads = scaled_dot_product_attention(
q, k, v, embed_dim // num_heads, attn_mask,
multi_head_attn)
out_proj_weight = multi_head_attn.out_proj.weight.numpy()
reference = fc(attn_heads, out_proj_weight)

np.testing.assert_allclose(
attn_output.numpy(), reference, atol=1e-6)

multihead_attention_test_helper(True, True)
multihead_attention_test_helper(True, False)
Expand Down
Loading