Skip to content

Commit

Permalink
[TESTS] Refactor tests to run on either the GPU or CPU.
Browse files Browse the repository at this point in the history
Much of the time spent in testing is duplicated work between CPU and GPU
test nodes. The main reason is that there is no way to control which
TVM devices are enabled at runtime, so tests that use LLVM will run on
both GPU and CPU nodes.

This patch adds an environment variable, TVM_TEST_DEVICES, which
controls which TVM devices should be used by tests. Devices not in
TVM_TEST_DEVICES can still be used, so tests must be careful to check
that the desired device is enabled with `tvm.testing.device_enabled` or
by enumerating all devices with `tvm.testing.enabled_devices`. All
tests have been retrofitted with these checks.

This patch also provides the decorator `@tvm.testing.gpu` to mark a test
as possibly using the gpu. Tests that require the gpu can use
`@tvm.testing.requires_gpu`. Tests without these flags will not be run
on GPU nodes.
  • Loading branch information
tkonolige committed Aug 24, 2020
1 parent 37cbbd7 commit 324066a
Show file tree
Hide file tree
Showing 123 changed files with 1,527 additions and 1,008 deletions.
8 changes: 4 additions & 4 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ stage('Unit Test') {
unpack_lib('gpu', tvm_multilib)
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${ci_gpu} ./tests/scripts/task_sphinx_precheck.sh"
sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest.sh"
sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration.sh"
sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest.sh gpu"
sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration.sh gpu"
}
}
}
Expand All @@ -214,8 +214,8 @@ stage('Unit Test') {
init_git()
unpack_lib('i386', tvm_multilib)
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_unittest.sh"
sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration.sh"
sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_unittest.sh cpu"
sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration.sh cpu"
sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_vta_fsim.sh"
}
}
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=invalid-name
import ctypes
import json
import os
import numpy as np
from .base import _LIB, check_call

Expand Down Expand Up @@ -182,6 +183,7 @@ class TVMContext(ctypes.Structure):
'hexagon': 14,
'webgpu': 15,
}

def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_type = device_type
Expand All @@ -197,6 +199,10 @@ def _GetDeviceAttr(self, device_type, device_id, attr_id):
@property
def exist(self):
"""Whether this device exist."""
allowed_ctxs = os.environ.get("TVM_TEST_CTXS")
if allowed_ctxs is not None:
if self.device_type not in set(allowed_ctxs.split(",")):
return False
return self._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tvm.relay as relay
import tvm.relay.op as op
from tvm.relay import Prelude
from tvm.testing import enabled_devices

from . import mlp
from . import resnet
Expand All @@ -41,7 +42,6 @@
from . import temp_op_attr
from . import synthetic

from .config import ctx_list
from .init import create_workload
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
from .py_converter import to_python, run_as_python
Expand Down Expand Up @@ -125,7 +125,7 @@ def check_grad(func,
if test_inputs is None:
test_inputs = inputs

for target, ctx in ctx_list():
for target, ctx in enabled_devices():
intrp = relay.create_executor(ctx=ctx, target=target)

# Get analytic gradients.
Expand Down
104 changes: 104 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
# pylint: disable=invalid-name,unnecessary-comprehension
""" TVM testing utilities """
import logging
import os
import pytest
import numpy as np
import tvm
import tvm.arith
import tvm.tir
import tvm.te
import tvm._ffi
from tvm.contrib import nvcc


def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
Expand Down Expand Up @@ -285,4 +288,105 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap):
constraints_trans.dst_to_src, constraints_trans.src_to_dst)


def gpu(f):
"""Mark to differentiate tests that use the GPU is some capacity. These
tests will be run on CPU-only nodes and on nodes with GPUS.
To mark a test that must have a GPU present to run, use `@requires_gpu`.
"""
return pytest.mark.gpu(f)


def requires_gpu(f):
"""Mark a test as requiring a GPU to run. Tests with this mark will not be
run unless a gpu is present.
"""
return pytest.mark.skipif(not tvm.gpu().exist, reason="No GPU present")(gpu(f))


def requires_cuda(f):
"""Mark a test as requiring the CUDA runtime. This does not mean the tests
also requires a gpu. For that, use `@requires_gpu` and `@requires_cuda`
"""
return pytest.mark.cuda(
pytest.mark.skipif(
not tvm.runtime.enabled("cuda"), reason="CUDA support not enabled"
)(requires_gpu(f))
)


def requires_opencl(f):
"""Mark a test as requiring the OpenCL runtime. This does not mean the tests
also requires a gpu. For that, use `@requires_gpu` and `@requires_cuda`.
"""
return pytest.mark.opencl(
pytest.mark.skipif(
not tvm.runtime.enabled("opencl"), reason="OpenCL support not enabled"
)(f)
)


def requires_tpu(f):
"""Mark a test as requiring a TPU to run. Tests with this mark will not be
run unless a tpu is present.
"""
return pytest.mark.tpu(
pytest.mark.skipif(
not tvm.gpu().exist or not nvcc.have_tensorcore(tvm.gpu(0).compute_version),
reason="No TPU present",
)(f)
)


def _get_backends():
backend_str = os.environ.get("TVM_TEST_DEVICES", "")
if len(backend_str) == 0:
backend_str = DEFAULT_TEST_DEVICES
backends = {
dev
for dev in backend_str.split(";")
if len(dev) > 0 and tvm.context(dev, 0).exist and tvm.runtime.enabled(dev)
}
if len(backends) == 0:
logging.warning(
"None of the following backends are supported by this build of TVM: %s."
"Try setting TVM_TEST_DEVICES to a supported backend. Defaulting to llvm.",
backend_str
)
return {"llvm"}
return backends


