Skip to content

Commit

Permalink
Update notebook.py to support multi eval datasets (huggingface#25796)
Browse files Browse the repository at this point in the history
* Update notebook.py

fix multi eval datasets

* Update notebook.py

* Update notebook.py

using `black` to reformat

* Update notebook.py

support Validation Loss

* Update notebook.py

reformat

* Update notebook.py
  • Loading branch information
matrix1001 authored and blbadger committed Nov 8, 2023
1 parent c95a3e6 commit 38fa6a1
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions src/transformers/utils/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,25 @@ def write_line(self, values):
self.inner_table = [list(values.keys()), list(values.values())]
else:
columns = self.inner_table[0]
if len(self.inner_table) == 1:
# We give a chance to update the column names at the first iteration
for key in values.keys():
if key not in columns:
columns.append(key)
self.inner_table[0] = columns
self.inner_table.append([values[c] for c in columns])
for key in values.keys():
if key not in columns:
columns.append(key)
self.inner_table[0] = columns
if len(self.inner_table) > 1:
last_values = self.inner_table[-1]
first_column = self.inner_table[0][0]
if last_values[0] != values[first_column]:
# write new line
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
else:
# update last line
new_values = values
for c in columns:
if c not in new_values.keys():
new_values[c] = last_values[columns.index(c)]
self.inner_table[-1] = [new_values[c] for c in columns]
else:
self.inner_table.append([values[c] for c in columns])

def add_child(self, total, prefix=None, width=300):
"""
Expand Down Expand Up @@ -341,12 +353,12 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
for k, v in metrics.items():
if k == f"{metric_key_prefix}_loss":
values["Validation Loss"] = v
else:
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
values[name] = v
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
if name == "Loss":
# Single dataset
name = "Validation Loss"
values[name] = v
self.training_tracker.write_line(values)
self.training_tracker.remove_child()
self.prediction_bar = None
Expand Down

0 comments on commit 38fa6a1

Please sign in to comment.