Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attach data refactor and tuner bugs [4/n] #7258

Merged
merged 38 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
395ce2a
`_fit_impl` refactor and types
carmocca Apr 27, 2021
65aa690
Fix return
carmocca Apr 27, 2021
b3f3314
Remove return docstring
carmocca Apr 27, 2021
e7c3657
Fixes
carmocca Apr 27, 2021
8060424
Fixes
carmocca Apr 27, 2021
c9a9ca9
Merge branch 'master' into use-fit-impl
carmocca Apr 27, 2021
67d5ca8
Undo results change
carmocca Apr 27, 2021
d7fbc5d
Revert changes for a separate PR
carmocca Apr 28, 2021
de7937e
WIP
carmocca Apr 28, 2021
449041c
Merge branch 'master' into attach-data-refactor
carmocca Apr 28, 2021
405082b
Progress
carmocca Apr 28, 2021
c255388
Deprecation messages
carmocca Apr 28, 2021
abc024f
Merge branch 'master' into attach-data-refactor
carmocca Apr 28, 2021
d2a54d6
Fixes
carmocca Apr 28, 2021
8641504
Move `copy_trainer_model_properties`
carmocca Apr 28, 2021
d8b0bf5
Code cleaning in preparation for 7258
carmocca Apr 28, 2021
1eeedac
Update CHANGELOG
carmocca Apr 28, 2021
5c62b3a
Fix test
carmocca Apr 28, 2021
ae89837
Merge branch 'changes-for-7258' into attach-data-refactor
carmocca Apr 28, 2021
d30467d
Update CHANGELOG
carmocca Apr 28, 2021
08e7a54
Update docs
carmocca Apr 28, 2021
236b9f4
Fix test?
carmocca Apr 28, 2021
35e6258
Merge branch 'master' into attach-data-refactor
carmocca Apr 29, 2021
e2d90c3
Fix docs
carmocca Apr 29, 2021
869c857
Dict return for trainer.tune
carmocca Apr 29, 2021
668e68e
Undo some changes
carmocca Apr 29, 2021
617e9e5
Undo some changes
carmocca Apr 29, 2021
d60eb20
Undo some changes
carmocca Apr 29, 2021
4502329
Apply suggestions from code review
carmocca Apr 29, 2021
f2ccf21
Undo deprecation
carmocca Apr 30, 2021
9cbc154
Fix docs
carmocca Apr 30, 2021
46a4e92
Fix docs
carmocca Apr 30, 2021
175e2cd
Apply suggestions from code review
carmocca Apr 30, 2021
23f2da2
Apply suggestions from code review
carmocca Apr 30, 2021
f2ab21c
Merge branch 'master' into attach-data-refactor
carmocca Apr 30, 2021
841381c
Update scale_batch_size mode docstring
carmocca Apr 30, 2021
ae9d7e0
Merge branch 'master' into attach-data-refactor
carmocca Apr 30, 2021
70a16ea
Merge branch 'master' into attach-data-refactor
carmocca Apr 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BasePredictionWriter` callback to implement prediction saving ([#7127](https://github.com/PyTorchLightning/pytorch-lightning/pull/7127))


- Added `trainer.tune(scale_batch_size_kwargs, lr_find_kwargs)` arguments to configure the tuning algorithms ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258))


- Added `tpu_distributed` check for TPU Spawn barrier ([#7241](https://github.com/PyTorchLightning/pytorch-lightning/pull/7241))


Expand Down Expand Up @@ -178,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937))


- `trainer.tune()` now returns the tuning result ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258))


- `LightningModule.from_datasets()` now accepts `IterableDataset` instances as training datasets. ([#7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503))


Expand Down Expand Up @@ -325,6 +331,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


- Fixed `trainer.tuner.{lr_find,scale_batch_size}` not setting the `Trainer` state properly ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258))


- Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879))


Expand Down
6 changes: 3 additions & 3 deletions docs/source/advanced/lr_finder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ If your model is using an arbitrary value instead of ``self.lr`` or ``self.learn
trainer.tune(model)


If you want to inspect the results of the learning rate finder or just play around
with the parameters of the algorithm, this can be done by invoking the ``lr_find``
method of the trainer. A typical example of this would look like
You can also inspect the results of the learning rate finder or just play around
with the parameters of the algorithm. This can be done by invoking the
:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method. A typical example of this would look like:

.. code-block:: python

Expand Down
8 changes: 2 additions & 6 deletions docs/source/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ search for batch sizes larger than the size of the training dataset.
to `.fit()`.

The scaling algorithm has a number of parameters that the user can control by
invoking the trainer method `.scale_batch_size` themself (see description below).
invoking the :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size` method:

.. code-block:: python

Expand All @@ -123,7 +123,7 @@ invoking the trainer method `.scale_batch_size` themself (see description below)
# Invoke method
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)

# Override old batch size
# Override old batch size (this is done automatically)
model.hparams.batch_size = new_batch_size

