Skip to content

Commit

Permalink
Add warning message for run_qa.py (#29867)
Browse files Browse the repository at this point in the history
* improve: error message for best model metric

* update: raise warning instead of error
  • Loading branch information
jla524 authored and ArthurZucker committed Apr 22, 2024
1 parent a7c5311 commit ca56b09
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,14 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)

if data_args.version_2_with_negative:
accepted_best_metrics = ("exact", "f1", "HasAns_exact", "HasAns_f1")
else:
accepted_best_metrics = ("exact_match", "f1")

if training_args.load_best_model_at_end and training_args.metric_for_best_model not in accepted_best_metrics:
warnings.warn(f"--metric_for_best_model should be set to one of {accepted_best_metrics}")

metric = evaluate.load(
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
)
Expand Down

0 comments on commit ca56b09

Please sign in to comment.