From dba96002ea07e740a199e379403d6c8994ae1c94 Mon Sep 17 00:00:00 2001 From: Yusuke Mori Date: Sat, 9 Jan 2021 02:29:42 +0900 Subject: [PATCH 1/5] Update run_glue for do_predict with local test data (#9442) --- examples/text-classification/run_glue.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 0dad31e04965..8ed6f9491d40 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -93,6 +93,7 @@ class DataTrainingArguments: validation_file: Optional[str] = field( default=None, metadata={"help": "A csv or a json file containing the validation data."} ) + test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) def __post_init__(self): if self.task_name is not None: @@ -413,6 +414,25 @@ def compute_metrics(p: EvalPrediction): if training_args.do_predict: logger.info("*** Test ***") + # Get the datasets: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + + if data_args.task_name is None and data_args.test_file is not None: + extension = data_args.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." + if data_args.test_file.endswith(".csv"): + # Loading a dataset from local csv files + test_dataset = load_dataset("csv", data_files={"test": data_args.test_file}) + else: + # Loading a dataset from local json files + test_dataset = load_dataset("json", data_files={"test": data_args.test_file}) + test_dataset = test_dataset.map( + preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache + ) + test_dataset = test_dataset["test"] + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + # Loop to handle MNLI double evaluation (matched, mis-matched) tasks = [data_args.task_name] test_datasets = [test_dataset] From a77675e2c1b7ec0239903bdf2dc82761597c77a7 Mon Sep 17 00:00:00 2001 From: Yusuke Mori Date: Sat, 9 Jan 2021 02:48:59 +0900 Subject: [PATCH 2/5] Update run_glue (#9442): fix comments ('files' to 'a file') --- examples/text-classification/run_glue.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 8ed6f9491d40..1773e19e9d69 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -421,10 +421,10 @@ def compute_metrics(p: EvalPrediction): extension = data_args.test_file.split(".")[-1] assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." if data_args.test_file.endswith(".csv"): - # Loading a dataset from local csv files + # Loading a dataset from a local csv file test_dataset = load_dataset("csv", data_files={"test": data_args.test_file}) else: - # Loading a dataset from local json files + # Loading a dataset from a local json file test_dataset = load_dataset("json", data_files={"test": data_args.test_file}) test_dataset = test_dataset.map( preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache From b1e10acc57fe0aea1defb0250f55d4e4a46310c7 Mon Sep 17 00:00:00 2001 From: Yusuke Mori Date: Sat, 9 Jan 2021 07:24:10 +0000 Subject: [PATCH 3/5] Update run_glue (#9442): reflect the code review --- examples/text-classification/run_glue.py | 57 ++++++++++++------------ 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 1773e19e9d69..67d7f3d2079b 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -206,16 +206,34 @@ def main(): if data_args.task_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset("glue", data_args.task_name) - elif data_args.train_file.endswith(".csv"): - # Loading a dataset from local csv files - datasets = load_dataset( - "csv", data_files={"train": data_args.train_file, "validation": data_args.validation_file} - ) else: - # Loading a dataset from local json files - datasets = load_dataset( - "json", data_files={"train": data_args.train_file, "validation": data_args.validation_file} - ) + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + # Get the test dataset: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + if training_args.do_predict: + if data_args.test_file is not None: + extension = data_args.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." + data_files["test"] = data_args.test_file + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + datasets = load_dataset( + "csv", data_files=data_files + ) + else: + # Loading a dataset from local json files + datasets = load_dataset( + "json", data_files=data_files + ) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. @@ -326,7 +344,7 @@ def preprocess_function(examples): train_dataset = datasets["train"] eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] - if data_args.task_name is not None: + if data_args.task_name is not None or data_args.test_file is not None: test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] # Log a few random samples from the training set: @@ -414,25 +432,6 @@ def compute_metrics(p: EvalPrediction): if training_args.do_predict: logger.info("*** Test ***") - # Get the datasets: you can provide your own CSV/JSON test file (see below) - # when you use `do_predict` without specifying a GLUE benchmark task. - - if data_args.task_name is None and data_args.test_file is not None: - extension = data_args.test_file.split(".")[-1] - assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." - if data_args.test_file.endswith(".csv"): - # Loading a dataset from a local csv file - test_dataset = load_dataset("csv", data_files={"test": data_args.test_file}) - else: - # Loading a dataset from a local json file - test_dataset = load_dataset("json", data_files={"test": data_args.test_file}) - test_dataset = test_dataset.map( - preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache - ) - test_dataset = test_dataset["test"] - else: - raise ValueError("Need either a GLUE task or a test file for `do_predict`.") - # Loop to handle MNLI double evaluation (matched, mis-matched) tasks = [data_args.task_name] test_datasets = [test_dataset] From b2936c342c0179c11518ec098fc3a10c988d38f5 Mon Sep 17 00:00:00 2001 From: Yusuke Mori Date: Sat, 9 Jan 2021 07:28:55 +0000 Subject: [PATCH 4/5] Update run_glue (#9442): auto format --- examples/text-classification/run_glue.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 67d7f3d2079b..fb2f156b136c 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -226,14 +226,10 @@ def main(): if data_args.train_file.endswith(".csv"): # Loading a dataset from local csv files - datasets = load_dataset( - "csv", data_files=data_files - ) + datasets = load_dataset("csv", data_files=data_files) else: # Loading a dataset from local json files - datasets = load_dataset( - "json", data_files=data_files - ) + datasets = load_dataset("json", data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. From 37531ab5a3a2348f6894b82e88fe7a3e1515fbc9 Mon Sep 17 00:00:00 2001 From: Yusuke Mori Date: Mon, 11 Jan 2021 16:37:48 +0000 Subject: [PATCH 5/5] Update run_glue (#9442): reflect the code review --- examples/text-classification/run_glue.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index fb2f156b136c..7cb7f0c8ff3f 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -103,10 +103,12 @@ def __post_init__(self): elif self.train_file is None or self.validation_file is None: raise ValueError("Need either a GLUE task or a training/validation file.") else: - extension = self.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." - extension = self.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + train_extension = self.train_file.split(".")[-1] + assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." @dataclass @@ -215,8 +217,11 @@ def main(): # when you use `do_predict` without specifying a GLUE benchmark task. if training_args.do_predict: if data_args.test_file is not None: - extension = data_args.test_file.split(".")[-1] - assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." + train_extension = data_args.train_file.split(".")[-1] + test_extension = data_args.test_file.split(".")[-1] + assert ( + test_extension == train_extension + ), "`test_file` should have the same extension (csv or json) as `train_file`." data_files["test"] = data_args.test_file else: raise ValueError("Need either a GLUE task or a test file for `do_predict`.")