From 9e34becb3b57e7209897714c5535bd41f5a00595 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Tue, 13 Sep 2022 14:35:56 +0900 Subject: [PATCH] add silu --- mmcv/cnn/bricks/activation.py | 3 +++ tests/test_cnn/test_silu.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 tests/test_cnn/test_silu.py diff --git a/mmcv/cnn/bricks/activation.py b/mmcv/cnn/bricks/activation.py index 23e6272277..50414aee0f 100644 --- a/mmcv/cnn/bricks/activation.py +++ b/mmcv/cnn/bricks/activation.py @@ -14,6 +14,9 @@ ]: ACTIVATION_LAYERS.register_module(module=module) +if digit_version(torch.__version__) >= digit_version('1.7.0'): + ACTIVATION_LAYERS.register_module(module=nn.SiLU) + @ACTIVATION_LAYERS.register_module(name='Clip') @ACTIVATION_LAYERS.register_module() diff --git a/tests/test_cnn/test_silu.py b/tests/test_cnn/test_silu.py new file mode 100644 index 0000000000..e202559bb4 --- /dev/null +++ b/tests/test_cnn/test_silu.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +import torch.nn.functional as F + +from mmcv.cnn.bricks import build_activation_layer +from mmcv.utils import digit_version + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.7.0'), + reason='torch.nn.SiLU is not available before 1.7.0') +def test_silu(): + act = build_activation_layer(dict(type='SiLU')) + input = torch.randn(1, 3, 64, 64) + expected_output = F.silu(input) + output = act(input) + # test output shape + assert output.shape == expected_output.shape + # test output value + assert torch.equal(output, expected_output)