diff --git a/pytorch_lightning/trainer/data_connector.py b/pytorch_lightning/trainer/data_connector.py index 3e03a91cff5ce..5b7556a4bc07a 100644 --- a/pytorch_lightning/trainer/data_connector.py +++ b/pytorch_lightning/trainer/data_connector.py @@ -24,6 +24,25 @@ class DataConnector(object): def __init__(self, trainer): self.trainer = trainer + def prepare_data(self, model): + # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 + # or in the case where each node needs to do its own manipulation in which case just local_rank=0 + if self.can_prepare_data(): + if self.trainer.datamodule is not None: + self.trainer.datamodule.prepare_data() + model.prepare_data() + self.trainer._is_data_prepared = True + + def can_prepare_data(self): + should_call_dm_prepare_data = True + if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule): + should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data + + if self.trainer.prepare_data_per_node: + return self.trainer.local_rank == 0 and should_call_dm_prepare_data + else: + return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data + def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # if a datamodule comes in as the second arg, then fix it for the user if isinstance(train_dataloader, LightningDataModule): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b4bbc43d96345..a0878cedeee56 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -944,7 +944,7 @@ def tune( self.call_hook('on_fit_start', model) # hook - self.prepare_data(model) + self.data_connector.prepare_data(model) # Run auto batch size scaling if self.auto_scale_batch_size: @@ -1014,7 +1014,7 @@ def fit( self.call_hook('on_fit_start', model) # hook - self.prepare_data(model) + self.data_connector.prepare_data(model) # set testing if set in environ self.testing = os.environ.get('PL_TESTING_MODE', self.testing) @@ -1056,15 +1056,6 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # check that model is configured correctly self.config_validator.verify_loop_configurations(model) - def prepare_data(self, model): - # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 - # or in the case where each node needs to do its own manipulation in which case just local_rank=0 - if self.can_prepare_data(): - if self.datamodule is not None: - self.datamodule.prepare_data() - model.prepare_data() - self._is_data_prepared = True - def select_accelerator(self): # SLURM ddp use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks @@ -1108,16 +1099,6 @@ def select_accelerator(self): return accelerator_backend - def can_prepare_data(self): - should_call_dm_prepare_data = True - if self.datamodule is not None and is_overridden('prepare_data', self.datamodule): - should_call_dm_prepare_data = not self.datamodule.has_prepared_data - - if self.prepare_data_per_node: - return self.local_rank == 0 and should_call_dm_prepare_data - else: - return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data - def setup_training(self, model: LightningModule): """Sanity check a few things before starting actual training. diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c6ca3943740c6..b00587b0616c3 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -23,26 +23,26 @@ def test_can_prepare_data(tmpdir): # local rank = 0 (True) trainer.prepare_data_per_node = True trainer.local_rank = 0 - assert trainer.can_prepare_data() + assert trainer.data_connector.can_prepare_data() # local rank = 1 (False) trainer.local_rank = 1 - assert not trainer.can_prepare_data() + assert not trainer.data_connector.can_prepare_data() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) trainer.prepare_data_per_node = False trainer.node_rank = 0 trainer.local_rank = 0 - assert trainer.can_prepare_data() + assert trainer.data_connector.can_prepare_data() # global rank = 1 (False) trainer.node_rank = 1 trainer.local_rank = 0 - assert not trainer.can_prepare_data() + assert not trainer.data_connector.can_prepare_data() trainer.node_rank = 0 trainer.local_rank = 1 - assert not trainer.can_prepare_data() + assert not trainer.data_connector.can_prepare_data() # 2 dm # prepar per node = True @@ -54,17 +54,17 @@ def test_can_prepare_data(tmpdir): # has been called # False dm._has_prepared_data = True - assert not trainer.can_prepare_data() + assert not trainer.data_connector.can_prepare_data() # has not been called # True dm._has_prepared_data = False - assert trainer.can_prepare_data() + assert trainer.data_connector.can_prepare_data() # is_overridden prepare data = False # True dm.prepare_data = None - assert trainer.can_prepare_data() + assert trainer.data_connector.can_prepare_data() def test_base_datamodule(tmpdir):