Skip to content

Commit

Permalink
ref: move prepare_data to data connector (#3307)
Browse files Browse the repository at this point in the history
* ref: moved argparse code to central class

* ref: moved argparse code to central class

* ref: moved argparse code to central class
  • Loading branch information
williamFalcon authored Sep 1, 2020
1 parent 3910ad0 commit 7d57f8d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
19 changes: 19 additions & 0 deletions pytorch_lightning/trainer/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer

def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
if self.trainer.datamodule is not None:
self.trainer.datamodule.prepare_data()
model.prepare_data()
self.trainer._is_data_prepared = True

def can_prepare_data(self):
should_call_dm_prepare_data = True
if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule):
should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data

if self.trainer.prepare_data_per_node:
return self.trainer.local_rank == 0 and should_call_dm_prepare_data
else:
return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data

def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
Expand Down
23 changes: 2 additions & 21 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def tune(
self.call_hook('on_fit_start', model)

# hook
self.prepare_data(model)
self.data_connector.prepare_data(model)

# Run auto batch size scaling
if self.auto_scale_batch_size:
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def fit(
self.call_hook('on_fit_start', model)

# hook
self.prepare_data(model)
self.data_connector.prepare_data(model)

# set testing if set in environ
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
Expand Down Expand Up @@ -1056,15 +1056,6 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)

def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
if self.datamodule is not None:
self.datamodule.prepare_data()
model.prepare_data()
self._is_data_prepared = True

def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
Expand Down Expand Up @@ -1108,16 +1099,6 @@ def select_accelerator(self):

return accelerator_backend

def can_prepare_data(self):
should_call_dm_prepare_data = True
if self.datamodule is not None and is_overridden('prepare_data', self.datamodule):
should_call_dm_prepare_data = not self.datamodule.has_prepared_data

if self.prepare_data_per_node:
return self.local_rank == 0 and should_call_dm_prepare_data
else:
return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data

def setup_training(self, model: LightningModule):
"""Sanity check a few things before starting actual training.
Expand Down
16 changes: 8 additions & 8 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,26 @@ def test_can_prepare_data(tmpdir):
# local rank = 0 (True)
trainer.prepare_data_per_node = True
trainer.local_rank = 0
assert trainer.can_prepare_data()
assert trainer.data_connector.can_prepare_data()

# local rank = 1 (False)
trainer.local_rank = 1
assert not trainer.can_prepare_data()
assert not trainer.data_connector.can_prepare_data()

# prepare_data_per_node = False (prepare across all nodes)
# global rank = 0 (True)
trainer.prepare_data_per_node = False
trainer.node_rank = 0
trainer.local_rank = 0
assert trainer.can_prepare_data()
assert trainer.data_connector.can_prepare_data()

# global rank = 1 (False)
trainer.node_rank = 1
trainer.local_rank = 0
assert not trainer.can_prepare_data()
assert not trainer.data_connector.can_prepare_data()
trainer.node_rank = 0
trainer.local_rank = 1
assert not trainer.can_prepare_data()
assert not trainer.data_connector.can_prepare_data()

# 2 dm
# prepar per node = True
Expand All @@ -54,17 +54,17 @@ def test_can_prepare_data(tmpdir):
# has been called
# False
dm._has_prepared_data = True
assert not trainer.can_prepare_data()
assert not trainer.data_connector.can_prepare_data()

# has not been called
# True
dm._has_prepared_data = False
assert trainer.can_prepare_data()
assert trainer.data_connector.can_prepare_data()

# is_overridden prepare data = False
# True
dm.prepare_data = None
assert trainer.can_prepare_data()
assert trainer.data_connector.can_prepare_data()


def test_base_datamodule(tmpdir):
Expand Down

0 comments on commit 7d57f8d

Please sign in to comment.