From c27c5ef2b6569fbce8849b52a6a0547ad66e97af Mon Sep 17 00:00:00 2001 From: Michal Futrega Date: Tue, 9 Jan 2024 17:55:36 +0100 Subject: [PATCH] Enhance flexibility by passing callbacks as method argument (#8015) * Enhance flexibility by passing callbacks as method argument Signed-off-by: Michal Futrega * Set callbacks default to None Signed-off-by: Michal Futrega --------- Signed-off-by: Michal Futrega --- nemo/collections/nlp/parts/megatron_trainer_builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index ff34f78fd183..c58e0be4a508 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -118,10 +118,12 @@ def _plugins(self) -> list: return plugins - def create_trainer(self) -> Trainer: + def create_trainer(self, callbacks=None) -> Trainer: strategy = self._training_strategy() plugins = self._plugins() - return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=[CustomProgressBar()]) + if callbacks is None: + callbacks = [CustomProgressBar()] + return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks) class MegatronBertTrainerBuilder(MegatronTrainerBuilder):