diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6b024fcd7c..ee90b0e9a1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -45,7 +45,22 @@ jobs: - name: Run unittests and generate coverage report run: | pip install -r requirements/test.txt - pytest tests/ --ignore=tests/test_runner --ignore=tests/test_device/test_ipu --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py --ignore=tests/test_device/test_mlu/test_mlu_parallel.py + pytest tests/ \ + --ignore=tests/test_runner \ + --ignore=tests/test_device/test_ipu \ + --ignore=tests/test_optimizer.py \ + --ignore=tests/test_cnn \ + --ignore=tests/test_parallel.py \ + --ignore=tests/test_ops \ + --ignore=tests/test_load_model_zoo.py \ + --ignore=tests/test_utils/test_logging.py \ + --ignore=tests/test_image/test_io.py \ + --ignore=tests/test_utils/test_registry.py \ + --ignore=tests/test_utils/test_parrots_jit.py \ + --ignore=tests/test_utils/test_trace.py \ + --ignore=tests/test_utils/test_hub.py \ + --ignore=tests/test_device/test_mlu/test_mlu_parallel.py \ + --ignore=tests/test_utils/test_torch_ops.py build_without_ops: runs-on: ubuntu-18.04 diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index e0825ed670..059ae746da 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -53,6 +53,7 @@ # yapf: enable from .registry import Registry, build_from_cfg from .seed import worker_init_fn + from .torch_ops import torch_meshgrid from .trace import is_jit_tracing __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', @@ -74,5 +75,6 @@ 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', '_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE', - 'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE' + 'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE', + 'torch_meshgrid' ] diff --git a/mmcv/utils/torch_ops.py b/mmcv/utils/torch_ops.py new file mode 100644 index 0000000000..b4f2213a43 --- /dev/null +++ b/mmcv/utils/torch_ops.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from .parrots_wrapper import TORCH_VERSION +from .version_utils import digit_version + +_torch_version_meshgrid_indexing = ( + 'parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) + + +def torch_meshgrid(*tensors): + """A wrapper of torch.meshgrid to compat different PyTorch versions. + + Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``. + So we implement a wrapper here to avoid warning when using high-version + PyTorch and avoid compatibility issues when using previous versions of + PyTorch. + + Args: + tensors (List[Tensor]): List of scalars or 1 dimensional tensors. + + Returns: + Sequence[Tensor]: Sequence of meshgrid tensors. + """ + if _torch_version_meshgrid_indexing: + return torch.meshgrid(*tensors, indexing='ij') + else: + return torch.meshgrid(*tensors) # Uses indexing='ij' by default diff --git a/tests/test_utils/test_torch_ops.py b/tests/test_utils/test_torch_ops.py new file mode 100644 index 0000000000..e8752e0fd6 --- /dev/null +++ b/tests/test_utils/test_torch_ops.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmcv.utils import torch_meshgrid + + +def test_torch_meshgrid(): + # torch_meshgrid should not throw warning + with pytest.warns(None) as record: + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6]) + grid_x, grid_y = torch_meshgrid(x, y) + + assert len(record) == 0