Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer learning phases #2006

Closed
lgvaz opened this issue May 29, 2020 · 12 comments
Closed

Transfer learning phases #2006

lgvaz opened this issue May 29, 2020 · 12 comments
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@lgvaz
Copy link
Contributor

lgvaz commented May 29, 2020

🚀 Feature

When doing transfer learning we need to switch between phases.

Normally, the first phase is to freeze all but the head of the model and train only that.

After a predefined amount of epochs, we unfreeze the rest of our model (or a part of it) and start training again (possibly with the help of differential learning rates, described in #2005). We can repeat this phase as many times as we like.

We should implement a class that handles all of that for us, this includes:

  • Unfreeze part of our model
  • Reset and change the lr_scheduler parameters between phases
  • If LearningRateLogger is being used, register the new lr_scheduler

#2005 Will take care of the parameter groups
This will take care of what I call "phase switches"

Proposals

There are some ways of achieving this:

Logic inside on_epoch_start

def on_epoch_start(self):
    if self.current_epoch == 0:
        self.freeze()
        self.trainer.lr_schedulers = ... # Define new scheduler
        
    if self.current_epoch == N_FREEZE_EPOCHS:
        self.unfreeze() # Or partially unfreeze
        self.trainer.lr_schedulers = ... # Define new scheduler

We can keep adding as many milestones as we want this way, but it's important to note that they all have to be define beforehand.

Multiple calls to Trainer.fit

model.freeze()
trainer.fit_one_cycle(model, n_epochs=2, lr=1e-3, pct_start=0.9)
model.unfreeze()
trainer.fit_one_cycle(mode, n_epochs=5, lr=slice(5e-6, 5e-4), pct_start=0.2)

This is exactly the flow on fastai, this way of training model is excellent for iterative training, like on a notebook or a REPL.

fit_one_cycle assumes that we are using the OneCycleLR scheduler, assumes that each call is a continuation of the last, and assumes we want to reset our schedule

When we pass a slice to lr we are asking for a interpolation of values between the trainable layer groups

Implement a new scheduler (suggested by @williamFalcon)

The scheduler receives a list of dicts, each dict will specify the duration of the phase and it's configuration (what layers to freeze, what lrs to use, ...)

scheduler = FineTuneScheduler([
   {'params': [nn.Sequential(self.c_d1, self.c_d1_bn), self.c_d2], 'action': 'freeze', 'epoch': 0},
   {'params': [self.c_d2], 'action': 'unfreeze', 'epoch': 2},
])

Then we can just pass the scheduler to the Trainer.

Notes

In both cases, the flow should be the same for all standard areas (vision, nlp, time-series,...).

The only things we assume is:

  • You want to train on model in multiple phases
  • The phases are a continuation of each other
@lgvaz lgvaz added feature Is an improvement or enhancement help wanted Open to be worked on labels May 29, 2020
@lgvaz
Copy link
Contributor Author

lgvaz commented May 30, 2020

I personally like the approach of calling Trainer.fit (or some variation) multiple times more.

It allows me to have more control on how to train my model. Usually transfer learning happens on small datasets, so it's possible for the user to train some epochs, see what happens, and only then decide if it's time to unfreeze some layers or run some more epochs on the current configuration.

@lgvaz
Copy link
Contributor Author

lgvaz commented May 30, 2020

Added a new proposal to OP, the scheduler interface suggested by @williamFalcon

I think the main benefit of this approach is that it's easily reproducible, because we are using a list of dicts (configs), I think we can even store the scheduler into as a config file in the future.

@lgvaz
Copy link
Contributor Author

lgvaz commented May 30, 2020

Another option with the scheduler, would be to pass a function to it instead of predefined actions, it would look like something like this:

def phase1(trainer, model):
    model.freeze()
    sched = OneCycleLR(...)
    trainer.new_schedule(sched)

def phase2(trainer, model):
    model.unfreeze()
    sched = OneCycleLR(...) # Differential LRs can be introduced here
    trainer.new_schedule(sched)

sched = FineTuneScheduler([
    {'func': phase1, 'epoch': 0},
    {'func': phase2, 'epoch': 5},
])

This gives the user full control on what happens in these phases


If you think about it, this is not even a specific FineTunerScheduler, it's more like a LambdaScheduler, you can inject any functionality you want with it, very powerful.

We can then implement helper functions to make the definition of differential learning rates, reseting schedulers easier. But it would be up to the user to construct what we wants =)

@lgvaz
Copy link
Contributor Author

lgvaz commented May 30, 2020

One thing I don't currently like about it though, is that when creating a new scheduler I also need to know the duration of the phase. Maybe we can change is signature to:

