Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

defaults changed #7600

Merged
merged 4 commits into from
Oct 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions scripts/metric_calculation/peft_metric_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,18 @@


"""
This script can be used to calcualte exact match and F1 scores for many different tasks, not just squad.

Example command for T5 Preds

```
python squad_metric_calc.py \
--ground-truth squad_test_gt.jsonl \
--preds squad_preds_t5.txt
```
This script can be used to calcualte exact match and F1 scores for many different tasks.
The file "squad_test_predictions.jsonl" is assumed to be generated by the
`examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py` script

Example command for GPT Preds

```
python squad_metric_calc.py \
--ground-truth squad_test_gt.jsonl \
--preds squad_preds_gpt.txt \
--split-string "answer:"
python peft_metric_calc.py \
--pred_file squad_test_predictions.jsonl \
--label_field "original_answers" \
```

In this case, the prediction file will be split on "answer: " when looking for the LM's predicted answer.

"""

Expand Down Expand Up @@ -92,21 +84,21 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
def main():
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument(
'--pred-file',
'--pred_file',
type=str,
help="Text file with test set prompts + model predictions. Prediction file can be made by running NeMo/examples/nlp/language_modeling/megatron_gpt_prompt_learning_eval.py",
)
parser.add_argument(
'--pred-field',
'--pred_field',
type=str,
help="The field in the json file that contains the prediction tokens",
default="pred",
)
parser.add_argument(
'--ground-truth-field',
'--label_field',
type=str,
help="The field in the json file that contains the ground truth tokens",
default="original_answers",
default="label",
)

args = parser.parse_args()
Expand All @@ -120,7 +112,7 @@ def main():
pred_line = json.loads(preds[i])

pred_answer = pred_line[args.pred_field]
true_answers = pred_line[args.ground_truth_field]
true_answers = pred_line[args.label_field]
if not isinstance(true_answers, list):
true_answers = [true_answers]

Expand Down
Loading