From c0cbbbceccd759d3018cb7eb123a0765f053cf56 Mon Sep 17 00:00:00 2001 From: huangqipeng Date: Mon, 5 Jun 2023 16:47:39 +0800 Subject: [PATCH 1/5] [Feature] Support modulated_deform_conv on 2.x branch with cambricon MLU backend --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- mmcv/ops/__init__.py | 6 ++ mmcv/ops/modulated_deform_conv.py | 66 ++++++++++++++++++++ tests/test_ops/test_modulated_deform_conv.py | 55 +++++++++++----- 5 files changed, 114 insertions(+), 17 deletions(-) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index c4212742d8..51ef29a6de 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -32,7 +32,7 @@ We implement common ops used in detection, segmentation, etc. | MaskedConv | | √ | √ | | √ | | MergeCells | | √ | | | | | MinAreaPolygon | | √ | | | | -| ModulatedDeformConv2d | √ | √ | | | √ | +| ModulatedDeformConv2d | √ | √ | √ | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | | NMSRotated | √ | √ | √ | | √ | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 3e34793685..d09ac92b30 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -32,7 +32,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | MaskedConv | | √ | √ | | √ | | MergeCells | | √ | | | | | MinAreaPolygon | | √ | | | | -| ModulatedDeformConv2d | √ | √ | | | √ | +| ModulatedDeformConv2d | √ | √ | √ | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | | NMSRotated | √ | √ | √ | | √ | diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index cffbd23fd4..eb1f18b6bb 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from mmcv.utils import IS_MLU_AVAILABLE from .active_rotated_filter import active_rotated_filter from .assign_score_withk import assign_score_withk from .ball_query import ball_query @@ -109,3 +110,8 @@ 'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d', 'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align' ] + +if IS_MLU_AVAILABLE: + from .modulated_deform_conv import \ + ModulatedDeformConv2dPack_MLU # noqa:F401 + __all__.append('ModulatedDeformConv2dPack_MLU') diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index dcdada47bd..b162a0e9dc 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -11,6 +11,7 @@ from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single +from mmcv.utils import IS_MLU_AVAILABLE from ..utils import ext_loader ext_module = ext_loader.load_ext( @@ -358,3 +359,68 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +if IS_MLU_AVAILABLE: + import torchvision + from mmengine.utils import digit_version + from torchvision.ops import deform_conv2d as tv_deform_conv2d + + @MODELS.register_module('DCNv2', force=True) + class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d): + """This class is the DCNv2 implementation of the MLU device. + + The MLU backend support of the operator has been implemented + in torchvision. The mmcv registration mechanism is used for + multiplexing here. The torchvision implementation of DCNv2 is called. + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by + the norm_cfg. Bias will be set as True if norm_cfg is None, + otherwise False. + """ + + def __init__(self, *args, **kwargs): + assert digit_version(torchvision.__version__) >= digit_version( + '0.10.0a0'), 'the version of torchvision should be >= 0.10.0' + super().__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 3 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=True) + self.init_weights() + + def init_weights(self): + super().init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + x = x.type_as(offset) + weight = self.weight.type_as(x) + mask = mask.type_as(x) + return tv_deform_conv2d( + x, + offset, + weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + mask=mask) diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py index b7e48edef0..5c9ccbf23b 100644 --- a/tests/test_ops/test_modulated_deform_conv.py +++ b/tests/test_ops/test_modulated_deform_conv.py @@ -7,7 +7,7 @@ from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION -from mmcv.utils import IS_CUDA_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast @@ -44,7 +44,12 @@ class TestMdconv: def _test_mdconv(self, dtype=torch.float, device='cuda'): if not torch.cuda.is_available() and device == 'cuda': pytest.skip('test requires GPU') - from mmcv.ops import ModulatedDeformConv2dPack + if device == 'mlu': + from mmcv.ops import \ + ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack + else: + from mmcv.ops import ModulatedDeformConv2dPack + input = torch.tensor(input_t, dtype=dtype, device=device) input.requires_grad = True @@ -55,10 +60,7 @@ def _test_mdconv(self, dtype=torch.float, device='cuda'): stride=1, padding=1, deform_groups=1, - bias=False) - - if device == 'cuda': - dcn.cuda() + bias=False).to(device) dcn.weight.data.fill_(1.) dcn.type(dtype) @@ -75,7 +77,7 @@ def _test_mdconv(self, dtype=torch.float, device='cuda'): assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(), dcn_offset_b_grad, 1e-2) - def _test_amp_mdconv(self, input_dtype=torch.float): + def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'): """The function to test amp released on pytorch 1.6.0. The type of input data might be torch.float or torch.half, @@ -85,10 +87,15 @@ def _test_amp_mdconv(self, input_dtype=torch.float): Args: input_dtype: torch.float or torch.half. """ - if not torch.cuda.is_available(): + if not torch.cuda.is_available() and device == 'cuda': return - from mmcv.ops import ModulatedDeformConv2dPack - input = torch.tensor(input_t).cuda().type(input_dtype) + if device == 'mlu': + from mmcv.ops import \ + ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack + else: + from mmcv.ops import ModulatedDeformConv2dPack + + input = torch.tensor(input_t).to(device).type(input_dtype) input.requires_grad = True dcn = ModulatedDeformConv2dPack( @@ -98,7 +105,7 @@ def _test_amp_mdconv(self, input_dtype=torch.float): stride=1, padding=1, deform_groups=1, - bias=False).cuda() + bias=False).to(device) dcn.weight.data.fill_(1.) output = dcn(input) output.sum().backward() @@ -119,6 +126,10 @@ def _test_amp_mdconv(self, input_dtype=torch.float): 'cuda', marks=pytest.mark.skipif( not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')), ]) def test_mdconv_float(self, device): self._test_mdconv(dtype=torch.float, device=device) @@ -129,16 +140,30 @@ def test_mdconv_float(self, device): 'cuda', marks=pytest.mark.skipif( not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')), ]) def test_mdconv_double(self, device): self._test_mdconv(dtype=torch.double, device=device) - def test_mdconv_half(self): - self._test_mdconv(torch.half) + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')), + ]) + def test_mdconv_half(self, device): + self._test_mdconv(torch.half, device=device) # test amp when torch version >= '1.6.0', the type of # input data for mdconv might be torch.float or torch.half if (TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): with autocast(enabled=True): - self._test_amp_mdconv(torch.float) - self._test_amp_mdconv(torch.half) + self._test_amp_mdconv(torch.float, device=device) + self._test_amp_mdconv(torch.half, device=device) From 7b7da213156c9e5d4e109f94b22759a796dd8621 Mon Sep 17 00:00:00 2001 From: huangqipeng Date: Tue, 6 Jun 2023 14:56:04 +0800 Subject: [PATCH 2/5] [Feature] Support deform_conv on 2.x branch with cambricon MLU backend --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- mmcv/ops/__init__.py | 3 +- mmcv/ops/deform_conv.py | 61 ++++++++++++++++++++++++++++++ tests/test_ops/test_deform_conv.py | 41 ++++++++++++++------ 5 files changed, 94 insertions(+), 15 deletions(-) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 51ef29a6de..c460a243ea 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc. | ConvexIoU | | √ | | | | | CornerPool | | √ | | | | | Correlation | | √ | | | | -| Deformable Convolution v1/v2 | √ | √ | | | √ | +| Deformable Convolution v1/v2 | √ | √ | √ | | √ | | Deformable RoIPool | | √ | √ | | √ | | DiffIoURotated | | √ | | | | | DynamicScatter | | √ | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index d09ac92b30..892548cad6 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ConvexIoU | | √ | | | | | CornerPool | | √ | | | | | Correlation | | √ | | | | -| Deformable Convolution v1/v2 | √ | √ | | | √ | +| Deformable Convolution v1/v2 | √ | √ | √ | | √ | | Deformable RoIPool | | √ | √ | | √ | | DiffIoURotated | | √ | | | | | DynamicScatter | | √ | √ | | | diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index eb1f18b6bb..ffad9b2bfd 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -112,6 +112,7 @@ ] if IS_MLU_AVAILABLE: + from .deform_conv import DeformConv2dPack_MLU # noqa:F401 from .modulated_deform_conv import \ ModulatedDeformConv2dPack_MLU # noqa:F401 - __all__.append('ModulatedDeformConv2dPack_MLU') + __all__.extend(['ModulatedDeformConv2dPack_MLU', 'DeformConv2dPack_MLU']) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index eb4b4e6f5a..8251bc7328 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -12,6 +12,7 @@ from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single +from mmcv.utils import IS_MLU_AVAILABLE from ..utils import ext_loader from .modulated_deform_conv import ModulatedDeformConv2dFunction @@ -438,3 +439,63 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +if IS_MLU_AVAILABLE: + import torchvision + from mmengine.utils import digit_version + from torchvision.ops import deform_conv2d as tv_deform_conv2d + + @MODELS.register_module('DCN', force=True) + class DeformConv2dPack_MLU(DeformConv2d): + """This class is the DCN implementation of the MLU device. The MLU + backend support of the operator has been implemented in torchvision. + The mmcv registration mechanism is used for multiplexing here. The + torchvision implementation of DCN is called. + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by + the norm_cfg. Bias will be set as True if norm_cfg is None, + otherwise False. + im2col_step (int): Number of samples processed by + im2col_cuda_kernel per call. It will work when ``batch_size`` + > ``im2col_step``, but ``batch_size`` must be divisible by + ``im2col_step``. Default: 32. `New in version 1.7.2. + Currently not supported on MLU devices.` + """ + + def __init__(self, *args, **kwargs): + assert digit_version(torchvision.__version__) >= digit_version( + '0.10.0a0'), 'the version of torchvision should be >= 0.10.0' + super().__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 2 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x: Tensor) -> Tensor: # type: ignore + cur_im2col_step = min(self.im2col_step, x.size(0)) + assert (x.size(0) % cur_im2col_step + ) == 0, 'batch size must be divisible by im2col_step' + offset = self.conv_offset(x) + x = x.type_as(offset) + weight = self.weight.type_as(x) + return tv_deform_conv2d(x, offset, weight, None, self.stride, + self.padding, self.dilation) diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index 64dcccfdef..483dccd3d1 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -5,6 +5,8 @@ from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION +from mmcv.utils import IS_MLU_AVAILABLE + try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast # would be imported and used; we should test if our modules support it. @@ -45,7 +47,10 @@ def _test_deformconv(self, im2col_step=2): if not torch.cuda.is_available() and device == 'cuda': pytest.skip('test requires GPU') - from mmcv.ops import DeformConv2dPack + if device == 'mlu': + from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack + else: + from mmcv.ops import DeformConv2dPack c_in = 1 c_out = 1 batch_size = 10 @@ -69,6 +74,8 @@ def _test_deformconv(self, torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) if device == 'cuda': model.cuda() + elif device == 'mlu': + model.mlu() model.type(dtype) out = model(x) @@ -108,6 +115,7 @@ def _test_deformconv(self, def _test_amp_deformconv(self, input_dtype, threshold=1e-3, + device='cuda', batch_size=10, im2col_step=2): """The function to test amp released on pytorch 1.6.0. @@ -120,15 +128,18 @@ def _test_amp_deformconv(self, input_dtype: torch.float or torch.half. threshold: the same as above function. """ - if not torch.cuda.is_available(): + if not torch.cuda.is_available() and device == 'cuda': return - from mmcv.ops import DeformConv2dPack + if device == 'mlu': + from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack + else: + from mmcv.ops import DeformConv2dPack c_in = 1 c_out = 1 repeated_input = np.repeat(input, batch_size, axis=0) repeated_gt_out = np.repeat(gt_out, batch_size, axis=0) repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0) - x = torch.Tensor(repeated_input).cuda().type(input_dtype) + x = torch.Tensor(repeated_input).to(device).type(input_dtype) x.requires_grad = True model = DeformConv2dPack( in_channels=c_in, @@ -143,7 +154,10 @@ def _test_amp_deformconv(self, torch.Tensor(offset_bias).reshape(8)) model.weight.data = torch.nn.Parameter( torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) - model.cuda() + if device == 'cuda': + model.cuda() + elif device == 'mlu': + model.mlu() out = model(x) out.backward(torch.ones_like(out)) @@ -180,21 +194,24 @@ def _test_amp_deformconv(self, def test_deformconv(self): self._test_deformconv(torch.double, device='cpu') self._test_deformconv(torch.float, device='cpu', threshold=1e-1) - self._test_deformconv(torch.double) - self._test_deformconv(torch.float) - self._test_deformconv(torch.half, threshold=1e-1) + device = 'mlu' if IS_MLU_AVAILABLE else 'cuda' + self._test_deformconv(torch.double, device=device) + self._test_deformconv(torch.float, device=device) + self._test_deformconv(torch.half, threshold=1e-1, device=device) # test batch_size < im2col_step - self._test_deformconv(torch.float, batch_size=1, im2col_step=2) + self._test_deformconv( + torch.float, batch_size=1, im2col_step=2, device=device) # test bach_size % im2col_step != 0 with pytest.raises( AssertionError, match='batch size must be divisible by im2col_step'): - self._test_deformconv(torch.float, batch_size=10, im2col_step=3) + self._test_deformconv( + torch.float, batch_size=10, im2col_step=3, device=device) # test amp when torch version >= '1.6.0', the type of # input data for deformconv might be torch.float or torch.half if (TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): with autocast(enabled=True): - self._test_amp_deformconv(torch.float, 1e-1) - self._test_amp_deformconv(torch.half, 1e-1) + self._test_amp_deformconv(torch.float, 1e-1, device) + self._test_amp_deformconv(torch.half, 1e-1, device) From 7b9a4922729646c04a20b54a0628fd0d2c22f44d Mon Sep 17 00:00:00 2001 From: huangqipeng Date: Fri, 16 Jun 2023 14:34:35 +0800 Subject: [PATCH 3/5] [Feature] Refactor test_deform_conv --- tests/test_ops/test_deform_conv.py | 58 ++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index 483dccd3d1..7f381c8ead 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -5,7 +5,7 @@ from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION -from mmcv.utils import IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast @@ -191,13 +191,21 @@ def _test_amp_deformconv(self, with pytest.raises(AssertionError): model = DeformConv2d(3, 4, 3, groups=3) - def test_deformconv(self): - self._test_deformconv(torch.double, device='cpu') - self._test_deformconv(torch.float, device='cpu', threshold=1e-1) - device = 'mlu' if IS_MLU_AVAILABLE else 'cuda' - self._test_deformconv(torch.double, device=device) - self._test_deformconv(torch.float, device=device) - self._test_deformconv(torch.half, threshold=1e-1, device=device) + @pytest.mark.parametrize('device, threshold', [ + ('cpu', 1e-1), + pytest.param( + 'cuda', + 1e-3, + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + 1e-3, + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')), + ]) + def test_deformconv_float(self, device, threshold): + self._test_deformconv(torch.float, device=device, threshold=threshold) # test batch_size < im2col_step self._test_deformconv( torch.float, batch_size=1, im2col_step=2, device=device) @@ -208,10 +216,40 @@ def test_deformconv(self): self._test_deformconv( torch.float, batch_size=10, im2col_step=3, device=device) + @pytest.mark.parametrize('device', [ + 'cpu', + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')), + ]) + def test_deformconv_double(self, device): + self._test_deformconv(torch.double, device=device) + + @pytest.mark.parametrize('device, threshold', [ + pytest.param( + 'cuda', + 1e-1, + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + 1e-1, + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')), + ]) + def test_deformconv_half(self, device, threshold): + self._test_deformconv(torch.half, device=device, threshold=threshold) # test amp when torch version >= '1.6.0', the type of # input data for deformconv might be torch.float or torch.half if (TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): with autocast(enabled=True): - self._test_amp_deformconv(torch.float, 1e-1, device) - self._test_amp_deformconv(torch.half, 1e-1, device) + self._test_amp_deformconv( + torch.float, device=device, threshold=threshold) + self._test_amp_deformconv( + torch.half, device=device, threshold=threshold) From 8fbe0070afa63021fa454c8f47050bfb27674986 Mon Sep 17 00:00:00 2001 From: huangqipeng Date: Fri, 30 Jun 2023 12:52:58 +0800 Subject: [PATCH 4/5] [Feature] fix test_deform_conv --- tests/test_ops/test_deform_conv.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index 7f381c8ead..7f2801fdc0 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -7,6 +7,9 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +if IS_MLU_AVAILABLE: + torch.backends.cnnl.allow_tf32 = False + try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast # would be imported and used; we should test if our modules support it. From 96bb3089af6fbef97ab440fafe3b52cbf29e7b8f Mon Sep 17 00:00:00 2001 From: huangqipeng Date: Fri, 30 Jun 2023 15:24:59 +0800 Subject: [PATCH 5/5] [Feature] fix pytest of masked_conv2d on tf32 --- tests/test_ops/test_masked_conv2d.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_ops/test_masked_conv2d.py b/tests/test_ops/test_masked_conv2d.py index a292f6a4fd..c949e0b71c 100644 --- a/tests/test_ops/test_masked_conv2d.py +++ b/tests/test_ops/test_masked_conv2d.py @@ -5,6 +5,10 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +if IS_MLU_AVAILABLE: + torch.backends.cnnl.allow_tf32 = False + torch.backends.mlu.matmul.allow_tf32 = False + class TestMaskedConv2d: