Skip to content

Commit

Permalink
fixed JSON error in run_qa with fp16 (#9186)
Browse files Browse the repository at this point in the history
  • Loading branch information
WissamAntoun authored Dec 18, 2020
1 parent 66a14a2 commit fd7b6a5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/question-answering/utils_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def postprocess_qa_predictions(

# Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions
]

Expand Down Expand Up @@ -394,7 +394,7 @@ def postprocess_qa_predictions_with_beam_search(

# Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions
]

Expand Down

0 comments on commit fd7b6a5

Please sign in to comment.