Skip to content

Commit

Permalink
[cherry-pick] PART-2 : Add Dataset and Module for Sequence Classifica…
Browse files Browse the repository at this point in the history
…tion task of Ernie model (#945)

* Support ERNIE Export (#934)

* [Ernie] PART-2 : Add Dataset and Module for Sequence Classification task of Ernie model (#935)

* add ernie export yaml and shell

* update

* add dataset and module

* add tokenizer

* add training/valuation step

Co-authored-by: Chang Xu <molixu7@gmail.com>
  • Loading branch information
haohongxiang and RachelXu7 authored Nov 28, 2022
1 parent f787795 commit 6ffc799
Show file tree
Hide file tree
Showing 15 changed files with 706 additions and 22 deletions.
37 changes: 37 additions & 0 deletions ppfleetx/configs/nlp/ernie/finetune_ernie_345M_single_card.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_base_: ./finetune_ernie_base.yaml

Global:
global_batch_size:
local_batch_size: 8
micro_batch_size: 8


Model:
vocab_size: 40000
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
intermediate_size:
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 512
type_vocab_size: 4
initializer_range: 0.02
pad_token_id: 0
task_type_vocab_size: 3
task_id: 0
use_task_id: True
use_recompute: False


Distributed:
dp_degree:
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
reduce_overlap: False
broadcast_overlap: False
107 changes: 107 additions & 0 deletions ppfleetx/configs/nlp/ernie/finetune_ernie_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
Global:
device: gpu
seed: 1024
binary_head: True

global_batch_size:
local_batch_size: 16
micro_batch_size: 16


Engine:
max_steps: 500000
num_train_epochs: 1
accumulate_steps: 1
logging_freq: 1
eval_freq: 500000
eval_iters: 10
test_iters: -1
mix_precision:
use_pure_fp16: False
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
save_load:
save_steps: 50000
save_epoch: 1
output_dir: ./output
ckpt_dir:


Model:
module: "ErnieSeqClsModule"
name: "Ernie"
hidden_size: 768
num_hidden_layers: 12
num_attention_heads: 12
intermediate_size: 3072
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 512
type_vocab_size: 2
initializer_range: 0.02
pad_token_id: 0
task_type_vocab_size: 3
task_id: 0
use_task_id: False
use_recompute: False


Data:
Train:
dataset:
name: ErnieSeqClsDataset
dataset_type: chnsenticorp_v2
tokenizer_type: ernie-1.0-base-zh-cw
max_seq_len: 512
sampler:
name: GPTBatchSampler
shuffle: False
drop_last: True
loader:
num_workers: 0
return_list: False
collate_fn:
name: DataCollatorWithPadding

Eval:
dataset:
name: ErnieSeqClsDataset
dataset_type: chnsenticorp_v2
tokenizer_type: ernie-1.0-base-zh-cw
max_seq_len: 512
sampler:
name: GPTBatchSampler
shuffle: False
drop_last: True
loader:
num_workers: 0
return_list: False
collate_fn:
name: DataCollatorWithPadding


Optimizer:
name: FusedAdamW
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
epsilon: 1.0e-8
lr:
name: CosineAnnealingWithWarmupDecay
decay_steps: 990000
warmup_rate: 0.01
max_lr: 0.0001
min_lr: 5e-05
grad_clip:
name: "ClipGradByGlobalNorm"
clip_norm: 1.0
tensor_fusion: False


Profiler:
enable: False
scheduler: [1, 5]
profiler_log: profiler_log
detailed: False
18 changes: 18 additions & 0 deletions ppfleetx/configs/nlp/ernie/inference_ernie_345M_single_card.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_: ./pretrain_ernie_base_345M_single_card.yaml


Inference:
model_dir: ./output
mp_degree: 1


Distributed:
dp_degree:
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
reduce_overlap: False
broadcast_overlap: False
2 changes: 1 addition & 1 deletion ppfleetx/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
from .multimodal_dataset import ImagenDataset
from .gpt_dataset import GPTDataset, LM_Eval_Dataset, Lambada_Eval_Dataset
from .glue_dataset import *
from .ernie.ernie_dataset import ErnieDataset
from .ernie.ernie_dataset import ErnieDataset, ErnieSeqClsDataset
158 changes: 158 additions & 0 deletions ppfleetx/data/dataset/ernie/ernie_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import re
import copy
from functools import partial
import paddle

from .dataset_utils import (
Expand All @@ -29,6 +30,7 @@
make_indexed_dataset,
get_indexed_dataset_, )
from paddlenlp.transformers import ErnieTokenizer
from paddlenlp.datasets.dataset import MapDataset, IterableDataset, SimpleBuilder, load_dataset


def get_local_rank():
Expand All @@ -38,6 +40,7 @@ def get_local_rank():
print_rank_0 = print

mode_to_index = {"Train": 0, "Eval": 1, "Test": 2}
mode_to_key = {"Train": "train", "Eval": "dev", "Test": "test"}


class ErnieDataset(paddle.io.Dataset):
Expand Down Expand Up @@ -319,3 +322,158 @@ def get_train_valid_test_split_(splits, size):
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index


class ErnieSeqClsDataset(paddle.io.Dataset):
def __init__(self, dataset_type, tokenizer_type, max_seq_len, mode):
self.dataset = dataset_type
self.max_seq_len = max_seq_len
self.mode = mode_to_key[mode]

from ppfleetx.data.tokenizers import get_ernie_tokenizer
self.tokenizer = get_ernie_tokenizer(tokenizer_type)

dataset_config = self.dataset.split(" ")
raw_datasets = load_dataset(
dataset_config[0],
None if len(dataset_config) <= 1 else dataset_config[1], )
self.label_list = getattr(raw_datasets['train'], "label_list", None)

# Define dataset pre-process function
if "clue" in self.dataset:
trans_fn = partial(self._clue_trans_fn)
else:
trans_fn = partial(self._seq_trans_fn)

self.seqcls_dataset = raw_datasets[self.mode].map(trans_fn)

def __getitem__(self, idx):
return self.seqcls_dataset.__getitem__(idx)

def __len__(self):
return self.seqcls_dataset.__len__()

def _seq_trans_fn(self, example):
return self._convert_example(
example,
tokenizer=self.tokenizer,
max_seq_length=self.max_seq_len, )

def _clue_trans_fn(self, example):
return self._convert_clue(
example,
label_list=self.label_list,
tokenizer=self.tokenizer,
max_seq_length=self.max_seq_len, )

def _convert_example(self,
example,
tokenizer,
max_seq_length=512,
is_test=False):
is_test = True
if 'label' in example.keys():
is_test = False

if "text_b" in example.keys():
text = example["text_a"]
text_pair = example["text_b"]
else:
text = example["text"]
text_pair = None

encoded_inputs = tokenizer(
text=text, text_pair=text_pair, max_seq_len=max_seq_length)
input_ids = encoded_inputs["input_ids"]
token_type_ids = encoded_inputs["token_type_ids"]

if is_test:
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
}
else:
# label = np.array([example["label"]], dtype="int64")
label = int(example["label"])
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"labels": label
}

# Data pre-process function for clue benchmark datatset
def _convert_clue(self,
example,
label_list,
tokenizer=None,
max_seq_length=512,
**kwargs):
"""convert a glue example into necessary features"""
is_test = False
if 'label' not in example.keys():
is_test = True

if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
example['label'] = int(example[
"label"]) if label_dtype != "float32" else float(example[
"label"])
label = example['label']
# Convert raw text to feature
if 'keyword' in example: # CSL
sentence1 = " ".join(example['keyword'])
example = {
'sentence1': sentence1,
'sentence2': example['abst'],
'label': example['label']
}
elif 'target' in example: # wsc
text, query, pronoun, query_idx, pronoun_idx = example[
'text'], example['target']['span1_text'], example['target'][
'span2_text'], example['target']['span1_index'], example[
'target']['span2_index']
text_list = list(text)
assert text[pronoun_idx:(pronoun_idx + len(
pronoun))] == pronoun, "pronoun: {}".format(pronoun)
assert text[query_idx:(query_idx + len(query)
)] == query, "query: {}".format(query)
if pronoun_idx > query_idx:
text_list.insert(query_idx, "_")
text_list.insert(query_idx + len(query) + 1, "_")
text_list.insert(pronoun_idx + 2, "[")
text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
else:
text_list.insert(pronoun_idx, "[")
text_list.insert(pronoun_idx + len(pronoun) + 1, "]")
text_list.insert(query_idx + 2, "_")
text_list.insert(query_idx + len(query) + 2 + 1, "_")
text = "".join(text_list)
example['sentence'] = text

if tokenizer is None:
return example
if 'sentence' in example:
example = tokenizer(
example['sentence'], max_seq_len=max_seq_length)
elif 'sentence1' in example:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)

if not is_test:
if "token_type_ids" in example:
return {
"input_ids": example['input_ids'],
"token_type_ids": example['token_type_ids'],
"labels": label
}
else:
return {"input_ids": example['input_ids'], "labels": label}
else:
return {
"input_ids": example['input_ids'],
"token_type_ids": example['token_type_ids']
}
1 change: 1 addition & 0 deletions ppfleetx/data/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .gpt_tokenizer import GPTTokenizer
from .ernie_tokenizer import get_ernie_tokenizer
25 changes: 25 additions & 0 deletions ppfleetx/data/tokenizers/ernie_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddlenlp.transformers import ErnieTokenizer

tokenizer = None


def get_ernie_tokenizer(tokenizer_type):
global tokenizer
if tokenizer is None:
tokenizer = ErnieTokenizer.from_pretrained(tokenizer_type)

return tokenizer
Loading

0 comments on commit 6ffc799

Please sign in to comment.