# Fit as normal
Expand All @@ -142,10 +142,6 @@ The algorithm in short works by:
3. The found batch size is saved to either `model.batch_size` or `model.hparams.batch_size`
4. Restore the initial state of model and trainer

.. autoclass:: pytorch_lightning.tuner.tuning.Tuner
:noindex:
:members: scale_batch_size

.. warning:: Batch size finder is not supported for DDP yet, it is coming soon.


Expand Down
6 changes: 3 additions & 3 deletions docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ Trainer API
Tuner API
---------

.. currentmodule:: pytorch_lightning.tuner
.. currentmodule:: pytorch_lightning.tuner.tuning

.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst

batch_size_scaling
lr_finder
Tuner

Utilities API
-------------
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):

Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`LightningDataModule`.
parsed and passed to the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid DataModule arguments.

Expand Down
37 changes: 19 additions & 18 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -68,24 +67,26 @@ def can_prepare_data(self):
else:
return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data

def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None

self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)

def attach_data(
self,
model: 'pl.LightningModule',
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional['pl.LightningDataModule'] = None
) -> None:
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule)

def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)
self.attach_dataloaders(
model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
predict_dataloaders=predict_dataloaders,
)
self.attach_datamodule(model, datamodule=datamodule)
# set local properties on the model
self.trainer.model_connector.copy_trainer_model_properties(model)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def attach_dataloaders(
self,
Expand Down
95 changes: 61 additions & 34 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
Expand Down Expand Up @@ -409,21 +410,15 @@ def __init__(
# Callback system
self.on_init_end()

def _run(
self,
model: LightningModule,
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# set local properties on the model
self.model_connector.copy_trainer_model_properties(model)
def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)

# ----------------------------
# LINK DATA
# ----------------------------
# setup data, etc...
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
self.config_validator.verify_loop_configurations(model)

# attach model log function to callback
self.callback_connector.attach_model_logging_functions(model)

# hook
self.data_connector.prepare_data(model)
Expand Down Expand Up @@ -848,14 +843,29 @@ def fit(
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

datamodule: A instance of :class:`LightningDataModule`.
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
Trainer._log_api_event("fit")

self.state = TrainerState.FITTING
self.training = True

self._run(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule)
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`'
)

# links data to the trainer
self.data_connector.attach_data(
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule
)

self._run(model)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -883,7 +893,7 @@ def validate(

verbose: If True, prints the validation results.

datamodule: A instance of :class:`LightningDataModule`.
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.

Returns:
The dictionary with final validation results returned by validation_epoch_end.
Expand All @@ -908,10 +918,8 @@ def validate(
model_provided = model is not None
model = model or self.lightning_module

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders)
# links data to the trainer
self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule)
Comment on lines +921 to +922
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side question: In trainer.fit we allow trainer.fit(datamodule) but not for test and validate. Do you know why? Did we decide explicit naming by keyword is better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good eye. I actually plan to open a PR adding this.

Not doing it yet as it would give myself conflicts


if not model_provided:
self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path)
Expand Down Expand Up @@ -948,7 +956,7 @@ def test(

verbose: If True, prints the test results.

datamodule: A instance of :class:`LightningDataModule`.
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.

Returns:
Returns a list of dictionaries, one for each test dataloader containing their respective metrics.
Expand All @@ -969,10 +977,8 @@ def test(
model_provided = model is not None
model = model or self.lightning_module

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
# links data to the trainer
self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule)

if not model_provided:
self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path)
Expand Down Expand Up @@ -1063,10 +1069,8 @@ def predict(
if dataloaders is not None and datamodule:
raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`')

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)
# links data to the trainer
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

results = self._run(model)

Expand All @@ -1081,7 +1085,9 @@ def tune(
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
) -> None:
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
r"""
Runs routines to tune hyperparameters before training.

Expand All @@ -1094,17 +1100,38 @@ def tune(
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

datamodule: A instance of :class:`LightningDataModule`.
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.

scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size`

lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`
"""
Trainer._log_api_event("tune")
self.state = TrainerState.TUNING
self.tuning = True

self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule)
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`'
)

# links data to the trainer
self.data_connector.attach_data(
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule
)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs)

assert self.state.stopped
self.tuning = False

return result

def call_setup_hook(self, model: LightningModule) -> None:
assert self.state.running, f"TrainerState: {self.state}"
state = self._setup_state
Expand Down
16 changes: 1 addition & 15 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
Expand Down Expand Up @@ -91,20 +91,6 @@ def on_train_start(self):
# hook
self.trainer.call_hook("on_train_start")

def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)

# links data to the trainer
self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule)

# check that model is configured correctly
self.trainer.config_validator.verify_loop_configurations(model)

# attach model log function to callback
self.trainer.callback_connector.attach_model_logging_functions(model)

def on_train_end(self):
if self._teardown_already_run:
return
Expand Down
Loading