Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: added data connector #3285

Merged
merged 10 commits into from
Aug 31, 2020
Merged
8 changes: 8 additions & 0 deletions docs/source/converting.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
.. testsetup:: *

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.trainer import Trainer

.. _converting:

**************************************
How to organize PyTorch into Lightning
**************************************
Expand Down
18 changes: 6 additions & 12 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,14 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden


class ConfigValidator(object):

def __init__(self, trainer):
self.trainer = trainer

def enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

def verify_loop_configurations(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training or testing is started.
Expand All @@ -48,7 +42,7 @@ def __verify_train_loop_configuration(self, model):
# -----------------------------------
# verify model has a training step
# -----------------------------------
has_training_step = self.trainer.is_overridden('training_step', model)
has_training_step = is_overridden('training_step', model)
if not has_training_step:
raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
Expand All @@ -58,7 +52,7 @@ def __verify_train_loop_configuration(self, model):
# -----------------------------------
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = self.trainer.is_overridden('train_dataloader', model)
has_train_dataloader = is_overridden('train_dataloader', model)
if not has_train_dataloader:
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
Expand All @@ -68,7 +62,7 @@ def __verify_train_loop_configuration(self, model):
# -----------------------------------
# verify model has optimizer
# -----------------------------------
has_optimizers = self.trainer.is_overridden('configure_optimizers', model)
has_optimizers = is_overridden('configure_optimizers', model)
if not has_optimizers:
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
Expand All @@ -83,8 +77,8 @@ def __verify_eval_loop_configuration(self, model, eval_loop_name):
if eval_loop_name == 'validation':
loader_name = 'val_dataloader'

has_loader = self.trainer.is_overridden(loader_name, model)
has_step = self.trainer.is_overridden(step_name, model)
has_loader = is_overridden(loader_name, model)
has_step = is_overridden(step_name, model)

if has_loader and not has_step:
rank_zero_warn(
Expand Down
100 changes: 100 additions & 0 deletions pytorch_lightning/trainer/data_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import List, Union
from torch.utils.data import DataLoader
from pytorch_lightning.utilities.model_utils import is_overridden


class DataConnector(object):

def __init__(self, trainer):
self.trainer = trainer

def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None

self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)

# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule, 'fit')

def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)

def attach_datamodule(self, model, datamodule, stage):

# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)

# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:

# Override loader hooks
if is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader

# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
if is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device

self.trainer.datamodule = datamodule


class _PatchDataLoader(object):
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.

Args:
dataloader: Dataloader object to return when called.

"""

def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader
104 changes: 21 additions & 83 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
from pytorch_lightning.trainer.data_connector import DataConnector

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -607,6 +608,7 @@ def __init__(
# tracks internal state for debugging
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.accelerator_backend = None

# loops
Expand Down Expand Up @@ -974,18 +976,8 @@ def fit(
"""
results = None

# bind logger and other properties
self.copy_trainer_model_properties(model)

# clean hparams
if hasattr(model, 'hparams'):
parsing.clean_namespace(model.hparams)

# links data to the trainer
self.attach_data(model, train_dataloader, val_dataloaders, datamodule)

# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)
# setup data, etc...
self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

# hook
self.call_hook('on_fit_start', model)
Expand Down Expand Up @@ -1031,6 +1023,20 @@ def fit(
# used for testing or when we need to know that training succeeded
return results or 1

def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
# bind logger and other properties
self.copy_trainer_model_properties(model)

# clean hparams
if hasattr(model, 'hparams'):
parsing.clean_namespace(model.hparams)

# links data to the trainer
self.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule)

# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)

def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
Expand All @@ -1040,18 +1046,6 @@ def prepare_data(self, model):
model.prepare_data()
self._is_data_prepared = True

def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None

self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)

# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
self.__attach_datamodule(model, datamodule, 'fit')

def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
Expand Down Expand Up @@ -1105,40 +1099,6 @@ def can_prepare_data(self):
else:
return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data

def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)

def __attach_datamodule(self, model, datamodule, stage):

# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)

# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:

# Override loader hooks
if self.is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if self.is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if self.is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader

# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
if self.is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device

self.datamodule = datamodule

def run_pretrain_routine(self, model: LightningModule):
"""Sanity check a few things before starting actual training.

Expand Down Expand Up @@ -1348,7 +1308,7 @@ def test(
)

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.__attach_datamodule(model or self.get_model(), datamodule, 'test')
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')

if model is not None:
results = self.__test_given_model(model, test_dataloaders)
Expand Down Expand Up @@ -1386,7 +1346,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):

# attach dataloaders
if test_dataloaders is not None:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)

# run tests
self.tested_ckpt_path = ckpt_path
Expand All @@ -1408,7 +1368,7 @@ def __test_given_model(self, model, test_dataloaders):

# attach data
if test_dataloaders is not None:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)

# run test
# sets up testing so we short circuit to eval
Expand Down Expand Up @@ -1472,28 +1432,6 @@ def call_hook(self, hook_name, *args, **kwargs):
return output


class _PatchDataLoader(object):
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.

Args:
dataloader: Dataloader object to return when called.

"""

def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
Expand Down
29 changes: 29 additions & 0 deletions pytorch_lightning/utilities/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule


def is_overridden(method_name: str, model: LightningModule) -> bool:
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule

# assert model, 'no model passes'

if not hasattr(model, method_name):
# in case of calling deprecated method
return False

instance_attr = getattr(model, method_name)
if not instance_attr:
return False
super_attr = getattr(super_object, method_name)

# when code pointers are different, it was implemented
if hasattr(instance_attr, 'patch_loader_code'):
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overridden = instance_attr.__code__ is not super_attr.__code__
return is_overridden