def phase(trainer, model, n_epochs)

@lgvaz
Copy link
Contributor Author

lgvaz commented May 30, 2020

And then, as @williamFalcon suggested again, we can implement a scheduler that is really specific to the standard transfer learning case:

class FineTuneScheduler(Scheduler):
  def __init__(self, pretrained, head, head_unfreeze_epoch):
       ...

# unfreeze head after 1 epoch
sched = FineTuneScheduler(nn.Sequential(self.c_d1, self.c_d1_bn), self.c_d2, 1)

# unfreeze head after 10 epoch
sched = FineTuneScheduler(nn.Sequential(self.c_d1, self.c_d1_bn), self.c_d2, 10)

This can be easily built on top of LambdaScheduler

@Borda
Copy link
Member

Borda commented Jun 3, 2020

I would go the scheduler way with duct config as it can be simply stored and even without load/run you can see what you did in past, kind or history notes

@Borda
Copy link
Member

Borda commented Jun 3, 2020

@PyTorchLightning/core-contributors any other thoughts?

@reactivetype
Copy link

reactivetype commented Jun 5, 2020

When restoring a checkpoint for finetuning a model, users still need a way to reset the current_epoch and global_step to 0.

Do we still need a GH issue to handle this aside from params_group and differentiable learning rate features?

A hack to this was described by @lgvaz

class MyTrainer(Trainer):
    def restore_weights(self, model: LightningModule):
        res = super().restore_weights(model)
        self.reset_lr_schedulers()
        return res
    def reset_lr_schedulers(self):
        for sched in self.lr_schedulers:
            sched['scheduler'].last_epoch = 0

Is there a better way? If we pass both resume_from_checkpoint and lr_schedulers params to the Trainer, will the new lr_schedulers override the ones saved from the saved checkpoint’s training state along with the scheduler's last_epoch?

@stale
Copy link

stale bot commented Aug 4, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Aug 4, 2020
@Borda Borda added design Includes a design discussion discussion In a discussion stage Important and removed won't fix This will not be worked on labels Aug 4, 2020
@edenlightning edenlightning modified the milestones: 0.9.x, 1.1 Sep 17, 2020
@edenlightning edenlightning modified the milestones: 1.1, 1.2 Oct 19, 2020
@edenlightning edenlightning modified the milestones: 1.2, 1.3 Feb 8, 2021
@tchaton
Copy link
Contributor

tchaton commented Mar 9, 2021

Dear @lgvaz,

This logic can easily be built on top of BaseFinetuning callback.

def phase1(trainer, model):
    model.freeze()
    sched = OneCycleLR(...)
    trainer.new_schedule(sched)

def phase2(trainer, model):
    model.unfreeze()
    sched = OneCycleLR(...) # Differential LRs can be introduced here
    trainer.new_schedule(sched)


class FinetuneScheduler(BaseFinetuning):

    def __init__(self, phases):

        self.phases = phases

   @property
    def max_epochs(self):
           # return total number of epochs to run from phases.

    def freeze_before_training(self, pl_module: pl.LightningModule):
        self.freeze(modules=pl_module, train_bn=self.train_bn)

    def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
        # Logic to extract the phase and apply it
        ...

        ...


cb = FinetuneScheduler([
    {'func': phase1, 'epoch': 0},
    {'func': phase2, 'epoch': 5},
])

Trainer(callback=cb, max_epochs= cb.max_epochs)

https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.callbacks.BaseFinetuning.html?highlight=BaseFinetuning

If you do implement a nice Finetuning Callback, please make a PR so the community can try it out :)

Best,
T.C

@tchaton tchaton added the waiting on author Waiting on user action, correction, or update label Mar 9, 2021
@tchaton tchaton added the priority: 1 Medium priority task label Mar 9, 2021
@lgvaz
Copy link
Contributor Author

lgvaz commented Mar 27, 2021

Hi @tchaton thanks for the update! Unfortunately I don't have the time to try this out right now =/

Should we leave this issue open or should we close it?

@edenlightning edenlightning modified the milestones: v1.3, v1.4 Apr 27, 2021
@edenlightning edenlightning removed Important priority: 1 Medium priority task waiting on author Waiting on user action, correction, or update labels May 9, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jun 30, 2021
@awaelchli awaelchli modified the milestones: v1.5, v1.6 Nov 4, 2021
@carmocca
Copy link
Contributor

carmocca commented Feb 1, 2022

We are not looking to add any more callbacks to core that are too opinionated, research-y, or just not applicable to most users. We suggest developing this callback in your own repository.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

7 participants