Skip to content

Commit

Permalink
Add torch_meshgrid wrapper due to PyTorch change (open-mmlab#2044)
Browse files Browse the repository at this point in the history
* Add torch_meshgrid_ij wrapper due to PyTorch change

* Update torch_meshgrid name/doc/version implementation

* Make imports local

* add ut

* ignore ut when torch is not available

Co-authored-by: zhouzaida <zhouzaida@163.com>
  • Loading branch information
pallgeuer and zhouzaida authored Jun 15, 2022
1 parent b062468 commit f5425ab
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 2 deletions.
17 changes: 16 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,22 @@ jobs:
- name: Run unittests and generate coverage report
run: |
pip install -r requirements/test.txt
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_device/test_ipu --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py --ignore=tests/test_device/test_mlu/test_mlu_parallel.py
pytest tests/ \
--ignore=tests/test_runner \
--ignore=tests/test_device/test_ipu \
--ignore=tests/test_optimizer.py \
--ignore=tests/test_cnn \
--ignore=tests/test_parallel.py \
--ignore=tests/test_ops \
--ignore=tests/test_load_model_zoo.py \
--ignore=tests/test_utils/test_logging.py \
--ignore=tests/test_image/test_io.py \
--ignore=tests/test_utils/test_registry.py \
--ignore=tests/test_utils/test_parrots_jit.py \
--ignore=tests/test_utils/test_trace.py \
--ignore=tests/test_utils/test_hub.py \
--ignore=tests/test_device/test_mlu/test_mlu_parallel.py \
--ignore=tests/test_utils/test_torch_ops.py
build_without_ops:
runs-on: ubuntu-18.04
Expand Down
4 changes: 3 additions & 1 deletion mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
# yapf: enable
from .registry import Registry, build_from_cfg
from .seed import worker_init_fn
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
Expand All @@ -74,5 +75,6 @@
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE'
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
'torch_meshgrid'
]
29 changes: 29 additions & 0 deletions mmcv/utils/torch_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .parrots_wrapper import TORCH_VERSION
from .version_utils import digit_version

_torch_version_meshgrid_indexing = (
'parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))


def torch_meshgrid(*tensors):
"""A wrapper of torch.meshgrid to compat different PyTorch versions.
Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``.
So we implement a wrapper here to avoid warning when using high-version
PyTorch and avoid compatibility issues when using previous versions of
PyTorch.
Args:
tensors (List[Tensor]): List of scalars or 1 dimensional tensors.
Returns:
Sequence[Tensor]: Sequence of meshgrid tensors.
"""
if _torch_version_meshgrid_indexing:
return torch.meshgrid(*tensors, indexing='ij')
else:
return torch.meshgrid(*tensors) # Uses indexing='ij' by default
15 changes: 15 additions & 0 deletions tests/test_utils/test_torch_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmcv.utils import torch_meshgrid


def test_torch_meshgrid():
# torch_meshgrid should not throw warning
with pytest.warns(None) as record:
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
grid_x, grid_y = torch_meshgrid(x, y)

assert len(record) == 0

0 comments on commit f5425ab

Please sign in to comment.