Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easier way to configure optimizers and schedulers in the CLI #7576

Closed
jlperla opened this issue May 17, 2021 · 14 comments · Fixed by #8093 or #9565
Closed

Easier way to configure optimizers and schedulers in the CLI #7576

jlperla opened this issue May 17, 2021 · 14 comments · Fixed by #8093 or #9565
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) discussion In a discussion stage feature Is an improvement or enhancement

Comments

@jlperla
Copy link

jlperla commented May 17, 2021

🚀 Feature

Right now if you want to work with multiple optimizers and/or learning rate schedulers you need to write a whole bunch of ugly boilerplate. I suggest that the LightningCLI in one way or another to enable that sort of optimizer configuration code in a clean way, which would take care of most use cases.

Motivation

Right now I have the following code in configure_optimizers callback:

    # Configuration.  Add more for learning schedulers, etc.?
    def configure_optimizers(self):
        if self.hparams.optimizer == "Adam":
            optimizer = torch.optim.Adam(
                self.parameters(),
                lr=self.hparams.learning_rate,
                weight_decay=self.hparams.weight_decay,
            )
        elif self.hparams.optimizer == "SGD":
            # Left out the momentum options for now
            optimizer = torch.optim.SGD(
                self.parameters(),
                lr=self.hparams.learning_rate,
                weight_decay=self.hparams.weight_decay,
            )
        elif self.hparams.optimizer == "LBFGS":
            optimizer = torch.optim.LBFGS(
                self.parameters(),
                # or can have self.hparams.learning_rate with warning if too low.
                lr=1,
                tolerance_grad=1e-5,  # can add to parameters if useful.
                tolerance_change=1e-9,  # can add to parameters if useful.
            )
        else:
            print("Invalid optimizer.  See --help")
            sys.exit()

        if self.hparams.lr_scheduler == None:
            return optimizer

        # Setup the scheduler
        scheduler = None
        try:
            if self.hparams.lr_scheduler == "StepLR":
                step_size = self.hparams.StepLR_step_size
                gamma = self.hparams.StepLR_gamma
                scheduler = torch.optim.lr_scheduler.StepLR(
                    optimizer, step_size, gamma)
            elif self.hparams.lr_scheduler == "ReduceLROnPlateau":
                factor = self.hparams.lr_factor
                patience = self.hparams.lr_patience
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, mode="min", factor=factor, patience=patience
                )
                return {
                    "optimizer": optimizer,
                    "lr_scheduler": scheduler,
                    "monitor": self.hparams.LRScheduler_metric,
                }
            elif self.hparams.lr_scheduler == "OneCycleLR":
                max_lr = self.hparams.OneCycleLR_max_lr
                epochs = self.hparams.OneCycleLR_epochs
                steps_per_epoch = self.hparams.train_trajectories * (
                    self.hparams.T + 1
                )  # I think this is the same as len(train_data_loader)? #TODO
                scheduler = torch.optim.lr_scheduler.OneCycleLR(
                    optimizer,
                    max_lr=max_lr,
                    epochs=epochs,
                    steps_per_epoch=steps_per_epoch,
                )
        except:
            print("Invalid scheduler configuration.  See --help")
            raise
        return [optimizer], [scheduler]

And the commandline configuration looks like

        parser.add_argument("--learning_rate", type=float, default=5e-3)
        parser.add_argument("--LRScheduler_metric",
                            type=str, default="val_loss")
        parser.add_argument(
            "--lr_scheduler",
            default="ReduceLROnPlateau",
            help="Learning rate scheduler name, it can be StepLR, ReduceLROnPlateau, or, OneCycleLR. Defaults to None.",
        )
        parser.add_argument(
            "--optimizer",
            type=str,
            default="Adam",
            help="Choice optimizer: Adam, SGD, or LBFGS.",
        )
        parser.add_argument(
            "--StepLR_step_size",
            type=int,
            default=30,  # TODO there was no actual default in the source code
            help="Step size for StepLR scheduler. Defaults to 30.",
        )
        parser.add_argument(
            "--StepLR_gamma",
            type=float,
            default=0.1,
            help="Gamma for StepLR scheduler. Defaults to 0.1 ",
        )
        # plateau based learning rate
        parser.add_argument(
            "--lr_factor",
            type=float,
            default=0.1,
            help="Factor by which the learning rate will be reduced. new_lr = lr * factor.",
        )
        parser.add_argument(
            "--lr_patience",
            type=int,
            default=5,
            help="Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then",
        )
        parser.add_argument(
            "--OneCycleLR_max_lr",
            type=float,  # TODO no default on the source code
            default=0.1,
            help="Upper learning rate boundary in the cycle. Note that it cannot be a list",
        )
        parser.add_argument(
            "--OneCycleLR_epochs",
            type=int,
            default=10,
            help="The number of epochs to train for. This is used along with steps_per_epoch in order to infer the total number of steps in the cycle if a value for total_steps is not provided. Default: None",
        )

or maybe even more....

Pitch

Not exactly sure how you are planning to do recursive configuration stuff, but I can imagine the CLI having a baseline configure_optimizers on its own so that the user doesn't need to write it for simple patterns. Then maybe have defaults in the CLI arguments to have things like

python train.py --model.learning_rate 1e-4 --model.lr_scheduler.type ReduceLROnPlateau --model.lr_scheduler.factor 0.1

and

python train.py --model.learning_rate 1e-4 --model.lr_scheduler.type ReduceLROnPlateau --model.lr_scheduler.factor 0.1 --model.optimizer.type Adam --model.optimizer.weight_decay 1e-5

or whatever. And the user just wouldn't implement a configure_optimizers or at least would have something simple to call.

Writing this out as a configuration, I have in mind that the user could have something like

model:
  decoder_layers:
  - 2
  - 4
  encoder_layers: 12
optimizer:
  Adam:
    weight_decay: 1e-5
lr_scheduler:
   ReduceLROnPlateau:
     factor: 0.1     
trainer:
  accelerator: null
  accumulate_grad_batches: 1
  amp_backend: native
  amp_level: O2

or maybe instead:

model:
  decoder_layers:
  - 2
  - 4
  encoder_layers: 12
  optimizer:
    Adam:
      weight_decay: 1e-5
  lr_scheduler:
     ReduceLROnPlateau:
       factor: 0.1     
trainer:
  accelerator: null
  accumulate_grad_batches: 1
  amp_backend: native
  amp_level: O2

etc.

Alternatives

See the above monstrosity of what is done now.

@jlperla jlperla added feature Is an improvement or enhancement help wanted Open to be worked on labels May 17, 2021
@justusschock
Copy link
Member

justusschock commented May 17, 2021

cc @carmocca @mauvilsa

@carmocca
Copy link
Contributor

carmocca commented May 17, 2021

but I can imagine the CLI having a baseline configure_optimizers on its own so that the user doesn't need to write it for simple patterns.

The CLI should not modify or override any code inside the LightningModule. To accomplish what you want, there are several ways you could go about it. What I would do is:

from enum import Enum

import torch

import pytorch_lightning as pl
from pytorch_lightning.utilities.cli import LightningCLI


# could be just strings but enum forces the set of choices
class OptimizerEnum(str, Enum):
    Adam = "Adam"
    SGD = "SGD"
    LBFGS = "LBFGS"


class LRSchedulerEnum(str, Enum):
    ...


class MyModel(pl.LightningModule):
    def configure_optimizer(self, optimizer: OptimizerEnum, learning_rate: float = 1e-3, weight_decay: float = 0.0):
        if optimizer == "Adam":
             return torch.optim.Adam(
                self.parameters(),
                lr=learning_rate,
                weight_decay=weight_decay,
            )
        elif optimizer == "SGD":
            return torch.optim.SGD(
                self.parameters(),
                lr=learning_rate,
                weight_decay=weight_decay,
            )
        elif optimizer == "LBFGS":
            return torch.optim.LBFGS(self.parameters(), lr=1)
        raise ValueError(f"Invalid optimizer {optimizer}. See --help")

    def configure_scheduler(self, optimizer: torch.optim.Optimizer, lr_scheduler: LRSchedulerEnum):
        # same structure as `configure_optimizer`
        ...

    def configure_optimizers(self):
        optimizer = self.configure_optimizer(
            self.hparams.optimizer,
            learning_rate=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        if self.hparams.lr_scheduler is None:
            return optimizer
        scheduler = self.configure_scheduler(optimizer, self.hparams.lr_scheduler)
        return [optimizer], [scheduler]


class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        # add the `configure_optimizer` arguments under the key optimizer
        parser.add_method_arguments(self.model_class, "configure_optimizer", nested_key="optimizer")
        # add the `configure_scheduler` arguments under the key lr_scheduler
        parser.add_method_arguments(self.model_class, "configure_scheduler", nested_key="lr_scheduler", skip={"optimizer"})


MyLightningCLI(MyModel)

And the commandline configuration looks like

Now all those defaults and types are part of the function signatures
And the help messages part of the docstrings (not included in the ported code)

@carmocca carmocca added the argparse (removed) Related to argument parsing (argparse, Hydra, ...) label May 17, 2021
@edenlightning edenlightning added discussion In a discussion stage working as intended Working as intended and removed help wanted Open to be worked on working as intended Working as intended labels May 19, 2021
@jlperla
Copy link
Author

jlperla commented May 31, 2021

I know this is "working as intended" but I also think you see missing a key feature. The code given doesnt tackle the bigger issue of configuring the learning rate schedulers with their different parameters, etc. I am not even sure how much wrapper code that would take or how to write it.

Also, the optimizers don't always have the same options so I think architecture for your code example just wouldn't work with more generality. Not to mention, how are new optimizers that lightning adds going to be supported? It would require basically hardcoding something in user code for the dispatching rather than an OO design within lightning or lightning CLI code.

Something is missing. @zippeurfou was there a structure that works for this with hydra?

This is very un-lightning in the quantity of copy paste boilerplate required for something almost everyone would want the same features work. I am positive there is something which would require zero code for users who want the CLI to use config for standard optimizer and scheduler usage

@mauvilsa
Copy link
Contributor

@jlperla I agree this is a missing feature. One of the objectives of LightningCLI is that it can be used without people having to learn about a configuration framework. This at least for most common cases since the possibilities of what people might need are endless.

I have some ideas and will describe them here in a bit. Though I will not manage to do this today.

@zippeurfou
Copy link
Contributor

Hey @jlperla you have good points and I agree with you that doing "if" statement with manually linking your hparams for each of the different options for optimizer and LR scheduler is not a great experience.
@mauvilsa have some idea let's wait and see if it works.

@jlperla
Copy link
Author

jlperla commented Jun 1, 2021

Thanks. Yeah, the other consideration is that if you can move stuff into your libraries in one form or another, it makes user code far less fragile to any refactoring or additional pytorch optimizer options. Otherwise there will be a half dozen, almost identical versions of the same code out there...none of them really customized for the user

@mauvilsa
Copy link
Contributor

mauvilsa commented Jun 1, 2021

I give here some initial ideas to start the discussion. Some objectives could be:

  • Be possible to add groups in the config for settings of optimizers/schedulers. The groups are dicts that can be used to instantiate, no automatic instantiation.
  • There could be one or more independent optimizers/schedulers.
  • A single optimizer/scheduler group can be configured to accept settings for one class or multiple classes using class_path and init_args to follow the same pattern that LightningCLI already uses.
  • There would be a function to ease instantiation. This is particularly important when a class is defined using class_path and init_args.

My initial idea is the following (not yet fully thought through). There would be a method in LightningArgumentParser to add groups in the config, e.g. add_optimizer_args. This method could be used to configure a single optimizer class, e.g. parser.add_optimizer_args(Adam, 'optimizer'), or multiple, e.g. parser.add_optimizer_args(BaseOptimizer, 'optimizer', subclass_mode=True) or parser.add_optimizer_args(Union[Adam, SGD], 'optimizer'). These groups would need to be linked to arguments of the module. Would be up to the user to store them in the class and then use them in configure_optimizers to instantiate.

@jlperla
Copy link
Author

jlperla commented Jun 1, 2021

Would be up to the user to store them in the class and then use them in configure_optimizers to instantiate.

I think I would have to get a sense of what you mean by that, but sounds like boilerplate to me :-) The source of the difficulty in all of this stuff is writing the code to configure the class and store the parameters. I still don't understand why the user would need to write a single line of code for this since it is always the same stuff (except for weird cases where you need two optimizers/etc. but users can then do what they want for that). Why can't PL and the CLI could do everything internally and have things nice and consistent?

But... the stuff in the parser groups seem to make sense to me and the grouping "dicts" as constructors for the optimizers and scheudlers also make plenty of sense. Now I think it just needs to be taken to its inevitable conclusion of an option which formalizes that stuff inside of PL itself :-)

@mauvilsa
Copy link
Contributor

mauvilsa commented Jun 2, 2021

Would be up to the user to store them in the class and then use them in configure_optimizers to instantiate.

I think I would have to get a sense of what you mean by that, but sounds like boilerplate to me :-) The source of the difficulty in all of this stuff is writing the code to configure the class and store the parameters. I still don't understand why the user would need to write a single line of code for this since it is always the same stuff (except for weird cases where you need two optimizers/etc. but users can then do what they want for that). Why can't PL and the CLI could do everything internally and have things nice and consistent?

You are right. It wouldn't be difficult to automatically implement the model's configure_optimizers in the case of a single optimizer and scheduler. It could be as simple as

