Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))


- Added LightningCLI support for argument links applied on instantiation ([#7895](https://github.com/PyTorchLightning/pytorch-lightning/pull/7895))


### Changed

- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
Expand Down
22 changes: 21 additions & 1 deletion docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
):
pass

class MyClassModel(LightningModule):
def __init__(self, num_classes: int):
pass

class MyDataModule(LightningDataModule):
def __init__(self, batch_size: int = 8):
pass
self.num_classes = 5

def send_email(address, message):
pass
Expand Down Expand Up @@ -402,6 +406,22 @@ The linking of arguments is observed in the help of the tool, which for this exa
model.batch_size <-- data.batch_size
Number of samples in a batch (type: int)

Sometimes a parameter value is only available after class instantiation. An example could be that your model requires the number of classes to instantiate its fully connected layer (for a classification task) but the value is not available until the data module has been instantiated.
The code below illustrates how to address this.

.. testcode::

from pytorch_lightning.utilities.cli import LightningCLI

class MyLightningCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.link_arguments('data.num_classes', 'model.num_classes', apply_on='instantiate')

cli = MyLightningCLI(MyClassModel, MyDataModule)

Instantiation links are used to automatically determine the order of instantiation, in this case data first.

.. tip::

The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes
Expand Down
29 changes: 9 additions & 20 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def add_lightning_class_args(
assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule))
if subclass_mode:
return self.add_subclass_arguments(lightning_class, nested_key, required=True)
return self.add_class_arguments(lightning_class, nested_key, fail_untyped=False)
return self.add_class_arguments(
lightning_class,
nested_key,
fail_untyped=False,
instantiate=not issubclass(lightning_class, Trainer),
)


class SaveConfigCallback(Callback):
Expand Down Expand Up @@ -212,27 +217,11 @@ def before_instantiate_classes(self) -> None:

def instantiate_classes(self) -> None:
"""Instantiates the classes using settings from self.config"""
self.config_init = self.parser.instantiate_subclasses(self.config)
self.instantiate_datamodule()
self.instantiate_model()
self.config_init = self.parser.instantiate_classes(self.config)
self.datamodule = self.config_init.get('data')
self.model = self.config_init['model']
self.instantiate_trainer()

def instantiate_datamodule(self) -> None:
"""Instantiates the datamodule using self.config_init['data'] if given"""
if self.datamodule_class is None:
self.datamodule = None
elif self.subclass_mode_data:
self.datamodule = self.config_init['data']
else:
self.datamodule = self.datamodule_class(**self.config_init.get('data', {}))

def instantiate_model(self) -> None:
"""Instantiates the model using self.config_init['model']"""
if self.subclass_mode_model:
self.model = self.config_init['model']
else:
self.model = self.model_class(**self.config_init.get('model', {}))

def instantiate_trainer(self) -> None:
"""Instantiates the trainer using self.config_init['trainer']"""
if self.config_init['trainer'].get('callbacks') is None:
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ torchtext>=0.5
# onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.13.1
jsonargparse[signatures]>=3.14.0
70 changes: 65 additions & 5 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,9 @@ def __init__(
with mock.patch('sys.argv', ['any.py'] + cli_args):
cli = LightningCLI(MainModule)

assert cli.config_init['model']['main_param'] == 2
assert cli.model.submodule1 == cli.config_init['model']['submodule1']
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)
assert cli.config['model']['main_param'] == 2
assert isinstance(cli.model.submodule1, BoringModel)
assert isinstance(cli.model.submodule2, BoringModel)


@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required')
Expand Down Expand Up @@ -497,3 +495,65 @@ def __init__(
assert cli.model.activation.negative_slope == 0.2
assert len(cli.model.transform) == 2
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)


class BoringModelRequiredClasses(BoringModel):

def __init__(
self,
num_classes: int,
batch_size: int = 8,
):
super().__init__()
self.num_classes = num_classes
self.batch_size = batch_size


class BoringDataModuleBatchSizeAndClasses(BoringDataModule):

def __init__(
self,
batch_size: int = 8,
):
super().__init__()
self.batch_size = batch_size
self.num_classes = 5 # only available after instantiation


def test_lightning_cli_link_arguments(tmpdir):

class MyLightningCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.link_arguments('data.batch_size', 'model.batch_size')
parser.link_arguments('data.num_classes', 'model.num_classes', apply_on='instantiate')

cli_args = [
f'--trainer.default_root_dir={tmpdir}',
'--trainer.max_epochs=1',
'--data.batch_size=12',
]

with mock.patch('sys.argv', ['any.py'] + cli_args):
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses)

assert cli.model.batch_size == 12
assert cli.model.num_classes == 5

class MyLightningCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.link_arguments('data.batch_size', 'model.init_args.batch_size')
parser.link_arguments('data.num_classes', 'model.init_args.num_classes', apply_on='instantiate')

cli_args[-1] = '--model=tests.utilities.test_cli.BoringModelRequiredClasses'

with mock.patch('sys.argv', ['any.py'] + cli_args):
cli = MyLightningCLI(
BoringModelRequiredClasses,
BoringDataModuleBatchSizeAndClasses,
subclass_mode_model=True,
)

assert cli.model.batch_size == 8
assert cli.model.num_classes == 5