diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 5c6c5afea..1cca32516 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -46,7 +46,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride))): padding = 0 is_same_padded_strided = True else: @@ -132,7 +133,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride))): padding = 0 is_same_padded_strided = True else: @@ -220,7 +222,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride))): padding = 0 is_same_padded_strided = True else: diff --git a/tests/brevitas/nn/test_conv2d.py b/tests/brevitas/nn/test_conv2d.py index cc4e49558..0fe795d24 100644 --- a/tests/brevitas/nn/test_conv2d.py +++ b/tests/brevitas/nn/test_conv2d.py @@ -1,10 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import pytest_cases import torch from torch.nn import BatchNorm2d from torch.nn import Conv2d -from torch.nn import Module from brevitas.inject.defaults import Int8BiasPerTensorFloatInternalScaling from brevitas.nn import QuantConv2d @@ -18,12 +18,24 @@ class TestQuantConv2d: - def test_module_init(self): + @pytest_cases.parametrize( + 'kwargs', + [{}, { + 'padding': 'same', 'stride': 1}, { + 'padding': 'same', 'stride': (1, 1)}, { + 'padding': 'same', 'stride': (2, 1)}], + ids=[ + 'defaults', + 'padding="same",stride=1', + 'padding="same",stride=(1,1)', + 'padding="same",stride=(2,1)']) + def test_module_init(self, kwargs): mod = QuantConv2d( out_channels=OUTPUT_CHANNELS, in_channels=INPUT_CHANNELS, kernel_size=KERNEL_SIZE, - bias=False) + bias=False, + **kwargs) def test_fp_quant_module(self): float_mod = Conv2d( @@ -102,3 +114,28 @@ def test_internally_scaled_int_bias_after_bn_merge(self): merge_bn(mod, bn) inp = torch.randn(1, INPUT_CHANNELS, 20, 20) mod(inp) + + @pytest_cases.parametrize( + 'kwargs', + [{ + 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': 1, 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': (1, 1), 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': (2, 1), 'is_same_padded_strided': True}, { + 'padding': 0, 'stride': (2, 1), 'is_same_padded_strided': False}], + ids=[ + 'defaults', + 'padding="same",stride=1', + 'padding="same",stride=(1,1)', + 'padding="same",stride=(2,1)', + 'padding=0,stride=(2,1)']) + def test_is_same_padded_strided(self, kwargs): + is_same_padded_strided = kwargs['is_same_padded_strided'] + del kwargs['is_same_padded_strided'] + mod = QuantConv2d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False, + **kwargs) + assert is_same_padded_strided == mod.is_same_padded_strided, f"Expected is_same_padded_strided={is_same_padded_strided}, found is_same_padded_strided={mod.is_same_padded_strided}" diff --git a/tests/brevitas/nn/test_conv3d.py b/tests/brevitas/nn/test_conv3d.py index fdd35524f..ba3d78b7d 100644 --- a/tests/brevitas/nn/test_conv3d.py +++ b/tests/brevitas/nn/test_conv3d.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import pytest_cases import torch from torch.nn import BatchNorm3d from torch.nn import Conv3d @@ -17,12 +18,24 @@ class TestQuantConv3d: - def test_module_init(self): + @pytest_cases.parametrize( + 'kwargs', + [{}, { + 'padding': 'same', 'stride': 1}, { + 'padding': 'same', 'stride': (1, 1, 1)}, { + 'padding': 'same', 'stride': (2, 1, 1)}], + ids=[ + 'defaults', + 'padding="same",stride=1', + 'padding="same",stride=(1,1,1)', + 'padding="same",stride=(2,1,1)']) + def test_module_init(self, kwargs): mod = QuantConv3d( out_channels=OUTPUT_CHANNELS, in_channels=INPUT_CHANNELS, kernel_size=KERNEL_SIZE, - bias=False) + bias=False, + **kwargs) def test_fp_quant_module(self): float_mod = Conv3d( @@ -101,3 +114,28 @@ def test_internally_scaled_int_bias_after_bn_merge(self): merge_bn(mod, bn) inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) mod(inp) + + @pytest_cases.parametrize( + 'kwargs', + [{ + 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': 1, 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': (1, 1, 1), 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': (2, 1, 1), 'is_same_padded_strided': True}, { + 'padding': 0, 'stride': (2, 1, 1), 'is_same_padded_strided': False}], + ids=[ + 'defaults', + 'padding="same",stride=1', + 'padding="same",stride=(1,1,1)', + 'padding="same",stride=(2,1,1)', + 'padding=0,stride=(2,1,1)']) + def test_is_same_padded_strided(self, kwargs): + is_same_padded_strided = kwargs['is_same_padded_strided'] + del kwargs['is_same_padded_strided'] + mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False, + **kwargs) + assert is_same_padded_strided == mod.is_same_padded_strided, f"Expected is_same_padded_strided={is_same_padded_strided}, found is_same_padded_strided={mod.is_same_padded_strided}"