From cabf3ab5bed12b6dfc3021489a81534318acdbcb Mon Sep 17 00:00:00 2001 From: kinghuin Date: Fri, 16 Apr 2021 16:03:31 +0800 Subject: [PATCH] fix lac log and optimize waybill (#265) * fix lac log and optimize waybill * fix typo * note for drop cls * fix waybill cuda error * use ernie_crf_result * add copyright --- .../waybill_ie/README.md | 20 ++- .../information_extraction/waybill_ie/data.py | 100 +++++++++++++ .../waybill_ie/model.py | 49 +++++++ .../waybill_ie/run_bigru_crf.py | 67 +-------- .../waybill_ie/run_ernie.py | 91 +++--------- .../waybill_ie/run_ernie_crf.py | 132 ++++++++++++++++++ examples/lexical_analysis/train.py | 3 +- paddlenlp/layers/crf.py | 5 +- 8 files changed, 323 insertions(+), 144 deletions(-) create mode 100644 examples/information_extraction/waybill_ie/data.py create mode 100644 examples/information_extraction/waybill_ie/model.py create mode 100644 examples/information_extraction/waybill_ie/run_ernie_crf.py diff --git a/examples/information_extraction/waybill_ie/README.md b/examples/information_extraction/waybill_ie/README.md index fba0a31c123ec..dc7656ae42b78 100644 --- a/examples/information_extraction/waybill_ie/README.md +++ b/examples/information_extraction/waybill_ie/README.md @@ -41,17 +41,29 @@ python download.py --data_dir ./ #### 启动BiGRU + CRF训练 ```bash -export CUDA_VISIBLE_DEVICES=0 # 只支持单卡训练 +export CUDA_VISIBLE_DEVICES=0 python run_bigru_crf.py ``` -更多详细教程请参考:[基于Bi-GRU+CRF的快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/1317771) #### 启动ERNIE + FC训练 ```bash -export CUDA_VISIBLE_DEVICES=0 # 只支持单卡训练 +export CUDA_VISIBLE_DEVICES=0 python run_ernie.py ``` -更多详细教程请参考:[使用PaddleNLP预训练模型ERNIE优化快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/1329361) + +#### 启动ERNIE + CRF训练 + + +```bash +export CUDA_VISIBLE_DEVICES=0 +python run_ernie_crf.py +``` + +更多详细教程请参考: + +[基于Bi-GRU+CRF的快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/1317771) + +[使用PaddleNLP预训练模型ERNIE优化快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/1329361) diff --git a/examples/information_extraction/waybill_ie/data.py b/examples/information_extraction/waybill_ie/data.py new file mode 100644 index 0000000000000..9deaba4afe41f --- /dev/null +++ b/examples/information_extraction/waybill_ie/data.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlenlp.datasets import MapDataset + + +def load_dict(dict_path): + vocab = {} + i = 0 + with open(dict_path, 'r', encoding='utf-8') as fin: + for line in fin: + key = line.strip('\n') + vocab[key] = i + i += 1 + return vocab + + +def load_dataset(datafiles): + def read(data_path): + with open(data_path, 'r', encoding='utf-8') as fp: + next(fp) # Skip header + for line in fp.readlines(): + words, labels = line.strip('\n').split('\t') + words = words.split('\002') + labels = labels.split('\002') + yield words, labels + + if isinstance(datafiles, str): + return MapDataset(list(read(datafiles))) + elif isinstance(datafiles, list) or isinstance(datafiles, tuple): + return [MapDataset(list(read(datafile))) for datafile in datafiles] + + +def convert_tokens_to_ids(tokens, vocab, oov_token=None): + token_ids = [] + oov_id = vocab.get(oov_token) if oov_token else None + for token in tokens: + token_id = vocab.get(token, oov_id) + token_ids.append(token_id) + return token_ids + + +def convert_ernie_example(example, tokenizer, label_vocab): + tokens, labels = example + tokenized_input = tokenizer( + tokens, return_length=True, is_split_into_words=True) + # Token '[CLS]' and '[SEP]' will get label 'O' + labels = ['O'] + labels + ['O'] + tokenized_input['labels'] = [label_vocab[x] for x in labels] + return tokenized_input['input_ids'], tokenized_input[ + 'token_type_ids'], tokenized_input['seq_len'], tokenized_input['labels'] + + +def parse_decodes(sentences, predictions, lengths, label_vocab): + """Parse the padding result + + Args: + sentences (list): the tagging sentences. + predictions (list): the prediction tags. + lengths (list): the valid length of each sentence. + label_vocab (dict): the label vocab. + + Returns: + outputs (list): the formatted output. + """ + predictions = [x for batch in predictions for x in batch] + lengths = [x for batch in lengths for x in batch] + id_label = dict(zip(label_vocab.values(), label_vocab.keys())) + + outputs = [] + for idx, end in enumerate(lengths): + sent = sentences[idx][:end] + tags = [id_label[x] for x in predictions[idx][:end]] + sent_out = [] + tags_out = [] + words = "" + for s, t in zip(sent, tags): + if t.endswith('-B') or t == 'O': + if len(words): + sent_out.append(words) + tags_out.append(t.split('-')[0]) + words = s + else: + words += s + if len(sent_out) < len(tags_out): + sent_out.append(words) + outputs.append(''.join( + [str((s, t)) for s, t in zip(sent_out, tags_out)])) + return outputs diff --git a/examples/information_extraction/waybill_ie/model.py b/examples/information_extraction/waybill_ie/model.py new file mode 100644 index 0000000000000..4693f2a5a7bca --- /dev/null +++ b/examples/information_extraction/waybill_ie/model.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn +from paddlenlp.transformers import ErniePretrainedModel +from paddlenlp.layers.crf import LinearChainCrf, ViterbiDecoder, LinearChainCrfLoss + + +class ErnieCrfForTokenClassification(nn.Layer): + def __init__(self, ernie, crf_lr=100): + super().__init__() + self.num_classes = ernie.num_classes + self.ernie = ernie # allow ernie to be config + self.crf = LinearChainCrf( + self.num_classes, crf_lr=crf_lr, with_start_stop_tag=False) + self.crf_loss = LinearChainCrfLoss(self.crf) + self.viterbi_decoder = ViterbiDecoder( + self.crf.transitions, with_start_stop_tag=False) + + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None, + lengths=None, + labels=None): + logits = self.ernie( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids) + + if labels is not None: + loss = self.crf_loss(logits, lengths, labels) + return loss + else: + _, prediction = self.viterbi_decoder(logits, lengths) + return prediction diff --git a/examples/information_extraction/waybill_ie/run_bigru_crf.py b/examples/information_extraction/waybill_ie/run_bigru_crf.py index b458a6dd8d478..9b68715395cc2 100644 --- a/examples/information_extraction/waybill_ie/run_bigru_crf.py +++ b/examples/information_extraction/waybill_ie/run_bigru_crf.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,67 +21,7 @@ from paddlenlp.metrics import ChunkEvaluator from paddlenlp.embeddings import TokenEmbedding - -def parse_decodes(ds, decodes, lens, label_vocab): - decodes = [x for batch in decodes for x in batch] - lens = [x for batch in lens for x in batch] - id_label = dict(zip(label_vocab.values(), label_vocab.keys())) - - outputs = [] - for idx, end in enumerate(lens): - sent = ds.data[idx][0][:end] - tags = [id_label[x] for x in decodes[idx][:end]] - sent_out = [] - tags_out = [] - words = "" - for s, t in zip(sent, tags): - if t.endswith('-B') or t == 'O': - if len(words): - sent_out.append(words) - tags_out.append(t.split('-')[0]) - words = s - else: - words += s - if len(sent_out) < len(tags_out): - sent_out.append(words) - outputs.append(''.join( - [str((s, t)) for s, t in zip(sent_out, tags_out)])) - return outputs - - -def convert_tokens_to_ids(tokens, vocab, oov_token=None): - token_ids = [] - oov_id = vocab.get(oov_token) if oov_token else None - for token in tokens: - token_id = vocab.get(token, oov_id) - token_ids.append(token_id) - return token_ids - - -def load_dict(dict_path): - vocab = {} - i = 0 - for line in open(dict_path, 'r', encoding='utf-8'): - key = line.strip('\n') - vocab[key] = i - i += 1 - return vocab - - -def load_dataset(datafiles): - def read(data_path): - with open(data_path, 'r', encoding='utf-8') as fp: - next(fp) - for line in fp.readlines(): - words, labels = line.strip('\n').split('\t') - words = words.split('\002') - labels = labels.split('\002') - yield words, labels - - if isinstance(datafiles, str): - return MapDataset(list(read(datafiles))) - elif isinstance(datafiles, list) or isinstance(datafiles, tuple): - return [MapDataset(list(read(datafile))) for datafile in datafiles] +from data import load_dict, load_dataset, convert_tokens_to_ids, parse_decodes class BiGRUWithCRF(nn.Layer): @@ -178,7 +118,8 @@ def convert_example(example): model.evaluate(eval_data=test_loader) outputs, lens, decodes = model.predict(test_data=test_loader) - preds = parse_decodes(test_ds, decodes, lens, label_vocab) + sentences = [example[0] for example in test_ds.data] + preds = parse_decodes(sentences, decodes, lens, label_vocab) file_path = "bigru_results.txt" with open(file_path, "w", encoding="utf8") as fout: diff --git a/examples/information_extraction/waybill_ie/run_ernie.py b/examples/information_extraction/waybill_ie/run_ernie.py index db4c96c71eceb..88713b070af0e 100644 --- a/examples/information_extraction/waybill_ie/run_ernie.py +++ b/examples/information_extraction/waybill_ie/run_ernie.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,40 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from functools import partial import paddle -from paddlenlp.datasets import MapDataset from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.transformers import ErnieTokenizer, ErnieForTokenClassification from paddlenlp.metrics import ChunkEvaluator - -def parse_decodes(ds, decodes, lens, label_vocab): - decodes = [x for batch in decodes for x in batch] - lens = [x for batch in lens for x in batch] - id_label = dict(zip(label_vocab.values(), label_vocab.keys())) - - outputs = [] - for idx, end in enumerate(lens): - sent = ds.data[idx][0][:end] - tags = [id_label[x] for x in decodes[idx][1:end]] - sent_out = [] - tags_out = [] - words = "" - for s, t in zip(sent, tags): - if t.endswith('-B') or t == 'O': - if len(words): - sent_out.append(words) - tags_out.append(t.split('-')[0]) - words = s - else: - words += s - if len(sent_out) < len(tags_out): - sent_out.append(words) - outputs.append(''.join( - [str((s, t)) for s, t in zip(sent_out, tags_out)])) - return outputs +from data import load_dict, load_dataset, convert_ernie_example, parse_decodes @paddle.no_grad() @@ -54,60 +29,28 @@ def evaluate(model, metric, data_loader): for input_ids, seg_ids, lens, labels in data_loader: logits = model(input_ids, seg_ids) preds = paddle.argmax(logits, axis=-1) - n_infer, n_label, n_correct = metric.compute(None, lens, preds, labels) + n_infer, n_label, n_correct = metric.compute(lens, preds, labels) metric.update(n_infer.numpy(), n_label.numpy(), n_correct.numpy()) precision, recall, f1_score = metric.accumulate() print("eval precision: %f - recall: %f - f1: %f" % (precision, recall, f1_score)) + model.train() +@paddle.no_grad() def predict(model, data_loader, ds, label_vocab): - pred_list = [] - len_list = [] + all_preds = [] + all_lens = [] for input_ids, seg_ids, lens, labels in data_loader: logits = model(input_ids, seg_ids) - pred = paddle.argmax(logits, axis=-1) - pred_list.append(pred.numpy()) - len_list.append(lens.numpy()) - preds = parse_decodes(ds, pred_list, len_list, label_vocab) - return preds - - -def convert_example(example, tokenizer, label_vocab): - tokens, labels = example - tokenized_input = tokenizer( - tokens, return_length=True, is_split_into_words=True) - # Token '[CLS]' and '[SEP]' will get label 'O' - labels = ['O'] + labels + ['O'] - tokenized_input['labels'] = [label_vocab[x] for x in labels] - return tokenized_input['input_ids'], tokenized_input[ - 'token_type_ids'], tokenized_input['seq_len'], tokenized_input['labels'] - - -def load_dict(dict_path): - vocab = {} - i = 0 - for line in open(dict_path, 'r', encoding='utf-8'): - key = line.strip('\n') - vocab[key] = i - i += 1 - return vocab - - -def load_dataset(datafiles): - def read(data_path): - with open(data_path, 'r', encoding='utf-8') as fp: - next(fp) # Skip header - for line in fp.readlines(): - words, labels = line.strip('\n').split('\t') - words = words.split('\002') - labels = labels.split('\002') - yield words, labels - - if isinstance(datafiles, str): - return MapDataset(list(read(datafiles))) - elif isinstance(datafiles, list) or isinstance(datafiles, tuple): - return [MapDataset(list(read(datafile))) for datafile in datafiles] + preds = paddle.argmax(logits, axis=-1) + # Drop CLS prediction + preds = [pred[1:] for pred in preds.numpy()] + all_preds.append(preds) + all_lens.append(lens) + sentences = [example[0] for example in ds.data] + results = parse_decodes(sentences, all_preds, all_lens, label_vocab) + return results if __name__ == '__main__': @@ -121,7 +64,7 @@ def read(data_path): tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') trans_func = partial( - convert_example, tokenizer=tokenizer, label_vocab=label_vocab) + convert_ernie_example, tokenizer=tokenizer, label_vocab=label_vocab) train_ds.map(trans_func) dev_ds.map(trans_func) diff --git a/examples/information_extraction/waybill_ie/run_ernie_crf.py b/examples/information_extraction/waybill_ie/run_ernie_crf.py new file mode 100644 index 0000000000000..cb065bf416cd2 --- /dev/null +++ b/examples/information_extraction/waybill_ie/run_ernie_crf.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import paddle +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.transformers import ErnieTokenizer, ErnieForTokenClassification +from paddlenlp.metrics import ChunkEvaluator + +from model import ErnieCrfForTokenClassification +from data import load_dict, load_dataset, convert_ernie_example, parse_decodes + + +@paddle.no_grad() +def evaluate(model, metric, data_loader): + model.eval() + metric.reset() + for input_ids, seg_ids, lens, labels in data_loader: + preds = model(input_ids, seg_ids, lengths=lens) + n_infer, n_label, n_correct = metric.compute(lens, preds, labels) + metric.update(n_infer.numpy(), n_label.numpy(), n_correct.numpy()) + precision, recall, f1_score = metric.accumulate() + print("eval precision: %f - recall: %f - f1: %f" % + (precision, recall, f1_score)) + model.train() + + +@paddle.no_grad() +def predict(model, data_loader, ds, label_vocab): + all_preds = [] + all_lens = [] + for input_ids, seg_ids, lens, labels in data_loader: + preds = model(input_ids, seg_ids, lengths=lens) + # Drop CLS prediction + preds = [pred[1:] for pred in preds.numpy()] + all_preds.append(preds) + all_lens.append(lens) + sentences = [example[0] for example in ds.data] + results = parse_decodes(sentences, all_preds, all_lens, label_vocab) + return results + + +if __name__ == '__main__': + paddle.set_device('gpu') + + # Create dataset, tokenizer and dataloader. + train_ds, dev_ds, test_ds = load_dataset(datafiles=( + './data/train.txt', './data/dev.txt', './data/test.txt')) + + label_vocab = load_dict('./data/tag.dic') + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + + trans_func = partial( + convert_ernie_example, tokenizer=tokenizer, label_vocab=label_vocab) + + train_ds.map(trans_func) + dev_ds.map(trans_func) + test_ds.map(trans_func) + + ignore_label = -1 + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'), # input_ids + Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'), # token_type_ids + Stack(dtype='int64'), # seq_len + Pad(axis=0, pad_val=ignore_label, dtype='int64') # labels + ): fn(samples) + + train_loader = paddle.io.DataLoader( + dataset=train_ds, + batch_size=200, + return_list=True, + collate_fn=batchify_fn) + dev_loader = paddle.io.DataLoader( + dataset=dev_ds, + batch_size=200, + return_list=True, + collate_fn=batchify_fn) + test_loader = paddle.io.DataLoader( + dataset=test_ds, + batch_size=200, + return_list=True, + collate_fn=batchify_fn) + + # Define the model netword and its loss + ernie = ErnieForTokenClassification.from_pretrained( + "ernie-1.0", num_classes=len(label_vocab)) + model = ErnieCrfForTokenClassification(ernie) + + metric = ChunkEvaluator(label_list=label_vocab.keys(), suffix=True) + optimizer = paddle.optimizer.AdamW( + learning_rate=2e-5, parameters=model.parameters()) + + step = 0 + for epoch in range(10): + # Switch the model to training mode + model.train() + for idx, (input_ids, token_type_ids, lengths, + labels) in enumerate(train_loader): + loss = model( + input_ids, token_type_ids, lengths=lengths, labels=labels) + avg_loss = paddle.mean(loss) + avg_loss.backward() + optimizer.step() + optimizer.clear_grad() + step += 1 + print("epoch:%d - step:%d - loss: %f" % (epoch, step, avg_loss)) + evaluate(model, metric, dev_loader) + + paddle.save(model.state_dict(), + './ernie_crf_result/model_%d.pdparams' % step) + + preds = predict(model, test_loader, test_ds, label_vocab) + file_path = "ernie_crf_results.txt" + with open(file_path, "w", encoding="utf8") as fout: + fout.write("\n".join(preds)) + # Print some examples + print( + "The results have been saved in the file: %s, some examples are shown below: " + % file_path) + print("\n".join(preds[:10])) diff --git a/examples/lexical_analysis/train.py b/examples/lexical_analysis/train.py index 01aa579178a89..833e2529a6efb 100644 --- a/examples/lexical_analysis/train.py +++ b/examples/lexical_analysis/train.py @@ -134,8 +134,8 @@ def train(args): train_reader_cost = 0.0 train_run_cost = 0.0 total_samples = 0 + reader_start = time.time() for epoch in range(args.epochs): - reader_start = time.time() for step, batch in enumerate(train_loader): train_reader_cost += time.time() - reader_start global_step += 1 @@ -165,6 +165,7 @@ def train(args): paddle.save(model.state_dict(), os.path.join(args.model_save_dir, "model_%d.pdparams" % global_step)) + reader_start = time.time() if __name__ == "__main__": diff --git a/paddlenlp/layers/crf.py b/paddlenlp/layers/crf.py index 63a2012dc0d24..2f7840018bd39 100644 --- a/paddlenlp/layers/crf.py +++ b/paddlenlp/layers/crf.py @@ -48,6 +48,8 @@ def __init__(self, num_labels, crf_lr=0.1, with_start_stop_tag=True): attr=paddle.ParamAttr(learning_rate=crf_lr), shape=[self.num_tags, self.num_tags], dtype='float32') + with paddle.no_grad(): + self.flattened_transition_params = paddle.flatten(self.transitions) self.with_start_stop_tag = with_start_stop_tag self._initial_alpha = None @@ -211,10 +213,9 @@ def _trans_score(self, labels, lengths): # Encode the indices in a flattened representation. transition_indices = start_tag_indices * self.num_tags + stop_tag_indices flattened_transition_indices = transition_indices.reshape([-1]) - flattened_transition_params = self.transitions.reshape([-1]) scores = paddle.gather( - flattened_transition_params, + self.flattened_transition_params, flattened_transition_indices).reshape([batch_size, -1]) mask_scores = scores * mask[:, 1:]