Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LambdaCallback #5347

Merged
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0df6bd9
Add LambdaCallback
marload Jan 4, 2021
aa13ddf
docs
marload Jan 4, 2021
01bd0a5
add pr link
Borda Jan 4, 2021
b0953dd
convention
marload Jan 4, 2021
7863e67
Fix Callback Typo
marload Jan 4, 2021
6792408
Update pytorch_lightning/callbacks/lambda_cb.py
marload Jan 4, 2021
d934a23
Update pytorch_lightning/callbacks/lambda_cb.py
marload Jan 4, 2021
9fc981a
Update pytorch_lightning/callbacks/lambda_cb.py
marload Jan 4, 2021
a93e468
use Misconfigureation
marload Jan 5, 2021
2ef199f
update docs
marload Jan 5, 2021
cb294e0
sort export
marload Jan 5, 2021
aadde9e
use inspect
marload Jan 5, 2021
8c10b14
string fill
marload Jan 5, 2021
39b1970
use fast dev run
marload Jan 5, 2021
dc11767
isort
marload Jan 5, 2021
0cfef59
remove unused import
marload Jan 5, 2021
6835771
sort
marload Jan 5, 2021
0263e3a
hilightning
marload Jan 5, 2021
7249a10
highlighting
marload Jan 5, 2021
3038d2f
highlighting
marload Jan 5, 2021
c400b98
remove debug log
marload Jan 5, 2021
8518382
eq
marload Jan 5, 2021
8bfe53e
res
marload Jan 5, 2021
9fd4c6b
results
marload Jan 5, 2021
c4563c7
add misconfig exception test
marload Jan 5, 2021
a329d4a
use pytest raises
marload Jan 5, 2021
571b941
Merge remote-tracking branch 'upstream/release/1.2-dev' into feature/…
marload Jan 5, 2021
d1f8d4a
fix
marload Jan 5, 2021
7293115
Apply suggestions from code review
Borda Jan 6, 2021
4d85f59
Update pytorch_lightning/callbacks/lambda_cb.py
marload Jan 6, 2021
c9ecb8a
hc
marload Jan 6, 2021
2044291
rm pt
marload Jan 6, 2021
5359ce6
Merge branch 'release/1.2-dev' into feature/lambdacallback
tchaton Jan 6, 2021
556ea09
fix
marload Jan 8, 2021
d190e15
try fix
rohitgr7 Jan 9, 2021
d7bfc4a
Merge branch 'release/1.2-dev' into feature/lambdacallback
rohitgr7 Jan 9, 2021
a27dbff
whitespace
rohitgr7 Jan 9, 2021
d1bd19a
new hook
rohitgr7 Jan 9, 2021
afe018a
add raise
marload Jan 10, 2021
709fb5b
fix
marload Jan 10, 2021
9b93a2c
remove unused
marload Jan 10, 2021
72f3f0c
rename
marload Jan 12, 2021
7ed3eea
Merge branch 'release/1.2-dev' into feature/lambdacallback
SkafteNicki Jan 12, 2021
2ce0131
Merge branch 'release/1.2-dev' into feature/lambdacallback
SkafteNicki Jan 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241))

- Added `LambdaCallback` ([#5347](https://github.com/PyTorchLightning/pytorch-lightning/pull/5347))


### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Lightning has a few built-in callbacks.
EarlyStopping
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
GPUStatsMonitor
GradientAccumulationScheduler
LambdaCallback
LearningRateMonitor
ModelCheckpoint
ProgressBar
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.lambda_cb import LambdaCallback
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase


__all__ = [
'Callback',
'EarlyStopping',
'GPUStatsMonitor',
'GradientAccumulationScheduler',
'LambdaCallback',
'LearningRateMonitor',
'ModelCheckpoint',
'ProgressBar',
Expand Down
155 changes: 155 additions & 0 deletions pytorch_lightning/callbacks/lambda_cb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright The PyTorch Lightning team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the file be renamed to lambda.py to be consistent with the other callbacks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for renaming to lambda.py and test_lambda.py

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Lambda Callback
^^^^^^^^^^^^^^^

Create a simple callback on the fly using lambda functions.

"""

from typing import Callable, Optional

from pytorch_lightning.callbacks.base import Callback


class LambdaCallback(Callback):
r"""
Create a simple callback on the fly using lambda functions.

Args:
**kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback`

Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LambdaCallback
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
"""

def __init__(
self,
setup: Optional[Callable] = None,
teardown: Optional[Callable] = None,
on_init_start: Optional[Callable] = None,
on_init_end: Optional[Callable] = None,
on_fit_start: Optional[Callable] = None,
on_fit_end: Optional[Callable] = None,
on_sanity_check_start: Optional[Callable] = None,
on_sanity_check_end: Optional[Callable] = None,
on_train_batch_start: Optional[Callable] = None,
on_train_batch_end: Optional[Callable] = None,
on_train_epoch_start: Optional[Callable] = None,
on_train_epoch_end: Optional[Callable] = None,
on_validation_epoch_start: Optional[Callable] = None,
on_validation_epoch_end: Optional[Callable] = None,
on_test_epoch_start: Optional[Callable] = None,
on_test_epoch_end: Optional[Callable] = None,
on_epoch_start: Optional[Callable] = None,
on_epoch_end: Optional[Callable] = None,
on_batch_start: Optional[Callable] = None,
on_validation_batch_start: Optional[Callable] = None,
on_validation_batch_end: Optional[Callable] = None,
on_test_batch_start: Optional[Callable] = None,
on_test_batch_end: Optional[Callable] = None,
on_batch_end: Optional[Callable] = None,
on_train_start: Optional[Callable] = None,
on_train_end: Optional[Callable] = None,
on_pretrain_routine_start: Optional[Callable] = None,
on_pretrain_routine_end: Optional[Callable] = None,
on_validation_start: Optional[Callable] = None,
on_validation_end: Optional[Callable] = None,
on_test_start: Optional[Callable] = None,
on_test_end: Optional[Callable] = None,
on_keyboard_interrupt: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
):
if setup is not None:
self.setup = setup
if teardown is not None:
self.teardown = teardown
if on_init_start is not None:
self.on_init_start = on_init_start
if on_init_end is not None:
self.on_init_end = on_init_end
if on_fit_start is not None:
self.on_fit_start = on_fit_start
if on_fit_end is not None:
self.on_fit_end = on_fit_end
if on_sanity_check_start is not None:
self.on_sanity_check_start = on_sanity_check_start
if on_sanity_check_end is not None:
self.on_sanity_check_end = on_sanity_check_end
if on_train_batch_start is not None:
self.on_train_batch_start = on_train_batch_start
if on_train_batch_end is not None:
self.on_train_batch_end = on_train_batch_end
if on_train_epoch_start is not None:
self.on_train_epoch_start = on_train_epoch_start
if on_train_epoch_end is not None:
self.on_train_epoch_end = on_train_epoch_end
if on_validation_epoch_start is not None:
self.on_validation_epoch_start = on_validation_epoch_start
if on_validation_epoch_end is not None:
self.on_validation_epoch_end = on_validation_epoch_end
if on_test_epoch_start is not None:
self.on_test_epoch_start = on_test_epoch_start
if on_test_epoch_end is not None:
self.on_test_epoch_end = on_test_epoch_end
if on_epoch_start is not None:
self.on_epoch_start = on_epoch_start
if on_epoch_end is not None:
self.on_epoch_end = on_epoch_end
if on_batch_start is not None:
self.on_batch_start = on_batch_start
if on_validation_batch_start is not None:
self.on_validation_batch_start = on_validation_batch_start
if on_validation_batch_end is not None:
self.on_validation_batch_end = on_validation_batch_end
if on_test_batch_start is not None:
self.on_test_batch_start = on_test_batch_start
if on_test_batch_end is not None:
self.on_test_batch_end = on_test_batch_end
if on_batch_end is not None:
self.on_batch_end = on_batch_end
if on_train_start is not None:
self.on_train_start = on_train_start
if on_train_end is not None:
self.on_train_end = on_train_end
if on_pretrain_routine_start is not None:
self.on_pretrain_routine_start = on_pretrain_routine_start
if on_pretrain_routine_end is not None:
self.on_pretrain_routine_end = on_pretrain_routine_end
if on_validation_start is not None:
self.on_validation_start = on_validation_start
if on_validation_end is not None:
self.on_validation_end = on_validation_end
if on_test_start is not None:
self.on_test_start = on_test_start
if on_test_end is not None:
self.on_test_end = on_test_end
if on_keyboard_interrupt is not None:
self.on_keyboard_interrupt = on_keyboard_interrupt
if on_save_checkpoint is not None:
self.on_save_checkpoint = on_save_checkpoint
if on_load_checkpoint is not None:
self.on_load_checkpoint = on_load_checkpoint
if on_after_backward is not None:
self.on_after_backward = on_after_backward
if on_before_zero_grad is not None:
self.on_before_zero_grad = on_before_zero_grad
63 changes: 63 additions & 0 deletions tests/callbacks/test_lambda_cb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LambdaCallback
from tests.base.boring_model import BoringModel


def test_lambda_call(tmpdir):
seed_everything(42)

checker = set()

hooks = [
"setup",
"teardown",
"on_init_start",
"on_init_end",
"on_fit_start",
"on_fit_end",
"on_train_batch_start",
"on_train_batch_end",
"on_train_epoch_start",
"on_train_epoch_end",
"on_validation_epoch_start",
"on_validation_epoch_end",
"on_test_epoch_start",
"on_test_epoch_end",
"on_epoch_start",
"on_epoch_end",
"on_batch_start",
"on_batch_end",
"on_validation_batch_start",
"on_validation_batch_end",
"on_test_batch_start",
"on_test_batch_end",
"on_train_start",
"on_train_end",
"on_test_start",
"on_test_end",
]
model = BoringModel()

hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks}
test_callback = LambdaCallback(**hooks_args)
marload marked this conversation as resolved.
Show resolved Hide resolved

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[test_callback])

results = trainer.fit(model)
trainer.test(model)

assert results
assert checker == set(hooks)