DEFAULT_TEST_DEVICES = (
"llvm;cuda;opencl;metal;rocm;vulkan;nvptx;"
"llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
)
TEST_DEVICES = _get_backends()


def device_enabled(device):
"""Check if a device should be used when testing. This allows the user to
control which devices they are testing against. In tests, this should be
used to check if a device should be used when said device is an optional
part of the test.
"""
return device in TEST_DEVICES


def enabled_devices():
"""Get all enabled devices with associated contexts. In this context,
enabled means that TVM was built with support for this device and the
device name appears in the TVM_TEST_DEVICES environment variable.
If TVM_TEST_DEVICES is not set, it defaults to variable
DEFAULT_TEST_DEVICES in this module.
Returns
-------
targets: list
A list of pairs of all enabled devices and the associated context
"""
return [(dev, tvm.context(dev, 0)) for dev in TEST_DEVICES]


tvm._ffi._init_api("testing", __name__)
1 change: 1 addition & 0 deletions tests/lint/check_file_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
"docs/_static/css/tvm_theme.css",
"docs/_static/img/tvm-logo-small.png",
"docs/_static/img/tvm-logo-square.png",
"tests/python/pytest.ini",
}


Expand Down
17 changes: 3 additions & 14 deletions python/tvm/relay/testing/config.py → tests/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configuration about tests"""
from __future__ import absolute_import as _abs
import tvm.testing

import os
import tvm


def ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("RELAY_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [(device, tvm.context(device, 0)) for device in device_list]
return [x for x in res if x[1].exist]
def pytest_configure(config):
print("Enabled backends:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_devices())))
7 changes: 4 additions & 3 deletions tests/python/contrib/test_cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm.contrib import cblas
from tvm.contrib import mkl
from tvm.contrib import mkldnn
from tvm.testing import device_enabled

def verify_matmul_add(m, l, n, lib, transa=False, transb=False, dtype="float32"):
bias = te.var('bias', dtype=dtype)
Expand All @@ -41,7 +42,7 @@ def get_numpy(a, b, bb, transa, transb):
return np.dot(a, b) + bb

def verify(target="llvm"):
if not tvm.runtime.enabled(target):
if not device_enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func(lib.__name__ + ".matmul", True):
Expand Down Expand Up @@ -107,7 +108,7 @@ def get_numpy(a, b, bb, transa, transb):
return np.dot(a, b) + bb

def verify(target="llvm"):
if not tvm.runtime.enabled(target):
if not device_enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.mkl.matmul_u8s8s32", True):
Expand Down Expand Up @@ -153,7 +154,7 @@ def get_numpy(a, b, transa, transb):
return tvm.topi.testing.batch_matmul(a, b)

def verify(target="llvm"):
if not tvm.runtime.enabled(target):
if not device_enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func(lib.__name__ + ".matmul", True):
Expand Down
13 changes: 4 additions & 9 deletions tests/python/contrib/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
from tvm.contrib import cublas
from tvm.contrib import cublaslt
from tvm.testing import requires_cuda

def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
n = 1024
Expand All @@ -30,9 +31,6 @@ def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
s = te.create_schedule(C.op)

def verify(target="cuda"):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not available")
return
Expand Down Expand Up @@ -64,9 +62,6 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
s = te.create_schedule(C.op)

def verify(target="cuda"):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True):
print("skip because extern function is not available")
return
Expand Down Expand Up @@ -115,9 +110,6 @@ def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
s = te.create_schedule(C.op)

def verify(target="cuda"):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not available")
return
Expand All @@ -132,15 +124,18 @@ def verify(target="cuda"):
b.asnumpy().astype(C.dtype)).astype(C.dtype), rtol=rtol)
verify()

@requires_cuda
def test_matmul_add():
verify_matmul_add('float', 'float', rtol=1e-3)
verify_matmul_add('float16', 'float')
verify_matmul_add('float16', 'float16', rtol=1e-2)
verify_matmul_add('int8', 'int32')

@requires_cuda
def test_matmul_add_igemm():
verify_matmul_add_igemm('int8', 'int32')

@requires_cuda
def test_batch_matmul():
verify_batch_matmul('float', 'float')
verify_batch_matmul('float16', 'float')
Expand Down
13 changes: 4 additions & 9 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.contrib.nvcc import have_fp16
import numpy as np
import tvm.topi.testing
from tvm.testing import requires_gpu

def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
in_channel = 4
Expand All @@ -36,9 +37,6 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
height = 32
width = 32

if not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return
Expand Down Expand Up @@ -87,6 +85,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
f(x, w, y)
tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-2, rtol=1e-2)

@requires_gpu
def test_conv2d():
verify_conv2d("float32", "float32", tensor_format=0)
verify_conv2d("float16", "float32", tensor_format=1)
Expand Down Expand Up @@ -118,9 +117,6 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
height = 32
width = 32

if not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return
Expand Down Expand Up @@ -161,6 +157,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
f(x, w, y)
tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4)

@requires_gpu
def test_conv3d():
verify_conv3d("float32", "float32", tensor_format=0)
verify_conv3d("float32", "float32", tensor_format=0, groups=2)
Expand Down Expand Up @@ -195,10 +192,8 @@ def verify_softmax_4d(shape, dtype="float32"):
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3)

@requires_gpu
def test_softmax():
if not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return
Expand Down
Loading

0 comments on commit 324066a

Please sign in to comment.