diff --git a/tests/python/unittest/test_target_codegen_bool.py b/tests/python/unittest/test_target_codegen_bool.py index bd6cf27ccaa4..0b6616537430 100644 --- a/tests/python/unittest/test_target_codegen_bool.py +++ b/tests/python/unittest/test_target_codegen_bool.py @@ -21,59 +21,54 @@ import numpy as np import tvm.testing +arr_size = tvm.testing.parameter(32) -@tvm.testing.uses_gpu -def test_cmp_load_store(): - n = 32 - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") + +@tvm.testing.fixture +def compute(arr_size): + A = te.placeholder((arr_size,), name="A") + B = te.placeholder((arr_size,), name="B") C = te.compute(A.shape, lambda *i: A(*i) > B(*i), name="C") D = te.compute(C.shape, lambda *i: tvm.tir.all(C(*i), A(*i) > 1).astype("float32"), name="D") + return [A, B, C, D] - def check_llvm(): - if not tvm.testing.device_enabled("llvm"): - return + +@tvm.testing.fixture +def schedule(target, compute): + target = tvm.target.Target(target) + A, B, C, D = compute + if target.kind.name == "llvm": s = te.create_schedule(D.op) xo, xi = s[C].split(C.op.axis[0], factor=4) xo1, xo2 = s[C].split(xo, factor=13) s[C].parallel(xo2) - # BUILD and invoke the kernel. - f = tvm.build(s, [A, B, D], "llvm") - dev = tvm.cpu(0) - a_np = np.random.uniform(size=n).astype(A.dtype) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) - d = tvm.nd.array(np.zeros(n, dtype=D.dtype), dev) - f(a, b, d) - np.testing.assert_equal( - d.numpy(), - np.logical_and(a.numpy() > b.numpy(), a.numpy() > 1).astype("float32"), - ) - def check_device(device): - if not tvm.testing.device_enabled(device): - return - dev = tvm.device(device, 0) + else: s = te.create_schedule(D.op) for stage in [C, D]: xo, xi = s[stage].split(stage.op.axis[0], factor=4) s[stage].bind(xo, te.thread_axis("blockIdx.x")) s[stage].bind(xi, te.thread_axis("threadIdx.x")) - f = tvm.build(s, [A, B, D], device) - a_np = np.random.uniform(size=n).astype(A.dtype) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) - d = tvm.nd.array(np.zeros(n, dtype=D.dtype), dev) - f(a, b, d) - np.testing.assert_equal( - d.numpy(), - np.logical_and(a.numpy() > b.numpy(), a.numpy() > 1).astype("float32"), - ) - check_llvm() - for device in ["vulkan", "opencl", "cuda", "rocm", "metal"]: - check_device(device) + return s + + +@tvm.testing.uses_gpu +def test_cmp_load_store(target, dev, arr_size, compute, schedule): + A, B, _, D = compute + f = tvm.build(schedule, [A, B, D], target) + + a_np = np.random.uniform(size=arr_size).astype(A.dtype) + b_np = np.random.uniform(size=arr_size).astype(B.dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + d = tvm.nd.array(np.zeros(arr_size, dtype=D.dtype), dev) + f(a, b, d) + np.testing.assert_equal( + d.numpy(), + np.logical_and(a_np > b_np, a_np > 1).astype("float32"), + ) if __name__ == "__main__": - test_cmp_load_store() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index c8cddf8b9598..0551fcd54855 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. +import random import re import sys +import threading import numpy as np +import pytest import tvm import tvm.testing @@ -26,102 +29,92 @@ from tvm.topi.math import cast -def check_mod(mod, x_np, res_np): - target = "vulkan" - dev = tvm.device(target, 0) - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) - res = ex.evaluate()(x_np).numpy() - tvm.testing.assert_allclose(res, res_np, atol=1e-5) +def randint_loguniform(low=1, high=32768, size=None): + logN = np.random.uniform(low=np.log(low), high=np.log(high), size=size) + N = np.exp(logN).astype(int) + return np.unique(N) -@tvm.testing.requires_vulkan -def test_vector_comparison(): - target = "vulkan" +dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") +fuzz_arr_size = tvm.testing.parameter(*randint_loguniform(size=25)) + + +# Explicitly specify a target, as this test is looking at the +# generated shader code, and is not running on an actual device. +@tvm.testing.parametrize_targets( + " ".join( + [ + "vulkan", + "-supports_int8=1", + "-supports_8bit_buffer=1", + "-supports_storage_buffer_storage_class=1", + "-supports_float16=1", + "-supports_16bit_buffer=1", + ] + ) +) +def test_vector_comparison(target, dtype): + n = (1024,) + A = te.placeholder(n, dtype=dtype, name="A") + B = te.compute( + A.shape, + lambda i: tvm.tir.Select( + A[i] >= 0, A[i] + tvm.tir.const(1, dtype), tvm.tir.const(0, dtype) + ), + name="B", + ) + s = te.create_schedule(B.op) + + (bx, tx) = s[B].split(s[B].op.axis[0], factor=128) + (tx, vx) = s[B].split(tx, factor=4) + s[B].bind(bx, te.thread_axis("blockIdx.x")) + s[B].bind(tx, te.thread_axis("threadIdx.x")) + s[B].vectorize(vx) + f = tvm.build(s, [A, B], target) + + # Verify we generate the boolx4 type declaration and the OpSelect + # v4{float,half,int} instruction + assembly = f.imported_modules[0].get_source() + matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly) + assert len(matches) == 1 + matches = re.findall("OpSelect %v4.*", assembly) + assert len(matches) == 1 + + +def test_array_copy(dev, dtype, fuzz_arr_size): + a_np = np.random.uniform(size=(fuzz_arr_size,)).astype(dtype) + a = tvm.nd.empty((fuzz_arr_size,), dtype, dev).copyfrom(a_np) + b_np = a.numpy() + tvm.testing.assert_allclose(a_np, b_np) + tvm.testing.assert_allclose(a_np, a.numpy()) + + +@tvm.testing.exclude_targets("llvm") +def test_array_vectorize_add(target, dev, dtype): + arr_size = 64 + lanes = 2 - def check_correct_assembly(dtype): - n = (1024,) - A = te.placeholder(n, dtype=dtype, name="A") - B = te.compute( - A.shape, - lambda i: tvm.tir.Select( - A[i] >= 0, A[i] + tvm.tir.const(1, dtype), tvm.tir.const(0, dtype) - ), - name="B", - ) - s = te.create_schedule(B.op) - - (bx, tx) = s[B].split(s[B].op.axis[0], factor=128) - (tx, vx) = s[B].split(tx, factor=4) - s[B].bind(bx, te.thread_axis("blockIdx.x")) - s[B].bind(tx, te.thread_axis("threadIdx.x")) - s[B].vectorize(vx) - f = tvm.build(s, [A, B], target) - - # Verify we generate the boolx4 type declaration and the OpSelect - # v4{float,half,int} instruction - assembly = f.imported_modules[0].get_source() - matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly) - assert len(matches) == 1 - matches = re.findall("OpSelect %v4.*", assembly) - assert len(matches) == 1 - - check_correct_assembly("float32") - check_correct_assembly("int32") - check_correct_assembly("float16") - - -tx = te.thread_axis("threadIdx.x") -bx = te.thread_axis("blockIdx.x") - - -@tvm.testing.requires_vulkan -def test_vulkan_copy(): - def check_vulkan(dtype, n): - A = te.placeholder((n,), name="A", dtype=dtype) - dev = tvm.vulkan(0) - a_np = np.random.uniform(size=(n,)).astype(A.dtype) - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(a_np) - b_np = a.numpy() - tvm.testing.assert_allclose(a_np, b_np) - tvm.testing.assert_allclose(a_np, a.numpy()) - - for _ in range(100): - dtype = np.random.choice(["float32", "float16", "int8", "int32"]) - logN = np.random.randint(1, 15) - peturb = np.random.uniform(low=0.5, high=1.5) - check_vulkan(dtype, int(peturb * (2 ** logN))) - - -@tvm.testing.requires_vulkan -def test_vulkan_vectorize_add(): num_thread = 8 - def check_vulkan(dtype, n, lanes): - A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes)) - B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") - s = te.create_schedule(B.op) - xo, xi = s[B].split(B.op.axis[0], factor=num_thread) - s[B].bind(xo, bx) - s[B].bind(xi, tx) - fun = tvm.build(s, [A, B], "vulkan") - dev = tvm.vulkan(0) - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) - c = tvm.nd.empty((n,), B.dtype, dev) - fun(a, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) - - check_vulkan("float32", 64, 2) - check_vulkan("float16", 64, 2) - - -@tvm.testing.requires_vulkan -def test_vulkan_stress(): + A = te.placeholder((arr_size,), name="A", dtype="%sx%d" % (dtype, lanes)) + B = te.compute((arr_size,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") + s = te.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=num_thread) + s[B].bind(xo, te.thread_axis("blockIdx.x")) + s[B].bind(xi, te.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, B], target) + a = tvm.nd.empty((arr_size,), A.dtype, dev).copyfrom(np.random.uniform(size=(arr_size, lanes))) + c = tvm.nd.empty((arr_size,), B.dtype, dev) + fun(a, c) + tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) + + +@tvm.testing.parametrize_targets("vulkan") +def test_vulkan_stress(target, dev): """ Launch a randomized test with multiple kernels per stream, multiple uses of kernels per stream, over multiple threads. """ - import random - import threading n = 1024 num_thread = 64 @@ -144,15 +137,14 @@ def build_f(f_ref): C = C_f() s = te.create_schedule(C.op) xo, xi = s[C].split(C.op.axis[0], factor=num_thread) - s[C].bind(xo, bx) - s[C].bind(xi, tx) - fun = tvm.build(s, [A, B, C], "vulkan") + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(xi, te.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, B, C], target) return (fun, ref) fs = [ build_f(random.choice(functions)) for _ in range(np.random.randint(low=1, high=10)) ] - dev = tvm.vulkan(0) a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n,))) b = tvm.nd.empty((n,), B.dtype, dev).copyfrom(np.random.uniform(size=(n,))) cs = [tvm.nd.empty((n,), A.dtype, dev) for _ in fs] @@ -171,8 +163,20 @@ def build_f(f_ref): run_stress() -@tvm.testing.requires_vulkan -def test_vulkan_bool_load(): +@tvm.testing.exclude_targets("llvm") +def test_vulkan_bool_load(target, dev): + arr_size = 1024 + + target = tvm.target.Target(target) + if target.kind.name == "vulkan": + supports_int8_buffer = target.attrs.get("supports_int8", False) and target.attrs.get( + "supports_8bit_buffer", False + ) + if not supports_int8_buffer: + pytest.xfail( + "Vulkan target does not support int8 buffer access, used to transfer booleans" + ) + def do_copy(A, B, n): ib = tvm.tir.ir_builder.create() A = ib.buffer_ptr(A) @@ -191,16 +195,13 @@ def do_copy(A, B, n): return ib.get() - n = 1024 - A = te.placeholder((n,), name="A", dtype="bool") - B = te.placeholder((n,), name="B", dtype="int32") - - target = "vulkan" + A = te.placeholder((arr_size,), name="A", dtype="bool") + B = te.placeholder((arr_size,), name="B", dtype="int32") B = te.extern( A.shape, [A], - lambda ins, outs: do_copy(ins[0], outs[0], n), + lambda ins, outs: do_copy(ins[0], outs[0], arr_size), name="bool_copy_ir", dtype="int32", ) @@ -209,9 +210,8 @@ def do_copy(A, B, n): with tvm.transform.PassContext(opt_level=3): func = tvm.build(s, [A, B], target) - dev = tvm.device(target, 0) - a_np = np.random.uniform(size=n) > 0.5 - b_np = np.zeros((n,), dtype="int32") + a_np = np.random.uniform(size=arr_size) > 0.5 + b_np = np.zeros((arr_size,), dtype="int32") a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) func(a, b) @@ -219,8 +219,13 @@ def do_copy(A, B, n): tvm.testing.assert_allclose(b.numpy(), ref) -@tvm.testing.requires_vulkan -def test_vulkan_pushconstants(): +def check_mod(target, dev, mod, x_np, res_np): + ex = relay.create_executor("vm", mod=mod, device=dev, target=target) + res = ex.evaluate()(x_np).numpy() + tvm.testing.assert_allclose(res, res_np, atol=1e-5) + + +def test_sqrt(target, dev): # Three 32 bit pushconstants: any_dim, stride, stride dtype = "float32" x = relay.var("x", shape=(relay.Any(),), dtype=dtype) @@ -229,8 +234,10 @@ def test_vulkan_pushconstants(): x_np = np.random.uniform(size=(10,)).astype(dtype) res_np = np.sqrt(x_np) - check_mod(mod, x_np, res_np) + check_mod(target, dev, mod, x_np, res_np) + +def test_argsort(target, dev): # One 64 bit and one 32 bit constants dtype = "int32" x = relay.var("x", shape=(relay.Any(),), dtype=dtype) @@ -239,8 +246,10 @@ def test_vulkan_pushconstants(): x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) res_np = np.argsort(x_np) - check_mod(mod, x_np, res_np) + check_mod(target, dev, mod, x_np, res_np) + +def test_cumsum(target, dev): # One 64 bit and one 32 bit constants dtype = "int32" x = relay.var("x", shape=(relay.Any(),), dtype=dtype) @@ -249,11 +258,10 @@ def test_vulkan_pushconstants(): x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) res_np = np.cumsum(x_np) - check_mod(mod, x_np, res_np) + check_mod(target, dev, mod, x_np, res_np) -@tvm.testing.requires_vulkan -def test_vulkan_unique(): +def test_unique(target, dev): dtype = "int32" x = relay.var("x", shape=(relay.Any(),), dtype=dtype) mod = tvm.IRModule() @@ -261,62 +269,70 @@ def test_vulkan_unique(): mod["main"] = relay.Function([x], relay.op.strided_slice(unique, begin=[0], end=num_unique)) x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) res_np = np.unique(x_np) - check_mod(mod, x_np, res_np) - + check_mod(target, dev, mod, x_np, res_np) -@tvm.testing.requires_vulkan -def test_vulkan_constant_passing(): - target = "vulkan" - def test_scalar_params(num_int_params): - n = te.var("n") - scalars = [te.var("scale{}".format(i)) for i in range(num_int_params)] - scalar_sum = scalars[0] - for s in scalars[1:]: - scalar_sum += s +vulkan_parameter_impl = tvm.testing.parameter("push_constants", "ubo") +vulkan_parameter_dtype = tvm.testing.parameter("int32", "float32", "int64") - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda i: scalar_sum + A[i], name="B") +# Only run on vulkan because extremely large numbers of input +# parameters can crash cuda/llvm compiler. +@tvm.testing.parametrize_targets("vulkan -from_device=0") +def test_vulkan_constant_passing(target, dev, vulkan_parameter_impl, vulkan_parameter_dtype): + target = tvm.target.Target(target) + dtype = vulkan_parameter_dtype - s = te.create_schedule(B.op) - xo, xi = s[B].split(B.op.axis[0], factor=64) - s[B].bind(xo, bx) - s[B].bind(xi, tx) - f_add = tvm.build(s, scalars + [A, B], target) - - n = 1024 - scalars = [1 for _ in scalars] - dev = tvm.vulkan(0) - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev) - f_add(*scalars, a, b) - - tvm.testing.assert_allclose(a.numpy() + sum(scalars), b.numpy()) + if not target.attrs.get("supports_int64", False): + pytest.xfail("Vulkan target does not support Int64 variables") # f_add has 3+num_int_params scalar parameters. The other three # are length_n, stride1, and stride2. + if vulkan_parameter_impl == "push_constants": + # 4 params, 32 bytes. Within 128-byte spec-guaranteed size of + # push constants. Uses push constants. + num_int_params = 1 + else: + # 24 params, 192 bytes. May be above spec-guaranteed size of 128 + # bytes for push constants. Uses either push constants or UBO, + # depending on the device. + max_push_constants_size = int(target.attrs.get("max_push_constants_size", 128)) + max_int_params_in_push = max_push_constants_size // 8 - 3 + num_int_params = max_int_params_in_push + 1 - # 4 params, 32 bytes. Within 128-byte spec-guaranteed size of - # push constants. Uses push constants. - test_scalar_params(1) + n = te.var("n") + scalars = [te.var("scale{}".format(i), dtype=dtype) for i in range(num_int_params)] + scalar_sum = scalars[0] + for s in scalars[1:]: + scalar_sum += s - # 24 params, 192 bytes. Too big for push constants, uses uniform - # buffer. - test_scalar_params(20) + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.compute(A.shape, lambda i: scalar_sum + A[i], name="B") - # 2047 params, 16376 bytes, just below 16kB of uniform buffer - # space guaranteed by the vulkan spec. - test_scalar_params(2044) + s = te.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=64) + s[B].bind(xo, te.thread_axis("blockIdx.x")) + s[B].bind(xi, te.thread_axis("threadIdx.x")) + f_add = tvm.build(s, scalars + [A, B], target) + + n = 1024 + scalars = np.array([1 for _ in scalars]).astype(dtype) + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev) + f_add(*scalars, a, b) + + tvm.testing.assert_allclose(a.numpy() + sum(scalars), b.numpy()) -@tvm.testing.parametrize_targets("vulkan") def test_vulkan_while_if(target, dev): + target = tvm.target.Target(target) + def do_compute(A, B, n): ib = tvm.tir.ir_builder.create() A = ib.buffer_ptr(A) B = ib.buffer_ptr(B) - ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) + if "gpu" in target.keys: + ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) iterations = ib.allocate("int32", (1,), name="iterations", scope="local") iterations[0] = 0 @@ -359,7 +375,7 @@ def do_compute(A, B, n): tvm.testing.assert_allclose(b.numpy(), [210]) -@tvm.testing.parametrize_targets("vulkan") +@tvm.testing.exclude_targets("llvm") def test_vulkan_local_threadidx(target, dev): # To access the thread index, the vulkan runtime accesses a global # array of thread indices, storing the result in a local variable.