From 6ffc799847a5ddaca6085e9e2e96a265ac68d835 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 28 Nov 2022 11:13:08 +0800 Subject: [PATCH] [cherry-pick] PART-2 : Add Dataset and Module for Sequence Classification task of Ernie model (#945) * Support ERNIE Export (#934) * [Ernie] PART-2 : Add Dataset and Module for Sequence Classification task of Ernie model (#935) * add ernie export yaml and shell * update * add dataset and module * add tokenizer * add training/valuation step Co-authored-by: Chang Xu --- .../finetune_ernie_345M_single_card.yaml | 37 ++++ .../nlp/ernie/finetune_ernie_base.yaml | 107 +++++++++++ .../inference_ernie_345M_single_card.yaml | 18 ++ ppfleetx/data/dataset/__init__.py | 2 +- ppfleetx/data/dataset/ernie/ernie_dataset.py | 158 ++++++++++++++++ ppfleetx/data/tokenizers/__init__.py | 1 + ppfleetx/data/tokenizers/ernie_tokenizer.py | 25 +++ ppfleetx/data/utils/batch_collate_fn.py | 42 +++++ ppfleetx/models/__init__.py | 2 +- .../models/language_model/ernie/__init__.py | 2 +- .../ernie/dygraph/hybrid_model.py | 125 ++++++++++++- .../ernie/dygraph/single_model.py | 8 +- .../language_model/ernie/ernie_module.py | 169 +++++++++++++++++- .../ernie/layers/transformer.py | 13 +- .../ernie/export_ernie_345M_single_card.sh | 19 ++ 15 files changed, 706 insertions(+), 22 deletions(-) create mode 100644 ppfleetx/configs/nlp/ernie/finetune_ernie_345M_single_card.yaml create mode 100644 ppfleetx/configs/nlp/ernie/finetune_ernie_base.yaml create mode 100644 ppfleetx/configs/nlp/ernie/inference_ernie_345M_single_card.yaml create mode 100644 ppfleetx/data/tokenizers/ernie_tokenizer.py create mode 100644 projects/ernie/export_ernie_345M_single_card.sh diff --git a/ppfleetx/configs/nlp/ernie/finetune_ernie_345M_single_card.yaml b/ppfleetx/configs/nlp/ernie/finetune_ernie_345M_single_card.yaml new file mode 100644 index 000000000..9488ac904 --- /dev/null +++ b/ppfleetx/configs/nlp/ernie/finetune_ernie_345M_single_card.yaml @@ -0,0 +1,37 @@ +_base_: ./finetune_ernie_base.yaml + +Global: + global_batch_size: + local_batch_size: 8 + micro_batch_size: 8 + + +Model: + vocab_size: 40000 + hidden_size: 1024 + num_hidden_layers: 24 + num_attention_heads: 16 + intermediate_size: + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 4 + initializer_range: 0.02 + pad_token_id: 0 + task_type_vocab_size: 3 + task_id: 0 + use_task_id: True + use_recompute: False + + +Distributed: + dp_degree: + mp_degree: 1 + pp_degree: 1 + sharding: + sharding_degree: 1 + sharding_stage: 1 + sharding_offload: False + reduce_overlap: False + broadcast_overlap: False diff --git a/ppfleetx/configs/nlp/ernie/finetune_ernie_base.yaml b/ppfleetx/configs/nlp/ernie/finetune_ernie_base.yaml new file mode 100644 index 000000000..02793fa85 --- /dev/null +++ b/ppfleetx/configs/nlp/ernie/finetune_ernie_base.yaml @@ -0,0 +1,107 @@ +Global: + device: gpu + seed: 1024 + binary_head: True + + global_batch_size: + local_batch_size: 16 + micro_batch_size: 16 + + +Engine: + max_steps: 500000 + num_train_epochs: 1 + accumulate_steps: 1 + logging_freq: 1 + eval_freq: 500000 + eval_iters: 10 + test_iters: -1 + mix_precision: + use_pure_fp16: False + scale_loss: 32768.0 + custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"] + custom_white_list: ["lookup_table", "lookup_table_v2"] + save_load: + save_steps: 50000 + save_epoch: 1 + output_dir: ./output + ckpt_dir: + + +Model: + module: "ErnieSeqClsModule" + name: "Ernie" + hidden_size: 768 + num_hidden_layers: 12 + num_attention_heads: 12 + intermediate_size: 3072 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 2 + initializer_range: 0.02 + pad_token_id: 0 + task_type_vocab_size: 3 + task_id: 0 + use_task_id: False + use_recompute: False + + +Data: + Train: + dataset: + name: ErnieSeqClsDataset + dataset_type: chnsenticorp_v2 + tokenizer_type: ernie-1.0-base-zh-cw + max_seq_len: 512 + sampler: + name: GPTBatchSampler + shuffle: False + drop_last: True + loader: + num_workers: 0 + return_list: False + collate_fn: + name: DataCollatorWithPadding + + Eval: + dataset: + name: ErnieSeqClsDataset + dataset_type: chnsenticorp_v2 + tokenizer_type: ernie-1.0-base-zh-cw + max_seq_len: 512 + sampler: + name: GPTBatchSampler + shuffle: False + drop_last: True + loader: + num_workers: 0 + return_list: False + collate_fn: + name: DataCollatorWithPadding + + +Optimizer: + name: FusedAdamW + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + epsilon: 1.0e-8 + lr: + name: CosineAnnealingWithWarmupDecay + decay_steps: 990000 + warmup_rate: 0.01 + max_lr: 0.0001 + min_lr: 5e-05 + grad_clip: + name: "ClipGradByGlobalNorm" + clip_norm: 1.0 + tensor_fusion: False + + +Profiler: + enable: False + scheduler: [1, 5] + profiler_log: profiler_log + detailed: False diff --git a/ppfleetx/configs/nlp/ernie/inference_ernie_345M_single_card.yaml b/ppfleetx/configs/nlp/ernie/inference_ernie_345M_single_card.yaml new file mode 100644 index 000000000..79929d69f --- /dev/null +++ b/ppfleetx/configs/nlp/ernie/inference_ernie_345M_single_card.yaml @@ -0,0 +1,18 @@ +_base_: ./pretrain_ernie_base_345M_single_card.yaml + + +Inference: + model_dir: ./output + mp_degree: 1 + + +Distributed: + dp_degree: + mp_degree: 1 + pp_degree: 1 + sharding: + sharding_degree: 1 + sharding_stage: 1 + sharding_offload: False + reduce_overlap: False + broadcast_overlap: False diff --git a/ppfleetx/data/dataset/__init__.py b/ppfleetx/data/dataset/__init__.py index 27ecf2c11..43356a82b 100644 --- a/ppfleetx/data/dataset/__init__.py +++ b/ppfleetx/data/dataset/__init__.py @@ -21,4 +21,4 @@ from .multimodal_dataset import ImagenDataset from .gpt_dataset import GPTDataset, LM_Eval_Dataset, Lambada_Eval_Dataset from .glue_dataset import * -from .ernie.ernie_dataset import ErnieDataset +from .ernie.ernie_dataset import ErnieDataset, ErnieSeqClsDataset diff --git a/ppfleetx/data/dataset/ernie/ernie_dataset.py b/ppfleetx/data/dataset/ernie/ernie_dataset.py index 370bc5329..aebe9d983 100644 --- a/ppfleetx/data/dataset/ernie/ernie_dataset.py +++ b/ppfleetx/data/dataset/ernie/ernie_dataset.py @@ -18,6 +18,7 @@ import numpy as np import re import copy +from functools import partial import paddle from .dataset_utils import ( @@ -29,6 +30,7 @@ make_indexed_dataset, get_indexed_dataset_, ) from paddlenlp.transformers import ErnieTokenizer +from paddlenlp.datasets.dataset import MapDataset, IterableDataset, SimpleBuilder, load_dataset def get_local_rank(): @@ -38,6 +40,7 @@ def get_local_rank(): print_rank_0 = print mode_to_index = {"Train": 0, "Eval": 1, "Test": 2} +mode_to_key = {"Train": "train", "Eval": "dev", "Test": "test"} class ErnieDataset(paddle.io.Dataset): @@ -319,3 +322,158 @@ def get_train_valid_test_split_(splits, size): assert len(splits_index) == 4 assert splits_index[-1] == size return splits_index + + +class ErnieSeqClsDataset(paddle.io.Dataset): + def __init__(self, dataset_type, tokenizer_type, max_seq_len, mode): + self.dataset = dataset_type + self.max_seq_len = max_seq_len + self.mode = mode_to_key[mode] + + from ppfleetx.data.tokenizers import get_ernie_tokenizer + self.tokenizer = get_ernie_tokenizer(tokenizer_type) + + dataset_config = self.dataset.split(" ") + raw_datasets = load_dataset( + dataset_config[0], + None if len(dataset_config) <= 1 else dataset_config[1], ) + self.label_list = getattr(raw_datasets['train'], "label_list", None) + + # Define dataset pre-process function + if "clue" in self.dataset: + trans_fn = partial(self._clue_trans_fn) + else: + trans_fn = partial(self._seq_trans_fn) + + self.seqcls_dataset = raw_datasets[self.mode].map(trans_fn) + + def __getitem__(self, idx): + return self.seqcls_dataset.__getitem__(idx) + + def __len__(self): + return self.seqcls_dataset.__len__() + + def _seq_trans_fn(self, example): + return self._convert_example( + example, + tokenizer=self.tokenizer, + max_seq_length=self.max_seq_len, ) + + def _clue_trans_fn(self, example): + return self._convert_clue( + example, + label_list=self.label_list, + tokenizer=self.tokenizer, + max_seq_length=self.max_seq_len, ) + + def _convert_example(self, + example, + tokenizer, + max_seq_length=512, + is_test=False): + is_test = True + if 'label' in example.keys(): + is_test = False + + if "text_b" in example.keys(): + text = example["text_a"] + text_pair = example["text_b"] + else: + text = example["text"] + text_pair = None + + encoded_inputs = tokenizer( + text=text, text_pair=text_pair, max_seq_len=max_seq_length) + input_ids = encoded_inputs["input_ids"] + token_type_ids = encoded_inputs["token_type_ids"] + + if is_test: + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + } + else: + # label = np.array([example["label"]], dtype="int64") + label = int(example["label"]) + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "labels": label + } + + # Data pre-process function for clue benchmark datatset + def _convert_clue(self, + example, + label_list, + tokenizer=None, + max_seq_length=512, + **kwargs): + """convert a glue example into necessary features""" + is_test = False + if 'label' not in example.keys(): + is_test = True + + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + example['label'] = int(example[ + "label"]) if label_dtype != "float32" else float(example[ + "label"]) + label = example['label'] + # Convert raw text to feature + if 'keyword' in example: # CSL + sentence1 = " ".join(example['keyword']) + example = { + 'sentence1': sentence1, + 'sentence2': example['abst'], + 'label': example['label'] + } + elif 'target' in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = example[ + 'text'], example['target']['span1_text'], example['target'][ + 'span2_text'], example['target']['span1_index'], example[ + 'target']['span2_index'] + text_list = list(text) + assert text[pronoun_idx:(pronoun_idx + len( + pronoun))] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + example['sentence'] = text + + if tokenizer is None: + return example + if 'sentence' in example: + example = tokenizer( + example['sentence'], max_seq_len=max_seq_length) + elif 'sentence1' in example: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + + if not is_test: + if "token_type_ids" in example: + return { + "input_ids": example['input_ids'], + "token_type_ids": example['token_type_ids'], + "labels": label + } + else: + return {"input_ids": example['input_ids'], "labels": label} + else: + return { + "input_ids": example['input_ids'], + "token_type_ids": example['token_type_ids'] + } diff --git a/ppfleetx/data/tokenizers/__init__.py b/ppfleetx/data/tokenizers/__init__.py index b01b642a3..e5779db4d 100644 --- a/ppfleetx/data/tokenizers/__init__.py +++ b/ppfleetx/data/tokenizers/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .gpt_tokenizer import GPTTokenizer +from .ernie_tokenizer import get_ernie_tokenizer diff --git a/ppfleetx/data/tokenizers/ernie_tokenizer.py b/ppfleetx/data/tokenizers/ernie_tokenizer.py new file mode 100644 index 000000000..f1e0a3114 --- /dev/null +++ b/ppfleetx/data/tokenizers/ernie_tokenizer.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlenlp.transformers import ErnieTokenizer + +tokenizer = None + + +def get_ernie_tokenizer(tokenizer_type): + global tokenizer + if tokenizer is None: + tokenizer = ErnieTokenizer.from_pretrained(tokenizer_type) + + return tokenizer diff --git a/ppfleetx/data/utils/batch_collate_fn.py b/ppfleetx/data/utils/batch_collate_fn.py index c75b612d6..eca743ba4 100644 --- a/ppfleetx/data/utils/batch_collate_fn.py +++ b/ppfleetx/data/utils/batch_collate_fn.py @@ -18,6 +18,7 @@ import sys import numbers import numpy as np +from dataclasses import dataclass try: from collections.abc import Sequence, Mapping @@ -145,6 +146,47 @@ def __call__(self, data): return all_data +@dataclass +class DataCollatorWithPadding: + """ + Data collator that will dynamically pad the inputs to the longest sequence in the batch. + + Args: + tokenizer_type (str): The type of tokenizer used for encoding the data. + """ + + def __init__(self, + tokenizer_type, + padding=True, + max_length=None, + pad_to_multiple_of=None, + return_tensors="pd", + return_attention_mask=None): + from ppfleetx.data.tokenizers import get_ernie_tokenizer + self.tokenizer = get_ernie_tokenizer(tokenizer_type) + self.padding = padding + self.max_length = max_length + self.pad_to_multiple_of = pad_to_multiple_of + self.return_tensors = return_tensors + self.return_attention_mask = return_attention_mask + + def __call__(self, features): + batch = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + return_attention_mask=self.return_attention_mask) + if "label" in batch: + batch["labels"] = batch["label"] + del batch["label"] + if "label_ids" in batch: + batch["labels"] = batch["label_ids"] + del batch["label_ids"] + return batch + + def imagen_collate_fn(batch): """ collate for imagen base64 """ text_embs = [] diff --git a/ppfleetx/models/__init__.py b/ppfleetx/models/__init__.py index 60417e9fe..483697c3a 100644 --- a/ppfleetx/models/__init__.py +++ b/ppfleetx/models/__init__.py @@ -21,7 +21,7 @@ from ppfleetx.models.vision_model.general_classification_module import GeneralClsModule from ppfleetx.models.vision_model.moco_module import MOCOModule, MOCOClsModule from ppfleetx.models.multimodal_model.multimodal_module import ImagenModule -from ppfleetx.models.language_model.ernie import ErnieModule +from ppfleetx.models.language_model.ernie import ErnieModule, ErnieSeqClsModule from ppfleetx.models.language_model.language_module import MoEModule from ppfleetx.models.multimodal_model.multimodal_module import ImagenModule diff --git a/ppfleetx/models/language_model/ernie/__init__.py b/ppfleetx/models/language_model/ernie/__init__.py index 14ebfae6b..336cbb6ef 100644 --- a/ppfleetx/models/language_model/ernie/__init__.py +++ b/ppfleetx/models/language_model/ernie/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .ernie_module import ErnieModule +from .ernie_module import ErnieModule, ErnieSeqClsModule diff --git a/ppfleetx/models/language_model/ernie/dygraph/hybrid_model.py b/ppfleetx/models/language_model/ernie/dygraph/hybrid_model.py index ceea28c1f..09dbc78b8 100644 --- a/ppfleetx/models/language_model/ernie/dygraph/hybrid_model.py +++ b/ppfleetx/models/language_model/ernie/dygraph/hybrid_model.py @@ -26,7 +26,8 @@ from ..layers.model_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput, - ErnieForPreTrainingOutput, ) + ErnieForPreTrainingOutput, + SequenceClassifierOutput, ) from ..layers.distributed_transformer import TransformerEncoderLayer, TransformerEncoder from paddle.distributed import fleet @@ -242,6 +243,7 @@ def __init__(self, self.hidden_size = hidden_size self.vocab_size = vocab_size self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob weight_attr = paddle.ParamAttr( initializer=nn.initializer.TruncatedNormal( @@ -867,3 +869,124 @@ def __init__(self, "offload": False, "partition": False }) + + +class ErnieForSequenceClassificationHybrid(nn.Layer): + """ + Ernie Model with a linear layer on top of the output layer, + designed for sequence classification/regression tasks like GLUE tasks. + + Args: + ernie (:class:`ErnieModel`): + An instance of ErnieModel. + num_classes (int, optional): + The number of classes. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of ERNIE. + If None, use the same value as `hidden_dropout_prob` of `ErnieModel` + instance `ernie`. Defaults to None. + """ + + def __init__(self, ernie, num_classes=2, dropout=None): + super(ErnieForSequenceClassificationHybrid, self).__init__() + self.num_classes = num_classes + self.ernie = ernie # allow ernie to be config + self.dropout = nn.Dropout(dropout if dropout is not None else + self.ernie.hidden_dropout_prob) + self.classifier = nn.Linear(self.ernie.hidden_size, num_classes) + self.apply(self.init_weights) + + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None, + labels=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): + r""" + The ErnieForSequenceClassification forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`ErnieModelHybrid`. + token_type_ids (Tensor, optional): + See :class:`ErnieModelHybrid`. + position_ids(Tensor, optional): + See :class:`ErnieModelHybrid`. + attention_mask (Tensor, optional): + See :class:`ErnieModelHybrid`. + labels (Tensor of shape `(batch_size,)`, optional): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1` + a regression loss is computed (Mean-Square loss), If `num_classes > 1` + a classification loss is computed (Cross-Entropy). + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~ppfleetx.models.language_model.ernie.layers.model_outputs.SequenceClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. + + Returns: + An instance of :class:`~ppfleetx.models.language_model.ernie.layers.model_outputs.SequenceClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~ppfleetx.models.language_model.ernie.layers.model_outputs.SequenceClassifierOutput`. + + """ + + outputs = self.ernie( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_classes == 1: + loss_fct = paddle.nn.MSELoss() + loss = loss_fct(logits, labels) + elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct( + logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) + else: + loss_fct = paddle.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else ( + output[0] if len(output) == 1 else output) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) + + def init_weights(self, layer): + """ Initialization hook """ + if isinstance(layer, (nn.Linear, nn.Embedding)): + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.initializer_range + if hasattr(self, "initializer_range") else + self.ernie.initializer_range, + shape=layer.weight.shape)) + elif isinstance(layer, nn.LayerNorm): + layer._epsilon = 1e-12 diff --git a/ppfleetx/models/language_model/ernie/dygraph/single_model.py b/ppfleetx/models/language_model/ernie/dygraph/single_model.py index 3e234a87e..45f033e0e 100644 --- a/ppfleetx/models/language_model/ernie/dygraph/single_model.py +++ b/ppfleetx/models/language_model/ernie/dygraph/single_model.py @@ -206,6 +206,7 @@ def __init__(self, self.hidden_size = hidden_size self.vocab_size = vocab_size self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob weight_attr = paddle.ParamAttr( initializer=nn.initializer.TruncatedNormal( @@ -664,9 +665,8 @@ def __init__(self, ernie, num_classes=2, dropout=None): self.num_classes = num_classes self.ernie = ernie # allow ernie to be config self.dropout = nn.Dropout(dropout if dropout is not None else - self.ernie.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.ernie.config["hidden_size"], - num_classes) + self.ernie.hidden_dropout_prob) + self.classifier = nn.Linear(self.ernie.hidden_size, num_classes) self.apply(self.init_weights) def forward(self, @@ -759,7 +759,7 @@ def init_weights(self, layer): mean=0.0, std=self.initializer_range if hasattr(self, "initializer_range") else - self.ernie.config["initializer_range"], + self.ernie.initializer_range, shape=layer.weight.shape)) elif isinstance(layer, nn.LayerNorm): layer._epsilon = 1e-12 diff --git a/ppfleetx/models/language_model/ernie/ernie_module.py b/ppfleetx/models/language_model/ernie/ernie_module.py index ab624459f..565e2752b 100644 --- a/ppfleetx/models/language_model/ernie/ernie_module.py +++ b/ppfleetx/models/language_model/ernie/ernie_module.py @@ -14,15 +14,25 @@ import sys import copy +from collections.abc import Mapping import paddle +from paddle.static import InputSpec +import paddle.nn as nn from ppfleetx.core.module.basic_module import BasicModule import ppfleetx.models.language_model.gpt as gpt from ppfleetx.utils.log import logger -from .dygraph.single_model import ErnieModel, ErnieForPretraining, ErniePretrainingCriterion -from .dygraph.hybrid_model import ErnieModelHybrid, ErnieForPretrainingHybrid, ErniePretrainingCriterionHybrid, ErnieForPretrainingPipe +from .dygraph.single_model import ( + ErnieModel, + ErnieForPretraining, + ErniePretrainingCriterion, + ErnieForSequenceClassification, ) +from .dygraph.hybrid_model import (ErnieModelHybrid, ErnieForPretrainingHybrid, + ErniePretrainingCriterionHybrid, + ErnieForPretrainingPipe, + ErnieForSequenceClassificationHybrid) from ppfleetx.models.language_model.utils import process_configs @@ -162,3 +172,158 @@ def training_step_end(self, log_dict): "ips_total: %.0f tokens/s, ips: %.0f tokens/s, learning rate: %.5e" % (log_dict['epoch'], log_dict['batch'], log_dict['loss'], log_dict['train_cost'], speed, speed * default_global_tokens_num, speed * default_global_tokens_num / self.nranks, log_dict['lr'])) + + def input_spec(self): + return [ + InputSpec( + shape=[None, None], dtype='int64'), InputSpec( + shape=[None, None], dtype='int64'), InputSpec( + shape=[None, None], dtype='int64') + ] + + +class ErnieSeqClsModule(BasicModule): + def __init__(self, configs): + self.nranks = paddle.distributed.get_world_size() + super(ErnieSeqClsModule, self).__init__(configs) + + self.criterion = nn.loss.CrossEntropyLoss( + ) # if data_args.label_list else nn.loss.MSELoss() + + self.past_index = -1 + self.past = None + self.label_names = (["start_positions", "end_positions"] \ + if "QusetionAnswering" in type(self.model).__name__ else ["labels"]) + + def process_configs(self, configs): + process_model_configs(configs) + + cfg_global = configs['Global'] + cfg_data = configs['Data'] + + for mode in ("Train", "Eval", "Test"): + if mode in cfg_data.keys(): + cfg_data[mode]['dataset']['mode'] = mode + cfg_data[mode]['sampler']['batch_size'] = cfg_global[ + 'local_batch_size'] + cfg_data[mode]['loader']['collate_fn'].setdefault( + 'tokenizer_type', + cfg_data[mode]['dataset']['tokenizer_type']) + + return configs + + def get_model(self): + model_setting = copy.deepcopy(self.configs.Model) + model_setting.pop("module") + model_setting.pop("name") + + if self.nranks > 1: + model_setting[ + 'num_partitions'] = self.configs.Distributed.mp_degree + + if self.configs.Distributed.pp_degree == 1: + model = ErnieForSequenceClassificationHybrid( + ErnieModelHybrid(**model_setting)) + else: + raise ValueError( + "Pipeline Parallelism is not supported in Sequence \ + Classification task of Ernie model.") + else: + model = ErnieForSequenceClassification(ErnieModel(**model_setting)) + + return model + + def prepare_input(self, data): + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, Mapping): + return type(data)( + {k: self.prepare_input(v) + for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self.prepare_input(v) for v in data) + elif isinstance(data, paddle.Tensor): + # kwargs = dict(device=self.args.current_device) + # update data type for pure fp16 + return data + # return data.to(**kwargs) + return data + + def pretreating_batch(self, batch): + self.has_labels = all( + batch.get(k) is not None for k in self.label_names) + + batch = self.prepare_input(batch) + if self.past_index >= 0 and self.past is not None: + batch["mems"] = self.past + + return batch + + def forward(self, inputs): + return self.model(**inputs) + + def compute_loss(self, inputs, return_outputs=False): + if "labels" in inputs: + labels = inputs.pop("labels") + elif "start_positions" in inputs and "end_positions" in inputs: + labels = (inputs.pop("start_positions"), + inputs.pop("end_positions")) + elif "generator_labels" in inputs: + labels = inputs["generator_labels"] + else: + labels = None + outputs = self(inputs) + + loss = self.criterion(outputs, labels) + outputs = (loss, outputs) + + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.past_index >= 0: + self.past = outputs[self.args.past_index] + + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss + + def training_step(self, batch): + return self.compute_loss(batch) + + def training_step_end(self, log_dict): + speed = 1. / log_dict['train_cost'] + default_global_tokens_num = self.configs.Global.global_batch_size * \ + self.configs.Data.Train.dataset.max_seq_len + + logger.info( + "[train] epoch: %d, batch: %d, loss: %.9f, avg_batch_cost: %.5f sec, speed: %.2f step/s, " \ + "ips_total: %.0f tokens/s, ips: %.0f tokens/s, learning rate: %.5e" + % (log_dict['epoch'], log_dict['batch'], log_dict['loss'], log_dict['train_cost'], speed, + speed * default_global_tokens_num, speed * default_global_tokens_num / self.nranks, log_dict['lr'])) + + def input_spec(self): + input_spec = [ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int64") # segment_ids + ] + return input_spec + + def validation_step(self, inputs): + if self.has_labels: + loss, outputs = self.compute_loss(inputs, return_outputs=True) + loss = loss.mean().detach() + + else: + loss = None + + return loss + + def validation_step_end(self, log_dict): + speed = 1. / log_dict['eval_cost'] + logger.info( + "[eval] epoch: %d, batch: %d, loss: %.9f, avg_eval_cost: %.5f sec, speed: %.2f step/s" + % (log_dict['epoch'], log_dict['batch'], log_dict['loss'], + log_dict['eval_cost'], speed)) diff --git a/ppfleetx/models/language_model/ernie/layers/transformer.py b/ppfleetx/models/language_model/ernie/layers/transformer.py index f5f6f0b9d..7aa32f968 100644 --- a/ppfleetx/models/language_model/ernie/layers/transformer.py +++ b/ppfleetx/models/language_model/ernie/layers/transformer.py @@ -758,18 +758,7 @@ def forward(self, all_hidden_states[-1] = output if not return_dict: - outputs = tuple( - tuple(v) if isinstance(v, list) else v - for v in [ - output, - new_caches, - all_hidden_states, - all_attentions, - ] if v is not None) - if len(outputs) == 1: - return output - else: - return outputs + return output return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=output, diff --git a/projects/ernie/export_ernie_345M_single_card.sh b/projects/ernie/export_ernie_345M_single_card.sh new file mode 100644 index 000000000..4015d95e3 --- /dev/null +++ b/projects/ernie/export_ernie_345M_single_card.sh @@ -0,0 +1,19 @@ +#! /bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +export CUDA_VISIBLE_DEVICES=0 +python ./tools/export.py -c ./ppfleetx/configs/nlp/ernie/inference_ernie_345M_single_card.yaml