Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update run_glue for do_predict with local test data (#9442) #9486

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None, metadata={"help": "A csv or a json file containing the validation data."}
)
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})

def __post_init__(self):
if self.task_name is not None:
Expand Down Expand Up @@ -413,6 +414,25 @@ def compute_metrics(p: EvalPrediction):
if training_args.do_predict:
logger.info("*** Test ***")

# Get the datasets: you can provide your own CSV/JSON test file (see below)
# when you use `do_predict` without specifying a GLUE benchmark task.

if data_args.task_name is None and data_args.test_file is not None:
extension = data_args.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
if data_args.test_file.endswith(".csv"):
# Loading a dataset from a local csv file
test_dataset = load_dataset("csv", data_files={"test": data_args.test_file})
else:
# Loading a dataset from a local json file
test_dataset = load_dataset("json", data_files={"test": data_args.test_file})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put those lines earlier, with the validation dataset? This way the map will be done with the other dataset. I think we can do something nice by creating data_files={"train": data_args.train_file, "validation": data_args.validation_file} and then adding the keys test if the test_file is passed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comment! I've reflected the review.
The nested if statements in the code have increased, but I think the readability may have improved in terms of "whether to use GLUE task or to use local files". What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added the logger output to make sure that the local files a user wants to use are loaded correctly. If this is superfluous, please let me know and I will remove it.

test_dataset = test_dataset.map(
preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache
)
test_dataset = test_dataset["test"]
else:
raise ValueError("Need either a GLUE task or a test file for `do_predict`.")

# Loop to handle MNLI double evaluation (matched, mis-matched)
tasks = [data_args.task_name]
test_datasets = [test_dataset]
Expand Down