Skip to content

Commit

Permalink
Optimize CRF Perf (PaddlePaddle#357)
Browse files Browse the repository at this point in the history
* optimize crf loss

* fix ernie crf negative label; remove paddle.no_grad

* fix negative ignore label
  • Loading branch information
joey12300 authored May 17, 2021
1 parent 2563a9d commit 8e5fb8f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
6 changes: 3 additions & 3 deletions examples/information_extraction/DuEE/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,14 @@ def do_train():
def do_predict():
paddle.set_device(args.device)

no_entity_label = "O"
ignore_label = -1

tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
label_map = load_dict(args.tag_path)
id2label = {val: key for key, val in label_map.items()}
model = ErnieForTokenClassification.from_pretrained("ernie-1.0", num_classes=len(label_map))

no_entity_label = "O"
ignore_label = len(label_map)

print("============start predict==========")
if not args.init_ckpt or not os.path.isfile(args.init_ckpt):
raise Exception("init checkpoints {} not exist".format(args.init_ckpt))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def predict(model, data_loader, ds, label_vocab):
dev_ds.map(trans_func)
test_ds.map(trans_func)

ignore_label = -1
ignore_label = len(label_vocab)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type_ids
Expand Down
26 changes: 14 additions & 12 deletions paddlenlp/layers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
__all__ = ['LinearChainCrf', 'LinearChainCrfLoss', 'ViterbiDecoder']


def log_sum_exp(vec, dim=0):
# Avoid underflow and overflow
max_num = paddle.max(vec, dim)
max_exp = max_num.unsqueeze(-1)
return max_num + paddle.log(paddle.sum(paddle.exp(vec - max_exp), dim))


class LinearChainCrf(nn.Layer):
"""
LinearChainCrf is a linear chain Conditional Random Field layer, it can implement sequential dependencies in the predictions.
Expand Down Expand Up @@ -50,8 +57,6 @@ def __init__(self, num_labels, crf_lr=0.1, with_start_stop_tag=True):
attr=paddle.ParamAttr(learning_rate=crf_lr),
shape=[self.num_tags, self.num_tags],
dtype='float32')
with paddle.no_grad():
self.flattened_transition_params = paddle.flatten(self.transitions)
self.with_start_stop_tag = with_start_stop_tag

self._initial_alpha = None
Expand Down Expand Up @@ -110,11 +115,9 @@ def forward(self, inputs, lengths):
The normalizers tensor. Its dtype is float32 and has a shape of `[batch_size]`.
"""
batch_size, seq_len, n_labels = inputs.shape
inputs_t_exp = inputs.transpose([1, 0, 2]).unsqueeze(-1).expand(
[seq_len, batch_size, n_labels, n_labels])
inputs_t_exp = inputs.transpose([1, 0, 2]).unsqueeze(-1)
# trans_exp: batch_size, num_tags, num_tags
trans_exp = self.transitions.unsqueeze(0).expand(
[batch_size, n_labels, n_labels])
trans_exp = self.transitions.unsqueeze(0)

all_alpha = []
if self.with_start_stop_tag:
Expand All @@ -126,11 +129,10 @@ def forward(self, inputs, lengths):
if i == 0 and not self.with_start_stop_tag:
alpha = inputs[:, 0]
else:
alpha_exp = alpha.unsqueeze(1).expand(
[batch_size, n_labels, n_labels])
alpha_exp = alpha.unsqueeze(1)
# F(n) = logsumexp(F(n-1) + p(y_n) + T(y_{n-1}, y_n))
mat = input_exp + trans_exp + alpha_exp
alpha = paddle.logsumexp(mat, 2)
alpha = log_sum_exp(mat, 2).squeeze(-1)
all_alpha.append(alpha)

# Get the valid alpha
Expand All @@ -143,7 +145,7 @@ def forward(self, inputs, lengths):
if self.with_start_stop_tag:
# The last one step
alpha += self.transitions[self.stop_idx].unsqueeze(0)
norm_score = paddle.logsumexp(alpha, 1)
norm_score = log_sum_exp(alpha, 1).squeeze(-1)
return norm_score

def gold_score(self, inputs, labels, lengths):
Expand Down Expand Up @@ -219,9 +221,9 @@ def _trans_score(self, labels, lengths):
# Encode the indices in a flattened representation.
transition_indices = start_tag_indices * self.num_tags + stop_tag_indices
flattened_transition_indices = transition_indices.reshape([-1])

flattened_transition_params = paddle.flatten(self.transitions)
scores = paddle.gather(
self.flattened_transition_params,
flattened_transition_params,
flattened_transition_indices).reshape([batch_size, -1])
mask_scores = scores * mask[:, 1:]

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/layers/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def sequence_mask(seq_ids, valid_lengths):
mask (`Tensor`):
The output sequence mask. Its dtype is ``bool`` and has a shpe of [batch_size, sequence_length].
"""
lengths_exp = valid_lengths.unsqueeze(1).expand_as(seq_ids)
lengths_exp = valid_lengths.unsqueeze(1)
mask = seq_ids < lengths_exp

return mask

0 comments on commit 8e5fb8f

Please sign in to comment.