Skip to content

Commit

Permalink
fixes logger crash on ddp (#2388)
Browse files Browse the repository at this point in the history
* remove warnings

* remove warnings

* remove warnings

* remove warnings

* remove warnings

* remove warnings

* remove warnings

* remove warnings

* remove warnings

* remove warnings
  • Loading branch information
williamFalcon authored Jun 27, 2020
1 parent 41f5df1 commit 90f641a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 10 deletions.
3 changes: 0 additions & 3 deletions pl_examples/models/lightning_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,12 @@ def setup(self, stage):
self.mnist_test = MNIST(self.data_root, train=False, download=False, transform=transform)

def train_dataloader(self):
log.info('Training data loader called.')
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

def val_dataloader(self):
log.info('Validation data loader called.')
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)

def test_dataloader(self):
log.info('Test data loader called.')
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)

@staticmethod
Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def configure_checkpoint_callback(self):
ckpt_path = self.default_root_dir
if self.checkpoint_callback:
# init a default one
if self.logger is not None:
if self.logger is not None and self.logger.experiment is not None:
save_dir = (getattr(self.logger, 'save_dir', None) or
getattr(self.logger, '_save_dir', None) or
self.default_root_dir)
Expand All @@ -53,12 +53,7 @@ def configure_checkpoint_callback(self):

version = self.logger.version if isinstance(
self.logger.version, str) else f'version_{self.logger.version}'
ckpt_path = os.path.join(
save_dir,
self.logger.name,
version,
"checkpoints"
)
ckpt_path = os.path.join(save_dir, self.logger.name, version, "checkpoints")
else:
ckpt_path = os.path.join(self.default_root_dir, "checkpoints")

Expand Down

0 comments on commit 90f641a

Please sign in to comment.