Skip to content

Commit

Permalink
Extract metric_key_prefix during NotebookProgressCallback.on_evaluate (
Browse files Browse the repository at this point in the history
…#11347)

* Pass metric_key_prefix as kwarg to on_evaluate

* Replace eval_loss with metric_key_prefix_loss

* Default to "eval" if metric_key_prefix not in kwargs

* Add kwargs to CallbackHandler.on_evaluate signature

* Revert "Add kwargs to CallbackHandler.on_evaluate signature"

This reverts commit 8d4c85e.

* Revert "Pass metric_key_prefix as kwarg to on_evaluate"

This reverts commit 7766bfe.

* Extract metric_key_prefix from metrics
  • Loading branch information
lewtun authored Apr 21, 2021
1 parent dabeb15 commit 41f3133
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/transformers/utils/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import collections
import re
import time
from typing import Optional

Expand Down Expand Up @@ -308,7 +309,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):

def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if self.training_tracker is not None:
values = {"Training Loss": "No log"}
values = {"Training Loss": "No log", "Validation Loss": "No log"}
for log in reversed(state.log_history):
if "loss" in log:
values["Training Loss"] = log["loss"]
Expand All @@ -318,13 +319,16 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
values["Epoch"] = int(state.epoch)
else:
values["Step"] = state.global_step
values["Validation Loss"] = metrics["eval_loss"]
metric_key_prefix = "eval"
for k in metrics:
if k.endswith("_loss"):
metric_key_prefix = re.sub(r"\_loss$", "", k)
_ = metrics.pop("total_flos", None)
_ = metrics.pop("epoch", None)
_ = metrics.pop("eval_runtime", None)
_ = metrics.pop("eval_samples_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
for k, v in metrics.items():
if k == "eval_loss":
if k == f"{metric_key_prefix}_loss":
values["Validation Loss"] = v
else:
splits = k.split("_")
Expand Down

0 comments on commit 41f3133

Please sign in to comment.