Skip to content

Commit

Permalink
[Vulkan][Unittests] Add parametrization to vulkan unit tests. (apache…
Browse files Browse the repository at this point in the history
…#8348)

This also switches to using `vulkan -from_device=0` by default, and
marks tests as `pytest.xfail` if the device does not support the
functionality being tested.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 2c1daf6 commit 47d56a1
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 186 deletions.
71 changes: 33 additions & 38 deletions tests/python/unittest/test_target_codegen_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,59 +21,54 @@
import numpy as np
import tvm.testing

arr_size = tvm.testing.parameter(32)

@tvm.testing.uses_gpu
def test_cmp_load_store():
n = 32
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")

@tvm.testing.fixture
def compute(arr_size):
A = te.placeholder((arr_size,), name="A")
B = te.placeholder((arr_size,), name="B")
C = te.compute(A.shape, lambda *i: A(*i) > B(*i), name="C")
D = te.compute(C.shape, lambda *i: tvm.tir.all(C(*i), A(*i) > 1).astype("float32"), name="D")
return [A, B, C, D]

def check_llvm():
if not tvm.testing.device_enabled("llvm"):
return

@tvm.testing.fixture
def schedule(target, compute):
target = tvm.target.Target(target)
A, B, C, D = compute
if target.kind.name == "llvm":
s = te.create_schedule(D.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
xo1, xo2 = s[C].split(xo, factor=13)
s[C].parallel(xo2)
# BUILD and invoke the kernel.
f = tvm.build(s, [A, B, D], "llvm")
dev = tvm.cpu(0)
a_np = np.random.uniform(size=n).astype(A.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), dev)
f(a, b, d)
np.testing.assert_equal(
d.numpy(),
np.logical_and(a.numpy() > b.numpy(), a.numpy() > 1).astype("float32"),
)

def check_device(device):
if not tvm.testing.device_enabled(device):
return
dev = tvm.device(device, 0)
else:
s = te.create_schedule(D.op)
for stage in [C, D]:
xo, xi = s[stage].split(stage.op.axis[0], factor=4)
s[stage].bind(xo, te.thread_axis("blockIdx.x"))
s[stage].bind(xi, te.thread_axis("threadIdx.x"))
f = tvm.build(s, [A, B, D], device)
a_np = np.random.uniform(size=n).astype(A.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), dev)
f(a, b, d)
np.testing.assert_equal(
d.numpy(),
np.logical_and(a.numpy() > b.numpy(), a.numpy() > 1).astype("float32"),
)

check_llvm()
for device in ["vulkan", "opencl", "cuda", "rocm", "metal"]:
check_device(device)
return s


@tvm.testing.uses_gpu
def test_cmp_load_store(target, dev, arr_size, compute, schedule):
A, B, _, D = compute
f = tvm.build(schedule, [A, B, D], target)

a_np = np.random.uniform(size=arr_size).astype(A.dtype)
b_np = np.random.uniform(size=arr_size).astype(B.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
d = tvm.nd.array(np.zeros(arr_size, dtype=D.dtype), dev)
f(a, b, d)
np.testing.assert_equal(
d.numpy(),
np.logical_and(a_np > b_np, a_np > 1).astype("float32"),
)


if __name__ == "__main__":
test_cmp_load_store()
sys.exit(pytest.main(sys.argv))
Loading

0 comments on commit 47d56a1

Please sign in to comment.