Skip to content

Commit

Permalink
test (nn/convxd): Added tests for padding issue
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Sep 4, 2024
1 parent 9de92f2 commit d171072
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
37 changes: 34 additions & 3 deletions tests/brevitas/nn/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
class TestQuantConv2d:

@pytest_cases.parametrize(
'kwargs', [{}, {
'kwargs',
[{}, {
'padding': 'same', 'stride': 1}, {
'padding': 'same', 'stride': (1, 1)}],
ids=['defaults', 'padding="same",stride=1', 'padding="same",stride=(1,1)'])
'padding': 'same', 'stride': (1, 1)}, {
'padding': 'same', 'stride': (2, 1)}],
ids=[
'defaults',
'padding="same",stride=1',
'padding="same",stride=(1,1)',
'padding="same",stride=(2,1)'])
def test_module_init(self, kwargs):
mod = QuantConv2d(
out_channels=OUTPUT_CHANNELS,
Expand Down Expand Up @@ -108,3 +114,28 @@ def test_internally_scaled_int_bias_after_bn_merge(self):
merge_bn(mod, bn)
inp = torch.randn(1, INPUT_CHANNELS, 20, 20)
mod(inp)

@pytest_cases.parametrize(
'kwargs',
[{
'is_same_padded_strided': False}, {
'padding': 'same', 'stride': 1, 'is_same_padded_strided': False}, {
'padding': 'same', 'stride': (1, 1), 'is_same_padded_strided': False}, {
'padding': 'same', 'stride': (2, 1), 'is_same_padded_strided': True}, {
'padding': 0, 'stride': (2, 1), 'is_same_padded_strided': False}],
ids=[
'defaults',
'padding="same",stride=1',
'padding="same",stride=(1,1)',
'padding="same",stride=(2,1)',
'padding=0,stride=(2,1)'])
def test_is_same_padded_strided(self, kwargs):
is_same_padded_strided = kwargs['is_same_padded_strided']
del kwargs['is_same_padded_strided']
mod = QuantConv2d(
out_channels=OUTPUT_CHANNELS,
in_channels=INPUT_CHANNELS,
kernel_size=KERNEL_SIZE,
bias=False,
**kwargs)
assert is_same_padded_strided == mod.is_same_padded_strided, f"Expected is_same_padded_strided={is_same_padded_strided}, found is_same_padded_strided={mod.is_same_padded_strided}"
37 changes: 34 additions & 3 deletions tests/brevitas/nn/test_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
class TestQuantConv3d:

@pytest_cases.parametrize(
'kwargs', [{}, {
'kwargs',
[{}, {
'padding': 'same', 'stride': 1}, {
'padding': 'same', 'stride': (1, 1, 1)}],
ids=['defaults', 'padding="same",stride=1', 'padding="same",stride=(1,1,1)'])
'padding': 'same', 'stride': (1, 1, 1)}, {
'padding': 'same', 'stride': (2, 1, 1)}],
ids=[
'defaults',
'padding="same",stride=1',
'padding="same",stride=(1,1,1)',
'padding="same",stride=(2,1,1)'])
def test_module_init(self, kwargs):
mod = QuantConv3d(
out_channels=OUTPUT_CHANNELS,
Expand Down Expand Up @@ -108,3 +114,28 @@ def test_internally_scaled_int_bias_after_bn_merge(self):
merge_bn(mod, bn)
inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20)
mod(inp)

@pytest_cases.parametrize(
'kwargs',
[{
'is_same_padded_strided': False}, {
'padding': 'same', 'stride': 1, 'is_same_padded_strided': False}, {
'padding': 'same', 'stride': (1, 1, 1), 'is_same_padded_strided': False}, {
'padding': 'same', 'stride': (2, 1, 1), 'is_same_padded_strided': True}, {
'padding': 0, 'stride': (2, 1, 1), 'is_same_padded_strided': False}],
ids=[
'defaults',
'padding="same",stride=1',
'padding="same",stride=(1,1,1)',
'padding="same",stride=(2,1,1)',
'padding=0,stride=(2,1,1)'])
def test_is_same_padded_strided(self, kwargs):
is_same_padded_strided = kwargs['is_same_padded_strided']
del kwargs['is_same_padded_strided']
mod = QuantConv3d(
out_channels=OUTPUT_CHANNELS,
in_channels=INPUT_CHANNELS,
kernel_size=KERNEL_SIZE,
bias=False,
**kwargs)
assert is_same_padded_strided == mod.is_same_padded_strided, f"Expected is_same_padded_strided={is_same_padded_strided}, found is_same_padded_strided={mod.is_same_padded_strided}"

0 comments on commit d171072

Please sign in to comment.