-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Major refactoring of ModelTrainer #3182
Conversation
…r metric-recording)
…g a second time, this is more flexibel)
…ption) in the core class of the plugin system
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The introduction of train_custom
in combination with the other functions which offer a set of presets makes the Trainer more flexible. Overall, the train functions are much more compact.
I am a bit unsure about the events which are not currently used in any of the plugins. I can see that one might only want to keep those which are referenced by plugins in the repository. However, I already started using some of them in my experiments and think they offer a much greater potential for extending the train function (without the need to derive a trainer class with an overwritten train function).
This PR is mostly based on #3084 by @plonerma and executes a major refactoring of the
ModelTrainer
.Why? Since Flair 0.1 a very large amount of features were added to the
ModelTrainer
, causing the class to become very bloated and difficult to extend/debug.This PR refactors the trainer in several ways:
train
andfine_tune
methods are "best practice" configurations for model training. A newtrain_custom
method is added to this that gives users maximal flexibility when doing model training.We hope that a more structured
ModelTrainer
will allow us to better integrate new features in the near future, such as multi-GPU and more logging support.