-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add callback system + associated test * Add trainer and pl_module args to callback methods * typing * typo in docstring * Switch to on_.*_start() * fix on_test_start * fix the mess after rebasing
- Loading branch information
Showing
14 changed files
with
407 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,4 @@ Loggers | |
_save_model, | ||
on_epoch_end, | ||
on_train_end, | ||
on_epoch_begin, | ||
on_epoch_start, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Callable | ||
from abc import ABC | ||
|
||
from pytorch_lightning.callbacks import Callback | ||
|
||
|
||
class TrainerCallbackHookMixin(ABC): | ||
|
||
def __init__(self): | ||
# this is just a summary on variables used in this abstract class, | ||
# the proper values/initialisation should be done in child class | ||
self.callbacks: list[Callback] = [] | ||
self.get_model: Callable = ... | ||
|
||
def on_init_start(self): | ||
"""Called when the trainer initialization begins.""" | ||
for callback in self.callbacks: | ||
callback.on_init_start(self, None) | ||
|
||
def on_init_end(self): | ||
"""Called when the trainer initialization ends.""" | ||
for callback in self.callbacks: | ||
callback.on_init_end(self, self.get_model()) | ||
|
||
def on_fit_start(self): | ||
"""Called when the fit begins.""" | ||
for callback in self.callbacks: | ||
callback.on_fit_start(self, self.get_model()) | ||
|
||
def on_fit_end(self): | ||
"""Called when the fit ends.""" | ||
for callback in self.callbacks: | ||
callback.on_fit_end(self, self.get_model()) | ||
|
||
def on_epoch_start(self): | ||
"""Called when the epoch begins.""" | ||
for callback in self.callbacks: | ||
callback.on_epoch_start(self, self.get_model()) | ||
|
||
def on_epoch_end(self): | ||
"""Called when the epoch ends.""" | ||
for callback in self.callbacks: | ||
callback.on_epoch_end(self, self.get_model()) | ||
|
||
def on_train_start(self): | ||
"""Called when the train begins.""" | ||
for callback in self.callbacks: | ||
callback.on_train_start(self, self.get_model()) | ||
|
||
def on_train_end(self): | ||
"""Called when the train ends.""" | ||
for callback in self.callbacks: | ||
callback.on_train_end(self, self.get_model()) | ||
|
||
def on_batch_start(self): | ||
"""Called when the training batch begins.""" | ||
for callback in self.callbacks: | ||
callback.on_batch_start(self, self.get_model()) | ||
|
||
def on_batch_end(self): | ||
"""Called when the training batch ends.""" | ||
for callback in self.callbacks: | ||
callback.on_batch_end(self, self.get_model()) | ||
|
||
def on_validation_start(self): | ||
"""Called when the validation loop begins.""" | ||
for callback in self.callbacks: | ||
callback.on_validation_start(self, self.get_model()) | ||
|
||
def on_validation_end(self): | ||
"""Called when the validation loop ends.""" | ||
for callback in self.callbacks: | ||
callback.on_validation_end(self, self.get_model()) | ||
|
||
def on_test_start(self): | ||
"""Called when the test begins.""" | ||
for callback in self.callbacks: | ||
callback.on_test_start(self, self.get_model()) | ||
|
||
def on_test_end(self): | ||
"""Called when the test ends.""" | ||
for callback in self.callbacks: | ||
callback.on_test_end(self, self.get_model()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.