From 3887f3b8036a1de0f8e17acbae0103344058d121 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 3 Sep 2024 16:39:35 +0100 Subject: [PATCH 1/4] fix (nn/conv): Fix regression introduced in #1017 --- src/brevitas/nn/quant_conv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 5c6c5afea..18fb3282e 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -46,7 +46,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): padding = 0 is_same_padded_strided = True else: @@ -132,7 +132,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): padding = 0 is_same_padded_strided = True else: @@ -220,7 +220,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): padding = 0 is_same_padded_strided = True else: From bc703e5615996545004b34bea167b8142de2f5a4 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 3 Sep 2024 17:37:50 +0100 Subject: [PATCH 2/4] fix/test (nn/conv): Fixed conv instantiation and added extra tests --- src/brevitas/nn/quant_conv.py | 9 ++++++--- tests/brevitas/nn/test_conv2d.py | 12 +++++++++--- tests/brevitas/nn/test_conv3d.py | 11 +++++++++-- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 18fb3282e..2385ad822 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -46,7 +46,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): + if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: @@ -132,7 +133,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): + if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: @@ -220,7 +222,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): + if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride)): padding = 0 is_same_padded_strided = True else: diff --git a/tests/brevitas/nn/test_conv2d.py b/tests/brevitas/nn/test_conv2d.py index cc4e49558..b37fcb254 100644 --- a/tests/brevitas/nn/test_conv2d.py +++ b/tests/brevitas/nn/test_conv2d.py @@ -1,10 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import pytest_cases import torch from torch.nn import BatchNorm2d from torch.nn import Conv2d -from torch.nn import Module from brevitas.inject.defaults import Int8BiasPerTensorFloatInternalScaling from brevitas.nn import QuantConv2d @@ -18,12 +18,18 @@ class TestQuantConv2d: - def test_module_init(self): + @pytest_cases.parametrize( + 'kwargs', [{}, { + 'padding': 'same', 'stride': 1}, { + 'padding': 'same', 'stride': (1, 1)}], + ids=['defaults', 'padding="same",stride=1', 'padding="same",stride=(1,1)']) + def test_module_init(self, kwargs): mod = QuantConv2d( out_channels=OUTPUT_CHANNELS, in_channels=INPUT_CHANNELS, kernel_size=KERNEL_SIZE, - bias=False) + bias=False, + **kwargs) def test_fp_quant_module(self): float_mod = Conv2d( diff --git a/tests/brevitas/nn/test_conv3d.py b/tests/brevitas/nn/test_conv3d.py index fdd35524f..a3b912dd5 100644 --- a/tests/brevitas/nn/test_conv3d.py +++ b/tests/brevitas/nn/test_conv3d.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import pytest_cases import torch from torch.nn import BatchNorm3d from torch.nn import Conv3d @@ -17,12 +18,18 @@ class TestQuantConv3d: - def test_module_init(self): + @pytest_cases.parametrize( + 'kwargs', [{}, { + 'padding': 'same', 'stride': 1}, { + 'padding': 'same', 'stride': (1, 1, 1)}], + ids=['defaults', 'padding="same",stride=1', 'padding="same",stride=(1,1,1)']) + def test_module_init(self, kwargs): mod = QuantConv3d( out_channels=OUTPUT_CHANNELS, in_channels=INPUT_CHANNELS, kernel_size=KERNEL_SIZE, - bias=False) + bias=False, + **kwargs) def test_fp_quant_module(self): float_mod = Conv3d( From 9de92f213995f2aac5bd36602417d974f29cb3fb Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 4 Sep 2024 12:14:00 +0100 Subject: [PATCH 3/4] fix (nn/convxd): typo fix on conditional --- src/brevitas/nn/quant_conv.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 2385ad822..1cca32516 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -46,8 +46,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance( - stride, int) else any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride))): padding = 0 is_same_padded_strided = True else: @@ -133,8 +133,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance( - stride, int) else any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride))): padding = 0 is_same_padded_strided = True else: @@ -222,8 +222,8 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance( - stride, int) else any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance( + stride, int) else any(map(lambda x: x > 1, stride))): padding = 0 is_same_padded_strided = True else: From d1710727c8b0f36b0470aa9fe664064aa52808e9 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 4 Sep 2024 12:18:01 +0100 Subject: [PATCH 4/4] test (nn/convxd): Added tests for padding issue --- tests/brevitas/nn/test_conv2d.py | 37 +++++++++++++++++++++++++++++--- tests/brevitas/nn/test_conv3d.py | 37 +++++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/tests/brevitas/nn/test_conv2d.py b/tests/brevitas/nn/test_conv2d.py index b37fcb254..0fe795d24 100644 --- a/tests/brevitas/nn/test_conv2d.py +++ b/tests/brevitas/nn/test_conv2d.py @@ -19,10 +19,16 @@ class TestQuantConv2d: @pytest_cases.parametrize( - 'kwargs', [{}, { + 'kwargs', + [{}, { 'padding': 'same', 'stride': 1}, { - 'padding': 'same', 'stride': (1, 1)}], - ids=['defaults', 'padding="same",stride=1', 'padding="same",stride=(1,1)']) + 'padding': 'same', 'stride': (1, 1)}, { + 'padding': 'same', 'stride': (2, 1)}], + ids=[ + 'defaults', + 'padding="same",stride=1', + 'padding="same",stride=(1,1)', + 'padding="same",stride=(2,1)']) def test_module_init(self, kwargs): mod = QuantConv2d( out_channels=OUTPUT_CHANNELS, @@ -108,3 +114,28 @@ def test_internally_scaled_int_bias_after_bn_merge(self): merge_bn(mod, bn) inp = torch.randn(1, INPUT_CHANNELS, 20, 20) mod(inp) + + @pytest_cases.parametrize( + 'kwargs', + [{ + 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': 1, 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': (1, 1), 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': (2, 1), 'is_same_padded_strided': True}, { + 'padding': 0, 'stride': (2, 1), 'is_same_padded_strided': False}], + ids=[ + 'defaults', + 'padding="same",stride=1', + 'padding="same",stride=(1,1)', + 'padding="same",stride=(2,1)', + 'padding=0,stride=(2,1)']) + def test_is_same_padded_strided(self, kwargs): + is_same_padded_strided = kwargs['is_same_padded_strided'] + del kwargs['is_same_padded_strided'] + mod = QuantConv2d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False, + **kwargs) + assert is_same_padded_strided == mod.is_same_padded_strided, f"Expected is_same_padded_strided={is_same_padded_strided}, found is_same_padded_strided={mod.is_same_padded_strided}" diff --git a/tests/brevitas/nn/test_conv3d.py b/tests/brevitas/nn/test_conv3d.py index a3b912dd5..ba3d78b7d 100644 --- a/tests/brevitas/nn/test_conv3d.py +++ b/tests/brevitas/nn/test_conv3d.py @@ -19,10 +19,16 @@ class TestQuantConv3d: @pytest_cases.parametrize( - 'kwargs', [{}, { + 'kwargs', + [{}, { 'padding': 'same', 'stride': 1}, { - 'padding': 'same', 'stride': (1, 1, 1)}], - ids=['defaults', 'padding="same",stride=1', 'padding="same",stride=(1,1,1)']) + 'padding': 'same', 'stride': (1, 1, 1)}, { + 'padding': 'same', 'stride': (2, 1, 1)}], + ids=[ + 'defaults', + 'padding="same",stride=1', + 'padding="same",stride=(1,1,1)', + 'padding="same",stride=(2,1,1)']) def test_module_init(self, kwargs): mod = QuantConv3d( out_channels=OUTPUT_CHANNELS, @@ -108,3 +114,28 @@ def test_internally_scaled_int_bias_after_bn_merge(self): merge_bn(mod, bn) inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20) mod(inp) + + @pytest_cases.parametrize( + 'kwargs', + [{ + 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': 1, 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': (1, 1, 1), 'is_same_padded_strided': False}, { + 'padding': 'same', 'stride': (2, 1, 1), 'is_same_padded_strided': True}, { + 'padding': 0, 'stride': (2, 1, 1), 'is_same_padded_strided': False}], + ids=[ + 'defaults', + 'padding="same",stride=1', + 'padding="same",stride=(1,1,1)', + 'padding="same",stride=(2,1,1)', + 'padding=0,stride=(2,1,1)']) + def test_is_same_padded_strided(self, kwargs): + is_same_padded_strided = kwargs['is_same_padded_strided'] + del kwargs['is_same_padded_strided'] + mod = QuantConv3d( + out_channels=OUTPUT_CHANNELS, + in_channels=INPUT_CHANNELS, + kernel_size=KERNEL_SIZE, + bias=False, + **kwargs) + assert is_same_padded_strided == mod.is_same_padded_strided, f"Expected is_same_padded_strided={is_same_padded_strided}, found is_same_padded_strided={mod.is_same_padded_strided}"