Skip to content

Commit

Permalink
Style
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Apr 6, 2021
1 parent aef4cf8 commit fd338ab
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
17 changes: 5 additions & 12 deletions examples/question-answering/run_qa_beam_search_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def parse_args():
parser.add_argument(
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
)
parser.add_argument(
"--do_predict", action="store_true", help="Eval the question answering model"
)
parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
Expand Down Expand Up @@ -284,7 +282,7 @@ def main():
# Preprocessing the datasets.
# Preprocessing is slighlty different for training and evaluation.
column_names = raw_datasets["train"].column_names

question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]
Expand Down Expand Up @@ -396,7 +394,6 @@ def prepare_train_features(examples):

return tokenized_examples


if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
Expand Down Expand Up @@ -481,7 +478,6 @@ def prepare_validation_features(examples):

return tokenized_examples


if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_examples = raw_datasets["validation"]
Expand Down Expand Up @@ -539,11 +535,8 @@ def prepare_validation_features(examples):
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
)


eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
eval_dataloader = DataLoader(
eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)

if args.do_predict:
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
Expand Down Expand Up @@ -605,8 +598,8 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
if step + batch_size < len(dataset):
logits_concat[step : step + batch_size, :cols] = output_logit
else:
logits_concat[step:, :cols] = output_logit[:len(dataset) - step]
logits_concat[step:, :cols] = output_logit[: len(dataset) - step]

step += batch_size

return logits_concat
Expand Down
10 changes: 3 additions & 7 deletions examples/question-answering/run_qa_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def parse_args():
parser.add_argument(
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
)
parser.add_argument(
"--do_predict", action="store_true", help="Eval the question answering model"
)
parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
Expand Down Expand Up @@ -543,9 +541,7 @@ def prepare_validation_features(examples):
)

eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
eval_dataloader = DataLoader(
eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)

if args.do_predict:
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
Expand Down Expand Up @@ -607,7 +603,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
if step + batch_size < len(dataset):
logits_concat[step : step + batch_size, :cols] = output_logit
else:
logits_concat[step:, :cols] = output_logit[:len(dataset) - step]
logits_concat[step:, :cols] = output_logit[: len(dataset) - step]

step += batch_size

Expand Down

0 comments on commit fd338ab

Please sign in to comment.