From 1e7e39880c53841530787cdf391b21b27e6a166c Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Tue, 10 Dec 2019 15:41:33 -0800 Subject: [PATCH] Add AMD codeGen unit tests --- tests/python/unittest/test_codegen_rocm.py | 96 ++++++++++++++++++++-- 1 file changed, 89 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_codegen_rocm.py b/tests/python/unittest/test_codegen_rocm.py index 2077372cd5b6..bba72e053142 100644 --- a/tests/python/unittest/test_codegen_rocm.py +++ b/tests/python/unittest/test_codegen_rocm.py @@ -16,13 +16,15 @@ # under the License. import tvm import numpy as np +import unittest +tx = tvm.thread_axis("threadIdx.x") +ty = tvm.thread_axis("threadIdx.y") +bx = tvm.thread_axis("blockIdx.x") +by = tvm.thread_axis("blockIdx.y") +@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..") def test_rocm_cross_thread_reduction(): - if not tvm.rocm(0).exist or not tvm.module.enabled("rocm"): - print("skip because rocm is not enabled..") - return - # based on the reduction tutorial n = tvm.var("n") m = tvm.var("m") @@ -33,9 +35,8 @@ def test_rocm_cross_thread_reduction(): ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) BF = s.rfactor(B, ki) xo, xi = s[B].split(s[B].op.axis[0], factor=32) - s[B].bind(xo, tvm.thread_axis("blockIdx.x")) - s[B].bind(xi, tvm.thread_axis("threadIdx.y")) - tx = tvm.thread_axis("threadIdx.x") + s[B].bind(xo, bx) + s[B].bind(xi, ty) s[B].bind(s[B].op.reduce_axis[0], tx) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) s[B].set_store_predicate(tx.var.equal(0)) @@ -49,6 +50,87 @@ def test_rocm_cross_thread_reduction(): tvm.testing.assert_allclose( b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) + +@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..") +def test_rocm_inf_nan(): + def check_inf_nan(ctx, n, value, dtype): + A = tvm.placeholder((n,), name='A', dtype=dtype) + inf_value = tvm.const(value, dtype=dtype) + C = tvm.compute((n,), lambda i: inf_value, name='C') + s = tvm.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], tx) + fun = tvm.build(s, [A, C], "rocm") + a = tvm.nd.empty((n,), A.dtype, ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + # Only need to test compiling here + fun(a, c) + + ctx = tvm.rocm(0) + + check_inf_nan(ctx, 1, -float('inf'), 'float32') + check_inf_nan(ctx, 1, -float('inf'), 'float64') + check_inf_nan(ctx, 1, float('inf'), 'float32') + check_inf_nan(ctx, 1, float('inf'), 'float64') + check_inf_nan(ctx, 1, float('nan'), 'float32') + check_inf_nan(ctx, 1, float('nan'), 'float64') + +@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..") +def test_rocm_reducition_binding(): + k = tvm.reduce_axis((0, 32), 'k') + A = tvm.placeholder((96, 32), name='A') + B = tvm.compute( (96,), lambda m: + tvm.sum(A[m, k], axis=k), + name='B') + s = tvm.create_schedule(B.op) + + s[B].reorder(B.op.reduce_axis[0], B.op.axis[0]) + + mo, _ = s[B].split(B.op.axis[0], 32) + s[B].bind(mo, bx) + +@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..") +def test_rocm_copy(): + + def check_rocm(dtype, n): + A = tvm.placeholder((n,), name='A', dtype=dtype) + ctx = tvm.rocm(0) + a_np = np.random.uniform(size=(n,)).astype(A.dtype) + a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(a_np) + b_np = a.asnumpy() + tvm.testing.assert_allclose(a_np, b_np) + tvm.testing.assert_allclose(a_np, a.asnumpy()) + + for _ in range(100): + dtype = np.random.choice(["float32", "float16", "int8", "int32"]) + logN = np.random.randint(1, 15) + peturb = np.random.uniform(low=0.5, high=1.5) + check_rocm(dtype, int(peturb * (2 ** logN))) + +@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..") +def test_rocm_vectorize_add(): + num_thread = 8 + + def check_rocm(dtype, n, lanes): + A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes)) + B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B') + s = tvm.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=num_thread) + s[B].bind(xo, bx) + s[B].bind(xi, tx) + fun = tvm.build(s, [A, B], "rocm") + ctx = tvm.rocm(0) + a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom( + np.random.uniform(size=(n, lanes))) + c = tvm.nd.empty((n,), B.dtype, ctx) + fun(a, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1) + + check_rocm("float32", 64, 2) + check_rocm("float16", 64, 2) if __name__ == "__main__": test_rocm_cross_thread_reduction() + test_rocm_inf_nan() + test_rocm_reducition_binding() + test_rocm_copy() + test_rocm_vectorize_add()