From de3abf3e6a7a575f7101f8d6139d2c1f7358d6ca Mon Sep 17 00:00:00 2001 From: LZR Date: Mon, 15 Apr 2024 05:53:21 -0700 Subject: [PATCH] round epoch only in console (#30237) --- src/transformers/trainer.py | 2 +- src/transformers/trainer_callback.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5b8ffeafc7c8ea..45b45992bf425a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3048,7 +3048,7 @@ def log(self, logs: Dict[str, float]) -> None: The values to log. """ if self.state.epoch is not None: - logs["epoch"] = round(self.state.epoch, 2) + logs["epoch"] = self.state.epoch if self.args.include_num_input_tokens_seen: logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index f5bbcdbd4218d5..225f645d631e41 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -15,6 +15,7 @@ """ Callbacks to use with the Trainer class and customize the training loop. """ +import copy import dataclasses import json from dataclasses import dataclass @@ -520,7 +521,12 @@ def on_predict(self, args, state, control, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs): if state.is_world_process_zero and self.training_bar is not None: + # avoid modifying the logs object as it is shared between callbacks + logs = copy.deepcopy(logs) _ = logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "epoch" in logs: + logs["epoch"] = round(logs["epoch"], 2) self.training_bar.write(str(logs)) def on_train_end(self, args, state, control, **kwargs):