diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index 69956129bdde6..b1ebada717476 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -95,10 +95,10 @@ def _plugins(self) -> list: return plugins - def create_trainer(self) -> Trainer: + def create_trainer(self, callbacks=(CustomProgressBar(),)) -> Trainer: strategy = self._training_strategy() plugins = self._plugins() - return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=[CustomProgressBar()]) + return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks) class MegatronBertTrainerBuilder(MegatronTrainerBuilder):