Skip to content

Commit

Permalink
Add support for AOT in external code generation tests (apache#8591)
Browse files Browse the repository at this point in the history
This adds support for the external code generation tests to use AOT. As
part of this the existing logic in check_result was split out into
multiple functions, this allows selectively disabling those that aren't
supported such as JSON outputs not being supported in AOT. I've replaced
existing checks to skip tests with @pytest.mark.skipif macros as they've
been moved out of the `check_result` function.
  • Loading branch information
Mousius authored and ylc committed Jan 13, 2022
1 parent 81bca08 commit ec2a091
Showing 1 changed file with 80 additions and 60 deletions.
140 changes: 80 additions & 60 deletions tests/python/relay/test_external_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,19 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for graph partitioning."""

import os
import sys
from collections import OrderedDict
import numpy as np
import pytest

import tvm
from tvm import te
import tvm.relay.testing
import tvm.relay.transform

from tvm import relay
from tvm import runtime
from tvm.relay import transform
from tvm import relay, runtime
from tvm.contrib import utils
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.op.annotation import compiler_begin, compiler_end
from aot.aot_test_utils import compile_and_run


def update_lib(lib):
Expand All @@ -48,37 +46,39 @@ def update_lib(lib):
return lib


def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()):
if sys.platform == "win32":
print("Skip test on Windows for now")
return

def check_vm_result():
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target)
code, lib = exe.save()
lib = update_lib(lib)
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe, device)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol)
def check_vm_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()):
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target)
code, lib = exe.save()
lib = update_lib(lib)
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe, device)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol)


def check_graph_executor_result(
mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()
):
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, target=target)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_executor.create(json, lib, device)

def check_graph_executor_result():
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, target=target)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_executor.create(json, lib, device)
for name, data in map_inputs.items():
rt_mod.set_input(name, data)
rt_mod.run()
out = tvm.nd.empty(out_shape, device=device)
out = rt_mod.get_output(0, out)

for name, data in map_inputs.items():
rt_mod.set_input(name, data)
rt_mod.run()
out = tvm.nd.empty(out_shape, device=device)
out = rt_mod.get_output(0, out)
tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol)

tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol)

check_vm_result()
check_graph_executor_result()
def check_aot_executor_result(
mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()
):
use_calculated_workspaces = True
compile_and_run(mod, list(map_inputs.values()), [result], "", use_calculated_workspaces)


def set_external_func_attr(func, compiler, ext_symbol):
Expand All @@ -88,7 +88,11 @@ def set_external_func_attr(func, compiler, ext_symbol):
return func


def test_multi_node_subgraph():
@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now")
@pytest.mark.parametrize(
"check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result]
)
def test_multi_node_subgraph(check_result):
x = relay.var("x", shape=(10, 10))
w0 = relay.var("w0", shape=(10, 10))
w1 = relay.var("w1", shape=(10, 10))
Expand Down Expand Up @@ -138,8 +142,7 @@ def test_multi_node_subgraph():
for _ in range(8):
w_data.append(np.random.rand(10, 10).astype("float32"))

map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
map_inputs["x"] = x_data
map_inputs = OrderedDict([("x", x_data)] + [("w{}".format(i), w_data[i]) for i in range(8)])
check_result(
mod,
map_inputs,
Expand All @@ -155,7 +158,11 @@ def test_multi_node_subgraph():
)


def test_extern_gcc_single_op():
@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now")
@pytest.mark.parametrize(
"check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result]
)
def test_extern_gcc_single_op(check_result):
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))

Expand All @@ -172,7 +179,11 @@ def test_extern_gcc_single_op():
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)


def test_extern_gcc_single_op_int():
@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now")
@pytest.mark.parametrize(
"check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result]
)
def test_extern_gcc_single_op_int(check_result):
x = relay.var("x", shape=(8, 8), dtype="int32")
y = relay.var("y", shape=(8, 8), dtype="int32")

Expand All @@ -189,7 +200,11 @@ def test_extern_gcc_single_op_int():
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)


def test_extern_gcc():
@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now")
@pytest.mark.parametrize(
"check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result]
)
def test_extern_gcc(check_result):
x = relay.var("x", shape=(2, 2))
y = relay.var("y", shape=(2, 2))

Expand Down Expand Up @@ -221,9 +236,17 @@ def test_extern_gcc():
x_data = np.random.rand(2, 2).astype("float32")
y_data = np.random.rand(2, 2).astype("float32")

check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))
inputs = OrderedDict(
[
("y", y_data),
("x", x_data),
]
)

check_result(mod, inputs, (2, 2), (y_data * y_data) - (x_data + x_data))


@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now")
def test_extern_gcc_consts():
@tvm._ffi.register_func("relay.ext.ccompiler.constant_updater")
def constant_updater(expr, symbol):
Expand Down Expand Up @@ -257,11 +280,13 @@ def constant_updater(expr, symbol):
tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater")


def test_extern_dnnl():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return

@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now")
@pytest.mark.skipif(
not tvm.get_global_func("relay.ext.dnnl", True),
reason="skip because DNNL codegen is not available",
)
@pytest.mark.parametrize("check_result", [check_vm_result, check_graph_executor_result])
def test_extern_dnnl(check_result):
dtype = "float32"
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
Expand Down Expand Up @@ -297,11 +322,13 @@ def test_extern_dnnl():
)


def test_extern_dnnl_const():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return

@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now")
@pytest.mark.skipif(
not tvm.get_global_func("relay.ext.dnnl", True),
reason="skip because DNNL codegen is not available",
)
@pytest.mark.parametrize("check_result", [check_vm_result, check_graph_executor_result])
def test_extern_dnnl_const(check_result):
dtype = "float32"
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
Expand Down Expand Up @@ -349,7 +376,7 @@ def test_load_params_with_constants_in_ext_codegen():
zce = compiler_end(z, "ccompiler")
mod["main"] = relay.Function([x, y], zce)
mod["main"] = bind_params_by_name(mod["main"], params)
mod = transform.PartitionGraph()(mod)
mod = relay.transform.PartitionGraph()(mod)

graph_module = relay.build(mod, target="llvm", params=params)
# Params will be stored in metadata module.
Expand All @@ -360,11 +387,4 @@ def test_load_params_with_constants_in_ext_codegen():


if __name__ == "__main__":
test_multi_node_subgraph()
test_extern_gcc_single_op()
test_extern_gcc_single_op_int()
test_extern_gcc()
test_extern_gcc_consts()
test_extern_dnnl()
test_extern_dnnl_const()
test_load_params_with_constants_in_ext_codegen()
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit ec2a091

Please sign in to comment.