From 90f641af0d509645ecd679d00f1213f68d4a44ad Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 27 Jun 2020 15:08:22 -0400 Subject: [PATCH] fixes logger crash on ddp (#2388) * remove warnings * remove warnings * remove warnings * remove warnings * remove warnings * remove warnings * remove warnings * remove warnings * remove warnings * remove warnings --- pl_examples/models/lightning_template.py | 3 --- pytorch_lightning/trainer/callback_config.py | 9 ++------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/pl_examples/models/lightning_template.py b/pl_examples/models/lightning_template.py index af2f29f79a032..c79405c04bb13 100644 --- a/pl_examples/models/lightning_template.py +++ b/pl_examples/models/lightning_template.py @@ -150,15 +150,12 @@ def setup(self, stage): self.mnist_test = MNIST(self.data_root, train=False, download=False, transform=transform) def train_dataloader(self): - log.info('Training data loader called.') return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4) def val_dataloader(self): - log.info('Validation data loader called.') return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4) def test_dataloader(self): - log.info('Test data loader called.') return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4) @staticmethod diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 5e490a106826b..80088db362de4 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -42,7 +42,7 @@ def configure_checkpoint_callback(self): ckpt_path = self.default_root_dir if self.checkpoint_callback: # init a default one - if self.logger is not None: + if self.logger is not None and self.logger.experiment is not None: save_dir = (getattr(self.logger, 'save_dir', None) or getattr(self.logger, '_save_dir', None) or self.default_root_dir) @@ -53,12 +53,7 @@ def configure_checkpoint_callback(self): version = self.logger.version if isinstance( self.logger.version, str) else f'version_{self.logger.version}' - ckpt_path = os.path.join( - save_dir, - self.logger.name, - version, - "checkpoints" - ) + ckpt_path = os.path.join(save_dir, self.logger.name, version, "checkpoints") else: ckpt_path = os.path.join(self.default_root_dir, "checkpoints")