Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 7 additions & 52 deletions apps/android_rpc/tests/android_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,15 @@ def test_rpc_module():
tracker = rpc.connect_tracker(tracker_host, tracker_port)
remote = tracker.request(key, priority=0, session_timeout=60)

# Compile the Graph for CPU target
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].parallel(xi)
s[B].pragma(xo, "parallel_launch_point")
s[B].pragma(xi, "parallel_barrier_when_finish")
f = tvm.build(s, [A, B], target, name="myadd_cpu")
path_dso_cpu = temp.relpath("cpu_lib.so")
f.export_library(path_dso_cpu, fcompile=ndk.create_shared)
mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]).with_attr("global_symbol", "myadd"))
sch = tvm.tir.Schedule(mod)
(x,) = sch.get_loops(block=sch.get_block("B"))
xo, xi = sch.split(i, [None, 32])
sch.bind(xo, "blockIdx.x")
sch.bind(xi, "threadIdx.x")

# Execute the portable graph on cpu target
print("Run CPU test ...")
dev = remote.cpu(0)
remote.upload(path_dso_cpu)
f2 = remote.load_module("cpu_lib.so")
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
time_f = f2.time_evaluator(f2.entry_name, dev, number=10)
cost = time_f(a, b).mean
print("%g secs/op\n" % cost)
np.testing.assert_equal(b.numpy(), a.numpy() + 1)

# Compile the Graph for OpenCL target
if test_opencl:
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, te.thread_axis("threadIdx.x"))
s[B].bind(xo, te.thread_axis("blockIdx.x"))
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], tvm.target.Target("opencl", host=target), name="myadd")
f = tvm.build(sch.mod, target=tvm.target.Target("opencl", host=target))
path_dso_cl = temp.relpath("dev_lib_cl.so")
f.export_library(path_dso_cl, fcompile=ndk.create_shared)

Expand All @@ -101,29 +79,6 @@ def test_rpc_module():
print("%g secs/op\n" % cost)
np.testing.assert_equal(b.numpy(), a.numpy() + 1)

# Compile the Graph for Vulkan target
if test_vulkan:
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, te.thread_axis("threadIdx.x"))
s[B].bind(xo, te.thread_axis("blockIdx.x"))
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], tvm.target.Target("vulkan", host=target), name="myadd")
path_dso_vulkan = temp.relpath("dev_lib_vulkan.so")
f.export_library(path_dso_vulkan, fcompile=ndk.create_shared)

print("Run GPU(Vulkan Flavor) test ...")
dev = remote.vulkan(0)
remote.upload(path_dso_vulkan)
f1 = remote.load_module("dev_lib_vulkan.so")
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
time_f = f1.time_evaluator(f1.entry_name, dev, number=10)
cost = time_f(a, b).mean
print("%g secs/op\n" % cost)
np.testing.assert_equal(b.numpy(), a.numpy() + 1)


if __name__ == "__main__":
test_rpc_module()
33 changes: 8 additions & 25 deletions apps/ios_rpc/tests/ios_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,19 @@ def test_rpc_module(host, port, key, mode):
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
temp = utils.tempdir()
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, te.thread_axis("threadIdx.x"))
s[B].bind(xo, te.thread_axis("blockIdx.x"))
mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]).with_attr("global_symbol", "myadd"))
sch = tvm.tir.Schedule(mod)
(i,) = sch.get_loops(block=sch.get_block("B"))
i0, i1 = sch.split(i, [None, 32])
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")

# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], tvm.target.Target("metal", host=target), name="myadd")
f = tvm.build(sch.mod, target=tvm.target.Target("metal", host=target))
path_dso1 = temp.relpath("dev_lib.dylib")
f.export_library(path_dso1, fcompile=xcode.create_dylib, arch=arch, sdk=sdk)

s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].parallel(xi)
s[B].pragma(xo, "parallel_launch_point")
s[B].pragma(xi, "parallel_barrier_when_finish")
f = tvm.build(s, [A, B], target, name="myadd_cpu")
path_dso2 = temp.relpath("cpu_lib.dylib")
f.export_library(path_dso2, fcompile=xcode.create_dylib, arch=arch, sdk=sdk)

# connect to the proxy
if mode == "tracker":
remote = MODES[mode](host, port).request(key)
Expand All @@ -84,17 +78,6 @@ def test_rpc_module(host, port, key, mode):
cost = time_f(a, b).mean
print("Metal: %g secs/op" % cost)
np.testing.assert_equal(b.numpy(), a.numpy() + 1)
# CPU
dev = remote.cpu(0)
remote.upload(path_dso2)
f2 = remote.load_module("cpu_lib.dylib")
a_np = np.random.uniform(size=1024).astype(A.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
time_f = f2.time_evaluator(f2.entry_name, dev, number=10)
cost = time_f(a, b).mean
print("CPU: %g secs/op" % cost)
np.testing.assert_equal(b.numpy(), a.numpy() + 1)


if __name__ == "__main__":
Expand Down
Loading
Loading