Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend the LightningCLI to register models, datamodules, and callbacks. #7250

Closed
tchaton opened this issue Apr 28, 2021 · 6 comments · Fixed by #10011
Closed

Extend the LightningCLI to register models, datamodules, and callbacks. #7250

tchaton opened this issue Apr 28, 2021 · 6 comments · Fixed by #10011
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@tchaton
Copy link
Contributor

tchaton commented Apr 28, 2021

🚀 Feature

Implement a CLI store/provider/registry for available LightningModules, LightningDataModules, and Callbacks.

Implementation

class LightningCLI:
    store = Registry()

    @classmethod
    def register(cls, *args, **kwargs):
        cls.store.register(*args, **kwargs)

The registry idea is highly influenced by this implementation:

https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/core/registry.py
https://github.com/PyTorchLightning/lightning-flash/blob/master/tests/core/test_registry.py

Usage

# note: these `register` calls could be in different files
LightningCLI.register('my-callback', CallbackA)

LightningCLI.register('vision-model', VisionModel)

LightningCLI.register('mnist-data', MnistDataModule)
LightningCLI.register('cifar10-data', CIFAR10DataModule)
LightningCLI.register('imagenet-data', ImageNetDataModule)

# API might change - naming the model-data pair as an experiment
LightningCLI.register_experiment(VisionModel, MnistDataModule)
LightningCLI.register_experiment(VisionModel, CIFAR10DataModule)
LightningCLI.register_experiment(VisionModel, ImageneteDataModule)

# or with a decorator.
# `name` could be taken from the class or function
# `arg` could be inferred from the parent class
@LightningCLI.register(name='some_model', arg='model')
class SomeModel(LightningModule)
    ...

# in train.py
cli = LightningCLI()

Console interaction

Basic help

python train.py -h 
callbacks: {
    'my-callback': CallbackA,
}
models: {
    'vision-model': VisionModel,
}
datamodules: {
    'mnist-data': MnistDataModule,
    'cifar10-data': CIFAR10DataModule,
    'imagenet-data': ImageNetDataModule,
}
experiments: {
    'vision-mnist': {'model': VisionModel, 'data': MnistDataModule},
    'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
    'vision-imagenet': {'model': VisionModel, 'data': ImageNetDataModule},  
}

Note that the help output should look nicer, showing raw dicts for simplicity

Filter by category

python train.py -h model=VisionModel
# alternatively
python train-py -h model='vision-model'
callbacks: {
    'my-callback': CallbackA,
}
datamodules: {
    'mnist-data': MnistDataModule,
    'cifar10-data': CIFAR10DataModule,
    'imagenet-data': ImageNetDataModule,
}
experiments: {
    'vision-mnist': {'model': VisionModel, 'data': MnistDataModule},
    'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
    'vision-imagenet': {'model': VisionModel, 'data': ImageNetDataModule},  
}
python train.py -h data=CIFAR10DataModule
callbacks: {
    'my-callback': CallbackA,
}
models: {
    'vision-model': VisionModel,
}
datamodules: {
    'cifar10-data': CIFAR10DataModule,
}
experiments: {
    'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
}

Basic LightningCLI usage (similar to what's currently implemented)

python train.py -h model=VisionModel data=ImageNetDataModule
# same as
python train.py -h experiment='vision-imagenet'
@tchaton tchaton added feature Is an improvement or enhancement help wanted Open to be worked on labels Apr 28, 2021
@tchaton tchaton added this to the v1.4 milestone Apr 28, 2021
@mauvilsa
Copy link
Contributor

@tchaton thank you for the proposal. I haven't thought about it yet. But I will do and comment later.

@carmocca
Copy link
Contributor

carmocca commented May 3, 2021

@tchaton @mauvilsa Modified the original post

@carmocca carmocca changed the title LightningCLIStore Extend the LightningCLI to register models, datamodules, and callbacks. May 3, 2021
@carmocca carmocca added the argparse (removed) Related to argument parsing (argparse, Hydra, ...) label May 3, 2021
@mpariente
Copy link
Contributor

This looks very nice, Asteroid would most probably use this!

@carmocca carmocca added design Includes a design discussion discussion In a discussion stage labels May 7, 2021
@edenlightning edenlightning added the priority: 0 High priority task label May 9, 2021
@edgarriba edgarriba self-assigned this May 10, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jul 8, 2021
@Andry-Bal
Copy link

This feature also will make specifying arguments in command line more convenient. Currently if you use subclasses for either model or data, for every argument you have to use init_args in command line, e.g. --data.init_args.batch_size 64, which is a bit annoying.

@mauvilsa
Copy link
Contributor

Being able to register models or data modules does not imply that the arguments can be overridden without using init_args. That is completely different feature. If you like you can create a feature request for this. Could be in jsonargparse since it does not really relate to lightning.

@tchaton
Copy link
Contributor Author

tchaton commented Jul 28, 2021

@register_model('image')
ModelA

@register_model('image')
ModelB

@register_datamodule('image')
DataModuleA

@register_datamodule('image')
DataModuleB

model = register.create_model()
datamodule = register.create_datamodule()
callbacks = register.create_callbacks()
trainer = register.create_trainer()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
8 participants