Skip to content

Commit

Permalink
Eval einsum in bigbird (PaddlePaddle#314)
Browse files Browse the repository at this point in the history
* matmul->einsum

* fix rand_mask_idx_list

Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
  • Loading branch information
joey12300 and ZeyuChen authored Apr 28, 2021
1 parent 04ab7af commit a79c9cb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/language_model/bigbird/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def do_evalute(model, criterion, metric, test_data_loader):
global_steps += 1
input_ids, labels = batch[:2]
rand_mask_idx_list = batch[2:]
output = model(input_ids, None, rand_mask_idx_list)
output = model(input_ids, rand_mask_idx_list=rand_mask_idx_list)
loss = criterion(output, labels)
correct = metric.compute(output, labels)
metric.update(correct)
Expand Down
65 changes: 32 additions & 33 deletions paddlenlp/transformers/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from paddle.nn import Linear, Dropout, LayerNorm, LayerList, Layer
from paddle import ParamAttr
import paddlenlp


class Registry(object):
Expand Down Expand Up @@ -276,14 +277,13 @@ def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size,
[B, L - G, bs, 1])
temp_key_mask_front = paddle.reshape(blocked_key_mask[:, :GF],
[B, 1, 1, GF * bs])
global_block_mask_front = paddle.matmul(temp_query_mask,
temp_key_mask_front)
global_block_mask_front = paddlenlp.ops.einsum(
"blqd,bmdk->blqk", temp_query_mask, temp_key_mask_front)

temp_key_mask_back = paddle.reshape(blocked_key_mask[:, -GB:],
[B, 1, 1, GB * bs])
global_block_mask_back = paddle.matmul(temp_query_mask,
temp_key_mask_back)

global_block_mask_back = paddlenlp.ops.einsum(
"blqd,bmdk->blqk", temp_query_mask, temp_key_mask_back)
# create window block mask
key_mask_list = []
for query_block_id in range(GF, GF + W // 2):
Expand Down Expand Up @@ -326,8 +326,8 @@ def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size,
[roll_key_mask1, window_key_mask, roll_key_mask2], axis=1)
window_key_mask = paddle.unsqueeze(window_key_mask, axis=2)
# [B, L-G, bs, 1] * [B, L-G, 1, W*bs] -> [B, L-G, bs, W*bs]
window_block_mask = paddle.matmul(temp_query_mask, window_key_mask)

window_block_mask = paddlenlp.ops.einsum(
"blkd,bldq->blkq", temp_query_mask, window_key_mask)
band_mask = paddle.concat(
[
global_block_mask_front, window_block_mask,
Expand Down Expand Up @@ -435,17 +435,13 @@ def _get_rand_mask(self, blocked_query_mask, blocked_key_mask,
for b in range(B)
]
temp_block_key_mask = paddle.concat(temp_block_key_mask_list, 0)
temp_block_key_mask = paddle.reshape(temp_block_key_mask,
[B, H, L - G, 1, R * bs])

temp_blocked_query_mask = paddle.unsqueeze(
blocked_query_mask[:, GF:-GB], 1)
temp_blocked_query_mask = paddle.expand(temp_blocked_query_mask,
[B, H, L - G, -1])
temp_blocked_query_mask = paddle.reshape(temp_blocked_query_mask,
[B, H, L - G, bs, 1])

rand_mask = paddle.matmul(temp_blocked_query_mask, temp_block_key_mask)
temp_block_key_mask = paddle.reshape(temp_block_key_mask, [
B, temp_block_key_mask.shape[0] // B // (L - GF - GB) // R,
L - GF - GB, -1
])
rand_mask = paddlenlp.ops.einsum("blq,bhlk->bhlqk",
blocked_query_mask[:, GF:-GB],
temp_block_key_mask)
return rand_mask

def _gather_random_key_value(self, blocked_matrix, rand_mask_idx, B, T):
Expand Down Expand Up @@ -575,35 +571,38 @@ def forward(self,
[band_value_matrix, random_values], axis=3)
second_top_value_matrix, second_middle_value_matrix, second_bottom_value_matrix = \
self._get_splited_matrix(second_value_matrix)

second_product = paddle.matmul(
second_query_matrix, second_key_matrix, transpose_y=True)
second_product = paddlenlp.ops.einsum(
"bhlqd,bhlkd->bhlqk", second_query_matrix, second_key_matrix)
second_product = second_product * (d_head**-0.5)
second_product += (1 - second_mask) * -1e6
second_weights = F.softmax(second_product)

second_top_weights, second_middle_weights, second_bottom_weights = \
self._get_splited_matrix(second_weights)
second_top_out = paddle.matmul(second_top_weights,
second_top_value_matrix)
second_top_out = paddlenlp.ops.einsum(
"bhlqk,bhlkd->bhlqd", second_top_weights, second_top_value_matrix)

second_middle_out = paddle.matmul(
second_middle_out = paddlenlp.ops.einsum(
"bhlqk,bhlkd->bhlqd",
second_middle_weights[:, :, :, :, GF * bs:-(GB + R) * bs],
second_middle_value_matrix[:, :, :, GF * bs:-(GB + R) * bs])
# add global block attention
second_middle_out += paddle.matmul(
second_middle_weights[:, :, :, :, :GF * bs],
blocked_value_matrix[:, :, 0:GF])
second_middle_out += paddle.matmul(
second_middle_out += paddlenlp.ops.einsum(
"bhlqk,bhkd->bhlqd", second_middle_weights[:, :, :, :, :GF * bs],
blocked_value_matrix[:, :, 0])
second_middle_out += paddlenlp.ops.einsum(
"bhlqk,bhkd->bhlqd",
second_middle_weights[:, :, :, :, -(GB + R) * bs:-R * bs],
blocked_value_matrix[:, :, -GB:])
blocked_value_matrix[:, :, -GB])
# add random block attention
second_middle_out += paddle.matmul(
second_middle_weights[:, :, :, :, -R * bs:],
second_middle_out += paddlenlp.ops.einsum(
"...qk,...kd->...qd", second_middle_weights[:, :, :, :, -R * bs:],
random_values[:, :, GF:-GB])

second_bottom_out = paddle.matmul(second_bottom_weights,
second_bottom_value_matrix)
second_bottom_out = paddlenlp.ops.einsum("bhlqk,bhlkd->bhlqd",
second_bottom_weights,
second_bottom_value_matrix)

second_out = paddle.concat(
[second_top_out, second_middle_out, second_bottom_out], axis=2)
second_out = paddle.reshape(second_out, [B, H, (L - G) * bs, -1])
Expand Down

0 comments on commit a79c9cb

Please sign in to comment.