diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aa4e24f7dd..5c0c7546aa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -47,7 +47,7 @@ jobs: run: rm -rf .eggs && pip install -e . - name: Run unittests and generate coverage report run: | - pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py + pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py build_without_ops: runs-on: ubuntu-latest diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index bf8792fae9..b2f160d9af 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -33,6 +33,7 @@ DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config) + from .parrots_jit import jit, skip_no_elena, skip_no_parrots from .registry import Registry, build_from_cfg __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', @@ -48,5 +49,6 @@ '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION', 'deprecated_api_warning', 'digit_version', - 'get_git_hash', 'import_modules_from_strings' + 'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena', + 'skip_no_parrots' ] diff --git a/mmcv/utils/parrots_jit.py b/mmcv/utils/parrots_jit.py new file mode 100644 index 0000000000..552fc999eb --- /dev/null +++ b/mmcv/utils/parrots_jit.py @@ -0,0 +1,47 @@ +import pytest +import torch + +TORCH_VERSION = torch.__version__ + +if TORCH_VERSION == 'parrots': + from parrots.jit import pat as jit +else: + + def jit(func=None, + check_input=None, + full_shape=True, + derivate=False, + coderize=False, + optimize=False): + + def wrapper(func): + + def wrapper_inner(*args, **kargs): + return func(*args, **kargs) + + return wrapper_inner + + if func is None: + return wrapper + else: + return func + + +if TORCH_VERSION == 'parrots': + from parrots.utils.tester import skip_no_elena +else: + + def skip_no_elena(func): + + def wrapper(*args, **kargs): + return func(*args, **kargs) + + return wrapper + + +def is_using_parrots(): + return TORCH_VERSION == 'parrots' + + +skip_no_parrots = pytest.mark.skipif( + not is_using_parrots(), reason='test case under parrots environment') diff --git a/tests/test_utils/test_parrots_jit.py b/tests/test_utils/test_parrots_jit.py new file mode 100644 index 0000000000..66c85e4298 --- /dev/null +++ b/tests/test_utils/test_parrots_jit.py @@ -0,0 +1,272 @@ +import pytest +import torch + +import mmcv + + +class TestJit(object): + + def test_add_dict(self): + + @mmcv.jit + def add_dict(oper): + rets = oper['x'] + oper['y'] + return {'result': rets} + + def add_dict_pyfunc(oper): + rets = oper['x'] + oper['y'] + return {'result': rets} + + a = torch.rand((3, 4)) + b = torch.rand((3, 4)) + oper = {'x': a, 'y': b} + + rets_t = add_dict(oper) + rets = add_dict_pyfunc(oper) + assert 'result' in rets + assert (rets_t['result'] == rets['result']).all() + + def test_add_list(self): + + @mmcv.jit + def add_list(oper, x, y): + rets = {} + for idx, pair in enumerate(oper): + rets[f'k{idx}'] = pair['x'] + pair['y'] + rets[f'k{len(oper)}'] = x + y + return rets + + def add_list_pyfunc(oper, x, y): + rets = {} + for idx, pair in enumerate(oper): + rets[f'k{idx}'] = pair['x'] + pair['y'] + rets[f'k{len(oper)}'] = x + y + return rets + + pair_num = 3 + oper = [] + for _ in range(pair_num): + oper.append({'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))}) + a = torch.rand((3, 4)) + b = torch.rand((3, 4)) + rets = add_list_pyfunc(oper, x=a, y=b) + rets_t = add_list(oper, x=a, y=b) + for idx in range(pair_num + 1): + assert f'k{idx}' in rets_t + assert (rets[f'k{idx}'] == rets_t[f'k{idx}']).all() + + @mmcv.skip_no_parrots + def test_jit_cache(self): + + @mmcv.jit + def func(oper): + if oper['const'] > 1: + return oper['x'] * 2 + oper['y'] + else: + return oper['x'] * 2 - oper['y'] + + def pyfunc(oper): + if oper['const'] > 1: + return oper['x'] * 2 + oper['y'] + else: + return oper['x'] * 2 - oper['y'] + + assert len(func._cache._cache) == 0 + + oper = {'const': 2, 'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))} + rets_plus = pyfunc(oper) + rets_plus_t = func(oper) + assert (rets_plus == rets_plus_t).all() + assert len(func._cache._cache) == 1 + + oper['const'] = 0.5 + rets_minus = pyfunc(oper) + rets_minus_t = func(oper) + assert (rets_minus == rets_minus_t).all() + assert len(func._cache._cache) == 2 + + rets_a = (rets_minus_t + rets_plus_t) / 4 + assert torch.allclose(oper['x'], rets_a) + + @mmcv.skip_no_parrots + def test_jit_shape(self): + + @mmcv.jit + def func(a): + return a + 1 + + assert len(func._cache._cache) == 0 + + a = torch.ones((3, 4)) + r = func(a) + assert r.shape == (3, 4) + assert (r == 2).all() + assert len(func._cache._cache) == 1 + + a = torch.ones((2, 3, 4)) + r = func(a) + assert r.shape == (2, 3, 4) + assert (r == 2).all() + assert len(func._cache._cache) == 2 + + @mmcv.skip_no_parrots + def test_jit_kwargs(self): + + @mmcv.jit + def func(a, b): + return torch.mean((a - b) * (a - b)) + + assert len(func._cache._cache) == 0 + x = torch.rand((16, 32)) + y = torch.rand((16, 32)) + func(x, y) + assert len(func._cache._cache) == 1 + func(x, b=y) + assert len(func._cache._cache) == 1 + func(b=y, a=x) + assert len(func._cache._cache) == 1 + + def test_jit_derivate(self): + + @mmcv.jit(derivate=True) + def func(x, y): + return (x + 2) * (y - 2) + + a = torch.rand((3, 4)) + b = torch.rand((3, 4)) + a.requires_grad = True + + c = func(a, b) + assert c.requires_grad + d = torch.empty_like(c) + d.fill_(1.0) + c.backward(d) + assert torch.allclose(a.grad, (b - 2)) + assert b.grad is None + + a.grad = None + c = func(a, b) + assert c.requires_grad + d = torch.empty_like(c) + d.fill_(2.7) + c.backward(d) + assert torch.allclose(a.grad, 2.7 * (b - 2)) + assert b.grad is None + + def test_jit_optimize(self): + + @mmcv.jit(optimize=True) + def func(a, b): + return torch.mean((a - b) * (a - b)) + + def pyfunc(a, b): + return torch.mean((a - b) * (a - b)) + + a = torch.rand((16, 32)) + b = torch.rand((16, 32)) + + c = func(a, b) + d = pyfunc(a, b) + assert torch.allclose(c, d) + + @mmcv.skip_no_elena + def test_jit_coderize(self): + if not torch.cuda.is_available(): + return + + @mmcv.jit(coderize=True) + def func(a, b): + return (a + b) * (a - b) + + def pyfunc(a, b): + return (a + b) * (a - b) + + a = torch.rand((16, 32), device='cuda') + b = torch.rand((16, 32), device='cuda') + + c = func(a, b) + d = pyfunc(a, b) + assert torch.allclose(c, d) + + def test_jit_value_dependent(self): + + @mmcv.jit + def func(a, b): + torch.nonzero(a) + return torch.mean((a - b) * (a - b)) + + def pyfunc(a, b): + torch.nonzero(a) + return torch.mean((a - b) * (a - b)) + + a = torch.rand((16, 32)) + b = torch.rand((16, 32)) + + c = func(a, b) + d = pyfunc(a, b) + assert torch.allclose(c, d) + + @mmcv.skip_no_parrots + def test_jit_check_input(self): + + def func(x): + y = torch.rand_like(x) + return x + y + + a = torch.ones((3, 4)) + with pytest.raises(AssertionError): + func = mmcv.jit(func, check_input=(a, )) + + @mmcv.skip_no_parrots + def test_jit_partial_shape(self): + + @mmcv.jit(full_shape=False) + def func(a, b): + return torch.mean((a - b) * (a - b)) + + def pyfunc(a, b): + return torch.mean((a - b) * (a - b)) + + a = torch.rand((3, 4)) + b = torch.rand((3, 4)) + assert torch.allclose(func(a, b), pyfunc(a, b)) + assert len(func._cache._cache) == 1 + + a = torch.rand((6, 5)) + b = torch.rand((6, 5)) + assert torch.allclose(func(a, b), pyfunc(a, b)) + assert len(func._cache._cache) == 1 + + a = torch.rand((3, 4, 5)) + b = torch.rand((3, 4, 5)) + assert torch.allclose(func(a, b), pyfunc(a, b)) + assert len(func._cache._cache) == 2 + + a = torch.rand((1, 9, 8)) + b = torch.rand((1, 9, 8)) + assert torch.allclose(func(a, b), pyfunc(a, b)) + assert len(func._cache._cache) == 2 + + def test_instance_method(self): + + class T(object): + + def __init__(self, shape): + self._c = torch.rand(shape) + + @mmcv.jit + def test_method(self, x, y): + return (x * self._c) + y + + shape = (16, 32) + t = T(shape) + a = torch.rand(shape) + b = torch.rand(shape) + res = (a * t._c) + b + jit_res = t.test_method(a, b) + assert torch.allclose(res, jit_res) + + t = T(shape) + res = (a * t._c) + b + jit_res = t.test_method(a, b) + assert torch.allclose(res, jit_res)