Skip to content

Commit

Permalink
Fixes an issue in text-classification where MNLI eval/test datasets…
Browse files Browse the repository at this point in the history
… are not being preprocessed. (#10621)

* Fix MNLI tests

* Linter fix
  • Loading branch information
allenwang28 authored Mar 10, 2021
1 parent 72d9e03 commit 6f52fce
Showing 1 changed file with 5 additions and 18 deletions.
23 changes: 5 additions & 18 deletions examples/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,45 +374,32 @@ def preprocess_function(examples):
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result

datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
if training_args.do_train:
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_eval:
if "validation" not in datasets and "validation_matched" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in datasets and "test_matched" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
if training_args.do_train:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

# Get the metric function
if data_args.task_name is not None:
Expand Down Expand Up @@ -447,7 +434,7 @@ def compute_metrics(p: EvalPrediction):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
Expand Down

0 comments on commit 6f52fce

Please sign in to comment.