From b235e590a66e1367c2fbadece8e4a89a263eb4b1 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 2 Sep 2020 14:50:59 -0700 Subject: [PATCH] [TESTS] Refactor tests to run on either the GPU or CPU. (#6331) Much of the time spent in testing is duplicated work between CPU and GPU test nodes. The main reason is that there is no way to control which TVM devices are enabled at runtime, so tests that use LLVM will run on both GPU and CPU nodes. This patch adds an environment variable, TVM_TEST_DEVICES, which controls which TVM devices should be used by tests. Devices not in TVM_TEST_DEVICES can still be used, so tests must be careful to check that the desired device is enabled with `tvm.testing.device_enabled` or by enumerating all devices with `tvm.testing.enabled_devices`. All tests have been retrofitted with these checks. This patch also provides the decorator `@tvm.testing.gpu` to mark a test as possibly using the gpu. Tests that require the gpu can use `@tvm.testing.requires_gpu`. Tests without these flags will not be run on GPU nodes. --- apps/extension/tests/test_ext.py | 5 +- conftest.py | 29 ++ docs/contribute/code_guide.rst | 14 + .../tvm/relay/testing/config.py => pytest.ini | 25 +- python/tvm/relay/testing/__init__.py | 4 +- python/tvm/testing.py | 383 +++++++++++++++++- tests/lint/check_file_type.py | 2 + tests/python/contrib/test_cblas.py | 7 +- tests/python/contrib/test_cublas.py | 13 +- tests/python/contrib/test_cudnn.py | 13 +- tests/python/contrib/test_gemm_acc32_vnni.py | 4 +- tests/python/contrib/test_miopen.py | 4 +- tests/python/contrib/test_mps.py | 8 +- tests/python/contrib/test_nnpack.py | 9 +- tests/python/contrib/test_random.py | 28 +- tests/python/contrib/test_rocblas.py | 4 +- tests/python/frontend/caffe2/test_forward.py | 10 +- tests/python/frontend/coreml/test_forward.py | 63 +-- tests/python/frontend/keras/test_forward.py | 5 +- tests/python/frontend/mxnet/test_forward.py | 198 ++++++--- tests/python/frontend/onnx/test_forward.py | 195 ++++++--- tests/python/frontend/pytorch/test_forward.py | 132 +++++- .../frontend/tensorflow/test_bn_dynamic.py | 2 +- .../frontend/tensorflow/test_forward.py | 37 +- tests/python/frontend/tflite/test_forward.py | 2 +- tests/python/integration/test_dot.py | 5 +- tests/python/integration/test_ewise.py | 25 +- tests/python/integration/test_ewise_fpga.py | 8 +- tests/python/integration/test_gemm.py | 4 +- tests/python/integration/test_reduce.py | 27 +- tests/python/integration/test_scan.py | 4 +- tests/python/integration/test_tuning.py | 30 +- .../integration/test_winograd_nnpack.py | 5 +- .../test_quantization_accuracy.py | 2 + .../relay/dyn/test_dynamic_op_level10.py | 9 +- .../relay/dyn/test_dynamic_op_level2.py | 4 +- .../relay/dyn/test_dynamic_op_level3.py | 10 +- .../relay/dyn/test_dynamic_op_level5.py | 6 +- .../relay/dyn/test_dynamic_op_level6.py | 5 +- .../relay/test_backend_compile_engine.py | 5 +- .../relay/test_backend_graph_runtime.py | 5 +- .../python/relay/test_backend_interpreter.py | 2 +- tests/python/relay/test_cpp_build_module.py | 56 ++- tests/python/relay/test_op_grad_level1.py | 9 +- tests/python/relay/test_op_grad_level2.py | 15 +- tests/python/relay/test_op_grad_level3.py | 6 +- tests/python/relay/test_op_level1.py | 27 +- tests/python/relay/test_op_level10.py | 45 +- tests/python/relay/test_op_level2.py | 91 +++-- tests/python/relay/test_op_level3.py | 60 ++- tests/python/relay/test_op_level4.py | 30 +- tests/python/relay/test_op_level5.py | 63 +-- tests/python/relay/test_op_level6.py | 8 +- .../python/relay/test_pass_alter_op_layout.py | 6 +- tests/python/relay/test_pass_annotation.py | 37 +- .../relay/test_pass_dynamic_to_static.py | 21 +- tests/python/relay/test_pass_fuse_ops.py | 4 +- .../relay/test_pass_lazy_gradient_init.py | 1 + tests/python/relay/test_pass_manager.py | 14 +- tests/python/relay/test_vm.py | 20 +- tests/python/topi/python/common.py | 14 - tests/python/topi/python/test_fifo_buffer.py | 24 +- .../topi/python/test_topi_batch_matmul.py | 13 +- .../python/topi/python/test_topi_broadcast.py | 57 +-- tests/python/topi/python/test_topi_clip.py | 13 +- tests/python/topi/python/test_topi_conv1d.py | 12 +- .../python/test_topi_conv1d_transpose_ncw.py | 12 +- .../topi/python/test_topi_conv2d_NCHWc.py | 5 +- .../topi/python/test_topi_conv2d_hwcn.py | 4 +- .../test_topi_conv2d_hwnc_tensorcore.py | 7 +- .../topi/python/test_topi_conv2d_int8.py | 12 +- .../topi/python/test_topi_conv2d_nchw.py | 7 +- .../topi/python/test_topi_conv2d_nhwc.py | 5 +- .../python/test_topi_conv2d_nhwc_pack_int8.py | 2 +- .../test_topi_conv2d_nhwc_tensorcore.py | 5 +- .../python/test_topi_conv2d_nhwc_winograd.py | 15 +- .../python/test_topi_conv2d_transpose_nchw.py | 13 +- .../topi/python/test_topi_conv2d_winograd.py | 4 +- .../topi/python/test_topi_conv3d_ncdhw.py | 14 +- .../topi/python/test_topi_conv3d_ndhwc.py | 12 +- .../test_topi_conv3d_ndhwc_tensorcore.py | 9 +- .../test_topi_conv3d_transpose_ncdhw.py | 14 +- .../topi/python/test_topi_conv3d_winograd.py | 5 +- .../topi/python/test_topi_correlation.py | 14 +- .../python/test_topi_deformable_conv2d.py | 5 +- tests/python/topi/python/test_topi_dense.py | 19 +- .../topi/python/test_topi_dense_tensorcore.py | 12 +- .../topi/python/test_topi_depth_to_space.py | 14 +- .../topi/python/test_topi_depthwise_conv2d.py | 25 +- .../test_topi_depthwise_conv2d_back_input.py | 4 +- .../test_topi_depthwise_conv2d_back_weight.py | 4 +- .../topi/python/test_topi_group_conv2d.py | 9 +- .../test_topi_group_conv2d_NCHWc_int8.py | 5 +- tests/python/topi/python/test_topi_image.py | 56 +-- tests/python/topi/python/test_topi_lrn.py | 4 +- tests/python/topi/python/test_topi_math.py | 41 +- tests/python/topi/python/test_topi_pooling.py | 69 ++-- tests/python/topi/python/test_topi_reduce.py | 12 +- tests/python/topi/python/test_topi_relu.py | 14 +- tests/python/topi/python/test_topi_reorg.py | 4 +- tests/python/topi/python/test_topi_softmax.py | 22 +- tests/python/topi/python/test_topi_sort.py | 9 +- .../topi/python/test_topi_space_to_depth.py | 13 +- tests/python/topi/python/test_topi_sparse.py | 24 +- tests/python/topi/python/test_topi_tensor.py | 9 +- .../python/topi/python/test_topi_transform.py | 283 +++++-------- .../topi/python/test_topi_upsampling.py | 24 +- tests/python/topi/python/test_topi_util.py | 2 +- tests/python/topi/python/test_topi_vision.py | 32 +- .../unittest/test_auto_scheduler_measure.py | 13 +- .../test_auto_scheduler_search_policy.py | 15 +- .../test_auto_scheduler_sketch_generation.py | 25 +- .../unittest/test_autotvm_index_tuner.py | 2 +- .../unittest/test_hybrid_error_report.py | 2 +- tests/python/unittest/test_runtime_graph.py | 10 +- .../unittest/test_runtime_graph_debug.py | 7 +- .../test_runtime_module_based_interface.py | 30 +- .../unittest/test_runtime_module_export.py | 11 +- .../unittest/test_runtime_module_load.py | 12 +- tests/python/unittest/test_runtime_ndarray.py | 20 +- tests/python/unittest/test_runtime_rpc.py | 12 +- .../unittest/test_target_codegen_blob.py | 9 +- .../unittest/test_target_codegen_bool.py | 8 +- .../test_target_codegen_cross_llvm.py | 4 +- .../unittest/test_target_codegen_cuda.py | 351 +++++++--------- .../unittest/test_target_codegen_device.py | 13 +- .../unittest/test_target_codegen_extern.py | 8 +- .../unittest/test_target_codegen_llvm.py | 54 ++- .../unittest/test_target_codegen_opencl.py | 19 +- .../unittest/test_target_codegen_rocm.py | 10 +- .../unittest/test_target_codegen_vm_basic.py | 3 +- .../unittest/test_target_codegen_vulkan.py | 17 +- tests/python/unittest/test_te_autodiff.py | 9 +- .../python/unittest/test_te_hybrid_script.py | 49 ++- ...hedule_postproc_rewrite_for_tensor_core.py | 18 +- .../unittest/test_te_schedule_tensor_core.py | 18 +- .../unittest/test_te_tensor_overload.py | 13 +- tests/python/unittest/test_testing.py | 18 +- .../test_tir_analysis_verify_gpu_code.py | 22 +- .../test_tir_analysis_verify_memory.py | 43 +- tests/python/unittest/test_tir_buffer.py | 10 +- tests/python/unittest/test_tir_ir_builder.py | 6 +- .../unittest/test_tir_transform_hoist_if.py | 4 +- ...tir_transform_instrument_bound_checkers.py | 19 +- .../test_tir_transform_lower_intrin.py | 6 +- .../test_tir_transform_lower_warp_memory.py | 21 +- .../test_tir_transform_thread_sync.py | 2 + tests/scripts/setup-pytest-env.sh | 4 +- tests/scripts/task_python_frontend.sh | 2 + tests/scripts/task_python_frontend_cpu.sh | 2 + tests/scripts/task_python_integration.sh | 2 +- .../task_python_integration_gpuonly.sh | 4 + tests/scripts/task_python_unittest_gpuonly.sh | 3 + tutorials/frontend/deploy_ssd_gluoncv.py | 10 +- 154 files changed, 2214 insertions(+), 1598 deletions(-) create mode 100644 conftest.py rename python/tvm/relay/testing/config.py => pytest.ini (62%) diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index f7e17d2fdc62..defac94528a3 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -17,6 +17,7 @@ import tvm_ext import tvm import tvm._ffi.registry +import tvm.testing from tvm import te import numpy as np @@ -32,7 +33,7 @@ def test_ext_dev(): B = te.compute((n,), lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) def check_llvm(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return f = tvm.build(s, [A, B], "ext_dev", "llvm") ctx = tvm.ext_dev(0) @@ -77,7 +78,7 @@ def test_extern_call(): s = te.create_schedule(B.op) def check_llvm(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return f = tvm.build(s, [A, B], "llvm") ctx = tvm.cpu(0) diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000000..edf1a73f0b43 --- /dev/null +++ b/conftest.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm.testing +from pytest import ExitCode + +def pytest_configure(config): + print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets()))) + print("pytest marker:", config.option.markexpr) + +def pytest_sessionfinish(session, exitstatus): + # Don't exit with an error if we select a subset of tests that doesn't + # include anything + if session.config.option.markexpr != '': + if exitstatus == ExitCode.NO_TESTS_COLLECTED: + session.exitstatus = ExitCode.OK diff --git a/docs/contribute/code_guide.rst b/docs/contribute/code_guide.rst index c932e93a11f1..c0b022bcf549 100644 --- a/docs/contribute/code_guide.rst +++ b/docs/contribute/code_guide.rst @@ -82,6 +82,20 @@ Python Code Styles - Stick to language features as in ``python 3.5`` +Writing Python Tests +-------------------- +We use `pytest `_ for all python testing. ``tests/python`` contains all the tests. + +If you want your test to run over a variety of targets, use the :py:func:`tvm.testing.parametrize_targets` decorator. For example: + +.. code:: python + + @tvm.testing.parametrize_targets + def test_mytest(target, ctx): + ... + +will run `test_mytest` with `target="llvm"`, `target="cuda"`, and few others. This also ensures that your test is run on the correct hardware by the CI. If you only want to test against a couple targets use `@tvm.testing.parametrize_targets("target_1", "target_2")`. If you want to test on a single target, use the associated decorator from :py:func:`tvm.testing`. For example, CUDA tests use the `@tvm.testing.requires_cuda` decorator. + Handle Integer Constant Expression ---------------------------------- We often need to handle constant integer expressions in TVM. Before we do so, the first question we want to ask is that is it really necessary to get a constant integer. If symbolic expression also works and let the logic flow, we should use symbolic expression as much as possible. So the generated code works for shapes that are not known ahead of time. diff --git a/python/tvm/relay/testing/config.py b/pytest.ini similarity index 62% rename from python/tvm/relay/testing/config.py rename to pytest.ini index 93a08db32d2c..675f8fe9b5a0 100644 --- a/python/tvm/relay/testing/config.py +++ b/pytest.ini @@ -14,18 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Configuration about tests""" -from __future__ import absolute_import as _abs - -import os -import tvm - - -def ctx_list(): - """Get context list for testcases""" - device_list = os.environ.get("RELAY_TEST_TARGETS", "") - device_list = (device_list.split(",") if device_list - else ["llvm", "cuda"]) - device_list = set(device_list) - res = [(device, tvm.context(device, 0)) for device in device_list] - return [x for x in res if x[1].exist] +[pytest] +markers = + gpu: mark a test as requiring a gpu + tensorcore: mark a test as requiring a tensorcore + cuda: mark a test as requiring cuda + opencl: mark a test as requiring opencl + rocm: mark a test as requiring rocm + vulkan: mark a test as requiring vulkan + metal: mark a test as requiring metal + llvm: mark a test as requiring llvm diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 904e4d7baf28..534015fdf888 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -25,6 +25,7 @@ import tvm.relay as relay import tvm.relay.op as op from tvm.relay import Prelude +from tvm.testing import enabled_targets from . import mlp from . import resnet @@ -41,7 +42,6 @@ from . import temp_op_attr from . import synthetic -from .config import ctx_list from .init import create_workload from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr from .py_converter import to_python, run_as_python @@ -125,7 +125,7 @@ def check_grad(func, if test_inputs is None: test_inputs = inputs - for target, ctx in ctx_list(): + for target, ctx in enabled_targets(): intrp = relay.create_executor(ctx=ctx, target=target) # Get analytic gradients. diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 7483a9fb4cf8..0a568b02fc9d 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -16,14 +16,54 @@ # under the License. # pylint: disable=invalid-name,unnecessary-comprehension -""" TVM testing utilities """ +""" TVM testing utilities + +Testing Markers +*************** + +We use pytest markers to specify the requirements of test functions. Currently +there is a single distinction that matters for our testing environment: does +the test require a gpu. For tests that require just a gpu or just a cpu, we +have the decorator :py:func:`requires_gpu` that enables the test when a gpu is +available. To avoid running tests that don't require a gpu on gpu nodes, this +decorator also sets the pytest marker `gpu` so we can use select the gpu subset +of tests (using `pytest -m gpu`). + +Unfortunately, many tests are written like this: + +.. python:: + + def test_something(): + for target in all_targets(): + do_something() + +The test uses both gpu and cpu targets, so the test needs to be run on both cpu +and gpu nodes. But we still want to only run the cpu targets on the cpu testing +node. The solution is to mark these tests with the gpu marker so they will be +run on the gpu nodes. But we also modify all_targets (renamed to +enabled_targets) so that it only returns gpu targets on gpu nodes and cpu +targets on cpu nodes (using an environment variable). + +Instead of using the all_targets function, future tests that would like to +test against a variety of targets should use the +:py:func:`tvm.testing.parametrize_targets` functionality. This allows us +greater control over which targets are run on which testing nodes. + +If in the future we want to add a new type of testing node (for example +fpgas), we need to add a new marker in `tests/python/pytest.ini` and a new +function in this module. Then targets using this node should be added to the +`TVM_TEST_TARGETS` environment variable in the CI. +""" import logging +import os +import pytest import numpy as np import tvm import tvm.arith import tvm.tir import tvm.te import tvm._ffi +from tvm.contrib import nvcc def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): @@ -285,4 +325,345 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): constraints_trans.dst_to_src, constraints_trans.src_to_dst) +def _get_targets(): + target_str = os.environ.get("TVM_TEST_TARGETS", "") + if len(target_str) == 0: + target_str = DEFAULT_TEST_TARGETS + targets = { + dev + for dev in target_str.split(";") + if len(dev) > 0 and tvm.context(dev, 0).exist and tvm.runtime.enabled(dev) + } + if len(targets) == 0: + logging.warning( + "None of the following targets are supported by this build of TVM: %s." + " Try setting TVM_TEST_TARGETS to a supported target. Defaulting to llvm.", + target_str, + ) + return {"llvm"} + return targets + + +DEFAULT_TEST_TARGETS = ( + "llvm;cuda;opencl;metal;rocm;vulkan;nvptx;" + "llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu" +) + + +def device_enabled(target): + """Check if a target should be used when testing. + + It is recommended that you use :py:func:`tvm.testing.parametrize_targets` + instead of manually checking if a target is enabled. + + This allows the user to control which devices they are testing against. In + tests, this should be used to check if a device should be used when said + device is an optional part of the test. + + Parameters + ---------- + target : str + Target string to check against + + Returns + ------- + bool + Whether or not the device associated with this target is enabled. + + Example + ------- + >>> @tvm.testing.uses_gpu + >>> def test_mytest(): + >>> for target in ["cuda", "llvm"]: + >>> if device_enabled(target): + >>> test_body... + + Here, `test_body` will only be reached by with `target="cuda"` on gpu test + nodes and `target="llvm"` on cpu test nodes. + """ + assert isinstance(target, str), "device_enabled requires a target as a string" + target_kind = target.split(" ")[ + 0 + ] # only check if device name is found, sometime there are extra flags + return any([target_kind in test_target for test_target in _get_targets()]) + + +def enabled_targets(): + """Get all enabled targets with associated contexts. + + In most cases, you should use :py:func:`tvm.testing.parametrize_targets` instead of + this function. + + In this context, enabled means that TVM was built with support for this + target and the target name appears in the TVM_TEST_TARGETS environment + variable. If TVM_TEST_TARGETS is not set, it defaults to variable + DEFAULT_TEST_TARGETS in this module. + + If you use this function in a test, you **must** decorate the test with + :py:func:`tvm.testing.uses_gpu` (otherwise it will never be run on the gpu). + + Returns + ------- + targets: list + A list of pairs of all enabled devices and the associated context + """ + return [(tgt, tvm.context(tgt)) for tgt in _get_targets()] + + +def _compose(args, decs): + """Helper to apply multiple markers + """ + if len(args) > 0: + f = args[0] + for d in reversed(decs): + f = d(f) + return f + return decs + + +def uses_gpu(*args): + """Mark to differentiate tests that use the GPU is some capacity. + + These tests will be run on CPU-only test nodes and on test nodes with GPUS. + To mark a test that must have a GPU present to run, use + :py:func:`tvm.testing.requires_gpu`. + + Parameters + ---------- + f : function + Function to mark + """ + _uses_gpu = [pytest.mark.gpu] + return _compose(args, _uses_gpu) + + +def requires_gpu(*args): + """Mark a test as requiring a GPU to run. + + Tests with this mark will not be run unless a gpu is present. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_gpu = [ + pytest.mark.skipif(not tvm.gpu().exist, reason="No GPU present"), + *uses_gpu(), + ] + return _compose(args, _requires_gpu) + + + + +def requires_cuda(*args): + """Mark a test as requiring the CUDA runtime. + + This also marks the test as requiring a gpu. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_cuda = [ + pytest.mark.cuda, + pytest.mark.skipif( + not device_enabled("cuda"), reason="CUDA support not enabled" + ), + *requires_gpu(), + ] + return _compose(args, _requires_cuda) + + + + +def requires_opencl(*args): + """Mark a test as requiring the OpenCL runtime. + + This also marks the test as requiring a gpu. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_opencl = [ + pytest.mark.opencl, + pytest.mark.skipif( + not device_enabled("opencl"), reason="OpenCL support not enabled" + ), + *requires_gpu(), + ] + return _compose(args, _requires_opencl) + + + + +def requires_rocm(*args): + """Mark a test as requiring the rocm runtime. + + This also marks the test as requiring a gpu. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_rocm = [ + pytest.mark.rocm, + pytest.mark.skipif( + not device_enabled("rocm"), reason="rocm support not enabled" + ), + *requires_gpu(), + ] + return _compose(args, _requires_rocm) + + + + +def requires_metal(*args): + """Mark a test as requiring the metal runtime. + + This also marks the test as requiring a gpu. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_metal = [ + pytest.mark.metal, + pytest.mark.skipif( + not device_enabled("metal"), reason="metal support not enabled" + ), + *requires_gpu(), + ] + return _compose(args, _requires_metal) + + + + +def requires_vulkan(*args): + """Mark a test as requiring the vulkan runtime. + + This also marks the test as requiring a gpu. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_vulkan = [ + pytest.mark.vulkan, + pytest.mark.skipif( + not device_enabled("vulkan"), reason="vulkan support not enabled" + ), + *requires_gpu(), + ] + return _compose(args, _requires_vulkan) + + + + +def requires_tensorcore(*args): + """Mark a test as requiring a tensorcore to run. + + Tests with this mark will not be run unless a tensorcore is present. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_tensorcore = [ + pytest.mark.tensorcore, + pytest.mark.skipif( + not tvm.gpu().exist or not nvcc.have_tensorcore(tvm.gpu(0).compute_version), + reason="No tensorcore present", + ), + *requires_gpu(), + ] + return _compose(args, _requires_tensorcore) + + + + +def requires_llvm(*args): + """Mark a test as requiring llvm to run. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_llvm = [ + pytest.mark.llvm, + pytest.mark.skipif( + not device_enabled("llvm"), reason="LLVM support not enabled" + ), + ] + return _compose(args, _requires_llvm) + + +def _target_to_requirement(target): + # mapping from target to decorator + if target.startswith("cuda"): + return requires_cuda() + if target.startswith("rocm"): + return requires_rocm() + if target.startswith("vulkan"): + return requires_vulkan() + if target.startswith("nvptx"): + return [*requires_llvm(), *requires_gpu()] + if target.startswith("metal"): + return requires_metal() + if target.startswith("opencl"): + return requires_opencl() + if target.startswith("llvm"): + return requires_llvm() + return [] + + +def parametrize_targets(*args): + """Parametrize a test over all enabled targets. + + Use this decorator when you want your test to be run over a variety of + targets and devices (including cpu and gpu devices). + + Parameters + ---------- + f : function + Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, ctx)`:, + where `xxxxxxxxx` is any name. + targets : list[str], optional + Set of targets to run against. If not supplied, + :py:func:`tvm.testing.enabled_targets` will be used. + + Example + ------- + >>> @tvm.testing.parametrize + >>> def test_mytest(target, ctx): + >>> ... # do something + + Or + + >>> @tvm.testing.parametrize("llvm", "cuda") + >>> def test_mytest(target, ctx): + >>> ... # do something + """ + def wrap(targets): + def func(f): + params = [ + pytest.param(target, tvm.context(target, 0), marks=_target_to_requirement(target)) + for target in targets + ] + return pytest.mark.parametrize("target,ctx", params)(f) + return func + if len(args) == 1 and callable(args[0]): + targets = [t for t, _ in enabled_targets()] + return wrap(targets)(args[0]) + return wrap(args) + + tvm._ffi._init_api("testing", __name__) diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index f803647d91a1..9c0a607002a3 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -122,6 +122,8 @@ "docs/_static/css/tvm_theme.css", "docs/_static/img/tvm-logo-small.png", "docs/_static/img/tvm-logo-square.png", + # pytest config + "pytest.ini", } diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index e1c1c7125536..7247ab7c6beb 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -22,6 +22,7 @@ from tvm.contrib import cblas from tvm.contrib import mkl from tvm.contrib import mkldnn +import tvm.testing def verify_matmul_add(m, l, n, lib, transa=False, transb=False, dtype="float32"): bias = te.var('bias', dtype=dtype) @@ -41,7 +42,7 @@ def get_numpy(a, b, bb, transa, transb): return np.dot(a, b) + bb def verify(target="llvm"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return if not tvm.get_global_func(lib.__name__ + ".matmul", True): @@ -107,7 +108,7 @@ def get_numpy(a, b, bb, transa, transb): return np.dot(a, b) + bb def verify(target="llvm"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return if not tvm.get_global_func("tvm.contrib.mkl.matmul_u8s8s32", True): @@ -153,7 +154,7 @@ def get_numpy(a, b, transa, transb): return tvm.topi.testing.batch_matmul(a, b) def verify(target="llvm"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return if not tvm.get_global_func(lib.__name__ + ".matmul", True): diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index 517e6e124030..f387f35925b3 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -19,6 +19,7 @@ import numpy as np from tvm.contrib import cublas from tvm.contrib import cublaslt +import tvm.testing def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): n = 1024 @@ -30,9 +31,6 @@ def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): s = te.create_schedule(C.op) def verify(target="cuda"): - if not tvm.runtime.enabled(target): - print("skip because %s is not enabled..." % target) - return if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): print("skip because extern function is not available") return @@ -64,9 +62,6 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5): s = te.create_schedule(C.op) def verify(target="cuda"): - if not tvm.runtime.enabled(target): - print("skip because %s is not enabled..." % target) - return if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True): print("skip because extern function is not available") return @@ -115,9 +110,6 @@ def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5): s = te.create_schedule(C.op) def verify(target="cuda"): - if not tvm.runtime.enabled(target): - print("skip because %s is not enabled..." % target) - return if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): print("skip because extern function is not available") return @@ -132,15 +124,18 @@ def verify(target="cuda"): b.asnumpy().astype(C.dtype)).astype(C.dtype), rtol=rtol) verify() +@tvm.testing.requires_cuda def test_matmul_add(): verify_matmul_add('float', 'float', rtol=1e-3) verify_matmul_add('float16', 'float') verify_matmul_add('float16', 'float16', rtol=1e-2) verify_matmul_add('int8', 'int32') +@tvm.testing.requires_cuda def test_matmul_add_igemm(): verify_matmul_add_igemm('int8', 'int32') +@tvm.testing.requires_cuda def test_batch_matmul(): verify_batch_matmul('float', 'float') verify_batch_matmul('float16', 'float') diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 61822c849a7e..5777c3b73b9d 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -20,6 +20,7 @@ from tvm.contrib.nvcc import have_fp16 import numpy as np import tvm.topi.testing +import tvm.testing def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): in_channel = 4 @@ -36,9 +37,6 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): height = 32 width = 32 - if not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled...") - return if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): print("skip because cudnn is not enabled...") return @@ -87,6 +85,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): f(x, w, y) tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-2, rtol=1e-2) +@tvm.testing.requires_gpu def test_conv2d(): verify_conv2d("float32", "float32", tensor_format=0) verify_conv2d("float16", "float32", tensor_format=1) @@ -118,9 +117,6 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): height = 32 width = 32 - if not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled...") - return if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): print("skip because cudnn is not enabled...") return @@ -161,6 +157,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): f(x, w, y) tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4) +@tvm.testing.requires_gpu def test_conv3d(): verify_conv3d("float32", "float32", tensor_format=0) verify_conv3d("float32", "float32", tensor_format=0, groups=2) @@ -195,10 +192,8 @@ def verify_softmax_4d(shape, dtype="float32"): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) +@tvm.testing.requires_gpu def test_softmax(): - if not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled...") - return if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): print("skip because cudnn is not enabled...") return diff --git a/tests/python/contrib/test_gemm_acc32_vnni.py b/tests/python/contrib/test_gemm_acc32_vnni.py index 37101a80ea77..538004024d75 100644 --- a/tests/python/contrib/test_gemm_acc32_vnni.py +++ b/tests/python/contrib/test_gemm_acc32_vnni.py @@ -17,6 +17,7 @@ # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition import tvm +import tvm.testing from tvm import te import numpy as np from tvm.topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake @@ -24,6 +25,7 @@ import pytest +@tvm.testing.requires_llvm @pytest.mark.skip("skip because feature not enabled") def test_fc_int8_acc32(): m = 1024 @@ -42,7 +44,7 @@ def test_fc_int8_acc32(): # (ignoring processor)" error with the following setting. After LLVM 8.0 is enabled in the # test, we should use cascadelake setting. def verify(target="llvm -mcpu=cascadelake"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return diff --git a/tests/python/contrib/test_miopen.py b/tests/python/contrib/test_miopen.py index deffbe9f4980..e8d348e0e365 100644 --- a/tests/python/contrib/test_miopen.py +++ b/tests/python/contrib/test_miopen.py @@ -20,6 +20,7 @@ import numpy as np +@tvm.testing.requires_rocm def test_conv2d(): in_channel = 3 out_channel = 64 @@ -33,9 +34,6 @@ def test_conv2d(): dilation_w = 1 xshape = [1, in_channel, 128, 128] - if not tvm.runtime.enabled("rocm"): - print("skip because rocm is not enabled...") - return if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True): print("skip because miopen is not enabled...") return diff --git a/tests/python/contrib/test_mps.py b/tests/python/contrib/test_mps.py index b5243659c1d5..1f0906e73602 100644 --- a/tests/python/contrib/test_mps.py +++ b/tests/python/contrib/test_mps.py @@ -19,10 +19,8 @@ import numpy as np from tvm.contrib import mps +@tvm.testing.requires_metal def test_matmul(): - if not tvm.runtime.enabled("metal"): - print("skip because %s is not enabled..." % "metal") - return n = 1024 l = 128 m = 256 @@ -62,10 +60,8 @@ def verify(A, B, D, s, target="metal"): c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 1, rtol=1e-5) verify(A, B, D, s) +@tvm.testing.requires_metal def test_conv2d(): - if not tvm.runtime.enabled("metal"): - print("skip because %s is not enabled..." % "metal") - return n = 1 h = 14 w = 14 diff --git a/tests/python/contrib/test_nnpack.py b/tests/python/contrib/test_nnpack.py index 81fcb123ebc1..bbee2b65fab1 100644 --- a/tests/python/contrib/test_nnpack.py +++ b/tests/python/contrib/test_nnpack.py @@ -23,6 +23,7 @@ import pytest +@tvm.testing.requires_llvm def test_fully_connected_inference(): n = 1024 l = 128 @@ -35,8 +36,6 @@ def test_fully_connected_inference(): s = te.create_schedule(D.op) def verify(target="llvm"): - if not tvm.runtime.enabled(target): - pytest.skip("%s is not enabled..." % target) if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): pytest.skip("extern function is not available") if not nnpack.is_available(): @@ -82,6 +81,7 @@ def np_conv(na, nw, padding, stride=1): nb[n, f] += out[::stride, ::stride] return nb +@tvm.testing.requires_llvm def test_convolution_inference(): BATCH = 8 IH = 48 @@ -105,8 +105,6 @@ def test_convolution_inference(): def verify(target="llvm", algorithm=nnpack.ConvolutionAlgorithm.AUTO, with_bias=True): - if not tvm.runtime.enabled(target): - pytest.skip("%s is not enabled..." % target) if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): pytest.skip("extern function is not available") if not nnpack.is_available(): @@ -144,6 +142,7 @@ def verify(target="llvm", verify(algorithm=algorithm, with_bias=with_bias) +@tvm.testing.requires_llvm def test_convolution_inference_without_weight_transform(): BATCH = 6 IH = 48 @@ -167,8 +166,6 @@ def test_convolution_inference_without_weight_transform(): def verify(target="llvm", algorithm=nnpack.ConvolutionAlgorithm.AUTO, with_bias=True): - if not tvm.runtime.enabled(target): - pytest.skip("%s is not enabled..." % target) if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): pytest.skip("extern function is not available") if not nnpack.is_available(): diff --git a/tests/python/contrib/test_random.py b/tests/python/contrib/test_random.py index e61030bd9835..c3601c7d6101 100644 --- a/tests/python/contrib/test_random.py +++ b/tests/python/contrib/test_random.py @@ -19,21 +19,7 @@ import numpy as np from tvm.contrib import random from tvm import rpc - -def enabled_ctx_list(): - ctx_list = [('cpu', tvm.cpu(0)), - ('gpu', tvm.gpu(0)), - ('cl', tvm.opencl(0)), - ('metal', tvm.metal(0)), - ('rocm', tvm.rocm(0)), - ('vulkan', tvm.vulkan(0)), - ('vpi', tvm.vpi(0))] - for k, v in ctx_list: - assert tvm.context(k, 0) == v - ctx_list = [x[1] for x in ctx_list if x[1].exist] - return ctx_list - -ENABLED_CTX_LIST = enabled_ctx_list() +import tvm.testing def test_randint(): m = 10240 @@ -42,7 +28,7 @@ def test_randint(): s = te.create_schedule(A.op) def verify(target="llvm"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return if not tvm.get_global_func("tvm.contrib.random.randint", True): @@ -66,7 +52,7 @@ def test_uniform(): s = te.create_schedule(A.op) def verify(target="llvm"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return if not tvm.get_global_func("tvm.contrib.random.uniform", True): @@ -90,7 +76,7 @@ def test_normal(): s = te.create_schedule(A.op) def verify(target="llvm"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return if not tvm.get_global_func("tvm.contrib.random.normal", True): @@ -105,6 +91,7 @@ def verify(target="llvm"): assert abs(np.std(na) - 4) < 1e-2 verify() +@tvm.testing.uses_gpu def test_random_fill(): def test_local(ctx, dtype): if not tvm.get_global_func("tvm.contrib.random.random_fill", True): @@ -125,7 +112,7 @@ def test_rpc(dtype): if not tvm.get_global_func("tvm.contrib.random.random_fill", True): print("skip because extern function is not available") return - if not tvm.runtime.enabled("rpc") or not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("rpc") or not tvm.runtime.enabled("llvm"): return np_ones = np.ones((512, 512), dtype=dtype) server = rpc.Server("localhost") @@ -142,7 +129,7 @@ def test_rpc(dtype): for dtype in ["bool", "int8", "uint8", "int16", "uint16", "int32", "int32", "int64", "uint64", "float16", "float32", "float64"]: - for ctx in ENABLED_CTX_LIST: + for _, ctx in tvm.testing.enabled_targets(): test_local(ctx, dtype) test_rpc(dtype) @@ -151,3 +138,4 @@ def test_rpc(dtype): test_uniform() test_normal() test_random_fill() + diff --git a/tests/python/contrib/test_rocblas.py b/tests/python/contrib/test_rocblas.py index af9d6ddf8dc9..f5ec5be165db 100644 --- a/tests/python/contrib/test_rocblas.py +++ b/tests/python/contrib/test_rocblas.py @@ -19,6 +19,7 @@ import numpy as np from tvm.contrib import rocblas +@tvm.testing.requires_rocm def test_matmul_add(): n = 1024 l = 128 @@ -29,9 +30,6 @@ def test_matmul_add(): s = te.create_schedule(C.op) def verify(target="rocm"): - if not tvm.runtime.enabled(target): - print("skip because %s is not enabled..." % target) - return if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True): print("skip because extern function is not available") return diff --git a/tests/python/frontend/caffe2/test_forward.py b/tests/python/frontend/caffe2/test_forward.py index 50a878180ac9..84d03d905060 100644 --- a/tests/python/frontend/caffe2/test_forward.py +++ b/tests/python/frontend/caffe2/test_forward.py @@ -18,12 +18,12 @@ import tvm from tvm import te from tvm.contrib import graph_runtime -from tvm.relay.testing.config import ctx_list from tvm import relay from model_zoo import c2_squeezenet, c2_resnet50, c2_vgg19 from caffe2.python import workspace, core from caffe2.proto import caffe2_pb2 from collections import namedtuple +import tvm.testing def get_tvm_output(model, @@ -84,19 +84,22 @@ def verify_caffe2_forward_impl(model, data_shape, out_shape): dtype = 'float32' data = np.random.uniform(size=data_shape).astype(dtype) c2_out = get_caffe2_output(model, data, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, data, target, ctx, out_shape, dtype) tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_forward_squeezenet1_1(): verify_caffe2_forward_impl(c2_squeezenet, (1, 3, 224, 224), (1, 1000, 1, 1)) +@tvm.testing.uses_gpu def test_forward_resnet50(): verify_caffe2_forward_impl(c2_resnet50, (1, 3, 224, 224), (1, 1000)) +@tvm.testing.uses_gpu def test_forward_vgg19(): verify_caffe2_forward_impl(c2_vgg19, (1, 3, 224, 224), (1, 1000)) @@ -104,6 +107,7 @@ def test_forward_vgg19(): Model = namedtuple('Model', ['init_net', 'predict_net']) +@tvm.testing.uses_gpu def test_elementwise_add(): data_shape = (1, 16, 9, 9) init_net = caffe2_pb2.NetDef() @@ -142,6 +146,7 @@ def test_elementwise_add(): verify_caffe2_forward_impl(model, data_shape, data_shape) +@tvm.testing.uses_gpu def test_elementwise_add_with_broadcast(): data_shape = (1, 16, 9, 9) init_net = caffe2_pb2.NetDef() @@ -181,6 +186,7 @@ def test_elementwise_add_with_broadcast(): verify_caffe2_forward_impl(model, data_shape, data_shape) +@tvm.testing.uses_gpu def test_normalize_yuv(): data_shape = (1, 3, 96, 96) init_net = caffe2_pb2.NetDef() diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 5ae7a6cc875f..d3a31fe6fa16 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -25,11 +25,11 @@ from tvm import topi import tvm.topi.testing from tvm import relay -from tvm.relay.testing.config import ctx_list from tvm.topi.testing import conv2d_nchw_python import coremltools as cm import model_zoo +import tvm.testing def get_tvm_output(func, x, params, target, ctx, out_shape=(1, 1000), input_name='image', dtype='float32'): @@ -50,15 +50,17 @@ def run_model_checkonly(model_file, model_name='', input_name='image'): shape_dict = {input_name : x.shape} # Some Relay passes change operators on the fly. Ensuring that we generate # new graph for each target. - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): mod, params = relay.frontend.from_coreml(model, shape_dict) tvm_output = get_tvm_output(mod["main"], x, params, target, ctx) print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat)) +@tvm.testing.uses_gpu def test_mobilenet_checkonly(): model_file = model_zoo.get_mobilenet() run_model_checkonly(model_file, 'mobilenet') +@tvm.testing.uses_gpu def test_resnet50_checkonly(): model_file = model_zoo.get_resnet50() run_model_checkonly(model_file, 'resnet50') @@ -122,10 +124,11 @@ def verify_AddLayerParams(input_dim, alpha=2): output_name='output', mode='ADD') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np1, a_np2], ['input1', 'input2'], b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_AddLayerParams(): verify_AddLayerParams((1, 2, 2), 0) verify_AddLayerParams((1, 2, 2), 1) @@ -148,10 +151,11 @@ def verify_MultiplyLayerParams(input_dim, alpha): output_name='output', mode='MULTIPLY') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np1, a_np2], ['input1', 'input2'], b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_MultiplyLayerParams(): verify_MultiplyLayerParams((1, 2, 2), 0) verify_MultiplyLayerParams((1, 2, 2), 1) @@ -173,10 +177,11 @@ def verify_ConcatLayerParams(input1_dim, input2_dim): output_name='output', mode='CONCAT') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np1, a_np2], ['input1', 'input2'], b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_ConcatLayerParams(): verify_ConcatLayerParams((1, 1, 2, 2), (1, 2, 2, 2)) verify_ConcatLayerParams((1, 2, 4, 4), (1, 3, 4, 4)) @@ -203,10 +208,11 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): output_name='output') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, a_np, 'input', b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_UpsampleLayerParams(): verify_UpsampleLayerParams((1, 16, 32, 32), 2, 'NN') verify_UpsampleLayerParams((1, 4, 6, 6), 3, 'BILINEAR') @@ -223,10 +229,11 @@ def verify_l2_normalize(input_dim, eps): builder.add_l2_normalize(name='L2', epsilon=eps, input_name='input', output_name='output') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, a_np, 'input', b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_l2_normalize(): verify_l2_normalize((1, 3, 20, 20), 0.001) @@ -248,10 +255,11 @@ def verify_lrn(input_dim, size, bias, alpha, beta): local_size=size) model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, a_np, 'input', b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_lrn(): verify_lrn((1, 3, 10, 20), 3, 1.0, 1.0, 0.5) @@ -272,10 +280,11 @@ def verify_average(input_dim1, input_dim2, axis=0): output_name='output', mode='AVE') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np1, a_np2], ['input1', 'input2'], b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_average(): verify_average((1, 3, 20, 20), (1, 3, 20, 20)) verify_average((3, 20, 20), (1, 3, 20, 20)) @@ -300,11 +309,12 @@ def verify_max(input_dim): output_name='output', mode='MAX') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np1, a_np2, a_np3], ['input1', 'input2', 'input3'], b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_max(): verify_max((1, 3, 20, 20)) verify_max((20, 20)) @@ -328,11 +338,12 @@ def verify_min(input_dim): output_name='output', mode='MIN') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np1, a_np2, a_np3], ['input1', 'input2', 'input3'], b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_min(): verify_min((1, 3, 20, 20)) verify_min((20, 20)) @@ -353,7 +364,7 @@ def verify_unary_sqrt(input_dim): mode='sqrt') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -375,7 +386,7 @@ def verify_unary_rsqrt(input_dim, epsilon=0): epsilon=epsilon) model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -397,7 +408,7 @@ def verify_unary_inverse(input_dim, epsilon=0): epsilon=epsilon) model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -419,7 +430,7 @@ def verify_unary_power(input_dim, alpha): alpha=alpha) model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -440,7 +451,7 @@ def verify_unary_exp(input_dim): mode='exp') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -461,7 +472,7 @@ def verify_unary_log(input_dim): mode='log') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -482,7 +493,7 @@ def verify_unary_abs(input_dim): mode='abs') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -504,12 +515,13 @@ def verify_unary_threshold(input_dim, alpha): alpha=alpha) model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_unary(): verify_unary_sqrt((1, 3, 20, 20)) verify_unary_rsqrt((1, 3, 20, 20)) @@ -525,6 +537,7 @@ def test_forward_unary(): verify_unary_threshold((1, 3, 20, 20), alpha=5.0) +@tvm.testing.uses_gpu def test_forward_reduce(): from enum import Enum class ReduceAxis(Enum): @@ -565,7 +578,7 @@ def _verify_reduce(input_dim, mode, axis, ref_func, dtype='float32'): mode=mode) model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5, atol=1e-5) @@ -602,7 +615,7 @@ def verify_reshape(input_dim, target_shape, mode): mode=mode) model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], ref_val.shape, dtype) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -637,7 +650,7 @@ def verify_split(input_dim, nOutputs): output_names=output_names) model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input'], output_shapes, [dtype] * len(output_shapes)) tvm.testing.assert_allclose(out, ref_val, rtol=1e-5) @@ -673,11 +686,12 @@ def verify_image_scaler(input_dim, blue_bias=0.0, green_bias=0.0, red_bias=0.0, builder.add_elementwise(name='add', input_names=['input1', 'input2'], output_name='output', alpha=0, mode='ADD') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np, a_np], ['input1', 'input2'], b_np.shape, dtype) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_image_scaler(): verify_image_scaler((3, 224, 224), image_scale=0.17) verify_image_scaler((3, 224, 224), @@ -705,11 +719,12 @@ def verify_convolution(input_dim, filter, padding): input_name='input1', output_name='output') model = cm.models.MLModel(builder.spec) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): out = run_tvm_graph(model, target, ctx, [a_np], ['input1'], output_shape=None) tvm.testing.assert_allclose(out, b_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_forward_convolution(): verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding='VALID') verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding='SAME') diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index f9402554d53c..94822303c3b4 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -19,8 +19,8 @@ from tvm import te from tvm import relay from tvm.contrib import graph_runtime -from tvm.relay.testing.config import ctx_list import keras +import tvm.testing try: import tensorflow.compat.v1 as tf @@ -104,7 +104,7 @@ def to_channels_last(arr): xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes] keras_out = get_keras_output(xs) keras_out = keras_out if isinstance(keras_out, list) else [keras_out] - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): inputs = [to_channels_first(x) for x in xs] if need_transpose else xs tvm_out = get_tvm_output(inputs, target, ctx) for kout, tout in zip(keras_out, tvm_out): @@ -113,6 +113,7 @@ def to_channels_last(arr): tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu class TestKeras: scenarios = [using_classic_keras, using_tensorflow_keras] diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 594ffe72faf0..bc5cbebeba1d 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -20,15 +20,16 @@ import tvm from tvm import te from tvm.contrib import graph_runtime -from tvm.relay.testing.config import ctx_list from tvm import relay import mxnet as mx from mxnet import gluon from mxnet.gluon.model_zoo import vision -import model_zoo import random import pytest +import model_zoo + +import tvm.testing def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), @@ -82,32 +83,36 @@ def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): x = np.random.uniform(size=data_shape) if gluon_impl: gluon_out, gluon_sym = get_gluon_output(name, x) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype) tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5) else: mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype) assert "data" not in args - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype) tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_forward_mlp(): mlp = model_zoo.mx_mlp() verify_mxnet_frontend_impl(mlp, data_shape=(1, 1, 28, 28), out_shape=(1, 10)) +@tvm.testing.uses_gpu def test_forward_vgg(): for n in [11]: mx_sym = model_zoo.mx_vgg(n) verify_mxnet_frontend_impl(mx_sym) +@tvm.testing.uses_gpu def test_forward_resnet(): for n in [18]: mx_sym = model_zoo.mx_resnet(18) verify_mxnet_frontend_impl(mx_sym) +@tvm.testing.uses_gpu def test_forward_leaky_relu(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly @@ -116,36 +121,42 @@ def test_forward_leaky_relu(): mx_sym = mx.sym.LeakyReLU(data, act_type='leaky') verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) +@tvm.testing.uses_gpu def test_forward_elu(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.LeakyReLU(data, act_type='elu') verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) +@tvm.testing.uses_gpu def test_forward_rrelu(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7) verify_mxnet_frontend_impl(mx_sym[0], (1, 3, 100, 100), (1, 6, 100, 100)) +@tvm.testing.uses_gpu def test_forward_prelu(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.LeakyReLU(data, act_type='prelu') verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) +@tvm.testing.uses_gpu def test_forward_gelu(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.LeakyReLU(data, act_type='gelu') verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) +@tvm.testing.uses_gpu def test_forward_softrelu(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.Activation(data, act_type='softrelu') verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) +@tvm.testing.uses_gpu def test_forward_fc_flatten(): # test flatten=True option in mxnet 0.11.1 data = mx.sym.var('data') @@ -157,27 +168,32 @@ def test_forward_fc_flatten(): except: pass +@tvm.testing.uses_gpu def test_forward_clip(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.clip(data, a_min=0, a_max=1) verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) +@tvm.testing.uses_gpu def test_forward_split(): data = mx.sym.var('data') mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False) verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1)) +@tvm.testing.uses_gpu def test_forward_split_squeeze(): data = mx.sym.var('data') mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True) verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1)) +@tvm.testing.uses_gpu def test_forward_expand_dims(): data = mx.sym.var('data') mx_sym = mx.sym.expand_dims(data, axis=1) verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4)) +@tvm.testing.uses_gpu def test_forward_pooling(): data = mx.sym.var('data') mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg') @@ -186,6 +202,7 @@ def test_forward_pooling(): mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max') verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) +@tvm.testing.uses_gpu def test_forward_pooling3d(): data = mx.sym.var('data') mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='avg') @@ -194,6 +211,7 @@ def test_forward_pooling3d(): mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='max') verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8)) +@tvm.testing.uses_gpu def test_forward_adaptive_pooling(): data = mx.sym.var('data') mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(1,)) @@ -202,49 +220,58 @@ def test_forward_adaptive_pooling(): mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(3, 3)) verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 3, 3)) +@tvm.testing.uses_gpu def test_forward_lrn(): data = mx.sym.var('data') mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5) verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24)) +@tvm.testing.uses_gpu def test_forward_ones(): data = mx.sym.var('data') ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32') mx_sym = mx.sym.elemwise_add(data, ones) verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) +@tvm.testing.uses_gpu def test_forward_zeros(): data = mx.sym.var('data') zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32') mx_sym = mx.sym.elemwise_add(data, zeros) verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) +@tvm.testing.uses_gpu def test_forward_ones_like(): data = mx.sym.var('data') mx_sym = mx.sym.ones_like(data, dtype='float32') verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) +@tvm.testing.uses_gpu def test_forward_make_loss(): data = mx.sym.var('data') ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32') mx_sym = mx.sym.make_loss((data-ones)**2/2, dtype='float32') verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) +@tvm.testing.uses_gpu def test_forward_zeros_like(): data = mx.sym.var('data') mx_sym = mx.sym.zeros_like(data, dtype='float32') verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) +@tvm.testing.uses_gpu def test_forward_argmax(): data = mx.sym.var('data') mx_sym = mx.sym.argmax(data, axis=1) verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,)) +@tvm.testing.uses_gpu def test_forward_argmin(): data = mx.sym.var('data') mx_sym = mx.sym.argmin(data, axis=0) verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,)) +@tvm.testing.uses_gpu def test_forward_slice(): data = mx.sym.var('data') mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4)) @@ -252,6 +279,7 @@ def test_forward_slice(): mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2)) verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2)) +@tvm.testing.uses_gpu def test_forward_where(): cond = mx.sym.var('cond') x = mx.sym.var('x') @@ -273,13 +301,14 @@ def test_forward_where(): mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy() mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(np_cond, np_x, np_y) tvm.testing.assert_allclose(op_res.asnumpy(), mx_out) +@tvm.testing.uses_gpu def test_forward_arange(): def _mx_symbol(F, start, stop, step): if start is None and step is None: @@ -296,7 +325,7 @@ def verify(start, stop, step): ref_res = _mx_symbol(mx.nd, start, stop, step).asnumpy() mx_sym = _mx_symbol(mx.sym, start, stop, step) mod, _ = relay.frontend.from_mxnet(mx_sym, {}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()() @@ -315,6 +344,7 @@ def _mx_symbol(F, op_name, inputs): op = getattr(F, op_name) return op(*inputs) +@tvm.testing.uses_gpu def test_forward_broadcast_ops(): for op in ["broadcast_add", "broadcast_plus", @@ -349,12 +379,13 @@ def test_forward_broadcast_ops(): ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) shapes = {'a': a_shape, 'b': b_shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(a_np, b_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) +@tvm.testing.uses_gpu def test_forward_elemwise_ops(): for op in ["elemwise_add", "elemwise_sub", "elemwise_mul", "elemwise_div", "maximum", "minimum", @@ -372,13 +403,14 @@ def test_forward_elemwise_ops(): ref_res = op(mx.nd.array(a_np), mx.nd.array(b_np)) shapes = {'a': shape, 'b': shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(a_np, b_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) +@tvm.testing.uses_gpu def test_forward_softmin(): data = mx.sym.var('data') mx_sym = mx.sym.softmin(data) @@ -388,6 +420,7 @@ def test_forward_softmin(): verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100)) +@tvm.testing.uses_gpu def test_forward_unary_ops(): for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", "trunc", "softsign", "hard_sigmoid", @@ -402,13 +435,14 @@ def test_forward_unary_ops(): ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np)]) shapes = {'a': shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(a_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_forward_scalar_ops(): for op in [operator.add, operator.sub, operator.mul, operator.truediv, operator.pow, operator.lt, operator.le, operator.eq, @@ -421,7 +455,7 @@ def test_forward_scalar_ops(): ref_res = op(mx.nd.array(a_np), b_scalar) shapes = {'a': a_shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(a_np) @@ -435,19 +469,20 @@ def test_forward_scalar_ops(): ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar]) shapes = {'a': a_shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(a_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) +@tvm.testing.uses_gpu def test_forward_slice_axis(): def verify(shape, axis, begin, end): data_np = np.random.uniform(size=shape).astype("float32") ref_res = mx.nd.slice_axis(mx.nd.array(data_np), axis, begin, end) mx_sym = mx.sym.slice_axis(mx.sym.var("data"), axis, begin, end) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data_np) @@ -458,6 +493,7 @@ def verify(shape, axis, begin, end): verify((3, 4), 1, -3, -1) verify((3, 4), -1, -3, -1) +@tvm.testing.uses_gpu def test_forward_slice_like(): def verify(x_shape, y_shape, axes): x_np = np.random.uniform(size=x_shape).astype("float32") @@ -469,7 +505,7 @@ def verify(x_shape, y_shape, axes): ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np), axes=axes) mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"), axes=axes) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np, y_np) @@ -479,6 +515,7 @@ def verify(x_shape, y_shape, axes): verify((3, 4), (2, 3), (0)) verify((3, 4), (2, 3), (-1)) +@tvm.testing.uses_gpu def test_forward_sequence_reverse(): def verify(shape, seq_lengths, use_seq_lengths, seq_axis): data_np = np.random.uniform(size=shape).astype("float32") @@ -500,7 +537,7 @@ def verify(shape, seq_lengths, use_seq_lengths, seq_axis): mx_sym = mx.sym.SequenceReverse(*mx_sym_args) mod, _ = relay.frontend.from_mxnet(mx_sym, *from_mxnet_args) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(*in_data) @@ -512,18 +549,20 @@ def verify(shape, seq_lengths, use_seq_lengths, seq_axis): # MXNet accepts axis value as 0 only # verify((3, 4, 5, 6), None, False, 2) +@tvm.testing.uses_gpu def test_forward_l2_normalize(): data = mx.sym.var('data') mx_sym = mx.sym.L2Normalization(data, mode="channel") verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) +@tvm.testing.uses_gpu def test_forward_shape_array(): def verify(shape): x_np = np.random.uniform(size=shape).astype("float32") ref_res = mx.nd.shape_array(mx.nd.array(x_np)) mx_sym = mx.sym.shape_array(mx.sym.var("x")) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) @@ -532,6 +571,7 @@ def verify(shape): verify((3, 4, 5)) verify((3, 4, 5, 6)) +@tvm.testing.uses_gpu def test_forward_squeeze(): def verify(shape, axis): x_np = np.random.uniform(size=shape).astype("float32") @@ -542,7 +582,7 @@ def verify(shape, axis): ref_res = mx.nd.squeeze(mx.nd.array(x_np), axis=axis) mx_sym = mx.sym.squeeze(mx.sym.var("x"), axis=axis) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) @@ -552,6 +592,7 @@ def verify(shape, axis): verify((1, 3, 1), 2) verify((1, 3, 1), (0, 2)) +@tvm.testing.uses_gpu def test_forward_broadcast_axis(): def verify(shape, axis, size): x_np = np.random.uniform(size=shape).astype("float32") @@ -560,7 +601,7 @@ def verify(shape, axis, size): mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('x'),axis,size]) ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(x_np),axis,size]) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) @@ -570,13 +611,14 @@ def verify(shape, axis, size): verify((1, 2, 1), (0, 2), (2, 3)) +@tvm.testing.uses_gpu def test_forward_broadcast_to(): def verify(input_shape, shape): x_np = np.random.uniform(size=input_shape).astype("float32") ref_res = mx.nd.broadcast_to(mx.nd.array(x_np), shape=shape) mx_sym = mx.sym.broadcast_to(mx.sym.var("x"), shape=shape) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) @@ -586,6 +628,7 @@ def verify(input_shape, shape): verify((4, 1, 32, 32), (4, 8, 32, 32)) +@tvm.testing.uses_gpu def test_forward_logical_not(): a_shape = (3, 4, 5) dtype = 'float32' @@ -594,20 +637,21 @@ def test_forward_logical_not(): ref_res = mx.nd.logical_not(mx.nd.array(a_np)) shapes = {'a': a_shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(a_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) +@tvm.testing.uses_gpu def test_forward_full(): def verify(val, shape, dtype): ctx = mx.cpu() ref_res = mx.nd.full(shape, val, dtype=dtype) mx_sym = mx.sym.full(shape, val, dtype=dtype) mod, _ = relay.frontend.from_mxnet(mx_sym, {}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): # Skip testing graph runtime because this op will be optimized out # by constant folding. for kind in ["debug"]: @@ -618,6 +662,7 @@ def verify(val, shape, dtype): verify(2, (3, 4), "int32") verify(3.5, (1, 3, 4), "float32") +@tvm.testing.uses_gpu def test_forward_embedding(): def verify(data_shape, weight_shape): in_dim, out_dim = weight_shape @@ -629,7 +674,7 @@ def verify(data_shape, weight_shape): input_dim=in_dim, output_dim=out_dim) mod, _ = relay.frontend.from_mxnet( mx_sym, {"x": data_shape, "w": weight_shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x=x_np, w=w_np) @@ -637,6 +682,7 @@ def verify(data_shape, weight_shape): verify((2, 2), (4, 5)) verify((2, 3, 4), (4, 5)) +@tvm.testing.uses_gpu def test_forward_smooth_l1(): data = mx.sym.var('data') mx_sym = mx.sym.smooth_l1(data) @@ -644,6 +690,7 @@ def test_forward_smooth_l1(): mx_sym = mx.sym.smooth_l1(data, scalar=1.0) verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4)) +@tvm.testing.uses_gpu def test_forward_take(): def verify(shape, indices_src, axis, mode="clip"): x_np = np.random.uniform(size=shape).astype("float32") @@ -651,7 +698,7 @@ def verify(shape, indices_src, axis, mode="clip"): ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode) mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np, indices_np) @@ -664,13 +711,14 @@ def verify(shape, indices_src, axis, mode="clip"): verify((3,4), [-1, 5], 1) verify((3,4), [-1, 5], 1, mode="wrap") +@tvm.testing.uses_gpu def test_forward_gather_nd(): def verify(xshape, yshape, y_data, error=False): x_data = np.random.uniform(size=xshape).astype("float32") ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data)) mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data")) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_data, y_data) @@ -682,12 +730,14 @@ def verify(xshape, yshape, y_data, error=False): verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) verify((1, 4), (1, 1), [[0]]) +@tvm.testing.uses_gpu def test_forward_bilinear_resize(): # add tests including scale_height and scale_width when mxnet is updated to version 1.5 data = mx.sym.var('data') mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10) verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) +@tvm.testing.uses_gpu def test_forward_grid_generator(): def verify(shape, transform_type, target_shape): x = np.random.uniform(size=shape).astype("float32") @@ -695,7 +745,7 @@ def verify(shape, transform_type, target_shape): mx_sym = mx.sym.GridGenerator(mx.sym.var("x"), transform_type, target_shape) shape_dict = {"x": x.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor( kind, mod=mod, ctx=ctx, target=target) @@ -706,6 +756,7 @@ def verify(shape, transform_type, target_shape): verify((4, 2, 16, 16), 'warp', None) verify((1, 2, 16, 16), 'warp', None) +@tvm.testing.uses_gpu def test_forward_bilinear_sampler(): def verify(data_shape, grid_shape): data = np.random.uniform(size=data_shape).astype("float32") @@ -714,7 +765,7 @@ def verify(data_shape, grid_shape): mx_sym = mx.sym.BilinearSampler(mx.sym.var("data"), mx.sym.var("grid")) shape_dict = {"data": data.shape, "grid": grid.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor( kind, mod=mod, ctx=ctx, target=target) @@ -724,6 +775,7 @@ def verify(data_shape, grid_shape): verify((4, 4, 16, 32), (4, 2, 8, 8)) verify((4, 4, 16, 32), (4, 2, 32, 32)) +@tvm.testing.uses_gpu def test_forward_rnn_layer(): def verify(mode, seq_len, input_size, hidden_size, num_layers, batch=1, init_states=True, bidirectional=False): @@ -768,7 +820,7 @@ def verify(mode, seq_len, input_size, hidden_size, num_layers, mod, params = relay.frontend.from_mxnet( mx_sym, shape=shape_dict, arg_params=mx_params) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): # only test graph runtime because debug runtime is too slow for kind in ["graph"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) @@ -792,6 +844,7 @@ def verify(mode, seq_len, input_size, hidden_size, num_layers, # verify(mode, 10, 64, 64, 3, init_states=False) # verify(mode, 10, 64, 64, 3, batch=2, bidirectional=True, init_states=False) +@tvm.testing.uses_gpu def test_forward_Crop(): def verify(xshape, yshape, offset=None): x_data = np.random.uniform(size=xshape).astype("float32") @@ -803,7 +856,7 @@ def verify(xshape, yshape, offset=None): mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"), offset=offset) ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data), offset=offset) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) if offset is None or offset == (0, 0): @@ -817,13 +870,14 @@ def verify(xshape, yshape, offset=None): verify((5, 32, 40, 40), (5, 32, 25, 25)) verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5)) +@tvm.testing.uses_gpu def test_forward_argsort(): def verify(shape, axis, is_ascend, dtype="float32"): x_np = np.random.uniform(size=shape).astype("float32") ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype) mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) @@ -832,6 +886,7 @@ def verify(shape, axis, is_ascend, dtype="float32"): verify((1, 4, 6), axis=1, is_ascend=True) verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32") +@tvm.testing.uses_gpu def test_forward_topk(): def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): x_np = np.random.uniform(size=shape).astype("float32") @@ -840,7 +895,7 @@ def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) @@ -856,6 +911,7 @@ def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True) verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32") +@tvm.testing.uses_gpu def test_forward_sequence_mask(): def verify(shape, use_sequence_length, value, axis, dtype, itype): data_np = np.random.uniform(size=shape).astype(dtype) @@ -885,7 +941,7 @@ def verify(shape, use_sequence_length, value, axis, dtype, itype): value=value, axis=axis) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}, dtype={"data": dtype}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ['graph', 'debug']: if use_sequence_length is False and kind == 'graph': # Disable the test for 'graph' when it's identity. @@ -901,13 +957,14 @@ def verify(shape, use_sequence_length, value, axis, dtype, itype): verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64') verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32') +@tvm.testing.uses_gpu def test_forward_contrib_div_sqrt_dim(): def verify(shape): x_np = np.random.uniform(size=shape).astype("float32") ref_res = mx.nd.contrib.div_sqrt_dim(mx.nd.array(x_np)) mx_sym = mx.sym.contrib.div_sqrt_dim(mx.sym.var("x")) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) @@ -915,6 +972,7 @@ def verify(shape): verify((3, 4)) verify((3, 4, 5)) +@tvm.testing.uses_gpu def test_forward_batch_norm(): def verify(shape, axis=1, fix_gamma=False): x = np.random.uniform(size=shape).astype("float32") @@ -934,7 +992,7 @@ def verify(shape, axis=1, fix_gamma=False): "mean": moving_mean.shape, "var": moving_var.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) #print(mod) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x, gamma, beta, moving_mean, moving_var) @@ -945,6 +1003,7 @@ def verify(shape, axis=1, fix_gamma=False): verify((2, 3, 4, 5), fix_gamma=True) +@tvm.testing.uses_gpu def test_forward_instance_norm(): def verify(shape, axis=1, epsilon=1e-5): x = np.random.uniform(size=shape).astype("float32") @@ -954,7 +1013,7 @@ def verify(shape, axis=1, epsilon=1e-5): mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon) shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x, gamma, beta) @@ -965,6 +1024,7 @@ def verify(shape, axis=1, epsilon=1e-5): verify((8, 7, 6, 5, 4)) +@tvm.testing.uses_gpu def test_forward_layer_norm(): def verify(shape, axis=-1): x = np.random.uniform(size=shape).astype("float32") @@ -976,7 +1036,7 @@ def verify(shape, axis=-1): mx.sym.var("beta"), axis=axis) shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x, gamma, beta) @@ -985,6 +1045,7 @@ def verify(shape, axis=-1): verify((2, 5), axis=0) verify((2, 5, 6)) +@tvm.testing.uses_gpu def test_forward_one_hot(): def verify(indices_shape, depth, on_value, off_value, dtype): x = np.random.randint(0, 5, size=indices_shape) @@ -992,7 +1053,7 @@ def verify(indices_shape, depth, on_value, off_value, dtype): mx_sym = mx.sym.one_hot(mx.sym.var("x"), depth, on_value, off_value, dtype) shape_dict = {"x": x.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x.astype("float32")) @@ -1004,6 +1065,7 @@ def verify(indices_shape, depth, on_value, off_value, dtype): verify((3, 2, 4, 5), 6, 1, 0, "int32") verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32") +@tvm.testing.uses_gpu def test_forward_pad(): def verify(data_shape, out_shape, mode, pad_width, constant_value=0.0): data = mx.sym.var('data') @@ -1028,6 +1090,7 @@ def verify(data_shape, out_shape, mode, pad_width, constant_value=0.0): pad_width=(0,0,0,0,1,2,3,4,5,6)) +@tvm.testing.uses_gpu def test_forward_slice(): def verify(data_shape, out_shape, begin, end): data = mx.sym.var('data') @@ -1038,6 +1101,7 @@ def verify(data_shape, out_shape, begin, end): verify(data_shape=(1,1,10), out_shape=(1,1,8), begin=(None, None, 2), end=(None, None, None)) +@tvm.testing.uses_gpu def test_forward_convolution(): def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False): if is_depthwise: @@ -1057,7 +1121,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False) pad=pad, num_filter=num_filter, num_group=groups) shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x, weight, bias) @@ -1078,6 +1142,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False) verify(data_shape=(1, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(2, 2, 2), pad=(1, 1, 1), num_filter=2) verify(data_shape=(20, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2) +@tvm.testing.uses_gpu def test_forward_deconvolution(): def verify(data_shape, kernel_size, stride, pad, num_filter): weight_shape=(data_shape[1], num_filter) + kernel_size @@ -1092,7 +1157,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): pad=pad, num_filter=num_filter, no_bias=False) shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x, weight, bias) @@ -1107,6 +1172,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) +@tvm.testing.uses_gpu def test_forward_cond(): def verify(a_np, b_np): a_nd, b_nd = mx.nd.array(a_np), mx.nd.array(b_np) @@ -1123,7 +1189,7 @@ def verify(a_np, b_np): shape_dict = {"a": a_np.shape, "b": b_np.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["debug", "vm"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(a_np, b_np) @@ -1132,6 +1198,7 @@ def verify(a_np, b_np): verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32')) verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32')) +@tvm.testing.uses_gpu def test_forward_amp_cast(): def verify(from_dtype, to_dtype): from_np = np.random.uniform(size=(1,3,18)).astype(from_dtype) @@ -1140,7 +1207,7 @@ def verify(from_dtype, to_dtype): shape_dict = {'x': (1,3,18)} dtype_dict = {'x': from_dtype} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(from_np) @@ -1150,6 +1217,7 @@ def verify(from_dtype, to_dtype): verify('float32', 'float16') verify('float16', 'float32') +@tvm.testing.uses_gpu def test_forward_amp_multicast(): def verify(dtypes, cast_narrow, expected_dtype): x_nps = [np.random.uniform(size=(1,3,18)).astype(dtype) for dtype in dtypes] @@ -1162,7 +1230,7 @@ def verify(dtypes, cast_narrow, expected_dtype): shape_dict[str(i)] = (1,3,18) dtype_dict[str(i)] = dtype mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(*x_nps) @@ -1178,6 +1246,7 @@ def verify(dtypes, cast_narrow, expected_dtype): verify(['float16', 'float16'], True, 'float16') +@tvm.testing.uses_gpu def test_forward_unravel_index(): def verify(x, shape, dtype): a_np = np.array(x).astype(dtype) @@ -1186,7 +1255,7 @@ def verify(x, shape, dtype): shapes = {'a': a_np.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(a_np) @@ -1204,6 +1273,7 @@ def verify(x, shape, dtype): # verify([0, 1, 2, 5], [2, 2], dtype) +@tvm.testing.uses_gpu def test_forward_swap_axis(): def _verify_swap_axis(in_shape, out_shape, dim1, dim2): data = mx.sym.var('data') @@ -1216,6 +1286,7 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): # _verify_swap_axis((4, 5), (5, 4), 0, 0) +@tvm.testing.uses_gpu def test_forward_depth_to_space(): def verify(shape, blocksize=2): x = np.random.uniform(size=shape).astype("float32") @@ -1223,7 +1294,7 @@ def verify(shape, blocksize=2): mx_sym = mx.sym.depth_to_space(mx.sym.var("x"), blocksize) shape_dict = {"x": x.shape, } mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x) @@ -1232,6 +1303,7 @@ def verify(shape, blocksize=2): verify((1, 18, 3, 3), 3) +@tvm.testing.uses_gpu def test_forward_space_to_depth(): def verify(shape, blocksize=2): x = np.random.uniform(size=shape).astype("float32") @@ -1239,7 +1311,7 @@ def verify(shape, blocksize=2): mx_sym = mx.sym.space_to_depth(mx.sym.var("x"), blocksize) shape_dict = {"x": x.shape, } mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x) @@ -1248,6 +1320,7 @@ def verify(shape, blocksize=2): verify((1, 1, 9, 9), 3) +@tvm.testing.uses_gpu def test_forward_correlation(): def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply): @@ -1263,7 +1336,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size is_multiply=is_multiply) shape_dict = {"data1": data1.shape, "data2": data2.shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data1, data2) @@ -1280,6 +1353,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size verify((5, 1, 11, 11), kernel_size = 5, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = False) +@tvm.testing.uses_gpu def test_forward_arange_like(): def verify(data_shape, start=None, step=None, axis=None): attrs = {} @@ -1295,7 +1369,7 @@ def verify(data_shape, start=None, step=None, axis=None): mx_sym = mx.sym.contrib.arange_like(data, **attrs) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()() @@ -1307,6 +1381,7 @@ def verify(data_shape, start=None, step=None, axis=None): verify(data_shape=(3, 4, 5), start=2., step=3., axis=1) +@tvm.testing.uses_gpu def test_forward_interleaved_matmul_selfatt_qk(): def verify(batch, seq_length, num_heads, head_dim): data_shape = (seq_length, batch, num_heads * head_dim * 3) @@ -1317,7 +1392,7 @@ def verify(batch, seq_length, num_heads, head_dim): mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_qk(data, heads=num_heads) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data_np) @@ -1327,6 +1402,7 @@ def verify(batch, seq_length, num_heads, head_dim): verify(3, 10, 6, 8) +@tvm.testing.uses_gpu def test_forward_interleaved_matmul_selfatt_valatt(): def verify(batch, seq_length, num_heads, head_dim): data_shape = (seq_length, batch, num_heads * head_dim * 3) @@ -1342,7 +1418,7 @@ def verify(batch, seq_length, num_heads, head_dim): data, weight, heads=num_heads) mod, _ = relay.frontend.from_mxnet( mx_sym, {"data": data_shape, "weight": weight_shape}) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data=data_np, weight=weight_np) @@ -1352,6 +1428,7 @@ def verify(batch, seq_length, num_heads, head_dim): verify(3, 10, 6, 8) +@tvm.testing.uses_gpu def test_forward_box_decode(): def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corner"): dtype = "float32" @@ -1361,7 +1438,7 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn mx_sym = mx.sym.contrib.box_decode(mx.sym.var("data"), mx.sym.var("anchors"), stds[0], stds[1], stds[2], stds[3], clip, in_format) shape_dict = {"data": data_shape, "anchors": anchor_shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data, anchors) @@ -1374,6 +1451,7 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn verify((1, 10, 4), (1, 10, 4), in_format="center") +@tvm.testing.uses_gpu def test_forward_softmax(): def verify(data_shape, axis, use_length, length): dtype = "float32" @@ -1394,7 +1472,7 @@ def verify(data_shape, axis, use_length, length): shape_dict = {"data": data_shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) if use_length: @@ -1419,7 +1497,7 @@ def verify(data_shape, axis, use_length, length): @pytest.mark.parametrize("mode", ["constant", "edge", "reflect"]) @pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32']) @pytest.mark.parametrize("constant_value", [0.0, 3.0]) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value,target, ctx, kind): data_np = np.random.uniform(size=data_shape).astype(dtype) @@ -1435,12 +1513,12 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value,targ op_res = intrp.evaluate()(data_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) - + @pytest.mark.skipif(not hasattr(mx.sym.np, 'pad'), reason="test'll abort with Mxnet 1.x, skip for now") @pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2)]) @pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32', 'bool']) @pytest.mark.parametrize("axes", [(1,0,2),None]) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_transpose(data_shape, axes, dtype,target, ctx, kind): data_np = np.random.uniform(size=data_shape).astype(dtype) @@ -1458,7 +1536,7 @@ def test_forward_npi_transpose(data_shape, axes, dtype,target, ctx, kind): [((2,2),(2,2),1),((2,4),(2,3),1),((1,3,2),(1,3,5),2),((1,3,3),(1,3,3),1),((1,3),(1,3),0)] ) @pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32']) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype,target, ctx, kind): data_np1 = np.random.uniform(size=data_shape1).astype(dtype) @@ -1475,7 +1553,7 @@ def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype,target, c @pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8)]) @pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32', 'bool']) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_np_copy(data_shape,dtype,target, ctx, kind): data_np = np.random.uniform(size=data_shape).astype(dtype) @@ -1489,7 +1567,7 @@ def test_forward_np_copy(data_shape,dtype,target, ctx, kind): @pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32', 'bool']) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) @pytest.mark.parametrize("data_shape,out_shape,reverse", [((2, 3, 8),(-2, -2, 2, -1),False), @@ -1510,7 +1588,7 @@ def test_forward_npx_reshape(data_shape,out_shape,dtype,target,reverse, ctx, kin @pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)]) @pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32']) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_binary(data_shape,dtype,target, ctx, kind): ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.less] @@ -1535,7 +1613,7 @@ def test_forward_npi_binary(data_shape,dtype,target, ctx, kind): @pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)]) @pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32']) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("scalar", [1.0,2.0,3.0,4.0]) @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_binary_scalar(data_shape,dtype,scalar,target, ctx, kind): @@ -1559,7 +1637,7 @@ def test_forward_npi_binary_scalar(data_shape,dtype,scalar,target, ctx, kind): @pytest.mark.parametrize("data_shape", [(2,2,2),(2,7,2),(2,2,2,1,2,3,1),(1,8),(2,2),(1,3)]) @pytest.mark.parametrize("dtype", ['float64', 'float32']) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_tanh(data_shape,dtype,target, ctx, kind): data_np1 = np.random.uniform(size=data_shape).astype(dtype) @@ -1577,7 +1655,7 @@ def test_forward_npi_tanh(data_shape,dtype,target, ctx, kind): @pytest.mark.parametrize("data_dtype", ['float64', 'float32', 'int64', 'int32', 'bool']) @pytest.mark.parametrize("cond_dtype", ['float64', 'float32', 'int64', 'int32', 'bool']) @pytest.mark.parametrize("scalar", [1.0,2.0]) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_where_rscalar(data_shape,cond_dtype,data_dtype,scalar,target, ctx, kind): if data_dtype == 'bool': @@ -1600,7 +1678,7 @@ def test_forward_npi_where_rscalar(data_shape,cond_dtype,data_dtype,scalar,targe @pytest.mark.parametrize("dtype", ['float64', 'float32', 'int64', 'int32', 'bool']) -@pytest.mark.parametrize("target, ctx", ctx_list()) +@tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) @pytest.mark.parametrize("data_shape, axis, indices_or_sections, squeeze_axis", [((3,2,1),1,2,False),((3,2,1),0,3,False),((3,2,1),0,3,True),((3,2,1),0,(1,2),False)]) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c09580e57301..5921c0df23ea 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -26,8 +26,8 @@ from tvm import te from tvm import relay from tvm.contrib import graph_runtime -from tvm.relay.testing.config import ctx_list import scipy +import tvm.testing def get_input_data_shape_dict(graph_def, input_data): @@ -117,11 +117,12 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): x = np.random.uniform(size=data_shape) model = onnx.load_model(graph_file) c2_out = get_onnxruntime_output(model, x, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype) tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_reshape(): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) @@ -145,13 +146,14 @@ def test_reshape(): model = helper.make_model(graph, producer_name='reshape_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype('int32') tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') tvm.testing.assert_allclose(ref_shape, tvm_out.shape) +@tvm.testing.uses_gpu def test_expand(): def _test_expand(name, data, shape, ref_data): @@ -174,7 +176,7 @@ def _test_expand(name, data, shape, ref_data): model = helper.make_model(graph, producer_name=name) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, 'float32') tvm.testing.assert_allclose(ref_data, tvm_out) @@ -205,13 +207,14 @@ def verify_depth_to_space(inshape, outshape, mode, blockSize): model = helper.make_model(graph, producer_name='depth_to_space_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=inshape).astype('float32') tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32') onnx_out = get_onnxruntime_output(model, x, 'float32') tvm.testing.assert_allclose(onnx_out, tvm_out) +@tvm.testing.uses_gpu def test_depth_to_space(): # current onnx.checker use OpSet-1 version of DepthToSpace, which doesn't have a mode argument. # TO-DO, we can add mode arguement to test CRD mode and DCR mode @@ -232,17 +235,19 @@ def verify_space_to_depth(inshape, outshape, blockSize): model = helper.make_model(graph, producer_name='space_to_depth_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=inshape).astype('float32') tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32') onnx_out = get_onnxruntime_output(model, x, 'float32') tvm.testing.assert_allclose(onnx_out, tvm_out) +@tvm.testing.uses_gpu def test_space_to_depth(): verify_space_to_depth((1, 1, 4, 6), (1, 4, 2, 3), 2) +@tvm.testing.uses_gpu def test_shape(): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) @@ -268,7 +273,7 @@ def test_shape(): model = helper.make_model(graph, producer_name='shape_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype('int32') tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'int32') @@ -297,17 +302,19 @@ def _test_power_iteration(x_shape, y_shape): model = helper.make_model(graph, producer_name='power_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape) tvm.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_power(): _test_power_iteration((1, 3), (1)) _test_power_iteration((2, 3), (2, 3)) _test_power_iteration((2, 3), (1, 3)) +@tvm.testing.uses_gpu def test_squeeze(): in_shape = (1, 3, 1, 3, 1, 1) out_shape = (3, 3) @@ -322,13 +329,14 @@ def test_squeeze(): model = helper.make_model(graph, producer_name='squeeze_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype('float32') tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_shape, tvm_out.shape) +@tvm.testing.uses_gpu def test_flatten(): in_shape = (1, 3, 4, 4) @@ -346,13 +354,14 @@ def test_flatten(): model = helper.make_model(graph, producer_name='flatten_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype('int32') tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') tvm.testing.assert_allclose(ref_shape, tvm_out.shape) +@tvm.testing.uses_gpu def test_unsqueeze(): in_shape = (3, 3) axis = (0, 3, 4) @@ -368,7 +377,7 @@ def test_unsqueeze(): model = helper.make_model(graph, producer_name='squeeze_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype('float32') tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32') @@ -392,12 +401,13 @@ def verify_gather(in_shape, indices, axis, dtype): TensorProto.FLOAT, list(out_np.shape))]) model = helper.make_model(graph, producer_name='gather_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [x, indices], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out) +@tvm.testing.uses_gpu def test_gather(): verify_gather((4,), [1], 0, 'int32') verify_gather((1, 4), [0], 0, 'int32') @@ -427,12 +437,13 @@ def verify_scatter(in_shape, indices, axis): model = helper.make_model(graph, producer_name='scatter_test') onnx_out = get_onnxruntime_output(model, [x, indices, updates]) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [x, indices, updates], target, ctx, onnx_out[0].shape) tvm.testing.assert_allclose(onnx_out[0], tvm_out) +@tvm.testing.uses_gpu def test_scatter(): verify_scatter((4,), [1], 0) verify_scatter((1, 4), [[0]], 0) @@ -459,7 +470,7 @@ def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): model = helper.make_model(graph, producer_name='slice_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, indata, target, ctx, outdata.shape, 'float32', opset=1) @@ -547,7 +558,7 @@ def add_noop_to_input_attr(attr_name, attr): initializer=initializer) model = helper.make_model(graph, producer_name='slice_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, indata, target, @@ -559,6 +570,7 @@ def add_noop_to_input_attr(attr_name, attr): tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_slice(): x = np.random.randn(20, 10, 5).astype(np.float32) _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) @@ -595,22 +607,25 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): model = helper.make_model(graph, producer_name=opname+'_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, indata, target, ctx, outdata.shape, dtype) tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_floor(): _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {}) +@tvm.testing.uses_gpu def test_ceil(): _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {}) +@tvm.testing.uses_gpu def test_clip(): _test_onnx_op_elementwise((2, 4, 5, 6), np.clip, @@ -620,7 +635,7 @@ def test_clip(): {'min': -1.0, 'max': 1.0}) - +@tvm.testing.uses_gpu def test_round(): _test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, 'float32', 'Round', {}) @@ -640,17 +655,19 @@ def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs): model = helper.make_model(graph, producer_name=opname+'_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, indata, target, ctx, outdata.shape, dtype) tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_isinf(): _test_finite_ops((2, 4, 5, 6), np.isinf, {}, 'float32', 'IsInf', {}) +@tvm.testing.uses_gpu def test_isnan(): _test_finite_ops((2, 4, 5, 6), np.isnan, {}, 'float32', 'IsNaN', {}) @@ -672,18 +689,20 @@ def verify_gather_nd(in_shape, indices, dtype): TensorProto.FLOAT, list(out_np.shape))]) model = helper.make_model(graph, producer_name='gather_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [x, indices], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out) +@tvm.testing.uses_gpu def test_gather_nd(): verify_gather_nd((2, 2), [[0,0],[1,1]], 'int32') verify_gather_nd((3, 3, 3), [[0,1],[1,0]] , 'float32') verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32') +@tvm.testing.uses_gpu def test_onehot(): indices_shape = [10] indices_array = np.random.randint( @@ -709,12 +728,13 @@ def test_onehot(): model = helper.make_model(graph, producer_name="onehot_test") - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [indices_array], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_matmul(): a_shape = (4, 3) b_shape = (3, 4) @@ -736,7 +756,7 @@ def test_matmul(): model = helper.make_model(graph, producer_name='matmul_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [a_array, b_array], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -759,11 +779,12 @@ def verify_batch_matmul(a_shape, b_shape): model = helper.make_model(graph, producer_name='matmul_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [a_array, b_array], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_batch_matmul(): verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4)) verify_batch_matmul((2, 4, 3), (3, 4)) @@ -800,7 +821,7 @@ def _get_python_lrn(): py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta) return py_out - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): input_name = model.graph.input[0].name py_out = _get_python_lrn() tvm_out = get_tvm_output( @@ -808,6 +829,7 @@ def _get_python_lrn(): tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_lrn(): verify_lrn((5, 5, 5, 5), 3, 'float32') verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0) @@ -845,12 +867,13 @@ def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5): helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))], outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))]) model = helper.make_model(graph, producer_name='instance_norm_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [x, gamma, beta], target, ctx, shape, 'float32') tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_instance_norm(): verify_instance_norm((2, 3, 4, 5)) verify_instance_norm((32, 64, 80, 64)) @@ -877,7 +900,7 @@ def _test_upsample_nearest(): model = helper.make_model(graph, producer_name='upsample_nearest_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out) @@ -902,7 +925,7 @@ def _test_upsample3d_nearest(): model = helper.make_model(graph, producer_name='upsample_nearest_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out) @@ -926,7 +949,7 @@ def _test_upsample_bilinear(): model = helper.make_model(graph, producer_name='upsample_bilinear_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) @@ -961,7 +984,7 @@ def _test_upsample_bilinear_opset9(): model = helper.make_model( graph, producer_name='upsample_bilinear_opset9_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) @@ -995,11 +1018,12 @@ def _test_upsample3d_trilinear(): model = helper.make_model( graph, producer_name='upsample_trilinear_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_upsample(): _test_upsample_nearest() _test_upsample_bilinear() @@ -1026,12 +1050,13 @@ def _test_softmax(inshape, axis): model = helper.make_model(graph, producer_name=opname+'_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, indata, target, ctx, outshape, 'float32') tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_softmax(): _test_softmax((1, 10), None) _test_softmax((1, 10), 1) @@ -1061,12 +1086,13 @@ def verify_min(input_dim): model = helper.make_model(graph, producer_name='Min_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_forward_min(): verify_min((1, 3, 20, 20)) verify_min((20, 20)) @@ -1096,12 +1122,13 @@ def verify_max(input_dim): model = helper.make_model(graph, producer_name='Max_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_forward_max(): verify_max((1, 3, 20, 20)) verify_max((20, 20)) @@ -1131,12 +1158,13 @@ def verify_mean(input_dim): model = helper.make_model(graph, producer_name='Mean_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_forward_mean(): verify_mean((1, 3, 20, 20)) verify_mean((20, 20)) @@ -1161,11 +1189,12 @@ def verify_hardsigmoid(input_dim, alpha, beta): model = helper.make_model(graph, producer_name='HardSigmoid_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_forward_hardsigmoid(): verify_hardsigmoid((1, 3, 20, 20), 0.5, 0.6) verify_hardsigmoid((20, 20), 0.3, 0.4) @@ -1212,7 +1241,7 @@ def _argmin_numpy(data, axis=0, keepdims=True): model = helper.make_model(graph, producer_name='argmin_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [a_np1], target, ctx, b_np.shape, b_np.dtype) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -1260,12 +1289,13 @@ def _argmax_numpy(data, axis=0, keepdims=True): model = helper.make_model(graph, producer_name='argmax_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [a_np1], target, ctx, b_np.shape, b_np.dtype) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_forward_arg_min_max(): '''Verify argmin and argmax''' verify_argmin([3, 4, 4]) @@ -1309,12 +1339,13 @@ def verify_constantofshape(input_dim, value, dtype): model = helper.make_model(graph, producer_name='fill_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [], target, ctx, out.shape) tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_constantofshape(): verify_constantofshape((2, 3, 4, 5), 10, 'float32') verify_constantofshape((3, 3), 0, 'int32') @@ -1355,7 +1386,7 @@ def verify_pad(indata, pads, mode='constant', value=0.0): TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name='pad_test') # tvm result - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, indata, target, ctx, outdata.shape, 'float32', opset=2) tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) @@ -1411,12 +1442,13 @@ def verify_pad_v11(indata, pads, mode='constant', value=0.0): TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name='pad_test') # tvm result - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, inputs, target, ctx, outdata.shape, 'float32', opset=11) tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_pad(): verify_pad(np.random.randn(2, 2).astype( np.float32), [0, 1, 0, 0], 'constant', 0.0) @@ -1465,10 +1497,11 @@ def verify_reduce_func(func, data, axis, keepdims): model = helper.make_model(graph, producer_name='reduce_test') onnx_out = get_onnxruntime_output(model, data, 'float32') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, data, target, ctx, outshape, 'float32') tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_all_reduce_funcs(): funcs = ["ReduceMax", "ReduceMean", @@ -1532,7 +1565,7 @@ def verify_split(indata, outdatas, split, axis=0): ]) model = helper.make_model(graph, producer_name='split_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): output_shape = [o.shape for o in outdatas] output_type = ['float32', 'float32', 'float32'] tvm_out = get_tvm_output( @@ -1541,6 +1574,7 @@ def verify_split(indata, outdatas, split, axis=0): tvm.testing.assert_allclose(o, t) +@tvm.testing.uses_gpu def test_split(): # 1D verify_split([1., 2., 3., 4., 5., 6.], [ @@ -1554,6 +1588,7 @@ def test_split(): verify_split([1, 2, 3], [[1], [2], [3]], False) +@tvm.testing.uses_gpu def test_binary_ops(): in_shape = (1, 2, 3, 3) dtype = "float32" @@ -1573,7 +1608,7 @@ def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=No outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) model = helper.make_model(graph, producer_name='_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -1595,6 +1630,7 @@ def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=No verify_binary_ops("Equal", x, y, x == y, broadcast=True) +@tvm.testing.uses_gpu def test_single_ops(): in_shape = (1, 2, 3, 3) dtype = "float32" @@ -1609,7 +1645,7 @@ def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) model = helper.make_model(graph, producer_name='_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol) @@ -1639,6 +1675,7 @@ def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): verify_single_ops("SoftPlus", x, np.log(1 + np.exp(x))) +@tvm.testing.uses_gpu def test_leaky_relu(): def leaky_relu_x(x, alpha): return np.where(x >= 0, x, x * alpha) @@ -1650,6 +1687,7 @@ def leaky_relu_x(x, alpha): {'alpha': 0.25}) +@tvm.testing.uses_gpu def test_elu(): def elu_x(x, alpha): return np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) @@ -1661,6 +1699,7 @@ def elu_x(x, alpha): {'alpha': 0.25}) +@tvm.testing.uses_gpu def test_selu(): def selu_x(x, alpha, gamma): return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) @@ -1672,6 +1711,7 @@ def selu_x(x, alpha, gamma): {'alpha': 0.25, 'gamma': 0.3}) +@tvm.testing.uses_gpu def test_prelu(): def verify_prelu(x_shape, a_shape): node = helper.make_node('PRelu', @@ -1700,6 +1740,7 @@ def verify_prelu(x_shape, a_shape): verify_prelu([2,12,16,16], [1, 12, 1, 1]) +@tvm.testing.uses_gpu def test_ThresholdedRelu(): def ThresholdedRelu_x(x, alpha): out_np = np.clip(x, alpha, np.inf) @@ -1713,6 +1754,7 @@ def ThresholdedRelu_x(x, alpha): {'alpha': 0.25}) +@tvm.testing.uses_gpu def test_ScaledTanh(): def ScaledTanh_x(x, alpha, beta): return alpha * np.tanh(beta * x) @@ -1724,6 +1766,7 @@ def ScaledTanh_x(x, alpha, beta): {'alpha': 0.25, 'beta': 0.3}) +@tvm.testing.uses_gpu def test_ParametricSoftplus(): def ParametricSoftplus_x(x, alpha, beta): return alpha * np.log(np.exp(beta * x) + 1) @@ -1735,6 +1778,7 @@ def ParametricSoftplus_x(x, alpha, beta): {'alpha': 0.25, 'beta': 0.3}) +@tvm.testing.uses_gpu def test_Scale(): def Scale_x(x, scale): return scale * x @@ -1746,6 +1790,7 @@ def Scale_x(x, scale): {'scale': 0.25}) +@tvm.testing.uses_gpu def test_LogSoftmax(): _test_onnx_op_elementwise((1, 4), tvm.topi.testing.log_softmax_python, @@ -1762,13 +1807,14 @@ def check_torch_conversion(model, input_size): torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) onnx_model = onnx.load(file_name) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): input_data = np.random.uniform(size=input_size).astype('int32') c2_out = get_onnxruntime_output(onnx_model, input_data) tvm_out = get_tvm_output(onnx_model, input_data, target, ctx) tvm.testing.assert_allclose(c2_out, tvm_out) +@tvm.testing.uses_gpu def test_resnet(): check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224)) # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224)) @@ -1787,10 +1833,12 @@ def test_resnet(): # check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) +@tvm.testing.uses_gpu def test_densenet(): check_torch_conversion(torchvision.models.densenet161, (1, 3, 224, 224)) +@tvm.testing.uses_gpu def test_inception(): check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224)) @@ -1803,6 +1851,7 @@ def test_inception(): # check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) +@tvm.testing.uses_gpu def test_sign(): def Sign_x(x): return np.sign(x) @@ -1828,11 +1877,12 @@ def verify_not(indata, dtype): model = helper.make_model(graph, producer_name='not_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x], target, ctx, outdata.shape) tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_not(): # 2d verify_not(indata=(np.random.randn(3, 4) > 0), dtype=bool) @@ -1857,11 +1907,12 @@ def verify_and(indata, dtype): model = helper.make_model(graph, producer_name='and_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape) tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_and(): # 2d x = (np.random.randn(3, 4) > 0) @@ -1899,7 +1950,7 @@ def verify_tile_v1(indata, outdata, **kwargs): model = helper.make_model(graph, producer_name='tile_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [indata], target, ctx, outdata.shape, opset=1) tvm.testing.assert_allclose(outdata, tvm_out) @@ -1929,7 +1980,7 @@ def verify_tile_v6(indata, repeats, outdata): model = helper.make_model(graph, producer_name='tile_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [indata], target, ctx, @@ -1938,6 +1989,7 @@ def verify_tile_v6(indata, repeats, outdata): tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_tile(): x = np.random.rand(2, 3, 4, 5).astype(np.float32) repeats = np.random.randint( @@ -1956,11 +2008,12 @@ def verify_erf(indata, outdata): outputs=[helper.make_tensor_value_info('out', TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name='erf_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape) tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_erf(): x = np.random.rand(2, 3, 4, 6).astype(np.float32) z = scipy.special.erf(x) @@ -1977,11 +2030,12 @@ def verify_where(condition, x, y, dtype, outdata): outputs=[helper.make_tensor_value_info('out', dtype, list(outdata.shape))]) model = helper.make_model(graph, producer_name='where_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape) tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_where(): condition = np.array([[1, 0], [1, 1]], dtype=np.bool) x = np.array([[1, 2], [3, 4]], dtype=np.int64) @@ -2031,11 +2085,12 @@ def verify_or(indata, dtype): model = helper.make_model(graph, producer_name='or_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape) tvm.testing.assert_allclose(outdata, tvm_out) +@tvm.testing.uses_gpu def test_or(): # 2d x = (np.random.randn(3, 4) > 0) @@ -2063,6 +2118,7 @@ def test_or(): verify_or(indata=[x, y], dtype=bool) +@tvm.testing.uses_gpu def test_batch_norm(): def verify_batch_norm(in_shape): batchnorm = onnx.helper.make_node('BatchNormalization', @@ -2087,7 +2143,7 @@ def verify_batch_norm(in_shape): model = helper.make_model(graph, producer_name='batchnorm_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype('float32') scale = np.random.uniform(size=in_shape[1]).astype('float32') b = np.random.uniform(size=in_shape[1]).astype('float32') @@ -2104,6 +2160,7 @@ def verify_batch_norm(in_shape): verify_batch_norm([16, 16, 10, 10]) +@tvm.testing.uses_gpu def test_batch_norm_dynamic_subgraph(): def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): batchnorm = onnx.helper.make_node('BatchNormalization', @@ -2132,7 +2189,7 @@ def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): model = helper.make_model(graph, producer_name='batchnorm_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=in_shape).astype('float32') inp = np.random.uniform(size=o_shape).astype('float32') scale = np.random.uniform(size=in_shape[1]).astype('float32') @@ -2186,7 +2243,7 @@ def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilat model = helper.make_model(graph, producer_name='conv_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=x_shape).astype('float32') W = np.random.uniform(size=w_shape).astype('float32') tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape) @@ -2194,6 +2251,7 @@ def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilat tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_conv(): def repeat(N, D): return tuple([N for _ in range(D)]) @@ -2276,7 +2334,7 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p): model = helper.make_model(graph, producer_name='convtranspose_trest') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=x_shape).astype('float32') W = np.random.uniform(size=w_shape).astype('float32') tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape) @@ -2284,6 +2342,7 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p): tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_convtranspose(): # Convolution Transpose with padding # (1, 1, 3, 3) input tensor @@ -2293,6 +2352,7 @@ def test_convtranspose(): verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2]) +@tvm.testing.uses_gpu def test_unsqueeze_constant(): from torch.nn import Linear, Sequential, Module class Flatten(Module): @@ -2343,13 +2403,14 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p model = helper.make_model(graph, producer_name='pooling_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): onnx_out = get_onnxruntime_output(model, x_np, 'float32') tvm_out = get_tvm_output( model, [x_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_pooling(): for mode in ['max', 'average']: # Pool1D @@ -2440,12 +2501,13 @@ def verify_mod(x_shape, y_shape, fmod, out_shape, dtype='float32'): onnx_out = get_onnxruntime_output(model, [x_np, y_np], dtype)[0] - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [x_np, y_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_mod(): # Mod verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32") @@ -2481,12 +2543,13 @@ def verify_xor(x_shape, y_shape): onnx_dtype, list(out_shape))]) model = helper.make_model(graph, producer_name='xor_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [x_np, y_np], target, ctx, out_shape) tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_xor(): # XOR verify_xor(x_shape=[1, 32, 32], y_shape=[1, 32, 32]) @@ -2523,12 +2586,13 @@ def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_sh model = helper.make_model(graph, producer_name='pool_test') onnx_out = get_onnxruntime_output(model, [x_np, rois_np], 'float32')[0] - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output( model, [x_np, rois_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_max_roi_pool(): verify_max_roi_pool(x_shape=[1, 3, 6, 6], rois_shape=[3, 5], @@ -2572,13 +2636,14 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad=" model = helper.make_model(graph, producer_name='lppool_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): onnx_out = get_onnxruntime_output(model, x_np, 'float32') tvm_out = get_tvm_output( model, [x_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_lppool(): # Pool1D verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[1], pads=[1, 1], @@ -2728,7 +2793,7 @@ def verify_rnn(seq_length, model = helper.make_model(graph, producer_name='rnn_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): onnx_out = get_onnxruntime_output(model, input_values, 'float32') tvm_out = get_tvm_output( model, @@ -2741,6 +2806,7 @@ def verify_rnn(seq_length, tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3) +@tvm.testing.uses_gpu def test_lstm(): # No bias. verify_rnn( @@ -2845,6 +2911,7 @@ def test_lstm(): rnn_type='LSTM') +@tvm.testing.uses_gpu def test_gru(): # No bias. verify_rnn( @@ -2940,6 +3007,7 @@ def test_gru(): rnn_type='GRU') +@tvm.testing.uses_gpu def test_resize(): def make_constant_node(name, data_type, dims, vals): return helper.make_node('Constant', @@ -2977,7 +3045,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): model = helper.make_model(graph, producer_name='resize_test') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=ishape).astype('float32') onnx_out = get_onnxruntime_output(model, x, 'float32') tvm_out = get_tvm_output(model, x, target, ctx, oshape, 'float32', opset=11) @@ -2997,6 +3065,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel") +@tvm.testing.uses_gpu def test_nonzero(): def verify_nonzero(indata, outdata, dtype): @@ -3025,6 +3094,7 @@ def verify_nonzero(indata, outdata, dtype): result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 2, 2], [0, 1, 0, 1]] verify_nonzero(input_data, result, dtype=np.int64) +@tvm.testing.uses_gpu def test_topk(): def verify_topk(input_dims, K, axis=-1): output_dims = list(input_dims) @@ -3063,6 +3133,7 @@ def verify_topk(input_dims, K, axis=-1): verify_topk([n, n, n], 5, 2) +@tvm.testing.uses_gpu def test_roi_align(): def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0): output_dims = [num_roi, input_dims[1], output_height, output_width] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6840ca333005..cfe9507f27c7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -28,7 +28,7 @@ from tvm import relay from tvm.contrib import graph_runtime from tvm.contrib.nvcc import have_fp16 -from tvm.relay.testing.config import ctx_list +import tvm.testing sys.setrecursionlimit(10000) @@ -152,7 +152,6 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): def verify_model(model_name, input_data=[], custom_convert_map={}, - ctx_list=ctx_list(), rtol=1e-5, atol=1e-5): """Assert that the output of a compiled model matches with that of its baseline.""" @@ -198,7 +197,7 @@ def verify_model(model_name, input_data=[], [inp.cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3): - for target, ctx in ctx_list: + for target, ctx in tvm.testing.enabled_targets(): relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) relay_model.set_input(**relay_params) @@ -218,6 +217,7 @@ def verify_model(model_name, input_data=[], torch.cuda.empty_cache() # Single operator tests +@tvm.testing.uses_gpu def test_forward_add(): torch.set_grad_enabled(False) input_shape = [10] @@ -250,6 +250,7 @@ def forward(self, *args): verify_model(Add3().float().eval(), input_data=input_data) verify_model(Add4().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_subtract(): torch.set_grad_enabled(False) input_shape = [10] @@ -282,6 +283,7 @@ def forward(self, *args): verify_model(Subtract3().float().eval(), input_data=input_data) verify_model(Subtract4().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_multiply(): torch.set_grad_enabled(False) input_shape = [10] @@ -315,6 +317,7 @@ def forward(self, *args): verify_model(Multiply4().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_min_max(): class Max(Module): def forward(self, inp): @@ -352,6 +355,7 @@ def forward(self, lhs, rhs): verify_model(Min3(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_reciprocal(): torch.set_grad_enabled(False) input_shape = [2, 1, 10, 1, 10] @@ -362,6 +366,7 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Reciprocal1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_repeat(): torch.set_grad_enabled(False) input_shape = [1, 3] @@ -382,6 +387,7 @@ def forward(self, *args): verify_model(Repeat2().float().eval(), input_data=input_data) verify_model(Repeat3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_repeat_interleave(): torch.set_grad_enabled(False) input_shape = [2, 2, 3] @@ -407,6 +413,7 @@ def forward(self, *args): verify_model(RepeatInterleave3().float().eval(), input_data=input_data) verify_model(RepeatInterleave4().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_unsqueeze(): torch.set_grad_enabled(False) input_shape = [10, 10] @@ -418,6 +425,7 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Unsqueeze1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_squeeze(): torch.set_grad_enabled(False) input_shape = [2, 1, 10, 1, 10] @@ -434,6 +442,7 @@ def forward(self, *args): verify_model(Squeeze1().float().eval(), input_data=input_data) verify_model(Squeeze2().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_arange(): torch.set_grad_enabled(False) @@ -508,6 +517,7 @@ def forward(self, *args): verify_model(Arange11().float().eval()) verify_model(Arange12().float().eval()) +@tvm.testing.uses_gpu def test_forward_mesh_grid(): torch.set_grad_enabled(False) @@ -528,6 +538,7 @@ def forward(self, *args): verify_model(MeshGrid1().float().eval()) verify_model(MeshGrid2().float().eval()) +@tvm.testing.uses_gpu def test_forward_abs(): torch.set_grad_enabled(False) input_shape = [2, 1, 10, 1, 10] @@ -539,6 +550,7 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Abs1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_concatenate(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -558,18 +570,21 @@ def forward(self, *args): verify_model(Concatenate1().float().eval(), input_data=input_data) verify_model(Concatenate2().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_relu(): torch.set_grad_enabled(False) input_shape = [10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.ReLU().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_prelu(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.PReLU(num_parameters=3).eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_leakyrelu(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -579,6 +594,7 @@ def test_forward_leakyrelu(): verify_model(torch.nn.LeakyReLU(negative_slope=1.0, inplace=True).eval(), input_data=input_data) verify_model(torch.nn.LeakyReLU(negative_slope=1.25, inplace=True).eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_elu(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -588,6 +604,7 @@ def test_forward_elu(): verify_model(torch.nn.ELU(alpha=1.0).eval(), input_data=input_data) verify_model(torch.nn.ELU(alpha=1.3).eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_celu(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -597,18 +614,21 @@ def test_forward_celu(): verify_model(torch.nn.CELU(alpha=1.0).eval(), input_data=input_data) verify_model(torch.nn.CELU(alpha=1.3).eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_gelu(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.GELU().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_selu(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.SELU().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_softplus(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -617,18 +637,21 @@ def test_forward_softplus(): verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data) verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_softsign(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.Softsign().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_log_sigmoid(): torch.set_grad_enabled(False) input_shape = [10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.LogSigmoid().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_adaptiveavgpool(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -636,6 +659,7 @@ def test_forward_adaptiveavgpool(): verify_model(torch.nn.AdaptiveAvgPool2d([1, 1]).eval(), input_data=input_data) verify_model(torch.nn.AdaptiveAvgPool2d([10, 10]).eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_maxpool2d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -661,6 +685,7 @@ def forward(self, *args): verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_maxpool1d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10] @@ -675,6 +700,7 @@ def test_forward_maxpool1d(): stride=2).eval(), input_data) +@tvm.testing.uses_gpu def test_forward_maxpool3d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10, 10] @@ -689,6 +715,7 @@ def test_forward_maxpool3d(): stride=2).eval(), input_data) +@tvm.testing.uses_gpu def test_forward_split(): torch.set_grad_enabled(False) input_shape = [4, 10] @@ -712,6 +739,7 @@ def forward(self, *args): verify_model(Split([2, 3, 5], 1).float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_avgpool(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -724,6 +752,7 @@ def forward(self, *args): verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data) verify_model(AvgPool2D2().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_avgpool3d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10, 10] @@ -736,12 +765,14 @@ def forward(self, *args): verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data) verify_model(AvgPool3D1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_hardtanh(): torch.set_grad_enabled(False) input_shape = [10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.Hardtanh().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_conv(): torch.set_grad_enabled(False) conv1d_input_shape = [1, 3, 10] @@ -816,6 +847,7 @@ def forward(self, *args): verify_model(Conv1D2().float().eval(), input_data=conv1d_input_data) verify_model(Conv1D3().float().eval(), input_data=conv1d_input_data) +@tvm.testing.uses_gpu def test_forward_conv_transpose(): torch.set_grad_enabled(False) conv2d_input_shape = [1, 3, 10, 10] @@ -829,12 +861,14 @@ def test_forward_conv_transpose(): verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data) +@tvm.testing.uses_gpu def test_forward_threshold(): torch.set_grad_enabled(False) input_shape = [1, 3] input_data = torch.rand(input_shape).float() verify_model(torch.nn.Threshold(0, 0).float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_contiguous(): torch.set_grad_enabled(False) input_shape = [10] @@ -847,6 +881,7 @@ def forward(self, *args): verify_model(Contiguous1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_batchnorm(): def init_weight(m): torch.nn.init.normal_(m.weight, 0, 0.01) @@ -861,6 +896,7 @@ def init_weight(m): verify_model(bn.eval(), input_data=inp) +@tvm.testing.uses_gpu def test_forward_instancenorm(): inp_2d = torch.rand((1, 16, 10, 10)) inp_3d = torch.rand((1, 16, 10, 10, 10)) @@ -869,6 +905,7 @@ def test_forward_instancenorm(): (torch.nn.InstanceNorm3d(16), inp_3d)]: verify_model(ins_norm.eval(), input_data=inp) +@tvm.testing.uses_gpu def test_forward_layernorm(): def init_weight(m): torch.nn.init.normal_(m.weight, 0, 0.01) @@ -882,6 +919,7 @@ def init_weight(m): verify_model(ln.eval(), input_data=inp) +@tvm.testing.uses_gpu def test_forward_groupnorm(): input_shape = [10, 6, 5, 5] input_data = torch.rand(input_shape).float() @@ -903,6 +941,7 @@ def test_forward_groupnorm(): verify_model(torch.nn.GroupNorm(10, 10).eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_reshape(): torch.set_grad_enabled(False) input_shape = [2, 1, 10, 1, 10] @@ -920,6 +959,7 @@ def forward(self, *args): verify_model(Reshape2().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_flatten(): class Flatten(Module): def forward(self, x): @@ -934,6 +974,7 @@ def forward(self, x): verify_model(BatchFlatten(), input_data=inp) +@tvm.testing.uses_gpu def test_forward_transpose(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -955,6 +996,7 @@ def forward(self, *args): verify_model(Transpose2().float().eval(), input_data=input_data) verify_model(Transpose3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_size(): torch.set_grad_enabled(False) input_shape = [1, 3] @@ -967,6 +1009,7 @@ def forward(self, *args): verify_model(Size1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_type_as(): torch.set_grad_enabled(False) input_shape = [1, 3] @@ -1004,6 +1047,7 @@ def forward(self, *args): verify_model(_create_module(torch.float16), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_view(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1026,7 +1070,7 @@ def forward(self, *args): verify_model(View2().float().eval(), input_data=input_data) verify_model(View3().float().eval(), input_data=input_data) - +@tvm.testing.uses_gpu def test_forward_select(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1055,6 +1099,7 @@ def forward(self, index): verify_model(IndexedSelect(x, 1).eval(), input_data=indices) +@tvm.testing.uses_gpu def test_forward_clone(): torch.set_grad_enabled(False) input_shape = [10] @@ -1067,6 +1112,7 @@ def forward(self, *args): verify_model(Clone1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_gather(): torch.set_grad_enabled(False) @@ -1105,6 +1151,7 @@ def forward(self, *args): verify_model(Gather3().float().eval(), input_data=[input_data, index]) +@tvm.testing.uses_gpu def test_forward_logsoftmax(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1117,6 +1164,7 @@ def forward(self, *args): verify_model(LogSoftmax1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_norm(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1174,6 +1222,7 @@ def forward(self, *args): verify_model(Norm10().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_frobenius_norm(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1201,12 +1250,14 @@ def forward(self, *args): verify_model(FroNorm4().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_sigmoid(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.Sigmoid().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_dense(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1236,6 +1287,7 @@ def forward(self, *args): ) assert not any([op.name == "multiply" for op in list_ops(mod['main'])]) +@tvm.testing.uses_gpu def test_forward_dropout(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1245,6 +1297,7 @@ def test_forward_dropout(): verify_model(torch.nn.Dropout3d(p=0.5).eval(), input_data=input_data) verify_model(torch.nn.AlphaDropout(p=0.5).eval(), input_data=input_data[0, 0]) +@tvm.testing.uses_gpu def test_forward_slice(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1269,6 +1322,7 @@ def forward(self, *args): verify_model(Slice3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_mean(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1280,6 +1334,7 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Mean1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_expand(): torch.set_grad_enabled(False) @@ -1300,6 +1355,7 @@ def forward(self, *args): verify_model(Expand2().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_pow(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1311,6 +1367,7 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Pow1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_chunk(): torch.set_grad_enabled(False) input_shape = [1, 3, 14, 14] @@ -1323,6 +1380,7 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Chunk1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_upsample(): class Upsample(Module): def __init__(self, size=None, scale=None, @@ -1346,6 +1404,7 @@ def forward(self, x): verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp) verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp) +@tvm.testing.uses_gpu def test_to(): """ test for aten::to(...) """ class ToCPU(Module): @@ -1377,6 +1436,7 @@ def forward(self, x): verify_model(ToDouble().eval(), torch.tensor(0.8)) +@tvm.testing.uses_gpu def test_adaptive_pool3d(): for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), @@ -1390,6 +1450,7 @@ def test_adaptive_pool3d(): verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp) +@tvm.testing.uses_gpu def test_forward_functional_pad(): torch.set_grad_enabled(False) pad = (0, 0) @@ -1408,12 +1469,14 @@ def forward(self, *args): verify_model(Pad1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_zero_pad2d(): inp = torch.rand((1, 1, 3, 3)) verify_model(torch.nn.ZeroPad2d(2).eval(), inp) verify_model(torch.nn.ZeroPad2d((1, 1, 2, 0)).eval(), inp) +@tvm.testing.uses_gpu def test_forward_constant_pad1d(): inp = torch.rand((1, 2, 4)) verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp) @@ -1422,18 +1485,21 @@ def test_forward_constant_pad1d(): verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp) +@tvm.testing.uses_gpu def test_forward_constant_pad2d(): inp = torch.rand((1, 2, 2, 2)) verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp) verify_model(torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5).eval(), inp) +@tvm.testing.uses_gpu def test_forward_constant_pad3d(): inp = torch.rand((1, 3, 2, 2, 2)) verify_model(torch.nn.ConstantPad3d(3, 3.5).eval(), inp) verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp) +@tvm.testing.uses_gpu def test_forward_reflection_pad1d(): inp = torch.rand((1, 2, 4)) verify_model(torch.nn.ReflectionPad1d(2).eval(), inp) @@ -1443,6 +1509,7 @@ def test_forward_reflection_pad1d(): verify_model(torch.nn.ReflectionPad1d((2, 3)).eval(), inp) +@tvm.testing.uses_gpu def test_forward_reflection_pad2d(): inp = torch.rand((1, 1, 3, 3)) verify_model(torch.nn.ReflectionPad2d(2).eval(), inp) @@ -1452,6 +1519,7 @@ def test_forward_reflection_pad2d(): verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp) +@tvm.testing.uses_gpu def test_forward_replication_pad1d(): inp = torch.rand((1, 2, 4)) verify_model(torch.nn.ReplicationPad1d(2).eval(), inp) @@ -1461,6 +1529,7 @@ def test_forward_replication_pad1d(): verify_model(torch.nn.ReplicationPad1d((2, 3)).eval(), inp) +@tvm.testing.uses_gpu def test_forward_replication_pad2d(): inp = torch.rand((1, 1, 3, 3)) verify_model(torch.nn.ReplicationPad2d(2).eval(), inp) @@ -1470,6 +1539,7 @@ def test_forward_replication_pad2d(): verify_model(torch.nn.ReplicationPad2d((1, 3, 2, 4)).eval(), inp) +@tvm.testing.uses_gpu def test_forward_replication_pad3d(): inp = torch.rand((1, 1, 3, 3, 3)) verify_model(torch.nn.ReplicationPad3d(3).eval(), inp) @@ -1479,6 +1549,7 @@ def test_forward_replication_pad3d(): verify_model(torch.nn.ReplicationPad3d((2, 3, 2, 5, 1, 4)).eval(), inp) +@tvm.testing.uses_gpu def test_forward_upsample3d(): inp = torch.arange(1, 9, dtype=torch.float32).view(1, 1, 2, 2, 2) verify_model(torch.nn.Upsample(scale_factor=2, mode='nearest').eval(), inp) @@ -1511,6 +1582,7 @@ def _gen_rand_inputs(num_boxes): verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores]) +@tvm.testing.uses_gpu def test_conv3d(): for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), @@ -1529,6 +1601,7 @@ def test_conv3d(): inp) +@tvm.testing.uses_gpu def test_conv3d_transpose(): for ishape in [(1, 8, 10, 5, 10), (1, 8, 5, 8, 8), @@ -1557,53 +1630,65 @@ def test_conv3d_transpose(): # Model tests +@tvm.testing.uses_gpu def test_resnet18(): torch.set_grad_enabled(False) verify_model("resnet18", atol=1e-4, rtol=1e-4) +@tvm.testing.uses_gpu def test_squeezenet1_0(): torch.set_grad_enabled(False) verify_model("squeezenet1_0", atol=1e-4, rtol=1e-4) +@tvm.testing.uses_gpu def test_squeezenet1_1(): torch.set_grad_enabled(False) verify_model("squeezenet1_1", atol=1e-4, rtol=1e-4) +@tvm.testing.uses_gpu def test_densenet121(): torch.set_grad_enabled(False) verify_model("densenet121", atol=1e-4, rtol=1e-4) +@tvm.testing.uses_gpu def test_inception_v3(): torch.set_grad_enabled(False) verify_model("inception_v3", atol=1e-4, rtol=1e-4) +@tvm.testing.uses_gpu def test_googlenet(): torch.set_grad_enabled(False) verify_model("googlenet", atol=1e-4, rtol=1e-4) +@tvm.testing.uses_gpu def test_mnasnet0_5(): torch.set_grad_enabled(False) verify_model("mnasnet0_5", atol=1e-4, rtol=1e-4) +@tvm.testing.uses_gpu def test_mobilenet_v2(): torch.set_grad_enabled(False) verify_model("mobilenet_v2", atol=1e-4, rtol=1e-4) """ #TODO: Fix VGG and AlexNet issues (probably due to pooling) +@tvm.testing.uses_gpu def test_alexnet(): torch.set_grad_enabled(False) verify_model("alexnet") +@tvm.testing.uses_gpu def test_vgg11(): torch.set_grad_enabled(False) verify_model("vgg11") +@tvm.testing.uses_gpu def test_vgg11_bn(): torch.set_grad_enabled(False) verify_model("vgg11_bn") """ +@tvm.testing.uses_gpu def test_custom_conversion_map(): def get_roi_align(): pool_size = 5 @@ -1633,6 +1718,7 @@ def _impl(inputs, input_types): verify_model(model, inputs, custom_map) +@tvm.testing.uses_gpu def test_segmentaton_models(): class SegmentationModelWrapper(Module): def __init__(self, model): @@ -1652,6 +1738,7 @@ def forward(self, inp): verify_model(SegmentationModelWrapper(deeplab.eval()), inp, atol=1e-4, rtol=1e-4) +@tvm.testing.uses_gpu def test_3d_models(): input_shape = (1, 3, 4, 56, 56) resnet3d = torchvision.models.video.r3d_18(pretrained=True).eval() @@ -1700,6 +1787,7 @@ def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None): rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_control_flow(): class SimpleIf(torch.nn.Module): def __init__(self, N, M): @@ -1813,6 +1901,7 @@ def forward(self, inp): verify_script_model(pt_model.eval(), [(10, 20)]) +@tvm.testing.uses_gpu def test_simple_rnn(): # The mixed tracing and scripting example from # https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#mixing-scripting-and-tracing @@ -1850,6 +1939,7 @@ def forward(self, xs): verify_script_model(RNNLoop().eval(), [(10, 10, 4)]) +@tvm.testing.uses_gpu def test_forward_reduce_sum(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1882,6 +1972,7 @@ def forward(self, *args): verify_model(ReduceSum5().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_reduce_prod(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1904,6 +1995,7 @@ def forward(self, *args): verify_model(ReduceProd3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_argmin(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1926,6 +2018,7 @@ def forward(self, *args): verify_model(ArgMin3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_argmax(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1948,6 +2041,7 @@ def forward(self, *args): verify_model(ArgMax3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_std(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2000,6 +2094,7 @@ def forward(self, *args): verify_model(Std9().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_variance(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2052,6 +2147,7 @@ def forward(self, *args): verify_model(Variance9().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_rsub(): torch.set_grad_enabled(False) @@ -2072,6 +2168,7 @@ def forward(self, *args): verify_model(Rsub2().float().eval(), input_data=[d1, d3]) +@tvm.testing.uses_gpu def test_forward_embedding(): torch.set_grad_enabled(False) @@ -2085,6 +2182,7 @@ def test_forward_embedding(): verify_model(torch.nn.Embedding(4, 5, sparse=True).float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_onehot(): torch.set_grad_enabled(False) @@ -2103,6 +2201,7 @@ def forward(self, *args): verify_model(OneHot2().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_isfinite(): torch.set_grad_enabled(False) @@ -2114,6 +2213,7 @@ def forward(self, *args): verify_model(IsFinite1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_isnan(): torch.set_grad_enabled(False) @@ -2125,6 +2225,7 @@ def forward(self, *args): verify_model(IsNan1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_isinf(): torch.set_grad_enabled(False) @@ -2136,6 +2237,7 @@ def forward(self, *args): verify_model(IsInf1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_clamp(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2158,6 +2260,7 @@ def forward(self, *args): verify_model(Clamp3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_ones(): torch.set_grad_enabled(False) @@ -2168,6 +2271,7 @@ def forward(self, *args): verify_model(Ones1().float().eval(), input_data=[]) +@tvm.testing.uses_gpu def test_forward_ones_like(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2190,6 +2294,7 @@ def forward(self, *args): verify_model(OnesLike3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_zeros(): torch.set_grad_enabled(False) @@ -2200,6 +2305,7 @@ def forward(self, *args): verify_model(Zeros1().float().eval(), input_data=[]) +@tvm.testing.uses_gpu def test_forward_zeros_like(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2222,6 +2328,7 @@ def forward(self, *args): verify_model(ZerosLike3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_full(): torch.set_grad_enabled(False) @@ -2237,6 +2344,7 @@ def forward(self, *args): verify_model(Full2().float().eval(), input_data=[]) +@tvm.testing.uses_gpu def test_forward_full_like(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2258,6 +2366,7 @@ def forward(self, *args): verify_model(FullLike2().float().eval(), input_data=input_data) verify_model(FullLike3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_linspace(): torch.set_grad_enabled(False) @@ -2296,6 +2405,7 @@ def forward(self, *args): verify_model(Linspace8().float().eval()) +@tvm.testing.uses_gpu def test_forward_take(): torch.set_grad_enabled(False) class Take1(Module): @@ -2315,6 +2425,7 @@ def forward(self, *args): verify_model(Take2().float().eval(), input_data=[input_data, indices]) +@tvm.testing.uses_gpu def test_forward_topk(): torch.set_grad_enabled(False) class Topk1(Module): @@ -2351,6 +2462,7 @@ def forward(self, *args): verify_model(Topk6().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_logical_not(): torch.set_grad_enabled(False) @@ -2371,6 +2483,7 @@ def forward(self, *args): verify_model(LogicalNot1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_bitwise_not(): torch.set_grad_enabled(False) @@ -2388,6 +2501,7 @@ def forward(self, *args): verify_model(BitwiseNot1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_bitwise_xor(): torch.set_grad_enabled(False) @@ -2414,6 +2528,7 @@ def forward(self, *args): verify_model(BitwiseXor2().float().eval(), input_data=[lhs]) +@tvm.testing.uses_gpu def test_forward_logical_xor(): torch.set_grad_enabled(False) @@ -2440,6 +2555,7 @@ def forward(self, *args): verify_model(LogicalXor2().float().eval(), input_data=[lhs]) +@tvm.testing.uses_gpu def test_forward_unary(): torch.set_grad_enabled(False) @@ -2562,6 +2678,7 @@ def forward(self, *args): verify_model(Neg1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu def test_forward_where(): torch.set_grad_enabled(False) @@ -2582,6 +2699,7 @@ def forward(self, *args): verify_model(Where2().float().eval(), input_data=[x, y]) +@tvm.testing.uses_gpu def test_forward_addcdiv(): torch.set_grad_enabled(False) @@ -2605,6 +2723,7 @@ def forward(self, *args): verify_model(Addcdiv2().float().eval(), input_data=[input_data, t1, t2]) +@tvm.testing.uses_gpu def test_forward_addcmul(): torch.set_grad_enabled(False) @@ -2627,6 +2746,7 @@ def forward(self, *args): t2 = torch.rand([1, 3]).float() verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2]) +@tvm.testing.uses_gpu def test_forward_traced_function(): def fn(t1, t2): return t1 + t2 @@ -2635,6 +2755,7 @@ def fn(t1, t2): tensor2 = torch.randn(3, 4) verify_model(fn, input_data=[tensor1, tensor2]) +@tvm.testing.uses_gpu def test_forward_dtypes(): def fn(t1, t2): return 2.5 * t1 + t2 @@ -2658,12 +2779,14 @@ def forward(self, x): verify_model(ModuleWithIntParameters(param), input_data=inp) +@tvm.testing.uses_gpu def test_weight_names(): tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)]) mod, params = relay.frontend.from_pytorch(tm, [('input', (2, 3))]) assert set(params.keys()) == set(n for n, p in tm.named_parameters()) +@tvm.testing.uses_gpu def test_duplicate_weight_use(): # The test cases doesn't make any sense as a neural network, # the issue popped up in shared input/output embeddings of bert, @@ -2681,6 +2804,7 @@ def forward(self, x): verify_model(Test(), input_data=[torch.randn(5, 5)]) +@tvm.testing.uses_gpu def test_forward_matmul(): torch.set_grad_enabled(False) diff --git a/tests/python/frontend/tensorflow/test_bn_dynamic.py b/tests/python/frontend/tensorflow/test_bn_dynamic.py index e80d774408a3..010899cabc74 100644 --- a/tests/python/frontend/tensorflow/test_bn_dynamic.py +++ b/tests/python/frontend/tensorflow/test_bn_dynamic.py @@ -45,7 +45,7 @@ def verify_fused_batch_norm(shape): for device in ["llvm"]: ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue mod, params = relay.frontend.from_tensorflow(constant_graph, diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 799d9c20058a..37a32be34699 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -49,6 +49,8 @@ from tvm.runtime.vm import VirtualMachine from packaging import version as package_version +import tvm.testing + ####################################################################### # Generic run functions for TVM & tensorflow # ------------------------------------------ @@ -198,7 +200,7 @@ def name_without_num(name): for device in ["llvm", "cuda"]: ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue if no_gpu and device == 'cuda': @@ -262,6 +264,7 @@ def _test_pooling(input_shape, **kwargs): _test_pooling_iteration(input_shape, **kwargs) +@tvm.testing.uses_gpu def test_forward_pooling(): """ Pooling """ # TensorFlow only supports NDHWC for max_pool3d on CPU @@ -408,6 +411,7 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, 'Placeholder:0', 'DepthwiseConv2dNative:0') +@tvm.testing.uses_gpu def test_forward_convolution(): if is_gpu_available(): _test_convolution('conv', [4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW') @@ -526,6 +530,7 @@ def _test_convolution3d(opname, tensor_in_sizes, filter_in_sizes, compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), 'Placeholder:0', 'Conv3D:0', cuda_layout="NCDHW") +@tvm.testing.uses_gpu def test_forward_convolution3d(): if is_gpu_available(): _test_convolution3d('conv', [4, 176, 8, 8, 8], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW') @@ -569,6 +574,7 @@ def _test_convolution3d_transpose(data_shape, filter_shape, strides, compare_tf_with_tvm(data_array, 'Placeholder:0', 'conv3d_transpose:0', cuda_layout="NDHWC") +@tvm.testing.uses_gpu def test_forward_convolution3d_transpose(): if is_gpu_available(): _test_convolution3d_transpose(data_shape=[1, 10, 8, 8, 8], @@ -655,6 +661,7 @@ def _test_biasadd(tensor_in_sizes, data_format): 'Placeholder:0', 'BiasAdd:0') +@tvm.testing.uses_gpu def test_forward_biasadd(): if is_gpu_available(): _test_biasadd([4, 176, 8, 8], 'NCHW') @@ -1230,7 +1237,8 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) -def test_read_variable_op(): +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_read_variable_op(target, ctx): """ Read Variable op test """ tf.reset_default_graph() @@ -1270,18 +1278,12 @@ def test_read_variable_op(): out_node, ) - for device in ["llvm", "cuda"]: - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - continue - - tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, - target=device, out_names=out_name, - num_output=len(out_name)) - for i in range(len(tf_output)): - tvm.testing.assert_allclose( - tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5) + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, + target=target, out_names=out_name, + num_output=len(out_name)) + for i in range(len(tf_output)): + tvm.testing.assert_allclose( + tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5) sess.close() @@ -2382,6 +2384,7 @@ def test_forward_mobilenet(): # -------- +@tvm.testing.requires_gpu def test_forward_resnetv2(): '''test resnet model''' if is_gpu_available(): @@ -2399,7 +2402,7 @@ def test_forward_resnetv2(): sess, data, 'input_tensor:0', out_node + ':0') for device in ["llvm", "cuda"]: ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), @@ -2431,7 +2434,7 @@ def _test_ssd_impl(): # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready. for device in ["llvm"]: ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue tvm_output = run_tvm_graph(graph_def, data, in_node, len(out_node), @@ -3754,7 +3757,7 @@ def test_forward_dynamic_input_shape(): # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready. for device in ["llvm"]: ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue tvm_output = run_tvm_graph(graph_def, np_data, ["data"], 1, diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 5496dfe96241..3e950003b6bf 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -278,7 +278,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, for device in ["llvm"]: ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue diff --git a/tests/python/integration/test_dot.py b/tests/python/integration/test_dot.py index 27f3788fef5c..0bfe61a1de6c 100644 --- a/tests/python/integration/test_dot.py +++ b/tests/python/integration/test_dot.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te import numpy as np +@tvm.testing.requires_llvm def test_dot(): nn = 12 n = tvm.runtime.convert(nn) @@ -29,9 +31,6 @@ def test_dot(): s = te.create_schedule(C.op) def verify(target): - if not tvm.runtime.enabled(target): - print("Target %s is not enabled" % target) - return f = tvm.driver.build(s, [A, B, C], target) # verify ctx = tvm.cpu(0) diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index dfa247e5a09a..d2fb503c3975 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -19,7 +19,9 @@ from tvm.contrib import nvcc import numpy as np import time +import tvm.testing +@tvm.testing.requires_gpu def test_exp(): # graph n = tvm.runtime.convert(1024) @@ -34,11 +36,9 @@ def test_exp(): # one line to build the function. def check_device(device, host="stackvm"): - if not tvm.runtime.enabled(host): + if not tvm.testing.device_enabled(host): return ctx = tvm.context(device, 0) - if not ctx.exist: - return fexp = tvm.build(s, [A, B], device, host, name="myexp") @@ -55,6 +55,7 @@ def check_device(device, host="stackvm"): check_device("cuda", "llvm") check_device("vulkan") +@tvm.testing.requires_gpu def test_fmod(): # graph def run(dtype): @@ -69,7 +70,7 @@ def run(dtype): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return target = tvm.target.create(device) @@ -102,6 +103,7 @@ def check_device(device): run("float32") +@tvm.testing.requires_gpu def test_multiple_cache_write(): # graph n = tvm.runtime.convert(1024) @@ -123,10 +125,10 @@ def test_multiple_cache_write(): s[C].bind(tx, te.thread_axis("threadIdx.x")) # one line to build the function. def check_device(device, host="stackvm"): - if not tvm.runtime.enabled(host): + if not tvm.testing.device_enabled(host): return ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): return func = tvm.build(s, [A0, A1, C], device, host, @@ -155,7 +157,7 @@ def test_log_pow_llvm(): # create iter var and assign them tags. bx, tx = s[B].split(B.op.axis[0], factor=32) # one line to build the function. - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return flog = tvm.build(s, [A, B], @@ -173,6 +175,7 @@ def test_log_pow_llvm(): b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5) +@tvm.testing.uses_gpu def test_popcount(): def run(dtype): # graph @@ -186,7 +189,7 @@ def run(dtype): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return target = tvm.target.create(device) @@ -212,6 +215,7 @@ def check_device(device): run('uint64') +@tvm.testing.requires_gpu def test_add(): def run(dtype): # graph @@ -235,7 +239,7 @@ def run(dtype): # one line to build the function. def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return fadd = tvm.build(s, [A, B, C], @@ -264,6 +268,7 @@ def check_device(device): run("uint64") +@tvm.testing.requires_gpu def try_warp_memory(): """skip this in default test because it require higher arch""" m = 128 @@ -289,7 +294,7 @@ def tvm_callback_cuda_compile(code): # one line to build the function. def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return f = tvm.build(s, [A, B], device) diff --git a/tests/python/integration/test_ewise_fpga.py b/tests/python/integration/test_ewise_fpga.py index 7883a4cc4dce..abcddc452910 100644 --- a/tests/python/integration/test_ewise_fpga.py +++ b/tests/python/integration/test_ewise_fpga.py @@ -40,11 +40,9 @@ def test_exp(): # one line to build the function. def check_device(device, host="llvm"): - if not tvm.runtime.enabled(host): + if not tvm.testing.device_enabled(device): return ctx = tvm.context(device, 0) - if not ctx.exist: - return fexp = tvm.build(s, [A, B], device, host, name="myexp") @@ -79,11 +77,9 @@ def test_multi_kernel(): # one line to build the function. def check_device(device, host="llvm"): - if not tvm.runtime.enabled(host): + if not tvm.testing.device_enabled(device): return ctx = tvm.context(device, 0) - if not ctx.exist: - return fadd = tvm.build(s, [A, B, C, D], device, host, name="myadd") diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index 12026da61394..1b7d54e2177f 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -18,8 +18,10 @@ from tvm import te import numpy as np import time +import tvm.testing +@tvm.testing.requires_gpu def test_gemm(): # graph nn = 1024 @@ -82,7 +84,7 @@ def test_gemm(): # one line to build the function. def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 67f6dcf9ce8c..35980ed602c1 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -17,8 +17,10 @@ import tvm from tvm import te import numpy as np +import tvm.testing +@tvm.testing.requires_gpu def test_reduce_prims(): def test_prim(reducer, np_reducer): # graph @@ -40,9 +42,7 @@ def test_prim(reducer, np_reducer): # one line to build the function. def check_device(device, host="llvm"): ctx = tvm.context(device, 0) - if not tvm.runtime.enabled(host): - return - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return freduce = tvm.build(s, @@ -140,7 +140,7 @@ def test_rfactor(): s[BF].parallel(BF.op.axis[0]) # one line to build the function. def check_target(target="llvm"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): return ctx = tvm.cpu(0) fapi = tvm.lower(s, args=[A, B]) @@ -204,7 +204,7 @@ def test_rfactor_factor_axis(): s[BF].parallel(BF.op.axis[0]) # one line to build the function. def check_target(target="llvm"): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): return ctx = tvm.cpu(0) fapi = tvm.lower(s, args=[A, B]) @@ -223,6 +223,7 @@ def check_target(target="llvm"): check_target() +@tvm.testing.requires_gpu def test_rfactor_threads(): nn = 1027 mm = 10 @@ -248,7 +249,7 @@ def test_rfactor_threads(): # one line to build the function. def check_target(device, host="stackvm"): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return @@ -273,6 +274,7 @@ def check_target(device, host="stackvm"): check_target("opencl") check_target("rocm") +@tvm.testing.requires_gpu def test_rfactor_elemwise_threads(): n = 1025 m = 10 @@ -303,7 +305,7 @@ def test_rfactor_elemwise_threads(): # one line to build the function. def check_target(device, host="stackvm"): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return fapi = tvm.lower(s, args=[A, C]) @@ -346,7 +348,7 @@ def fidentity(t0, t1): def check_target(): device = 'cpu' - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return ctx = tvm.context(device, 0) @@ -371,6 +373,7 @@ def check_target(): check_target() +@tvm.testing.requires_gpu def test_rfactor_argmax(): def fcombine(x, y): lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) @@ -409,7 +412,7 @@ def fidentity(t0, t1): def check_target(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return fapi = tvm.lower(s, args=[A0, A1, B0, B1]) @@ -432,6 +435,7 @@ def check_target(device): check_target("vulkan") check_target("rocm") +@tvm.testing.requires_gpu def test_warp_reduction1(): nthx = 32 nthy = 4 @@ -441,7 +445,7 @@ def test_warp_reduction1(): def check_target(device, m, n): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return @@ -478,6 +482,7 @@ def check_target(device, m, n): # This is a bug in normal reduction. # check_target("cuda", m=10, n=37) +@tvm.testing.requires_gpu def test_warp_reduction2(): def fcombine(x, y): return x[0] + y[0], x[1] * y[1] @@ -503,7 +508,7 @@ def fidentity(t0, t1): def check_target(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return diff --git a/tests/python/integration/test_scan.py b/tests/python/integration/test_scan.py index 99553c3579d5..9a61e6086d3c 100644 --- a/tests/python/integration/test_scan.py +++ b/tests/python/integration/test_scan.py @@ -17,7 +17,9 @@ import tvm from tvm import te import numpy as np +import tvm.testing +@tvm.testing.requires_gpu def test_scan(): m = te.size_var("m") n = te.size_var("n") @@ -47,7 +49,7 @@ def test_scan(): # one line to build the function. def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return fscan = tvm.build(s, [X, res], diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 95b94f693b39..5f45119b57bd 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -26,6 +26,8 @@ from tvm import autotvm from tvm.autotvm.tuner import RandomTuner +import tvm.testing + @autotvm.template("testing/conv2d_no_batching") def conv2d_no_batching(N, H, W, CI, CO, KH, KW): """An example template for testing""" @@ -120,26 +122,18 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None): target=target, target_host=target_host) return task, target -def test_tuning(): - def check(target, target_host): - ctx = tvm.context(target, 0) - if not ctx.exist: - logging.info("Skip test because %s is not available" % target) - return - - # init task - task, target = get_sample_task(target, target_host) - logging.info("%s", task.config_space) - - measure_option = autotvm.measure_option( - autotvm.LocalBuilder(), - autotvm.LocalRunner()) +@tvm.testing.parametrize_targets("cuda", "opencl") +def test_tuning(target, ctx): + # init task + task, target = get_sample_task(target, None) + logging.info("%s", task.config_space) - tuner = RandomTuner(task) - tuner.tune(n_trial=20, measure_option=measure_option) + measure_option = autotvm.measure_option( + autotvm.LocalBuilder(), + autotvm.LocalRunner()) - check("cuda", None) - check("opencl", None) + tuner = RandomTuner(task) + tuner.tune(n_trial=20, measure_option=measure_option) if __name__ == "__main__": # only print log when invoked from main diff --git a/tests/python/integration/test_winograd_nnpack.py b/tests/python/integration/test_winograd_nnpack.py index 994a047df742..e6841ddc8132 100644 --- a/tests/python/integration/test_winograd_nnpack.py +++ b/tests/python/integration/test_winograd_nnpack.py @@ -25,6 +25,7 @@ import tvm.topi.testing from tvm.topi.util import get_const_tuple from pytest import skip +import tvm.testing def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, @@ -60,8 +61,8 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: - skip("s is not enabled" % device) + if not tvm.testing.device_enabled(device): + print("Skipping %s becuase it is not enabled" % device) print("Running on target: %s" % device) with tvm.target.create(device): C = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NCHW', out_dtype=dtype) diff --git a/tests/python/nightly/quantization/test_quantization_accuracy.py b/tests/python/nightly/quantization/test_quantization_accuracy.py index d4b55f14100b..ada9a96d3d95 100644 --- a/tests/python/nightly/quantization/test_quantization_accuracy.py +++ b/tests/python/nightly/quantization/test_quantization_accuracy.py @@ -23,6 +23,7 @@ from mxnet import gluon import logging import os +import tvm.testing logging.basicConfig(level=logging.INFO) @@ -112,6 +113,7 @@ def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) return top1 +@tvm.testing.requires_gpu def test_quantize_acc(cfg, rec_val): qconfig = qtz.qconfig(skip_conv_layers=[0], nbit_input=cfg.nbit_input, diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index 95a030f5b8ae..0097a4eed9dc 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -22,11 +22,13 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type import tvm.topi.testing import random +import tvm.testing +@tvm.testing.uses_gpu def test_dyn_broadcast_to(): dtype = 'uint8' rank = 3 @@ -44,7 +46,7 @@ def test_dyn_broadcast_to(): x = np.random.uniform(size=x_shape).astype(dtype) dyn_shape = (1, ) * rank ref_res = np.broadcast_to(x, dyn_shape) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) @@ -53,6 +55,7 @@ def test_dyn_broadcast_to(): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_dyn_one_hot(): def _get_oshape(indices_shape, depth, axis): oshape = [] @@ -77,7 +80,7 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): func = relay.Function([indices, depth_var], out) indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index e1a0d284d9bf..bab4869b3fe0 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -21,7 +21,7 @@ import tvm from tvm import relay from tvm import te -from tvm.relay.testing import ctx_list +from tvm.relay.testing import enabled_targets import random from test_dynamic_op_level3 import verify_func import tvm.topi.testing @@ -51,7 +51,7 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa zz = run_infer_type(z) func = relay.Function([x, scale_h_var, scale_w_var], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 91e9cc77fe99..193de85a5242 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -22,11 +22,12 @@ from tvm import te from tvm import relay from tvm.relay import create_executor, transform -from tvm.relay.testing import ctx_list, check_grad, run_infer_type +from tvm.relay.testing import check_grad, run_infer_type +import tvm.testing def verify_func(func, data, ref_res): assert isinstance(data, list) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): #TODO(mbrookhart): enable Cuda tests onces the VM supports dynamic shapes if "llvm" not in target: continue for kind in ["vm", "debug"]: @@ -36,6 +37,7 @@ def verify_func(func, data, ref_res): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) relay.backend.compile_engine.get().clear() +@tvm.testing.uses_gpu def test_dyn_reshape(): def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -60,6 +62,7 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20)) verify_reshape((2, 3, 4), (0, -3), (2, 12)) +@tvm.testing.uses_gpu def test_dyn_shape_reshape(): def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -76,6 +79,7 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4), (8, 3), (8, 3)) verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) +@tvm.testing.uses_gpu def test_dyn_tile(): def verify_tile(dshape, reps): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -92,6 +96,7 @@ def verify_tile(dshape, reps): verify_tile((2, 3), (3, 2, 1)) +@tvm.testing.uses_gpu def test_dyn_zeros_ones(): def verify_zeros_ones(shape, dtype): for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]: @@ -107,6 +112,7 @@ def verify_zeros_ones(shape, dtype): verify_zeros_ones((1, 3), 'int64') verify_zeros_ones((8, 9, 1, 2), 'float32') +@tvm.testing.uses_gpu def test_dyn_full(): def verify_full(fill_value, src_shape, dtype): x = relay.var("x", relay.scalar_type(dtype)) diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index 8dcfd1fd5778..226bbfe2678e 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -22,8 +22,9 @@ from tvm import te from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type import tvm.topi.testing +import tvm.testing def test_resize_infer_type(): @@ -35,6 +36,7 @@ def test_resize_infer_type(): assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") +@tvm.testing.uses_gpu def test_resize(): def verify_resize(dshape, scale, method, layout): if layout == "NHWC": @@ -57,7 +59,7 @@ def verify_resize(dshape, scale, method, layout): zz = run_infer_type(z) func = relay.Function([x, size_var], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py index ddfab552ed83..6dcde953710d 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level6.py +++ b/tests/python/relay/dyn/test_dynamic_op_level6.py @@ -21,8 +21,9 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import ctx_list +import tvm.testing +@tvm.testing.uses_gpu def test_dynamic_topk(): def verify_topk(k, axis, ret_type, is_ascend, dtype): shape = (20, 100) @@ -51,7 +52,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): np_values[i, :] = np_data[i, np_indices[i, :]] np_indices = np_indices.astype(dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index 6bc170d2b5af..906882d774fd 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -23,6 +23,7 @@ from tvm import topi from tvm.relay.testing import run_infer_type from tvm.relay.testing.temp_op_attr import TempOpAttr +import tvm.testing @autotvm.register_topi_compute("test/conv2d_1") @@ -161,14 +162,14 @@ def get_func(shape): z3 = engine.lower(get_func(()), "llvm") assert z1.same_as(z2) assert not z3.same_as(z1) - if tvm.context("cuda").exist: + if tvm.testing.device_enabled("cuda"): z4 = engine.lower(get_func(()), "cuda") assert not z3.same_as(z4) # Test JIT target for target in ["llvm"]: ctx = tvm.context(target) - if ctx.exist: + if tvm.testing.device_enabled(target): f = engine.jit(get_func((10,)), target) x = tvm.nd.array(np.ones(10).astype("float32"), ctx=ctx) y = tvm.nd.empty((10,), ctx=ctx) diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index f0785bcf1c09..70a6fb13de73 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -20,7 +20,7 @@ from tvm import relay from tvm.contrib import graph_runtime from tvm.relay.op import add -from tvm.relay.testing.config import ctx_list +import tvm.testing # @tq, @jr should we put this in testing ns? def check_rts(expr, args, expected_result, mod=None): @@ -141,6 +141,7 @@ def test_plan_memory(): assert len(device_types) == 1 +@tvm.testing.uses_gpu def test_gru_like(): def unit(rnn_dim): X = relay.var("X", shape=(1, rnn_dim)) @@ -165,7 +166,7 @@ def unit_numpy(X, W): out_shape = (1, rnn_dim) z = unit(rnn_dim) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): with tvm.transform.PassContext(opt_level=2): graph, lib, params = relay.build(tvm.IRModule.from_expr(z), target) m = graph_runtime.create(graph, lib, ctx) diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 360b6bd20416..41a07e4dac51 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -30,7 +30,7 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): # TODO(tqchen) add more types once the schedule register is fixed. for target in ["llvm"]: ctx = tvm.context(target, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(target): return intrp = create_executor(mod=mod, ctx=ctx, target=target) result = intrp.evaluate(expr)(*args) diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index fa56eb0eef29..faf68674028a 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -20,6 +20,7 @@ from tvm import te from tvm import relay from tvm.contrib.nvcc import have_fp16 +import tvm.testing def test_basic_build(): @@ -64,13 +65,10 @@ def test_basic_build(): atol=1e-5, rtol=1e-5) +@tvm.testing.requires_cuda def test_fp16_build(): dtype = "float16" - if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: - print("skip because cuda is not enabled.") - return - ctx = tvm.gpu(0) if dtype == "float16" and not have_fp16(ctx.compute_version): print("skip because gpu does not support fp16") @@ -100,40 +98,34 @@ def test_fp16_build(): atol=1e-5, rtol=1e-5) -def test_fp16_conversion(): - def check_conversion(tgt, ctx): - if not tvm.runtime.enabled(tgt): - print("skip because {} is not enabled.".format(tgt)) - return - elif tgt == "cuda" and ctx.exist and not have_fp16(ctx.compute_version): - print("skip because gpu does not support fp16") - return - - n = 10 +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_fp16_conversion(target, ctx): + if target == "cuda" and not have_fp16(ctx.compute_version): + print("skip because gpu does not support fp16") + return - for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]: - x = relay.var("x", relay.TensorType((n,), src)) - y = x.astype(dst) - func = relay.Function([x], y) + n = 10 - # init input - X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2) + for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]: + x = relay.var("x", relay.TensorType((n,), src)) + y = x.astype(dst) + func = relay.Function([x], y) - # build - with tvm.transform.PassContext(opt_level=1): - g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), tgt) + # init input + X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2) - # test - rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) - rt.set_input("x", X) - rt.run() - out = rt.get_output(0) + # build + with tvm.transform.PassContext(opt_level=1): + g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), target) - np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst), - atol=1e-5, rtol=1e-5) + # test + rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) + rt.set_input("x", X) + rt.run() + out = rt.get_output(0) - for target, ctx in [('llvm', tvm.cpu()), ('cuda', tvm.gpu())]: - check_conversion(target, ctx) + np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst), + atol=1e-5, rtol=1e-5) if __name__ == "__main__": diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 437901ee95fc..3847e18c7c68 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -20,8 +20,9 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import check_grad, ctx_list, run_infer_type +from tvm.relay.testing import check_grad, run_infer_type from tvm.relay.transform import gradient +import tvm.testing def sigmoid(x): @@ -35,6 +36,7 @@ def relu(x): return x_copy +@tvm.testing.uses_gpu def test_unary_op(): def check_single_op(opfunc, ref, dtype): shape = (10, 4) @@ -49,7 +51,7 @@ def check_single_op(opfunc, ref, dtype): fwd_func = run_infer_type(fwd_func) bwd_func = run_infer_type(gradient(fwd_func)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor(ctx=ctx, target=target) op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) @@ -79,6 +81,7 @@ def check_single_op(opfunc, ref, dtype): check_single_op(opfunc, ref, dtype) +@tvm.testing.uses_gpu def test_binary_op(): def inst(vars, sh): return [vars.get(s, s) for s in sh] @@ -97,7 +100,7 @@ def check_binary_op(opfunc, ref, dtype): fwd_func = run_infer_type(fwd_func) bwd_func = run_infer_type(gradient(fwd_func)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor(ctx=ctx, target=target) op_res, (op_grad0, op_grad1) = intrp.evaluate(bwd_func)(x_data, y_data) np.testing.assert_allclose(op_grad0.asnumpy(), ref_grad0, rtol=0.01) diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 50e358564fbb..396e43dad1ec 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -21,8 +21,9 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import check_grad, ctx_list, run_infer_type +from tvm.relay.testing import check_grad, run_infer_type from tvm.relay.transform import gradient +import tvm.testing def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode): @@ -43,12 +44,13 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode): padding=[ph, pw, ph, pw], pool_type='max', ceil_mode=ceil_mode) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor(ctx=ctx, target=target) op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) +@tvm.testing.uses_gpu def test_max_pool2d_grad(): verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False) verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False) @@ -72,11 +74,12 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun padding=[ph, pw, ph, pw], pool_type='avg', ceil_mode=ceil_mode) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor(ctx=ctx, target=target) op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) +@tvm.testing.uses_gpu def test_avg_pool2d_grad(): verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False, count_include_pad=True) @@ -100,11 +103,12 @@ def verify_global_avg_pool2d_grad(x_shape): strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg', ceil_mode=False) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor(ctx=ctx, target=target) op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) +@tvm.testing.uses_gpu def test_global_avg_pool2d_grad(): verify_global_avg_pool2d_grad((1, 4, 16, 16)) verify_global_avg_pool2d_grad((1, 8, 8, 24)) @@ -139,7 +143,7 @@ def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mod .detach().numpy() - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): data = tvm.nd.array(data_pt.detach().numpy(), ctx) weight = tvm.nd.array(weight_pt.detach().numpy(), ctx) intrp = relay.create_executor(ctx=ctx, target=target) @@ -148,6 +152,7 @@ def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mod np.testing.assert_allclose(grad_weight.asnumpy(), grad_weight_pt, rtol=1e-4, atol=1e-4) +@tvm.testing.uses_gpu def test_conv2d_grad(): verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1]) verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [1, 1], [0, 0], [1, 1]) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 8ca1eaea5ed2..a63ec6ee0902 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -20,10 +20,12 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import check_grad, ctx_list, run_infer_type +from tvm.relay.testing import check_grad, run_infer_type from tvm.relay.transform import gradient +import tvm.testing +@tvm.testing.uses_gpu def test_clip(): for dtype in ('float32', 'float64'): ref = (lambda x: np.where(x > 10.0, np.zeros_like(x), @@ -37,7 +39,7 @@ def test_clip(): fwd_func = run_infer_type(fwd_func) bwd_func = run_infer_type(gradient(fwd_func)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor(ctx=ctx, target=target) op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 4616a14dcdde..086a880ab9ed 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -21,9 +21,10 @@ import scipy from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type import tvm.topi.testing from tvm.contrib.nvcc import have_fp16 +import tvm.testing def sigmoid(x): @@ -39,6 +40,7 @@ def rsqrt(x): one = np.ones_like(x) return one / np.sqrt(x) +@tvm.testing.uses_gpu def test_unary_op(): def check_single_op(opfunc, ref, dtype): shape = (10, 4) @@ -56,7 +58,7 @@ def check_single_op(opfunc, ref, dtype): data = np.random.rand(*shape).astype(dtype) ref_res = ref(data) func = relay.Function([x], y) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): # use graph by execuor default for testing, as we need # create function explicitly to avoid constant-folding. if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): @@ -82,6 +84,7 @@ def check_single_op(opfunc, ref, dtype): check_single_op(opfunc, ref, dtype) +@tvm.testing.uses_gpu def test_binary_op(): def inst(vars, sh): return [vars.get(s, s) for s in sh] @@ -112,7 +115,7 @@ def check_binary_op(opfunc, ref, dtype): ref_res = ref(x_data, y_data) func = relay.Function([x, y], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): # use graph by execuor default for testing, as we need # create function explicitly to avoid constant-folding. if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): @@ -131,12 +134,13 @@ def check_binary_op(opfunc, ref, dtype): check_binary_op(opfunc, ref, dtype) +@tvm.testing.uses_gpu def test_expand_dims(): # based on topi test def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): x = relay.Var("x", relay.TensorType(dshape, dtype)) func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): continue data = np.random.uniform(size=dshape).astype(dtype) @@ -149,6 +153,7 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): verify_expand_dims((3, 10), dtype, (1, 3, 10), -3, 1) +@tvm.testing.uses_gpu def test_bias_add(): for dtype in ['float16', 'float32']: xshape=(10, 2, 3, 4) @@ -165,7 +170,7 @@ def test_bias_add(): x_data = np.random.uniform(size=xshape).astype(dtype) y_data = np.random.uniform(size=bshape).astype(dtype) ref_res = x_data + y_data.reshape((2, 1, 1)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): continue intrp = relay.create_executor("graph", ctx=ctx, target=target) @@ -183,6 +188,7 @@ def test_expand_dims_infer_type(): assert yy.checked_type == relay.TensorType((n, t, 1, 100), dtype) +@tvm.testing.uses_gpu def test_softmax(): for dtype in ['float16', 'float32']: # Softmax accuracy for float16 is poor @@ -197,12 +203,13 @@ def test_softmax(): func = relay.Function([x], y) x_data = np.random.uniform(size=shape).astype(dtype) ref_res = tvm.topi.testing.softmax_python(x_data) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_log_softmax(): for dtype in ['float16', 'float32']: # Softmax accuracy for float16 is poor @@ -217,12 +224,13 @@ def test_log_softmax(): func = relay.Function([x], y) x_data = np.random.uniform(size=shape).astype(dtype) ref_res = tvm.topi.testing.log_softmax_python(x_data) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_concatenate(): for dtype in ['float16', 'float32']: n, t, d = te.size_var("n"), te.size_var("t"), 100 @@ -266,7 +274,7 @@ def test_concatenate(): t_data = np.random.uniform(size=()).astype(dtype) ref_res = np.concatenate((x_data, y_data), axis=1) + t_data - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): continue intrp1 = relay.create_executor("graph", ctx=ctx, target=target) @@ -345,6 +353,7 @@ def test_dense_type_check(): y = relay.nn.dense(x, w) yy = run_infer_type(y) +@tvm.testing.uses_gpu def test_dense(): for dtype in ['float16', 'float32']: # Dense accuracy for float16 is poor @@ -383,7 +392,7 @@ def test_dense(): w_data = np.random.rand(2, 5).astype(dtype) ref_res = np.dot(x_data, w_data.T) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data, w_data) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index a65b17fe1ad8..3aaa76d771d3 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -22,11 +22,13 @@ import tvm.topi.testing from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type from tvm import topi import tvm.topi.testing +import tvm.testing +@tvm.testing.uses_gpu def test_checkpoint(): dtype = "float32" xs = [relay.var("x{}".format(i), dtype) for i in range(4)] @@ -38,7 +40,7 @@ def test_checkpoint(): assert f.checked_type == f_checkpoint.checked_type inputs = [np.random.uniform() for _ in range(len(xs))] - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) f_res = intrp.evaluate(f)(*inputs) @@ -148,6 +150,7 @@ def test_checkpoint_alpha_equal_tuple(): tvm.ir.assert_structural_equal(df, df_parsed) +@tvm.testing.uses_gpu def test_collapse_sum_like(): shape = (3, 4, 5, 6) shape_like = (4, 5, 6) @@ -162,13 +165,14 @@ def test_collapse_sum_like(): x = np.random.uniform(size=shape).astype(dtype) y = np.random.uniform(size=shape_like).astype(dtype) ref_res = np.sum(x, 0) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x, y) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_collapse_sum_to(): shape = (3, 4, 5, 6) shape_to = (4, 5, 6) @@ -181,13 +185,14 @@ def test_collapse_sum_to(): func = relay.Function([x], z) x = np.random.uniform(size=shape).astype(dtype) ref_res = np.sum(x, 0) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_broadcast_to(): shape = (4, 1, 6) shape_like = (3, 4, 5, 6) @@ -200,12 +205,13 @@ def test_broadcast_to(): func = relay.Function([x], z) x = np.random.uniform(size=shape).astype(dtype) ref_res = np.broadcast_to(x, shape_like) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_broadcast_to_like(): shape = (4, 1, 6) shape_like = (3, 4, 5, 6) @@ -222,7 +228,7 @@ def test_broadcast_to_like(): y = np.random.uniform(size=shape_like).astype(dtype) ref_res = np.broadcast_to(x, shape_like) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x, y) @@ -266,12 +272,13 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"): y_data = np.random.uniform(size=slice_like).astype(dtype) ref_res = np_slice_like(x_data, y_data, axes) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_slice_like(): d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") verify_slice_like(data=(d1, d2, d3), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3)) @@ -286,6 +293,7 @@ def test_slice_like(): axes=(2, 3), output=(1, 3, 112, 112)) +@tvm.testing.uses_gpu def test_reverse_reshape(): def verify_reverse_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -297,7 +305,7 @@ def verify_reverse_reshape(shape, newshape, oshape): func = relay.Function([x], z) x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") ref_res = np.reshape(x_data, oshape) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -320,12 +328,13 @@ def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): y_np = np.random.uniform(size=y_shape).astype(dtype) z_np = tvm.topi.testing.batch_matmul(x_np, y_np) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) z = intrp.evaluate(func)(x_np, y_np) tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_batch_matmul(): b, m, n, k = te.size_var("b"), te.size_var("m"), te.size_var("n"), te.size_var("k") x = relay.var("x", relay.TensorType((b, m, k), "float32")) @@ -339,13 +348,14 @@ def test_batch_matmul(): verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) +@tvm.testing.uses_gpu def test_shape_of(): shape = (10, 5, 12) x = relay.var("x", shape=shape) func = relay.Function([x], relay.op.shape_of(x)) func = run_infer_type(func) x_data = np.random.rand(*shape).astype('float32') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): # Because using graph executor, this op will be optimized after # constant folding pass, here we only test with interpreter for kind in ["debug"]: @@ -354,6 +364,7 @@ def test_shape_of(): tvm.testing.assert_allclose(op_res.asnumpy(), np.array(shape).astype('int32')) +@tvm.testing.uses_gpu def test_ndarray_size(): def verify_ndarray_size(shape): x = relay.var("x", shape=shape) @@ -362,7 +373,7 @@ def verify_ndarray_size(shape): x_data = np.random.uniform(size=shape).astype("float32") ref_res = np.size(x_data) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -380,7 +391,7 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc): np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype) np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) relay_out = intrp1.evaluate(func)(np_data) tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5) @@ -396,6 +407,7 @@ def verify_adaptive_pool3d(dshape, out_size, pool_type, layout="NCHW", dtype="fl verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc) +@tvm.testing.uses_gpu def test_adaptive_pool(): verify_adaptive_pool2d((1, 9, 224, 224), (1, 1), "max") verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg") @@ -409,6 +421,7 @@ def test_adaptive_pool(): verify_adaptive_pool3d((1, 16, 32, 32, 32), (2, 4, 4), "max", layout="NDHWC") +@tvm.testing.uses_gpu def test_sequence_mask(): def _verify(data_shape, mask_value, axis, dtype, itype): max_length = data_shape[axis] @@ -423,7 +436,7 @@ def _verify(data_shape, mask_value, axis, dtype, itype): valid_length_np = np.random.randint(0, max_length, size=nbatch).astype(itype) gt_out_np = tvm.topi.testing.sequence_mask(data_np, valid_length_np, mask_value, axis) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) out_relay = intrp.evaluate(func)(data_np, valid_length_np) @@ -432,6 +445,7 @@ def _verify(data_shape, mask_value, axis, dtype, itype): _verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64') _verify((5, 8, 3), 0.1, 1, 'float64', 'float32') +@tvm.testing.uses_gpu def test_one_hot(): def _get_oshape(indices_shape, depth, axis): oshape = [] @@ -458,7 +472,7 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) out_relay = intrp.evaluate(func)(indices_np) @@ -471,6 +485,7 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") +@tvm.testing.uses_gpu def test_matrix_set_diag(): def _verify(input_shape, dtype): diagonal_shape = list(input_shape[:-2]) @@ -488,7 +503,7 @@ def _verify(input_shape, dtype): diagonal_np = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) out_np = tvm.topi.testing.matrix_set_diag(input_np, diagonal_np) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) out_relay = intrp.evaluate(func)(input_np, diagonal_np) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 6258d8c9aaf8..93eecfc0ee8b 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -22,12 +22,14 @@ from tvm import autotvm from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type from tvm.contrib import util import tvm.topi.testing from tvm.topi.cuda.conv3d_winograd import _infer_tile_size +import tvm.testing +@tvm.testing.uses_gpu def test_conv1d_infer_type(): # symbolic in batch dimension n, c, w = te.var("n"), 10, 224 @@ -78,6 +80,7 @@ def test_conv1d_infer_type(): (n, w, 16), "int32") +@tvm.testing.uses_gpu def test_conv1d_run(): def run_test_conv1d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), @@ -100,9 +103,10 @@ def run_test_conv1d(dtype, out_dtype, scale, dshape, kshape, ref_res = tvm.topi.testing.conv1d_ncw_python( data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, dilation) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if target in except_targets: continue + ctx = tvm.context(target, 0) intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -122,6 +126,7 @@ def run_test_conv1d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), channels=10, kernel_size=3, dilation=3) +@tvm.testing.uses_gpu def test_conv2d_infer_type(): # symbolic in batch dimension n, c, h, w = te.size_var("n"), 10, 224, 224 @@ -189,6 +194,7 @@ def test_conv2d_infer_type(): (n, h, w, 16), "int32") +@tvm.testing.uses_gpu def test_conv2d_run(): def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), @@ -219,9 +225,10 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if target in except_targets: continue + ctx = tvm.context(target, 0) intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4, atol=1e-4) @@ -314,6 +321,7 @@ def compile_test_conv2d_arm_cpu(dtype, out_dtype, scale, dshape, kshape, run_test_conv2d("float32", "float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3 ,3), dilation=(3, 3)) +@tvm.testing.uses_gpu def test_conv2d_winograd(): class WinogradFallback(autotvm.FallbackContext): def _query_inside(self, target, workload): @@ -357,9 +365,10 @@ def run_test_conv2d_cuda(dtype, out_dtype, scale, dshape, kshape, groups=groups) with WinogradFallback(), tvm.transform.PassContext(opt_level=3): - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if target != 'cuda': continue + ctx = tvm.context(target, 0) params = {'w': tvm.nd.array(kernel)} graph, lib, params = relay.build_module.build(mod, target=target, params=params) module = tvm.contrib.graph_runtime.create(graph, lib, ctx) @@ -385,6 +394,7 @@ def run_test_conv2d_cuda(dtype, out_dtype, scale, dshape, kshape, padding=(2, 2), channels=192, kernel_size=(7, 7)) +@tvm.testing.uses_gpu def test_conv3d_infer_type(): # symbolic in batch dimension n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224 @@ -435,6 +445,7 @@ def test_conv3d_infer_type(): (n, d, h, w, 16), "int32") +@tvm.testing.uses_gpu def test_conv3d_run(): def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1, 1), @@ -465,9 +476,10 @@ def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if target in except_targets: continue + ctx = tvm.context(target, 0) intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) @@ -479,6 +491,7 @@ def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, run_test_conv3d("float32", "float32", 1, dshape, kshape, padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3)) +@tvm.testing.uses_gpu def test_conv3d_ndhwc_run(): def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1, 1), @@ -509,9 +522,10 @@ def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if target in except_targets: continue + ctx = tvm.context(target, 0) intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) @@ -523,6 +537,7 @@ def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, run_test_conv3d("float32", "float32", 1, dshape, kshape, padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"]) +@tvm.testing.uses_gpu def test_conv3d_winograd(): class WinogradFallback(autotvm.FallbackContext): def _query_inside(self, target, workload): @@ -579,9 +594,10 @@ def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape, groups=groups) with WinogradFallback(), tvm.transform.PassContext(opt_level=3): - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if target != 'cuda': continue + ctx = tvm.context(target, 0) params = {'w': tvm.nd.array(kernel)} graph, lib, params = relay.build_module.build(mod, target=target, params=params) module = tvm.contrib.graph_runtime.create(graph, lib, ctx) @@ -612,6 +628,7 @@ def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape, padding=(0, 2, 2), channels=120, kernel_size=(1, 5, 5)) +@tvm.testing.uses_gpu def test_conv3d_transpose_infer_type(): # symbolic in batch dimension n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224 @@ -649,6 +666,7 @@ def test_conv3d_transpose_infer_type(): (n, 12, 226, 226, 226), "int32") +@tvm.testing.uses_gpu def test_conv3d_transpose_ncdhw_run(): dshape = (1, 3, 24, 24, 24) kshape = (3, 4, 2, 2, 2) @@ -665,12 +683,13 @@ def test_conv3d_transpose_ncdhw_run(): kernel = np.random.uniform(size=kshape).astype(dtype) ref_res = tvm.topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1, 0) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_conv2d_transpose_infer_type(): # symbolic in batch dimension n, c, h, w = te.size_var("n"), 10, 10, 12 @@ -700,6 +719,7 @@ def test_conv2d_transpose_infer_type(): (n, 15, 15, 11), "float32") +@tvm.testing.uses_gpu def test_conv2d_transpose_nchw_run(): dshape = (1, 3, 18, 18) kshape = (3, 10, 3, 3) @@ -716,12 +736,13 @@ def test_conv2d_transpose_nchw_run(): ref_res = tvm.topi.testing.conv2d_transpose_nchw_python( data, kernel, 2, 1, (1, 1)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_conv2d_transpose_nhwc_run(): dshape_nhwc = (1, 18, 18, 3) kshape_hwoi = (3, 3, 10, 3) @@ -743,12 +764,13 @@ def test_conv2d_transpose_nhwc_run(): ref_res = tvm.topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1, output_padding=(1, 1)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_conv1d_transpose_ncw_run(): dshape = (1, 3, 18) kshape = (3, 10, 3) @@ -765,12 +787,13 @@ def test_conv1d_transpose_ncw_run(): ref_res = tvm.topi.testing.conv1d_transpose_ncw_python( data, kernel, 2, 1, output_padding=(1,)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, kernel) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_upsampling_infer_type(): n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") scale = tvm.tir.const(2.0, "float64") @@ -787,6 +810,7 @@ def test_upsampling_infer_type(): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") +@tvm.testing.uses_gpu def test_upsampling3d_infer_type(): n, c, d, h, w = te.size_var("n"), te.size_var("c"),\ te.size_var("d"), te.size_var("h"), te.size_var("w") @@ -820,7 +844,7 @@ def _test_pool2d(opfunc, reffunc, pool_size=(2, 2), strides=(2, 2), padding=(0, func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -840,7 +864,7 @@ def _test_pool2d_int(opfunc, reffunc, dtype): func = relay.Function([x], y) data = np.random.randint(low=-128, high=128, size=dshape) ref_res = reffunc(data.reshape(1,3,14,2,14,2), axis=(3,5)).astype(dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -865,12 +889,13 @@ def _test_global_pool2d(opfunc, reffunc): func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) ref_res = reffunc(data, axis=(2,3), keepdims=True) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_pool2d(): _test_pool2d(relay.nn.max_pool2d, np.max) _test_pool2d(relay.nn.max_pool2d, np.max, pool_size=2, strides=2, padding=0) @@ -882,6 +907,7 @@ def test_pool2d(): _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) +@tvm.testing.uses_gpu def test_pool1d(): def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)): @@ -901,7 +927,7 @@ def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)): data = np.random.uniform(size=dshape).astype(dtype) ref_res = tvm.topi.testing.pool1d_ncw_python(data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -912,6 +938,7 @@ def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)): _test_pool1d(relay.nn.avg_pool1d, pool_size=2, strides=2, padding=0) +@tvm.testing.uses_gpu def test_pool3d(): def _test_pool3d(opfunc, @@ -939,7 +966,7 @@ def _test_pool3d(opfunc, data = np.random.uniform(size=dshape).astype(dtype) ref_res = tvm.topi.testing.pool3d_ncdhw_python(data, pool_size, strides, padding, out_shape, pool_type, False) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -956,6 +983,7 @@ def _test_pool3d(opfunc, _test_pool3d(relay.nn.avg_pool3d, pool_size=2, padding=0, strides=2) +@tvm.testing.uses_gpu def test_avg_pool2d_no_count_pad(): kh, kw = (4, 4) sh, sw = (2, 2) @@ -985,11 +1013,12 @@ def test_avg_pool2d_no_count_pad(): ref_res = np.maximum(b_np, 0.0) data = a_np - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_flatten_infer_type(): d1, d2, d3, d4 = te.size_var("d1"), te.size_var("d2"), te.size_var("d3"), te.size_var("d4") x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32")) @@ -1018,7 +1047,7 @@ def test_flatten_infer_type(): x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) ref_res = x_data.flatten().reshape(o_shape) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data) @@ -1026,6 +1055,7 @@ def test_flatten_infer_type(): op_res2 = intrp2.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_pad_infer_type(): # entirely concrete case n, c, h, w = 1, 2, 3, 4 @@ -1042,6 +1072,7 @@ def test_pad_infer_type(): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32") +@tvm.testing.uses_gpu def test_pad_run(): def _test_run(dtype): dshape = (4, 10, 7, 7) @@ -1050,7 +1081,7 @@ def _test_run(dtype): func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) ref_res = np.pad(data, ((1, 1), (2, 2), (3, 3), (4, 4)), 'constant') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -1058,6 +1089,7 @@ def _test_run(dtype): _test_run('float32') _test_run('int32') +@tvm.testing.uses_gpu def test_lrn(): n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", shape=(n, c , h, w)) @@ -1081,7 +1113,7 @@ def test_lrn(): x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) ref_res = tvm.topi.testing.lrn_python(x_data, size, axis, bias, alpha, beta) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data) @@ -1089,6 +1121,7 @@ def test_lrn(): op_res2 = intrp2.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_l2_normalize(): n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", shape=(n, c , h, w)) @@ -1109,7 +1142,7 @@ def test_l2_normalize(): x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) ref_res = tvm.topi.testing.l2_normalize_python(x_data, eps, axis) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data) @@ -1126,6 +1159,7 @@ def batch_flatten(data): return np.reshape(data, (shape[0], target_dim)) +@tvm.testing.uses_gpu def test_batch_flatten(): t1 = relay.TensorType((5, 10, 5)) x = relay.Var("x", t1) @@ -1133,7 +1167,7 @@ def test_batch_flatten(): data = np.random.rand(5, 10, 5).astype(t1.dtype) ref_res = batch_flatten(data) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) @@ -1166,12 +1200,13 @@ def get_shape(): else: ref = tvm.topi.testing.bilinear_resize_python(data, (int(round(h*scale_h)), int(round(w*scale_w))), layout) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", ctx=ctx, target=target) out = executor.evaluate(func)(data) tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_upsampling(): _test_upsampling("NCHW", "nearest_neighbor") _test_upsampling("NCHW", "bilinear", True) @@ -1212,17 +1247,19 @@ def get_shape(): ref = tvm.topi.testing.trilinear_resize3d_python(data, (int(round(d*scale_d)),\ int(round(h*scale_h)),\ int(round(w*scale_w))), layout) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", ctx=ctx, target=target) out = executor.evaluate(func)(data) tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu def test_upsampling3d(): _test_upsampling3d("NCDHW", "nearest_neighbor") _test_upsampling3d("NCDHW", "trilinear", "align_corners") _test_upsampling3d("NDHWC", "nearest_neighbor") _test_upsampling3d("NDHWC", "trilinear", "align_corners") +@tvm.testing.uses_gpu def test_conv2d_int8_intrinsics(): def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): input_dtype, weight_dtype, output_dtype = dtypes @@ -1347,6 +1384,7 @@ def _has_fast_int8_instructions(asm, target): assert "vpmulld" in asm and "vpadd" in asm +@tvm.testing.uses_gpu def test_depthwise_conv2d_int8(): input_dtype = 'uint8' weight_dtype = 'int8' @@ -1376,6 +1414,7 @@ def test_depthwise_conv2d_int8(): graph, lib, params = relay.build(func, target, params=parameters) +@tvm.testing.uses_gpu def test_bitserial_conv2d_infer_type(): # Basic shape test with ambiguous batch. n, c, h, w = te.size_var("n"), 32, 224, 224 @@ -1388,6 +1427,7 @@ def test_bitserial_conv2d_infer_type(): (n, 32, 222, 222), "int16") +@tvm.testing.uses_gpu def test_bitpack_infer_type(): # Test axis packing shape inference. o, i, h, w = 32, 32, 128, 128 @@ -1400,6 +1440,7 @@ def test_bitpack_infer_type(): # TODO(@jwfromm): Need to add bitserial_conv2d & bitpack run test cases +@tvm.testing.uses_gpu def test_correlation(): def _test_correlation(data_shape, kernel_size, max_displacement, stride1, stride2, padding, is_multiply, dtype='float32'): data1 = relay.var("data1", relay.ty.TensorType(data_shape, dtype)) @@ -1422,7 +1463,7 @@ def _test_correlation(data_shape, kernel_size, max_displacement, stride1, stride data2_np = np.random.uniform(size=data_shape).astype(dtype) ref_res = tvm.topi.testing.correlation_nchw_python(data1_np, data2_np, kernel_size, max_displacement, stride1, stride2, padding, is_multiply) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data1_np, data2_np) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 745130da0296..940bb70c53a0 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -23,7 +23,8 @@ from tvm import relay from tvm.error import TVMError from tvm.relay import create_executor, transform -from tvm.relay.testing import ctx_list, check_grad, run_infer_type +from tvm.relay.testing import check_grad, run_infer_type +import tvm.testing def test_zeros_ones(): @@ -199,6 +200,7 @@ def test_transpose_infer_type(): (100, t, n), "float32") +@tvm.testing.uses_gpu def test_transpose(): def verify_transpose(dshape, axes): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -208,7 +210,7 @@ def verify_transpose(dshape, axes): x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") ref_res = np.transpose(x_data, axes=axes) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -250,6 +252,7 @@ def test_reshape_infer_type(): assert yy.checked_type == relay.TensorType( (n, t, 2000), "float32") +@tvm.testing.uses_gpu def test_reshape(): def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -262,7 +265,7 @@ def verify_reshape(shape, newshape, oshape): check_grad(func) x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") ref_res = np.reshape(x_data, oshape) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -307,6 +310,7 @@ def test_reshape_like_infer_type(): assert zz.checked_type == relay.TensorType((1, 8, 8), "float32") +@tvm.testing.uses_gpu def test_reshape_like(): def verify_reshape_like(shape, oshape): x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") @@ -321,7 +325,7 @@ def verify_reshape_like(shape, oshape): func = relay.Function([x, y], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) @@ -347,6 +351,7 @@ def verify_take(dshape, indices_shape, oshape, axis=None): verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1) verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) +@tvm.testing.uses_gpu def test_take(): def verify_take(src_shape, indices_src, axis=None, mode="clip"): src_dtype = "float32" @@ -361,7 +366,7 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"): np_mode = "raise" if mode == "fast" else mode ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, indices_src) @@ -448,13 +453,14 @@ def test_full_infer_type(): assert yy.checked_type == relay.TensorType((1, 2), "int8") +@tvm.testing.uses_gpu def test_full(): def verify_full(fill_value, src_shape, dtype): x = relay.var("x", relay.scalar_type(dtype)) z = relay.full(x, src_shape, dtype) func = relay.Function([x], z) ref_res = np.full(src_shape, fill_value) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(np.array(fill_value, dtype)) @@ -481,6 +487,7 @@ def test_full_like_infer_type(): assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") +@tvm.testing.uses_gpu def test_full_like(): def verify_full_like(base, fill_value, dtype): x_data = np.random.uniform(low=-1, high=1, size=base).astype(dtype) @@ -491,7 +498,7 @@ def verify_full_like(base, fill_value, dtype): func = relay.Function([x, y], z) ref_res = np.full_like(x_data, fill_value) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, np.array(fill_value, dtype)) @@ -500,6 +507,7 @@ def verify_full_like(base, fill_value, dtype): verify_full_like((1, 1), 44.0, "float32") +@tvm.testing.uses_gpu def test_infer_type_leaky_relu(): n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) @@ -519,7 +527,7 @@ def test_infer_type_leaky_relu(): x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) ref_res = np.where(x_data > 0, x_data, x_data * 0.1) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data) @@ -555,7 +563,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"): else: ref_res = (x_data < 0) * (x_data * a_data.reshape(1, 1, 3)) + (x_data>=0) * x_data - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data, a_data) @@ -564,6 +572,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"): tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_infer_type_prelu(): n, c , h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") verify_infer_type_prelu((n, c, h, w), (c,), 1, (n, c, h, w)) @@ -576,6 +585,7 @@ def test_infer_type_prelu(): verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3)) +@tvm.testing.uses_gpu def test_arange(): def verify_arange(start, stop, step): dtype = "float32" @@ -596,7 +606,7 @@ def verify_arange(start, stop, step): ref_res = np.arange(start, stop, step).astype(dtype) func = relay.Function([], x) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)() @@ -613,6 +623,7 @@ def verify_arange(start, stop, step): # arange doesnt' support floating point right now, see type relation # verify_arange(20, 1, -1.5) +@tvm.testing.uses_gpu def test_meshgrid(): def verify_meshgrid(lengths, indexing="ij"): input_vars = [] @@ -632,7 +643,7 @@ def verify_meshgrid(lengths, indexing="ij"): # Get ref ref_res = np.meshgrid(*input_data, indexing=indexing) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(*input_data) @@ -646,6 +657,7 @@ def verify_meshgrid(lengths, indexing="ij"): # Length 0 signifies scalar. verify_meshgrid([3, 5, 0]) +@tvm.testing.uses_gpu def test_tile(): def verify_tile(dshape, reps): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -655,7 +667,7 @@ def verify_tile(dshape, reps): x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") ref_res = np.tile(x_data, reps=reps) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -664,13 +676,14 @@ def verify_tile(dshape, reps): verify_tile((2, 3, 4), (1, 2)) verify_tile((2, 3), (3, 2, 1)) +@tvm.testing.uses_gpu def test_repeat(): def verify_repeat(dshape, repeats, axis): x = relay.Var("x", relay.TensorType(dshape, "float32")) func = relay.Function([x], relay.repeat(x, repeats, axis)) data = np.random.uniform(size=dshape).astype("float32") ref_res = np.repeat(data, repeats, axis) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data) @@ -679,6 +692,7 @@ def verify_repeat(dshape, repeats, axis): verify_repeat((3, 10), 2, -1) verify_repeat((3, 2, 4), 3, 1) +@tvm.testing.uses_gpu def test_stack(): def verify_stack(dshapes, axis): y = [] @@ -691,7 +705,7 @@ def verify_stack(dshapes, axis): x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes] ref_res = np.stack(x_data, axis=axis) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(*x_data) @@ -702,6 +716,7 @@ def verify_stack(dshapes, axis): verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1) +@tvm.testing.uses_gpu def test_reverse(): def verify_reverse(dshape, axis): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -711,7 +726,7 @@ def verify_reverse(dshape, axis): func = relay.Function([x], z) x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") ref_res = np.flip(x_data, axis) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -721,6 +736,7 @@ def verify_reverse(dshape, axis): verify_reverse((2, 3, 4), -1) +@tvm.testing.uses_gpu def test_reverse_sequence(): def verify_reverse_sequence(x_data, seq_lengths, batch_axis, seq_axis, ref_res): seq_lengths_data = np.array(seq_lengths).astype("int32") @@ -730,7 +746,7 @@ def verify_reverse_sequence(x_data, seq_lengths, batch_axis, seq_axis, ref_res): assert zz.checked_type == x.type_annotation func = relay.Function([x], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -881,6 +897,7 @@ def verify_scatter_add(dshape, ishape, axis=0): verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3) +@tvm.testing.uses_gpu def test_gather(): def verify_gather(data, axis, indices, ref_res): data = np.asarray(data, dtype='float32') @@ -893,7 +910,7 @@ def verify_gather(data, axis, indices, ref_res): func = relay.Function([d, i], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data, indices) @@ -933,6 +950,7 @@ def verify_gather(data, axis, indices, ref_res): [-0.5700, 0.1558, -0.5700, 0.1558]]]) +@tvm.testing.uses_gpu def test_gather_nd(): def verify_gather_nd(xshape, yshape, y_data): x = relay.var("x", relay.TensorType(xshape, "float32")) @@ -943,7 +961,7 @@ def verify_gather_nd(xshape, yshape, y_data): x_data = np.random.uniform(size=xshape).astype("float32") ref_res = x_data[tuple(y_data)] - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) @@ -981,6 +999,7 @@ def test_isinf(): _verify_infiniteness_ops(relay.isinf, np.isinf) +@tvm.testing.uses_gpu def test_unravel_index(): def verify_unravel_index(indices, shape, dtype): x_data = np.array(indices).astype(dtype) @@ -999,7 +1018,7 @@ def verify_unravel_index(indices, shape, dtype): func = relay.Function([x, y], z) ref_res = np.unravel_index(x_data, y_data) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) @@ -1017,6 +1036,7 @@ def verify_unravel_index(indices, shape, dtype): # output which is inline with Tensorflow # verify_unravel_index([0, 1, 2, 5], [2, 2], dtype) +@tvm.testing.uses_gpu def test_sparse_to_dense(): def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected): sparse_indices_data = np.array(sparse_indices) @@ -1037,7 +1057,7 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ assert zz.checked_type == relay.ty.TensorType(output_shape, str(sparse_values_data.dtype)) func = relay.Function(args, d) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) if default_value is None: diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 8e01fa2a89cd..af3826448edb 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -19,10 +19,12 @@ import numpy as np from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type import tvm.topi.testing +import tvm.testing +@tvm.testing.uses_gpu def test_binary_op(): def check_binary_op(opfunc, ref): n = te.size_var("n") @@ -47,7 +49,7 @@ def check_binary_op(opfunc, ref): ref_res = ref(x_data, y_data) func = relay.Function([x, y], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) @@ -56,6 +58,7 @@ def check_binary_op(opfunc, ref): check_binary_op(opfunc, ref) +@tvm.testing.uses_gpu def test_cmp_type(): for op, ref in ((relay.greater, np.greater), (relay.greater_equal, np.greater_equal), @@ -82,12 +85,13 @@ def test_cmp_type(): ref_res = ref(x_data, y_data) func = relay.Function([x, y], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) +@tvm.testing.uses_gpu def test_binary_int_broadcast_1(): for op, ref in [(relay.right_shift, np.right_shift), (relay.left_shift, np.left_shift)]: @@ -107,11 +111,12 @@ def test_binary_int_broadcast_1(): func = relay.Function([x, y], z) ref_res = ref(x_data, y_data) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) +@tvm.testing.uses_gpu def test_binary_int_broadcast_2(): for op, ref in [(relay.maximum, np.maximum), (relay.minimum, np.minimum), @@ -132,11 +137,12 @@ def test_binary_int_broadcast_2(): func = relay.Function([x, y], z) ref_res = ref(x_data, y_data) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) +@tvm.testing.uses_gpu def test_where(): shape = (3, 4) dtype = "float32" @@ -152,7 +158,7 @@ def test_where(): x = np.random.uniform(size=shape).astype(dtype) y = np.random.uniform(size=shape).astype(dtype) ref_res = np.where(condition, x, y) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(condition, x, y) @@ -195,7 +201,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") return ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data) @@ -203,6 +209,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") op_res2 = intrp2.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_reduce_functions(): def _with_keepdims(func): def _wrapper(data, axis=None, keepdims=False): @@ -282,7 +289,7 @@ def verify_mean_var_std(funcs, shape, axis, keepdims): ref_mean = np.mean(x_data, axis=axis, dtype=dtype, keepdims=keepdims) ref_res = ref_func(x_data, axis=axis, dtype=dtype, keepdims=keepdims) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data) @@ -292,6 +299,7 @@ def verify_mean_var_std(funcs, shape, axis, keepdims): tvm.testing.assert_allclose(op_res2[0].asnumpy(), ref_mean, rtol=1e-5) tvm.testing.assert_allclose(op_res2[1].asnumpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu def test_mean_var_std(): for func in [[relay.mean_variance, np.var], [relay.mean_std, np.std]]: @@ -307,6 +315,7 @@ def test_mean_var_std(): verify_mean_var_std(func, (128, 24, 128), (0, 2), True) +@tvm.testing.uses_gpu def test_strided_slice(): def verify(dshape, begin, end, strides, output, slice_mode="end", attr_const=True, test_ref=True, dtype="int32"): @@ -349,7 +358,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", if not test_ref: return - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) @@ -371,6 +380,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True) +@tvm.testing.uses_gpu def test_strided_set(): def verify(dshape, begin, end, strides, vshape, test_ref=True): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -394,7 +404,7 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): v_data = np.random.uniform(size=vshape).astype("float32") ref_res = tvm.topi.testing.strided_set_python( x_data, v_data, begin, end, strides) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, v_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 254bab5e1692..25e9ac0ce5a8 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -22,8 +22,9 @@ from tvm import te from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type import tvm.topi.testing +import tvm.testing def test_resize_infer_type(): @@ -40,6 +41,7 @@ def test_resize_infer_type(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") +@tvm.testing.uses_gpu def test_resize(): def verify_resize(dshape, scale, method, layout, coord_trans): if layout == "NHWC": @@ -61,7 +63,7 @@ def verify_resize(dshape, scale, method, layout, coord_trans): assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([x], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -87,7 +89,8 @@ def test_resize3d_infer_type(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 10, 10, 20), "int8") -def test_resize3d(): +@tvm.testing.parametrize_targets +def test_resize3d(target, ctx): def verify_resize(dshape, scale, method, layout): if layout == "NDHWC": size = (dshape[1] * scale, dshape[2] * scale, dshape[3] * scale) @@ -106,15 +109,15 @@ def verify_resize(dshape, scale, method, layout): assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([x], z) - for target, ctx in ctx_list(): - for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x_data) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4) + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4) for method in ["trilinear", "nearest_neighbor"]: for layout in ["NDHWC", "NCDHW"]: verify_resize((1, 4, 4, 4, 4), 2, method, layout) +@tvm.testing.uses_gpu def test_crop_and_resize(): def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size, layout, method, extrapolation_value=0.0): @@ -138,7 +141,7 @@ def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size, assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([img, bx, bx_idx], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(image_data, boxes, box_indices) @@ -157,6 +160,7 @@ def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size, verify_crop_and_resize((5, 3, 255, 255), boxes_nchw, indices_nchw, size_nchw, 'NCHW', method, 0.1) +@tvm.testing.uses_gpu def test_multibox_prior(): def get_ref_result(dshape, sizes=(1.0,), ratios=(1.0,), steps=(-1.0, -1.0), @@ -213,7 +217,7 @@ def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,), data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") func = relay.Function([x], z) func = run_infer_type(func) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) @@ -242,6 +246,7 @@ def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,), verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True) +@tvm.testing.uses_gpu def test_get_valid_counts(): def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): dtype = "float32" @@ -271,7 +276,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): assert "score_threshold" in z.astext() func = relay.Function([x], z.astuple()) func = run_infer_type(func) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) @@ -287,6 +292,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): verify_get_valid_counts((16, 500, 5), 0.95, -1, 0) +@tvm.testing.uses_gpu def test_non_max_suppression(): def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res, ref_indices_res, iou_threshold=0.5, force_suppress=False, @@ -319,7 +325,7 @@ def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res, func = run_infer_type(func) func_indices = relay.Function([x0, x1, x2, x3], z_indices) func_indices = run_infer_type(func_indices) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) @@ -366,6 +372,7 @@ def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res, np_indices_result, top_k=2) +@tvm.testing.uses_gpu def test_multibox_transform_loc(): def test_default_value(): num_anchors = 3 @@ -408,7 +415,7 @@ def test_default_value(): nms = relay.vision.non_max_suppression(mtl[0], mtl[1], mtl[0], return_indices=False) func = relay.Function([cls_prob, loc_pred, anchors], nms) func = run_infer_type(func) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors) @@ -450,6 +457,7 @@ def test_threshold(): test_threshold() +@tvm.testing.uses_gpu def test_roi_align(): def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ratio): data = relay.var("data", relay.ty.TensorType(data_shape, "float32")) @@ -471,7 +479,7 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ ref_res = tvm.topi.testing.roi_align_nchw_python(np_data, np_rois, pooled_size=pooled_size, spatial_scale=spatial_scale, sample_ratio=sample_ratio) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(np_data, np_rois) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4) @@ -483,6 +491,7 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2) +@tvm.testing.uses_gpu def test_roi_pool(): def verify_roi_pool(data_shape, rois_shape, pooled_size, spatial_scale): data = relay.var("data", relay.ty.TensorType(data_shape, "float32")) @@ -502,7 +511,7 @@ def verify_roi_pool(data_shape, rois_shape, pooled_size, spatial_scale): np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi).astype('float32') ref_res = tvm.topi.testing.roi_pool_nchw_python(np_data, np_rois, pooled_size=pooled_size, spatial_scale=spatial_scale) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(np_data, np_rois) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4) @@ -514,6 +523,7 @@ def verify_roi_pool(data_shape, rois_shape, pooled_size, spatial_scale): verify_roi_pool((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5) +@tvm.testing.uses_gpu def test_proposal(): def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): cls_prob = relay.var("cls_prob", relay.ty.TensorType(np_cls_prob.shape, "float32")) @@ -526,7 +536,7 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): func = relay.Function([cls_prob, bbox_pred, im_info], z) func = run_infer_type(func) for target in ['llvm', 'cuda']: - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): print("Skip test because %s is not enabled." % target) continue ctx = tvm.context(target, 0) @@ -592,6 +602,7 @@ def verify_yolo_reorg(shape, stride, out_shape): verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2)) verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2))) +@tvm.testing.uses_gpu def test_yolo_reorg(): def verify_yolo_reorg(shape, stride): x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") @@ -605,7 +616,7 @@ def verify_yolo_reorg(shape, stride): func = relay.Function([x], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -615,6 +626,7 @@ def verify_yolo_reorg(shape, stride): verify_yolo_reorg((1, 4, 6, 6), 2) +@tvm.testing.uses_gpu def test_deformable_conv2d(): def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, groups): data_shape = (batch, in_channel, size, size) @@ -665,7 +677,7 @@ def test_run(batch, in_channel, size, out_channel, deformable_groups, groups): kernel = np.random.uniform(size=kernel_shape).astype(dtype) ref_res = tvm.topi.testing.deformable_conv2d_nchw_python(data, offset, kernel, stride=(1, 1), padding=(1, 1), dilation=(1, 1), deformable_groups=deformable_groups, groups=groups) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp1 = relay.create_executor(kind, ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data, offset, kernel) @@ -674,6 +686,7 @@ def test_run(batch, in_channel, size, out_channel, deformable_groups, groups): test_run(2, 4, 16, 4, 4, 1) +@tvm.testing.uses_gpu def test_depth_to_space(): def verify_depth_to_space(dshape, block_size, layout, mode): if layout == "NHWC": @@ -696,7 +709,7 @@ def verify_depth_to_space(dshape, block_size, layout, mode): assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([x], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -706,6 +719,7 @@ def verify_depth_to_space(dshape, block_size, layout, mode): verify_depth_to_space((1, 4, 4, 4), 2, layout, mode) +@tvm.testing.uses_gpu def test_space_to_depth(): def verify_space_to_depth(dshape, block_size, layout): if layout == "NHWC": @@ -728,7 +742,7 @@ def verify_space_to_depth(dshape, block_size, layout): assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([x], z) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -753,6 +767,7 @@ def test_dilation2d_infer_type(): (n, 10, 217, 217), "float32") +@tvm.testing.uses_gpu def test_dilation2d_run(): def run_test_dilation2d(indata, kernel, out, dtype='float32', @@ -777,7 +792,7 @@ def run_test_dilation2d(indata, kernel, out, **attrs) func = relay.Function([x, w], y) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if target in except_targets: continue intrp = relay.create_executor("graph", ctx=ctx, target=target) @@ -844,6 +859,7 @@ def _convert_data(indata, kernel, out, layout=None): data_layout='NHWC', kernel_layout='HWI') +@tvm.testing.uses_gpu def test_affine_grid(): def verify_affine_grid(num_batch, target_shape): dtype = 'float32' @@ -857,7 +873,7 @@ def verify_affine_grid(num_batch, target_shape): data_np = np.random.uniform(size=data_shape).astype(dtype) ref_res = tvm.topi.testing.affine_grid_python(data_np, target_shape) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp1 = relay.create_executor(kind, ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data_np) @@ -867,6 +883,7 @@ def verify_affine_grid(num_batch, target_shape): verify_affine_grid(4, (16, 32)) +@tvm.testing.uses_gpu def test_grid_sample(): def verify_grid_sample(data_shape, grid_shape): dtype = 'float32' @@ -883,7 +900,7 @@ def verify_grid_sample(data_shape, grid_shape): grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) ref_res = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, method='bilinear') - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp1 = relay.create_executor(kind, ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data_np, grid_np) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 287e80a0fab7..e683224f1811 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -20,8 +20,9 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import ctx_list +import tvm.testing +@tvm.testing.uses_gpu def test_argsort(): def verify_argsort(shape, axis, is_ascend, dtype): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -33,7 +34,7 @@ def verify_argsort(shape, axis, is_ascend, dtype): else: ref_res = np.argsort(-x_data, axis=axis) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) @@ -44,6 +45,7 @@ def verify_argsort(shape, axis, is_ascend, dtype): verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype) +@tvm.testing.uses_gpu def test_topk(): def verify_topk(k, axis, ret_type, is_ascend, dtype): shape = (20, 100) @@ -70,7 +72,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): np_values[i, :] = np_data[i, np_indices[i, :]] np_indices = np_indices.astype(dtype) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(np_data) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 85dd2edb9185..0e0ab570ec10 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -21,8 +21,9 @@ from tvm import relay from tvm.relay import transform, analysis from tvm.relay.testing.temp_op_attr import TempOpAttr -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type import numpy as np +import tvm.testing def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] @@ -615,6 +616,7 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +@tvm.testing.uses_gpu def test_alter_layout_strided_slice(): """Test rewriting strided_slice during alter_iop_layout""" def before(): @@ -661,7 +663,7 @@ def expected(): mod_before['main'] = a mod_new['main'] = b with relay.build_config(opt_level=3): - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug", "vm"]: ex_before = relay.create_executor(kind, mod=mod_before, ctx=ctx, target=target) ex_new = relay.create_executor(kind, mod=mod_new, ctx=ctx, target=target) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index e95708007b95..7a2ff55790a7 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -23,7 +23,7 @@ from tvm.contrib import graph_runtime from tvm.relay.expr_functor import ExprMutator from tvm.relay import transform - +import tvm.testing def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] @@ -624,22 +624,33 @@ def expected(): tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) -def test_check_run(): - for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"), - ("opencl", str(tvm.target.intel_graphics()))]: - if not tvm.runtime.enabled(dev): - print("Skip test because %s is not enabled." % dev) - continue - run_fusible_network(dev, tgt) - run_unpropagatable_graph(dev, tgt) +@tvm.testing.requires_opencl +def test_check_run_opencl(): + dev = "opencl" + tgt = "opencl" + run_fusible_network(dev, tgt) + run_unpropagatable_graph(dev, tgt) -def test_tuple_get_item(): +@tvm.testing.requires_opencl +def test_check_run_opencl_intel(): + dev = "opencl" + tgt = str(tvm.target.intel_graphics()) + run_fusible_network(dev, tgt) + run_unpropagatable_graph(dev, tgt) + + +@tvm.testing.requires_cuda +def test_check_run_cuda(): dev = "cuda" - if not tvm.runtime.enabled(dev): - print("Skip test because %s is not enabled." % dev) - return + tgt = "cuda" + run_fusible_network(dev, tgt) + run_unpropagatable_graph(dev, tgt) + +@tvm.testing.requires_cuda +def test_tuple_get_item(): + dev = "cuda" cpu_ctx = tvm.cpu(0) gpu_ctx = tvm.context(dev) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 6b422ca0d594..453c469d2c07 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -20,8 +20,9 @@ from tvm import relay from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from tvm.relay.testing import run_infer_type, create_workload, ctx_list +from tvm.relay.testing import run_infer_type, create_workload import tvm.topi.testing +import tvm.testing def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) @@ -34,7 +35,7 @@ def run_opt_pass(expr, opt_pass): def verify_func(func, data, ref_res, rtol=1e-5, atol=1e-7): assert isinstance(data, list) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) @@ -42,6 +43,7 @@ def verify_func(func, data, ref_res, rtol=1e-5, atol=1e-7): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) +@tvm.testing.uses_gpu def test_dynamic_to_static_reshape(): def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -66,6 +68,7 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) +@tvm.testing.uses_gpu def test_dynamic_to_static_double_reshape(): def verify_reshape(shape, newshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -90,6 +93,7 @@ def verify_reshape(shape, newshape): verify_reshape((4, 7), (2, 7, 2)) +@tvm.testing.uses_gpu def test_dynamic_to_static_quad_reshape(): def verify_reshape(shape, newshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -116,6 +120,7 @@ def verify_reshape(shape, newshape): verify_reshape((4, 7), (2, 7, 2)) +@tvm.testing.uses_gpu def test_dynamic_to_static_tile(): def verify_tile(shape, reps, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -139,6 +144,7 @@ def verify_tile(shape, reps, oshape): verify_tile((4, 7), (4, 2), (16, 14)) +@tvm.testing.uses_gpu def test_dynamic_to_static_topk(): def verify_topk(k, axis, ret_type, is_ascend, dtype): shape = (20, 100) @@ -173,7 +179,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): assert isinstance(zz, relay.Call) assert zz.op == relay.op.get("topk") - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if "llvm" not in target: continue for kind in ["graph", "vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func2) @@ -195,6 +201,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): verify_topk(k, axis, ret_type, False, "float32") +@tvm.testing.uses_gpu def test_dynamic_to_static_broadcast_to(): def verify_broadcast_to(shape, broadcast_shape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -219,6 +226,7 @@ def verify_broadcast_to(shape, broadcast_shape): verify_broadcast_to((3, 1), (3, 3)) +@tvm.testing.uses_gpu def test_dynamic_to_static_zeros_ones(): def verify_ones_zeros(shape, dtype): for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]: @@ -241,6 +249,7 @@ def verify_ones_zeros(shape, dtype): verify_ones_zeros((9, 8, 3, 4), 'float32') +@tvm.testing.uses_gpu def test_dynamic_to_static_resize(): def verify_resize(shape, scale, method, layout): if layout == "NHWC": @@ -275,6 +284,7 @@ def verify_resize(shape, scale, method, layout): verify_resize((1, 4, 4, 4), 2, method, layout) +@tvm.testing.uses_gpu def test_dynamic_to_static_one_hot(): def _verify(indices_shape, depth, on_value, off_value, axis, dtype): indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) @@ -302,6 +312,7 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") +@tvm.testing.uses_gpu def test_dynamic_to_static_full(): def verify_full(fill_value, fill_shape, dtype): x = relay.var("x", relay.scalar_type(dtype)) @@ -310,7 +321,7 @@ def verify_full(fill_value, fill_shape, dtype): func = run_infer_type(relay.Function([x, y], z)) func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) - + zz = func2.body assert isinstance(zz, relay.Call) assert zz.op == relay.op.get("full") @@ -318,7 +329,7 @@ def verify_full(fill_value, fill_shape, dtype): ref_res = np.full(fill_shape, fill_value).astype(dtype) y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64') verify_func(func2, [fill_value, y_data], ref_res) - + verify_full(4, (1, 2, 3, 4), 'int32') verify_full(4.0, (1, 2, 8, 10), 'float32') diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 90e80d8e673f..df30eb41c3f0 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -19,6 +19,7 @@ from tvm import relay from tvm.relay import transform from tvm.relay.testing import run_opt_pass +import tvm.testing def test_fuse_simple(): @@ -704,6 +705,7 @@ def expected(): assert tvm.ir.structural_equal(m["main"], after) +@tvm.testing.uses_gpu def test_fuse_bcast_reduce_scalar(): """Test fusion case with broadcast and reduction involving scalar""" @@ -726,7 +728,7 @@ def expected(): orig = before() m = fuse2(tvm.IRModule.from_expr(orig)) - for tgt, _ in tvm.relay.testing.config.ctx_list(): + for tgt, ctx in tvm.testing.enabled_targets(): relay.build(m, tgt) after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(m["main"], after) diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index 377164e08b73..4a09e4ef25c8 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -20,6 +20,7 @@ from tvm import relay from tvm.relay import create_executor, transform from tvm.relay.testing import rand, run_infer_type +import tvm.testing from tvm.testing import assert_allclose import pytest diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 25299caae30b..9245bbde3544 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -25,7 +25,8 @@ from tvm.relay import Function, Call from tvm.relay import analysis from tvm.relay import transform as _transform -from tvm.relay.testing import ctx_list, run_infer_type +from tvm.relay.testing import run_infer_type +import tvm.testing def get_var_func(): @@ -114,6 +115,7 @@ def check_func(func, ref_func): assert tvm.ir.structural_equal(func, ref_func) +@tvm.testing.uses_gpu def test_module_pass(): shape = (5, 10) dtype = 'float32' @@ -178,7 +180,7 @@ def test_pass_run(): x_nd = get_rand(shape, dtype) y_nd = get_rand(shape, dtype) ref_res = x_nd.asnumpy() + y_nd.asnumpy() - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_add)(x_nd, y_nd) @@ -214,6 +216,7 @@ def transform_function(self, func, mod, ctx): assert tvm.ir.structural_equal(mod["main"], mod2["main"]) +@tvm.testing.uses_gpu def test_function_pass(): shape = (10, ) dtype = 'float32' @@ -271,7 +274,7 @@ def test_pass_run(): # Execute the add function. x_nd = get_rand(shape, dtype) ref_res = np.log(x_nd.asnumpy() * 2) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_log)(x_nd) @@ -314,6 +317,7 @@ def test_pass_info(): assert info.name == "xyz" +@tvm.testing.uses_gpu def test_sequential_pass(): shape = (10, ) dtype = 'float32' @@ -433,7 +437,7 @@ def test_multiple_passes(): x_nd = get_rand(shape, dtype) y_nd = get_rand(shape, dtype) ref_res = np.subtract(x_nd.asnumpy() * 2, y_nd.asnumpy() * 2) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_sub)(x_nd, y_nd) @@ -444,7 +448,7 @@ def test_multiple_passes(): # Execute the updated abs function. x_nd = get_rand((5, 10), dtype) ref_res = np.abs(x_nd.asnumpy() * 2) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): exe1 = relay.create_executor("graph", ctx=ctx, target=target) exe2 = relay.create_executor("debug", ctx=ctx, target=target) res1 = exe1.evaluate(new_abs)(x_nd) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index a69f928bae58..710025aeadb3 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -21,10 +21,10 @@ from tvm import runtime from tvm import relay from tvm.relay.scope_builder import ScopeBuilder -from tvm.relay.testing.config import ctx_list from tvm.relay.prelude import Prelude from tvm.relay.loops import while_loop from tvm.relay import testing +import tvm.testing def check_result(args, expected_result, mod=None): """ @@ -41,7 +41,7 @@ def check_result(args, expected_result, mod=None): """ # TODO(@zhiics, @icemelon9): Disable the gpu test for now until the heterogeneous support # is ready - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): if "cuda" in target: continue vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod) @@ -91,6 +91,7 @@ def test_split_no_fuse(): res = veval(f, x_data) tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) +@tvm.testing.uses_gpu def test_id(): x = relay.var('x', shape=(10, 10), dtype='float64') f = relay.Function([x], x) @@ -99,6 +100,7 @@ def test_id(): mod["main"] = f check_result([x_data], x_data, mod=mod) +@tvm.testing.uses_gpu def test_op(): x = relay.var('x', shape=(10, 10)) f = relay.Function([x], x + x) @@ -111,6 +113,7 @@ def any(x): x = relay.op.nn.batch_flatten(x) return relay.op.min(x, axis=[0, 1]) +@tvm.testing.uses_gpu def test_cond(): x = relay.var('x', shape=(10, 10)) y = relay.var('y', shape=(10, 10)) @@ -127,6 +130,7 @@ def test_cond(): # diff check_result([x_data, y_data], False, mod=mod) +@tvm.testing.uses_gpu def test_simple_if(): x = relay.var('x', shape=(10, 10)) y = relay.var('y', shape=(10, 10)) @@ -162,6 +166,7 @@ def test_multiple_ifs(): res = vmobj_to_list(vm.evaluate()(False)) assert(res == [1, 0]) +@tvm.testing.uses_gpu def test_simple_call(): mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') @@ -175,6 +180,7 @@ def test_simple_call(): mod["main"] = relay.Function([iarg], sum_up(iarg)) check_result([i_data], i_data, mod=mod) +@tvm.testing.uses_gpu def test_count_loop(): mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') @@ -195,6 +201,7 @@ def test_count_loop(): tvm.testing.assert_allclose(result.asnumpy(), i_data) check_result([i_data], i_data, mod=mod) +@tvm.testing.uses_gpu def test_sum_loop(): mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') @@ -217,6 +224,7 @@ def test_sum_loop(): mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) check_result([i_data, accum_data], sum(range(1, loop_bound + 1)), mod=mod) +@tvm.testing.uses_gpu def test_tuple_fst(): ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) tup = relay.var('tup', type_annotation=ttype) @@ -227,6 +235,7 @@ def test_tuple_fst(): mod["main"] = f check_result([(i_data, j_data)], i_data, mod=mod) +@tvm.testing.uses_gpu def test_tuple_second(): ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) tup = relay.var('tup', type_annotation=ttype) @@ -259,6 +268,7 @@ def test_list_constructor(): obj = vmobj_to_list(result) tvm.testing.assert_allclose(obj, np.array([3,2,1])) +@tvm.testing.uses_gpu def test_let_tensor(): sb = relay.ScopeBuilder() shape = (1,) @@ -277,6 +287,7 @@ def test_let_tensor(): mod["main"] = f check_result([x_data], x_data + 42.0, mod=mod) +@tvm.testing.uses_gpu def test_let_scalar(): sb = relay.ScopeBuilder() @@ -545,6 +556,7 @@ def test_closure(): res = veval(main) tvm.testing.assert_allclose(res.asnumpy(), 3.0) +@tvm.testing.uses_gpu def test_add_op_scalar(): """ test_add_op_scalar: @@ -561,6 +573,7 @@ def test_add_op_scalar(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) +@tvm.testing.uses_gpu def test_add_op_tensor(): """ test_add_op_tensor: @@ -577,6 +590,7 @@ def test_add_op_tensor(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) +@tvm.testing.uses_gpu def test_add_op_broadcast(): """ test_add_op_broadcast: @@ -608,6 +622,7 @@ def test_vm_optimize(): comp = relay.vm.VMCompiler() opt_mod, _ = comp.optimize(mod, target="llvm", params=params) +@tvm.testing.uses_gpu def test_loop_free_var(): x = relay.var('x', shape=(), dtype='int32') i = relay.var('i', shape=(), dtype='int32') @@ -634,6 +649,7 @@ def body_with_free_var(i, acc): mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret) check_result(args, expected, mod=mod) +@tvm.testing.uses_gpu def test_vm_reshape_tensor(): x_np = np.random.uniform(size=(8, 16)).astype("float32") x = relay.var("x", shape=(8, 16), dtype="float32") diff --git a/tests/python/topi/python/common.py b/tests/python/topi/python/common.py index 735072c1ca4d..d63251ee99f1 100644 --- a/tests/python/topi/python/common.py +++ b/tests/python/topi/python/common.py @@ -16,22 +16,8 @@ # under the License. """Common utility for topi test""" -import tvm -from tvm import te from tvm import autotvm from tvm.autotvm.task.space import FallbackConfigEntity -from tvm import topi - -def get_all_backend(): - """return all supported target - - Returns - ------- - targets: list - A list of all supported targets - """ - return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx', - 'llvm -device=arm_cpu', 'opencl -device=mali', 'aocl_sw_emu'] class Int8Fallback(autotvm.FallbackContext): def _query_inside(self, target, workload): diff --git a/tests/python/topi/python/test_fifo_buffer.py b/tests/python/topi/python/test_fifo_buffer.py index 9af30f9dc779..8e69a7639358 100644 --- a/tests/python/topi/python/test_fifo_buffer.py +++ b/tests/python/topi/python/test_fifo_buffer.py @@ -19,11 +19,11 @@ import tvm from tvm import te from tvm import topi +import tvm.testing import tvm.topi.testing import numpy as np from tvm.contrib.pickle_memoize import memoize -from common import get_all_backend def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype='float32'): buffer = te.placeholder(buffer_shape, name='buffer', dtype=dtype) @@ -46,11 +46,7 @@ def get_ref_data(): # Get the test data buffer_np, data_np, out_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print(' Skip because %s is not enabled' % device) - return + def check_device(device, ctx): print(' Running on target: {}'.format(device)) with tvm.target.create(device): @@ -64,8 +60,8 @@ def check_device(device): f(data_tvm, buffer_tvm, out_tvm) tvm.testing.assert_allclose(out_tvm.asnumpy(), out_np) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_conv1d_integration(): batch_size = 1 @@ -122,11 +118,7 @@ def get_data(): # Get the test data inc_input_np, input_window_np, kernel_np, context_np, output_window_np = get_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print(' Skip because %s is not enabled' % device) - return + def check_device(device, ctx): print(' Running on target: {}'.format(device)) conv2d_nchw, schedule_conv2d_nchw = tvm.topi.testing.get_conv2d_nchw_implement(device) @@ -184,9 +176,10 @@ def check_device(device): tvm.testing.assert_allclose(output_window_tvm.asnumpy(), output_window_ref_tvm.asnumpy()) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_fifo_buffer(): for ndim in [1, 2, 3, 4, 5, 6]: for axis in range(ndim): @@ -196,6 +189,7 @@ def test_fifo_buffer(): .format(buffer_shape, data_shape, axis)) verify_fifo_buffer(buffer_shape, data_shape, axis) +@tvm.testing.uses_gpu def test_conv1d_integration(): print('Testing FIFO buffer with 1D convolution') verify_conv1d_integration() diff --git a/tests/python/topi/python/test_topi_batch_matmul.py b/tests/python/topi/python/test_topi_batch_matmul.py index c8cddb661dca..c785c6d85108 100644 --- a/tests/python/topi/python/test_topi_batch_matmul.py +++ b/tests/python/topi/python/test_topi_batch_matmul.py @@ -23,7 +23,7 @@ from tvm.topi.util import get_const_tuple from tvm.contrib.pickle_memoize import memoize -from common import get_all_backend +import tvm.testing _batch_matmul_implement = { "generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul), @@ -46,11 +46,7 @@ def get_ref_data(): # get the test data a_np, b_np, c_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _batch_matmul_implement) @@ -63,9 +59,10 @@ def check_device(device): f(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_batch_matmul(): verify_batch_matmul(1, 16, 16, 32) verify_batch_matmul(5, 16, 16, 32) diff --git a/tests/python/topi/python/test_topi_broadcast.py b/tests/python/topi/python/test_topi_broadcast.py index 4ac985e057b9..9826b88ee2db 100644 --- a/tests/python/topi/python/test_topi_broadcast.py +++ b/tests/python/topi/python/test_topi_broadcast.py @@ -20,7 +20,6 @@ from tvm import te from tvm import topi import tvm.topi.testing -from common import get_all_backend def verify_broadcast_to_ele(in_shape, out_shape, fbcast): @@ -30,7 +29,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -44,7 +43,7 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for target in get_all_backend(): + for target, ctx in tvm.testing.enabled_targets(): check_device(target) check_device("sdaccel") @@ -78,7 +77,7 @@ def gen_operand(shape, low, high, ctx): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -94,11 +93,12 @@ def check_device(device): foo(lhs_nd, rhs_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) - for target in get_all_backend(): + for target, ctx in tvm.testing.enabled_targets(): check_device(target) check_device("sdaccel") +@tvm.testing.uses_gpu def test_broadcast_to(): verify_broadcast_to_ele((1,), (10,), topi.broadcast_to) verify_broadcast_to_ele((), (10,), topi.broadcast_to) @@ -106,6 +106,7 @@ def test_broadcast_to(): verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to) +@tvm.testing.uses_gpu def test_add(): verify_broadcast_binary_ele( (), (), topi.add, np.add) @@ -113,6 +114,7 @@ def test_add(): (5, 2, 3), (2, 1), topi.add, np.add) +@tvm.testing.uses_gpu def test_subtract(): verify_broadcast_binary_ele( (5, 2, 3), (), topi.subtract, np.subtract) @@ -124,11 +126,13 @@ def test_subtract(): (1, 32), (64, 32), topi.subtract, np.subtract) +@tvm.testing.uses_gpu def test_multiply(): verify_broadcast_binary_ele( (5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply) +@tvm.testing.uses_gpu def test_divide(): verify_broadcast_binary_ele( None, (10,), topi.divide, np.divide, rhs_min=0.0001) @@ -137,6 +141,7 @@ def test_divide(): verify_broadcast_binary_ele( (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001) +@tvm.testing.uses_gpu def test_floor_divide(): def _canonical_floor_div(a,b): return np.floor(a / b) @@ -147,6 +152,7 @@ def _canonical_floor_div(a,b): verify_broadcast_binary_ele( (2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001) +@tvm.testing.uses_gpu def test_maximum_minmum(): verify_broadcast_binary_ele( (32,), (64, 32), topi.maximum, np.maximum) @@ -154,15 +160,18 @@ def test_maximum_minmum(): (1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum) +@tvm.testing.uses_gpu def test_power(): verify_broadcast_binary_ele( (1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2) +@tvm.testing.uses_gpu def test_mod(): verify_broadcast_binary_ele( (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32") +@tvm.testing.uses_gpu def test_floor_mod(): def _canonical_floor_mod(a,b): return a - np.floor(a / b) * b @@ -171,6 +180,7 @@ def _canonical_floor_mod(a,b): verify_broadcast_binary_ele( (3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32") +@tvm.testing.uses_gpu def test_cmp(): # explicit specify the output type def greater(x, y): @@ -208,6 +218,7 @@ def less_equal(x, y): lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32') +@tvm.testing.uses_gpu def test_shift(): # explicit specify the output type verify_broadcast_binary_ele( @@ -223,6 +234,7 @@ def test_shift(): dtype="int8", rhs_min=0, rhs_max=32) +@tvm.testing.uses_gpu def test_logical_single_ele(): def test_apply( func, @@ -238,11 +250,7 @@ def test_apply( assert (isinstance(B, tvm.tir.PrimExpr)) return - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_broadcast_schedule(device)(B) @@ -256,13 +264,14 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) test_apply(topi.logical_not, "logical_not", np.logical_not, np.array([True, False, 0, 1])) test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3)) +@tvm.testing.uses_gpu def test_bitwise_not(): def test_apply( func, @@ -279,11 +288,7 @@ def test_apply( assert (isinstance(B, tvm.tir.PrimExpr)) return - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_broadcast_schedule(device)(B) @@ -297,13 +302,14 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) test_apply(topi.bitwise_not, "bitwise_not", np.bitwise_not, ()) test_apply(topi.bitwise_not, "bitwise_not", np.bitwise_not, (2, 1, 2)) +@tvm.testing.uses_gpu def test_logical_binary_ele(): def test_apply( func, @@ -321,11 +327,7 @@ def test_apply( assert (isinstance(C, tvm.tir.PrimExpr)) return - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_broadcast_schedule(device)(C) @@ -339,8 +341,8 @@ def check_device(device): foo(lhs_nd, rhs_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) test_apply(topi.logical_and, "logical_and", np.logical_and, True, False) test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False]) @@ -350,6 +352,7 @@ def check_device(device): test_apply(topi.logical_xor, "logical_xor", np.logical_xor, [True, False], [False, False]) +@tvm.testing.uses_gpu def test_bitwise_and(): verify_broadcast_binary_ele( None, None, topi.bitwise_and, np.bitwise_and, @@ -359,6 +362,7 @@ def test_bitwise_and(): dtype="int32") +@tvm.testing.uses_gpu def test_bitwise_or(): verify_broadcast_binary_ele( None, None, topi.bitwise_or, np.bitwise_or, @@ -368,6 +372,7 @@ def test_bitwise_or(): dtype="int32") +@tvm.testing.uses_gpu def test_bitwise_xor(): verify_broadcast_binary_ele( None, None, topi.bitwise_xor, np.bitwise_xor, diff --git a/tests/python/topi/python/test_topi_clip.py b/tests/python/topi/python/test_topi_clip.py index b3d95dd2e07a..70af1f84cae0 100644 --- a/tests/python/topi/python/test_topi_clip.py +++ b/tests/python/topi/python/test_topi_clip.py @@ -19,11 +19,11 @@ import tvm from tvm import te from tvm import topi +import tvm.testing import tvm.topi.testing from tvm.topi.util import get_const_tuple from tvm.contrib.pickle_memoize import memoize -from common import get_all_backend def verify_clip(N, a_min, a_max, dtype): A = te.placeholder((N, N), dtype=dtype, name='A') @@ -38,11 +38,7 @@ def get_ref_data(): return a_np, b_np a_np, b_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -53,9 +49,10 @@ def check_device(device): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_clip(): verify_clip(1024, -127, 127, 'float32') verify_clip(1024, -127, 127, 'int16') diff --git a/tests/python/topi/python/test_topi_conv1d.py b/tests/python/topi/python/test_topi_conv1d.py index 49f2cd1125a3..b50aa56c74fe 100644 --- a/tests/python/topi/python/test_topi_conv1d.py +++ b/tests/python/topi/python/test_topi_conv1d.py @@ -23,7 +23,6 @@ import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple -from common import get_all_backend _conv1d_ncw_implement = { @@ -74,11 +73,7 @@ def get_ref_data(layout): a_np, w_np, b_np = get_ref_data(layout) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): if layout == "NCW": fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv1d_ncw_implement) else: @@ -95,10 +90,11 @@ def check_device(device): func(a, w, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_conv1d(): for layout in ["NCW", "NWC"]: # Most basic test case diff --git a/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py b/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py index 7efa96d807b6..fc5819be3330 100644 --- a/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py +++ b/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py @@ -23,7 +23,7 @@ import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple -from common import get_all_backend +import tvm.testing _conv1d_transpose_ncw_implement = { "generic": (topi.nn.conv1d_transpose_ncw, topi.generic.schedule_conv1d_transpose_ncw), @@ -49,11 +49,8 @@ def get_ref_data(): a_np, w_np, b_np, c_np = get_ref_data() - def check_device(device): + def check_device(device, ctx): ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return with tvm.target.create(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv1d_transpose_ncw_implement) B = fcompute(A, W, stride, padding, A.dtype, output_padding) @@ -72,10 +69,11 @@ def check_device(device): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_conv1d_transpose_ncw(): verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 1, 0, (0,)) verify_conv1d_transpose_ncw(1, 3, 224, 32, 7, 1, 2, (0,)) diff --git a/tests/python/topi/python/test_topi_conv2d_NCHWc.py b/tests/python/topi/python/test_topi_conv2d_NCHWc.py index 95d5633bc1f8..604d09d6837a 100644 --- a/tests/python/topi/python/test_topi_conv2d_NCHWc.py +++ b/tests/python/topi/python/test_topi_conv2d_NCHWc.py @@ -21,13 +21,12 @@ from tvm import te from tvm import autotvm from tvm import topi +import tvm.testing import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.util import get_pad_tuple from tvm.topi.util import get_const_tuple -from common import get_all_backend - def _transform_data(data, bn): # NCHW -> NCHW[x]c batch_size, channel, height, width = data.shape @@ -94,7 +93,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) diff --git a/tests/python/topi/python/test_topi_conv2d_hwcn.py b/tests/python/topi/python/test_topi_conv2d_hwcn.py index 20b1b4dfa8e5..04f34b6ea673 100644 --- a/tests/python/topi/python/test_topi_conv2d_hwcn.py +++ b/tests/python/topi/python/test_topi_conv2d_hwcn.py @@ -23,6 +23,7 @@ import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple +import tvm.testing _conv2d_hwcn_implement = { @@ -58,7 +59,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -94,6 +95,7 @@ def check_device(device): check_device(device) +@tvm.testing.requires_gpu def test_conv2d_hwcn(): verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME") verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME") diff --git a/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py b/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py index 2c071c9a266b..ea1aee1cae66 100644 --- a/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py +++ b/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py @@ -32,7 +32,7 @@ } def verify_conv2d_hwnc(batch, in_channel, in_size, num_filter, kernel, stride, - padding, dilation=1, devices='cuda', dtype='int4'): + padding, dilation=1, dtype='int4'): """Test the conv2d with tensorcore for hwnc layout""" pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) padding_sum = pad_top + pad_left + pad_bottom + pad_right @@ -89,7 +89,7 @@ def convert_int32_into_int4(a_int32): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return if not nvcc.have_tensorcore(ctx.compute_version): @@ -112,9 +112,10 @@ def check_device(device): rtol = 1e-3 tvm.testing.assert_allclose(c.asnumpy().transpose((2, 0, 1, 3)), c_np, rtol=rtol) - check_device(devices) + check_device('cuda') +@tvm.testing.requires_tensorcore def test_conv2d_hwnc_tensorcore(): """Test the conv2d with tensorcore for hwnc layout""" verify_conv2d_hwnc(8, 64, 56, 64, 3, 1, 1, dtype='int8') diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 615dc515b1f4..c18946b2b933 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -28,7 +28,8 @@ from tvm.topi.util import get_const_tuple from tvm.topi.arm_cpu.conv2d_gemm import is_aarch64_arm -from common import get_all_backend, Int8Fallback +from common import Int8Fallback +import tvm.testing def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): @@ -45,7 +46,7 @@ def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, ke device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu" ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Compiling on arm AArch64 target: %s" % device) @@ -128,7 +129,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -223,7 +224,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version): @@ -293,7 +294,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version): @@ -327,6 +328,7 @@ def check_device(device): check_device(device) +@tvm.testing.requires_cuda def test_conv2d_nchw(): with Int8Fallback(): # ResNet18 workloads where channels in / out are multiple of oc_block_factor diff --git a/tests/python/topi/python/test_topi_conv2d_nchw.py b/tests/python/topi/python/test_topi_conv2d_nchw.py index dcdf0a776099..a306e3edae11 100644 --- a/tests/python/topi/python/test_topi_conv2d_nchw.py +++ b/tests/python/topi/python/test_topi_conv2d_nchw.py @@ -26,7 +26,7 @@ from tvm.topi.nn.util import get_pad_tuple from tvm.topi.util import get_const_tuple -from common import get_all_backend +import tvm.testing def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,\ use_cudnn=False): @@ -63,7 +63,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -97,7 +97,7 @@ def check_device(device): func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4) - for device in get_all_backend(): + for device, ctx in tvm.testing.enabled_targets(): with autotvm.tophub.context(device): # load tophub pre-tuned parameters check_device(device) @@ -105,6 +105,7 @@ def check_device(device): check_device("cuda -model=unknown -libs=cudnn") +@tvm.testing.uses_gpu def test_conv2d_nchw(): # ResNet18 workloads verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py b/tests/python/topi/python/test_topi_conv2d_nhwc.py index 7750f235c6c5..29b8634869ff 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py @@ -23,7 +23,7 @@ import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple - +import tvm.testing _conv2d_nhwc_implement = { @@ -56,7 +56,7 @@ def get_ref_data(): a_np, w_np, b_np = get_ref_data() def check_device(device): - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -76,6 +76,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_conv2d_nhwc(): verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME") verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "SAME") diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc_pack_int8.py b/tests/python/topi/python/test_topi_conv2d_nhwc_pack_int8.py index 4439d6ae13eb..019dd30fda2f 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc_pack_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc_pack_int8.py @@ -51,7 +51,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc_tensorcore.py b/tests/python/topi/python/test_topi_conv2d_nhwc_tensorcore.py index 8375df34323c..fb0167a1e045 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc_tensorcore.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc_tensorcore.py @@ -26,6 +26,7 @@ from tvm.contrib import nvcc from tvm.topi.nn.util import get_pad_tuple from tvm.topi.util import get_const_tuple +import tvm.testing _conv2d_nhwc_tensorcore_implement = { @@ -70,7 +71,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return if not nvcc.have_tensorcore(ctx.compute_version): @@ -105,6 +106,8 @@ def check_device(device): check_device(devices) +@tvm.testing.requires_cuda +@tvm.testing.requires_gpu def test_conv2d_nhwc_tensorcore(): """Test the conv2d with tensorcore for nhwc layout""" verify_conv2d_nhwc(16, 16, 14, 16, 3, 1, 1) diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc_winograd.py b/tests/python/topi/python/test_topi_conv2d_nhwc_winograd.py index 00b40bfbe826..cbcc32d0b425 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc_winograd.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc_winograd.py @@ -24,9 +24,9 @@ import tvm.topi.testing from tvm import te from tvm.contrib.pickle_memoize import memoize -from tvm.contrib import nvcc from tvm.topi.nn.util import get_pad_tuple from tvm.topi.util import get_const_tuple +import tvm.testing _conv2d_nhwc_winograd_tensorcore = { @@ -78,9 +78,6 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return print("Running on target: %s" % device) with tvm.target.create(device): if bgemm == "direct": @@ -114,6 +111,8 @@ def check_device(device): check_device(devices) +@tvm.testing.requires_cuda +@tvm.testing.requires_gpu def test_conv2d_nhwc_winograd_direct(): """Test the conv2d with winograd for nhwc layout""" # resnet 18 workloads @@ -135,13 +134,11 @@ def test_conv2d_nhwc_winograd_direct(): verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, "SAME", add_relu=True, add_bias=True) verify_conv2d_nhwc(1, 48, 35, 48, 5, 1, "VALID") + +@tvm.testing.requires_cuda +@tvm.testing.requires_tensorcore def test_conv2d_nhwc_winograd_tensorcore(): """Test the conv2d with winograd for nhwc layout""" - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): - return verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1, bgemm="tensorcore") verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1, bgemm="tensorcore") verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1, bgemm="tensorcore") diff --git a/tests/python/topi/python/test_topi_conv2d_transpose_nchw.py b/tests/python/topi/python/test_topi_conv2d_transpose_nchw.py index 6c43b2d980cf..8c30f441e622 100644 --- a/tests/python/topi/python/test_topi_conv2d_transpose_nchw.py +++ b/tests/python/topi/python/test_topi_conv2d_transpose_nchw.py @@ -23,7 +23,7 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple -from common import get_all_backend +import tvm.testing _conv2d_transpose_nchw_implement = { @@ -57,11 +57,7 @@ def get_ref_data(): a_np, w_np, b_np, c_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv2d_transpose_nchw_implement) @@ -83,10 +79,11 @@ def check_device(device): func2(a, w, c) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_conv2d_transpose_nchw(): verify_conv2d_transpose_nchw(1, 3, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0), (0, 0)) verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0)) diff --git a/tests/python/topi/python/test_topi_conv2d_winograd.py b/tests/python/topi/python/test_topi_conv2d_winograd.py index 800aaea5363a..674590a7fa0f 100644 --- a/tests/python/topi/python/test_topi_conv2d_winograd.py +++ b/tests/python/topi/python/test_topi_conv2d_winograd.py @@ -26,6 +26,7 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.util import get_pad_tuple from tvm.topi.util import get_const_tuple +import tvm.testing _conv2d_nchw_winograd_implement = { @@ -70,7 +71,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -102,6 +103,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_conv2d_nchw(): # inception v3 workloads verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3, devices=['cuda']) diff --git a/tests/python/topi/python/test_topi_conv3d_ncdhw.py b/tests/python/topi/python/test_topi_conv3d_ncdhw.py index ad2b93ce00ce..319fb723da76 100644 --- a/tests/python/topi/python/test_topi_conv3d_ncdhw.py +++ b/tests/python/topi/python/test_topi_conv3d_ncdhw.py @@ -21,13 +21,12 @@ from tvm import te from tvm import autotvm from tvm import topi +import tvm.testing import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.util import get_pad_tuple3d from tvm.topi.util import get_const_tuple -from common import get_all_backend - _conv3d_ncdhw_implement = { "generic": (topi.nn.conv3d_ncdhw, topi.generic.schedule_conv3d_ncdhw), "cpu": (topi.x86.conv3d_ncdhw, topi.x86.schedule_conv3d_ncdhw), @@ -66,11 +65,7 @@ def get_ref_data(): a_np, w_np, b_np, c_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv3d_ncdhw_implement) with tvm.target.create(device): @@ -94,10 +89,11 @@ def check_device(device): func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4) - for device in get_all_backend(): + for device, ctx in tvm.testing.enabled_targets(): with autotvm.tophub.context(device): # load tophub pre-tuned parameters - check_device(device) + check_device(device, ctx) +@tvm.testing.uses_gpu def test_conv3d_ncdhw(): #3DCNN workloads verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 0) diff --git a/tests/python/topi/python/test_topi_conv3d_ndhwc.py b/tests/python/topi/python/test_topi_conv3d_ndhwc.py index b80f96bfb26d..7e330e77a365 100644 --- a/tests/python/topi/python/test_topi_conv3d_ndhwc.py +++ b/tests/python/topi/python/test_topi_conv3d_ndhwc.py @@ -24,7 +24,6 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple -from common import get_all_backend _conv3d_ndhwc_implement = { "generic": (topi.nn.conv3d_ndhwc, topi.generic.schedule_conv3d_ndhwc), @@ -58,11 +57,7 @@ def get_ref_data(): return a_np, w_np, b_np a_np, w_np, b_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv3d_ndhwc_implement) with tvm.target.create(device): @@ -76,10 +71,11 @@ def check_device(device): func(a, w, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_conv3d_ndhwc(): verify_conv3d_ndhwc(1, 16, 32, 16, 3, 1, "SAME") verify_conv3d_ndhwc(4, 32, 16, 32, 5, 2, "SAME") diff --git a/tests/python/topi/python/test_topi_conv3d_ndhwc_tensorcore.py b/tests/python/topi/python/test_topi_conv3d_ndhwc_tensorcore.py index 2adc34864c13..9f92efa54222 100644 --- a/tests/python/topi/python/test_topi_conv3d_ndhwc_tensorcore.py +++ b/tests/python/topi/python/test_topi_conv3d_ndhwc_tensorcore.py @@ -26,6 +26,7 @@ from tvm.contrib import nvcc from tvm.topi.nn.util import get_pad_tuple3d from tvm.topi.util import get_const_tuple +import tvm.testing _conv3d_ndhwc_tensorcore_implement = { @@ -71,12 +72,6 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - if not nvcc.have_tensorcore(ctx.compute_version): - print("skip because gpu does not support Tensor Cores") - return print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv3d_ndhwc_tensorcore_implement) @@ -106,6 +101,8 @@ def check_device(device): check_device(devices) +@tvm.testing.requires_tensorcore +@tvm.testing.requires_cuda def test_conv3d_ndhwc_tensorcore(): """Test the conv3d with tensorcore for ndhwc layout""" verify_conv3d_ndhwc(16, 16, 14, 16, 3, 1, 1) diff --git a/tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py b/tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py index 8e9812043ce9..25d9b725dedf 100644 --- a/tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py +++ b/tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py @@ -19,12 +19,11 @@ import tvm from tvm import te from tvm import topi +import tvm.testing import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple -from common import get_all_backend - _conv3d_transpose_ncdhw_implement = { "generic": (topi.nn.conv3d_transpose_ncdhw, topi.generic.schedule_conv3d_transpose_ncdhw), @@ -55,11 +54,7 @@ def get_ref_data(): a_np, w_np, b_np, c_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv3d_transpose_ncdhw_implement) @@ -81,10 +76,11 @@ def check_device(device): func2(a, w, c) tvm.testing.assert_allclose(b.asnumpy(), b_np, atol=1e-4, rtol=1e-4) tvm.testing.assert_allclose(c.asnumpy(), c_np, atol=1e-4, rtol=1e-4) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_conv3d_transpose_ncdhw(): verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 1, (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0)) verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0)) diff --git a/tests/python/topi/python/test_topi_conv3d_winograd.py b/tests/python/topi/python/test_topi_conv3d_winograd.py index 6e261305b9a4..a6e528c85e33 100644 --- a/tests/python/topi/python/test_topi_conv3d_winograd.py +++ b/tests/python/topi/python/test_topi_conv3d_winograd.py @@ -21,12 +21,12 @@ from tvm import te from tvm import autotvm from tvm import topi +import tvm.testing import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.util import get_pad_tuple3d from tvm.topi.util import get_const_tuple -from common import get_all_backend _conv3d_ncdhw_implement = { "gpu": (topi.cuda.conv3d_ncdhw_winograd, topi.cuda.schedule_conv3d_ncdhw_winograd), @@ -78,7 +78,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -117,6 +117,7 @@ def check_device(device): check_device(device) +@tvm.testing.requires_gpu def test_conv3d_ncdhw(): # Try without depth transformation #3DCNN workloads diff --git a/tests/python/topi/python/test_topi_correlation.py b/tests/python/topi/python/test_topi_correlation.py index f5eb51c8a6af..81063925ebc3 100644 --- a/tests/python/topi/python/test_topi_correlation.py +++ b/tests/python/topi/python/test_topi_correlation.py @@ -24,9 +24,6 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple -from common import get_all_backend - - _correlation_implement = { "generic": (topi.nn.correlation_nchw, topi.generic.schedule_correlation_nchw), "cuda": (topi.cuda.correlation_nchw, topi.cuda.schedule_correlation_nchw), @@ -52,11 +49,7 @@ def get_ref_data(): a_np, b_np, c_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) fcompute, fschedule = tvm.topi.testing.dispatch( device, _correlation_implement) @@ -72,10 +65,11 @@ def check_device(device): func(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_correlation_nchw(): verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=4, stride1=1, stride2=1, pad_size=4, is_multiply=True) diff --git a/tests/python/topi/python/test_topi_deformable_conv2d.py b/tests/python/topi/python/test_topi_deformable_conv2d.py index a2a01fc7ea1f..5d361b4d02b4 100644 --- a/tests/python/topi/python/test_topi_deformable_conv2d.py +++ b/tests/python/topi/python/test_topi_deformable_conv2d.py @@ -23,7 +23,7 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple -from common import get_all_backend +import tvm.testing _deformable_conv2d_implement = { @@ -62,7 +62,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -85,6 +85,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_deformable_conv2d_nchw(): verify_deformable_conv2d_nchw(1, 16, 7, 16, 1, 1, 0, deformable_groups=4) verify_deformable_conv2d_nchw(1, 16, 7, 16, 3, 1, 1, dilation=2, deformable_groups=4) diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 517cb4d3ecc6..e6530e751e1b 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -23,7 +23,8 @@ from tvm.topi.util import get_const_tuple from tvm.contrib.pickle_memoize import memoize -from common import get_all_backend, Int8Fallback +from common import Int8Fallback +import tvm.testing _dense_implement = { "generic": [(topi.nn.dense, topi.generic.schedule_dense)], @@ -57,11 +58,7 @@ def get_ref_data(): # get the test data a_np, b_np, c_np, d_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) for fcompute, fschedule in tvm.topi.testing.dispatch(device, _dense_implement): with tvm.target.create(device): @@ -76,8 +73,8 @@ def check_device(device): f(a, b, c, d) tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_dense_int8(batch, in_dim, out_dim, use_bias=True): @@ -104,9 +101,6 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version): print("Skip because int8 intrinsics are not available") return @@ -128,6 +122,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_dense(): verify_dense(1, 1024, 1000, use_bias=True) verify_dense(1, 1024, 1000, use_bias=False) @@ -136,6 +131,8 @@ def test_dense(): verify_dense(128, 1024, 1000, use_bias=True) +@tvm.testing.requires_cuda +@tvm.testing.requires_gpu def test_dense_int8(): with Int8Fallback(): verify_dense_int8(2, 1024, 1000, use_bias=True) diff --git a/tests/python/topi/python/test_topi_dense_tensorcore.py b/tests/python/topi/python/test_topi_dense_tensorcore.py index 8a645e6b45ca..642d124a86a5 100644 --- a/tests/python/topi/python/test_topi_dense_tensorcore.py +++ b/tests/python/topi/python/test_topi_dense_tensorcore.py @@ -23,7 +23,7 @@ from tvm.topi.util import get_const_tuple from tvm import te from tvm.contrib.pickle_memoize import memoize -from tvm.contrib import nvcc +import tvm.testing _dense_implement = { @@ -53,12 +53,6 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - if not nvcc.have_tensorcore(ctx.compute_version): - print("skip because gpu does not support Tensor Cores") - return print("Running on target: %s" % device) for fcompute, fschedule in tvm.topi.testing.dispatch(device, _dense_implement): with tvm.target.create(device): @@ -74,10 +68,10 @@ def check_device(device): tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-3) - for device in ['cuda']: - check_device(device) + check_device('cuda') +@tvm.testing.requires_tensorcore def test_dense_tensorcore(): """Test cases""" verify_dense(8, 16, 32, use_bias=True) diff --git a/tests/python/topi/python/test_topi_depth_to_space.py b/tests/python/topi/python/test_topi_depth_to_space.py index 380f656bf599..c94981235522 100644 --- a/tests/python/topi/python/test_topi_depth_to_space.py +++ b/tests/python/topi/python/test_topi_depth_to_space.py @@ -19,10 +19,9 @@ import tvm from tvm import te from tvm import topi +import tvm.testing import tvm.topi.testing -from common import get_all_backend - def verify_depth_to_space(block_size, batch, in_channel, in_height, in_width, layout='NCHW', mode='DCR'): out_channel = int(in_channel / (block_size * block_size)) @@ -50,11 +49,7 @@ def verify_depth_to_space(block_size, batch, in_channel, in_height, in_width, la a_np = np.transpose(a_np, axes=[0, 2, 3, 1]) b_np = np.transpose(b_np, axes=[0, 2, 3, 1]) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -64,10 +59,11 @@ def check_device(device): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_depth_to_space(): for layout in ['NCHW', 'NHWC']: for mode in ['DCR', 'CDR']: diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 5497e1124e70..bc804cb978f9 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -24,7 +24,7 @@ from tvm.topi.nn.util import get_pad_tuple from tvm.contrib.pickle_memoize import memoize -from common import get_all_backend +import tvm.testing _depthwise_conv2d_nchw_implement = { "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)], @@ -67,11 +67,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu dtype = 'float32' - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) impl_list = tvm.topi.testing.dispatch(device, _depthwise_conv2d_nchw_implement)[:] @@ -143,9 +139,9 @@ def get_ref_data(): tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) - for device in get_all_backend(): + for device, ctx in tvm.testing.enabled_targets(): with autotvm.tophub.context(device): # load tophub pre-tuned parameters - check_device(device) + check_device(device, ctx) def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1): @@ -170,11 +166,7 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu dtype = 'float32' - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) fcompute, fschedule = tvm.topi.testing.dispatch(device, _depthwise_conv2d_nhwc_implement) @@ -243,9 +235,9 @@ def get_ref_data(): tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) - for device in get_all_backend(): + for device, ctx in tvm.testing.enabled_targets(): with autotvm.tophub.context(device): # load tophub pre-tuned parameters - check_device(device) + check_device(device, ctx) def _transform_data(data, bn): # NCHW -> NCHW[x]c @@ -298,7 +290,7 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -360,6 +352,7 @@ def get_ref_data(): check_device(device) +@tvm.testing.uses_gpu def test_depthwise_conv2d(): # mobilenet workloads depthwise_conv2d_with_workload_nchw(1, 32, 112, 1, 3, 1, "SAME") diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d_back_input.py b/tests/python/topi/python/test_topi_depthwise_conv2d_back_input.py index ba8bfcc72a4e..25ef6f10815e 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d_back_input.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d_back_input.py @@ -24,6 +24,7 @@ from tvm.topi.nn.util import get_pad_tuple import tvm.topi.testing from tvm.topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc +import tvm.testing def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h): @@ -51,7 +52,7 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -106,6 +107,7 @@ def get_ref_data(): check_device("vulkan") check_device("nvptx") +@tvm.testing.requires_gpu def test_topi_depthwise_conv2d_backward_input_nhwc(): verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1) verify_depthwise_conv2d_back_input(16, 256, 56, 2, 3, 1, 1) diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py b/tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py index 599225d0a667..5ebc56dbd5c1 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py @@ -24,6 +24,7 @@ from tvm.topi.util import get_const_tuple from tvm.topi.nn.util import get_pad_tuple from tvm.topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc +import tvm.testing def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multiplier, filter_h, stride_h, padding_h): @@ -51,7 +52,7 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -99,6 +100,7 @@ def get_ref_data(): check_device("vulkan") check_device("nvptx") +@tvm.testing.requires_gpu def test_topi_depthwise_conv2d_backward_weight_nhwc(): verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1) verify_depthwise_conv2d_back_weight(16, 256, 56, 2, 3, 1, 1) diff --git a/tests/python/topi/python/test_topi_group_conv2d.py b/tests/python/topi/python/test_topi_group_conv2d.py index 6050d452140c..2eea4b0ca43b 100644 --- a/tests/python/topi/python/test_topi_group_conv2d.py +++ b/tests/python/topi/python/test_topi_group_conv2d.py @@ -26,7 +26,8 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple -from common import get_all_backend, Int8Fallback +from common import Int8Fallback +import tvm.testing _group_conv2d_nchw_implement = { @@ -71,7 +72,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return @@ -148,7 +149,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version): @@ -182,6 +183,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_group_conv2d_nchw(): # ResNeXt-50 workload verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32) @@ -207,6 +209,7 @@ def test_group_conv2d_nchw(): +@tvm.testing.requires_cuda def test_group_conv2d_NCHWc_int8(): with Int8Fallback(): # ResNeXt-50 workload diff --git a/tests/python/topi/python/test_topi_group_conv2d_NCHWc_int8.py b/tests/python/topi/python/test_topi_group_conv2d_NCHWc_int8.py index 6afe44e51466..c5eebf411634 100644 --- a/tests/python/topi/python/test_topi_group_conv2d_NCHWc_int8.py +++ b/tests/python/topi/python/test_topi_group_conv2d_NCHWc_int8.py @@ -27,8 +27,6 @@ from tvm.topi.util import get_const_tuple import pytest -from common import get_all_backend - def _transform_data(data, bn): # NCHW -> NCHW[x]c batch_size, channel, height, width = data.shape @@ -77,7 +75,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(ctx): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -105,6 +103,7 @@ def check_device(device): check_device(device) autotvm.GLOBAL_SCOPE.silent = False +@tvm.testing.uses_gpu @pytest.mark.skip def test_conv2d_NCHWc(): # ResNet50 workloads diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index 7fce69d2300c..2fafe6c131ea 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -22,7 +22,6 @@ import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize -from common import get_all_backend def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', coord_trans="align_corners", method="bilinear"): @@ -47,11 +46,7 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, scale_w = out_width / in_width b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -62,10 +57,11 @@ def check_device(device): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_resize(): # Scale NCHW verify_resize(4, 16, 32, 32, 50, 50, 'NCHW') @@ -114,11 +110,7 @@ def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, scale_w = out_width / in_width b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -129,10 +121,11 @@ def check_device(device): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_resize3d(): # Trilinear verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NCDHW') @@ -147,6 +140,7 @@ def test_resize3d(): verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NDHWC', method="nearest_neighbor") +@tvm.testing.uses_gpu def test_crop_and_resize(): def verify_crop_and_resize(image_shape, np_boxes, np_box_indices, np_crop_size, layout='NHWC', method="bilinear", extrapolation_value=0.0): @@ -174,11 +168,7 @@ def verify_crop_and_resize(image_shape, np_boxes, np_box_indices, np_crop_size, baseline_np = tvm.topi.testing.crop_and_resize_python(np_images, np_boxes, np_box_indices, np_crop_size, layout, method, extrapolation_value) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(out) @@ -191,8 +181,8 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out.asnumpy(), baseline_np, rtol=1e-3, atol=1e-3) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) boxes_1 = np.array([[.2, .3, .7, .9]], dtype="float32") boxes_2 = np.array([[.2, .3, .7, .9], [0, .1, .8, 1]], dtype="float32") @@ -209,6 +199,7 @@ def check_device(device): verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW") +@tvm.testing.uses_gpu def test_affine_grid(): def verify_affine_grid(num_batch, target_shape): dtype = "float32" @@ -224,11 +215,7 @@ def get_ref_data(): data_np, out_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(out) @@ -240,13 +227,14 @@ def check_device(device): tvm.testing.assert_allclose( tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) verify_affine_grid(1, (16, 32)) verify_affine_grid(4, (16, 32)) +@tvm.testing.uses_gpu def test_grid_sample(): def verify_grid_sample(data_shape, grid_shape): dtype = "float32" @@ -264,11 +252,7 @@ def get_ref_data(): data_np, grid_np, out_np = get_ref_data() - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(out) @@ -281,8 +265,8 @@ def check_device(device): tvm.testing.assert_allclose( tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32)) diff --git a/tests/python/topi/python/test_topi_lrn.py b/tests/python/topi/python/test_topi_lrn.py index 2d57d078407c..13dcc715f9f5 100644 --- a/tests/python/topi/python/test_topi_lrn.py +++ b/tests/python/topi/python/test_topi_lrn.py @@ -21,6 +21,7 @@ from tvm import topi from tvm.topi.util import get_const_tuple import tvm.topi.testing +import tvm.testing _lrn_schedule = { "generic": topi.generic.schedule_lrn, @@ -41,7 +42,7 @@ def verify_lrn(shape, size, axis, bias, alpha, beta): b_np = tvm.topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta) def check_device(device): - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -58,6 +59,7 @@ def check_device(device): for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: check_device(device) +@tvm.testing.uses_gpu def test_lrn(): verify_lrn((1, 3, 5, 5), 3, 1, 1.0, 1.0, 0.5) verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5) diff --git a/tests/python/topi/python/test_topi_math.py b/tests/python/topi/python/test_topi_math.py index 8a9754ed6f96..9a7bc6ea004f 100644 --- a/tests/python/topi/python/test_topi_math.py +++ b/tests/python/topi/python/test_topi_math.py @@ -22,7 +22,6 @@ from tvm import topi import tvm.topi.testing from tvm.topi import util -from common import get_all_backend def test_util(): @@ -31,6 +30,7 @@ def test_util(): assert util.get_const_tuple((x, x)) == (100, 100) +@tvm.testing.uses_gpu def test_ewise(): def test_apply( func, @@ -57,11 +57,7 @@ def test_apply( a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-4 b_np = f_numpy(a_np) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -71,8 +67,8 @@ def check_device(device): foo(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for target in get_all_backend(): - check_device(target) + for target, ctx in tvm.testing.enabled_targets(): + check_device(target, ctx) def test_isnan( low, @@ -97,11 +93,7 @@ def test_isnan( a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5 b_np = np.isnan(a_np) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -111,8 +103,8 @@ def check_device(device): foo(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for target in get_all_backend(): - check_device(target) + for target, ctx in tvm.testing.enabled_targets(): + check_device(target, ctx) def test_infiniteness_ops(topi_op, ref_op, name): for dtype in ['float32', 'float64', 'int32', 'int16']: @@ -128,11 +120,7 @@ def test_infiniteness_ops(topi_op, ref_op, name): a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan b_np = ref_op(a_np) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) foo = tvm.build(s, [A, B], device, name=name) @@ -141,8 +129,8 @@ def check_device(device): foo(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for target in get_all_backend(): - check_device(target) + for target, ctx in tvm.testing.enabled_targets(): + check_device(target, ctx) test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100) @@ -167,6 +155,7 @@ def check_device(device): test_infiniteness_ops(topi.isinf, np.isinf, 'isinf') +@tvm.testing.uses_gpu def test_cast(): def verify(from_dtype, to_dtype, low=-100, high=100): shape = (5, 4) @@ -181,11 +170,7 @@ def verify(from_dtype, to_dtype, low=-100, high=100): a_np = a_np - a_np[2, 3] b_np = a_np.astype(to_dtype) - for device in get_all_backend(): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - continue + for device, ctx in tvm.testing.enabled_targets(): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -223,7 +208,7 @@ def test_apply( def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return with tvm.target.create(device): diff --git a/tests/python/topi/python/test_topi_pooling.py b/tests/python/topi/python/test_topi_pooling.py index b24dd85927b1..2f3a38c3df5e 100644 --- a/tests/python/topi/python/test_topi_pooling.py +++ b/tests/python/topi/python/test_topi_pooling.py @@ -21,9 +21,10 @@ import tvm from tvm import te from tvm import topi +import tvm.testing import tvm.topi.testing from tvm.topi.util import get_const_tuple -from common import get_all_backend +import tvm.testing _pool_schedule = { "generic": topi.generic.schedule_pool, @@ -91,11 +92,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ b_np[:, :, i, j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) b_np = np.maximum(b_np, 0.0) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s_func = tvm.topi.testing.dispatch(device, _pool_schedule) @@ -107,8 +104,8 @@ def check_device(device): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=2e-5, atol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True, add_relu=False): @@ -147,11 +144,7 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc if add_relu: pool_grad_np = np.maximum(pool_grad_np, 0.) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s_func = tvm.topi.testing.dispatch(device, _pool_grad_schedule) @@ -164,9 +157,10 @@ def check_device(device): f(a, out_grad, pool_grad) tvm.testing.assert_allclose(pool_grad.asnumpy(), pool_grad_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_pool(): """test cases of pool""" verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) @@ -183,6 +177,7 @@ def test_pool(): verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False) verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True) +@tvm.testing.uses_gpu def test_pool_grad(): """test cases of pool_grad""" verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False) @@ -222,11 +217,7 @@ def verify_global_pool(dshape, pool_type, layout='NCHW'): b_np = np.max(a_np, axis=axis, keepdims=True) b_np = np.maximum(b_np, 0.0) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s_func = tvm.topi.testing.dispatch(device, _adaptive_pool_schedule) @@ -240,9 +231,10 @@ def check_device(device): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_global_pool(): """test cases of global_pool""" verify_global_pool((1, 1024, 7, 7), 'avg') @@ -268,11 +260,7 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="floa assert len(out_size) == 3 out = topi.nn.adaptive_pool3d(data, out_size, pool_type, layout) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s_func = tvm.topi.testing.dispatch(device, _adaptive_pool_schedule) @@ -286,10 +274,11 @@ def check_device(device): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), np_out, rtol=4e-5, atol=1e-6) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_adaptive_pool(): """test cases of adaptive_pool""" verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max") @@ -329,11 +318,7 @@ def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ref_np = tvm.topi.testing.pool3d_ncdhw_python(input_np, kernel, stride, padding, output_shape, pool_type, count_include_pad, ceil_mode) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s_func = tvm.topi.testing.dispatch(device, _pool_schedule) @@ -345,10 +330,11 @@ def check_device(device): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_pool3d(): """test cases of pool3d""" verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True) @@ -384,11 +370,7 @@ def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type, ref_np = tvm.topi.testing.pool1d_ncw_python(input_np, kernel, stride, padding, output_shape, pool_type, count_include_pad, ceil_mode) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s_func = tvm.topi.testing.dispatch(device, _pool_schedule) @@ -400,10 +382,11 @@ def check_device(device): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_pool1d(): """test cases of pool1d""" verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True) diff --git a/tests/python/topi/python/test_topi_reduce.py b/tests/python/topi/python/test_topi_reduce.py index d84182f21ffd..33706e6cdc3c 100644 --- a/tests/python/topi/python/test_topi_reduce.py +++ b/tests/python/topi/python/test_topi_reduce.py @@ -22,7 +22,6 @@ from tvm import topi import tvm.topi.testing -from common import get_all_backend def _my_npy_argmax(arr, axis, keepdims): if not keepdims: @@ -69,11 +68,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") else: raise NotImplementedError - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_reduce_schedule(device)(B) @@ -122,10 +117,11 @@ def check_device(device): tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3) else: tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_reduce_map(): verify_reduce_map_ele(in_shape=(32,), diff --git a/tests/python/topi/python/test_topi_relu.py b/tests/python/topi/python/test_topi_relu.py index 1114b3fa3c8c..74425386bcbe 100644 --- a/tests/python/topi/python/test_topi_relu.py +++ b/tests/python/topi/python/test_topi_relu.py @@ -24,7 +24,7 @@ from tvm.topi.util import get_const_tuple from tvm.contrib.nvcc import have_fp16 -from common import get_all_backend +import tvm.testing def verify_relu(m, n, dtype="float32"): A = te.placeholder((m, n), name='A', dtype=dtype) @@ -33,11 +33,7 @@ def verify_relu(m, n, dtype="float32"): a_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(A.shape)).astype(A.dtype) b_np = a_np * (a_np > 0) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): if dtype == "float16" and device == "cuda" and not have_fp16(tvm.gpu(0).compute_version): print("Skip because %s does not have fp16 support" % device) return @@ -51,8 +47,8 @@ def check_device(device): foo(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_leaky_relu(m, alpha): @@ -92,10 +88,12 @@ def _prelu_numpy(x, W): out_np = _prelu_numpy(x_np, w_np) tvm.testing.assert_allclose(b.asnumpy(), out_np, rtol=1e-5) +@tvm.testing.uses_gpu def test_relu(): verify_relu(10, 128, "float32") verify_relu(128, 64, "float16") +@tvm.testing.uses_gpu def test_schedule_big_array(): verify_relu(1024 * 100 , 512) diff --git a/tests/python/topi/python/test_topi_reorg.py b/tests/python/topi/python/test_topi_reorg.py index e5a19474029a..2b49461bf8de 100644 --- a/tests/python/topi/python/test_topi_reorg.py +++ b/tests/python/topi/python/test_topi_reorg.py @@ -21,6 +21,7 @@ import tvm from tvm import te import tvm.topi.testing +import tvm.testing _reorg_schedule = { "generic": topi.generic.schedule_reorg, @@ -47,7 +48,7 @@ def get_ref_data_reorg(): def check_device(device): '''Cheching devices is enabled or not''' ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -63,6 +64,7 @@ def check_device(device): for device in ['llvm', 'cuda']: check_device(device) +@tvm.testing.uses_gpu def test_reorg(): verify_reorg(1, 20, 8, 2) diff --git a/tests/python/topi/python/test_topi_softmax.py b/tests/python/topi/python/test_topi_softmax.py index 1ff69be7bc87..46322ba38e50 100644 --- a/tests/python/topi/python/test_topi_softmax.py +++ b/tests/python/topi/python/test_topi_softmax.py @@ -20,11 +20,11 @@ import tvm from tvm import te from tvm import topi +import tvm.testing import tvm.topi.testing import logging from tvm.topi.util import get_const_tuple -from common import get_all_backend _softmax_schedule = { "generic": topi.generic.schedule_softmax, @@ -33,11 +33,7 @@ "hls": topi.hls.schedule_softmax, } -def check_device(A, B, a_np, b_np, device, name): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return +def check_device(A, B, a_np, b_np, device, ctx, name): print("Running on target: %s" % device) with tvm.target.create(device): s_func = tvm.topi.testing.dispatch(device, _softmax_schedule) @@ -59,8 +55,8 @@ def verify_softmax(m, n, dtype="float32"): a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) b_np = tvm.topi.testing.softmax_python(a_np) - for device in get_all_backend(): - check_device(A, B, a_np, b_np, device, "softmax") + for device, ctx in tvm.testing.enabled_targets(): + check_device(A, B, a_np, b_np, device, ctx, "softmax") def verify_softmax_4d(shape, dtype="float32"): A = te.placeholder(shape, dtype=dtype, name='A') @@ -71,9 +67,10 @@ def verify_softmax_4d(shape, dtype="float32"): b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c)) b_np = b_np.reshape(1, h, w, c).transpose(0, 3, 1, 2) - for device in get_all_backend(): - check_device(A, B, a_np, b_np, device, "softmax") + for device, ctx in tvm.testing.enabled_targets(): + check_device(A, B, a_np, b_np, device, ctx, "softmax") +@tvm.testing.uses_gpu def test_softmax(): verify_softmax(32, 10) verify_softmax(3, 4) @@ -89,10 +86,11 @@ def verify_log_softmax(m, n, dtype="float32"): a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) b_np = tvm.topi.testing.log_softmax_python(a_np) - for device in get_all_backend(): - check_device(A, B, a_np, b_np, device, "log_softmax") + for device, ctx in tvm.testing.enabled_targets(): + check_device(A, B, a_np, b_np, device, ctx, "log_softmax") +@tvm.testing.uses_gpu def test_log_softmax(): verify_log_softmax(32, 10) verify_log_softmax(3, 4) diff --git a/tests/python/topi/python/test_topi_sort.py b/tests/python/topi/python/test_topi_sort.py index 7abfe586a4e0..603d2ef51851 100644 --- a/tests/python/topi/python/test_topi_sort.py +++ b/tests/python/topi/python/test_topi_sort.py @@ -21,6 +21,7 @@ from tvm import te from tvm import topi import tvm.topi.testing +import tvm.testing _argsort_implement = { "generic": (topi.argsort, topi.generic.schedule_argsort), @@ -52,10 +53,10 @@ def verify_argsort(axis, is_ascend): np_indices = np_indices[:, :dshape[axis]] def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return + ctx = tvm.context(device, 0) print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _argsort_implement) @@ -97,7 +98,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -124,6 +125,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_argsort(): np.random.seed(0) for axis in [0, -1, 1]: @@ -131,6 +133,7 @@ def test_argsort(): verify_argsort(axis, False) +@tvm.testing.uses_gpu def test_topk(): np.random.seed(0) for k in [0, 1, 5]: diff --git a/tests/python/topi/python/test_topi_space_to_depth.py b/tests/python/topi/python/test_topi_space_to_depth.py index f659c33d3739..509678513ddd 100644 --- a/tests/python/topi/python/test_topi_space_to_depth.py +++ b/tests/python/topi/python/test_topi_space_to_depth.py @@ -21,8 +21,6 @@ from tvm import topi import tvm.topi.testing -from common import get_all_backend - def verify_space_to_depth(block_size, batch, in_channel, in_height, in_width, layout='NCHW'): out_channel = int(in_channel * (block_size * block_size)) @@ -50,11 +48,7 @@ def verify_space_to_depth(block_size, batch, in_channel, in_height, in_width, la a_np = np.transpose(a_np, axes=[0, 2, 3, 1]) b_np = np.transpose(b_np, axes=[0, 2, 3, 1]) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -64,10 +58,11 @@ def check_device(device): f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_space_to_depth(): for layout in ['NCHW', 'NHWC']: # Simplest possible case diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index e5fd0e9e6684..f0e701b13047 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -25,6 +25,7 @@ from collections import namedtuple import time import scipy.sparse as sp +import tvm.testing _sparse_dense_implement = { "generic": (topi.nn.sparse_dense, topi.generic.schedule_sparse_dense), @@ -56,7 +57,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -100,7 +101,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -141,7 +142,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -178,7 +179,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -303,7 +304,7 @@ def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -325,11 +326,13 @@ def check_device(device): for device in ['llvm', 'cuda']: check_device(device) +@tvm.testing.uses_gpu def test_sparse_dense_bsr(): M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9 verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=True) verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=False) +@tvm.testing.uses_gpu def test_sparse_dense_bsr_randomized(): for _ in range(20): BS_R = np.random.randint(1, 16) @@ -351,7 +354,7 @@ def test_sparse_dense_bsr_randomized(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -372,14 +375,11 @@ def check_device(device): check_device(device) -def test_sparse_dense(): - test_sparse_dense_csr() - test_sparse_dense_bsr() - test_sparse_dense_bsr_randomized() - if __name__ == "__main__": test_csrmv() test_csrmm() test_dense() - test_sparse_dense() + test_sparse_dense_csr() + test_sparse_dense_bsr() + test_sparse_dense_bsr_randomized() test_sparse_transpose_csr() diff --git a/tests/python/topi/python/test_topi_tensor.py b/tests/python/topi/python/test_topi_tensor.py index 34442845a869..53e48bf9da18 100644 --- a/tests/python/topi/python/test_topi_tensor.py +++ b/tests/python/topi/python/test_topi_tensor.py @@ -22,6 +22,7 @@ import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.contrib.nvcc import have_fp16 +import tvm.testing def verify_elemwise_sum(num_args, dtype): shape = (3,5,4) @@ -41,7 +42,7 @@ def get_ref_data(): np_nd = get_ref_data() def check_device(device): - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return @@ -70,7 +71,7 @@ def get_ref_data(): np_nd = get_ref_data() def check_device(device): - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return @@ -89,7 +90,7 @@ def check_device(device): def verify_vectorization(n, m, dtype): def check_device(device): - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return if dtype == "float16" and device == "cuda" and not have_fp16(tvm.gpu(0).compute_version): @@ -112,6 +113,8 @@ def check_device(device): for device in ["cuda"]: check_device(device) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_vectorization(): verify_vectorization(128, 64, "float16") diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index d8c51b8108ba..12e50b49a307 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -23,16 +23,12 @@ import tvm.topi.testing from tvm.contrib.nvcc import have_fp16 -from common import get_all_backend +import tvm.testing def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): A = te.placeholder(shape=in_shape, name="A") B = topi.expand_dims(A, axis, num_newaxis) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_broadcast_schedule(device)(B) @@ -44,18 +40,14 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_reinterpret(in_shape, in_dtype, out_dtype, generator): A = te.placeholder(shape=in_shape, name="A", dtype=in_dtype) B = topi.reinterpret(A, out_dtype) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): if in_dtype == "float16" and device == 'cuda' and not have_fp16(ctx.compute_version): print("Skip because %s does not have fp16 support" % device) return @@ -70,18 +62,14 @@ def check_device(device): foo(data_nd, out_nd) np.testing.assert_equal(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_transpose(in_shape, axes): A = te.placeholder(shape=in_shape, name="A") B = topi.transpose(A, axes) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -93,18 +81,14 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_reshape(src_shape, dst_shape): A = te.placeholder(shape=src_shape, name="A") B = topi.reshape(A, dst_shape) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -116,18 +100,14 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_squeeze(src_shape, axis): A = te.placeholder(shape=src_shape, name="A") B = topi.squeeze(A, axis=axis) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -141,8 +121,8 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_concatenate(shapes, axis): @@ -162,11 +142,7 @@ def get_concat_schedule(target): for i, shape in enumerate(shapes): tensor_l.append(te.placeholder(shape, name="A" + str(i))) out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = get_concat_schedule(device)(out_tensor) @@ -179,19 +155,15 @@ def check_device(device): foo(*(data_nds + [out_nd])) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_stack(shapes, axis): tensor_l = [] for i, shape in enumerate(shapes): tensor_l.append(te.placeholder(shape, name="A" + str(i))) out_tensor = topi.stack(tensor_l, axis) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_broadcast_schedule(device)(out_tensor) @@ -204,18 +176,14 @@ def check_device(device): foo(*(data_nds + [out_nd])) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_split(src_shape, indices_or_sections, axis): A = te.placeholder(shape=src_shape, name="A") tensor_l = topi.split(A, indices_or_sections, axis=axis) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(tensor_l) @@ -229,8 +197,8 @@ def check_device(device): for out_nd, out_npy in zip(out_nds, out_npys): tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_expand_like(in_shape, out_shape, axis): @@ -240,9 +208,6 @@ def verify_expand_like(in_shape, out_shape, axis): s = te.create_schedule([C.op]) def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return print("Running on target: %s" % device) ctx = tvm.context(device, 0) @@ -272,7 +237,7 @@ def verify_flip(in_shape, axis): B = topi.flip(A, axis) + 1 def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -291,6 +256,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_reverse_sequence(): def verify_reverse_sequence(in_data, seq_lengths, batch_axis, seq_axis, ref_res): seq_lengths = np.array(seq_lengths).astype("int32") @@ -298,11 +264,7 @@ def verify_reverse_sequence(in_data, seq_lengths, batch_axis, seq_axis, ref_res) B = te.placeholder(shape=seq_lengths.shape, name="B", dtype=str(seq_lengths.dtype)) C = topi.reverse_sequence(A, B, seq_axis, batch_axis) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(C) @@ -315,8 +277,8 @@ def check_device(device): foo(data_nd, seq_lengths_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), ref_res) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") result = [[0, 5, 10, 15], @@ -382,7 +344,7 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -417,7 +379,7 @@ def verify_strided_slice(in_shape, begin, end, strides=None): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -449,7 +411,7 @@ def verify_strided_set(in_shape, v_shape, begin, end, strides=None): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -490,11 +452,7 @@ def verify_gather(data, axis, indices): var_indices = te.placeholder(shape=indices.shape, dtype=indices.dtype.name, name="indices") out_tensor = topi.gather(var_data, axis, var_indices) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(out_tensor) @@ -508,8 +466,8 @@ def check_device(device): func(data_nd, indices_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_gather_nd(src_shape, indices_src, indices_dtype): src_dtype = "float32" @@ -518,11 +476,7 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype): indices = te.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices") out_tensor = topi.gather_nd(a=A, indices=indices) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(out_tensor) @@ -540,8 +494,8 @@ def check_device(device): func(data_nd, indices_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_arange(start, stop, step): if start is None and step is None: @@ -557,11 +511,7 @@ def verify_arange(start, stop, step): A = topi.arange(start, stop, step) a_np = np.arange(start, stop, step) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(A) @@ -570,17 +520,13 @@ def check_device(device): f(a_nd) tvm.testing.assert_allclose(a_nd.asnumpy(), a_np) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_repeat(in_shape, repeats, axis): A = te.placeholder(shape=in_shape, name="A") B = topi.repeat(A, repeats, axis) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_broadcast_schedule(device)(B) @@ -592,17 +538,13 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_tile(in_shape, reps): A = te.placeholder(shape=in_shape, name="A") B = topi.tile(A, reps) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_broadcast_schedule(device)(B) @@ -614,8 +556,8 @@ def check_device(device): foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_where(in_shape): Cond = te.placeholder(shape=in_shape, name="cond") @@ -623,11 +565,7 @@ def verify_where(in_shape): A = te.placeholder(shape=in_shape, name="A") B = te.placeholder(shape=in_shape, name="B") C = topi.where(Cond, A, B) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_broadcast_schedule(device)(C) @@ -643,19 +581,15 @@ def check_device(device): f(cond_nd, x_nd, y_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype): indices = te.placeholder(shape=indices_shape, name="indices", dtype="int32") on_value_const = tvm.tir.const(on_value, dtype) off_value_const = tvm.tir.const(off_value, dtype) one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(one_hot_result) @@ -668,8 +602,8 @@ def check_device(device): out_topi = out_nd.asnumpy() tvm.testing.assert_allclose(out_topi, out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_unravel_index(indices, shape, dtype): @@ -684,11 +618,7 @@ def verify_unravel_index(indices, shape, dtype): Y = te.placeholder(shape=y_data.shape, dtype=dtype, name="Y") Z = topi.unravel_index(X, Y) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(Z) @@ -701,8 +631,8 @@ def check_device(device): foo(datax_nd, datay_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected): sparse_indices_data = np.array(sparse_indices) @@ -720,11 +650,7 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ args = [A, B, C] D = topi.sparse_to_dense(A, output_shape, B, C) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(D) @@ -743,8 +669,8 @@ def check_device(device): tvm.testing.assert_allclose(out_nd.asnumpy(), np.array(xpected)) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) def verify_matrix_set_diag(input_shape, dtype): diagonal_shape = list(input_shape[:-2]) @@ -752,11 +678,8 @@ def verify_matrix_set_diag(input_shape, dtype): input = te.placeholder(shape=input_shape, name="input", dtype=dtype) diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal) - def check_device(device): + def check_device(device, ctx): ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(matrix_set_diag_result) @@ -771,10 +694,11 @@ def check_device(device): out_topi = out_nd.asnumpy() tvm.testing.assert_allclose(out_topi, out_npy) - for device in get_all_backend(): - check_device(device) + for target, ctx in tvm.testing.enabled_targets(): + check_device(target, ctx) +@tvm.testing.uses_gpu def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) @@ -784,6 +708,7 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3]) +@tvm.testing.uses_gpu def test_strided_set(): verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2]) verify_strided_set((3, 4, 3), (3, 1, 2), [0, 0, 0], [4, -5, 4], [1, -1, 2]) @@ -795,11 +720,13 @@ def test_strided_set(): verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1, 0], [4, 4, 3]) verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1], [4, 4, 3]) +@tvm.testing.uses_gpu def test_expand_dims(): verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (1, 3, 10), -3, 1) +@tvm.testing.uses_gpu def test_reinterpret(): verify_reinterpret((1000,), "float32", "int32", lambda shape: np.random.randn(*shape) * 1000) @@ -813,12 +740,14 @@ def test_reinterpret(): lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape)) +@tvm.testing.uses_gpu def test_transpose(): verify_transpose((3, 10, 2), (1, 0, 2)) verify_transpose((3, 10, 5), (2, 0, 1)) verify_transpose((3, 10), None) +@tvm.testing.uses_gpu def test_reshape(): verify_reshape((1, 2, 3, 4), (2, 3, 4)) verify_reshape((4, 2, 3, 4), (2, 4, 12)) @@ -827,10 +756,12 @@ def test_reshape(): verify_reshape((4, 0), (2, 0, 2)) +@tvm.testing.uses_gpu def test_where(): verify_where((1, 2, 3, 4)) +@tvm.testing.requires_gpu def test_squeeze(): verify_squeeze((1, 2, 3, 4), 0) verify_squeeze((1, 2, 1, 4), None) @@ -843,7 +774,7 @@ def test_squeeze(): C = te.compute((1,), lambda i: E[(2 * A[0] - 1).astype('int32')]) for device in ['cuda', 'opencl']: ctx = tvm.context(device, 0) - if ctx.exist: + if tvm.testing.device_enabled(device): with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(C) func = tvm.build(s, [A, C]) @@ -853,6 +784,7 @@ def test_squeeze(): assert c.asnumpy()[0] == 2 +@tvm.testing.uses_gpu def test_concatenate(): verify_concatenate([(2,), (2,), (2,)], -1) verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1) @@ -865,6 +797,7 @@ def test_concatenate(): verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1) +@tvm.testing.uses_gpu def test_stack(): verify_stack([(2,), (2,), (2,)], -1) verify_stack([(2,), (2,), (2,)], 1) @@ -873,11 +806,13 @@ def test_stack(): verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1) +@tvm.testing.uses_gpu def test_split(): verify_split((2, 12, 3), 3, 1) verify_split((2, 12, 3), [2, 4], 1) verify_split((10, 12, 24), [5, 7, 9], -1) +@tvm.testing.uses_gpu def test_flip(): verify_flip((3, 4, 3), 1) verify_flip((3, 4, 3), 0) @@ -886,12 +821,14 @@ def test_flip(): verify_flip((3, 4, 3), -3) verify_flip((3, 4, 3), -2) +@tvm.testing.requires_llvm def test_expand_like(): verify_expand_like((3,), (2, 3), [0]) verify_expand_like((2,), (2, 3), [1]) verify_expand_like((3, 4), (3, 5, 4), [1]) verify_expand_like((5, 7), (5, 6, 7, 8), [1, 3]) +@tvm.testing.uses_gpu def test_take(): verify_take((4,), [1]) verify_take((4,), [[0,1,2,3]]) @@ -911,6 +848,7 @@ def test_take(): verify_take((3,4), [0, 2], axis=0, mode="fast") verify_take((3,4), [0, 2], axis=1, mode="fast") +@tvm.testing.uses_gpu def test_gather(): verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]]) verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(1, 7, 5))) @@ -920,6 +858,7 @@ def test_gather(): verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2))) verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10))) +@tvm.testing.uses_gpu def test_gather_nd(): for indices_dtype in ['int32', 'float32']: verify_gather_nd((4,), [[1.8]], indices_dtype) @@ -935,6 +874,7 @@ def test_gather_nd(): verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]], indices_dtype) +@tvm.testing.uses_gpu def test_arange(): verify_arange(None, 20, None) verify_arange(None, 20, 2) @@ -946,18 +886,21 @@ def test_arange(): verify_arange(20, 1, -1) verify_arange(20, 1, -1.5) +@tvm.testing.uses_gpu def test_repeat(): verify_repeat((2,), 1, 0) verify_repeat((3, 2), 2, 0) verify_repeat((3, 2, 4), 3, 1) verify_repeat((1, 3, 2, 4), 4, -1) +@tvm.testing.uses_gpu def test_tile(): verify_tile((3, 2), (2, 3)) verify_tile((3, 2, 5), (2,)) verify_tile((3, ), (2, 3, 3)) verify_tile((4, 0), (5,)) +@tvm.testing.uses_gpu def test_layout_transform(): in_shape = (1, 32, 8, 8) A = te.placeholder(shape=in_shape, dtype="float32", name="A") @@ -968,11 +911,7 @@ def test_layout_transform(): output = np.reshape(output, newshape=(1, 8, 8, 2, 16)) output = np.transpose(output, axes=(0, 3, 1, 2, 4)) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): tvm_input = tvm.nd.array(input, ctx) tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype) print("Running on target: %s" % device) @@ -982,10 +921,11 @@ def check_device(device): f(tvm_input, tvm_output) tvm.testing.assert_allclose(tvm_output.asnumpy(), output) - for backend in get_all_backend(): - check_device(backend) + for backend, ctx in tvm.testing.enabled_targets(): + check_device(backend, ctx) +@tvm.testing.uses_gpu def test_shape(): in_shape = (8, 7, 13) dtype = "int32" @@ -995,11 +935,7 @@ def test_shape(): input = np.random.uniform(size=in_shape).astype(A.dtype) output = np.asarray(in_shape).astype(dtype) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): tvm_input = tvm.nd.array(input, ctx) tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=dtype) print("Running on target: %s" % device) @@ -1009,10 +945,11 @@ def check_device(device): f(tvm_input, tvm_output) tvm.testing.assert_allclose(tvm_output.asnumpy(), output) - for backend in get_all_backend(): - check_device(backend) + for backend, ctx in tvm.testing.enabled_targets(): + check_device(backend, ctx) +@tvm.testing.uses_gpu def test_sequence_mask(): for in_shape in (5, 10), (3, 4, 5, 4): for axis in [0, 1]: @@ -1026,11 +963,7 @@ def test_sequence_mask(): B_data = np.random.randint(1, max_length, (batch_size,)).astype(np.int32) C_gt_data = tvm.topi.testing.sequence_mask(A_data, B_data, mask_value, axis) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): tvm_A = tvm.nd.array(A_data, ctx) tvm_B = tvm.nd.array(B_data, ctx) tvm_C = tvm.nd.empty(in_shape, ctx=ctx, dtype="float32") @@ -1040,9 +973,10 @@ def check_device(device): f = tvm.build(s, [A, B, C], device, name="SequenceMask") f(tvm_A, tvm_B, tvm_C) tvm.testing.assert_allclose(tvm_C.asnumpy(), C_gt_data) - for backend in get_all_backend(): - check_device(backend) + for backend, ctx in tvm.testing.enabled_targets(): + check_device(backend, ctx) +@tvm.testing.uses_gpu def test_ndarray_size(): in_shape = (5, 11, 7) dtype = "int32" @@ -1052,11 +986,7 @@ def test_ndarray_size(): input = np.random.uniform(size=in_shape).astype(A.dtype) output = np.asarray(np.size(input)).astype(dtype) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): tvm_input = tvm.nd.array(input, ctx=ctx) tvm_output = tvm.nd.empty((), ctx=ctx, dtype=B.dtype) print("Running on target: %s" % device) @@ -1066,18 +996,15 @@ def check_device(device): f(tvm_input, tvm_output) tvm.testing.assert_allclose(tvm_output.asnumpy(), output) - for backend in get_all_backend(): - check_device(backend) + for backend, ctx in tvm.testing.enabled_targets(): + check_device(backend, ctx) +@tvm.testing.uses_gpu def test_where_fusion(): """integration test that where and zeros should be properly inlined""" - def check_device(device): + def check_device(device, ctx): with tvm.target.create(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return print("Running on target: %s" % device) conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device) data = te.placeholder((2, 1, 2, 4), 'int8', 'data') @@ -1093,9 +1020,10 @@ def check_device(device): s = conv2d_schedule(outs) tvm.build(s, [data, w, add], target=backend) - for backend in get_all_backend(): - check_device(backend) + for backend, ctx in tvm.testing.enabled_targets(): + check_device(backend, ctx) +@tvm.testing.uses_gpu def test_one_hot(): verify_one_hot((3,), 3, 1, 0, -1, "int32") verify_one_hot((3,), 3, 1.0, 0.0, -1, "float32") @@ -1105,6 +1033,7 @@ def test_one_hot(): verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") +@tvm.testing.uses_gpu def test_unravel_index(): for dtype in ["int32", "int64"]: verify_unravel_index([0, 1, 2, 3], [2, 2], dtype) @@ -1112,6 +1041,7 @@ def test_unravel_index(): verify_unravel_index(144, [5, 5, 5, 2], dtype) verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype) +@tvm.testing.uses_gpu def test_sparse_to_dense(): verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) #scalar verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) #vector @@ -1134,6 +1064,7 @@ def test_sparse_to_dense(): #sparse_indices should not be > 2d tensor #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) +@tvm.testing.uses_gpu def test_matrix_set_diag(): for dtype in ['float32', 'int32']: verify_matrix_set_diag((2, 2), dtype) diff --git a/tests/python/topi/python/test_topi_upsampling.py b/tests/python/topi/python/test_topi_upsampling.py index 04cc31092402..7861a29f7950 100644 --- a/tests/python/topi/python/test_topi_upsampling.py +++ b/tests/python/topi/python/test_topi_upsampling.py @@ -23,8 +23,6 @@ import math from tvm.topi.util import nchw_pack_layout -from common import get_all_backend - def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, layout='NCHW', method="nearest_neighbor", in_batch_block = 0, in_channel_block = 0): @@ -58,11 +56,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, else: b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -73,9 +67,10 @@ def check_device(device): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_upsampling(): # nearest_neighbor - NCHW verify_upsampling(8, 16, 32, 32, 2.0, 2.0) @@ -141,11 +136,7 @@ def verify_upsampling3d(batch, in_channel, in_depth, in_height, in_width, scale_ else: b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return + def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) @@ -156,9 +147,10 @@ def check_device(device): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for device in get_all_backend(): - check_device(device) + for device, ctx in tvm.testing.enabled_targets(): + check_device(device, ctx) +@tvm.testing.uses_gpu def test_upsampling3d(): # nearest_neighbor - NCDHW verify_upsampling3d(8, 8, 16, 16, 16, 2.0, 2.0, 2.0) diff --git a/tests/python/topi/python/test_topi_util.py b/tests/python/topi/python/test_topi_util.py index 345e7f9baf1a..a6287b137518 100644 --- a/tests/python/topi/python/test_topi_util.py +++ b/tests/python/topi/python/test_topi_util.py @@ -32,4 +32,4 @@ def test_get_shape(): verify_get_shape((2, 3, 32, 32, 16, 8), "OIHW16i8o", "HWO8oI16i", (32, 32, 2, 8, 3, 16)) if __name__ == "__main__": - test_get_shape() \ No newline at end of file + test_get_shape() diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index e0e2205ba0bf..691dcdfaf926 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -26,6 +26,8 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.util import get_const_tuple from tvm.topi.vision import ssd, non_max_suppression, get_valid_counts +import pytest +import tvm.testing _get_valid_counts_implement = { "generic": (topi.vision.get_valid_counts, topi.generic.schedule_get_valid_counts), @@ -88,7 +90,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -114,16 +116,13 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) - """ Skip this test as it is intermittent - see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094 for device in ['llvm', 'cuda', 'opencl']: - # Disable gpu test for now - if device != "llvm": - continue check_device(device) - """ +@tvm.testing.uses_gpu +@pytest.mark.skip("Skip this test as it is intermittent." + "See https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094") def test_get_valid_counts(): verify_get_valid_counts((1, 1000, 5), 0.5, -1, 0) verify_get_valid_counts((1, 2500, 6), 0, 0, 1) @@ -143,7 +142,7 @@ def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, n def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -179,7 +178,7 @@ def check_device(device): for device in ['llvm', 'cuda', 'opencl']: check_device(device) - +@tvm.testing.uses_gpu def test_non_max_suppression(): np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], @@ -247,7 +246,7 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -267,12 +266,14 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_multibox_prior(): verify_multibox_prior((1, 3, 50, 50)) verify_multibox_prior((1, 3, 224, 224), sizes=(0.5, 0.25, 0.1), ratios=(1, 2, 0.5)) verify_multibox_prior((1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True) +@tvm.testing.uses_gpu def test_multibox_detection(): batch_size = 1 num_anchors = 3 @@ -292,7 +293,7 @@ def test_multibox_detection(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -336,7 +337,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -359,6 +360,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_roi_align(): verify_roi_align(1, 16, 32, 64, 7, 1.0, -1) verify_roi_align(4, 16, 32, 64, 7, 0.5, 2) @@ -387,7 +389,7 @@ def get_ref_data(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -409,6 +411,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_roi_pool(): verify_roi_pool(1, 4, 16, 32, 7, 1.0) verify_roi_pool(4, 4, 16, 32, 7, 0.5) @@ -421,7 +424,7 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -441,6 +444,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_proposal(): attrs = {'scales': (0.5,),'ratios': (0.5,), 'feature_stride': 16, diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 9282667c025a..c94515a2480b 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -21,6 +21,7 @@ from tvm import topi from tvm import te, auto_scheduler import tempfile +import tvm.testing from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul @@ -46,7 +47,7 @@ def record_common(dag, s): def test_record_split_reorder_fuse_annotation(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name='A') @@ -80,7 +81,7 @@ def test_record_split_reorder_fuse_annotation(): def test_record_compute_at_root_inline_cache_read_write(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name='A') @@ -108,7 +109,7 @@ def test_record_compute_at_root_inline_cache_read_write(): def test_record_follow_split_follow_fused_split(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name='A') @@ -142,7 +143,7 @@ def test_record_follow_split_follow_fused_split(): def test_record_pragma_storage_align_rfactor(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name='A') @@ -165,7 +166,7 @@ def test_record_pragma_storage_align_rfactor(): def test_measure_local_builder_runner(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return dag, s0 = get_tiled_matmul() @@ -183,7 +184,7 @@ def test_measure_local_builder_runner(): def test_measure_local_builder_rpc_runner(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return dag, s0 = get_tiled_matmul() diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 21ac9844b54c..bf7cefabc614 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -54,6 +54,9 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm", tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials, runner=runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]) sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options) + print("*"*80) + print(target) + print("*"*80) inp, res = auto_scheduler.load_best(log_file, workload_key, target) print("==== Python Code ====") @@ -78,9 +81,8 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm", print() +@tvm.testing.requires_llvm def test_workload_registry_search_basic(): - if not tvm.runtime.enabled("llvm"): - return # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool t = PropagatingThread(target=search_common, kwargs={'seed': 944563397}) @@ -96,9 +98,8 @@ def test_workload_registry_search_basic(): t.join() +@tvm.testing.requires_llvm def test_sketch_search_policy_basic(): - if not tvm.runtime.enabled("llvm"): - return # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool t = PropagatingThread(target=search_common, @@ -107,9 +108,8 @@ def test_sketch_search_policy_basic(): t.join() +@tvm.testing.requires_llvm def test_sketch_search_policy_xgbmodel(): - if not tvm.runtime.enabled("llvm"): - return # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool t = PropagatingThread(target=search_common, @@ -119,9 +119,8 @@ def test_sketch_search_policy_xgbmodel(): t.join() +@tvm.testing.requires_cuda def test_sketch_search_policy_cuda_rpc_runner(): - if not tvm.runtime.enabled("cuda"): - return measure_ctx = auto_scheduler.LocalRPCMeasureContext() # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index f5188669def8..c35a3f75b7e2 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -18,6 +18,7 @@ """ Test sketch generation. """ import tvm +import tvm.testing from tvm import te, auto_scheduler from tvm.auto_scheduler import _ffi_api from tvm.auto_scheduler.loop_state import Stage @@ -233,10 +234,8 @@ def test_cpu_conv2d_winograd_sketch(): assert sketches[1] != sketches[2] +@tvm.testing.requires_cuda def test_cuda_matmul_sketch(): - if not tvm.context("cuda", 0).exist: - return - sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'cuda') ''' 1 multi-level tiling sketch ''' assert len(sketches) == 1 @@ -265,10 +264,8 @@ def test_cuda_matmul_sketch(): assert_is_tiled(sketches[1].stages[5]) +@tvm.testing.requires_cuda def test_cuda_conv2d_bn_relu_sketch(): - if not tvm.context("cuda", 0).exist: - return - sketches = generate_sketches(conv2d_nchw_bn_relu_auto_scheduler_test, (1, 56, 56, 512, 512, 3, 1, 1), 'cuda') ''' 1 multi-level tiling sketch ''' @@ -286,20 +283,16 @@ def test_cuda_conv2d_bn_relu_sketch(): assert_is_tiled(sketches[0].stages[12]) +@tvm.testing.requires_cuda def test_cuda_max_pool2d_sketch(): - if not tvm.context("cuda", 0).exist: - return - sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 0), 'cuda') ''' 1 default sketch ''' assert len(sketches) == 1 assert len(sketches[0].transform_steps) == 0 +@tvm.testing.requires_cuda def test_cuda_min_sketch(): - if not tvm.context("cuda", 0).exist: - return - sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 'cuda') ''' 1 cross thread reuction sketch + 1 default sketch ''' assert len(sketches) == 2 @@ -309,10 +302,8 @@ def test_cuda_min_sketch(): assert len(sketches[1].transform_steps) == 0 +@tvm.testing.requires_cuda def test_cuda_softmax_sketch(): - if not tvm.context("cuda", 0).exist: - return - sketches = generate_sketches(softmax_nm_auto_scheduler_test, (2, 1024), 'cuda') ''' (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) ''' assert len(sketches) == (2 * 2) @@ -346,10 +337,8 @@ def test_cuda_softmax_sketch(): assert_compute_at_condition(sketches[3].stages[2], "inlined") +@tvm.testing.requires_cuda def test_cuda_conv2d_winograd_sketch(): - if not tvm.context("cuda", 0).exist: - return - sketches = generate_sketches(conv2d_winograd_nhwc_auto_scheduler_test, (1, 28, 28, 128, 128, 3, 1, 1), 'cuda') ''' 1 multi-level tiling sketch ''' diff --git a/tests/python/unittest/test_autotvm_index_tuner.py b/tests/python/unittest/test_autotvm_index_tuner.py index c7fa2ea364b5..2875fd78ba3c 100644 --- a/tests/python/unittest/test_autotvm_index_tuner.py +++ b/tests/python/unittest/test_autotvm_index_tuner.py @@ -65,4 +65,4 @@ def test_random_tuner(): if __name__ == '__main__': test_gridsearch_tuner() - test_random_tuner() \ No newline at end of file + test_random_tuner() diff --git a/tests/python/unittest/test_hybrid_error_report.py b/tests/python/unittest/test_hybrid_error_report.py index dd5d70840943..0dfdbbd0eec0 100644 --- a/tests/python/unittest/test_hybrid_error_report.py +++ b/tests/python/unittest/test_hybrid_error_report.py @@ -102,4 +102,4 @@ def wrap_error(module, lineno): wrap_error(Module4, 60) wrap_error(Module5, 70) wrap_error(Module6, 77) - wrap_error(Module7, 84) \ No newline at end of file + wrap_error(Module7, 84) diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index ee2cd718e45f..d718f20b5201 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -21,6 +21,7 @@ from tvm import rpc from tvm.contrib import util, graph_runtime +@tvm.testing.requires_llvm def test_graph_simple(): n = 4 A = te.placeholder((n,), name='A') @@ -52,9 +53,6 @@ def test_graph_simple(): graph = json.dumps(graph) def check_verify(): - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled") - return mlib = tvm.build(s, [A, B], "llvm", name="myadd") mod = graph_runtime.create(graph, mlib, tvm.cpu(0)) a = np.random.uniform(size=(n,)).astype(A.dtype) @@ -63,9 +61,6 @@ def check_verify(): np.testing.assert_equal(out.asnumpy(), a + 1) def check_remote(): - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled") - return mlib = tvm.build(s, [A, B], "llvm", name="myadd") server = rpc.Server("localhost") remote = rpc.connect(server.host, server.port) @@ -93,9 +88,6 @@ def check_sharing(): params = {'x': x_in} graph, lib, params = relay.build(func, target="llvm", params=params) - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled") - return mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0)) mod_shared.load_params(relay.save_param_dict(params)) num_mods = 10 diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index ce47b16fc4d5..f284ba6a9eb9 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -23,6 +23,7 @@ from tvm.contrib import util from tvm.contrib.debugger import debug_runtime as graph_runtime +@tvm.testing.requires_llvm def test_graph_simple(): n = 4 A = te.placeholder((n,), name='A') @@ -54,9 +55,6 @@ def test_graph_simple(): graph = json.dumps(graph) def check_verify(): - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled") - return mlib = tvm.build(s, [A, B], "llvm", name="myadd") try: mod = graph_runtime.create(graph, mlib, tvm.cpu(0)) @@ -115,9 +113,6 @@ def check_verify(): assert(not os.path.exists(directory)) def check_remote(): - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled") - return mlib = tvm.build(s, [A, B], "llvm", name="myadd") server = rpc.Server("localhost") remote = rpc.connect(server.host, server.port) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 56ae25092510..512fefdbb4b7 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -20,6 +20,7 @@ import tvm from tvm.contrib import graph_runtime from tvm.contrib.debugger import debug_runtime +import tvm.testing def input_shape(mod): return [int(x) for x in mod["main"].checked_type.arg_types[0].shape] @@ -42,7 +43,7 @@ def verify(data): return out def test_legacy_compatibility(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -58,7 +59,7 @@ def test_legacy_compatibility(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) def test_cpu(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -83,10 +84,9 @@ def test_cpu(): out = gmod.get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +@tvm.testing.requires_cuda +@tvm.testing.requires_gpu def test_gpu(): - if not tvm.runtime.enabled("cuda"): - print("Skip because cuda is not enabled") - return mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) @@ -110,9 +110,10 @@ def test_gpu(): out = gmod.get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +@tvm.testing.uses_gpu def test_mod_export(): def verify_cpu_export(obj_format): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -150,7 +151,7 @@ def verify_cpu_export(obj_format): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) def verify_gpu_export(obj_format): - if not tvm.runtime.enabled("cuda"): + if not tvm.testing.device_enabled("cuda"): print("Skip because cuda is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -188,7 +189,7 @@ def verify_gpu_export(obj_format): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) def verify_rpc_cpu_export(obj_format): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -230,7 +231,7 @@ def verify_rpc_cpu_export(obj_format): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) def verify_rpc_gpu_export(obj_format): - if not tvm.runtime.enabled("cuda"): + if not tvm.testing.device_enabled("cuda"): print("Skip because cuda is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -278,9 +279,10 @@ def verify_rpc_gpu_export(obj_format): verify_rpc_cpu_export(obj_format) verify_rpc_gpu_export(obj_format) +@tvm.testing.uses_gpu def test_remove_package_params(): def verify_cpu_remove_package_params(obj_format): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -326,7 +328,7 @@ def verify_cpu_remove_package_params(obj_format): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) def verify_gpu_remove_package_params(obj_format): - if not tvm.runtime.enabled("cuda"): + if not tvm.testing.device_enabled("cuda"): print("Skip because cuda is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -372,7 +374,7 @@ def verify_gpu_remove_package_params(obj_format): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) def verify_rpc_cpu_remove_package_params(obj_format): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -423,7 +425,7 @@ def verify_rpc_cpu_remove_package_params(obj_format): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) def verify_rpc_gpu_remove_package_params(obj_format): - if not tvm.runtime.enabled("cuda"): + if not tvm.testing.device_enabled("cuda"): print("Skip because cuda is not enabled") return mod, params = relay.testing.synthetic.get_workload() @@ -480,7 +482,7 @@ def verify_rpc_gpu_remove_package_params(obj_format): verify_rpc_gpu_remove_package_params(obj_format) def test_debug_graph_runtime(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled") return mod, params = relay.testing.synthetic.get_workload() diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index 9a859da39ae2..bc5e7fba6c22 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -19,6 +19,8 @@ import tvm from tvm import te +import tvm.testing + from tvm.contrib import util header_file_dir_path = util.tempdir() @@ -59,10 +61,11 @@ def generate_engine_module(): return csource_module +@tvm.testing.uses_gpu def test_mod_export(): def verify_gpu_mod_export(obj_format): for device in ["llvm", "cuda"]: - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled..." % device) return @@ -89,7 +92,7 @@ def verify_gpu_mod_export(obj_format): def verify_multi_dso_mod_export(obj_format): for device in ["llvm"]: - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled..." % device) return @@ -117,7 +120,7 @@ def verify_multi_dso_mod_export(obj_format): def verify_json_import_dso(obj_format): for device in ["llvm"]: - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled..." % device) return @@ -173,7 +176,7 @@ def verify_multi_c_mod_export(): print("Skip test because gcc is not available.") for device in ["llvm"]: - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled..." % device) return diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index c7a5544f4a30..6e7df062ba0e 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -22,6 +22,7 @@ import sys import numpy as np import subprocess +import tvm.testing runtime_py = """ import os @@ -42,7 +43,7 @@ """ def test_dso_module_load(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return dtype = 'int64' temp = util.tempdir() @@ -90,6 +91,7 @@ def save_object(names): shell=True) +@tvm.testing.requires_gpu def test_device_module_dump(): # graph n = tvm.runtime.convert(1024) @@ -104,7 +106,7 @@ def test_device_module_dump(): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return temp = util.tempdir() @@ -132,7 +134,7 @@ def check_device(device): def check_stackvm(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return temp = util.tempdir() @@ -161,7 +163,7 @@ def test_combine_module_llvm(): def check_llvm(): ctx = tvm.cpu(0) - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled" ) return temp = util.tempdir() @@ -186,7 +188,7 @@ def check_llvm(): def check_system_lib(): ctx = tvm.cpu(0) - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): print("Skip because llvm is not enabled" ) return temp = util.tempdir() diff --git a/tests/python/unittest/test_runtime_ndarray.py b/tests/python/unittest/test_runtime_ndarray.py index 36312959da3d..bda987db35b5 100644 --- a/tests/python/unittest/test_runtime_ndarray.py +++ b/tests/python/unittest/test_runtime_ndarray.py @@ -17,26 +17,12 @@ import tvm from tvm import te import numpy as np - -def enabled_ctx_list(): - ctx_list = [('cpu', tvm.cpu(0)), - ('gpu', tvm.gpu(0)), - ('cl', tvm.opencl(0)), - ('metal', tvm.metal(0)), - ('rocm', tvm.rocm(0)), - ('vulkan', tvm.vulkan(0)), - ('vpi', tvm.vpi(0))] - for k, v in ctx_list: - assert tvm.context(k, 0) == v - ctx_list = [x[1] for x in ctx_list if x[1].exist] - return ctx_list - -ENABLED_CTX_LIST = enabled_ctx_list() -print("Testing using contexts:", ENABLED_CTX_LIST) +import tvm.testing +@tvm.testing.uses_gpu def test_nd_create(): - for ctx in ENABLED_CTX_LIST: + for target, ctx in tvm.testing.enabled_targets(): for dtype in ["uint8", "int8", "uint16", "int16", "uint32", "int32", "float32"]: x = np.random.randint(0, 10, size=(3, 4)) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 7f01f880cd3d..50c753fafbfa 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -179,6 +179,7 @@ def test_rpc_file_exchange(): rev = remote.download("dat.bin") assert(rev == blob) +@tvm.testing.requires_llvm def test_rpc_remote_module(): if not tvm.runtime.enabled("rpc"): return @@ -197,9 +198,6 @@ def test_rpc_remote_module(): "rpc.Connect", server1.host, server1.port, "x1"]) def check_remote(remote): - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled") - return temp = util.tempdir() ctx = remote.cpu(0) f = tvm.build(s, [A, B], "llvm", name="myadd") @@ -215,9 +213,6 @@ def check_remote(remote): np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) def check_minrpc(): - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled") - return if tvm.get_global_func("rpc.PopenSession", allow_missing=True) is None: return # export to minrpc @@ -254,10 +249,7 @@ def check_remote_link_cl(remote): runtime initializes. We leave it as an example on how to do rpc when we want to do linking on remote. """ - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled") - return - if not tvm.runtime.enabled("opencl"): + if not tvm.testing.device_enabled("opencl"): print("Skip because opencl is not enabled") return temp = util.tempdir() diff --git a/tests/python/unittest/test_target_codegen_blob.py b/tests/python/unittest/test_target_codegen_blob.py index 0059083ebdcc..758643d53320 100644 --- a/tests/python/unittest/test_target_codegen_blob.py +++ b/tests/python/unittest/test_target_codegen_blob.py @@ -22,10 +22,12 @@ import tvm from tvm import te import ctypes +import tvm.testing +@tvm.testing.uses_gpu def test_synthetic(): for device in ["llvm", "cuda"]: - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled..." % device) return @@ -70,10 +72,11 @@ def verify(data): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +@tvm.testing.uses_gpu def test_cuda_lib(): ctx = tvm.gpu(0) for device in ["llvm", "cuda"]: - if not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("skip because %s is not enabled..." % device) return nn = 12 @@ -99,4 +102,4 @@ def test_cuda_lib(): if __name__ == "__main__": test_synthetic() - #test_system_lib() + test_cuda_lib() diff --git a/tests/python/unittest/test_target_codegen_bool.py b/tests/python/unittest/test_target_codegen_bool.py index cdb343f3530b..f8d6e329be38 100644 --- a/tests/python/unittest/test_target_codegen_bool.py +++ b/tests/python/unittest/test_target_codegen_bool.py @@ -19,7 +19,9 @@ import tvm from tvm import te import numpy as np +import tvm.testing +@tvm.testing.uses_gpu def test_cmp_load_store(): n = 32 A = te.placeholder((n,), name='A') @@ -30,7 +32,7 @@ def test_cmp_load_store(): def check_llvm(): - if not tvm.runtime.enabled("llvm"): + if not tvm.testing.device_enabled("llvm"): return s = te.create_schedule(D.op) xo, xi = s[C].split(C.op.axis[0], factor=4) @@ -48,9 +50,9 @@ def check_llvm(): d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32')) def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): return + ctx = tvm.context(device, 0) s = te.create_schedule(D.op) for stage in [C, D]: xo, xi = s[stage].split(stage.op.axis[0], factor=4) diff --git a/tests/python/unittest/test_target_codegen_cross_llvm.py b/tests/python/unittest/test_target_codegen_cross_llvm.py index 3ea413c51d84..64a10d80390b 100644 --- a/tests/python/unittest/test_target_codegen_cross_llvm.py +++ b/tests/python/unittest/test_target_codegen_cross_llvm.py @@ -23,6 +23,7 @@ from tvm.contrib import util, cc import numpy as np +@tvm.testing.requires_llvm def test_llvm_add_pipeline(): nn = 1024 n = tvm.runtime.convert(nn) @@ -43,9 +44,6 @@ def verify_elf(path, e_machine): assert struct.unpack(endian + 'h', arr[0x12:0x14])[0] == e_machine def build_i386(): - if not tvm.runtime.enabled("llvm"): - print("Skip because llvm is not enabled..") - return temp = util.tempdir() target = "llvm -mtriple=i386-pc-linux-gnu" f = tvm.build(s, [A, B, C], target) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 7fdd2592ee5f..567f5eace186 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -22,16 +22,16 @@ import unittest from tvm.contrib.nvcc import have_fp16, have_int8 from tvm.contrib import nvcc +import tvm.testing tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_vectorize_add(): num_thread = 8 def check_cuda(dtype, n, lanes): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): print("Skip because gpu does not have fp16 support") return @@ -66,12 +66,11 @@ def check_cuda(dtype, n, lanes): check_cuda("float16", 64, 6) check_cuda("float16", 64, 8) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_multiply_add(): num_thread = 8 def check_cuda(dtype, n, lanes): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version): print("skip because gpu does not support int8") return @@ -98,12 +97,11 @@ def check_cuda(dtype, n, lanes): tvm.testing.assert_allclose(d.asnumpy(), np_d) check_cuda("int8", 64, 4) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_vectorize_load(): num_thread = 8 def check_cuda(dtype, n, lanes): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return ctx = tvm.gpu(0) A = te.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes)) B = te.compute((n,), lambda i: A[i], name='B') @@ -123,11 +121,10 @@ def check_cuda(dtype, n, lanes): check_cuda("int8", 64, 8) check_cuda("int8", 64, 16) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_make_int8(): def check_cuda(n, value, lanes): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return dtype = 'int8' ctx = tvm.gpu(0) A = te.compute((n, lanes), lambda i,j: tvm.tir.const(value, dtype=dtype)) @@ -151,6 +148,8 @@ def check_cuda(n, value, lanes): check_cuda(64, -3, 2) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_inf_nan(): target = 'cuda' def check_inf_nan(ctx, n, value, dtype): @@ -165,10 +164,6 @@ def check_inf_nan(ctx, n, value, dtype): # Only need to test compiling here fun(a, c) - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - ctx = tvm.context(target, 0) check_inf_nan(ctx, 1, -float('inf'), 'float32') @@ -179,11 +174,9 @@ def check_inf_nan(ctx, n, value, dtype): check_inf_nan(ctx, 1, float('nan'), 'float64') +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_shuffle(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - idxm = tvm.tir.indexmod a = te.placeholder((64, ), 'int32') b = te.placeholder((64, ), 'int32') @@ -227,99 +220,82 @@ def _transform(f, *_): module(nda, ndb, ndc) tvm.testing.assert_allclose(ndc.asnumpy(), ref) -def test_crossthread_reduction1(): - def check(device): - ctx = tvm.context(device, 0) - if not ctx.exist or not tvm.runtime.enabled(device): - print("skip because", device, "is not enabled..") - return - n = te.var("n") - m = te.var("m") - A = te.placeholder((n, m), name='A') - k = te.reduce_axis((0, m), "m") - B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") - - def sched(nthd): - s = te.create_schedule(B.op) - ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd) - s[B].bind(ko, te.thread_axis("threadIdx.x")) - s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) - func = tvm.build(s, [A, B], device) - return func - - def verify(nthd): - func = sched(nthd) - nn = 3 - # checks three typical cases - vals = [nthd-1, nthd, nthd+1] - for kk in [x for x in vals]: - size = (nn, kk) - a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) - func(a, b) - tvm.testing.assert_allclose(b.asnumpy(), \ - np.sum(a.asnumpy(), axis=1), rtol=1e-3) - - verify(16) - verify(32) - verify(64) - - check("cuda") - check("rocm") - - -def test_crossthread_reduction2(): - def check(device): - ctx = tvm.context(device, 0) - if not ctx.exist or not tvm.runtime.enabled(device): - print("skip because", device, "is not enabled..") - return - - n = te.var("n") - k0 = te.var("k0") - k1 = te.var("k1") - A = te.placeholder((n, k0, k1), name='A') - k0 = te.reduce_axis((0, k0), "k0") - k1 = te.reduce_axis((0, k1), "k1") - B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B") - - def sched(nthdx, nthdy): - s = te.create_schedule(B.op) - k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx) - k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy) - s[B].bind(k0o, te.thread_axis("threadIdx.x")) - s[B].bind(k1o, te.thread_axis("threadIdx.y")) - s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) - func = tvm.build(s, [A, B], device) - return func - - def verify(nthdx, nthdy): - func = sched(nthdx, nthdy) - nn = 3 - # checks three typical cases - vx = [nthdx-1, nthdx, nthdx+1] - vy = [nthdy-1, nthdy, nthdy+1] - for kk0, kk1 in [(x, y) for x in vx for y in vy]: - size = (nn, kk0, kk1) - a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) - func(a, b) - tvm.testing.assert_allclose(b.asnumpy(), \ - np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3) - - verify(16, 16) - verify(32, 32) - verify(16, 32) - verify(32, 16) - - check("cuda") - check("rocm") +@tvm.testing.parametrize_targets("cuda", "rocm") +def test_crossthread_reduction1(target, ctx): + n = te.var("n") + m = te.var("m") + A = te.placeholder((n, m), name='A') + k = te.reduce_axis((0, m), "m") + B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") + def sched(nthd): + s = te.create_schedule(B.op) + ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd) + s[B].bind(ko, te.thread_axis("threadIdx.x")) + s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) + func = tvm.build(s, [A, B], target) + return func + + def verify(nthd): + func = sched(nthd) + nn = 3 + # checks three typical cases + vals = [nthd-1, nthd, nthd+1] + for kk in [x for x in vals]: + size = (nn, kk) + a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), \ + np.sum(a.asnumpy(), axis=1), rtol=1e-3) + + verify(16) + verify(32) + verify(64) + + +@tvm.testing.parametrize_targets("cuda", "rocm") +def test_crossthread_reduction2(target, ctx): + n = te.var("n") + k0 = te.var("k0") + k1 = te.var("k1") + A = te.placeholder((n, k0, k1), name='A') + k0 = te.reduce_axis((0, k0), "k0") + k1 = te.reduce_axis((0, k1), "k1") + B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B") + + def sched(nthdx, nthdy): + s = te.create_schedule(B.op) + k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx) + k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy) + s[B].bind(k0o, te.thread_axis("threadIdx.x")) + s[B].bind(k1o, te.thread_axis("threadIdx.y")) + s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) + func = tvm.build(s, [A, B], target) + return func + + def verify(nthdx, nthdy): + func = sched(nthdx, nthdy) + nn = 3 + # checks three typical cases + vx = [nthdx-1, nthdx, nthdx+1] + vy = [nthdy-1, nthdy, nthdy+1] + for kk0, kk1 in [(x, y) for x in vx for y in vy]: + size = (nn, kk0, kk1) + a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), \ + np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3) + + verify(16, 16) + verify(32, 32) + verify(16, 32) + verify(32, 16) + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_reduction_binding(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - k = te.reduce_axis((0, 32), 'k') A = te.placeholder((96, 32), name='A') B = te.compute( (96,), lambda m: @@ -334,46 +310,39 @@ def test_cuda_reduction_binding(): fcuda = tvm.build(s, [A, B], "cuda") -def test_rfactor_predicates(): - def check(device): - ctx = tvm.context(device, 0) - if not ctx.exist or not tvm.runtime.enabled(device): - print("skip because", device, "is not enabled..") - return - - n = te.reduce_axis((0, 129), 'n') - A = te.placeholder((129,), name='A') - B = te.compute( (1, ), lambda b: - te.sum(A[n], - axis=n), - name='B' - ) +@tvm.testing.parametrize_targets("cuda", "rocm") +def test_rfactor_predicates(target, ctx): + n = te.reduce_axis((0, 129), 'n') + A = te.placeholder((129,), name='A') + B = te.compute( (1, ), lambda b: + te.sum(A[n], + axis=n), + name='B' + ) - s = te.create_schedule(B.op) - - _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8) + s = te.create_schedule(B.op) - BF = s.rfactor(B, ni, 0) - s[B].set_store_predicate(tx.var.equal(0)) + _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8) - s[B].bind(s[B].op.reduce_axis[0], tx) - s[B].bind(s[B].op.axis[0], bx) + BF = s.rfactor(B, ni, 0) + s[B].set_store_predicate(tx.var.equal(0)) - s[BF].compute_at(s[B], s[B].op.axis[0]) + s[B].bind(s[B].op.reduce_axis[0], tx) + s[B].bind(s[B].op.axis[0], bx) - _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2) + s[BF].compute_at(s[B], s[B].op.axis[0]) - BF2 = s.rfactor(BF, noi, 0) + _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2) - s[BF].bind(s[BF].op.axis[0], tx) - s[BF2].compute_at(s[BF], s[BF].op.axis[1]) + BF2 = s.rfactor(BF, noi, 0) - fcuda = tvm.build(s, [A, B], device) + s[BF].bind(s[BF].op.axis[0], tx) + s[BF2].compute_at(s[BF], s[BF].op.axis[1]) - check("cuda") - check("rocm") + fcuda = tvm.build(s, [A, B], target) -@unittest.skipIf(not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"), "skip because cuda is not enabled..") +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_const_float_to_half(): # This import is required to use nvcc to perform code gen; # otherwise it is found that the code gen is done by nvrtc. @@ -398,16 +367,14 @@ def test_cuda_const_float_to_half(): func(a, c) np.testing.assert_equal(c.asnumpy(), a_np > b.value) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_reduction(): def check(device, dtype, m=32, n=32): - ctx = tvm.context(device, 0) - if not ctx.exist or not tvm.runtime.enabled(device): - print("skip because", device, "is not enabled..") + if not tvm.testing.device_enabled(device): + print("Skipping", device) return - if dtype == "float16" and not have_fp16(ctx.compute_version): - print("Skip because gpu does not have fp16 support") - return - + ctx = tvm.context(device, 0) a = te.placeholder((m, n), name="a", dtype=dtype) b = te.placeholder((m, n), name="b", dtype=dtype) c = a + b @@ -430,12 +397,14 @@ def check(device, dtype, m=32, n=32): check("rocm", "float32") check("cuda", "float16") +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_mix_threaded_and_normal_reduction(): def check(device, dtype, m=32, n=32): - ctx = tvm.context(device, 0) - if not ctx.exist or not tvm.runtime.enabled(device): - print("skip because", device, "is not enabled..") + if not tvm.testing.device_enabled(device): + print("Skipping", device) return + ctx = tvm.context(device, 0) if dtype == "float16" and not have_fp16(ctx.compute_version): print("Skip because gpu does not have fp16 support") return @@ -458,11 +427,9 @@ def check(device, dtype, m=32, n=32): check("rocm", "float32") check("cuda", "float16") +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_floordiv_with_vectorization(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - with tvm.target.cuda(): # B[i] = A[floordiv(i, k)] n = 256 @@ -485,11 +452,9 @@ def test_cuda_floordiv_with_vectorization(): func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_floormod_with_vectorization(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - with tvm.target.cuda(): # B[i] = A[floormod(i, k)] n = 256 @@ -512,11 +477,9 @@ def test_cuda_floormod_with_vectorization(): func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_vectorized_casts(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - def check(t0, t1): if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.gpu(0).compute_version): print("Skip because gpu does not have fp16 support") @@ -571,6 +534,8 @@ def sched(B): s[B].bind(iio, tx) return s +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_vectorized_intrin1(): test_funcs = [ (tvm.tir.floor, lambda x : np.floor(x)), @@ -594,9 +559,6 @@ def test_vectorized_intrin1(): (tvm.tir.sqrt, lambda x : np.sqrt(x)), ] def run_test(tvm_intrin, np_func, dtype): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): print("Skip because gpu does not have fp16 support") return @@ -627,6 +589,8 @@ def run_test(tvm_intrin, np_func, dtype): run_test(*func, "float32") run_test(*func, "float16") +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_vectorized_intrin2(dtype="float32"): c2 = tvm.tir.const(2, dtype=dtype) test_funcs = [ @@ -634,10 +598,6 @@ def test_vectorized_intrin2(dtype="float32"): (tvm.tir.fmod, lambda x : np.fmod(x, 2.0)) ] def run_test(tvm_intrin, np_func): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - n = 128 A = te.placeholder((n,), dtype=dtype, name='A') B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name='B') @@ -652,6 +612,8 @@ def run_test(tvm_intrin, np_func): for func in test_funcs: run_test(*func) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_vectorized_popcount(): def ref_popcount(x): cnt = 0 @@ -661,10 +623,6 @@ def ref_popcount(x): return cnt def run_test(dtype): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - n = 128 A = te.placeholder((n,), dtype=dtype, name='A') B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name='B') @@ -680,11 +638,10 @@ def run_test(dtype): run_test("uint32") run_test("uint64") +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_vectorize_load_permute_pad(): def check_cuda(dtype, n, l, padding, lanes): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): print("Skip because gpu does not have fp16 support") return @@ -755,23 +712,21 @@ def post_visit(stmt): tvm.tir.stmt_functor.ir_transform(stmt['main'].body, pre_visit, post_visit) - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("CUDA device not found, skip the verification.") - return - else: - tgt = tvm.target.cuda() - mod = tvm.build(s, args, tgt) - # To check if every vectorize loop transforms to correct instruction - # print(mod.imported_modules[0].get_source()) - - ctx = tvm.context("cuda", 0) - a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx) - b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx) - c = tvm.nd.array(np.zeros((512, 512), dtype="float32"), ctx) - mod(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), np.dot( - a.asnumpy(), b.asnumpy()), rtol=1e-5) - + tgt = tvm.target.cuda() + mod = tvm.build(s, args, tgt) + # To check if every vectorize loop transforms to correct instruction + # print(mod.imported_modules[0].get_source()) + + ctx = tvm.context("cuda", 0) + a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx) + b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx) + c = tvm.nd.array(np.zeros((512, 512), dtype="float32"), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_vectorized_cooperative_fetching_x(): N = 512 A = te.placeholder((N, N), name='A', dtype='float32') @@ -821,6 +776,8 @@ def test_vectorized_cooperative_fetching_x(): vcf_check_common(s, [A, B, C]) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_vectorized_cooperative_fetching_xy(): N = 512 A = te.placeholder((N, N), name='A') @@ -874,11 +831,9 @@ def test_vectorized_cooperative_fetching_xy(): vcf_check_common(s, [A, B, C]) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_unrolled_vectorization(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - dtype = 'float32' target = 'cuda' diff --git a/tests/python/unittest/test_target_codegen_device.py b/tests/python/unittest/test_target_codegen_device.py index ddb35f31fe1d..3289e38df0bb 100644 --- a/tests/python/unittest/test_target_codegen_device.py +++ b/tests/python/unittest/test_target_codegen_device.py @@ -18,7 +18,9 @@ from tvm import te from tvm.contrib import util import numpy as np +import tvm.testing +@tvm.testing.requires_gpu def test_large_uint_imm(): value = (1 << 63) + 123 other = tvm.tir.const(3, "uint64") @@ -32,9 +34,9 @@ def test_large_uint_imm(): s[A].bind(xo, te.thread_axis("blockIdx.x")) def check_target(device): - ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): return + ctx = tvm.context(device, 0) f = tvm.build(s, [A], device) # launch the kernel. a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx) @@ -45,6 +47,7 @@ def check_target(device): check_target("vulkan") +@tvm.testing.requires_gpu def test_add_pipeline(): n = te.size_var('n') A = te.placeholder((n,), name='A') @@ -64,11 +67,9 @@ def test_add_pipeline(): s[D].bind(xo, te.thread_axis("blockIdx.x")) def check_target(device, host="stackvm"): - ctx = tvm.context(device, 0) - if not ctx.exist: - return - if not tvm.runtime.enabled(host): + if not tvm.testing.device_enabled(device) or not tvm.testing.device_enabled(host): return + ctx = tvm.context(device, 0) mhost = tvm.driver.build(s, [A, B, D], target=device, target_host=host) f = mhost.entry_func # launch the kernel. diff --git a/tests/python/unittest/test_target_codegen_extern.py b/tests/python/unittest/test_target_codegen_extern.py index 4104af864439..ef98816e24c6 100644 --- a/tests/python/unittest/test_target_codegen_extern.py +++ b/tests/python/unittest/test_target_codegen_extern.py @@ -17,7 +17,9 @@ import tvm from tvm import te import numpy as np +import tvm.testing +@tvm.testing.uses_gpu def test_add_pipeline(): nn = 64 max_threads = 4 @@ -51,7 +53,7 @@ def extern_generator_gpu(ins, outs): print(tvm.lower(s_gpu, [A, C_gpu], simple_mode=True)) def check_target(target): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): return s = s_gpu if target in ['opencl', 'cuda'] else s_cpu C = C_gpu if target in ['opencl', 'cuda'] else C_cpu @@ -86,7 +88,7 @@ def my_extern_array_func1(aa, bb): def check_target(target): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): return # build and invoke the kernel. f = tvm.build(s, [A, C], target) @@ -116,7 +118,7 @@ def extern_generator(ins, outs): s = te.create_schedule(C.op) def check_target(target): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): return # build and invoke the kernel. f = tvm.build(s, [A, C], target) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index d690364f5c5d..fd7a764a5baa 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -24,6 +24,7 @@ import re +@tvm.testing.requires_llvm def test_llvm_intrin(): ib = tvm.tir.ir_builder.create() n = tvm.runtime.convert(4) @@ -44,6 +45,7 @@ def test_llvm_intrin(): fcode = tvm.build(mod, None, "llvm") +@tvm.testing.requires_llvm def test_llvm_void_intrin(): ib = tvm.tir.ir_builder.create() A = ib.pointer("uint8", name="A") @@ -56,6 +58,7 @@ def test_llvm_void_intrin(): fcode = tvm.build(mod, None, "llvm") +@tvm.testing.requires_llvm def test_llvm_overloaded_intrin(): # Name lookup for overloaded intrinsics in LLVM 4- requires a name # that includes the overloaded types. @@ -80,6 +83,7 @@ def use_llvm_intrinsic(A, C): f = tvm.build(s, [A, C], target = 'llvm') +@tvm.testing.requires_llvm def test_llvm_import(): # extern "C" is necessary to get the correct signature cc_code = """ @@ -93,8 +97,6 @@ def test_llvm_import(): tvm.tir.call_pure_extern("float32", "my_add", A(*i), 1.0), name='B') def check_llvm(use_file): - if not tvm.runtime.enabled("llvm"): - return if not clang.find_clang(required=False): print("skip because clang is not available") return @@ -120,6 +122,7 @@ def check_llvm(use_file): +@tvm.testing.requires_llvm def test_llvm_lookup_intrin(): ib = tvm.tir.ir_builder.create() A = ib.pointer("uint8x8", name="A") @@ -132,6 +135,7 @@ def test_llvm_lookup_intrin(): fcode = tvm.build(mod, None, "llvm") +@tvm.testing.requires_llvm def test_llvm_large_uintimm(): value = (1 << 63) + 123 other = tvm.tir.const(3, "uint64") @@ -139,8 +143,6 @@ def test_llvm_large_uintimm(): s = te.create_schedule(A.op) def check_llvm(): - if not tvm.runtime.enabled("llvm"): - return f = tvm.build(s, [A], "llvm") ctx = tvm.cpu(0) # launch the kernel. @@ -151,6 +153,7 @@ def check_llvm(): check_llvm() +@tvm.testing.requires_llvm def test_llvm_add_pipeline(): nn = 1024 n = tvm.runtime.convert(nn) @@ -170,8 +173,6 @@ def test_llvm_add_pipeline(): s[C].vectorize(xi) def check_llvm(): - if not tvm.runtime.enabled("llvm"): - return # Specifically allow offset to test codepath when offset is available Ab = tvm.tir.decl_buffer( A.shape, A.dtype, @@ -194,6 +195,7 @@ def check_llvm(): check_llvm() +@tvm.testing.requires_llvm def test_llvm_persist_parallel(): n = 128 A = te.placeholder((n,), name='A') @@ -210,8 +212,6 @@ def test_llvm_persist_parallel(): s[C].pragma(xi, "parallel_stride_pattern") def check_llvm(): - if not tvm.runtime.enabled("llvm"): - return # BUILD and invoke the kernel. f = tvm.build(s, [A, C], "llvm") ctx = tvm.cpu(0) @@ -226,10 +226,9 @@ def check_llvm(): check_llvm() +@tvm.testing.requires_llvm def test_llvm_flip_pipeline(): def check_llvm(nn, base): - if not tvm.runtime.enabled("llvm"): - return n = tvm.runtime.convert(nn) A = te.placeholder((n + base), name='A') C = te.compute((n,), lambda i: A(nn + base- i - 1), name='C') @@ -253,10 +252,9 @@ def check_llvm(nn, base): check_llvm(128, 1) +@tvm.testing.requires_llvm def test_llvm_vadd_pipeline(): def check_llvm(n, lanes): - if not tvm.runtime.enabled("llvm"): - return A = te.placeholder((n,), name='A', dtype="float32x%d" % lanes) B = te.compute((n,), lambda i: A[i], name='B') C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name='C') @@ -282,10 +280,9 @@ def check_llvm(n, lanes): check_llvm(512, 2) +@tvm.testing.requires_llvm def test_llvm_madd_pipeline(): def check_llvm(nn, base, stride): - if not tvm.runtime.enabled("llvm"): - return n = tvm.runtime.convert(nn) A = te.placeholder((n + base, stride), name='A') C = te.compute((n, stride), lambda i, j: A(base + i, j) + 1, name='C') @@ -310,6 +307,7 @@ def check_llvm(nn, base, stride): check_llvm(4, 0, 3) +@tvm.testing.requires_llvm def test_llvm_temp_space(): nn = 1024 n = tvm.runtime.convert(nn) @@ -319,8 +317,6 @@ def test_llvm_temp_space(): s = te.create_schedule(C.op) def check_llvm(): - if not tvm.runtime.enabled("llvm"): - return # build and invoke the kernel. f = tvm.build(s, [A, C], "llvm") ctx = tvm.cpu(0) @@ -333,6 +329,7 @@ def check_llvm(): c.asnumpy(), a.asnumpy() + 1 + 1) check_llvm() +@tvm.testing.requires_llvm def test_multiple_func(): nn = 1024 n = tvm.runtime.convert(nn) @@ -344,8 +341,6 @@ def test_multiple_func(): s[C].parallel(xo) s[C].vectorize(xi) def check_llvm(): - if not tvm.runtime.enabled("llvm"): - return # build two functions f2 = tvm.lower(s, [A, B, C], name="fadd1") f1 = tvm.lower(s, [A, B, C], name="fadd2") @@ -369,10 +364,9 @@ def check_llvm(): +@tvm.testing.requires_llvm def test_llvm_condition(): def check_llvm(n, offset): - if not tvm.runtime.enabled("llvm"): - return A = te.placeholder((n, ), name='A') C = te.compute((n,), lambda i: tvm.tir.if_then_else(i >= offset, A[i], 0.0), name='C') s = te.create_schedule(C.op) @@ -389,10 +383,9 @@ def check_llvm(n, offset): check_llvm(64, 8) +@tvm.testing.requires_llvm def test_llvm_bool(): def check_llvm(n): - if not tvm.runtime.enabled("llvm"): - return A = te.placeholder((n, ), name='A', dtype="int32") C = te.compute((n,), lambda i: A[i].equal(1).astype("float"), name='C') s = te.create_schedule(C.op) @@ -408,10 +401,9 @@ def check_llvm(n): check_llvm(64) +@tvm.testing.requires_llvm def test_rank_zero(): def check_llvm(n): - if not tvm.runtime.enabled("llvm"): - return A = te.placeholder((n, ), name='A') scale = te.placeholder((), name='scale') k = te.reduce_axis((0, n), name="k") @@ -431,10 +423,9 @@ def check_llvm(n): tvm.testing.assert_allclose(d.asnumpy(), d_np) check_llvm(64) +@tvm.testing.requires_llvm def test_rank_zero_bound_checkers(): def check_llvm(n): - if not tvm.runtime.enabled("llvm"): - return with tvm.transform.PassContext(config={"tir.instrument_bound_checkers": True}): A = te.placeholder((n, ), name='A') scale = te.placeholder((), name='scale') @@ -456,6 +447,7 @@ def check_llvm(n): check_llvm(64) +@tvm.testing.requires_llvm def test_alignment(): n = tvm.runtime.convert(1024) A = te.placeholder((n,), name='A') @@ -496,6 +488,7 @@ def has_call_to_assume(): assert has_call_to_assume() +@tvm.testing.requires_llvm def test_llvm_div(): """Check that the semantics of div and mod is correct""" def check(start, end, dstart, dend, dtype, floor_div=False): @@ -595,6 +588,7 @@ def _show_info(): check(0, 255, dstart, dend, 'uint8', floor_div=False) check(0, 255, dstart, dend, 'uint8', floor_div=True) +@tvm.testing.requires_llvm def test_llvm_fp_math(): def check_llvm_reciprocal(n): A = te.placeholder((n,), name='A') @@ -629,6 +623,7 @@ def check_llvm_sigmoid(n): check_llvm_sigmoid(16) +@tvm.testing.requires_llvm def test_dwarf_debug_information(): nn = 1024 n = tvm.runtime.convert(nn) @@ -640,8 +635,6 @@ def test_dwarf_debug_information(): s[C].parallel(xo) s[C].vectorize(xi) def check_llvm_object(): - if not tvm.runtime.enabled("llvm"): - return if tvm.target.codegen.llvm_version_major() < 5: return if tvm.target.codegen.llvm_version_major() > 6: @@ -676,8 +669,6 @@ def check_llvm_object(): assert re.search(r"""DW_AT_name.*fadd2""", str(output)) def check_llvm_ir(): - if not tvm.runtime.enabled("llvm"): - return if tvm.target.codegen.llvm_version_major() < 5: return if tvm.target.codegen.llvm_version_major() > 6: @@ -704,6 +695,7 @@ def check_llvm_ir(): check_llvm_ir() +@tvm.testing.requires_llvm def test_llvm_shuffle(): a = te.placeholder((8, ), 'int32') b = te.placeholder((8, ), 'int32') @@ -760,6 +752,7 @@ def np_bf16_cast_and_cast_back(arr): ''' Convert a numpy array of float to bf16 and cast back''' return np_bf162np_float(np_float2np_bf16(arr)) +@tvm.testing.requires_llvm def test_llvm_bf16(): def dotest(do_vectorize): np.random.seed(122) @@ -784,6 +777,7 @@ def dotest(do_vectorize): dotest(True) dotest(False) +@tvm.testing.requires_llvm def test_llvm_crt_static_lib(): A = te.placeholder((32, ), dtype='bfloat16') B = te.placeholder((32, ), dtype='bfloat16') diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index e403589dff1d..9a03a795e875 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -16,9 +16,12 @@ # under the License. import tvm from tvm import te +import tvm.testing target = 'opencl' +@tvm.testing.requires_gpu +@tvm.testing.requires_opencl def test_opencl_ternary_expression(): def check_if_then_else(ctx, n, dtype): A = te.placeholder((n,), name='A', dtype=dtype) @@ -52,10 +55,6 @@ def check_select(ctx, n, dtype): # Only need to test compiling here fun(a, c) - if not tvm.runtime.enabled(target): - print("skip because opencl is not enabled..") - return - ctx = tvm.context(target, 0) check_if_then_else(ctx, 1, 'int8') @@ -67,6 +66,8 @@ def check_select(ctx, n, dtype): check_select(ctx, 1, 'int16') check_select(ctx, 1, 'uint16') +@tvm.testing.requires_gpu +@tvm.testing.requires_opencl def test_opencl_inf_nan(): def check_inf_nan(ctx, n, value, dtype): A = te.placeholder((n,), name='A', dtype=dtype) @@ -80,10 +81,6 @@ def check_inf_nan(ctx, n, value, dtype): # Only need to test compiling here fun(a, c) - if not tvm.runtime.enabled(target): - print("skip because opencl is not enabled..") - return - ctx = tvm.context(target, 0) check_inf_nan(ctx, 1, -float('inf'), 'float32') @@ -94,6 +91,8 @@ def check_inf_nan(ctx, n, value, dtype): check_inf_nan(ctx, 1, float('nan'), 'float64') +@tvm.testing.requires_gpu +@tvm.testing.requires_opencl def test_opencl_max(): def check_max(ctx, n, dtype): A = te.placeholder((n,), name='A', dtype=dtype) @@ -109,10 +108,6 @@ def check_max(ctx, n, dtype): # Only need to test compiling here fun(a, c) - if not tvm.runtime.enabled(target): - print("skip because opencl is not enabled..") - return - ctx = tvm.context(target, 0) check_max(ctx, 1, 'int8') diff --git a/tests/python/unittest/test_target_codegen_rocm.py b/tests/python/unittest/test_target_codegen_rocm.py index 4c6304a7a31f..2adc1c89d804 100644 --- a/tests/python/unittest/test_target_codegen_rocm.py +++ b/tests/python/unittest/test_target_codegen_rocm.py @@ -24,7 +24,7 @@ bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") -@unittest.skipIf(not tvm.rocm(0).exist or not tvm.runtime.enabled("rocm"), "skip because rocm is not enabled..") +@tvm.testing.requires_rocm def test_rocm_cross_thread_reduction(): # based on the reduction tutorial n = te.size_var("n") @@ -52,7 +52,7 @@ def test_rocm_cross_thread_reduction(): b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) -@unittest.skipIf(not tvm.rocm(0).exist or not tvm.runtime.enabled("rocm"), "skip because rocm is not enabled..") +@tvm.testing.requires_rocm def test_rocm_inf_nan(): def check_inf_nan(ctx, n, value, dtype): A = te.placeholder((n,), name='A', dtype=dtype) @@ -75,7 +75,7 @@ def check_inf_nan(ctx, n, value, dtype): check_inf_nan(ctx, 1, float('nan'), 'float32') check_inf_nan(ctx, 1, float('nan'), 'float64') -@unittest.skipIf(not tvm.rocm(0).exist or not tvm.runtime.enabled("rocm"), "skip because rocm is not enabled..") +@tvm.testing.requires_rocm def test_rocm_reduction_binding(): k = te.reduce_axis((0, 32), 'k') A = te.placeholder((96, 32), name='A') @@ -89,7 +89,7 @@ def test_rocm_reduction_binding(): mo, _ = s[B].split(B.op.axis[0], 32) s[B].bind(mo, bx) -@unittest.skipIf(not tvm.rocm(0).exist or not tvm.runtime.enabled("rocm"), "skip because rocm is not enabled..") +@tvm.testing.requires_rocm def test_rocm_copy(): def check_rocm(dtype, n): @@ -107,7 +107,7 @@ def check_rocm(dtype, n): peturb = np.random.uniform(low=0.5, high=1.5) check_rocm(dtype, int(peturb * (2 ** logN))) -@unittest.skipIf(not tvm.rocm(0).exist or not tvm.runtime.enabled("rocm"), "skip because rocm is not enabled..") +@tvm.testing.requires_rocm def test_rocm_vectorize_add(): num_thread = 8 diff --git a/tests/python/unittest/test_target_codegen_vm_basic.py b/tests/python/unittest/test_target_codegen_vm_basic.py index e03d689b6c5f..55c7c3148b79 100644 --- a/tests/python/unittest/test_target_codegen_vm_basic.py +++ b/tests/python/unittest/test_target_codegen_vm_basic.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te import numpy as np def run_jit(fapi, check): for target in ["llvm", "stackvm"]: - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): continue f = tvm.driver.build(fapi, target=target) s = f.get_source() diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 722a9ec6be15..a036cd89141d 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -20,11 +20,8 @@ import numpy as np +@tvm.testing.requires_vulkan def test_vector_comparison(): - if not tvm.runtime.enabled("vulkan"): - print("Skipping due to no Vulkan module") - return - target = 'vulkan' def check_correct_assembly(dtype): @@ -60,12 +57,10 @@ def check_correct_assembly(dtype): bx = te.thread_axis("blockIdx.x") +@tvm.testing.requires_vulkan def test_vulkan_copy(): def check_vulkan(dtype, n): - if not tvm.vulkan(0).exist or not tvm.runtime.enabled("vulkan"): - print("skip because vulkan is not enabled..") - return A = te.placeholder((n,), name='A', dtype=dtype) ctx = tvm.vulkan(0) a_np = np.random.uniform(size=(n,)).astype(A.dtype) @@ -81,13 +76,11 @@ def check_vulkan(dtype, n): check_vulkan(dtype, int(peturb * (2 ** logN))) +@tvm.testing.requires_vulkan def test_vulkan_vectorize_add(): num_thread = 8 def check_vulkan(dtype, n, lanes): - if not tvm.vulkan(0).exist or not tvm.runtime.enabled("vulkan"): - print("skip because vulkan is not enabled..") - return A = te.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes)) B = te.compute((n,), lambda i: A[i]+tvm.tir.const(1, A.dtype), name='B') s = te.create_schedule(B.op) @@ -106,6 +99,7 @@ def check_vulkan(dtype, n, lanes): check_vulkan("float16", 64, 2) +@tvm.testing.requires_vulkan def test_vulkan_stress(): """ Launch a randomized test with multiple kernels per stream, multiple uses of @@ -118,9 +112,6 @@ def test_vulkan_stress(): def run_stress(): def worker(): - if not tvm.vulkan(0).exist or not tvm.runtime.enabled("vulkan"): - print("skip because vulkan is not enabled..") - return A = te.placeholder((n,), name='A', dtype="float32") B = te.placeholder((n,), name='B', dtype="float32") functions = [ diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py index 5b4a3094f791..5bebf3d98fcb 100644 --- a/tests/python/unittest/test_te_autodiff.py +++ b/tests/python/unittest/test_te_autodiff.py @@ -17,7 +17,7 @@ import tvm from tvm import te -from tvm.testing import check_numerical_grads, assert_allclose +from tvm.testing import assert_allclose from tvm import topi from tvm.topi.util import get_const_tuple import pytest @@ -30,10 +30,7 @@ def check_grad(out, inputs, args=[], data_range=(-10, 10), desired_grads=None, a def check_device(device, host="llvm"): ctx = tvm.context(device, 0) - if not tvm.runtime.enabled(host): - return - if not ctx.exist: - print("skip because %s is not enabled.." % device) + if not tvm.testing.device_enabled(host): return sout = te.create_schedule(out.op) @@ -74,7 +71,7 @@ def forward(*in_data): out_data = tvm.nd.empty(out_shape, out.dtype) mout(out_data, *[tvm.nd.array(d) for d in list(in_data)]) return out_data.asnumpy().sum() - check_numerical_grads(forward, [d.asnumpy() for d in input_data + arg_vals], g_res) + tvm.testing.check_numerical_grads(forward, [d.asnumpy() for d in input_data + arg_vals], g_res) check_device("cpu") diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 8ab65f129cc5..6640420c3cf9 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -21,6 +21,8 @@ from tvm.te.hybrid import script from tvm.te.hybrid.runtime import HYBRID_GLOBALS +import tvm.testing + @pytest.mark.skip def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): @@ -316,11 +318,9 @@ def if_and(a): run_and_check(func, ins, outs=outs) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_bind(): - if not tvm.gpu(0).exist: - print('[Warning] No GPU found! Skip bind test!') - return - @script def vec_add(a, b): c = output_tensor((1000, ), 'float32') @@ -463,6 +463,8 @@ def triangle(a, b): func, ins, outs = run_and_check(triangle, [a, b]) run_and_check(func, ins, outs=outs) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_allocate(): @te.hybrid.script def blur2d(a): @@ -482,27 +484,24 @@ def blur2d(a): func, ins, outs = run_and_check(blur2d, [a]) run_and_check(func, ins, outs=outs) - if tvm.gpu().exist: - @te.hybrid.script - def share_vec_add(a, b): - c = output_tensor((256, ), 'float32') - shared = allocate((256, ), 'float32', 'shared') - for i in bind("threadIdx.x", 256): - shared[i] = a[i] - local = allocate((256, ), 'float32', 'local') - for i in bind("threadIdx.x", 256): - local[i] = b[i] - for i in bind("threadIdx.x", 256): - c[i] = shared[i] + local[i] - return c - - a = te.placeholder((256, ), dtype='float32', name='a') - b = te.placeholder((256, ), dtype='float32', name='b') - c = share_vec_add(a, b) - func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda') - run_and_check(func, ins, outs=outs, target='cuda') - else: - print('[Warning] No GPU found! Skip shared mem test!') + @te.hybrid.script + def share_vec_add(a, b): + c = output_tensor((256, ), 'float32') + shared = allocate((256, ), 'float32', 'shared') + for i in bind("threadIdx.x", 256): + shared[i] = a[i] + local = allocate((256, ), 'float32', 'local') + for i in bind("threadIdx.x", 256): + local[i] = b[i] + for i in bind("threadIdx.x", 256): + c[i] = shared[i] + local[i] + return c + + a = te.placeholder((256, ), dtype='float32', name='a') + b = te.placeholder((256, ), dtype='float32', name='b') + c = share_vec_add(a, b) + func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda') + run_and_check(func, ins, outs=outs, target='cuda') def test_upstream(): @te.hybrid.script diff --git a/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py b/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py index 1f1791447ab1..a57a340ac063 100644 --- a/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py +++ b/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py @@ -18,7 +18,7 @@ from tvm import te from tvm import topi import numpy as np -from tvm.contrib import nvcc +import tvm.testing def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96): A = te.placeholder((n, l), name='A', dtype='float16') @@ -204,26 +204,14 @@ def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2): c_np[bs, :, :] = np.dot(a_np[bs, :, :], b_np[bs, :, :]) np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3) +@tvm.testing.requires_tensorcore def test_tensor_core_matmul(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): - print("skip because gpu does not support tensor core") - return - tensor_core_matmul(16) #test with warp_tile 16x16x16 tensor_core_matmul(8) #test with warp_tile 8x32x16 tensor_core_matmul(32) #test with warp_tile 32x8x16 +@tvm.testing.requires_tensorcore def test_tensor_core_batch_matmul(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): - print("skip because gpu does not support tensor core") - return - tensor_core_batch_matmul() if __name__ == '__main__': diff --git a/tests/python/unittest/test_te_schedule_tensor_core.py b/tests/python/unittest/test_te_schedule_tensor_core.py index aa87665455df..8b70c82af6e2 100644 --- a/tests/python/unittest/test_te_schedule_tensor_core.py +++ b/tests/python/unittest/test_te_schedule_tensor_core.py @@ -18,7 +18,7 @@ from tvm import te import numpy as np from tvm.topi.testing import conv2d_nhwc_python -from tvm.contrib import nvcc +import tvm.testing VERIFY = True @@ -103,14 +103,8 @@ def intrin_func(ins, outs): return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) +@tvm.testing.requires_tensorcore def test_tensor_core_batch_matmal(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): - print("skip because gpu does not support tensor core") - return - batch_size = 4 n = 512 m, l = n, n @@ -216,14 +210,8 @@ def test_tensor_core_batch_matmal(): +@tvm.testing.requires_tensorcore def test_tensor_core_batch_conv(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): - print("skip because gpu does not support tensor core") - return - # The sizes of inputs and filters batch_size = 32 height = 14 diff --git a/tests/python/unittest/test_te_tensor_overload.py b/tests/python/unittest/test_te_tensor_overload.py index 97143681891c..4b11c009d41b 100644 --- a/tests/python/unittest/test_te_tensor_overload.py +++ b/tests/python/unittest/test_te_tensor_overload.py @@ -20,6 +20,7 @@ from tvm import topi import tvm.topi.testing from tvm.topi.util import get_const_tuple +import tvm.testing def test_operator_type_and_tags(): @@ -103,10 +104,10 @@ def verify_tensor_scalar_bop(shape, typ="add"): raise NotImplementedError() def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return + ctx = tvm.context(device, 0) print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_elemwise_schedule(device)(B) @@ -150,7 +151,7 @@ def verify_broadcast_bop(lhs_shape, rhs_shape, typ="add"): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -183,10 +184,11 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, stride, padding, typ="add"): def check_device(device): ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -239,6 +241,7 @@ def check_device(device): check_device(device) +@tvm.testing.uses_gpu def test_tensor_scalar_bop(): verify_tensor_scalar_bop((1,), typ="add") verify_tensor_scalar_bop((3, 5), typ="sub") @@ -246,6 +249,7 @@ def test_tensor_scalar_bop(): verify_tensor_scalar_bop((2, 3, 1, 32), typ="div") +@tvm.testing.uses_gpu def test_broadcast_bop(): verify_broadcast_bop((2, 3), (), typ="add") verify_broadcast_bop((5, 2, 3), (1,), typ="add") @@ -254,6 +258,7 @@ def test_broadcast_bop(): verify_broadcast_bop((2, 3, 1, 32), (64, 32), typ="div") +@tvm.testing.uses_gpu def test_conv2d_scalar_bop(): verify_conv2d_scalar_bop(1, 16, 4, 4, 3, 1, 1, typ="add") verify_conv2d_scalar_bop(1, 32, 2, 1, 3, 1, 1, typ="sub") diff --git a/tests/python/unittest/test_testing.py b/tests/python/unittest/test_testing.py index ea1680111ee0..c7be32552019 100644 --- a/tests/python/unittest/test_testing.py +++ b/tests/python/unittest/test_testing.py @@ -17,7 +17,7 @@ import numpy as np import tvm from tvm import te -from tvm.testing import check_numerical_grads +import tvm.testing def test_check_numerical_grads(): # Functions and their derivatives @@ -46,7 +46,7 @@ def test_check_numerical_grads(): func_forw = lambda x: np.sum(func(x)[0]) grads = [func(x_input)[1]] - check_numerical_grads(func_forw, [x_input], grads) + tvm.testing.check_numerical_grads(func_forw, [x_input], grads) # Check functions with multiple arguments for f1 in functions: @@ -57,13 +57,13 @@ def test_check_numerical_grads(): func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0]) grads = [f1(x_input)[1], f2(y_input)[1]] - check_numerical_grads(func_forw, [x_input, y_input], grads) + tvm.testing.check_numerical_grads(func_forw, [x_input, y_input], grads) # Same thing but with keyword arguments func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0]) grads = {'x': f1(x_input)[1], 'y': f2(y_input)[1]} - check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads) + tvm.testing.check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads) def _noise1(x, atol=1e-2, rtol=0.1): # We go in random direction using twice the original tolerance to be sure this @@ -93,23 +93,23 @@ def _noise2(x, atol=1e-2, rtol=0.1): grads = [_noise1(f1(x_input)[1]), _noise1(f2(y_input)[1])] try: - check_numerical_grads(func_forw, [x_input, y_input], grads) + tvm.testing.check_numerical_grads(func_forw, [x_input, y_input], grads) except AssertionError as e: pass else: - raise AssertionError("check_numerical_grads didn't raise an exception") + raise AssertionError("tvm.testing.check_numerical_grads didn't raise an exception") func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0]) grads = {'x': _noise2(f1(x_input)[1]), 'y': _noise2(f2(y_input)[1])} try: - check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads) + tvm.testing.check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads) except AssertionError as e: pass else: - raise AssertionError("check_numerical_grads didn't raise an exception") + raise AssertionError("tvm.testing.check_numerical_grads didn't raise an exception") if __name__ == "__main__": - test_check_numerical_grads() + test_tvm.testing.check_numerical_grads() diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py index 2e37de49f243..ec3c762a50c9 100644 --- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py @@ -17,6 +17,7 @@ """Test gpu code verifier""" import tvm from tvm import te +import tvm.testing def get_verify_pass(valid, **kwargs): def _fverify(f, *_): @@ -25,6 +26,7 @@ def _fverify(f, *_): return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0) +@tvm.testing.requires_gpu def test_shared_memory(): def check_shared_memory(dtype): N = 1024 @@ -47,7 +49,7 @@ def check_shared_memory(dtype): # thread usage: M for target in ['opencl', 'cuda']: - if not tvm.context(target).exist: + if not tvm.testing.device_enabled(target): continue valid = [None] with tvm.transform.PassContext(config={"tir.add_lower_pass": [ @@ -66,6 +68,7 @@ def check_shared_memory(dtype): check_shared_memory('float32') check_shared_memory('int8x4') +@tvm.testing.requires_gpu def test_local_memory(): N = 1024 M = 128 @@ -83,7 +86,7 @@ def test_local_memory(): # thread usage: M for target in ['opencl', 'cuda']: - if not tvm.context(target).exist: + if not tvm.testing.device_enabled(target): continue valid = [None] @@ -101,6 +104,7 @@ def test_local_memory(): tvm.build(s, [A, B], target) assert valid[0] +@tvm.testing.requires_gpu def test_num_thread(): N = 1024 M = 128 @@ -118,7 +122,7 @@ def test_num_thread(): # thread usage: N for target in ['opencl', 'cuda']: - if not tvm.context(target).exist: + if not tvm.testing.device_enabled(target): continue valid = [None] @@ -152,6 +156,7 @@ def test_num_thread(): tvm.build(s, [A, B], target) assert valid[0] +@tvm.testing.requires_gpu def test_multiple_kernels(): N = 1024 @@ -168,7 +173,7 @@ def test_multiple_kernels(): # thread usage: N for target in ['opencl', 'cuda']: - if not tvm.context(target).exist: + if not tvm.testing.device_enabled(target): continue valid = [None] @@ -186,6 +191,7 @@ def test_multiple_kernels(): tvm.build(s, [A, C], target) assert valid[0] +@tvm.testing.requires_gpu def test_wrong_bind(): N = 1024 @@ -199,7 +205,7 @@ def test_wrong_bind(): s[B].bind(s[B].op.axis[1], te.thread_axis("threadIdx.x")) for target in ['opencl', 'cuda']: - if not tvm.context(target).exist: + if not tvm.testing.device_enabled(target): continue valid = [None] @@ -208,6 +214,7 @@ def test_wrong_bind(): tvm.build(s, [A, B], target) assert not valid[0] +@tvm.testing.requires_gpu def test_vectorize(): N = 1024 @@ -224,7 +231,7 @@ def test_vectorize(): s[B].vectorize(ji) for target in ['opencl', 'cuda']: - if not tvm.context(target).exist: + if not tvm.testing.device_enabled(target): continue valid = [None] @@ -233,6 +240,7 @@ def test_vectorize(): tvm.lower(s, [A, B]) assert not valid[0] +@tvm.testing.requires_gpu def test_vthread(): N = 1024 @@ -245,7 +253,7 @@ def test_vthread(): s[B].bind(s[B].op.axis[1], te.thread_axis("vthread")) for target in ['opencl', 'cuda']: - if not tvm.context(target).exist: + if not tvm.testing.device_enabled(target): continue valid = [None] diff --git a/tests/python/unittest/test_tir_analysis_verify_memory.py b/tests/python/unittest/test_tir_analysis_verify_memory.py index 386fceb150e3..7022e285a4c0 100644 --- a/tests/python/unittest/test_tir_analysis_verify_memory.py +++ b/tests/python/unittest/test_tir_analysis_verify_memory.py @@ -17,6 +17,7 @@ import tvm import pytest from tvm import te +import tvm.testing # The following DLDeviceType/TVMDeviceExtType values # are originally defined in dlpack.h and c_runtime_api.h. @@ -27,6 +28,7 @@ # All computations are bound. # So VerifyMemory pass is expected to succeed. # +@tvm.testing.uses_gpu def test_verify_memory_all_bind(): n = te.var("n") A = te.placeholder((n,), name='A') @@ -41,15 +43,17 @@ def test_verify_memory_all_bind(): mod = tvm.lower(s, [A, B]) for dev_type in gpu_devices + other_devices: - binded_mod = tvm.tir.transform.Apply( - lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.transform.VerifyMemory()(binded_mod) + if tvm.testing.device_enabled(dev_type): + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + tvm.tir.transform.VerifyMemory()(binded_mod) # Computations are not bound. # So VerifyMemory pass fails when device type is GPU. # +@tvm.testing.uses_gpu def test_verify_memory_not_bind(): n = te.var("n") A = te.placeholder((n,), name='A') @@ -61,20 +65,23 @@ def test_verify_memory_not_bind(): mod = tvm.lower(s, [A, B]) for dev_type in gpu_devices: - binded_mod = tvm.tir.transform.Apply( - lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - with pytest.raises(RuntimeError): - tvm.tir.transform.VerifyMemory()(binded_mod) + if tvm.testing.device_enabled(dev_type): + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + with pytest.raises(RuntimeError): + tvm.tir.transform.VerifyMemory()(binded_mod) for dev_type in other_devices: - binded_mod = tvm.tir.transform.Apply( - lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.transform.VerifyMemory()(binded_mod) + if tvm.testing.device_enabled(dev_type): + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + tvm.tir.transform.VerifyMemory()(binded_mod) # Computations are partially bound. # So VerifyMemory pass fails when device type is GPU. # +@tvm.testing.uses_gpu def test_verify_memory_partially_bind(): n = te.var("n") A = te.placeholder((n,), name='A') @@ -91,15 +98,17 @@ def test_verify_memory_partially_bind(): mod = tvm. lower(s, [A, B, C, D]) for dev_type in gpu_devices: - binded_mod = tvm.tir.transform.Apply( - lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - with pytest.raises(RuntimeError): - tvm.tir.transform.VerifyMemory()(binded_mod) + if tvm.testing.device_enabled(dev_type): + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + with pytest.raises(RuntimeError): + tvm.tir.transform.VerifyMemory()(binded_mod) for dev_type in other_devices: - binded_mod = tvm.tir.transform.Apply( - lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.transform.VerifyMemory()(binded_mod) + if tvm.testing.device_enabled(dev_type): + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + tvm.tir.transform.VerifyMemory()(binded_mod) diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 7ee1e539204b..f7e8f2f5ef87 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -119,6 +119,7 @@ def assert_simplified_equal(index_simplified, index_direct): assert_simplified_equal(index_simplified, index_direct) +@tvm.testing.requires_llvm def test_buffer_broadcast(): m0, m1, m2 = te.size_var("m0"), te.size_var("m1"), te.size_var("m2") n0, n1, n2 = te.size_var("n0"), te.size_var("n1"), te.size_var("n2") @@ -134,8 +135,6 @@ def test_buffer_broadcast(): s = te.create_schedule(C.op) def check(): - if not tvm.runtime.enabled("llvm"): - return fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) ctx = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) @@ -147,6 +146,7 @@ def check(): check() +@tvm.testing.requires_llvm def test_buffer_broadcast_expr(): n0, m0, x = te.size_var('n0'), te.size_var('m0'), te.size_var('x') n1, m1 = te.size_var('n1'), te.size_var('m1') @@ -162,8 +162,6 @@ def test_buffer_broadcast_expr(): s = te.create_schedule(C.op) def check_stride(): - if not tvm.runtime.enabled("llvm"): - return fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add', binds={A:Ab, B:Bb, C:Cc}) ctx = tvm.cpu(0) @@ -174,8 +172,6 @@ def check_stride(): tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) def check_no_stride(): - if not tvm.runtime.enabled("llvm"): - return fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add', binds={A: Ab, B: Bb, C: Cc}) ctx = tvm.cpu(0) @@ -186,8 +182,6 @@ def check_no_stride(): tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) def check_auto_bind(): - if not tvm.runtime.enabled("llvm"): - return # Let build bind buffers fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add') ctx = tvm.cpu(0) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 95047f5344c7..7664806923b8 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -17,6 +17,7 @@ import tvm from tvm import te import numpy as np +import tvm.testing def test_for(): ib = tvm.tir.ir_builder.create() @@ -90,7 +91,7 @@ def test_device_ir(A, B, C): name="vector_add", dtype=dtype) s = te.create_schedule(C.op) def check_target(target): - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): return # build and invoke the kernel. fadd = tvm.build(s, [A, B, C], target) @@ -103,6 +104,7 @@ def check_target(target): tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) check_target("llvm") +@tvm.testing.requires_gpu def test_gpu(): n = te.size_var('n') dtype = "float32" @@ -133,7 +135,7 @@ def test_device_ir(A, B, C): stmt = tvm.te.schedule.ScheduleOps(s, bounds) def check_target(target): n = 1024 - if not tvm.runtime.enabled(target): + if not tvm.testing.device_enabled(target): return # build and invoke the kernel. fadd = tvm.build(s, [A, B, C], target) diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 186a52d12da1..7c93b4ef3305 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -19,7 +19,7 @@ from tvm import relay import numpy as np import pytest -from tvm.relay.testing import ctx_list +from tvm.testing import enabled_targets var_list = [] @@ -711,7 +711,7 @@ def test_hoisting_op_conv(): kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) params = {'w': tvm.nd.array(kernel)} - for target, ctx in ctx_list(): + for target, ctx in enabled_targets(): with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build_module.build(mod, target=target, params=params) m = tvm.contrib.graph_runtime.create(graph, lib, ctx) diff --git a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index fa27fddf4c98..bb35f32a50a2 100644 --- a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -25,6 +25,7 @@ def collect_visit(stmt, f): return ret +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_llvm(index_a, index_b): n = te.size_var("n") @@ -43,6 +44,7 @@ def test_out_of_bounds_llvm(index_a, index_b): c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), ctx) fadd (a, b, c) +@tvm.testing.requires_llvm def test_in_bounds_llvm(): n = te.size_var("n") A = te.placeholder ((n,), name='A') @@ -59,6 +61,7 @@ def test_in_bounds_llvm(): c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), ctx) fadd (a, b, c) +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b): n = tvm.runtime.convert(nn) @@ -80,6 +83,7 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b): c = tvm.nd.array(np.zeros(n, dtype=c.dtype), ctx) f(a, b, c) +@tvm.testing.requires_llvm def test_in_bounds_vectorize_llvm(): n = 512 lanes = 2 @@ -105,6 +109,7 @@ def test_in_bounds_vectorize_llvm(): f(a, c) tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1) +@tvm.testing.requires_llvm def test_in_bounds_loop_partition_basic_llvm(): n = te.size_var('n') A = te.placeholder((n, ), name='A') @@ -122,6 +127,7 @@ def test_in_bounds_loop_partition_basic_llvm(): t = tvm.nd.empty((32,), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b): n = te.size_var('n') @@ -186,6 +192,7 @@ def collect_branch_stmt (x): assert(len(branch_collector) == 2) +@tvm.testing.requires_llvm def test_in_bounds_const_loop_partition_llvm(): with tvm.transform.PassContext(config={ "tir.instrument_bound_checkers": True, @@ -207,6 +214,7 @@ def test_in_bounds_const_loop_partition_llvm(): t = tvm.nd.empty((n,), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b): with tvm.transform.PassContext(config={ @@ -229,6 +237,7 @@ def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b): t = tvm.nd.empty((n,), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm def test_in_bounds_conv_llvm(loop_tiling=False): HSTR = WSTR = 1 in_channel = 128 @@ -264,6 +273,7 @@ def test_in_bounds_conv_llvm(loop_tiling=False): conv_out = tvm.nd.empty ((batch_size, out_channel, out_height, out_width), "float32", ctx) f(data_input, kernel_input, conv_out) +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False): HSTR = WSTR = 1 @@ -307,6 +317,7 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False conv_out = tvm.nd.empty ((batch_size, out_channel, out_height, out_width), "float32", ctx) f(data_input, kernel_input, conv_out) +@tvm.testing.requires_llvm def test_in_bounds_tensors_with_same_shapes1D_llvm(): n = te.size_var('n') k = te.size_var('k') @@ -325,6 +336,7 @@ def test_in_bounds_tensors_with_same_shapes1D_llvm(): t = tvm.nd.empty((32,), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape): n = te.size_var('n') @@ -344,6 +356,7 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape t = tvm.nd.empty((c_shape,), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm def test_in_bounds_tensors_with_same_shapes2D_llvm(): n = te.size_var('n') k = te.size_var('k') @@ -362,6 +375,7 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm(): t = tvm.nd.empty((32, 32), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape): n = te.size_var('n') @@ -381,6 +395,7 @@ def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape t = tvm.nd.empty((c_shape[0],c_shape[1]), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm def test_in_bounds_tensors_with_same_shapes3D_llvm(): n = te.size_var('n') k = te.size_var('k') @@ -400,6 +415,7 @@ def test_in_bounds_tensors_with_same_shapes3D_llvm(): t = tvm.nd.empty((32, 32, 32), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape): n = te.size_var('n') @@ -420,10 +436,9 @@ def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape t = tvm.nd.empty((c_shape[0],c_shape[1],c_shape[2]), T.dtype, ctx) f(a, b, t) +@tvm.testing.requires_llvm @pytest.mark.xfail def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm(): - if not tvm.runtime.enabled("llvm"): - return n = 64 A = te.placeholder((n, ), name='A') scale = te.placeholder((), name='scale') diff --git a/tests/python/unittest/test_tir_transform_lower_intrin.py b/tests/python/unittest/test_tir_transform_lower_intrin.py index 98c79e414226..fbd4ce62efe8 100644 --- a/tests/python/unittest/test_tir_transform_lower_intrin.py +++ b/tests/python/unittest/test_tir_transform_lower_intrin.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te import numpy as np @@ -47,9 +48,6 @@ def make_binds(i): C = te.compute((n,), make_binds) s = te.create_schedule([C.op]) - if not tvm.runtime.enabled("llvm"): - return - f = tvm.build(s, [A, B, C], "llvm") a = tvm.nd.array(np.array([x for x, y in data], dtype=expr.dtype)) b = tvm.nd.array(np.array([y for x, y in data], dtype=expr.dtype)) @@ -69,6 +67,7 @@ def get_ref_data(): return list(itertools.product(x, y)) +@tvm.testing.requires_llvm def test_lower_floordiv(): data = get_ref_data() for dtype in ["int32", "int64", "int16"]: @@ -92,6 +91,7 @@ def test_lower_floordiv(): check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b) +@tvm.testing.requires_llvm def test_lower_floormod(): data = get_ref_data() for dtype in ["int32", "int64", "int16"]: diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 5801200c15da..eecc7f1bc4e9 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -19,7 +19,9 @@ from tvm.contrib.nvcc import have_fp16 import numpy as np +import tvm.testing +@tvm.testing.requires_cuda def test_lower_warp_memory_local_scope(): m = 128 A = te.placeholder((m,), name='A') @@ -47,6 +49,7 @@ def test_lower_warp_memory_local_scope(): assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.body.extents[0].value == 2) +@tvm.testing.requires_cuda def test_lower_warp_memory_correct_indices(): n = 32 A = te.placeholder((2, n, n), name='A', dtype="float32") @@ -83,11 +86,10 @@ def test_lower_warp_memory_correct_indices(): assert "threadIdx.x" in idx_names assert "threadIdx.y" not in idx_names +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_lower_warp_memory_cuda_end_to_end(): def check_cuda(dtype): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): print("Skip because gpu does not have fp16 support") return @@ -127,11 +129,10 @@ def check_cuda(dtype): check_cuda("float32") check_cuda("float16") +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_lower_warp_memory_cuda_half_a_warp(): def check_cuda(dtype): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): print("Skip because gpu does not have fp16 support") return @@ -170,11 +171,10 @@ def check_cuda(dtype): check_cuda("float32") check_cuda("float16") +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_lower_warp_memory_cuda_2_buffers(): def check_cuda(dtype): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): print("Skip because gpu does not have fp16 support") return @@ -218,6 +218,7 @@ def check_cuda(dtype): check_cuda("float32") check_cuda("float16") +@tvm.testing.requires_gpu def test_lower_warp_memory_roundup(): def check(device, m): A = te.placeholder((m,), name='A') @@ -246,7 +247,7 @@ def check(device, m): tvm.testing.assert_allclose(B_nd.asnumpy(), B_np) for device in ['cuda', 'rocm']: - if not tvm.context(device, 0).exist or not tvm.runtime.enabled(device): + if not tvm.testing.device_enabled(device): print("skip because", device,"is not enabled..") continue check(device, m=31) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 3ff6804cf7e0..75b3193ecf3f 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -16,7 +16,9 @@ # under the License. import tvm from tvm import te +import tvm.testing +@tvm.testing.requires_cuda def test_thread_storage_sync(): m = te.size_var('m') l = te.size_var('l') diff --git a/tests/scripts/setup-pytest-env.sh b/tests/scripts/setup-pytest-env.sh index 61c079aa4744..475ce1ce1c53 100755 --- a/tests/scripts/setup-pytest-env.sh +++ b/tests/scripts/setup-pytest-env.sh @@ -20,9 +20,9 @@ set +u if [[ ! -z $CI_PYTEST_ADD_OPTIONS ]]; then - export PYTEST_ADDOPTS="-v $CI_PYTEST_ADD_OPTIONS" + export PYTEST_ADDOPTS="-v $CI_PYTEST_ADD_OPTIONS $PYTEST_ADDOPTS" else - export PYTEST_ADDOPTS="-v " + export PYTEST_ADDOPTS="-v $PYTEST_ADDOPTS" fi set -u diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index e5f9b20e3325..3c5839bc7e1c 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -24,6 +24,8 @@ source tests/scripts/setup-pytest-env.sh export TVM_BIND_THREADS=0 export OMP_NUM_THREADS=1 +export TVM_TEST_TARGETS="llvm;cuda" + find . -type f -path "*.pyc" | xargs rm -f # Rebuild cython diff --git a/tests/scripts/task_python_frontend_cpu.sh b/tests/scripts/task_python_frontend_cpu.sh index 10354e588720..6dfcabc2cd37 100755 --- a/tests/scripts/task_python_frontend_cpu.sh +++ b/tests/scripts/task_python_frontend_cpu.sh @@ -25,6 +25,8 @@ source tests/scripts/setup-pytest-env.sh export TVM_BIND_THREADS=0 export OMP_NUM_THREADS=1 +export TVM_TEST_TARGETS="llvm" + find . -type f -path "*.pyc" | xargs rm -f # Rebuild cython diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index d61895c45973..741f15ba4a94 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -63,7 +63,7 @@ TVM_FFI=ctypes python3 -m pytest apps/dso_plugin_module TVM_FFI=ctypes python3 -m pytest tests/python/integration TVM_FFI=ctypes python3 -m pytest tests/python/contrib -TVM_FFI=ctypes python3 -m pytest tests/python/relay +TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" TVM_FFI=ctypes python3 -m pytest tests/python/relay # Do not enable OpenGL # TVM_FFI=cython python -m pytest tests/webgl diff --git a/tests/scripts/task_python_integration_gpuonly.sh b/tests/scripts/task_python_integration_gpuonly.sh index 6b2755a66f50..c2a9e0c15abe 100755 --- a/tests/scripts/task_python_integration_gpuonly.sh +++ b/tests/scripts/task_python_integration_gpuonly.sh @@ -16,4 +16,8 @@ # specific language governing permissions and limitations # under the License. +export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;vulkan;nvptx;opencl -device=mali,aocl_sw_emu" +export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" +export TVM_RELAY_TEST_TARGETS="cuda" + ./tests/scripts/task_python_integration.sh diff --git a/tests/scripts/task_python_unittest_gpuonly.sh b/tests/scripts/task_python_unittest_gpuonly.sh index 637be67626f0..56722b16a364 100755 --- a/tests/scripts/task_python_unittest_gpuonly.sh +++ b/tests/scripts/task_python_unittest_gpuonly.sh @@ -16,4 +16,7 @@ # specific language governing permissions and limitations # under the License. +export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;vulkan;nvptx;opencl -device=mali,aocl_sw_emu" +export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" + ./tests/scripts/task_python_unittest.sh diff --git a/tutorials/frontend/deploy_ssd_gluoncv.py b/tutorials/frontend/deploy_ssd_gluoncv.py index 46162e116496..3643c8db0a33 100644 --- a/tutorials/frontend/deploy_ssd_gluoncv.py +++ b/tutorials/frontend/deploy_ssd_gluoncv.py @@ -27,7 +27,6 @@ from tvm import te from matplotlib import pyplot as plt -from tvm.relay.testing.config import ctx_list from tvm import relay from tvm.contrib import graph_runtime from tvm.contrib.download import download_testdata @@ -70,7 +69,6 @@ model_name = supported_model[0] dshape = (1, 3, 512, 512) -target_list = ctx_list() ###################################################################### # Download and pre-process demo image @@ -105,9 +103,11 @@ def run(lib, ctx): class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2) return class_IDs, scores, bounding_boxs -for target, ctx in target_list: - lib = build(target) - class_IDs, scores, bounding_boxs = run(lib, ctx) +for target in ["llvm", "cuda"]: + ctx = tvm.context(target, 0) + if ctx.exist: + lib = build(target) + class_IDs, scores, bounding_boxs = run(lib, ctx) ###################################################################### # Display result