diff --git a/CHANGELOG.md b/CHANGELOG.md index 6195660283..0bdc457628 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added metric GIoU ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347)) - Added Intersection over Union Metric/Loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469)) - Added SimSiam model ([#407](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/407)) +- Added gradient verification callback ([#465](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/465)) ### Changed diff --git a/docs/source/info_callbacks.rst b/docs/source/info_callbacks.rst index cfeed28817..b65136675d 100644 --- a/docs/source/info_callbacks.rst +++ b/docs/source/info_callbacks.rst @@ -64,3 +64,63 @@ You can track all or just a selection of submodules: This is especially useful for debugging the data flow in complex models and to identify numerical instabilities. + + +--------------- + +Model Verification +------------------ + + +Gradient-Check for Batch-Optimization +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Gradient descent over a batch of samples can not only benefit the optimization but also leverages data parallelism. +However, one has to be careful not to mix data across the batch dimension. +Only a small error in a reshape or permutation operation results in the optimization getting stuck and you won't +even get a runtime error. How can one tell if the model mixes data in the batch? +A simple trick is to do the following: + +1. run the model on an example batch (can be random data) +2. get the output batch and select the n-th sample (choose n) +3. compute a dummy loss value of only that sample and compute the gradient w.r.t the entire input batch +4. observe that only the i-th sample in the input batch has non-zero gradient + +| + +If the gradient is non-zero for the other samples in the batch, it means the forward pass of the model is mixing data! +The :class:`~pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerificationCallback` +does all of that for you before training begins. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pl_bolts.callbacks import BatchGradientVerificationCallback + + model = YourLightningModule() + verification = BatchGradientVerificationCallback() + trainer = Trainer(callbacks=[verification]) + trainer.fit(model) + +This Callback will warn the user with the following message in case data mixing inside the batch is detected: + +.. code-block:: + + Your model is mixing data across the batch dimension. + This can lead to wrong gradient updates in the optimizer. + Check the operations that reshape and permute tensor dimensions in your model. + + +A non-Callback version +:class:`~pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerification` +that works with any PyTorch :class:`~torch.nn.Module` is also available: + +.. code-block:: python + + from pl_bolts.utils import BatchGradientVerification + + model = YourPyTorchModel() + verification = BatchGradientVerification(model) + valid = verification.check(input_array=torch.rand(2, 3, 4), sample_idx=1) + +In this example we run the test on a batch size 2 by inspecting gradients on the second sample. diff --git a/pl_bolts/callbacks/__init__.py b/pl_bolts/callbacks/__init__.py index 0e56662a63..866ecfcdfa 100644 --- a/pl_bolts/callbacks/__init__.py +++ b/pl_bolts/callbacks/__init__.py @@ -6,10 +6,12 @@ from pl_bolts.callbacks.printing import PrintTableMetricsCallback # noqa: F401 from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator # noqa: F401 from pl_bolts.callbacks.variational import LatentDimInterpolator # noqa: F401 +from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback # noqa: F401 from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler # noqa: F401 __all__ = [ + "BatchGradientVerificationCallback", "BYOLMAWeightUpdate", "ModuleDataMonitor", "TrainingDataMonitor", diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index e95becd2bc..dfb57661b0 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -18,7 +18,7 @@ import wandb else: # pragma: no cover warn_missing_pkg("wandb") - wandb = None + wandb = None # type: ignore class DataMonitorBase(Callback): diff --git a/pl_bolts/callbacks/verification/__init__.py b/pl_bolts/callbacks/verification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/callbacks/verification/base.py b/pl_bolts/callbacks/verification/base.py new file mode 100644 index 0000000000..75f88f4a63 --- /dev/null +++ b/pl_bolts/callbacks/verification/base.py @@ -0,0 +1,123 @@ +# type: ignore +from abc import abstractmethod +from copy import deepcopy +from typing import Any, Optional + +import torch.nn as nn +from pytorch_lightning import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn + + +class VerificationBase: + """ + Base class for model verification. + All verifications should run with any :class:`torch.nn.Module` unless otherwise stated. + """ + + def __init__(self, model: nn.Module): + """ + Arguments: + model: The model to run verification for. + """ + super().__init__() + self.model = model + + @abstractmethod + def check(self, *args: Any, **kwargs: Any) -> bool: + """ Runs the actual test on the model. All verification classes must implement this. + + Arguments: + *args: Any positional arguments that are needed to run the test + *kwargs: Keyword arguments that are needed to run the test + + Returns: + `True` if the test passes, and `False` otherwise. Some verifications can only be performed + with a heuristic accuracy, thus the return value may not always reflect the true state of + the system in these cases. + """ + + def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any: + """ + Returns a deep copy of the example input array in cases where it is expected that the + input changes during the verification process. + + Arguments: + input_array: The input to clone. + """ + if input_array is None and isinstance(self.model, LightningModule): + input_array = self.model.example_input_array + input_array = deepcopy(input_array) + + if isinstance(self.model, LightningModule): + input_array = self.model.transfer_batch_to_device(input_array, self.model.device) + else: + input_array = move_data_to_device(input_array, device=next(self.model.parameters()).device) + + return input_array + + def _model_forward(self, input_array: Any) -> Any: + """ + Feeds the input array to the model via the ``__call__`` method. + + Arguments: + input_array: The input that goes into the model. If it is a tuple, it gets + interpreted as the sequence of positional arguments and is passed in by tuple unpacking. + If it is a dict, the contents get passed in as named parameters by unpacking the dict. + Otherwise, the input array gets passed in as a single argument. + + Returns: + The output of the model. + """ + if isinstance(input_array, tuple): + return self.model(*input_array) + if isinstance(input_array, dict): + return self.model(**input_array) + return self.model(input_array) + + +class VerificationCallbackBase(Callback): + """ + Base class for model verification in form of a callback. + This type of verification is expected to only work with + :class:`~pytorch_lightning.core.lightning.LightningModule` and will take the input array + from :attr:`~pytorch_lightning.core.lightning.LightningModule.example_input_array` if needed. + """ + + def __init__(self, warn: bool = True, error: bool = False) -> None: + """ + Arguments: + warn: If ``True``, prints a warning message when verification fails. Default: ``True``. + error: If ``True``, prints an error message when verification fails. Default: ``False``. + """ + self._raise_warning = warn + self._raise_error = error + + def message(self, *args: Any, **kwargs: Any) -> str: + """ + The message to be printed when the model does not pass the verification. + If the message for warning and error differ, override the + :meth:`warning_message` and :meth:`error_message` + methods directly. + + Arguments: + *args: Any positional arguments that are needed to construct the message. + **kwargs: Any keyword arguments that are needed to construct the message. + + Returns: + The message as a string. + """ + + def warning_message(self, *args: Any, **kwargs: Any) -> str: + """ The warning message printed when the model does not pass the verification. """ + return self.message(*args, **kwargs) + + def error_message(self, *args: Any, **kwargs: Any) -> str: + """ The error message printed when the model does not pass the verification. """ + return self.message(*args, **kwargs) + + def _raise(self, *args: Any, **kwargs: Any) -> None: + if self._raise_error: + raise RuntimeError(self.error_message(*args, **kwargs)) + if self._raise_warning: + rank_zero_warn(self.warning_message(*args, **kwargs)) diff --git a/pl_bolts/callbacks/verification/batch_gradient.py b/pl_bolts/callbacks/verification/batch_gradient.py new file mode 100644 index 0000000000..b8ec9963af --- /dev/null +++ b/pl_bolts/callbacks/verification/batch_gradient.py @@ -0,0 +1,192 @@ +# type: ignore +from typing import Any, Callable, List, Optional + +import torch +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from pl_bolts.callbacks.verification.base import VerificationBase, VerificationCallbackBase + + +class BatchGradientVerification(VerificationBase): + """ + Checks if a model mixes data across the batch dimension. + This can happen if reshape- and/or permutation operations are carried out in the wrong order or + on the wrong tensor dimensions. + """ + + def check( + self, + input_array: Any, + input_mapping: Optional[Callable] = None, + output_mapping: Optional[Callable] = None, + sample_idx: int = 0, + ) -> bool: + """ + Runs the test for data mixing across the batch. + + Arguments: + input_array: A dummy input for the model. Can be a tuple or dict in case the model takes + multiple positional or named arguments. + input_mapping: An optional input mapping that returns all batched tensors in a input collection. + By default, we handle nested collections (tuples, lists, dicts) of tensors and pull them + out. If your batch is a custom object, you need to provide this input mapping yourself. + See :func:`default_input_mapping` for more information on the default behavior. + output_mapping: An optional output mapping that combines all batched tensors in the output + collection into one big batch of shape (B, N), where N is the total number of dimensions + that follow the batch dimension in each tensor. By default, we handle nested collections + (tuples, lists, dicts) of tensors and combine them automatically. See + :func:`default_output_mapping` for more information on the default behavior. + sample_idx: + The index `i` of the batch sample to run the test for. When computing the gradient of + a loss value on the `i-th` output w.r.t. the whole input, we expect the gradient to be + non-zero only on the `i-th` input sample and zero gradient on the rest of the batch. + + Returns: + ``True`` if the data in the batch does not mix during the forward pass, and ``False`` otherwise. + """ + input_mapping = input_mapping or default_input_mapping + output_mapping = output_mapping or default_output_mapping + input_array = self._get_input_array_copy(input_array) + input_batches = input_mapping(input_array) + + if input_batches[0].size(0) < 2: + raise MisconfigurationException("Batch size must be greater than 1 to run verification.") + + for input_batch in input_batches: + input_batch.requires_grad = True + + self.model.zero_grad() + output = self._model_forward(input_array) + + # backward on the i-th sample should lead to gradient only in i-th input slice + output_mapping(output)[sample_idx].sum().backward() + + zero_grad_inds = list(range(len(input_batches[0]))) + zero_grad_inds.pop(sample_idx) + + has_grad_outside_sample = [input_batch.grad[zero_grad_inds].abs().sum().item() for input_batch in input_batches] + has_grad_inside_sample = [input_batch.grad[sample_idx].abs().sum().item() for input_batch in input_batches] + return not any(has_grad_outside_sample) and all(has_grad_inside_sample) + + +class BatchGradientVerificationCallback(VerificationCallbackBase): + """ + The callback version of the :class:`BatchGradientVerification` test. + Verification is performed right before training begins. + """ + + def __init__( + self, + input_mapping: Optional[Callable] = None, + output_mapping: Optional[Callable] = None, + sample_idx: int = 0, + **kwargs: Any, + ): + """ + Arguments: + input_mapping: An optional input mapping that returns all batched tensors in a input collection. + See :meth:`BatchGradientVerification.check` for more information. + output_mapping: An optional output mapping that combines all batched tensors in the output + collection into one big batch. See :meth:`BatchGradientVerification.check` for more information. + sample_idx: The index of the batch sample to run the test for. + See :meth:`BatchGradientVerification.check` for more information. + **kwargs: Additional arguments for the base class :class:`VerificationCallbackBase` + """ + super().__init__(**kwargs) + self._input_mapping = input_mapping + self._output_mapping = output_mapping + self._sample_idx = sample_idx + + def message(self, *args: Any, **kwargs: Any) -> str: + message = ( + "Your model is mixing data across the batch dimension." + " This can lead to wrong gradient updates in the optimizer." + " Check the operations that reshape and permute tensor dimensions in your model." + ) + return message + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + verification = BatchGradientVerification(pl_module) + result = verification.check( + input_array=pl_module.example_input_array, + input_mapping=self._input_mapping, + output_mapping=self._output_mapping, + sample_idx=self._sample_idx, + ) + if not result: + self._raise() + + +def default_input_mapping(data: Any) -> List[torch.Tensor]: + """ + Finds all tensors in a (nested) collection that have the same batch size. + + Args: + data: a tensor or a collection of tensors (tuple, list, dict, etc.). + + Returns: + A list of all tensors with the same batch dimensions. If the input was already a tensor, a one- + element list with the tensor is returned. + + >>> data = (torch.zeros(3, 1), "foo", torch.ones(3, 2), torch.rand(2)) + >>> result = default_input_mapping(data) + >>> len(result) + 2 + >>> result[0].shape + torch.Size([3, 1]) + >>> result[1].shape + torch.Size([3, 2]) + """ + tensors = collect_tensors(data) + batches: List[torch.Tensor] = [] + for tensor in tensors: + if tensor.ndim > 0 and (not batches or tensor.size(0) == batches[0].size(0)): + batches.append(tensor) + return batches + + +def default_output_mapping(data: Any) -> torch.Tensor: + """ + Pulls out all tensors in a output collection and combines them into one big batch + for verification. + + Args: + data: a tensor or a (nested) collection of tensors (tuple, list, dict, etc.). + + Returns: + A float tensor with shape (B, N) where B is the batch size and N is the sum of (flattened) + dimensions of all tensors in the collection. If the input was already a tensor, the tensor + itself is returned. + + Example: + >>> data = (torch.rand(3, 5), "foo", torch.rand(3, 2, 4)) + >>> result = default_output_mapping(data) + >>> result.shape + torch.Size([3, 13]) + >>> data = {"one": torch.rand(3, 5), "two": torch.rand(3, 2, 1)} + >>> result = default_output_mapping(data) + >>> result.shape + torch.Size([3, 7]) + """ + if isinstance(data, torch.Tensor): + return data + + batches = default_input_mapping(data) + # cannot use .flatten(1) because of tensors with shape (B, ) + batches = [batch.view(batch.size(0), -1).float() for batch in batches] + combined = torch.cat(batches, 1) # combined batch has shape (B, N) + return combined + + +def collect_tensors(data: Any) -> List[torch.Tensor]: + """ Filters all tensors in a collection and returns them in a list. """ + tensors = [] + + def collect_batches(tensor: torch.Tensor) -> torch.Tensor: + tensors.append(tensor) + return tensor + + apply_to_collection(data, dtype=torch.Tensor, function=collect_batches) + return tensors diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 795e0080c5..0a49d730c4 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -1,6 +1,8 @@ import torch from pytorch_lightning.utilities import _module_available +from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore + _NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _TORCHVISION_AVAILABLE: bool = _module_available("torchvision") @@ -10,3 +12,5 @@ _OPENCV_AVAILABLE: bool = _module_available("cv2") _WANDB_AVAILABLE: bool = _module_available("wandb") _MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib") + +__all__ = ["BatchGradientVerification"] diff --git a/tests/callbacks/verification/__init__.py b/tests/callbacks/verification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/callbacks/verification/test_base.py b/tests/callbacks/verification/test_base.py new file mode 100644 index 0000000000..b8f21124e1 --- /dev/null +++ b/tests/callbacks/verification/test_base.py @@ -0,0 +1,96 @@ +from unittest.mock import Mock, patch + +import pytest +import torch +import torch.nn as nn +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities import move_data_to_device + +from pl_bolts.callbacks.verification.base import VerificationBase + + +class TrivialVerification(VerificationBase): + + def check(self, *args, **kwargs): + return True + + +class PyTorchModel(nn.Module): + + def __init__(self): + super().__init__() + self.layer = nn.Linear(5, 2) + + def forward(self, *args): + return args + + +class LitModel(LightningModule): + + def __init__(self): + super().__init__() + self.example_input_array = None + self.model = PyTorchModel() + + def forward(self, *args): + return self.model(*args) + + +@pytest.mark.parametrize( + "device", + [ + pytest.param(torch.device("cpu")), + pytest.param( + torch.device("cuda", 0), + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU"), + ), + ], +) +def test_verification_base_get_input_array(device): + """ Test that the base class calls the correct methods to transfer the input to the device the model is on. """ + model = PyTorchModel().to(device) + verification = TrivialVerification(model) + input_tensor = torch.rand(5) + assert verification.model == model + + # for a PyTorch model, user must provide the input array + with patch( + "pl_bolts.callbacks.verification.base.move_data_to_device", + wraps=move_data_to_device, + ) as mocked: + copied_tensor = verification._get_input_array_copy(input_array=input_tensor) + mocked.assert_called_once() + assert copied_tensor.device == device + assert torch.allclose(input_tensor, copied_tensor.cpu()) + + model = LitModel().to(device) + model.example_input_array = input_tensor + verification = TrivialVerification(model) + + # for a LightningModule, user can rely on the example_input_array + with patch.object(model, "transfer_batch_to_device", wraps=model.transfer_batch_to_device) as mocked: + copied_tensor = verification._get_input_array_copy(input_array=None) + mocked.assert_called_once() + assert copied_tensor.device == model.device == device + assert torch.allclose(model.example_input_array, copied_tensor.cpu()) + + +def test_verification_base_model_forward_for_input_array(): + """ Test that the input_array is correctly fed to the forward method depending on its type. """ + model = Mock() + verification = TrivialVerification(model) + + # tuple must be passed as positional args + input_array = (1, torch.tensor(2), None) + verification._model_forward(input_array) + model.assert_called_with(1, torch.tensor(2), None) + + # dict must be passed as keyword args + input_array = {"one": 1, "two": torch.tensor(2), "three": None} + verification._model_forward(input_array) + model.assert_called_with(one=1, two=torch.tensor(2), three=None) + + # everything else will be passed directly + input_array = torch.rand(2) + verification._model_forward(input_array) + model.assert_called_with(input_array) diff --git a/tests/callbacks/verification/test_batch_gradient.py b/tests/callbacks/verification/test_batch_gradient.py new file mode 100644 index 0000000000..0f9e10405e --- /dev/null +++ b/tests/callbacks/verification/test_batch_gradient.py @@ -0,0 +1,257 @@ +from unittest.mock import Mock + +import pytest +import torch +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn as nn + +from pl_bolts.callbacks import BatchGradientVerificationCallback +from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping +from pl_bolts.utils import BatchGradientVerification + + +class TemplateModel(nn.Module): + + def __init__(self, mix_data=False): + """ Base model for testing. The setting ``mix_data=True`` simulates a wrong implementation. """ + super().__init__() + self.mix_data = mix_data + self.linear = nn.Linear(10, 5) + self.input_array = torch.rand(10, 5, 2) + + def forward(self, *args, **kwargs): + return self.forward__standard(*args, **kwargs) + + def forward__standard(self, x): + # x: (B, 5, 2) + if self.mix_data: + x = x.view(10, -1).permute(1, 0).view(-1, 10) # oops! + else: + x = x.view(-1, 10) # good! + return self.linear(x) + + +class MultipleInputModel(TemplateModel): + """ Base model for testing verification when forward accepts multiple arguments. """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_array = (torch.rand(10, 5, 2), torch.rand(10, 5, 2)) + + def forward(self, x, y, some_kwarg=True): + out = super().forward(x) + super().forward(y) + return out + + +class MultipleOutputModel(TemplateModel): + """ Base model for testing verification when forward has multiple outputs. """ + + def forward(self, x): + out = super().forward(x) + return None, out, out, False + + +class DictInputDictOutputModel(TemplateModel): + """ Base model for testing verification when forward has a collection of outputs. """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_array = { + "w": 42, + "x": { + "a": torch.rand(3, 5, 2) + }, + "y": torch.rand(3, 1, 5, 2), + "z": torch.tensor(2), + } + + def forward(self, y, x, z, w): + out1 = super().forward(x["a"]) + out2 = super().forward(y) + out3 = out1 + out2 + out = {1: out1, 2: out2, 3: [out1, out3]} + return out + + +class LitModel(LightningModule): + """ Base model for testing verification with LightningModules. """ + + def __init__(self, *args, **kwargs): + super().__init__() + self.model = DictInputDictOutputModel(*args, **kwargs) + self.example_input_array = self.model.input_array + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + +@pytest.mark.parametrize( + "model_class", + [TemplateModel, MultipleInputModel, MultipleOutputModel, DictInputDictOutputModel], +) +@pytest.mark.parametrize("mix_data", [True, False]) +@pytest.mark.parametrize( + "device", + [ + pytest.param(torch.device("cpu")), + pytest.param( + torch.device("cuda", 0), + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU"), + ), + ], +) +def test_batch_gradient_verification(model_class, mix_data, device): + """ Test detection of batch gradient mixing with different PyTorch models. """ + model = model_class(mix_data).to(device) + is_valid = not mix_data + verification = BatchGradientVerification(model) + result = verification.check(input_array=model.input_array) + assert result == is_valid + + +@pytest.mark.parametrize("mix_data", [True, False]) +@pytest.mark.parametrize( + "device", + [ + pytest.param(torch.device("cpu")), + pytest.param( + torch.device("cuda", 0), + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU"), + ), + ], +) +def test_batch_gradient_verification_pl_module(mix_data, device): + """ Test detection of batch gradient mixing with a LightningModule. """ + model = LitModel(mix_data).to(device) + is_valid = not mix_data + verification = BatchGradientVerification(model) + result = verification.check(input_array=None) + assert result == is_valid + + +@pytest.mark.parametrize( + "gpus", + [ + pytest.param(0), + pytest.param(1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU")), + ], +) +def test_batch_gradient_verification_callback(gpus): + """ Test detection of batch gradient mixing with the callback implementation. """ + trainer = Trainer(gpus=gpus) + model = LitModel(mix_data=True) + + expected = "Your model is mixing data across the batch dimension." + + callback = BatchGradientVerificationCallback() + with pytest.warns(UserWarning, match=expected): + callback.on_train_start(trainer, model) + + callback = BatchGradientVerificationCallback(error=True) + with pytest.raises(RuntimeError, match=expected): + callback.on_train_start(trainer, model) + + +def test_batch_verification_raises_on_batch_size_1(): + """ Test that batch gradient verification only works with batch size greater than one. """ + model = TemplateModel() + verification = BatchGradientVerification(model) + small_batch = model.input_array[0:1] + with pytest.raises(MisconfigurationException, match="Batch size must be greater than 1"): + verification.check(input_array=small_batch) + + +def test_batch_verification_calls_custom_input_output_mappings(): + """ Test that batch gradient verification can support different input and outputs with user-provided mappings. """ + model = MultipleInputModel() + + def input_mapping(inputs): + assert isinstance(inputs, tuple) and len(inputs) == 2 + return [inputs[0]] + + def output_mapping(outputs): + assert isinstance(outputs, torch.Tensor) + return torch.cat((outputs, outputs), 1) + + mocked_input_mapping = Mock(wraps=input_mapping) + mocked_output_mapping = Mock(wraps=output_mapping) + verification = BatchGradientVerification(model) + verification.check( + model.input_array, + input_mapping=mocked_input_mapping, + output_mapping=mocked_output_mapping, + ) + mocked_input_mapping.assert_called_once() + mocked_output_mapping.assert_called_once() + + +def test_default_input_mapping(): + """ Test the data types and nesting the default input mapping can handle. """ + b = 3 + tensor0 = torch.rand(b, 2, 5) + tensor1 = torch.rand(b, 9) + tensor2 = torch.rand(b, 5, 1) + + # Tensor + data = tensor0.double() + output = default_input_mapping(data) + assert len(output) == 1 + assert output[0] is data + + # tuple + data = ("foo", tensor1, tensor2, []) + out1, out2 = default_input_mapping(data) + assert out1 is tensor1 + assert out2 is tensor2 + + # dict + nesting + data = { + "one": ["foo", tensor2], + "two": tensor0, + } + out2, out0 = default_input_mapping(data) + assert out2 is tensor2 + assert out0 is tensor0 + + +def test_default_output_mapping(): + """ Test the data types and nesting the default output mapping can handle. """ + b = 3 + tensor0 = torch.rand(b, 2, 5) + tensor1 = torch.rand(b, 9) + tensor2 = torch.rand(b, 5, 1) + tensor3 = torch.rand(b) + scalar = torch.tensor(3.14) + + # Tensor + data = tensor0.double() + output = default_output_mapping(data) + assert output is data + + # tuple + nesting + data = (tensor0, None, tensor1, "foo", [tensor2]) + expected = torch.cat((tensor0.view(b, -1), tensor1.view(b, -1), tensor2.view(b, -1)), dim=1) + output = default_output_mapping(data) + assert torch.all(output == expected) + + # dict + nesting + data = { + "one": tensor1, + "two": { + "three": tensor3.double() + }, # will convert to float + "four": scalar, # ignored + "five": [tensor0, tensor0], + } + expected = torch.cat( + ( + tensor1.view(b, -1), + tensor3.view(b, -1), + tensor0.view(b, -1), + tensor0.view(b, -1), + ), + dim=1, + ) + output = default_output_mapping(data) + assert torch.all(output == expected)