Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add Pytorch HardSwish assertion in unit test #1294

Merged
merged 6 commits into from
Feb 16, 2022
12 changes: 10 additions & 2 deletions tests/test_models/test_backbones/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def test_inv_residualv3():
assert output.shape == (1, 32, 64, 64)

# test with se_cfg and with_expand_conv
# Note: Use PyTorch official HSwish when torch>=1.7 after MMCV >= 1.4.5.
# Hardswish is not supported when PyTorch version < 1.6.
# And Hardswish in PyTorch 1.6 does not support inplace.
# More details could be found from:
# https://github.com/open-mmlab/mmcv/pull/1709
se_cfg = dict(
channels=16,
ratio=4,
Expand All @@ -108,15 +113,18 @@ def test_inv_residualv3():
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
assert inv_module.expand_conv.conv.stride == (1, 1)
assert inv_module.expand_conv.conv.padding == (0, 0)
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
assert isinstance(inv_module.expand_conv.activate,
(mmcv.cnn.HSwish, torch.nn.Hardswish))

assert isinstance(inv_module.depthwise_conv.conv,
mmcv.cnn.bricks.Conv2dAdaptivePadding)
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
assert inv_module.depthwise_conv.conv.stride == (2, 2)
assert inv_module.depthwise_conv.conv.padding == (0, 0)
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)

assert isinstance(inv_module.depthwise_conv.activate,
(mmcv.cnn.HSwish, torch.nn.Hardswish))
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
assert inv_module.linear_conv.conv.stride == (1, 1)
assert inv_module.linear_conv.conv.padding == (0, 0)
Expand Down