Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle batchnorms in BatchGradientVerification #569

Merged
merged 6 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions pl_bolts/callbacks/verification/batch_gradient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# type: ignore
from typing import Any, Callable, List, Optional
from contextlib import contextmanager
from typing import Any, Callable, Iterable, List, Optional, Type

import torch
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -16,6 +18,14 @@ class BatchGradientVerification(VerificationBase):
on the wrong tensor dimensions.
"""

NORM_LAYER_CLASSES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
torch.nn.GroupNorm,
)

def check(
self,
input_array: Any,
Expand Down Expand Up @@ -58,7 +68,8 @@ def check(
input_batch.requires_grad = True

self.model.zero_grad()
output = self._model_forward(input_array)
with selective_eval(self.model, self.NORM_LAYER_CLASSES):
output = self._model_forward(input_array)

# backward on the i-th sample should lead to gradient only in i-th input slice
output_mapping(output)[sample_idx].sum().backward()
Expand Down Expand Up @@ -190,3 +201,26 @@ def collect_batches(tensor: torch.Tensor) -> torch.Tensor:

apply_to_collection(data, dtype=torch.Tensor, function=collect_batches)
return tensors


@contextmanager
def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None:
"""
A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance``
check, so all subclasses are also affected.

Args:
model: A model which has layers that need to be set to eval mode.
layer_types: The list of class objects for which all layers of that type will be set to eval mode.
"""
to_revert = []
try:
for module in model.modules():
if isinstance(module, tuple(layer_types)):
if module.training:
module.eval()
to_revert.append(module)
yield
finally:
for module in to_revert:
module.train()
44 changes: 42 additions & 2 deletions tests/callbacks/verification/test_batch_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn as nn

from pl_bolts.callbacks import BatchGradientVerificationCallback
from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping
from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping, selective_eval
from pl_bolts.utils import BatchGradientVerification


Expand All @@ -18,6 +18,7 @@ def __init__(self, mix_data=False):
super().__init__()
self.mix_data = mix_data
self.linear = nn.Linear(10, 5)
self.bn = nn.BatchNorm1d(10)
self.input_array = torch.rand(10, 5, 2)

def forward(self, *args, **kwargs):
Expand All @@ -29,7 +30,7 @@ def forward__standard(self, x):
x = x.view(10, -1).permute(1, 0).view(-1, 10) # oops!
else:
x = x.view(-1, 10) # good!
return self.linear(x)
return self.linear(self.bn(x))


class MultipleInputModel(TemplateModel):
Expand Down Expand Up @@ -255,3 +256,42 @@ def test_default_output_mapping():
)
output = default_output_mapping(data)
assert torch.all(output == expected)


class BatchNormModel(nn.Module):

def __init__(self):
super().__init__()
self.batch_norm0 = nn.BatchNorm1d(2)
self.batch_norm1 = nn.BatchNorm1d(3)
self.instance_norm = nn.InstanceNorm1d(4)


def test_selective_eval():
""" Test that the selective_eval context manager only applies to selected layer types. """
model = BatchNormModel()
model.train()
with selective_eval(model, [nn.BatchNorm1d]):
assert not model.batch_norm0.training
assert not model.batch_norm1.training
assert model.instance_norm.training

assert model.batch_norm0.training
assert model.batch_norm1.training
assert model.instance_norm.training


def test_selective_eval_invariant():
""" Test that the selective_eval context manager does not undo layers that were already in eval mode. """
model = BatchNormModel()
model.train()
model.batch_norm1.eval()
assert model.batch_norm0.training
assert not model.batch_norm1.training

with selective_eval(model, [nn.BatchNorm1d]):
assert not model.batch_norm0.training
assert not model.batch_norm1.training

assert model.batch_norm0.training
assert not model.batch_norm1.training