-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A Multi-phase, Scheduled Finetuning Callback #10197
Comments
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
Unless there's objection, I think we can close this now that the requested functionality is available via the Finetuning Scheduler extension (used in this lightning tutorial). On a related note, any sense of when the _notebooks submodule will be updated? Looks like it was last updated in mid-Jan so wondering if there will be an update tied to the 1.7 release. Nice work on 1.7 so far btw! ⚡ 🚀 🎉 |
🚀 Feature
A callback that enables multi-phase, scheduled finetuning of foundational models.
Motivation
Gradual unfreezing/thawing can help maximize foundational model knowledge retention while allowing (typically upper layers of) the model to optimally adapt to new tasks during transfer learning [1, 2, 3].
When tuning pre-trained large language (aka foundational) models for downstream tasks over the last couple years, I've personally observed the benefits of this technique in multiple project contexts and have been using/refining code to expedite the application of this pattern. Given that this approach to finetuning continues to be widely used and that multi-phase finetuning has be requested by others in the PL community I thought it could be immensely useful to the community to provide a callback (extending BaseFinetuning) for this purpose. I've created said callback (named FinetuningScheduler) and have been using it from my PL fork to great effect that last few months and am hoping others may find it similarly useful.
Pitch
Though approaches to leveraging foundational models for downstream tasks are continually evolving (e.g. prompt/prefix-tuning etc.), finetuning w/ gradual unfreezing continues to be widely used and multi-phase finetuning has be requested by the PL community (#2006) to boot. I think multi-phase finetuning is a natural extension of the BaseFinetuning functionality that PL provides and comports nicely with its aspiration to decouple the science from the engineering.
The PR I'm submitting includes a fully-functional, tested and documented beta version of the FinetuningScheduler (fts) callback as well as a new example (in ./pl_examples/basic_examples/fts) demonstrating a few use cases as applied to a SuperGLUE benchmark task using the LightningCLI. Given the nature of this callback, I thought a LightningCLI-based example was better suited than a notebook-based one.
Rather than re-iterate the documentation here in detail, I think the best way to get a sense of the potential utility of this callback would be to review the documentation I've provided in the PR and execute the new example. At a high-level though, this callback essentially implements gradual unfreezing of foundational models via either explicit or implicit finetuning schedules. Explicit finetuning mode involves unfreezing/thawing layers based upon user-defined layer groupings. Schedule definition is facilitated via a method that dumps a default finetuning schedule which can be adjusted as desired by the user and subsequently passed to the callback. Implicit finetuning mode generates the default schedule and proceeds to finetune according to the generated schedule.
Alternatives
Since this pattern is so commonly used, I think it makes sense to have it available in the PL framework as a callback rather than have it implemented in each user LightningModule or as a community example. I'd note that I did consider modifying ModelCheckpoint to accommodate this callback, but ultimately decided given the extensive usage of that callback, it would be more prudent to extend ModelCheckpoint with FTSCheckpoint, at least while FinetuningScheduler is in beta.
Additional context
Feel free to look at the Tensorboard experiment demo I've linked to in the documentation. While I've made other minor contributions to PyTorch Lightning, this is my first feature contribution, so please bear with me if there are any shortcomings wrt my contribution. Thank you so much to everyone in the PL community for contributing to this awesome framework! I've found it immensely useful and plan to continue using it (and evangelizing about it) in the future.
cc @Borda
The text was updated successfully, but these errors were encountered: