-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Attach data refactor and tuner bugs [4/n] #7258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
395ce2a
65aa690
b3f3314
e7c3657
8060424
c9a9ca9
67d5ca8
d7fbc5d
de7937e
449041c
405082b
c255388
abc024f
d2a54d6
8641504
d8b0bf5
1eeedac
5c62b3a
ae89837
d30467d
08e7a54
236b9f4
35e6258
e2d90c3
869c857
668e68e
617e9e5
d60eb20
4502329
f2ccf21
9cbc154
46a4e92
175e2cd
23f2da2
f2ab21c
841381c
ae9d7e0
70a16ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,8 +56,9 @@ | |
from pytorch_lightning.trainer.states import TrainerState | ||
from pytorch_lightning.trainer.training_loop import TrainLoop | ||
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin | ||
from pytorch_lightning.tuner.lr_finder import _LRFinder | ||
from pytorch_lightning.tuner.tuning import Tuner | ||
from pytorch_lightning.utilities import DeviceType, rank_zero_warn | ||
from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn | ||
from pytorch_lightning.utilities.debugging import InternalDebugger | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.memory import recursive_detach | ||
|
@@ -409,21 +410,15 @@ def __init__( | |
# Callback system | ||
self.on_init_end() | ||
|
||
def _run( | ||
self, | ||
model: LightningModule, | ||
train_dataloader: Any = None, | ||
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, | ||
datamodule: Optional[LightningDataModule] = None, | ||
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: | ||
# set local properties on the model | ||
self.model_connector.copy_trainer_model_properties(model) | ||
def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: | ||
# clean hparams | ||
if hasattr(model, "hparams"): | ||
parsing.clean_namespace(model.hparams) | ||
|
||
# ---------------------------- | ||
# LINK DATA | ||
# ---------------------------- | ||
# setup data, etc... | ||
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) | ||
self.config_validator.verify_loop_configurations(model) | ||
|
||
# attach model log function to callback | ||
self.callback_connector.attach_model_logging_functions(model) | ||
|
||
# hook | ||
self.data_connector.prepare_data(model) | ||
|
@@ -848,14 +843,29 @@ def fit( | |
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. | ||
If the model has a predefined val_dataloaders method this will be skipped | ||
|
||
datamodule: A instance of :class:`LightningDataModule`. | ||
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. | ||
""" | ||
Trainer._log_api_event("fit") | ||
|
||
self.state = TrainerState.FITTING | ||
self.training = True | ||
|
||
self._run(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) | ||
# if a datamodule comes in as the second arg, then fix it for the user | ||
if isinstance(train_dataloader, LightningDataModule): | ||
datamodule = train_dataloader | ||
train_dataloader = None | ||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders | ||
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: | ||
raise MisconfigurationException( | ||
'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`' | ||
) | ||
|
||
# links data to the trainer | ||
self.data_connector.attach_data( | ||
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule | ||
) | ||
|
||
self._run(model) | ||
|
||
assert self.state.stopped | ||
self.training = False | ||
|
@@ -883,7 +893,7 @@ def validate( | |
|
||
verbose: If True, prints the validation results. | ||
|
||
datamodule: A instance of :class:`LightningDataModule`. | ||
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. | ||
|
||
Returns: | ||
The dictionary with final validation results returned by validation_epoch_end. | ||
|
@@ -908,10 +918,8 @@ def validate( | |
model_provided = model is not None | ||
model = model or self.lightning_module | ||
|
||
# Attach datamodule to get setup/prepare_data added to model before the call to it below | ||
self.data_connector.attach_datamodule(model, datamodule) | ||
# Attach dataloaders (if given) | ||
self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) | ||
# links data to the trainer | ||
self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule) | ||
Comment on lines
+921
to
+922
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Side question: In trainer.fit we allow trainer.fit(datamodule) but not for test and validate. Do you know why? Did we decide explicit naming by keyword is better? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good eye. I actually plan to open a PR adding this. Not doing it yet as it would give myself conflicts |
||
|
||
if not model_provided: | ||
self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) | ||
|
@@ -948,7 +956,7 @@ def test( | |
|
||
verbose: If True, prints the test results. | ||
|
||
datamodule: A instance of :class:`LightningDataModule`. | ||
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. | ||
|
||
Returns: | ||
Returns a list of dictionaries, one for each test dataloader containing their respective metrics. | ||
|
@@ -969,10 +977,8 @@ def test( | |
model_provided = model is not None | ||
model = model or self.lightning_module | ||
|
||
# Attach datamodule to get setup/prepare_data added to model before the call to it below | ||
self.data_connector.attach_datamodule(model, datamodule) | ||
# Attach dataloaders (if given) | ||
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) | ||
# links data to the trainer | ||
self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule) | ||
|
||
if not model_provided: | ||
self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) | ||
|
@@ -1063,10 +1069,8 @@ def predict( | |
if dataloaders is not None and datamodule: | ||
raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') | ||
|
||
# Attach datamodule to get setup/prepare_data added to model before the call to it below | ||
self.data_connector.attach_datamodule(model, datamodule) | ||
# Attach dataloaders (if given) | ||
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) | ||
# links data to the trainer | ||
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) | ||
|
||
results = self._run(model) | ||
|
||
|
@@ -1081,7 +1085,9 @@ def tune( | |
train_dataloader: Optional[DataLoader] = None, | ||
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, | ||
datamodule: Optional[LightningDataModule] = None, | ||
) -> None: | ||
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, | ||
lr_find_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Optional[Union[int, _LRFinder]]]: | ||
r""" | ||
Runs routines to tune hyperparameters before training. | ||
|
||
|
@@ -1094,17 +1100,38 @@ def tune( | |
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. | ||
If the model has a predefined val_dataloaders method this will be skipped | ||
|
||
datamodule: A instance of :class:`LightningDataModule`. | ||
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. | ||
|
||
scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` | ||
|
||
lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` | ||
""" | ||
Trainer._log_api_event("tune") | ||
self.state = TrainerState.TUNING | ||
self.tuning = True | ||
|
||
self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule) | ||
# if a datamodule comes in as the second arg, then fix it for the user | ||
if isinstance(train_dataloader, LightningDataModule): | ||
datamodule = train_dataloader | ||
train_dataloader = None | ||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders | ||
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: | ||
raise MisconfigurationException( | ||
'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`' | ||
) | ||
|
||
# links data to the trainer | ||
self.data_connector.attach_data( | ||
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule | ||
) | ||
SkafteNicki marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) | ||
|
||
assert self.state.stopped | ||
self.tuning = False | ||
|
||
return result | ||
|
||
def call_setup_hook(self, model: LightningModule) -> None: | ||
assert self.state.running, f"TrainerState: {self.state}" | ||
state = self._setup_state | ||
|
Uh oh!
There was an error while loading. Please reload this page.