Skip to content

Commit cdd01f3

Browse files
mauvilsacarmocca
andauthored
LightningCLI support for argument links applied on instantiation (#7895)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
1 parent 6856cce commit cdd01f3

File tree

5 files changed

+99
-27
lines changed

5 files changed

+99
-27
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7777
- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))
7878

7979

80+
- Added LightningCLI support for argument links applied on instantiation ([#7895](https://github.com/PyTorchLightning/pytorch-lightning/pull/7895))
81+
82+
8083
### Changed
8184

8285
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)

docs/source/common/lightning_cli.rst

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@
1919
):
2020
pass
2121

22+
class MyClassModel(LightningModule):
23+
def __init__(self, num_classes: int):
24+
pass
25+
2226
class MyDataModule(LightningDataModule):
2327
def __init__(self, batch_size: int = 8):
24-
pass
28+
self.num_classes = 5
2529

2630
def send_email(address, message):
2731
pass
@@ -402,6 +406,22 @@ The linking of arguments is observed in the help of the tool, which for this exa
402406
model.batch_size <-- data.batch_size
403407
Number of samples in a batch (type: int)
404408
409+
Sometimes a parameter value is only available after class instantiation. An example could be that your model requires the number of classes to instantiate its fully connected layer (for a classification task) but the value is not available until the data module has been instantiated.
410+
The code below illustrates how to address this.
411+
412+
.. testcode::
413+
414+
from pytorch_lightning.utilities.cli import LightningCLI
415+
416+
class MyLightningCLI(LightningCLI):
417+
418+
def add_arguments_to_parser(self, parser):
419+
parser.link_arguments('data.num_classes', 'model.num_classes', apply_on='instantiate')
420+
421+
cli = MyLightningCLI(MyClassModel, MyDataModule)
422+
423+
Instantiation links are used to automatically determine the order of instantiation, in this case data first.
424+
405425
.. tip::
406426

407427
The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes

pytorch_lightning/utilities/cli.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@ def add_lightning_class_args(
6666
assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule))
6767
if subclass_mode:
6868
return self.add_subclass_arguments(lightning_class, nested_key, required=True)
69-
return self.add_class_arguments(lightning_class, nested_key, fail_untyped=False)
69+
return self.add_class_arguments(
70+
lightning_class,
71+
nested_key,
72+
fail_untyped=False,
73+
instantiate=not issubclass(lightning_class, Trainer),
74+
)
7075

7176

7277
class SaveConfigCallback(Callback):
@@ -212,27 +217,11 @@ def before_instantiate_classes(self) -> None:
212217

213218
def instantiate_classes(self) -> None:
214219
"""Instantiates the classes using settings from self.config"""
215-
self.config_init = self.parser.instantiate_subclasses(self.config)
216-
self.instantiate_datamodule()
217-
self.instantiate_model()
220+
self.config_init = self.parser.instantiate_classes(self.config)
221+
self.datamodule = self.config_init.get('data')
222+
self.model = self.config_init['model']
218223
self.instantiate_trainer()
219224

220-
def instantiate_datamodule(self) -> None:
221-
"""Instantiates the datamodule using self.config_init['data'] if given"""
222-
if self.datamodule_class is None:
223-
self.datamodule = None
224-
elif self.subclass_mode_data:
225-
self.datamodule = self.config_init['data']
226-
else:
227-
self.datamodule = self.datamodule_class(**self.config_init.get('data', {}))
228-
229-
def instantiate_model(self) -> None:
230-
"""Instantiates the model using self.config_init['model']"""
231-
if self.subclass_mode_model:
232-
self.model = self.config_init['model']
233-
else:
234-
self.model = self.model_class(**self.config_init.get('model', {}))
235-
236225
def instantiate_trainer(self) -> None:
237226
"""Instantiates the trainer using self.config_init['trainer']"""
238227
if self.config_init['trainer'].get('callbacks') is None:

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ torchtext>=0.5
77
# onnx>=1.7.0
88
onnxruntime>=1.3.0
99
hydra-core>=1.0
10-
jsonargparse[signatures]>=3.13.1
10+
jsonargparse[signatures]>=3.14.0

tests/utilities/test_cli.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -446,11 +446,9 @@ def __init__(
446446
with mock.patch('sys.argv', ['any.py'] + cli_args):
447447
cli = LightningCLI(MainModule)
448448

449-
assert cli.config_init['model']['main_param'] == 2
450-
assert cli.model.submodule1 == cli.config_init['model']['submodule1']
451-
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
452-
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
453-
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)
449+
assert cli.config['model']['main_param'] == 2
450+
assert isinstance(cli.model.submodule1, BoringModel)
451+
assert isinstance(cli.model.submodule2, BoringModel)
454452

455453

456454
@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required')
@@ -497,3 +495,65 @@ def __init__(
497495
assert cli.model.activation.negative_slope == 0.2
498496
assert len(cli.model.transform) == 2
499497
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
498+
499+
500+
class BoringModelRequiredClasses(BoringModel):
501+
502+
def __init__(
503+
self,
504+
num_classes: int,
505+
batch_size: int = 8,
506+
):
507+
super().__init__()
508+
self.num_classes = num_classes
509+
self.batch_size = batch_size
510+
511+
512+
class BoringDataModuleBatchSizeAndClasses(BoringDataModule):
513+
514+
def __init__(
515+
self,
516+
batch_size: int = 8,
517+
):
518+
super().__init__()
519+
self.batch_size = batch_size
520+
self.num_classes = 5 # only available after instantiation
521+
522+
523+
def test_lightning_cli_link_arguments(tmpdir):
524+
525+
class MyLightningCLI(LightningCLI):
526+
527+
def add_arguments_to_parser(self, parser):
528+
parser.link_arguments('data.batch_size', 'model.batch_size')
529+
parser.link_arguments('data.num_classes', 'model.num_classes', apply_on='instantiate')
530+
531+
cli_args = [
532+
f'--trainer.default_root_dir={tmpdir}',
533+
'--trainer.max_epochs=1',
534+
'--data.batch_size=12',
535+
]
536+
537+
with mock.patch('sys.argv', ['any.py'] + cli_args):
538+
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses)
539+
540+
assert cli.model.batch_size == 12
541+
assert cli.model.num_classes == 5
542+
543+
class MyLightningCLI(LightningCLI):
544+
545+
def add_arguments_to_parser(self, parser):
546+
parser.link_arguments('data.batch_size', 'model.init_args.batch_size')
547+
parser.link_arguments('data.num_classes', 'model.init_args.num_classes', apply_on='instantiate')
548+
549+
cli_args[-1] = '--model=tests.utilities.test_cli.BoringModelRequiredClasses'
550+
551+
with mock.patch('sys.argv', ['any.py'] + cli_args):
552+
cli = MyLightningCLI(
553+
BoringModelRequiredClasses,
554+
BoringDataModuleBatchSizeAndClasses,
555+
subclass_mode_model=True,
556+
)
557+
558+
assert cli.model.batch_size == 8
559+
assert cli.model.num_classes == 5

0 commit comments

Comments
 (0)