parser.add_optimizer_args(Adam)  # 'optimizer' key would be default
parser.add_scheduler_args(StepLR)  # 'scheduler' key would be default

In the case of multiple optimizers/schedulers it becomes manual. The parser could be configured like

parser.add_optimizer_args(Adam, 'optimizer1', link_to='model.optim1')
parser.add_optimizer_args(Adam, 'optimizer2', link_to='model.optim2')

Then the user needs to store the parameters and implement configure_optimizers.

@jlperla
Copy link
Author

jlperla commented Jun 2, 2021

It wouldn't be difficult to automatically implement the model's configure_optimizers in the case of a single optimizer and scheduler

I am not sure I completely follow, but if it means I can have a configurable and swappable single optimizer and single scheduler in my code without any manual boilerplate, then I am happy. But just to confirm, the key here is that it is all swappable and configurable in the CLI wihtout me needing to manually write a bunch of elif etc. statements and manipulating arguments in my model.

That is, if you are saying that I have to go parser.add_scheduler_args(StepLR) in my code, then it seems like you are forcing me to choose a scheduler at parsetime? Whereas what we really want is to let the users swap out different parsers and schedulers in the CLI or json configuration so we can try different things in different runs to see what works. I still see no reason that all of that stuff can't be basically automatic for single optimizer/single scheudler setups. At least if the user uses the

class MyModel(pl.LightningModule): 
    def __init__(self, **kwargs): 
        super().__init__()
        self.save_hyperparameters()

Then if the user doesn't provide a def configure_optimizers(self): method, the default one could look into the module to see if it has the standard thing stored. e.g. maybe it stores a optimizer_dict and lr_scheduler_dict or whatever in it, which gets populated in the CLI config from

model:
  decoder_layers:
  - 2
  - 4
  encoder_layers: 12
  optimizer:
    Adam:
      weight_decay: 1e-5
  lr_scheduler:
     ReduceLROnPlateau:
       factor: 0.1     
trainer:
  accelerator: null
  accumulate_grad_batches: 1
  amp_backend: native
  amp_level: O2

And if they wanted to try a different optimizer, they just pass in a different YAML or CLI.

  optimizer:
     SGD:
      weight_decay: 1e-5

etc.
Etc. The user wouldn't need to write any code specific to the optimizers or lr_schedulers if they want to just use the basic method which looks for variables associated with those?

Then you could either use dispatching on the type (e.g. the Adam) or fill things in with a dictionary and a bunch of internal elif.

@mauvilsa
Copy link
Contributor

mauvilsa commented Jun 2, 2021

It wouldn't be difficult to automatically implement the model's configure_optimizers in the case of a single optimizer and scheduler

[...] But just to confirm, the key here is that it is all swappable and configurable in the CLI wihtout me needing to manually write a bunch of elif etc. statements and manipulating arguments in my model.

This is why I said a single optimizer class or multiple. If it is configured to allow multiple classes then it would follow the same pattern that LightningCLI already uses. So in the config it would be like

optimizer:
  class_path: torch.optim.Adam
  init_args:
    weight_decay: 1e-5
scheduler:
  class_path: torch.optim.lr_scheduler.ReduceLROnPlateau
  init_args:
    factor: 0.1     

@mauvilsa
Copy link
Contributor

mauvilsa commented Jun 2, 2021

[...] I still see no reason that all of that stuff can't be basically automatic for single optimizer/single scheudler setups. At least if the user uses the

class MyModel(pl.LightningModule): 
    def __init__(self, **kwargs): 
        super().__init__()
        self.save_hyperparameters()

Yes, for a single optimizer/scheduler it would be automatic and it shouldn't depend on the use of self.save_hyperparameters(). I guess it would make sense that if the module already implements configure_optimizers then there wouldn't be an automatic implementation or fail with a meaningful error message.

@stale
Copy link

stale bot commented Jul 31, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jul 31, 2021
@stale stale bot closed this as completed Aug 7, 2021
@tchaton tchaton reopened this Aug 22, 2021
@stale stale bot removed the won't fix This will not be worked on label Aug 22, 2021
@blacksnail789521
Copy link

blacksnail789521 commented Nov 20, 2022

def configure_optimizers(self):
    optimizer = getattr(torch.optim, self.hparams.optimizer)(
        self.parameters()
        lr=self.hparams.lr,
        weight_decay=self.hparams.weight_decay,
    )
    scheduler = self.configure_scheduler(optimizer, self.hparams.lr_scheduler)
    return [optimizer], [scheduler]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) discussion In a discussion stage feature Is an improvement or enhancement
Projects
None yet
8 participants