diff --git a/examples/glue/run_glue.py b/examples/glue/run_glue.py index b2baee5f747c16..1ef08166d4436b 100644 --- a/examples/glue/run_glue.py +++ b/examples/glue/run_glue.py @@ -26,8 +26,8 @@ from paddle.io import DataLoader from paddle.metric import Metric, Accuracy, Precision, Recall -from paddlenlp.datasets import GlueCoLA, GlueSST2, GlueMRPC, GlueSTSB, GlueQQP, GlueMNLI, GlueQNLI, GlueRTE -from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad, Dict from paddlenlp.data.sampler import SamplerHelper from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer @@ -40,14 +40,14 @@ logger = logging.getLogger(__name__) TASK_CLASSES = { - "cola": (GlueCoLA, Mcc), - "sst-2": (GlueSST2, Accuracy), - "mrpc": (GlueMRPC, AccuracyAndF1), - "sts-b": (GlueSTSB, PearsonAndSpearman), - "qqp": (GlueQQP, AccuracyAndF1), - "mnli": (GlueMNLI, Accuracy), - "qnli": (GlueQNLI, Accuracy), - "rte": (GlueRTE, Accuracy), + "cola": Mcc, + "sst-2": Accuracy, + "mrpc": AccuracyAndF1, + "sts-b": PearsonAndSpearman, + "qqp": AccuracyAndF1, + "mnli": Accuracy, + "qnli": Accuracy, + "rte": Accuracy, } MODEL_CLASSES = { @@ -211,66 +211,25 @@ def convert_example(example, max_seq_length=512, is_test=False): """convert a glue example into necessary features""" - - def _truncate_seqs(seqs, max_seq_length): - if len(seqs) == 1: # single sentence - # Account for [CLS] and [SEP] with "- 2" - seqs[0] = seqs[0][0:(max_seq_length - 2)] - else: # Sentence pair - # Account for [CLS], [SEP], [SEP] with "- 3" - tokens_a, tokens_b = seqs - max_seq_length -= 3 - while True: # Truncate with longest_first strategy - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_seq_length: - break - if len(tokens_a) > len(tokens_b): - tokens_a.pop() - else: - tokens_b.pop() - return seqs - - def _concat_seqs(seqs, separators, seq_mask=0, separator_mask=1): - concat = sum((seq + sep for sep, seq in zip(separators, seqs)), []) - segment_ids = sum( - ([i] * (len(seq) + len(sep)) - for i, (sep, seq) in enumerate(zip(separators, seqs))), []) - if isinstance(seq_mask, int): - seq_mask = [[seq_mask] * len(seq) for seq in seqs] - if isinstance(separator_mask, int): - separator_mask = [[separator_mask] * len(sep) for sep in separators] - p_mask = sum((s_mask + mask - for sep, seq, s_mask, mask in zip( - separators, seqs, seq_mask, separator_mask)), []) - return concat, segment_ids, p_mask - if not is_test: # `label_list == None` is for regression task label_dtype = "int64" if label_list else "float32" # Get the label - label = example[-1] - example = example[:-1] - # Create label maps if classification task - if label_list: - label_map = {} - for (i, l) in enumerate(label_list): - label_map[l] = i - label = label_map[label] + label = example['labels'] label = np.array([label], dtype=label_dtype) - - # Tokenize raw text - if len(example) == 1: - example = tokenizer(example[0], max_seq_len=max_seq_length) + # Convert raw text to feature + if len(example) == 2: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) else: example = tokenizer( - example[0], text_pair=example[1], max_seq_len=max_seq_length) + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) if not is_test: - return example['input_ids'], example['token_type_ids'], len(example[ - 'input_ids']), label + return example['input_ids'], example['token_type_ids'], label else: - return example['input_ids'], example['token_type_ids'], len(example[ - 'input_ids']) + return example['input_ids'], example['token_type_ids'] def do_train(args): @@ -281,69 +240,67 @@ def do_train(args): set_seed(args) args.task_name = args.task_name.lower() - dataset_class, metric_class = TASK_CLASSES[args.task_name] + metric_class = TASK_CLASSES[args.task_name] args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - train_dataset = dataset_class.get_datasets(["train"]) + train_ds = load_dataset('glue', args.task_name, splits="train") tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) trans_func = partial( convert_example, tokenizer=tokenizer, - label_list=train_dataset.get_labels(), + label_list=train_ds.label_list, max_seq_length=args.max_seq_length) - train_dataset = train_dataset.apply(trans_func, lazy=True) + train_ds = train_ds.map(trans_func, lazy=True) train_batch_sampler = paddle.io.DistributedBatchSampler( - train_dataset, batch_size=args.batch_size, shuffle=True) + train_ds, batch_size=args.batch_size, shuffle=True) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # input Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment - Stack(), # length - Stack(dtype="int64" if train_dataset.get_labels() else "float32") # label - ): [data for i, data in enumerate(fn(samples)) if i != 2] + Stack(dtype="int64" if train_ds.label_list else "float32") # label + ): fn(samples) train_data_loader = DataLoader( - dataset=train_dataset, + dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True) if args.task_name == "mnli": - dev_dataset_matched, dev_dataset_mismatched = dataset_class.get_datasets( - ["dev_matched", "dev_mismatched"]) - dev_dataset_matched = dev_dataset_matched.apply(trans_func, lazy=True) - dev_dataset_mismatched = dev_dataset_mismatched.apply( - trans_func, lazy=True) + dev_ds_matched, dev_ds_mismatched = load_dataset( + 'glue', args.task_name, splits=["dev_matched", "dev_mismatched"]) + + dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True) + dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True) dev_batch_sampler_matched = paddle.io.BatchSampler( - dev_dataset_matched, batch_size=args.batch_size, shuffle=False) + dev_ds_matched, batch_size=args.batch_size, shuffle=False) dev_data_loader_matched = DataLoader( - dataset=dev_dataset_matched, + dataset=dev_ds_matched, batch_sampler=dev_batch_sampler_matched, collate_fn=batchify_fn, num_workers=0, return_list=True) dev_batch_sampler_mismatched = paddle.io.BatchSampler( - dev_dataset_mismatched, batch_size=args.batch_size, shuffle=False) + dev_ds_mismatched, batch_size=args.batch_size, shuffle=False) dev_data_loader_mismatched = DataLoader( - dataset=dev_dataset_mismatched, + dataset=dev_ds_mismatched, batch_sampler=dev_batch_sampler_mismatched, collate_fn=batchify_fn, num_workers=0, return_list=True) else: - dev_dataset = dataset_class.get_datasets(["dev"]) - dev_dataset = dev_dataset.apply(trans_func, lazy=True) + dev_ds = load_dataset('glue', args.task_name, splits='dev') + dev_ds = dev_ds.map(trans_func, lazy=True) dev_batch_sampler = paddle.io.BatchSampler( - dev_dataset, batch_size=args.batch_size, shuffle=False) + dev_ds, batch_size=args.batch_size, shuffle=False) dev_data_loader = DataLoader( - dataset=dev_dataset, + dataset=dev_ds, batch_sampler=dev_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True) - num_classes = 1 if train_dataset.get_labels() == None else len( - train_dataset.get_labels()) + num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list) model = model_class.from_pretrained( args.model_name_or_path, num_classes=num_classes) if paddle.distributed.get_world_size() > 1: @@ -368,8 +325,8 @@ def do_train(args): if not any(nd in n for nd in ["bias", "norm"]) ]) - loss_fct = paddle.nn.loss.CrossEntropyLoss() if train_dataset.get_labels( - ) else paddle.nn.loss.MSELoss() + loss_fct = paddle.nn.loss.CrossEntropyLoss( + ) if train_ds.label_list else paddle.nn.loss.MSELoss() metric = metric_class() @@ -378,6 +335,7 @@ def do_train(args): for epoch in range(args.num_train_epochs): for step, batch in enumerate(train_data_loader): global_step += 1 + input_ids, segment_ids, labels = batch logits = model(input_ids, segment_ids) loss = loss_fct(logits, labels) @@ -392,7 +350,7 @@ def do_train(args): paddle.distributed.get_rank(), loss, optimizer.get_lr(), args.logging_steps / (time.time() - tic_train))) tic_train = time.time() - if global_step % args.save_steps == 0: + if global_step % args.save_steps == 0 or global_step == num_training_steps: tic_eval = time.time() if args.task_name == "mnli": evaluate(model, loss_fct, metric, dev_data_loader_matched) diff --git a/examples/machine_reading_comprehension/DuReader-robust/run_du.py b/examples/machine_reading_comprehension/DuReader-robust/run_du.py index d139de74cf8b4d..6c70f82428a137 100644 --- a/examples/machine_reading_comprehension/DuReader-robust/run_du.py +++ b/examples/machine_reading_comprehension/DuReader-robust/run_du.py @@ -123,7 +123,6 @@ def prepare_train_features(examples): questions, contexts, stride=args.doc_stride, - pad_to_max_seq_len=True, max_seq_len=args.max_seq_length) # Let's label those examples! @@ -154,9 +153,11 @@ def prepare_train_features(examples): token_start_index += 1 # End token index of the current span in the text. - token_end_index = len(input_ids) - 2 + token_end_index = len(input_ids) - 1 while sequence_ids[token_end_index] != 1: token_end_index -= 1 + # Minus one more to reach actual text + token_end_index -= 1 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). if not (offsets[token_start_index][0] <= start_char and diff --git a/examples/machine_reading_comprehension/SQuAD/run_squad.py b/examples/machine_reading_comprehension/SQuAD/run_squad.py index 579d6687234fb9..800ff8a5ee89d0 100644 --- a/examples/machine_reading_comprehension/SQuAD/run_squad.py +++ b/examples/machine_reading_comprehension/SQuAD/run_squad.py @@ -158,9 +158,11 @@ def prepare_train_features(examples): token_start_index += 1 # End token index of the current span in the text. - token_end_index = len(input_ids) - 2 + token_end_index = len(input_ids) - 1 while sequence_ids[token_end_index] != 1: token_end_index -= 1 + # Minus one more to reach actual text + token_end_index -= 1 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). if not (offsets[token_start_index][0] <= start_char and diff --git a/paddlenlp/datasets/experimental/cmrc2018.py b/paddlenlp/datasets/experimental/cmrc2018.py index d7312c5c6f76e0..cc58a5caedb907 100644 --- a/paddlenlp/datasets/experimental/cmrc2018.py +++ b/paddlenlp/datasets/experimental/cmrc2018.py @@ -31,11 +31,10 @@ def _get_data(self, mode, **kwargs): if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): get_path_from_url(URL, default_root) - fullname = os.path.join(default_root, filename) return fullname - def _read(self, filename): + def _read(self, filename, *args): with open(filename, "r", encoding="utf8") as f: input_data = json.load(f)["data"] for entry in input_data: diff --git a/paddlenlp/datasets/experimental/dataset.py b/paddlenlp/datasets/experimental/dataset.py index 7f8616e679f95a..6655b0a363747c 100644 --- a/paddlenlp/datasets/experimental/dataset.py +++ b/paddlenlp/datasets/experimental/dataset.py @@ -53,11 +53,14 @@ def import_main_class(module_path): return module_main_cls -def load_dataset(name, data_files=None, splits=None, lazy=None): - module_path = DATASETS_MODULE_PATH + name +def load_dataset(path, name=None, data_files=None, splits=None, lazy=None): + module_path = DATASETS_MODULE_PATH + path reader_cls = import_main_class(module_path) - reader_instance = reader_cls(lazy) + if not name: + reader_instance = reader_cls(lazy=lazy) + else: + reader_instance = reader_cls(lazy=lazy, name=name) datasets = reader_instance.read_datasets( data_files=data_files, splits=splits) @@ -317,25 +320,28 @@ class DatasetBuilder: """ lazy = False - def __init__(self, lazy=None, max_examples: Optional[int]=None): + def __init__(self, lazy=None, name=None): if lazy is not None: self.lazy = lazy - self.max_examples = max_examples + self.name = name def read_datasets(self, splits=None, data_files=None): datasets = [] assert splits or data_files, "`data_files` and `splits` can not both be None." if data_files: - assert isinstance(data_files, str) or ( - isinstance(data_files, list) and isinstance(data_files[0], str) - ) or ( - isinstance(data_files, tuple) and isinstance(data_files[0], str) - ), "`data_files` should be a string or list of string or a tuple of string." + assert isinstance(data_files, str) or isinstance( + data_files, dict + ), "`data_files` should be a string or a dictionary whose key is split name ande value is a path of data file." if isinstance(data_files, str): - datasets.append(self.read(data_files)) + split = 'train' + datasets.append(self.read(filename=data_files, split=split)) else: - datasets += [self.read(data_file) for data_file in data_files] + datasets += [ + self.read( + filename=filename, split=split) + for split, filename in data_files.items() + ] if splits: assert isinstance(splits, str) or ( @@ -344,16 +350,16 @@ def read_datasets(self, splits=None, data_files=None): isinstance(splits, tuple) and isinstance(splits[0], str) ), "`splits` should be a string or list of string or a tuple of string." if isinstance(splits, str): - root = self._get_data(splits) - datasets.append(self.read(root)) + filename = self._get_data(splits) + datasets.append(self.read(filename=filename, split=splits)) else: for split in splits: - root = self._get_data(split) - datasets.append(self.read(root)) + filename = self._get_data(split) + datasets.append(self.read(filename=filename, split=split)) return datasets if len(datasets) > 1 else datasets[0] - def read(self, root): + def read(self, filename, split='train'): """ Returns an dataset containing all the examples that can be read from the file path. If `self.lazy` is `False`, this eagerly reads all instances from `self._read()` @@ -367,35 +373,34 @@ def read(self, root): if self.lazy: label_list = self.get_labels() - if label_list is not None: - label_dict = {} - for i, label in enumerate(label_list): - label_dict[label] = i - - def generate_examples(): - for example in self._read(root): - if 'labels' not in example.keys(): - raise ValueError( - "Keyword 'labels' should be in example if get_label() is specified." - ) - else: + def generate_examples(): + generator = self._read( + filename, split + ) if self._read.__code__.co_argcount > 2 else self._read( + filename) + for example in generator: + if label_list is not None and 'labels' in example.keys(): + label_dict = {} + for i, label in enumerate(label_list): + label_dict[label] = i + if isinstance(example['labels'], list) or isinstance( + examples[idx]['labels'], tuple): for label_idx in range(len(example['labels'])): example['labels'][label_idx] = label_dict[ example['labels'][label_idx]] + else: + example['labels'] = label_dict[example['labels']] - yield example - - return IterDataset(generate_examples, label_list=label_list) - else: - - def generate_examples(): - for example in self._read(root): + yield example + else: yield example - return IterDataset(generate_examples) - + return IterDataset(generate_examples, label_list=label_list) else: - examples = self._read(root) + examples = self._read( + filename, + split) if self._read.__code__.co_argcount > 2 else self._read( + filename) # Then some validation. if not isinstance(examples, list): @@ -404,30 +409,27 @@ def generate_examples(): if not examples: raise ValueError( "No instances were read from the given filepath {}. " - "Is the path correct?".format(root)) + "Is the path correct?".format(filename)) label_list = self.get_labels() # Convert class label to label ids. - if label_list is not None: - if 'labels' not in examples[0].keys(): - raise ValueError( - "Key 'labels' should be in example if get_label() is specified." - ) - + if label_list is not None and 'labels' in examples[0].keys(): label_dict = {} for i, label in enumerate(label_list): label_dict[label] = i - for idx in range(len(examples)): - for label_idx in range(len(examples[idx]['labels'])): - examples[idx]['labels'][label_idx] = label_dict[ - examples[idx]['labels'][label_idx]] - - return MapDataset( - examples, - label_list=label_list) if label_list else MapDataset(examples) - - def _read(self, file_path: str): + if isinstance(examples[idx]['labels'], list) or isinstance( + examples[idx]['labels'], tuple): + for label_idx in range(len(examples[idx]['labels'])): + examples[idx]['labels'][label_idx] = label_dict[ + examples[idx]['labels'][label_idx]] + else: + examples[idx]['labels'] = label_dict[examples[idx][ + 'labels']] + + return MapDataset(examples, label_list=label_list) + + def _read(self, filename: str, *args): """ Reads examples from the given file_path and returns them as an `Iterable` (which could be a list or could be a generator). diff --git a/paddlenlp/datasets/experimental/drcd.py b/paddlenlp/datasets/experimental/drcd.py index cacb8cc6a78ab1..d67a9597aa4d04 100644 --- a/paddlenlp/datasets/experimental/drcd.py +++ b/paddlenlp/datasets/experimental/drcd.py @@ -31,11 +31,10 @@ def _get_data(self, mode, **kwargs): if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): get_path_from_url(URL, default_root) - fullname = os.path.join(default_root, filename) return fullname - def _read(self, filename): + def _read(self, filename, *args): with open(filename, "r", encoding="utf8") as f: input_data = json.load(f)["data"] for entry in input_data: diff --git a/paddlenlp/datasets/experimental/dureader_robust.py b/paddlenlp/datasets/experimental/dureader_robust.py index a22b0a4a4bb614..e1b699d2c4ee8c 100644 --- a/paddlenlp/datasets/experimental/dureader_robust.py +++ b/paddlenlp/datasets/experimental/dureader_robust.py @@ -32,12 +32,11 @@ def _get_data(self, mode, **kwargs): fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): - get_path_from_url(self.URL, default_root) - fullname = os.path.join(default_root, filename) + get_path_from_url(self.URL, default_root, self.MD5) return fullname - def _read(self, filename): + def _read(self, filename, *args): with open(filename, "r", encoding="utf8") as f: input_data = json.load(f)["data"] for entry in input_data: diff --git a/paddlenlp/datasets/experimental/glue.py b/paddlenlp/datasets/experimental/glue.py new file mode 100644 index 00000000000000..328ba3ccdd425b --- /dev/null +++ b/paddlenlp/datasets/experimental/glue.py @@ -0,0 +1,322 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + + +class Glue(DatasetBuilder): + BUILDER_CONFIGS = { + 'cola': { + 'url': "https://dataset.bj.bcebos.com/glue/CoLA.zip", + 'md5': 'b178a7c2f397b0433c39c7caf50a3543', + 'splits': { + 'train': [ + os.path.join('CoLA', 'train.tsv'), + 'c79d4693b8681800338aa044bf9e797b', (3, 1), 0 + ], + 'dev': [ + os.path.join('CoLA', 'dev.tsv'), + 'c5475ccefc9e7ca0917294b8bbda783c', (3, 1), 0 + ], + 'test': [ + os.path.join('CoLA', 'test.tsv'), + 'd8721b7dedda0dcca73cebb2a9f4259f', (1, ), 1 + ] + }, + 'labels': ["0", "1"] + }, + 'sst-2': { + 'url': "https://dataset.bj.bcebos.com/glue/SST.zip", + 'md5': '9f81648d4199384278b86e315dac217c', + 'splits': { + 'train': [ + os.path.join('SST-2', 'train.tsv'), + 'da409a0a939379ed32a470bc0f7fe99a', (0, 1), 1 + ], + 'dev': [ + os.path.join('SST-2', 'dev.tsv'), + '268856b487b2a31a28c0a93daaff7288', (0, 1), 1 + ], + 'test': [ + os.path.join('SST-2', 'test.tsv'), + '3230e4efec76488b87877a56ae49675a', (1, ), 1 + ] + }, + 'labels': ["0", "1"] + }, + 'sts-b': { + 'url': 'https://dataset.bj.bcebos.com/glue/STS.zip', + 'md5': 'd573676be38f1a075a5702b90ceab3de', + 'splits': { + 'train': [ + os.path.join('STS-B', 'train.tsv'), + '4f7a86dde15fe4832c18e5b970998672', (7, 8, 9), 1 + ], + 'dev': [ + os.path.join('STS-B', 'dev.tsv'), + '5f4d6b0d2a5f268b1b56db773ab2f1fe', (7, 8, 9), 1 + ], + 'test': [ + os.path.join('STS-B', 'test.tsv'), + '339b5817e414d19d9bb5f593dd94249c', (7, 8), 1 + ] + }, + 'labels': None + }, + 'qqp': { + 'url': 'https://dataset.bj.bcebos.com/glue/QQP.zip', + 'md5': '884bf26e39c783d757acc510a2a516ef', + 'splits': { + 'train': [ + os.path.join('QQP', 'train.tsv'), + 'e003db73d277d38bbd83a2ef15beb442', (3, 4, 5), 1 + ], + 'dev': [ + os.path.join('QQP', 'dev.tsv'), + 'cff6a448d1580132367c22fc449ec214', (3, 4, 5), 1 + ], + 'test': [ + os.path.join('QQP', 'test.tsv'), + '73de726db186b1b08f071364b2bb96d0', (1, 2), 1 + ] + }, + 'labels': ["0", "1"] + }, + 'mnli': { + 'url': 'https://dataset.bj.bcebos.com/glue/MNLI.zip', + 'md5': 'e343b4bdf53f927436d0792203b9b9ff', + 'splits': { + 'train': [ + os.path.join('MNLI', 'train.tsv'), + '220192295e23b6705f3545168272c740', (8, 9, 11), 1 + ], + 'dev_matched': [ + os.path.join('MNLI', 'dev_matched.tsv'), + 'c3fa2817007f4cdf1a03663611a8ad23', (8, 9, 15), 1 + ], + 'dev_mismatched': [ + os.path.join('MNLI', 'dev_mismatched.tsv'), + 'b219e6fe74e4aa779e2f417ffe713053', (8, 9, 15), 1 + ], + 'test_matched': [ + os.path.join('MNLI', 'test_matched.tsv'), + '33ea0389aedda8a43dabc9b3579684d9', (8, 9), 1 + ], + 'test_mismatched': [ + os.path.join('MNLI', 'test_mismatched.tsv'), + '7d2f60a73d54f30d8a65e474b615aeb6', (8, 9), 1 + ] + }, + 'labels': ["contradiction", "entailment", "neutral"] + }, + 'qnli': { + 'url': 'https://dataset.bj.bcebos.com/glue/QNLI.zip', + 'md5': 'b4efd6554440de1712e9b54e14760e82', + 'splits': { + 'train': [ + os.path.join('QNLI', 'train.tsv'), + '5e6063f407b08d1f7c7074d049ace94a', (1, 2, 3), 1 + ], + 'dev': [ + os.path.join('QNLI', 'dev.tsv'), + '1e81e211959605f144ba6c0ad7dc948b', (1, 2, 3), 1 + ], + 'test': [ + os.path.join('QNLI', 'test.tsv'), + 'f2a29f83f3fe1a9c049777822b7fa8b0', (1, 2), 1 + ] + }, + 'labels': ["entailment", "not_entailment"] + }, + 'rte': { + 'url': 'https://dataset.bj.bcebos.com/glue/RTE.zip', + 'md5': 'bef554d0cafd4ab6743488101c638539', + 'splits': { + 'train': [ + os.path.join('RTE', 'train.tsv'), + 'd2844f558d111a16503144bb37a8165f', (1, 2, 3), 1 + ], + 'dev': [ + os.path.join('RTE', 'dev.tsv'), + '973cb4178d4534cf745a01c309d4a66c', (1, 2, 3), 1 + ], + 'test': [ + os.path.join('RTE', 'test.tsv'), + '6041008f3f3e48704f57ce1b88ad2e74', (1, 2), 1 + ] + }, + 'labels': ["entailment", "not_entailment"] + }, + 'wnli': { + 'url': 'https://dataset.bj.bcebos.com/glue/WNLI.zip', + 'md5': 'a1b4bd2861017d302d29e42139657a42', + 'splits': { + 'train': [ + os.path.join('WNLI', 'train.tsv'), + '5cdc5a87b7be0c87a6363fa6a5481fc1', (1, 2, 3), 1 + ], + 'dev': [ + os.path.join('WNLI', 'dev.tsv'), + 'a79a6dd5d71287bcad6824c892e517ee', (1, 2, 3), 1 + ], + 'test': [ + os.path.join('WNLI', 'test.tsv'), + 'a18789ba4f60f6fdc8cb4237e4ba24b5', (1, 2), 1 + ] + }, + 'labels': ["0", "1"] + }, + 'mrpc': { + 'url': { + 'train_data': + 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_train.txt', + 'dev_id': 'https://dataset.bj.bcebos.com/glue/mrpc/dev_ids.tsv', + 'test_data': + 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_test.txt' + }, + 'md5': { + 'train_data': '793daf7b6224281e75fe61c1f80afe35', + 'dev_id': '7ab59a1b04bd7cb773f98a0717106c9b', + 'test_data': 'e437fdddb92535b820fe8852e2df8a49' + }, + 'splits': { + 'train': [ + os.path.join('MRPC', 'train.tsv'), + 'dc2dac669a113866a6480a0b10cd50bf', (3, 4, 0), 1 + ], + 'dev': [ + os.path.join('MRPC', 'dev.tsv'), + '185958e46ba556b38c6a7cc63f3a2135', (3, 4, 0), 1 + ], + 'test': [ + os.path.join('MRPC', 'test.tsv'), + '4825dab4b4832f81455719660b608de5', (3, 4), 1 + ] + }, + 'labels': ["0", "1"] + } + } + + def _get_data(self, mode, **kwargs): + builder_config = self.BUILDER_CONFIGS[self.name] + if self.name != 'mrpc': + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash, _, _ = builder_config['splits'][mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or ( + data_hash and not md5file(fullname) == data_hash): + get_path_from_url(builder_config['url'], default_root, + builder_config['md5']) + + else: + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash, _, _ = builder_config['splits'][mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or ( + data_hash and not md5file(fullname) == data_hash): + if mode in ('train', 'dev'): + dev_id_path = get_path_from_url( + builder_config['url']['dev_id'], + os.path.join(default_root, 'MRPC'), + builder_config['md5']['dev_id']) + train_data_path = get_path_from_url( + builder_config['url']['train_data'], + os.path.join(default_root, 'MRPC'), + builder_config['md5']['train_data']) + # read dev data ids + dev_ids = [] + print(dev_id_path) + with open(dev_id_path, encoding='utf-8') as ids_fh: + for row in ids_fh: + dev_ids.append(row.strip().split('\t')) + + # generate train and dev set + train_path = os.path.join(default_root, 'MRPC', 'train.tsv') + dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv') + with open(train_data_path, encoding='utf-8') as data_fh: + with open( + train_path, 'w', encoding='utf-8') as train_fh: + with open(dev_path, 'w', encoding='utf8') as dev_fh: + header = data_fh.readline() + train_fh.write(header) + dev_fh.write(header) + for row in data_fh: + label, id1, id2, s1, s2 = row.strip().split( + '\t') + example = '%s\t%s\t%s\t%s\t%s\n' % ( + label, id1, id2, s1, s2) + if [id1, id2] in dev_ids: + dev_fh.write(example) + else: + train_fh.write(example) + + else: + test_data_path = get_path_from_url( + builder_config['url']['test_data'], + os.path.join(default_root, 'MRPC'), + builder_config['md5']['test_data']) + test_path = os.path.join(default_root, 'MRPC', 'test.tsv') + with open(test_data_path, encoding='utf-8') as data_fh: + with open(test_path, 'w', encoding='utf-8') as test_fh: + header = data_fh.readline() + test_fh.write( + 'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n') + for idx, row in enumerate(data_fh): + label, id1, id2, s1, s2 = row.strip().split( + '\t') + test_fh.write('%d\t%s\t%s\t%s\t%s\n' % + (idx, id1, id2, s1, s2)) + + return fullname + + def _read(self, filename, split): + _, _, field_indices, num_discard_samples = self.BUILDER_CONFIGS[ + self.name]['splits'][split] + with open(filename, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + if idx < num_discard_samples: + continue + line_stripped = line.strip().split('\t') + if not line_stripped: + break + example = [line_stripped[indice] for indice in field_indices] + if self.name in ['cola', 'sst-2']: + yield { + 'sentence': example[0] + } if 'test' in split else { + 'sentence': example[0], + 'labels': example[-1] + } + else: + yield { + 'sentence1': example[0], + 'sentence2': example[1] + } if 'test' in split else { + 'sentence1': example[0], + 'sentence2': example[1], + 'labels': example[-1] + } + + def get_labels(self): + """ + Return labels of the Glue task. + """ + return self.BUILDER_CONFIGS[self.name]['labels'] diff --git a/paddlenlp/datasets/experimental/ptb.py b/paddlenlp/datasets/experimental/ptb.py index adceb57510443a..28541579d68f2f 100644 --- a/paddlenlp/datasets/experimental/ptb.py +++ b/paddlenlp/datasets/experimental/ptb.py @@ -35,11 +35,10 @@ def _get_data(self, mode, **kwargs): not md5file(fullname) == data_hash): get_path_from_url(self.URL, default_root, self.MD5) - fullname = os.path.join(default_root, filename) return fullname - def _read(self, filename): + def _read(self, filename, *args): with open(filename, 'r', encoding='utf-8') as f: for line in f: line_stripped = line.strip() diff --git a/paddlenlp/datasets/experimental/squad.py b/paddlenlp/datasets/experimental/squad.py index 1949f75b33b8be..50d47adccc252c 100644 --- a/paddlenlp/datasets/experimental/squad.py +++ b/paddlenlp/datasets/experimental/squad.py @@ -34,11 +34,10 @@ def _get_data(self, mode, **kwargs): if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): get_path_from_url(URL, default_root) - fullname = os.path.join(default_root, filename) return fullname - def _read(self, filename): + def _read(self, filename, *args): with open(filename, "r", encoding="utf8") as f: input_data = json.load(f)["data"] for entry in input_data: diff --git a/paddlenlp/datasets/experimental/wmt14ende.py b/paddlenlp/datasets/experimental/wmt14ende.py index 74551fecdd87db..5b1effec7946fb 100644 --- a/paddlenlp/datasets/experimental/wmt14ende.py +++ b/paddlenlp/datasets/experimental/wmt14ende.py @@ -82,7 +82,7 @@ def _get_data(self, mode, **kwargs): return src_fullname, tgt_fullname - def _read(self, filename): + def _read(self, filename, *args): src_filename, tgt_filename = filename with open(src_filename, 'r', encoding='utf-8') as src_f: with open(tgt_filename, 'r', encoding='utf-8') as tgt_f: