Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update notebook.py to support multi eval datasets #25796

Merged
merged 7 commits into from
Sep 15, 2023

Conversation

matrix1001
Copy link
Contributor

Fix key error when using multiple evaluation datasets

Code triggering the error

Any code using multiple eval_dataset will trigger the error.

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset={
        'valid':valid_dataset,
        'test':test_dataset,
    },
    compute_metrics=compute_metrics
)

Before fix

image

Here's detailed error msg:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[1], line 80
     68 model = BertForRelationExtraction.from_pretrained('bert-base-cased', config=config)
     69 trainer = Trainer(
     70     model=model,
     71     args=training_args,
   (...)
     77     compute_metrics=compute_metrics
     78 )
---> 80 trainer.train()

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/trainer.py:1664], in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1659     self.model_wrapped = self.model
   1661 inner_training_loop = find_executable_batch_size(
   1662     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1663 )
-> 1664 return inner_training_loop(
   1665     args=args,
   1666     resume_from_checkpoint=resume_from_checkpoint,
   1667     trial=trial,
   1668     ignore_keys_for_eval=ignore_keys_for_eval,
   1669 )

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/trainer.py:2019], in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2016     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2017     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2019     self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
   2020 else:
   2021     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/trainer.py:2293], in Trainer._maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)
   2291     metrics = {}
   2292     for eval_dataset_name, eval_dataset in self.eval_dataset.items():
-> 2293         dataset_metrics = self.evaluate(
   2294             eval_dataset=eval_dataset,
   2295             ignore_keys=ignore_keys_for_eval,
   2296             metric_key_prefix=f"eval_{eval_dataset_name}",
   2297         )
   2298         metrics.update(dataset_metrics)
   2299 else:

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/trainer.py:3057], in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3053 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   3054     # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
   3055     xm.master_print(met.metrics_report())
-> 3057 self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
   3059 self._memory_tracker.stop_and_update_metrics(output.metrics)
   3061 return output.metrics

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/trainer_callback.py:379], in CallbackHandler.on_evaluate(self, args, state, control, metrics)
    377 def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
    378     control.should_evaluate = False
--> 379     return self.call_event("on_evaluate", args, state, control, metrics=metrics)

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/trainer_callback.py:397], in CallbackHandler.call_event(self, event, args, state, control, **kwargs)
    395 def call_event(self, event, args, state, control, **kwargs):
    396     for callback in self.callbacks:
--> 397         result = getattr(callback, event)(
    398             args,
    399             state,
    400             control,
    401             model=self.model,
    402             tokenizer=self.tokenizer,
    403             optimizer=self.optimizer,
    404             lr_scheduler=self.lr_scheduler,
    405             train_dataloader=self.train_dataloader,
    406             eval_dataloader=self.eval_dataloader,
    407             **kwargs,
    408         )
    409         # A Callback can skip the return of `control` if it doesn't change it.
    410         if result is not None:

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/utils/notebook.py:350], in NotebookProgressCallback.on_evaluate(self, args, state, control, metrics, **kwargs)
    348         name = " ".join([part.capitalize() for part in splits[1:]])
    349         values[name] = v
--> 350 self.training_tracker.write_line(values)
    351 self.training_tracker.remove_child()
    352 self.prediction_bar = None

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/utils/notebook.py:244], in NotebookTrainingTracker.write_line(self, values)
    242             columns.append(key)
    243     self.inner_table[0] = columns
--> 244 self.inner_table.append([values[c] for c in columns])

File [~/anaconda3/envs/pytorch2/lib/python3.11/site-packages/transformers/utils/notebook.py:244](, in (.0)
    242             columns.append(key)
    243     self.inner_table[0] = columns
--> 244 self.inner_table.append([values[c] for c in columns])

KeyError: 'Valid Accuracy'

After fix

image

Some explanation

I remove all predefined key "Validation Loss". However, this will have the following result if there is only one eval_dataset:
image
I think we don't have to name it as "Validation Loss"?

My modification allows dynamic columns and updates values if multiple calls for NotebookProgressCallback.on_evaluate correspond to the same epoch or step.

fix multi eval datasets
using `black` to reformat
@matrix1001
Copy link
Contributor Author

Note that my last commit corresponds to the workflow check. I dont think it's necessary.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for opening the fix. I think there was an issue related to this, if you could link it would be great!
To fix the CIs you can run make style. (Will also be easier to review the required changes!)

Already looks nice, would be better if we can keep the Validation Loss in the single dataset case!

Comment on lines 144 to 148
elif (
force_update
or self.first_calls > 0
or value >= min(self.last_value + self.wait_for, self.total)
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be fixed by make style (same for changes that break lines

@ArthurZucker ArthurZucker requested a review from muellerzr August 29, 2023 12:38
@muellerzr muellerzr linked an issue Aug 29, 2023 that may be closed by this pull request
4 tasks
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the improvement! Looks good to me, but as mentioned let's see if we can try to keep validation loss to still be a thing if we have a single dataset :)

Comment on lines 283 to 284
if args.evaluation_strategy != IntervalStrategy.NO:
column_names.append("Validation Loss")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't have multiple datasets, let's keep this in please :)

@@ -320,7 +362,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):

def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if self.training_tracker is not None:
values = {"Training Loss": "No log", "Validation Loss": "No log"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

src/transformers/utils/notebook.py Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ArthurZucker
Copy link
Collaborator

@matrix1001 do you want me to merge this or do you still have modifications to include?

@matrix1001
Copy link
Contributor Author

@ArthurZucker merge.

@ArthurZucker ArthurZucker merged commit ebd21e9 into huggingface:main Sep 15, 2023
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
* Update notebook.py

fix multi eval datasets

* Update notebook.py

* Update notebook.py

using `black` to reformat

* Update notebook.py

support Validation Loss

* Update notebook.py

reformat

* Update notebook.py
@puneetdabulya
Copy link

import evaluate

f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
accuracy_metric = evaluate.load("accuracy")
average_method = 'weighted'
def compute_metrics(eval_pred):
    results = {}
    predictions = np.argmax(eval_pred.predictions, axis=1)
    labels = eval_pred.label_ids
    results.update(f1_metric.compute(predictions=predictions, references = labels, average=average_method))
    results.update(precision_metric.compute(predictions=predictions, references = labels, average=average_method))
    results.update(recall_metric.compute(predictions=predictions, references = labels, average=average_method))
    results.update(accuracy_metric.compute(predictions=predictions, references = labels))
    return results

In this code, on master branch I am still seeing:

----> 1 train_results = trainer.train()

File ~/.local/lib/python3.8/site-packages/transformers/trainer.py:1591, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1589         hf_hub_utils.enable_progress_bars()
   1590 else:
-> 1591     return inner_training_loop(
   1592         args=args,
   1593         resume_from_checkpoint=resume_from_checkpoint,
   1594         trial=trial,
   1595         ignore_keys_for_eval=ignore_keys_for_eval,
   1596     )

File ~/.local/lib/python3.8/site-packages/transformers/trainer.py:1999, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1996     self.control.should_training_stop = True
   1998 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 1999 self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
   2001 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   2002     if is_torch_tpu_available():
   2003         # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)

File ~/.local/lib/python3.8/site-packages/transformers/trainer.py:2339, in Trainer._maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)
   2336         self.lr_scheduler.step(metrics[metric_to_check])
   2338 if self.control.should_save:
-> 2339     self._save_checkpoint(model, trial, metrics=metrics)
   2340     self.control = self.callback_handler.on_save(self.args, self.state, self.control)

File ~/.local/lib/python3.8/site-packages/transformers/trainer.py:2458, in Trainer._save_checkpoint(self, model, trial, metrics)
   2456 if not metric_to_check.startswith("eval_"):
   2457     metric_to_check = f"eval_{metric_to_check}"
-> 2458 metric_value = metrics[metric_to_check]
   2460 operator = np.greater if self.args.greater_is_better else np.less
   2461 if (
   2462     self.state.best_metric is None
   2463     or self.state.best_model_checkpoint is None
   2464     or operator(metric_value, self.state.best_metric)
   2465 ):

KeyError: 'eval_accuracy'

@matrix1001
Copy link
Contributor Author

@puneetdabulya This PR only fixes notebook.py. You may need to fix it in another PR.

blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* Update notebook.py

fix multi eval datasets

* Update notebook.py

* Update notebook.py

using `black` to reformat

* Update notebook.py

support Validation Loss

* Update notebook.py

reformat

* Update notebook.py
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* Update notebook.py

fix multi eval datasets

* Update notebook.py

* Update notebook.py

using `black` to reformat

* Update notebook.py

support Validation Loss

* Update notebook.py

reformat

* Update notebook.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trainer explodes with multiple validation sets used
5 participants